Skip to content

Commit 2262bb9

Browse files
committed
FEAT: Add missing element wise functions
- arg, factorial, trunc, pow2, root, sign
1 parent 61debc5 commit 2262bb9

11 files changed

Lines changed: 293 additions & 5 deletions

File tree

include/af/arith.h

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,28 @@ namespace af
156156
*/
157157
AFAPI array abs (const array &in);
158158

159+
/**
160+
C++ Interface for arg
161+
162+
\param[in] in is input array
163+
\return phase of \p in
164+
165+
\ingroup arith_func_arg
166+
*/
167+
AFAPI array arg (const array &in);
168+
169+
/**
170+
C++ Interface for getting the sign of input
171+
172+
\param[in] in is input array
173+
\return the sign of each element of input
174+
175+
\note output is 1 for negative numbers and 0 for positive numbers
176+
177+
\ingroup arith_func_sign
178+
*/
179+
AFAPI array sign (const array &in);
180+
159181
/**
160182
C++ Interface for rounding an array of numbers
161183
@@ -168,6 +190,17 @@ namespace af
168190
*/
169191
AFAPI array round (const array &in);
170192

193+
/**
194+
C++ Interface for truncating an array of numbers
195+
196+
\param[in] in is input array
197+
\return values truncated to nearest integer not greater than input values
198+
199+
\ingroup arith_func_trunc
200+
*/
201+
AFAPI array trunc (const array &in);
202+
203+
171204
/**
172205
C++ Interface for flooring an array of numbers
173206
@@ -447,6 +480,39 @@ namespace af
447480
*/
448481
AFAPI array atanh (const array &in);
449482

483+
/**
484+
C++ Interface for nth root
485+
486+
\param[in] lhs is nth root
487+
\param[in] rhs is value
488+
\return \p lhs th root of \p rhs
489+
490+
\ingroup arith_func_root
491+
*/
492+
AFAPI array root (const array &lhs, const array &rhs);
493+
494+
/**
495+
C++ Interface for nth root
496+
497+
\param[in] lhs is nth root
498+
\param[in] rhs is value
499+
\return \p lhs th root of \p rhs
500+
501+
\ingroup arith_func_root
502+
*/
503+
AFAPI array root (const array &lhs, const double rhs);
504+
505+
/**
506+
C++ Interface for nth root
507+
508+
\param[in] lhs is nth root
509+
\param[in] rhs is value
510+
\return \p lhs th root of \p rhs
511+
512+
\ingroup arith_func_root
513+
*/
514+
AFAPI array root (const double lhs, const array &rhs);
515+
450516

451517
/**
452518
C++ Interface for power when base and exponent are arrays
@@ -481,6 +547,16 @@ namespace af
481547
*/
482548
AFAPI array pow (const double lhs, const array &rhs);
483549

550+
/**
551+
C++ Interface for power of 2
552+
553+
\param[in] in is exponent
554+
\return 2 raised to power of \p in
555+
556+
\ingroup arith_func_pow2
557+
*/
558+
AFAPI array pow2 (const array &in);
559+
484560
/**
485561
C++ Interface for exponential of an array
486562
@@ -585,6 +661,16 @@ namespace af
585661
*/
586662
AFAPI array cbrt (const array &in);
587663

664+
/**
665+
C++ Interface for factorial of input
666+
667+
\param[in] in is input
668+
\return the factorial function of input
669+
670+
\ingroup arith_func_factorial
671+
*/
672+
AFAPI array factorial (const array &in);
673+
588674
/**
589675
C++ Interface for gamma function of input
590676
@@ -948,6 +1034,30 @@ extern "C" {
9481034
*/
9491035
AFAPI af_err af_abs (af_array *out, const af_array in);
9501036

