Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions src/backend/cpu/kernel/sort_by_key_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ void sort0ByKeyIterative(Array<Tk> okey, Array<Tv> oval)
Tk *okey_ptr = okey.get();
Tv *oval_ptr = oval.get();

std::vector<IndexPair<Tk, Tv> > X;
X.reserve(okey.dims()[0]);
typedef IndexPair<Tk, Tv> CurrentPair;

dim_t size = okey.dims()[0];
size_t bytes = size * sizeof(CurrentPair);
CurrentPair *pairKeyVal = (CurrentPair *)memAlloc<char>(bytes);

for(dim_t w = 0; w < okey.dims()[3]; w++) {
dim_t okeyW = w * okey.strides()[3];
Expand All @@ -47,23 +50,24 @@ void sort0ByKeyIterative(Array<Tk> okey, Array<Tv> oval)
dim_t okeyOffset = okeyWZ + y * okey.strides()[1];
dim_t ovalOffset = ovalWZ + y * oval.strides()[1];

X.clear();
std::transform(okey_ptr + okeyOffset, okey_ptr + okeyOffset + okey.dims()[0],
oval_ptr + ovalOffset,
std::back_inserter(X),
[](Tk v_, Tv i_) { return std::make_pair(v_, i_); }
);
Tk *okey_col_ptr = okey_ptr + okeyOffset;
Tv *oval_col_ptr = oval_ptr + ovalOffset;

for(dim_t x = 0; x < size; x++) {
pairKeyVal[x] = std::make_tuple(okey_col_ptr[x], oval_col_ptr[x]);
}

std::stable_sort(X.begin(), X.end(), IPCompare<Tk, Tv, isAscending>());
std::stable_sort(pairKeyVal, pairKeyVal + size, IPCompare<Tk, Tv, isAscending>());

for(unsigned it = 0; it < X.size(); it++) {
okey_ptr[okeyOffset + it] = X[it].first;
oval_ptr[ovalOffset + it] = X[it].second;
for(unsigned x = 0; x < size; x++) {
okey_ptr[okeyOffset + x] = std::get<0>(pairKeyVal[x]);
oval_ptr[ovalOffset + x] = std::get<1>(pairKeyVal[x]);
}
}
}
}

memFree((char *)pairKeyVal);
return;
}

Expand Down Expand Up @@ -108,24 +112,27 @@ void sortByKeyBatched(Array<Tk> okey, Array<Tv> oval)
Tk *okey_ptr = okey.get();
Tv *oval_ptr = oval.get();

std::vector<KeyIndexPair<Tk, Tv> > X;
X.reserve(okey.elements());
typedef KeyIndexPair<Tk, Tv> CurrentTuple;
size_t size = okey.elements();
size_t bytes = okey.elements() * sizeof(CurrentTuple);
CurrentTuple *tupleKeyValIdx = (CurrentTuple *)memAlloc<char>(bytes);

for(unsigned i = 0; i < okey.elements(); i++) {
X.push_back(std::make_pair(std::make_pair(okey_ptr[i], oval_ptr[i]), key[i]));
for(unsigned i = 0; i < size; i++) {
tupleKeyValIdx[i] = std::make_tuple(okey_ptr[i], oval_ptr[i], key[i]);
}

memFree(key); // key is no longer required

std::stable_sort(X.begin(), X.end(), KIPCompareV<Tk, Tv, isAscending>());
std::stable_sort(tupleKeyValIdx, tupleKeyValIdx + size, KIPCompareV<Tk, Tv, isAscending>());

std::stable_sort(X.begin(), X.end(), KIPCompareK<Tk, Tv, true>());
std::stable_sort(tupleKeyValIdx, tupleKeyValIdx + size, KIPCompareK<Tk, Tv, true>());

for(unsigned it = 0; it < okey.elements(); it++) {
okey_ptr[it] = X[it].first.first;
oval_ptr[it] = X[it].first.second;
for(unsigned x = 0; x < okey.elements(); x++) {
okey_ptr[x] = std::get<0>(tupleKeyValIdx[x]);
oval_ptr[x] = std::get<1>(tupleKeyValIdx[x]);
}

memFree((char *)tupleKeyValIdx);
return;
}

Expand Down Expand Up @@ -163,4 +170,3 @@ void sort0ByKey(Array<Tk> okey, Array<Tv> oval)
INSTANTIATE(Tk, uintl , dr)
}
}

22 changes: 14 additions & 8 deletions src/backend/cpu/kernel/sort_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,34 @@ namespace cpu
namespace kernel
{
template <typename Tk, typename Tv>
using IndexPair = std::pair<Tk, Tv>;
using IndexPair = std::tuple<Tk, Tv>;

template <typename Tk, typename Tv, bool isAscending>
struct IPCompare
{
bool operator()(const IndexPair<Tk, Tv> &lhs, const IndexPair<Tk, Tv> &rhs)
{
// Check stable sort condition
if(isAscending) return (lhs.first < rhs.first);
else return (lhs.first > rhs.first);
Tk lhsVal = std::get<0>(lhs);
Tk rhsVal = std::get<0>(rhs);
if(isAscending) return (lhsVal < rhsVal);
else return (lhsVal > rhsVal);
}
};

template <typename Tk, typename Tv>
using KeyIndexPair = std::pair<IndexPair<Tk, Tv>, uint>;
using KeyIndexPair = std::tuple<Tk, Tv, uint>;

template <typename Tk, typename Tv, bool isAscending>
struct KIPCompareV
{
bool operator()(const KeyIndexPair<Tk, Tv> &lhs, const KeyIndexPair<Tk, Tv> &rhs)
{
// Check stable sort condition
if(isAscending) return (lhs.first.first < rhs.first.first);
else return (lhs.first.first > rhs.first.first);
Tk lhsVal = std::get<0>(lhs);
Tk rhsVal = std::get<0>(rhs);
if(isAscending) return (lhsVal < rhsVal);
else return (lhsVal > rhsVal);
}
};

Expand All @@ -46,8 +50,10 @@ namespace cpu
{
bool operator()(const KeyIndexPair<Tk, Tv> &lhs, const KeyIndexPair<Tk, Tv> &rhs)
{
if(isAscending) return (lhs.second < rhs.second);
else return (lhs.second > rhs.second);
uint lhsVal = std::get<2>(lhs);
uint rhsVal = std::get<2>(rhs);
if(isAscending) return (lhsVal < rhsVal);
else return (lhsVal > rhsVal);
}
};
}
Expand Down