2222#include < logic.hpp>
2323
2424using namespace detail ;
25+ using af::dim4;
26+
27+ static dim4 getOutDims (const dim4 ldims, const dim4 rdims, bool batchMode)
28+ {
29+ if (!batchMode) {
30+ DIM_ASSERT (1 , ldims == rdims);
31+ return ldims;
32+ }
33+
34+ AF_ERROR (" Batch mode not supported yet" , AF_ERR_NOT_SUPPORTED );
35+ }
2536
2637template <typename T, af_op_t op>
27- static inline af_array arithOp (const af_array lhs, const af_array rhs)
38+ static inline af_array arithOp (const af_array lhs, const af_array rhs,
39+ const dim4 &odims)
2840{
29- af_array res = getHandle (*arithOp<T, op>(getArray<T>(lhs), getArray<T>(rhs)));
41+ af_array res = getHandle (*arithOp<T, op>(getArray<T>(lhs), getArray<T>(rhs), odims ));
3042 // All inputs to this function are temporary references
3143 // Delete the temporary references
3244 destroyHandle<T>(lhs);
@@ -45,19 +57,18 @@ static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, bo
4557 ArrayInfo linfo = getInfo (lhs);
4658 ArrayInfo rinfo = getInfo (rhs);
4759
48- if (!batchMode) DIM_ASSERT (1 , linfo.dims () == rinfo.dims ());
49- else AF_ERROR (" Batch mode not supported yet" , AF_ERR_NOT_SUPPORTED );
60+ dim4 odims = getOutDims (linfo.dims (), rinfo.dims (), batchMode);
5061
5162 af_array res;
5263 switch (otype) {
53- case f32 : res = arithOp<float , op>(left, right); break ;
54- case f64 : res = arithOp<double , op>(left, right); break ;
55- case c32: res = arithOp<cfloat , op>(left, right); break ;
56- case c64: res = arithOp<cdouble, op>(left, right); break ;
57- case s32: res = arithOp<int , op>(left, right); break ;
58- case u32 : res = arithOp<uint , op>(left, right); break ;
59- case u8 : res = arithOp<uchar , op>(left, right); break ;
60- case b8 : res = arithOp<char , op>(left, right); break ;
64+ case f32 : res = arithOp<float , op>(left, right, odims ); break ;
65+ case f64 : res = arithOp<double , op>(left, right, odims ); break ;
66+ case c32: res = arithOp<cfloat , op>(left, right, odims ); break ;
67+ case c64: res = arithOp<cdouble, op>(left, right, odims ); break ;
68+ case s32: res = arithOp<int , op>(left, right, odims ); break ;
69+ case u32 : res = arithOp<uint , op>(left, right, odims ); break ;
70+ case u8 : res = arithOp<uchar , op>(left, right, odims ); break ;
71+ case b8 : res = arithOp<char , op>(left, right, odims ); break ;
6172 default : TYPE_ERROR (0 , otype);
6273 }
6374
@@ -78,17 +89,16 @@ static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rh
7889 ArrayInfo linfo = getInfo (lhs);
7990 ArrayInfo rinfo = getInfo (rhs);
8091
81- if (!batchMode) DIM_ASSERT (1 , linfo.dims () == rinfo.dims ());
82- else AF_ERROR (" Batch mode not supported yet" , AF_ERR_NOT_SUPPORTED );
92+ dim4 odims = getOutDims (linfo.dims (), rinfo.dims (), batchMode);
8393
8494 af_array res;
8595 switch (otype) {
86- case f32 : res = arithOp<float , op>(left, right); break ;
87- case f64 : res = arithOp<double , op>(left, right); break ;
88- case s32: res = arithOp<int , op>(left, right); break ;
89- case u32 : res = arithOp<uint , op>(left, right); break ;
90- case u8 : res = arithOp<uchar , op>(left, right); break ;
91- case b8 : res = arithOp<char , op>(left, right); break ;
96+ case f32 : res = arithOp<float , op>(left, right, odims ); break ;
97+ case f64 : res = arithOp<double , op>(left, right, odims ); break ;
98+ case s32: res = arithOp<int , op>(left, right, odims ); break ;
99+ case u32 : res = arithOp<uint , op>(left, right, odims ); break ;
100+ case u8 : res = arithOp<uchar , op>(left, right, odims ); break ;
101+ case b8 : res = arithOp<char , op>(left, right, odims ); break ;
92102 default : TYPE_ERROR (0 , otype);
93103 }
94104
@@ -169,14 +179,12 @@ af_err af_atan2(af_array *out, const af_array lhs, const af_array rhs, bool batc
169179 ArrayInfo linfo = getInfo (lhs);
170180 ArrayInfo rinfo = getInfo (rhs);
171181
172- if (!batchMode) DIM_ASSERT (1 , linfo.dims () == rinfo.dims ());
173- else AF_ERROR (" Batch mode not supported yet" , AF_ERR_NOT_SUPPORTED );
182+ dim4 odims = getOutDims (linfo.dims (), rinfo.dims (), batchMode);
174183
175184 af_array res;
176-
177185 switch (type) {
178- case f32 : res = arithOp<float , af_atan2_t >(left, right); break ;
179- case f64 : res = arithOp<double , af_atan2_t >(left, right); break ;
186+ case f32 : res = arithOp<float , af_atan2_t >(left, right, odims ); break ;
187+ case f64 : res = arithOp<double , af_atan2_t >(left, right, odims ); break ;
180188 default : TYPE_ERROR (0 , type);
181189 }
182190
@@ -203,14 +211,12 @@ af_err af_hypot(af_array *out, const af_array lhs, const af_array rhs, bool batc
203211 ArrayInfo linfo = getInfo (lhs);
204212 ArrayInfo rinfo = getInfo (rhs);
205213
206- if (!batchMode) DIM_ASSERT (1 , linfo.dims () == rinfo.dims ());
207- else AF_ERROR (" Batch mode not supported yet" , AF_ERR_NOT_SUPPORTED );
214+ dim4 odims = getOutDims (linfo.dims (), rinfo.dims (), batchMode);
208215
209216 af_array res;
210-
211217 switch (type) {
212- case f32 : res = arithOp<float , af_hypot_t >(left, right); break ;
213- case f64 : res = arithOp<double , af_hypot_t >(left, right); break ;
218+ case f32 : res = arithOp<float , af_hypot_t >(left, right, odims ); break ;
219+ case f64 : res = arithOp<double , af_hypot_t >(left, right, odims ); break ;
214220 default : TYPE_ERROR (0 , type);
215221 }
216222
@@ -221,9 +227,9 @@ af_err af_hypot(af_array *out, const af_array lhs, const af_array rhs, bool batc
221227}
222228
223229template <typename T, af_op_t op>
224- static inline af_array logicOp (const af_array lhs, const af_array rhs)
230+ static inline af_array logicOp (const af_array lhs, const af_array rhs, const dim4 &odims )
225231{
226- af_array res = getHandle (*logicOp<T, op>(getArray<T>(lhs), getArray<T>(rhs)));
232+ af_array res = getHandle (*logicOp<T, op>(getArray<T>(lhs), getArray<T>(rhs), odims ));
227233 // All inputs to this function are temporary references
228234 // Delete the temporary references
229235 destroyHandle<T>(lhs);
@@ -243,19 +249,18 @@ static af_err af_logic(af_array *out, const af_array lhs, const af_array rhs, bo
243249 ArrayInfo linfo = getInfo (lhs);
244250 ArrayInfo rinfo = getInfo (rhs);
245251
246- if (!batchMode) DIM_ASSERT (1 , linfo.dims () == rinfo.dims ());
247- else AF_ERROR (" Batch mode not supported yet" , AF_ERR_NOT_SUPPORTED );
252+ dim4 odims = getOutDims (linfo.dims (), rinfo.dims (), batchMode);
248253
249254 af_array res;
250255 switch (type) {
251- case f32 : res = logicOp<float , op>(left, right); break ;
252- case f64 : res = logicOp<double , op>(left, right); break ;
253- case c32: res = logicOp<cfloat , op>(left, right); break ;
254- case c64: res = logicOp<cdouble, op>(left, right); break ;
255- case s32: res = logicOp<int , op>(left, right); break ;
256- case u32 : res = logicOp<uint , op>(left, right); break ;
257- case u8 : res = logicOp<uchar , op>(left, right); break ;
258- case b8 : res = logicOp<char , op>(left, right); break ;
256+ case f32 : res = logicOp<float , op>(left, right, odims ); break ;
257+ case f64 : res = logicOp<double , op>(left, right, odims ); break ;
258+ case c32: res = logicOp<cfloat , op>(left, right, odims ); break ;
259+ case c64: res = logicOp<cdouble, op>(left, right, odims ); break ;
260+ case s32: res = logicOp<int , op>(left, right, odims ); break ;
261+ case u32 : res = logicOp<uint , op>(left, right, odims ); break ;
262+ case u8 : res = logicOp<uchar , op>(left, right, odims ); break ;
263+ case b8 : res = logicOp<char , op>(left, right, odims ); break ;
259264 default : TYPE_ERROR (0 , type);
260265 }
261266
0 commit comments