/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include using namespace af; using std::vector; template class Constant : public ::testing::Test { }; typedef ::testing::Types TestTypes; TYPED_TEST_CASE(Constant, TestTypes); template void ConstantCPPCheck(T value) { if (noDoubleTests()) return; const int num = 1000; T val = value; dtype dty = (dtype) dtype_traits::af_type; af::array in = constant(val, num, dty); vector h_in(num); in.host(&h_in.front()); for (int i = 0; i < num; i++) { ASSERT_EQ(h_in[i], val); } } template void ConstantCCheck(T value) { if (noDoubleTests()) return; const int num = 1000; typedef typename af::dtype_traits::base_type BT; BT val = ::real(value); dtype dty = (dtype) dtype_traits::af_type; af_array out; dim_t dim[] = {(dim_t)num}; ASSERT_EQ(AF_SUCCESS, af_constant(&out, val, 1, dim, dty)); vector h_in(num); af_get_data_ptr(&h_in.front(), out); for (int i = 0; i < num; i++) { ASSERT_EQ(::real(h_in[i]), val); } } template void IdentityCPPCheck() { if (noDoubleTests()) return; int num = 1000; dtype dty = (dtype) dtype_traits::af_type; array out = af::identity(num, num, dty); vector h_in(num*num); out.host(&h_in.front()); for (int i = 0; i < num; i++) { for (int j = 0; j < num; j++) { if(j == i) ASSERT_EQ(h_in[i * num + j], T(1)); else ASSERT_EQ(h_in[i * num + j], T(0)); } } num = 100; out = af::identity(num, num, num, dty); h_in.resize(num*num*num); out.host(&h_in.front()); for (int h = 0; h < num; h++) { for (int i = 0; i < num; i++) { for (int j = 0; j < num; j++) { if(j == i) ASSERT_EQ(h_in[i * num + j], T(1)); else ASSERT_EQ(h_in[i * num + j], T(0)); } } } } template void IdentityCCheck() { if (noDoubleTests()) return; static const int num = 1000; dtype dty = (dtype) dtype_traits::af_type; af_array out; dim_t dim[] = {(dim_t)num, (dim_t)num}; ASSERT_EQ(AF_SUCCESS, af_identity(&out, 2, dim, dty)); vector h_in(num*num); af_get_data_ptr(&h_in.front(), out); for (int i = 0; i < num; i++) { for (int j = 0; j < num; j++) { if(j == i) ASSERT_EQ(h_in[i * num + j], T(1)); else ASSERT_EQ(h_in[i * num + j], T(0)); } } } template void IdentityCPPError() { if (noDoubleTests()) return; static const int num = 1000; dtype dty = (dtype) dtype_traits::af_type; try { array out = af::identity(num, 0, 10, dty); } catch(const af::exception &ex) { SUCCEED(); return; } FAIL() << "Failed to throw an exception"; } TYPED_TEST(Constant, basicCPP) { ConstantCPPCheck(5); } TYPED_TEST(Constant, basicC) { ConstantCCheck(5); } TYPED_TEST(Constant, IdentityC) { IdentityCCheck(); } TYPED_TEST(Constant, IdentityCPP) { IdentityCPPCheck(); } TYPED_TEST(Constant, IdentityCPPError) { IdentityCPPError(); }