@@ -752,6 +752,22 @@ static inline void ireduce(af_array *res, af_array *loc, const af_array in,
752752 *loc = getHandle (Loc);
753753}
754754
755+ template <af_op_t op, typename T>
756+ static inline void rreduce (af_array *res, af_array *loc, const af_array in,
757+ const int dim, const af_array ragged_len) {
758+ const Array<T> In = getArray<T>(in);
759+ const Array<uint> Len = getArray<uint>(ragged_len);
760+ dim4 odims = In.dims ();
761+ odims[dim] = 1 ;
762+
763+ Array<T> Res = createEmptyArray<T>(odims);
764+ Array<uint> Loc = createEmptyArray<uint>(odims);
765+ rreduce<op, T>(Res, Loc, In, dim, Len);
766+
767+ *res = getHandle (Res);
768+ *loc = getHandle (Loc);
769+ }
770+
755771template <af_op_t op>
756772static af_err ireduce_common (af_array *val, af_array *idx, const af_array in,
757773 const int dim) {
@@ -804,6 +820,78 @@ af_err af_imax(af_array *val, af_array *idx, const af_array in, const int dim) {
804820 return ireduce_common<af_max_t >(val, idx, in, dim);
805821}
806822
823+ template <af_op_t op>
824+ static af_err rreduce_common (af_array *val, af_array *idx, const af_array in,
825+ const af_array ragged_len, const int dim) {
826+ try {
827+ ARG_ASSERT (3 , dim >= 0 );
828+ ARG_ASSERT (3 , dim < 4 );
829+
830+ const ArrayInfo &in_info = getInfo (in);
831+ ARG_ASSERT (2 , in_info.ndims () > 0 );
832+
833+ if (dim >= (int )in_info.ndims ()) {
834+ *val = retain (in);
835+ *idx = createHandleFromValue<uint>(in_info.dims (), 0 );
836+ return AF_SUCCESS ;
837+ }
838+
839+ // TODO: make sure ragged_len.dims == in.dims(), except on reduced dim
840+ const ArrayInfo &ragged_info = getInfo (ragged_len);
841+ dim4 test_dim = in_info.dims ();
842+ test_dim[dim] = 1 ;
843+ ARG_ASSERT (4 , test_dim == ragged_info.dims ());
844+
845+ af_dtype keytype = ragged_info.getType ();
846+ if (keytype != u32 ) { TYPE_ERROR (4 , keytype); }
847+
848+ af_dtype type = in_info.getType ();
849+ af_array res, loc;
850+
851+ switch (type) {
852+ case f32 :
853+ rreduce<op, float >(&res, &loc, in, dim, ragged_len);
854+ break ;
855+ case f64 :
856+ rreduce<op, double >(&res, &loc, in, dim, ragged_len);
857+ break ;
858+ case c32:
859+ rreduce<op, cfloat>(&res, &loc, in, dim, ragged_len);
860+ break ;
861+ case c64:
862+ rreduce<op, cdouble>(&res, &loc, in, dim, ragged_len);
863+ break ;
864+ case u32 : rreduce<op, uint>(&res, &loc, in, dim, ragged_len); break ;
865+ case s32: rreduce<op, int >(&res, &loc, in, dim, ragged_len); break ;
866+ case u64 :
867+ rreduce<op, uintl>(&res, &loc, in, dim, ragged_len);
868+ break ;
869+ case s64: rreduce<op, intl>(&res, &loc, in, dim, ragged_len); break ;
870+ case u16 :
871+ rreduce<op, ushort>(&res, &loc, in, dim, ragged_len);
872+ break ;
873+ case s16:
874+ rreduce<op, short >(&res, &loc, in, dim, ragged_len);
875+ break ;
876+ case b8: rreduce<op, char >(&res, &loc, in, dim, ragged_len); break ;
877+ case u8 : rreduce<op, uchar>(&res, &loc, in, dim, ragged_len); break ;
878+ case f16 : rreduce<op, half>(&res, &loc, in, dim, ragged_len); break ;
879+ default : TYPE_ERROR (2 , type);
880+ }
881+
882+ std::swap (*val, res);
883+ std::swap (*idx, loc);
884+ }
885+ CATCHALL ;
886+
887+ return AF_SUCCESS ;
888+ }
889+
890+ af_err af_max_ragged (af_array *val, af_array *idx, const af_array in,
891+ const af_array ragged_len, const int dim) {
892+ return rreduce_common<af_max_t >(val, idx, in, ragged_len, dim);
893+ }
894+
807895template <af_op_t op, typename T>
808896static inline T ireduce_all (unsigned *loc, const af_array in) {
809897 return ireduce_all<op, T>(loc, getArray<T>(in));
0 commit comments