Skip to content

Commit 0a66851

Browse files
authored
Ragged reduction (#2786)
* initial ragged max api and cuda implementation * move ragged lengths into single ireduce kernel implementation * adds opencl, cpu ragged max to ireduce * fix issue with cuda bounds for higher dimensions, adds range based tests * opencl kernel updates for higher dimensions * check out of bounds access in lengths array * fix incorrect nullptr for empty buffer in cl backend, clang-format * update api * remove old tests
1 parent d137012 commit 0a66851

17 files changed

Lines changed: 507 additions & 62 deletions

File tree

include/af/algorithm.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,24 @@ namespace af
216216
const int dim = -1);
217217
#endif
218218

219+
#if AF_API_VERSION >= 38
220+
/**
221+
C++ Interface for ragged max values in an array
222+
Uses an additional input array to determine the number of elements to use along the reduction axis.
223+
224+
\param[out] val will contain the maximum ragged values in \p in along \p dim according to \p ragged_len
225+
\param[out] idx will contain the locations of the maximum ragged values in \p in along \p dim according to \p ragged_len
226+
\param[in] in contains the input values to be reduced
227+
\param[in] ragged_len array containing number of elements to use when reducing along \p dim
228+
\param[in] dim The dimension along which the max operation occurs
229+
230+
\ingroup reduce_func_max
231+
232+
\note NaN values are ignored
233+
*/
234+
AFAPI void max(array &val, array &idx, const array &in, const array &ragged_len, const int dim);
235+
#endif
236+
219237
/**
220238
C++ Interface for checking all true values in an array
221239
@@ -838,6 +856,25 @@ extern "C" {
838856
const int dim);
839857
#endif
840858

859+
#if AF_API_VERSION >= 38
860+
/**
861+
C Interface for finding ragged max values in an array
862+
Uses an additional input array to determine the number of elements to use along the reduction axis.
863+
864+
\param[out] val will contain the maximum ragged values in \p in along \p dim according to \p ragged_len
865+
\param[out] idx will contain the locations of the maximum ragged values in \p in along \p dim according to \p ragged_len
866+
\param[in] in contains the input values to be reduced
867+
\param[in] ragged_len array containing number of elements to use when reducing along \p dim
868+
\param[in] dim The dimension along which the max operation occurs
869+
\return \ref AF_SUCCESS if the execution completes properly
870+
871+
\ingroup reduce_func_max
872+
873+
\note NaN values are ignored
874+
*/
875+
AFAPI af_err af_max_ragged(af_array *val, af_array *idx, const af_array in, const af_array ragged_len, const int dim);
876+
#endif
877+
841878
/**
842879
C Interface for checking all true values in an array
843880

src/api/c/reduce.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,22 @@ static inline void ireduce(af_array *res, af_array *loc, const af_array in,
752752
*loc = getHandle(Loc);
753753
}
754754

755+
template<af_op_t op, typename T>
756+
static inline void rreduce(af_array *res, af_array *loc, const af_array in,
757+
const int dim, const af_array ragged_len) {
758+
const Array<T> In = getArray<T>(in);
759+
const Array<uint> Len = getArray<uint>(ragged_len);
760+
dim4 odims = In.dims();
761+
odims[dim] = 1;
762+
763+
Array<T> Res = createEmptyArray<T>(odims);
764+
Array<uint> Loc = createEmptyArray<uint>(odims);
765+
rreduce<op, T>(Res, Loc, In, dim, Len);
766+
767+
*res = getHandle(Res);
768+
*loc = getHandle(Loc);
769+
}
770+
755771
template<af_op_t op>
756772
static af_err ireduce_common(af_array *val, af_array *idx, const af_array in,
757773
const int dim) {
@@ -804,6 +820,78 @@ af_err af_imax(af_array *val, af_array *idx, const af_array in, const int dim) {
804820
return ireduce_common<af_max_t>(val, idx, in, dim);
805821
}
806822

823+
template<af_op_t op>
824+
static af_err rreduce_common(af_array *val, af_array *idx, const af_array in,
825+
const af_array ragged_len, const int dim) {
826+
try {
827+
ARG_ASSERT(3, dim >= 0);
828+
ARG_ASSERT(3, dim < 4);
829+
830+
const ArrayInfo &in_info = getInfo(in);
831+
ARG_ASSERT(2, in_info.ndims() > 0);
832+
833+
if (dim >= (int)in_info.ndims()) {
834+
*val = retain(in);
835+
*idx = createHandleFromValue<uint>(in_info.dims(), 0);
836+
return AF_SUCCESS;
837+
}
838+
839+
// TODO: make sure ragged_len.dims == in.dims(), except on reduced dim
840+
const ArrayInfo &ragged_info = getInfo(ragged_len);
841+
dim4 test_dim = in_info.dims();
842+
test_dim[dim] = 1;
843+
ARG_ASSERT(4, test_dim == ragged_info.dims());
844+
845+
af_dtype keytype = ragged_info.getType();
846+
if (keytype != u32) { TYPE_ERROR(4, keytype); }
847+
848+
af_dtype type = in_info.getType();
849+
af_array res, loc;
850+
851+
switch (type) {
852+
case f32:
853+
rreduce<op, float>(&res, &loc, in, dim, ragged_len);
854+
break;
855+
case f64:
856+
rreduce<op, double>(&res, &loc, in, dim, ragged_len);
857+
break;
858+
case c32:
859+
rreduce<op, cfloat>(&res, &loc, in, dim, ragged_len);
860+
break;
861+
case c64:
862+
rreduce<op, cdouble>(&res, &loc, in, dim, ragged_len);
863+
break;
864+
case u32: rreduce<op, uint>(&res, &loc, in, dim, ragged_len); break;
865+
case s32: rreduce<op, int>(&res, &loc, in, dim, ragged_len); break;
866+
case u64:
867+
rreduce<op, uintl>(&res, &loc, in, dim, ragged_len);
868+
break;
869+
case s64: rreduce<op, intl>(&res, &loc, in, dim, ragged_len); break;
870+
case u16:
871+
rreduce<op, ushort>(&res, &loc, in, dim, ragged_len);
872+
break;
873+
case s16:
874+
rreduce<op, short>(&res, &loc, in, dim, ragged_len);
875+
break;
876+
case b8: rreduce<op, char>(&res, &loc, in, dim, ragged_len); break;
877+
case u8: rreduce<op, uchar>(&res, &loc, in, dim, ragged_len); break;
878+
case f16: rreduce<op, half>(&res, &loc, in, dim, ragged_len); break;
879+
default: TYPE_ERROR(2, type);
880+
}
881+
882+
std::swap(*val, res);
883+
std::swap(*idx, loc);
884+
}
885+
CATCHALL;
886+
887+
return AF_SUCCESS;
888+
}
889+
890+
af_err af_max_ragged(af_array *val, af_array *idx, const af_array in,
891+
const af_array ragged_len, const int dim) {
892+
return rreduce_common<af_max_t>(val, idx, in, ragged_len, dim);
893+
}
894+
807895
template<af_op_t op, typename T>
808896
static inline T ireduce_all(unsigned *loc, const af_array in) {
809897
return ireduce_all<op, T>(loc, getArray<T>(in));

src/api/cpp/reduce.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ void maxByKey(array &keys_out, array &vals_out, const array &keys,
106106
vals_out = array(ovals);
107107
}
108108

109+
void max(array &val, array &idx, const array &in, const array &ragged_len,
110+
const int dim) {
111+
af_array oval, oidx;
112+
AF_THROW(af_max_ragged(&oval, &oidx, in.get(), ragged_len.get(), dim));
113+
val = array(oval);
114+
idx = array(oidx);
115+
}
116+
109117
// 2.1 compatibility
110118
array alltrue(const array &in, const int dim) { return allTrue(in, dim); }
111119
array allTrue(const array &in, const int dim) {

src/api/unified/algorithm.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,9 @@ af_err af_set_intersect(af_array *out, const af_array first,
176176
CHECK_ARRAYS(first, second);
177177
CALL(af_set_intersect, out, first, second, is_unique);
178178
}
179+
180+
af_err af_max_ragged(af_array *vals, af_array *idx, const af_array in,
181+
const af_array ragged_len, const int dim) {
182+
CHECK_ARRAYS(in, ragged_len);
183+
CALL(af_max_ragged, vals, idx, in, ragged_len, dim);
184+
}

src/backend/cpu/ireduce.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,36 @@ using common::half;
2323
namespace cpu {
2424

2525
template<af_op_t op, typename T>
26-
using ireduce_dim_func = std::function<void(Param<T>, Param<uint>, const dim_t,
27-
CParam<T>, const dim_t, const int)>;
26+
using ireduce_dim_func =
27+
std::function<void(Param<T>, Param<uint>, const dim_t, CParam<T>,
28+
const dim_t, const int, CParam<uint>)>;
2829

2930
template<af_op_t op, typename T>
3031
void ireduce(Array<T> &out, Array<uint> &loc, const Array<T> &in,
3132
const int dim) {
32-
dim4 odims = in.dims();
33-
odims[dim] = 1;
33+
dim4 odims = in.dims();
34+
odims[dim] = 1;
35+
Array<uint> rlen = createEmptyArray<uint>(af::dim4(0));
3436
static const ireduce_dim_func<op, T> ireduce_funcs[] = {
3537
kernel::ireduce_dim<op, T, 1>(), kernel::ireduce_dim<op, T, 2>(),
3638
kernel::ireduce_dim<op, T, 3>(), kernel::ireduce_dim<op, T, 4>()};
3739

38-
getQueue().enqueue(ireduce_funcs[in.ndims() - 1], out, loc, 0, in, 0, dim);
40+
getQueue().enqueue(ireduce_funcs[in.ndims() - 1], out, loc, 0, in, 0, dim,
41+
rlen);
42+
}
43+
44+
template<af_op_t op, typename T>
45+
void rreduce(Array<T> &out, Array<uint> &loc, const Array<T> &in, const int dim,
46+
const Array<uint> &rlen) {
47+
dim4 odims = in.dims();
48+
odims[dim] = 1;
49+
50+
static const ireduce_dim_func<op, T> ireduce_funcs[] = {
51+
kernel::ireduce_dim<op, T, 1>(), kernel::ireduce_dim<op, T, 2>(),
52+
kernel::ireduce_dim<op, T, 3>(), kernel::ireduce_dim<op, T, 4>()};
53+
54+
getQueue().enqueue(ireduce_funcs[in.ndims() - 1], out, loc, 0, in, 0, dim,
55+
rlen);
3956
}
4057

4158
template<af_op_t op, typename T>
@@ -72,6 +89,9 @@ T ireduce_all(unsigned *loc, const Array<T> &in) {
7289
#define INSTANTIATE(ROp, T) \
7390
template void ireduce<ROp, T>(Array<T> & out, Array<uint> & loc, \
7491
const Array<T> &in, const int dim); \
92+
template void rreduce<ROp, T>(Array<T> & out, Array<uint> & loc, \
93+
const Array<T> &in, const int dim, \
94+
const Array<uint> &rlen); \
7595
template T ireduce_all<ROp, T>(unsigned *loc, const Array<T> &in);
7696

7797
// min

src/backend/cpu/ireduce.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ template<af_op_t op, typename T>
1515
void ireduce(Array<T> &out, Array<uint> &loc, const Array<T> &in,
1616
const int dim);
1717

18+
template<af_op_t op, typename T>
19+
void rreduce(Array<T> &out, Array<uint> &loc, const Array<T> &in, const int dim,
20+
const Array<uint> &rlen);
21+
1822
template<af_op_t op, typename T>
1923
T ireduce_all(unsigned *loc, const Array<T> &in);
2024
} // namespace cpu

src/backend/cpu/kernel/ireduce.hpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111
#include <Param.hpp>
1212
#include <ops.hpp>
13+
#include <algorithm>
1314

1415
namespace cpu {
1516
namespace kernel {
@@ -64,15 +65,15 @@ template<af_op_t op, typename T, int D>
6465
struct ireduce_dim {
6566
void operator()(Param<T> output, Param<uint> locParam,
6667
const dim_t outOffset, CParam<T> input,
67-
const dim_t inOffset, const int dim) {
68+
const dim_t inOffset, const int dim, CParam<uint> rlen) {
6869
const af::dim4 odims = output.dims();
6970
const af::dim4 ostrides = output.strides();
7071
const af::dim4 istrides = input.strides();
7172
const int D1 = D - 1;
7273
for (dim_t i = 0; i < odims[D1]; i++) {
7374
ireduce_dim<op, T, D1>()(output, locParam,
7475
outOffset + i * ostrides[D1], input,
75-
inOffset + i * istrides[D1], dim);
76+
inOffset + i * istrides[D1], dim, rlen);
7677
}
7778
}
7879
};
@@ -81,19 +82,20 @@ template<af_op_t op, typename T>
8182
struct ireduce_dim<op, T, 0> {
8283
void operator()(Param<T> output, Param<uint> locParam,
8384
const dim_t outOffset, CParam<T> input,
84-
const dim_t inOffset, const int dim) {
85+
const dim_t inOffset, const int dim, CParam<uint> rlen) {
8586
const af::dim4 idims = input.dims();
8687
const af::dim4 istrides = input.strides();
8788

88-
T const *const in = input.get();
89-
T *out = output.get();
90-
uint *loc = locParam.get();
89+
T const *const in = input.get();
90+
T *out = output.get();
91+
uint *loc = locParam.get();
92+
const uint *rlenptr = (rlen.get()) ? rlen.get() + outOffset : nullptr;
9193

9294
dim_t stride = istrides[dim];
9395
MinMaxOp<op, T> Op(in[inOffset], 0);
94-
for (dim_t i = 0; i < idims[dim]; i++) {
95-
Op(in[inOffset + i * stride], i);
96-
}
96+
int lim =
97+
(rlenptr) ? std::min(idims[dim], (dim_t)*rlenptr) : idims[dim];
98+
for (dim_t i = 0; i < lim; i++) { Op(in[inOffset + i * stride], i); }
9799

98100
out[outOffset] = Op.m_val;
99101
loc[outOffset] = Op.m_idx;

src/backend/cuda/ireduce.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,14 @@ namespace cuda {
2626
template<af_op_t op, typename T>
2727
void ireduce(Array<T> &out, Array<uint> &loc, const Array<T> &in,
2828
const int dim) {
29-
kernel::ireduce<T, op>(out, loc.get(), in, dim);
29+
Array<uint> rlen = createEmptyArray<uint>(af::dim4(0));
30+
kernel::ireduce<T, op>(out, loc.get(), in, dim, rlen);
31+
}
32+
33+
template<af_op_t op, typename T>
34+
void rreduce(Array<T> &out, Array<uint> &loc, const Array<T> &in, const int dim,
35+
const Array<uint> &rlen) {
36+
kernel::ireduce<T, op>(out, loc.get(), in, dim, rlen);
3037
}
3138

3239
template<af_op_t op, typename T>
@@ -37,6 +44,9 @@ T ireduce_all(unsigned *loc, const Array<T> &in) {
3744
#define INSTANTIATE(ROp, T) \
3845
template void ireduce<ROp, T>(Array<T> & out, Array<uint> & loc, \
3946
const Array<T> &in, const int dim); \
47+
template void rreduce<ROp, T>(Array<T> & out, Array<uint> & loc, \
48+
const Array<T> &in, const int dim, \
49+
const Array<uint> &rlen); \
4050
template T ireduce_all<ROp, T>(unsigned *loc, const Array<T> &in);
4151

4252
// min

src/backend/cuda/ireduce.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ template<af_op_t op, typename T>
1515
void ireduce(Array<T> &out, Array<uint> &loc, const Array<T> &in,
1616
const int dim);
1717

18+
template<af_op_t op, typename T>
19+
void rreduce(Array<T> &out, Array<uint> &loc, const Array<T> &in, const int dim,
20+
const Array<uint> &rlen);
21+
1822
template<af_op_t op, typename T>
1923
T ireduce_all(unsigned *loc, const Array<T> &in);
2024
} // namespace cuda

0 commit comments

Comments
 (0)