Skip to content

Commit 75b1a6b

Browse files
committed
Instantiate sort_by_key kernels in separately
1 parent b6a6a87 commit 75b1a6b

26 files changed

Lines changed: 331 additions & 220 deletions

src/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ FILE(GLOB cuda_headers
158158
FILE(GLOB cuda_sources
159159
"*.cu"
160160
"*.cpp"
161-
"sort_by_key/*.cu"
161+
"kernel/sort_by_key/*.cu"
162162
"kernel/*.cu")
163163

164164
FILE(GLOB jit_sources

src/backend/cuda/kernel/iota.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10+
#include <af/dim4.hpp>
1011
#include <math.hpp>
1112
#include <dispatch.hpp>
1213
#include <Param.hpp>
@@ -69,7 +70,7 @@ namespace cuda
6970
// Wrapper functions
7071
///////////////////////////////////////////////////////////////////////////
7172
template<typename T>
72-
void iota(Param<T> out, const dim4 &sdims, const dim4 &tdims)
73+
void iota(Param<T> out, const af::dim4 &sdims, const af::dim4 &tdims)
7374
{
7475
dim3 threads(IOTA_TX, IOTA_TY, 1);
7576

src/backend/cuda/kernel/sort_by_key.hpp

Lines changed: 4 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -12,130 +12,19 @@
1212
#include <Param.hpp>
1313
#include <err_cuda.hpp>
1414
#include <debug_cuda.hpp>
15-
#include <kernel/sort_helper.hpp>
16-
#include <kernel/iota.hpp>
17-
18-
#include <thrust/device_ptr.h>
19-
#include <thrust/device_vector.h>
20-
#include <thrust/sort.h>
2115

2216
namespace cuda
2317
{
2418
namespace kernel
2519
{
26-
///////////////////////////////////////////////////////////////////////////
27-
// Wrapper functions
28-
///////////////////////////////////////////////////////////////////////////
2920
template<typename Tk, typename Tv, bool isAscending>
30-
void sort0ByKeyIterative(Param<Tk> okey, Param<Tv> oval)
31-
{
32-
thrust::device_ptr<Tk> okey_ptr = thrust::device_pointer_cast(okey.ptr);
33-
thrust::device_ptr<Tv> oval_ptr = thrust::device_pointer_cast(oval.ptr);
34-
35-
for(int w = 0; w < okey.dims[3]; w++) {
36-
int okeyW = w * okey.strides[3];
37-
int ovalW = w * oval.strides[3];
38-
for(int z = 0; z < okey.dims[2]; z++) {
39-
int okeyWZ = okeyW + z * okey.strides[2];
40-
int ovalWZ = ovalW + z * oval.strides[2];
41-
for(int y = 0; y < okey.dims[1]; y++) {
42-
43-
int okeyOffset = okeyWZ + y * okey.strides[1];
44-
int ovalOffset = ovalWZ + y * oval.strides[1];
45-
46-
if(isAscending) {
47-
THRUST_SELECT(thrust::stable_sort_by_key,
48-
okey_ptr + okeyOffset,
49-
okey_ptr + okeyOffset + okey.dims[0],
50-
oval_ptr + ovalOffset);
51-
} else {
52-
THRUST_SELECT(thrust::stable_sort_by_key,
53-
okey_ptr + okeyOffset,
54-
okey_ptr + okeyOffset + okey.dims[0],
55-
oval_ptr + ovalOffset, thrust::greater<Tk>());
56-
}
57-
}
58-
}
59-
}
60-
POST_LAUNCH_CHECK();
61-
}
21+
void sort0ByKeyIterative(Param<Tk> okey, Param<Tv> oval);
6222

6323
template<typename Tk, typename Tv, bool isAscending, int dim>
64-
void sortByKeyBatched(Param<Tk> pKey, Param<Tv> pVal)
65-
{
66-
af::dim4 inDims;
67-
for(int i = 0; i < 4; i++)
68-
inDims[i] = pKey.dims[i];
69-
70-
// Sort dimension
71-
// tileDims * seqDims = inDims
72-
af::dim4 tileDims(1);
73-
af::dim4 seqDims = inDims;
74-
tileDims[dim] = inDims[dim];
75-
seqDims[dim] = 1;
76-
77-
// Create/call iota
78-
// Array<uint> key = iota<uint>(seqDims, tileDims);
79-
dim4 keydims = inDims;
80-
uint* key = memAlloc<uint>(keydims.elements());
81-
Param<uint> pSeq;
82-
pSeq.ptr = key;
83-
pSeq.strides[0] = 1;
84-
pSeq.dims[0] = keydims[0];
85-
for(int i = 1; i < 4; i++) {
86-
pSeq.dims[i] = keydims[i];
87-
pSeq.strides[i] = pSeq.strides[i - 1] * pSeq.dims[i - 1];
88-
}
89-
cuda::kernel::iota<uint>(pSeq, seqDims, tileDims);
90-
91-
// Make pkey, pVal into a pair
92-
thrust::device_vector<IndexPair<Tk, Tv> > X(inDims.elements());
93-
IndexPair<Tk, Tv> *Xptr = thrust::raw_pointer_cast(X.data());
94-
95-
const int threads = 256;
96-
int blocks = divup(inDims.elements(), threads * copyPairIter);
97-
CUDA_LAUNCH((makeIndexPair<Tk, Tv>), blocks, threads,
98-
Xptr, pKey.ptr, pVal.ptr, inDims.elements());
99-
POST_LAUNCH_CHECK();
100-
101-
// Sort indices
102-
// Need to convert pSeq to thrust::device_ptr, otherwise thrust
103-
// throws weird errors for all *64 data types (double, intl, uintl etc)
104-
thrust::device_ptr<uint> dSeq = thrust::device_pointer_cast(pSeq.ptr);
105-
THRUST_SELECT(thrust::stable_sort_by_key,
106-
X.begin(), X.end(),
107-
dSeq,
108-
IPCompare<Tk, Tv, isAscending>());
109-
POST_LAUNCH_CHECK();
110-
111-
// Needs to be ascending (true) in order to maintain the indices properly
112-
//kernel::sort0_by_key<uint, T, true>(pKey, pVal);
113-
THRUST_SELECT(thrust::stable_sort_by_key,
114-
dSeq,
115-
dSeq + inDims.elements(),
116-
X.begin());
117-
POST_LAUNCH_CHECK();
118-
119-
CUDA_LAUNCH((splitIndexPair<Tk, Tv>), blocks, threads,
120-
pKey.ptr, pVal.ptr, Xptr, inDims.elements());
121-
POST_LAUNCH_CHECK();
122-
123-
// No need of doing moddims here because the original Array<T>
124-
// dimensions have not been changed
125-
//val.modDims(inDims);
126-
127-
memFree(key);
128-
}
24+
void sortByKeyBatched(Param<Tk> pKey, Param<Tv> pVal);
12925

13026
template<typename Tk, typename Tv, bool isAscending>
131-
void sort0ByKey(Param<Tk> okey, Param<Tv> oval)
132-
{
133-
int higherDims = okey.dims[1] * okey.dims[2] * okey.dims[3];
134-
// TODO Make a better heurisitic
135-
if(higherDims > 5)
136-
kernel::sortByKeyBatched<Tk, Tv, isAscending, 0>(okey, oval);
137-
else
138-
kernel::sort0ByKeyIterative<Tk, Tv, isAscending>(okey, oval);
139-
}
27+
void sort0ByKey(Param<Tk> okey, Param<Tv> oval);
28+
14029
}
14130
}

src/backend/cuda/sort_by_key/ascd_f32.cu renamed to src/backend/cuda/kernel/sort_by_key/ascd_f32.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10-
#include <sort_by_key_impl.hpp>
10+
#include <kernel/sort_by_key_impl.hpp>
1111

1212
namespace cuda
13+
{
14+
namespace kernel
1315
{
1416
INSTANTIATE1(float, true)
1517
}
18+
}

src/backend/cuda/sort_by_key/ascd_f64.cu renamed to src/backend/cuda/kernel/sort_by_key/ascd_f64.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10-
#include <sort_by_key_impl.hpp>
10+
#include <kernel/sort_by_key_impl.hpp>
1111

1212
namespace cuda
13+
{
14+
namespace kernel
1315
{
1416
INSTANTIATE1(double, true)
1517
}
18+
}

src/backend/cuda/sort_by_key/ascd_s16.cu renamed to src/backend/cuda/kernel/sort_by_key/ascd_s16.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10-
#include <sort_by_key_impl.hpp>
10+
#include <kernel/sort_by_key_impl.hpp>
1111

1212
namespace cuda
13+
{
14+
namespace kernel
1315
{
1416
INSTANTIATE1(short, true)
1517
}
18+
}

src/backend/cuda/sort_by_key/ascd_s32.cu renamed to src/backend/cuda/kernel/sort_by_key/ascd_s32.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10-
#include <sort_by_key_impl.hpp>
10+
#include <kernel/sort_by_key_impl.hpp>
1111

1212
namespace cuda
13+
{
14+
namespace kernel
1315
{
1416
INSTANTIATE1(int, true)
1517
}
18+
}

src/backend/cuda/sort_by_key/ascd_s64.cu renamed to src/backend/cuda/kernel/sort_by_key/ascd_s64.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10-
#include <sort_by_key_impl.hpp>
10+
#include <kernel/sort_by_key_impl.hpp>
1111

1212
namespace cuda
13+
{
14+
namespace kernel
1315
{
1416
INSTANTIATE1(intl, true)
1517
}
18+
}

src/backend/cuda/sort_by_key/ascd_s8.cu renamed to src/backend/cuda/kernel/sort_by_key/ascd_s8.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10-
#include <sort_by_key_impl.hpp>
10+
#include <kernel/sort_by_key_impl.hpp>
1111

1212
namespace cuda
13+
{
14+
namespace kernel
1315
{
1416
INSTANTIATE1(char, true)
1517
}
18+
}

src/backend/cuda/sort_by_key/ascd_u16.cu renamed to src/backend/cuda/kernel/sort_by_key/ascd_u16.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10-
#include <sort_by_key_impl.hpp>
10+
#include <kernel/sort_by_key_impl.hpp>
1111

1212
namespace cuda
13+
{
14+
namespace kernel
1315
{
1416
INSTANTIATE1(ushort, true)
1517
}
18+
}

0 commit comments

Comments
 (0)