1037+
/**
1038+
C Interface for finding the phase
1039+
1040+
\param[out] out will the phase of \p in
1041+
\param[in] in is input array
1042+
\return \ref AF_SUCCESS if the execution completes properly
1043+
1044+
\ingroup arith_func_arg
1045+
*/
1046+
AFAPI af_err af_arg (af_array *out, const af_array in);
1047+
1048+
/**
1049+
C Interface for finding the sign of the input
1050+
1051+
\param[out] out will contain the sign of each element of the input arrays
1052+
\param[in] in is input array
1053+
\return \ref AF_SUCCESS if the execution completes properly
1054+
1055+
\note output is 1 for negative numbers and 0 for positive numbers
1056+
1057+
\ingroup arith_func_round
1058+
*/
1059+
AFAPI af_err af_sign (af_array *out, const af_array in);
1060+
9511061
/**
9521062
C Interface for rounding an array of numbers
9531063
@@ -961,6 +1071,17 @@ extern "C" {
9611071
*/
9621072
AFAPI af_err af_round (af_array *out, const af_array in);
9631073

1074+
/**
1075+
C Interface for truncing an array of numbers
1076+
1077+
\param[out] out will contain values truncated to nearest integer not greater than input
1078+
\param[in] in is input array
1079+
\return \ref AF_SUCCESS if the execution completes properly
1080+
1081+
\ingroup arith_func_trunc
1082+
*/
1083+
AFAPI af_err af_trunc (af_array *out, const af_array in);
1084+
9641085
/**
9651086
C Interface for flooring an array of numbers
9661087
@@ -1198,6 +1319,20 @@ extern "C" {
11981319
*/
11991320
AFAPI af_err af_atanh (af_array *out, const af_array in);
12001321

