/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include using af::dim4; using namespace detail; template static inline af_array sort(const af_array in, const unsigned dim, const bool isAscending) { const Array &inArray = getArray(in); if(isAscending) { return getHandle(sort(inArray, dim)); } else { return getHandle(sort(inArray, dim)); } } af_err af_sort(af_array *out, const af_array in, const unsigned dim, const bool isAscending) { try { ArrayInfo info = getInfo(in); af_dtype type = info.getType(); DIM_ASSERT(1, info.elements() > 0); // Only Dim 0 supported ARG_ASSERT(2, dim == 0); af_array val; switch(type) { case f32: val = sort(in, dim, isAscending); break; case f64: val = sort(in, dim, isAscending); break; case s32: val = sort(in, dim, isAscending); break; case u32: val = sort(in, dim, isAscending); break; case s16: val = sort(in, dim, isAscending); break; case u16: val = sort(in, dim, isAscending); break; case s64: val = sort(in, dim, isAscending); break; case u64: val = sort(in, dim, isAscending); break; case u8: val = sort(in, dim, isAscending); break; case b8: val = sort(in, dim, isAscending); break; default: TYPE_ERROR(1, type); } std::swap(*out, val); } CATCHALL; return AF_SUCCESS; } template static inline void sort_index(af_array *val, af_array *idx, const af_array in, const unsigned dim, const bool isAscending) { const Array &inArray = getArray(in); // Initialize Dummy Arrays Array valArray = createEmptyArray(af::dim4()); Array idxArray = createEmptyArray(af::dim4()); if(isAscending) { sort_index(valArray, idxArray, inArray, dim); } else { sort_index(valArray, idxArray, inArray, dim); } *val = getHandle(valArray); *idx = getHandle(idxArray); } af_err af_sort_index(af_array *out, af_array *indices, const af_array in, const unsigned dim, const bool isAscending) { try { ArrayInfo info = getInfo(in); af_dtype type = info.getType(); DIM_ASSERT(2, info.elements() > 0); // Only Dim 0 supported ARG_ASSERT(3, dim == 0); af_array val; af_array idx; switch(type) { case f32: sort_index(&val, &idx, in, dim, isAscending); break; case f64: sort_index(&val, &idx, in, dim, isAscending); break; case s32: sort_index(&val, &idx, in, dim, isAscending); break; case u32: sort_index(&val, &idx, in, dim, isAscending); break; case s16: sort_index(&val, &idx, in, dim, isAscending); break; case u16: sort_index(&val, &idx, in, dim, isAscending); break; case s64: sort_index(&val, &idx, in, dim, isAscending); break; case u64: sort_index(&val, &idx, in, dim, isAscending); break; case u8: sort_index(&val, &idx, in, dim, isAscending); break; case b8: sort_index(&val, &idx, in, dim, isAscending); break; default: TYPE_ERROR(1, type); } std::swap(*out , val); std::swap(*indices, idx); } CATCHALL; return AF_SUCCESS; } template static inline void sort_by_key(af_array *okey, af_array *oval, const af_array ikey, const af_array ival, const unsigned dim, const bool isAscending) { const Array &ikeyArray = getArray(ikey); const Array &ivalArray = getArray(ival); // Initialize Dummy Arrays Array okeyArray = createEmptyArray(af::dim4()); Array ovalArray = createEmptyArray(af::dim4()); if(isAscending) { sort_by_key(okeyArray, ovalArray, ikeyArray, ivalArray, dim); } else { sort_by_key(okeyArray, ovalArray, ikeyArray, ivalArray, dim); } *okey = getHandle(okeyArray); *oval = getHandle(ovalArray); } template void sort_by_key_tmplt(af_array *okey, af_array *oval, const af_array ikey, const af_array ival, const unsigned dim, const bool isAscending) { ArrayInfo info = getInfo(ival); af_dtype vtype = info.getType(); switch(vtype) { case f32: sort_by_key(okey, oval, ikey, ival, dim, isAscending); break; case f64: sort_by_key(okey, oval, ikey, ival, dim, isAscending); break; case s32: sort_by_key(okey, oval, ikey, ival, dim, isAscending); break; case u32: sort_by_key(okey, oval, ikey, ival, dim, isAscending); break; case s16: sort_by_key(okey, oval, ikey, ival, dim, isAscending); break; case u16: sort_by_key(okey, oval, ikey, ival, dim, isAscending); break; case s64: sort_by_key(okey, oval, ikey, ival, dim, isAscending); break; case u64: sort_by_key(okey, oval, ikey, ival, dim, isAscending); break; case u8: sort_by_key(okey, oval, ikey, ival, dim, isAscending); break; case b8: sort_by_key(okey, oval, ikey, ival, dim, isAscending); break; default: TYPE_ERROR(1, vtype); } return; } af_err af_sort_by_key(af_array *out_keys, af_array *out_values, const af_array keys, const af_array values, const unsigned dim, const bool isAscending) { try { ArrayInfo info = getInfo(keys); af_dtype type = info.getType(); ArrayInfo vinfo = getInfo(values); DIM_ASSERT(3, info.elements() > 0); DIM_ASSERT(4, info.dims() == vinfo.dims()); // Only Dim 0 supported ARG_ASSERT(5, dim == 0); af_array oKey; af_array oVal; switch(type) { case f32: sort_by_key_tmplt(&oKey, &oVal, keys, values, dim, isAscending); break; case f64: sort_by_key_tmplt(&oKey, &oVal, keys, values, dim, isAscending); break; case s32: sort_by_key_tmplt(&oKey, &oVal, keys, values, dim, isAscending); break; case u32: sort_by_key_tmplt(&oKey, &oVal, keys, values, dim, isAscending); break; case s16: sort_by_key_tmplt(&oKey, &oVal, keys, values, dim, isAscending); break; case u16: sort_by_key_tmplt(&oKey, &oVal, keys, values, dim, isAscending); break; case s64: sort_by_key_tmplt(&oKey, &oVal, keys, values, dim, isAscending); break; case u64: sort_by_key_tmplt(&oKey, &oVal, keys, values, dim, isAscending); break; case u8: sort_by_key_tmplt(&oKey, &oVal, keys, values, dim, isAscending); break; case b8: sort_by_key_tmplt(&oKey, &oVal, keys, values, dim, isAscending); break; default: TYPE_ERROR(1, type); } std::swap(*out_keys , oKey); std::swap(*out_values , oVal); } CATCHALL; return AF_SUCCESS; }