1616#include " error.hpp"
1717#include < af/defines.h>
1818
19- namespace af
20- {
19+ #include < type_traits>
20+
21+ using std::enable_if;
22+ using af::array;
23+ using af::dim4;
24+ using af::dtype;
25+
26+ namespace {
27+ template <typename T> struct is_complex { static const bool value = false ; };
28+ template <> struct is_complex <af::cfloat> { static const bool value = true ; };
29+ template <> struct is_complex <af::cdouble> { static const bool value = true ; };
2130
2231 template <typename T>
23- array
24- constant (T val, const dim4 & dims, const af:: dtype type)
32+ typename enable_if<is_complex<T>::value == false , array>::type
33+ constant (T val, const dim4& dims, const dtype type)
2534 {
2635 af_array res;
2736 if (type != s64 && type != u64 ) {
@@ -40,11 +49,12 @@ namespace af
4049 return array (res);
4150 }
4251
43- template <>
44- AFAPI array constant (cfloat val, const dim4 &dims, const af::dtype type)
52+ template <typename T>
53+ typename enable_if<is_complex<T>::value == true , array>::type
54+ constant (T val, const dim4& dims, const dtype type)
4555 {
4656 if (type != c32 && type != c64) {
47- return constant (real (val), dims, type);
57+ return :: constant (real (val), dims, type);
4858 }
4959 af_array res;
5060 AF_THROW (af_constant_complex (&res,
@@ -54,57 +64,53 @@ namespace af
5464 dims.get (), type));
5565 return array (res);
5666 }
67+ }
5768
58- template <>
59- AFAPI array constant (cdouble val, const dim4 &dims, const af::dtype type)
60- {
61- if (type != c32 && type != c64) {
62- return constant (real (val), dims, type);
63- }
64- af_array res;
65- AF_THROW (af_constant_complex (&res,
66- real (val),
67- imag (val),
68- dims.ndims (),
69- dims.get (), type));
70- return array (res);
71- }
69+ namespace af
70+ {
7271 template <typename T>
73- array constant (T val, const dim_t d0, const af::dtype ty)
74- {
75- return constant (val, dim4 (d0), ty);
72+ array constant (T val, const dim4& dims, const af::dtype type) {
73+ return ::constant (val, dims, type);
7674 }
7775
76+ template <typename T>
77+ array constant (T val, const dim_t d0, const af::dtype ty)
78+ {
79+ return ::constant (val, dim4 (d0), ty);
80+ }
81+
7882 template <typename T>
7983 array constant (T val, const dim_t d0, const dim_t d1, const af::dtype ty)
8084 {
81- return constant (val, dim4 (d0, d1), ty);
85+ return :: constant (val, dim4 (d0, d1), ty);
8286 }
8387
8488 template <typename T>
8589 array constant (T val, const dim_t d0, const dim_t d1, const dim_t d2, const af::dtype ty)
8690 {
87- return constant (val, dim4 (d0, d1, d2), ty);
91+ return :: constant (val, dim4 (d0, d1, d2), ty);
8892 }
8993
9094 template <typename T>
9195 array constant (T val, const dim_t d0, const dim_t d1, const dim_t d2, const dim_t d3, const af::dtype ty)
9296 {
93- return constant (val, dim4 (d0, d1, d2, d3), ty);
94- }
95-
96- #define CONSTANT (TYPE ) \
97- template AFAPI array constant<TYPE>(TYPE val, const dim4 &dims, const af::dtype ty); \
98- template AFAPI array constant<TYPE>(TYPE val, const dim_t d0, const af::dtype ty); \
99- template AFAPI array constant<TYPE>(TYPE val, const dim_t d0, \
100- const dim_t d1, const af::dtype ty); \
101- template AFAPI array constant<TYPE>(TYPE val, const dim_t d0, \
102- const dim_t d1, \
103- const dim_t d2, const af::dtype ty); \
104- template AFAPI array constant<TYPE>(TYPE val, const dim_t d0, \
105- const dim_t d1, \
106- const dim_t d2, \
107- const dim_t d3, const af::dtype ty);
97+ return ::constant (val, dim4 (d0, d1, d2, d3), ty);
98+ }
99+
100+ #define CONSTANT (TYPE ) \
101+ template AFAPI array constant<TYPE>(TYPE val, const dim4& dims, \
102+ const af::dtype ty); \
103+ template AFAPI array constant<TYPE>(TYPE val, const dim_t d0, \
104+ const af::dtype ty); \
105+ template AFAPI array constant<TYPE>(TYPE val, const dim_t d0, \
106+ const dim_t d1, const af::dtype ty); \
107+ template AFAPI array constant<TYPE>(TYPE val, const dim_t d0, \
108+ const dim_t d1, const dim_t d2, \
109+ const af::dtype ty); \
110+ template AFAPI array constant<TYPE>(TYPE val, const dim_t d0, \
111+ const dim_t d1, \
112+ const dim_t d2, \
113+ const dim_t d3, const af::dtype ty);
108114 CONSTANT (double );
109115 CONSTANT (float );
110116 CONSTANT (int );
0 commit comments