1515#include < handle.hpp>
1616#include < ops.hpp>
1717#include < scan.hpp>
18+ #include < scan_by_key.hpp>
1819#include < backend.hpp>
1920
2021using af::dim4;
@@ -26,18 +27,49 @@ static inline af_array scan(const af_array in, const int dim, bool inclusive_sca
2627 return getHandle (scan<op,Ti,To>(getArray<Ti>(in), dim, inclusive_scan));
2728}
2829
30+ template <af_op_t op, typename Ti, typename To>
31+ static inline af_array scan_key (const af_array key, const af_array in, const int dim, bool inclusive_scan = true )
32+ {
33+ const ArrayInfo& key_info = getInfo (key);
34+ af_dtype type = key_info.getType ();
35+ af_array out;
36+
37+ switch (type) {
38+ case s32: out = getHandle (scan<op, Ti, int , To>(getArray< int >(key), getArray<Ti>(in), dim, inclusive_scan)); break ;
39+ case u32 : out = getHandle (scan<op, Ti, uint, To>(getArray< uint>(key), getArray<Ti>(in), dim, inclusive_scan)); break ;
40+ case s64: out = getHandle (scan<op, Ti, intl, To>(getArray< intl>(key), getArray<Ti>(in), dim, inclusive_scan)); break ;
41+ case u64 : out = getHandle (scan<op, Ti, uintl, To>(getArray<uintl>(key), getArray<Ti>(in), dim, inclusive_scan)); break ;
42+ default :
43+ TYPE_ERROR (1 , type);
44+ }
45+ return out;
46+ }
47+
48+ template <typename Ti, typename To>
49+ static inline af_array scan_op (const af_array key, const af_array in, const int dim, af_binary_op op, bool inclusive_scan = true )
50+ {
51+ af_array out;
52+
53+ switch (op) {
54+ case AF_BINARY_ADD: out = scan_key<af_add_t , Ti, To>(key, in, dim, inclusive_scan); break ;
55+ case AF_BINARY_MUL: out = scan_key<af_mul_t , Ti, To>(key, in, dim, inclusive_scan); break ;
56+ case AF_BINARY_MIN: out = scan_key<af_min_t , Ti, To>(key, in, dim, inclusive_scan); break ;
57+ case AF_BINARY_MAX: out = scan_key<af_max_t , Ti, To>(key, in, dim, inclusive_scan); break ;
58+ // TODO Error for op in default case
59+ }
60+ return out;
61+ }
62+
2963template <typename Ti, typename To>
3064static inline af_array scan_op (const af_array in, const int dim, af_binary_op op, bool inclusive_scan)
3165{
3266 af_array out;
3367
3468 switch (op) {
35- case AF_ADD: out = scan<af_add_t , Ti, To>(in, dim, inclusive_scan); break ;
36- case AF_SUB: out = scan<af_sub_t , Ti, To>(in, dim, inclusive_scan); break ;
37- case AF_MUL: out = scan<af_mul_t , Ti, To>(in, dim, inclusive_scan); break ;
38- case AF_DIV: out = scan<af_div_t , Ti, To>(in, dim, inclusive_scan); break ;
39- case AF_MIN: out = scan<af_min_t , Ti, To>(in, dim, inclusive_scan); break ;
40- case AF_MAX: out = scan<af_max_t , Ti, To>(in, dim, inclusive_scan); break ;
69+ case AF_BINARY_ADD: out = scan<af_add_t , Ti, To>(in, dim, inclusive_scan); break ;
70+ case AF_BINARY_MUL: out = scan<af_mul_t , Ti, To>(in, dim, inclusive_scan); break ;
71+ case AF_BINARY_MIN: out = scan<af_min_t , Ti, To>(in, dim, inclusive_scan); break ;
72+ case AF_BINARY_MAX: out = scan<af_max_t , Ti, To>(in, dim, inclusive_scan); break ;
4173 // TODO Error for op in default case
4274 }
4375 return out;
@@ -125,3 +157,44 @@ af_err af_scan(af_array *out, const af_array in, const int dim, af_binary_op op,
125157
126158 return AF_SUCCESS;
127159}
160+
161+ af_err af_scan_by_key (af_array *out, const af_array key, const af_array in, const int dim, af_binary_op op, bool inclusive_scan)
162+ {
163+ ARG_ASSERT (2 , dim >= 0 );
164+ ARG_ASSERT (2 , dim < 4 );
165+
166+ try {
167+
168+ const ArrayInfo& in_info = getInfo (in);
169+
170+ if (dim >= (int )in_info.ndims ()) {
171+ *out = retain (in);
172+ return AF_SUCCESS;
173+ }
174+
175+ af_dtype type = in_info.getType ();
176+ af_array res;
177+
178+ switch (type) {
179+ case f32 : res = scan_op<float , float >(key, in, dim, op, inclusive_scan); break ;
180+ case f64 : res = scan_op<double , double >(key, in, dim, op, inclusive_scan); break ;
181+ case c32: res = scan_op<cfloat , cfloat >(key, in, dim, op, inclusive_scan); break ;
182+ case c64: res = scan_op<cdouble, cdouble>(key, in, dim, op, inclusive_scan); break ;
183+ case u32 : res = scan_op<uint , uint >(key, in, dim, op, inclusive_scan); break ;
184+ case s32: res = scan_op<int , int >(key, in, dim, op, inclusive_scan); break ;
185+ case u64 : res = scan_op<uintl , uintl >(key, in, dim, op, inclusive_scan); break ;
186+ case s64: res = scan_op<intl , intl >(key, in, dim, op, inclusive_scan); break ;
187+ case u16 : res = scan_op<ushort , uint >(key, in, dim, op, inclusive_scan); break ;
188+ case s16: res = scan_op<short , int >(key, in, dim, op, inclusive_scan); break ;
189+ case u8 : res = scan_op<uchar , uint >(key, in, dim, op, inclusive_scan); break ;
190+ case b8: res = scan_op<char , uint >(key, in, dim, op, inclusive_scan); break ;
191+ default :
192+ TYPE_ERROR (1 , type);
193+ }
194+
195+ std::swap (*out, res);
196+ }
197+ CATCHALL;
198+
199+ return AF_SUCCESS;
200+ }
0 commit comments