Skip to content

Commit d42fc21

Browse files
committed
binaryNode now accepts output dimension size
1 parent 2d37a95 commit d42fc21

14 files changed

Lines changed: 100 additions & 73 deletions

File tree

src/api/c/binary.cpp

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,23 @@
2222
#include <logic.hpp>
2323

2424
using namespace detail;
25+
using af::dim4;
26+
27+
static dim4 getOutDims(const dim4 ldims, const dim4 rdims, bool batchMode)
28+
{
29+
if (!batchMode) {
30+
DIM_ASSERT(1, ldims == rdims);
31+
return ldims;
32+
}
33+
34+
AF_ERROR("Batch mode not supported yet", AF_ERR_NOT_SUPPORTED);
35+
}
2536

2637
template<typename T, af_op_t op>
27-
static inline af_array arithOp(const af_array lhs, const af_array rhs)
38+
static inline af_array arithOp(const af_array lhs, const af_array rhs,
39+
const dim4 &odims)
2840
{
29-
af_array res = getHandle(*arithOp<T, op>(getArray<T>(lhs), getArray<T>(rhs)));
41+
af_array res = getHandle(*arithOp<T, op>(getArray<T>(lhs), getArray<T>(rhs), odims));
3042
// All inputs to this function are temporary references
3143
// Delete the temporary references
3244
destroyHandle<T>(lhs);
@@ -45,19 +57,18 @@ static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, bo
4557
ArrayInfo linfo = getInfo(lhs);
4658
ArrayInfo rinfo = getInfo(rhs);
4759

48-
if (!batchMode) DIM_ASSERT(1, linfo.dims() == rinfo.dims());
49-
else AF_ERROR("Batch mode not supported yet", AF_ERR_NOT_SUPPORTED);
60+
dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
5061

5162
af_array res;
5263
switch (otype) {
53-
case f32: res = arithOp<float , op>(left, right); break;
54-
case f64: res = arithOp<double , op>(left, right); break;
55-
case c32: res = arithOp<cfloat , op>(left, right); break;
56-
case c64: res = arithOp<cdouble, op>(left, right); break;
57-
case s32: res = arithOp<int , op>(left, right); break;
58-
case u32: res = arithOp<uint , op>(left, right); break;
59-
case u8 : res = arithOp<uchar , op>(left, right); break;
60-
case b8 : res = arithOp<char , op>(left, right); break;
64+
case f32: res = arithOp<float , op>(left, right, odims); break;
65+
case f64: res = arithOp<double , op>(left, right, odims); break;
66+
case c32: res = arithOp<cfloat , op>(left, right, odims); break;
67+
case c64: res = arithOp<cdouble, op>(left, right, odims); break;
68+
case s32: res = arithOp<int , op>(left, right, odims); break;
69+
case u32: res = arithOp<uint , op>(left, right, odims); break;
70+
case u8 : res = arithOp<uchar , op>(left, right, odims); break;
71+
case b8 : res = arithOp<char , op>(left, right, odims); break;
6172
default: TYPE_ERROR(0, otype);
6273
}
6374

@@ -78,17 +89,16 @@ static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rh
7889
ArrayInfo linfo = getInfo(lhs);
7990
ArrayInfo rinfo = getInfo(rhs);
8091

81-
if (!batchMode) DIM_ASSERT(1, linfo.dims() == rinfo.dims());
82-
else AF_ERROR("Batch mode not supported yet", AF_ERR_NOT_SUPPORTED);
92+
dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
8393

8494
af_array res;
8595
switch (otype) {
86-
case f32: res = arithOp<float , op>(left, right); break;
87-
case f64: res = arithOp<double , op>(left, right); break;
88-
case s32: res = arithOp<int , op>(left, right); break;
89-
case u32: res = arithOp<uint , op>(left, right); break;
90-
case u8 : res = arithOp<uchar , op>(left, right); break;
91-
case b8 : res = arithOp<char , op>(left, right); break;
96+
case f32: res = arithOp<float , op>(left, right, odims); break;
97+
case f64: res = arithOp<double , op>(left, right, odims); break;
98+
case s32: res = arithOp<int , op>(left, right, odims); break;
99+
case u32: res = arithOp<uint , op>(left, right, odims); break;
100+
case u8 : res = arithOp<uchar , op>(left, right, odims); break;
101+
case b8 : res = arithOp<char , op>(left, right, odims); break;
92102
default: TYPE_ERROR(0, otype);
93103
}
94104

@@ -169,14 +179,12 @@ af_err af_atan2(af_array *out, const af_array lhs, const af_array rhs, bool batc
169179
ArrayInfo linfo = getInfo(lhs);
170180
ArrayInfo rinfo = getInfo(rhs);
171181

172-
if (!batchMode) DIM_ASSERT(1, linfo.dims() == rinfo.dims());
173-
else AF_ERROR("Batch mode not supported yet", AF_ERR_NOT_SUPPORTED);
182+
dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
174183

175184
af_array res;
176-
177185
switch (type) {
178-
case f32: res = arithOp<float , af_atan2_t>(left, right); break;
179-
case f64: res = arithOp<double, af_atan2_t>(left, right); break;
186+
case f32: res = arithOp<float , af_atan2_t>(left, right, odims); break;
187+
case f64: res = arithOp<double, af_atan2_t>(left, right, odims); break;
180188
default: TYPE_ERROR(0, type);
181189
}
182190

@@ -203,14 +211,12 @@ af_err af_hypot(af_array *out, const af_array lhs, const af_array rhs, bool batc
203211
ArrayInfo linfo = getInfo(lhs);
204212
ArrayInfo rinfo = getInfo(rhs);
205213

206-
if (!batchMode) DIM_ASSERT(1, linfo.dims() == rinfo.dims());
207-
else AF_ERROR("Batch mode not supported yet", AF_ERR_NOT_SUPPORTED);
214+
dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
208215

209216
af_array res;
210-
211217
switch (type) {
212-
case f32: res = arithOp<float , af_hypot_t>(left, right); break;
213-
case f64: res = arithOp<double, af_hypot_t>(left, right); break;
218+
case f32: res = arithOp<float , af_hypot_t>(left, right, odims); break;
219+
case f64: res = arithOp<double, af_hypot_t>(left, right, odims); break;
214220
default: TYPE_ERROR(0, type);
215221
}
216222

@@ -221,9 +227,9 @@ af_err af_hypot(af_array *out, const af_array lhs, const af_array rhs, bool batc
221227
}
222228

223229
template<typename T, af_op_t op>
224-
static inline af_array logicOp(const af_array lhs, const af_array rhs)
230+
static inline af_array logicOp(const af_array lhs, const af_array rhs, const dim4 &odims)
225231
{
226-
af_array res = getHandle(*logicOp<T, op>(getArray<T>(lhs), getArray<T>(rhs)));
232+
af_array res = getHandle(*logicOp<T, op>(getArray<T>(lhs), getArray<T>(rhs), odims));
227233
// All inputs to this function are temporary references
228234
// Delete the temporary references
229235
destroyHandle<T>(lhs);
@@ -243,19 +249,18 @@ static af_err af_logic(af_array *out, const af_array lhs, const af_array rhs, bo
243249
ArrayInfo linfo = getInfo(lhs);
244250
ArrayInfo rinfo = getInfo(rhs);
245251

246-
if (!batchMode) DIM_ASSERT(1, linfo.dims() == rinfo.dims());
247-
else AF_ERROR("Batch mode not supported yet", AF_ERR_NOT_SUPPORTED);
252+
dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
248253

249254
af_array res;
250255
switch (type) {
251-
case f32: res = logicOp<float , op>(left, right); break;
252-
case f64: res = logicOp<double , op>(left, right); break;
253-
case c32: res = logicOp<cfloat , op>(left, right); break;
254-
case c64: res = logicOp<cdouble, op>(left, right); break;
255-
case s32: res = logicOp<int , op>(left, right); break;
256-
case u32: res = logicOp<uint , op>(left, right); break;
257-
case u8 : res = logicOp<uchar , op>(left, right); break;
258-
case b8 : res = logicOp<char , op>(left, right); break;
256+
case f32: res = logicOp<float , op>(left, right, odims); break;
257+
case f64: res = logicOp<double , op>(left, right, odims); break;
258+
case c32: res = logicOp<cfloat , op>(left, right, odims); break;
259+
case c64: res = logicOp<cdouble, op>(left, right, odims); break;
260+
case s32: res = logicOp<int , op>(left, right, odims); break;
261+
case u32: res = logicOp<uint , op>(left, right, odims); break;
262+
case u8 : res = logicOp<uchar , op>(left, right, odims); break;
263+
case b8 : res = logicOp<char , op>(left, right, odims); break;
259264
default: TYPE_ERROR(0, type);
260265
}
261266

src/api/c/complex.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,23 @@
2121
#include <complex.hpp>
2222

2323
using namespace detail;
24+
using af::dim4;
25+
26+
static dim4 getOutDims(const dim4 ldims, const dim4 rdims, bool batchMode)
27+
{
28+
if (!batchMode) {
29+
DIM_ASSERT(1, ldims == rdims);
30+
return ldims;
31+
}
32+
33+
AF_ERROR("Batch mode not supported yet", AF_ERR_NOT_SUPPORTED);
34+
}
2435

2536
template<typename To, typename Ti>
26-
static inline af_array cplx(const af_array lhs, const af_array rhs, bool destroy=true)
37+
static inline af_array cplx(const af_array lhs, const af_array rhs,
38+
const dim4 &odims, bool destroy=true)
2739
{
28-
af_array res = getHandle(*cplx<To, Ti>(getArray<Ti>(lhs), getArray<Ti>(rhs)));
40+
af_array res = getHandle(*cplx<To, Ti>(getArray<Ti>(lhs), getArray<Ti>(rhs), odims));
2941
if (destroy) {
3042
// All inputs to this function are temporary references
3143
// Delete the temporary references
@@ -47,16 +59,15 @@ af_err af_cplx2(af_array *out, const af_array lhs, const af_array rhs, bool batc
4759

4860
if (type != f64) type = f32;
4961

50-
if (!batchMode) DIM_ASSERT(1, getInfo(lhs).dims() == getInfo(rhs).dims());
51-
else AF_ERROR("Batch mode not supported yet", AF_ERR_NOT_SUPPORTED);
62+
dim4 odims = getOutDims(getInfo(lhs).dims(), getInfo(rhs).dims(), batchMode);
5263

5364
const af_array left = cast(lhs, type);
5465
const af_array right = cast(rhs, type);
5566

5667
af_array res;
5768
switch (type) {
58-
case f32: res = cplx<cfloat , float>(left, right); break;
59-
case f64: res = cplx<cdouble, double>(left, right); break;
69+
case f32: res = cplx<cfloat , float >(left, right, odims); break;
70+
case f64: res = cplx<cdouble, double>(left, right, odims); break;
6071
default: TYPE_ERROR(0, type);
6172
}
6273

@@ -86,8 +97,8 @@ af_err af_cplx(af_array *out, const af_array in)
8697
af_array res;
8798
switch (type) {
8899

89-
case f32: res = cplx<cfloat , float >(in, tmp, false); break;
90-
case f64: res = cplx<cdouble, double>(in, tmp, false); break;
100+
case f32: res = cplx<cfloat , float >(in, tmp, info.dims(), false); break;
101+
case f64: res = cplx<cdouble, double>(in, tmp, info.dims(), false); break;
91102

92103
default: TYPE_ERROR(0, type);
93104
}

src/api/c/data.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ static inline af_array createCplx(dim4 dims, const Ti real, const Ti imag)
117117
{
118118
Array<Ti> *Real = createValueArray<Ti>(dims, real);
119119
Array<Ti> *Imag = createValueArray<Ti>(dims, imag);
120-
Array<To> *Cplx = cplx<To, Ti>(*Real, *Imag);
120+
Array<To> *Cplx = cplx<To, Ti>(*Real, *Imag, dims);
121121
af_array out = getHandle(*Cplx);
122122

123123
destroyArray(*Real);

src/backend/cpu/arith.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <af/defines.h>
1111
#include <af/array.h>
12+
#include <af/dim4.hpp>
1213
#include <Array.hpp>
1314
#include <optypes.hpp>
1415
#include <err_cpu.hpp>
@@ -55,14 +56,14 @@ NUMERIC_FN(af_atan2_t, atan2)
5556
NUMERIC_FN(af_hypot_t, hypot)
5657

5758
template<typename T, af_op_t op>
58-
Array<T>* arithOp(const Array<T> &lhs, const Array<T> &rhs)
59+
Array<T>* arithOp(const Array<T> &lhs, const Array<T> &rhs, const af::dim4 &odims)
5960
{
6061
TNJ::Node_ptr lhs_node = lhs.getNode();
6162
TNJ::Node_ptr rhs_node = rhs.getNode();
6263

6364
TNJ::BinaryNode<T, T, op> *node = new TNJ::BinaryNode<T, T, op>(lhs_node, rhs_node);
6465

65-
return createNodeArray<T>(lhs.dims(), TNJ::Node_ptr(
66+
return createNodeArray<T>(odims, TNJ::Node_ptr(
6667
reinterpret_cast<TNJ::Node *>(node)));
6768
}
6869
}

src/backend/cpu/complex.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <complex>
1111
#include <af/defines.h>
1212
#include <af/array.h>
13+
#include <af/dim4.hpp>
1314
#include <Array.hpp>
1415
#include <optypes.hpp>
1516
#include <err_cpu.hpp>
@@ -29,15 +30,15 @@ namespace cpu
2930
};
3031

3132
template<typename To, typename Ti>
32-
Array<To>* cplx(const Array<Ti> &lhs, const Array<Ti> &rhs)
33+
Array<To>* cplx(const Array<Ti> &lhs, const Array<Ti> &rhs, const af::dim4 &odims)
3334
{
3435
TNJ::Node_ptr lhs_node = lhs.getNode();
3536
TNJ::Node_ptr rhs_node = rhs.getNode();
3637

3738
TNJ::BinaryNode<To, Ti, af_cplx2_t> *node =
3839
new TNJ::BinaryNode<To, Ti, af_cplx2_t>(lhs_node, rhs_node);
3940

40-
return createNodeArray<To>(lhs.dims(), TNJ::Node_ptr(
41+
return createNodeArray<To>(odims, TNJ::Node_ptr(
4142
reinterpret_cast<TNJ::Node *>(node)));
4243
}
4344

src/backend/cpu/logic.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <af/defines.h>
1111
#include <af/array.h>
12+
#include <af/dim4.hpp>
1213
#include <Array.hpp>
1314
#include <optypes.hpp>
1415
#include <err_cpu.hpp>
@@ -69,14 +70,14 @@ LOGIC_CPLX_FN(double, af_or_t, ||)
6970
#undef LOGIC_CPLX_FN
7071

7172
template<typename T, af_op_t op>
72-
Array<uchar>* logicOp(const Array<T> &lhs, const Array<T> &rhs)
73+
Array<uchar>* logicOp(const Array<T> &lhs, const Array<T> &rhs, const af::dim4 &odims)
7374
{
7475
TNJ::Node_ptr lhs_node = lhs.getNode();
7576
TNJ::Node_ptr rhs_node = rhs.getNode();
7677

7778
TNJ::BinaryNode<uchar, T, op> *node = new TNJ::BinaryNode<uchar, T, op>(lhs_node, rhs_node);
7879

79-
return createNodeArray<uchar>(lhs.dims(), TNJ::Node_ptr(
80+
return createNodeArray<uchar>(odims, TNJ::Node_ptr(
8081
reinterpret_cast<TNJ::Node *>(node)));
8182
}
8283
}

src/backend/cuda/arith.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <af/defines.h>
1111
#include <af/array.h>
12+
#include <af/dim4.hpp>
1213
#include <Array.hpp>
1314
#include <optypes.hpp>
1415
#include <err_cuda.hpp>
@@ -17,8 +18,8 @@
1718
namespace cuda
1819
{
1920
template<typename T, af_op_t op>
20-
Array<T>* arithOp(const Array<T> &lhs, const Array<T> &rhs)
21+
Array<T>* arithOp(const Array<T> &lhs, const Array<T> &rhs, const af::dim4 &odims)
2122
{
22-
return createBinaryNode<T, T, op>(lhs, rhs);
23+
return createBinaryNode<T, T, op>(lhs, rhs, odims);
2324
}
2425
}

src/backend/cuda/binary.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10+
#include <af/dim4.hpp>
1011
#include <Array.hpp>
1112
#include <optypes.hpp>
1213
#include <math.hpp>
@@ -64,7 +65,7 @@ BINARY(hypot)
6465
#undef BINARY
6566

6667
template<typename To, typename Ti, af_op_t op>
67-
Array<To> *createBinaryNode(const Array<Ti> &lhs, const Array<Ti> &rhs)
68+
Array<To> *createBinaryNode(const Array<Ti> &lhs, const Array<Ti> &rhs, const af::dim4 &odims)
6869
{
6970
BinOp<To, Ti, op> bop;
7071

@@ -76,7 +77,7 @@ Array<To> *createBinaryNode(const Array<Ti> &lhs, const Array<Ti> &rhs)
7677
lhs_node,
7778
rhs_node, (int)(op));
7879

79-
return createNodeArray<To>(lhs.dims(), JIT::Node_ptr(reinterpret_cast<JIT::Node *>(node)));
80+
return createNodeArray<To>(odims, JIT::Node_ptr(reinterpret_cast<JIT::Node *>(node)));
8081
}
8182

8283
}

src/backend/cuda/complex.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <af/defines.h>
1111
#include <af/array.h>
12+
#include <af/dim4.hpp>
1213
#include <Array.hpp>
1314
#include <optypes.hpp>
1415
#include <err_cuda.hpp>
@@ -40,7 +41,7 @@ namespace cuda
4041
template<> STATIC_ const char *conj_name<cdouble>() { return "@___conjz"; }
4142

4243
template<typename To, typename Ti>
43-
Array<To>* cplx(const Array<Ti> &lhs, const Array<Ti> &rhs)
44+
Array<To>* cplx(const Array<Ti> &lhs, const Array<Ti> &rhs, const af::dim4 &odims)
4445
{
4546
JIT::Node_ptr lhs_node = lhs.getNode();
4647
JIT::Node_ptr rhs_node = rhs.getNode();
@@ -50,7 +51,7 @@ namespace cuda
5051
lhs_node,
5152
rhs_node, (int)(af_cplx2_t));
5253

53-
return createNodeArray<To>(lhs.dims(), JIT::Node_ptr(reinterpret_cast<JIT::Node *>(node)));
54+
return createNodeArray<To>(odims, JIT::Node_ptr(reinterpret_cast<JIT::Node *>(node)));
5455
}
5556

5657
template<typename To, typename Ti>

src/backend/cuda/logic.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99

1010
#include <af/defines.h>
1111
#include <af/array.h>
12+
#include <af/dim4.hpp>
1213
#include <Array.hpp>
1314
#include <optypes.hpp>
1415
#include <err_cuda.hpp>
1516

1617
namespace cuda
1718
{
1819
template<typename T, af_op_t op>
19-
Array<uchar>* logicOp(const Array<T> &lhs, const Array<T> &rhs)
20+
Array<uchar>* logicOp(const Array<T> &lhs, const Array<T> &rhs, const af::dim4 &odims)
2021
{
21-
return createBinaryNode<uchar, T, op>(lhs, rhs);
22+
return createBinaryNode<uchar, T, op>(lhs, rhs, odims);
2223
}
2324
}

0 commit comments

Comments
 (0)