1322+
/**
1323+
C Interface for root
1324+
1325+
\param[out] out will contain \p lhs th root of \p rhs
1326+
\param[in] lhs is nth root
1327+
\param[in] rhs is value
1328+
\param[in] batch specifies if operations need to be performed in batch mode
1329+
\return \ref AF_SUCCESS if the execution completes properly
1330+
1331+
\ingroup arith_func_root
1332+
*/
1333+
AFAPI af_err af_root (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
1334+
1335+
12011336
/**
12021337
C Interface for power
12031338
@@ -1211,6 +1346,17 @@ extern "C" {
12111346
*/
12121347
AFAPI af_err af_pow (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
12131348

1349+
/**
1350+
C Interface for power of two
1351+
1352+
\param[out] out will contain the values of 2 to the power \p in
1353+
\param[in] in is exponent
1354+
\return \ref AF_SUCCESS if the execution completes properly
1355+
1356+
\ingroup arith_func_pow2
1357+
*/
1358+
AFAPI af_err af_pow2 (af_array *out, const af_array in);
1359+
12141360
/**
12151361
C Interface for exponential of an array
12161362
@@ -1321,6 +1467,17 @@ extern "C" {
13211467
*/
13221468
AFAPI af_err af_cbrt (af_array *out, const af_array in);
13231469

1470+
/**
1471+
C Interface for the factorial
1472+
1473+
\param[out] out will contain the result of factorial of \p in
1474+
\param[in] in is input
1475+
\return \ref AF_SUCCESS if the execution completes properly
1476+
1477+
\ingroup arith_func_factorial
1478+
*/
1479+
AFAPI af_err af_factorial (af_array *out, const af_array in);
1480+
13241481
/**
13251482
C Interface for the gamma function
13261483

src/api/c/binary.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,31 @@ af_err af_pow(af_array *out, const af_array lhs, const af_array rhs, const bool
147147
return af_arith_real<af_pow_t>(out, lhs, rhs, batchMode);
148148
}
149149

150+
af_err af_root(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
151+
{
152+
try {
153+
ArrayInfo linfo = getInfo(lhs);
154+
ArrayInfo rinfo = getInfo(rhs);
155+
if (linfo.isComplex() || rinfo.isComplex()) {
156+
AF_ERROR("Powers of Complex numbers not supported", AF_ERR_NOT_SUPPORTED);
157+
}
158+
159+
af_array one;
160+
AF_CHECK(af_constant(&one, 1, linfo.ndims(), linfo.dims().get(), linfo.getType()));
161+
162+
af_array inv_lhs;
163+
AF_CHECK(af_div(&inv_lhs, one, lhs, batchMode));
164+
165+
AF_CHECK(af_arith_real<af_pow_t>(out, rhs, inv_lhs, batchMode));
166+
167+
AF_CHECK(af_release_array(one));
168+
AF_CHECK(af_release_array(inv_lhs));
169+
170+
} CATCHALL;
171+
172+
return AF_SUCCESS;
173+
}
174+
150175
af_err af_atan2(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
151176
{
152177
try {

src/api/c/optypes.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ typedef enum {
7474
af_floor_t,
7575
af_ceil_t,
7676
af_round_t,
77+
af_trunc_t,
78+
af_sign_t,
7779

7880
af_rem_t,
7981
af_mod_t,

src/api/c/unary.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ UNARY(asinh)
7979
UNARY(acosh)
8080
UNARY(atanh)
8181

82+
UNARY(trunc)
83+
UNARY(sign)
8284
UNARY(round)
8385
UNARY(floor)
8486
UNARY(ceil)
@@ -99,6 +101,7 @@ UNARY(cbrt)
99101
UNARY(tgamma)
100102
UNARY(lgamma)
101103

104+
102105
af_err af_not(af_array *out, const af_array in)
103106
{
104107
try {
@@ -118,6 +121,75 @@ af_err af_not(af_array *out, const af_array in)
118121
return AF_SUCCESS;
119122
}
120123

124+
af_err af_arg(af_array *out, const af_array in)
125+
{
126+
try {
127+
128+
ArrayInfo in_info = getInfo(in);
129+
130+
if (!in_info.isComplex()) {
131+
return af_constant(out, 0,
132+
in_info.ndims(),
133+
in_info.dims().get(), in_info.getType());
134+
}
135+
136+
af_array real;
137+
af_array imag;
138+
139+
AF_CHECK(af_real(&real, in));
140+
AF_CHECK(af_imag(&imag, in));
141+
142+
AF_CHECK(af_atan2(out, imag, real, false));
143+
144+
AF_CHECK(af_release_array(real));
145+
AF_CHECK(af_release_array(imag));
146+
} CATCHALL;
147+
148+
return AF_SUCCESS;
149+
}
150+
151+
af_err af_pow2(af_array *out, const af_array in)
152+
{
153+
try {
154+
155+
af_array two;
156+
ArrayInfo in_info = getInfo(in);
157+
158+
AF_CHECK(af_constant(&two, 2,
159+
in_info.ndims(),
160+
in_info.dims().get(), in_info.getType()));
161+
162+
AF_CHECK(af_pow(out, two, in, false));
163+
164+
AF_CHECK(af_release_array(two));
165+
} CATCHALL;
166+
167+
return AF_SUCCESS;
168+
}
169+
170+
af_err af_factorial(af_array *out, const af_array in)
171+
{
172+
try {
173+
174+
af_array one;
175+
ArrayInfo in_info = getInfo(in);
176+
177+
AF_CHECK(af_constant(&one, 1,
178+
in_info.ndims(),
179+
in_info.dims().get(), in_info.getType()));
180+
181+
af_array inp1;
182+
AF_CHECK(af_add(&inp1, one, in, false));
183+
184+
AF_CHECK(af_tgamma(out, inp1));
185+
186+
AF_CHECK(af_release_array(one));
187+
AF_CHECK(af_release_array(inp1));
188+
} CATCHALL;
189+
190+
return AF_SUCCESS;
191+
}
192+
121193
template<typename T, af_op_t op>
122194
static inline af_array checkOp(const af_array in)
123195
{

src/api/cpp/binary.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ namespace af
2424
return array(out); \
2525
}
2626

27-
INSTANTIATE(min, af_minof)
28-
INSTANTIATE(max, af_maxof)
29-
INSTANTIATE(pow, af_pow )
30-
INSTANTIATE(rem, af_rem )
31-
INSTANTIATE(mod, af_mod )
27+
INSTANTIATE(min , af_minof)
28+
INSTANTIATE(max , af_maxof)
29+
INSTANTIATE(pow , af_pow )
30+
INSTANTIATE(root, af_root )
31+
INSTANTIATE(rem , af_rem )
32+
INSTANTIATE(mod , af_mod )
3233

3334
INSTANTIATE(complex, af_cplx2)
3435
INSTANTIATE(atan2, af_atan2)
@@ -47,6 +48,7 @@ namespace af
4748
WRAPPER(min)
4849
WRAPPER(max)
4950
WRAPPER(pow)
51+
WRAPPER(root)
5052
WRAPPER(rem)
5153
WRAPPER(mod)
5254
WRAPPER(complex)

0 commit comments

Comments
 (0)