|
12 | 12 | #include <Param.hpp> |
13 | 13 | #include <err_cuda.hpp> |
14 | 14 | #include <debug_cuda.hpp> |
15 | | -#include <kernel/sort_helper.hpp> |
16 | | -#include <kernel/iota.hpp> |
17 | | - |
18 | | -#include <thrust/device_ptr.h> |
19 | | -#include <thrust/device_vector.h> |
20 | | -#include <thrust/sort.h> |
21 | 15 |
|
22 | 16 | namespace cuda |
23 | 17 | { |
24 | 18 | namespace kernel |
25 | 19 | { |
26 | | - /////////////////////////////////////////////////////////////////////////// |
27 | | - // Wrapper functions |
28 | | - /////////////////////////////////////////////////////////////////////////// |
29 | 20 | template<typename Tk, typename Tv, bool isAscending> |
30 | | - void sort0ByKeyIterative(Param<Tk> okey, Param<Tv> oval) |
31 | | - { |
32 | | - thrust::device_ptr<Tk> okey_ptr = thrust::device_pointer_cast(okey.ptr); |
33 | | - thrust::device_ptr<Tv> oval_ptr = thrust::device_pointer_cast(oval.ptr); |
34 | | - |
35 | | - for(int w = 0; w < okey.dims[3]; w++) { |
36 | | - int okeyW = w * okey.strides[3]; |
37 | | - int ovalW = w * oval.strides[3]; |
38 | | - for(int z = 0; z < okey.dims[2]; z++) { |
39 | | - int okeyWZ = okeyW + z * okey.strides[2]; |
40 | | - int ovalWZ = ovalW + z * oval.strides[2]; |
41 | | - for(int y = 0; y < okey.dims[1]; y++) { |
42 | | - |
43 | | - int okeyOffset = okeyWZ + y * okey.strides[1]; |
44 | | - int ovalOffset = ovalWZ + y * oval.strides[1]; |
45 | | - |
46 | | - if(isAscending) { |
47 | | - THRUST_SELECT(thrust::stable_sort_by_key, |
48 | | - okey_ptr + okeyOffset, |
49 | | - okey_ptr + okeyOffset + okey.dims[0], |
50 | | - oval_ptr + ovalOffset); |
51 | | - } else { |
52 | | - THRUST_SELECT(thrust::stable_sort_by_key, |
53 | | - okey_ptr + okeyOffset, |
54 | | - okey_ptr + okeyOffset + okey.dims[0], |
55 | | - oval_ptr + ovalOffset, thrust::greater<Tk>()); |
56 | | - } |
57 | | - } |
58 | | - } |
59 | | - } |
60 | | - POST_LAUNCH_CHECK(); |
61 | | - } |
| 21 | + void sort0ByKeyIterative(Param<Tk> okey, Param<Tv> oval); |
62 | 22 |
|
63 | 23 | template<typename Tk, typename Tv, bool isAscending, int dim> |
64 | | - void sortByKeyBatched(Param<Tk> pKey, Param<Tv> pVal) |
65 | | - { |
66 | | - af::dim4 inDims; |
67 | | - for(int i = 0; i < 4; i++) |
68 | | - inDims[i] = pKey.dims[i]; |
69 | | - |
70 | | - // Sort dimension |
71 | | - // tileDims * seqDims = inDims |
72 | | - af::dim4 tileDims(1); |
73 | | - af::dim4 seqDims = inDims; |
74 | | - tileDims[dim] = inDims[dim]; |
75 | | - seqDims[dim] = 1; |
76 | | - |
77 | | - // Create/call iota |
78 | | - // Array<uint> key = iota<uint>(seqDims, tileDims); |
79 | | - dim4 keydims = inDims; |
80 | | - uint* key = memAlloc<uint>(keydims.elements()); |
81 | | - Param<uint> pSeq; |
82 | | - pSeq.ptr = key; |
83 | | - pSeq.strides[0] = 1; |
84 | | - pSeq.dims[0] = keydims[0]; |
85 | | - for(int i = 1; i < 4; i++) { |
86 | | - pSeq.dims[i] = keydims[i]; |
87 | | - pSeq.strides[i] = pSeq.strides[i - 1] * pSeq.dims[i - 1]; |
88 | | - } |
89 | | - cuda::kernel::iota<uint>(pSeq, seqDims, tileDims); |
90 | | - |
91 | | - // Make pkey, pVal into a pair |
92 | | - thrust::device_vector<IndexPair<Tk, Tv> > X(inDims.elements()); |
93 | | - IndexPair<Tk, Tv> *Xptr = thrust::raw_pointer_cast(X.data()); |
94 | | - |
95 | | - const int threads = 256; |
96 | | - int blocks = divup(inDims.elements(), threads * copyPairIter); |
97 | | - CUDA_LAUNCH((makeIndexPair<Tk, Tv>), blocks, threads, |
98 | | - Xptr, pKey.ptr, pVal.ptr, inDims.elements()); |
99 | | - POST_LAUNCH_CHECK(); |
100 | | - |
101 | | - // Sort indices |
102 | | - // Need to convert pSeq to thrust::device_ptr, otherwise thrust |
103 | | - // throws weird errors for all *64 data types (double, intl, uintl etc) |
104 | | - thrust::device_ptr<uint> dSeq = thrust::device_pointer_cast(pSeq.ptr); |
105 | | - THRUST_SELECT(thrust::stable_sort_by_key, |
106 | | - X.begin(), X.end(), |
107 | | - dSeq, |
108 | | - IPCompare<Tk, Tv, isAscending>()); |
109 | | - POST_LAUNCH_CHECK(); |
110 | | - |
111 | | - // Needs to be ascending (true) in order to maintain the indices properly |
112 | | - //kernel::sort0_by_key<uint, T, true>(pKey, pVal); |
113 | | - THRUST_SELECT(thrust::stable_sort_by_key, |
114 | | - dSeq, |
115 | | - dSeq + inDims.elements(), |
116 | | - X.begin()); |
117 | | - POST_LAUNCH_CHECK(); |
118 | | - |
119 | | - CUDA_LAUNCH((splitIndexPair<Tk, Tv>), blocks, threads, |
120 | | - pKey.ptr, pVal.ptr, Xptr, inDims.elements()); |
121 | | - POST_LAUNCH_CHECK(); |
122 | | - |
123 | | - // No need of doing moddims here because the original Array<T> |
124 | | - // dimensions have not been changed |
125 | | - //val.modDims(inDims); |
126 | | - |
127 | | - memFree(key); |
128 | | - } |
| 24 | + void sortByKeyBatched(Param<Tk> pKey, Param<Tv> pVal); |
129 | 25 |
|
130 | 26 | template<typename Tk, typename Tv, bool isAscending> |
131 | | - void sort0ByKey(Param<Tk> okey, Param<Tv> oval) |
132 | | - { |
133 | | - int higherDims = okey.dims[1] * okey.dims[2] * okey.dims[3]; |
134 | | - // TODO Make a better heurisitic |
135 | | - if(higherDims > 5) |
136 | | - kernel::sortByKeyBatched<Tk, Tv, isAscending, 0>(okey, oval); |
137 | | - else |
138 | | - kernel::sort0ByKeyIterative<Tk, Tv, isAscending>(okey, oval); |
139 | | - } |
| 27 | + void sort0ByKey(Param<Tk> okey, Param<Tv> oval); |
| 28 | + |
140 | 29 | } |
141 | 30 | } |
0 commit comments