Skip to content

Commit ea61dd4

Browse files
author
Kumar Aatish
committed
CUDA Backend for scan by key
1 parent 5bf26fd commit ea61dd4

File tree

19 files changed

+1530
-86
lines changed

19 files changed

+1530
-86
lines changed

include/af/algorithm.h

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -318,17 +318,31 @@ namespace af
318318

319319
#if AF_API_VERSION >=34
320320
/**
321-
C++ Interface exclusive sum (cumulative sum) of an array
321+
C++ Interface generalized scan of an array
322322
323323
\param[in] in is the input array
324-
\param[in] dim The dimension along which exclusive sum is performed
324+
\param[in] dim The dimension along which scan is performed
325+
\param[in] op is the type of binary operation used
326+
\param[in] inclusive_scan is flag specifying whether scan is inclusive
327+
\return the output containing scan of the input
328+
329+
\ingroup scan_func_scan
330+
*/
331+
AFAPI array scan(const array &in, const int dim = 0, af_binary_op op = AF_BINARY_ADD, bool inclusive_scan = true);
332+
333+
/**
334+
C++ Interface generalized scan by key of an array
335+
336+
\param[in] key is the key array
337+
\param[in] in is the input array
338+
\param[in] dim The dimension along which scan is performed
325339
\param[in] op is the type of binary operations used
326340
\param[in] inclusive_scan is flag specifying whether scan is inclusive
327-
\return the output containing exclusive sums of the input
341+
\return the output containing scan of the input
328342
329343
\ingroup scan_func_scan
330344
*/
331-
AFAPI array scan(const array &in, const int dim = 0, af_binary_op op = AF_ADD, bool inclusive_scan = true);
345+
AFAPI array scanByKey(const array &key, const array& in, const int dim = 0, af_binary_op op = AF_BINARY_ADD, bool inclusive_scan = true);
332346
#endif
333347

334348
/**
@@ -762,16 +776,31 @@ extern "C" {
762776
/**
763777
C Interface generalized scan of an array
764778
765-
\param[out] out will contain exclusive sums of the input
779+
\param[out] out will contain scan of the input
766780
\param[in] in is the input array
767-
\param[in] dim The dimension along which exclusive sum is performed
781+
\param[in] dim The dimension along which scan is performed
768782
\param[in] op is the type of binary operations used
769783
\param[in] inclusive_scan is flag specifying whether scan is inclusive
770784
\return \ref AF_SUCCESS if the execution completes properly
771785
772786
\ingroup scan_func_scan
773787
*/
774788
AFAPI af_err af_scan(af_array *out, const af_array in, const int dim, af_binary_op op, bool inclusive_scan);
789+
790+
/**
791+
C Interface generalized scan by key of an array
792+
793+
\param[out] out will contain scan of the input
794+
\param[in] key is the key array
795+
\param[in] in is the input array
796+
\param[in] dim The dimension along which scan is performed
797+
\param[in] op is the type of binary operations used
798+
\param[in] inclusive_scan is flag specifying whether scan is inclusive
799+
\return \ref AF_SUCCESS if the execution completes properly
800+
801+
\ingroup scan_func_scan
802+
*/
803+
AFAPI af_err af_scan_by_key(af_array *out, const af_array key, const af_array in, const int dim, af_binary_op op, bool inclusive_scan);
775804
#endif
776805

777806
/**

include/af/defines.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,12 +383,10 @@ typedef enum {
383383

384384
#if AF_API_VERSION >=34
385385
typedef enum {
386-
AF_ADD = 0,
387-
AF_SUB = 1,
388-
AF_MUL = 2,
389-
AF_DIV = 3,
390-
AF_MIN = 4,
391-
AF_MAX = 5
386+
AF_BINARY_ADD = 0,
387+
AF_BINARY_MUL = 1,
388+
AF_BINARY_MIN = 2,
389+
AF_BINARY_MAX = 3
392390
} af_binary_op;
393391
#endif
394392

src/api/c/scan.cpp

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <handle.hpp>
1616
#include <ops.hpp>
1717
#include <scan.hpp>
18+
#include <scan_by_key.hpp>
1819
#include <backend.hpp>
1920

2021
using af::dim4;
@@ -26,18 +27,49 @@ static inline af_array scan(const af_array in, const int dim, bool inclusive_sca
2627
return getHandle(scan<op,Ti,To>(getArray<Ti>(in), dim, inclusive_scan));
2728
}
2829

30+
template<af_op_t op, typename Ti, typename To>
31+
static inline af_array scan_key(const af_array key, const af_array in, const int dim, bool inclusive_scan = true)
32+
{
33+
const ArrayInfo& key_info = getInfo(key);
34+
af_dtype type = key_info.getType();
35+
af_array out;
36+
37+
switch(type) {
38+
case s32: out = getHandle(scan<op, Ti, int, To>(getArray< int>(key), getArray<Ti>(in), dim, inclusive_scan)); break;
39+
case u32: out = getHandle(scan<op, Ti, uint, To>(getArray< uint>(key), getArray<Ti>(in), dim, inclusive_scan)); break;
40+
case s64: out = getHandle(scan<op, Ti, intl, To>(getArray< intl>(key), getArray<Ti>(in), dim, inclusive_scan)); break;
41+
case u64: out = getHandle(scan<op, Ti, uintl, To>(getArray<uintl>(key), getArray<Ti>(in), dim, inclusive_scan)); break;
42+
default:
43+
TYPE_ERROR(1, type);
44+
}
45+
return out;
46+
}
47+
48+
template<typename Ti, typename To>
49+
static inline af_array scan_op(const af_array key, const af_array in, const int dim, af_binary_op op, bool inclusive_scan = true)
50+
{
51+
af_array out;
52+
53+
switch(op) {
54+
case AF_BINARY_ADD: out = scan_key<af_add_t, Ti, To>(key, in, dim, inclusive_scan); break;
55+
case AF_BINARY_MUL: out = scan_key<af_mul_t, Ti, To>(key, in, dim, inclusive_scan); break;
56+
case AF_BINARY_MIN: out = scan_key<af_min_t, Ti, To>(key, in, dim, inclusive_scan); break;
57+
case AF_BINARY_MAX: out = scan_key<af_max_t, Ti, To>(key, in, dim, inclusive_scan); break;
58+
//TODO Error for op in default case
59+
}
60+
return out;
61+
}
62+
2963
template<typename Ti, typename To>
3064
static inline af_array scan_op(const af_array in, const int dim, af_binary_op op, bool inclusive_scan)
3165
{
3266
af_array out;
3367

3468
switch(op) {
35-
case AF_ADD: out = scan<af_add_t, Ti, To>(in, dim, inclusive_scan); break;
36-
case AF_SUB: out = scan<af_sub_t, Ti, To>(in, dim, inclusive_scan); break;
37-
case AF_MUL: out = scan<af_mul_t, Ti, To>(in, dim, inclusive_scan); break;
38-
case AF_DIV: out = scan<af_div_t, Ti, To>(in, dim, inclusive_scan); break;
39-
case AF_MIN: out = scan<af_min_t, Ti, To>(in, dim, inclusive_scan); break;
40-
case AF_MAX: out = scan<af_max_t, Ti, To>(in, dim, inclusive_scan); break;
69+
case AF_BINARY_ADD: out = scan<af_add_t, Ti, To>(in, dim, inclusive_scan); break;
70+
case AF_BINARY_MUL: out = scan<af_mul_t, Ti, To>(in, dim, inclusive_scan); break;
71+
case AF_BINARY_MIN: out = scan<af_min_t, Ti, To>(in, dim, inclusive_scan); break;
72+
case AF_BINARY_MAX: out = scan<af_max_t, Ti, To>(in, dim, inclusive_scan); break;
4173
//TODO Error for op in default case
4274
}
4375
return out;
@@ -125,3 +157,44 @@ af_err af_scan(af_array *out, const af_array in, const int dim, af_binary_op op,
125157

126158
return AF_SUCCESS;
127159
}
160+
161+
af_err af_scan_by_key(af_array *out, const af_array key, const af_array in, const int dim, af_binary_op op, bool inclusive_scan)
162+
{
163+
ARG_ASSERT(2, dim >= 0);
164+
ARG_ASSERT(2, dim < 4);
165+
166+
try {
167+
168+
const ArrayInfo& in_info = getInfo(in);
169+
170+
if (dim >= (int)in_info.ndims()) {
171+
*out = retain(in);
172+
return AF_SUCCESS;
173+
}
174+
175+
af_dtype type = in_info.getType();
176+
af_array res;
177+
178+
switch(type) {
179+
case f32: res = scan_op<float , float >(key, in, dim, op, inclusive_scan); break;
180+
case f64: res = scan_op<double , double >(key, in, dim, op, inclusive_scan); break;
181+
case c32: res = scan_op<cfloat , cfloat >(key, in, dim, op, inclusive_scan); break;
182+
case c64: res = scan_op<cdouble, cdouble>(key, in, dim, op, inclusive_scan); break;
183+
case u32: res = scan_op<uint , uint >(key, in, dim, op, inclusive_scan); break;
184+
case s32: res = scan_op<int , int >(key, in, dim, op, inclusive_scan); break;
185+
case u64: res = scan_op<uintl , uintl >(key, in, dim, op, inclusive_scan); break;
186+
case s64: res = scan_op<intl , intl >(key, in, dim, op, inclusive_scan); break;
187+
case u16: res = scan_op<ushort , uint >(key, in, dim, op, inclusive_scan); break;
188+
case s16: res = scan_op<short , int >(key, in, dim, op, inclusive_scan); break;
189+
case u8: res = scan_op<uchar , uint >(key, in, dim, op, inclusive_scan); break;
190+
case b8: res = scan_op<char , uint >(key, in, dim, op, inclusive_scan); break;
191+
default:
192+
TYPE_ERROR(1, type);
193+
}
194+
195+
std::swap(*out, res);
196+
}
197+
CATCHALL;
198+
199+
return AF_SUCCESS;
200+
}

src/api/cpp/scan.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,11 @@ namespace af
2626
AF_THROW(af_scan(&out, in.get(), dim, op, inclusive_scan));
2727
return array(out);
2828
}
29+
30+
array scanByKey(const array& key, const array& in, const int dim, af_binary_op op, bool inclusive_scan)
31+
{
32+
af_array out = 0;
33+
AF_THROW(af_scan_by_key(&out, key.get(), in.get(), dim, op, inclusive_scan));
34+
return array(out);
35+
}
2936
}

src/backend/cpu/scan.cpp

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,30 +71,59 @@ namespace cpu
7171
return out;
7272
}
7373

74-
#define INSTANTIATE(ROp, Ti, To)\
74+
template<af_op_t op, typename Ti, typename Tk, typename To>
75+
Array<To> scan(const Array<Tk>& key, const Array<Ti>& in, const int dim, bool inclusive_scan)
76+
{
77+
return scan(in, dim, inclusive_scan);
78+
}
79+
80+
#define INSTANTIATE_SCAN(ROp, Ti, To)\
7581
template Array<To> scan<ROp, Ti, To>(const Array<Ti> &in, const int dim, bool inclusive_scan);
7682

77-
#define INSTANTIATE_SCAN(ROp) \
78-
INSTANTIATE(ROp, float , float ) \
79-
INSTANTIATE(ROp, double , double ) \
80-
INSTANTIATE(ROp, cfloat , cfloat ) \
81-
INSTANTIATE(ROp, cdouble, cdouble) \
82-
INSTANTIATE(ROp, int , int ) \
83-
INSTANTIATE(ROp, uint , uint ) \
84-
INSTANTIATE(ROp, intl , intl ) \
85-
INSTANTIATE(ROp, uintl , uintl ) \
86-
INSTANTIATE(ROp, char , int ) \
87-
INSTANTIATE(ROp, char , uint ) \
88-
INSTANTIATE(ROp, uchar , uint ) \
89-
INSTANTIATE(ROp, short , int ) \
90-
INSTANTIATE(ROp, ushort , uint )
83+
#define INSTANTIATE_SCAN_BY_KEY(ROp, Ti, Tk, To)\
84+
template Array<To> scan(const Array<Tk>& key, const Array<Ti>& in, const int dim, bool inclusive_scan);
85+
86+
#define INSTANTIATE_SCAN_ALL(ROp) \
87+
INSTANTIATE_SCAN(ROp, float , float ) \
88+
INSTANTIATE_SCAN(ROp, double , double ) \
89+
INSTANTIATE_SCAN(ROp, cfloat , cfloat ) \
90+
INSTANTIATE_SCAN(ROp, cdouble, cdouble) \
91+
INSTANTIATE_SCAN(ROp, int , int ) \
92+
INSTANTIATE_SCAN(ROp, uint , uint ) \
93+
INSTANTIATE_SCAN(ROp, intl , intl ) \
94+
INSTANTIATE_SCAN(ROp, uintl , uintl ) \
95+
INSTANTIATE_SCAN(ROp, char , int ) \
96+
INSTANTIATE_SCAN(ROp, char , uint ) \
97+
INSTANTIATE_SCAN(ROp, uchar , uint ) \
98+
INSTANTIATE_SCAN(ROp, short , int ) \
99+
INSTANTIATE_SCAN(ROp, ushort , uint )
100+
101+
#define INSTANTIATE_SCAN_BY_KEY_ALL(ROp, Tk) \
102+
INSTANTIATE_SCAN_BY_KEY(ROp, float , Tk, float ) \
103+
INSTANTIATE_SCAN_BY_KEY(ROp, double , Tk, double ) \
104+
INSTANTIATE_SCAN_BY_KEY(ROp, cfloat , Tk, cfloat ) \
105+
INSTANTIATE_SCAN_BY_KEY(ROp, cdouble, Tk, cdouble) \
106+
INSTANTIATE_SCAN_BY_KEY(ROp, int , Tk, int ) \
107+
INSTANTIATE_SCAN_BY_KEY(ROp, uint , Tk, uint ) \
108+
INSTANTIATE_SCAN_BY_KEY(ROp, intl , Tk, intl ) \
109+
INSTANTIATE_SCAN_BY_KEY(ROp, uintl , Tk, uintl ) \
110+
INSTANTIATE_SCAN_BY_KEY(ROp, char , Tk, int ) \
111+
INSTANTIATE_SCAN_BY_KEY(ROp, char , Tk, uint ) \
112+
INSTANTIATE_SCAN_BY_KEY(ROp, uchar , Tk, uint ) \
113+
INSTANTIATE_SCAN_BY_KEY(ROp, short , Tk, int ) \
114+
INSTANTIATE_SCAN_BY_KEY(ROp, ushort , Tk, uint )
115+
116+
#define INSTANTIATE_SCAN_OP(ROp) \
117+
INSTANTIATE_SCAN_ALL(ROp) \
118+
INSTANTIATE_SCAN_BY_KEY_ALL(ROp, int) \
119+
INSTANTIATE_SCAN_BY_KEY_ALL(ROp, uint) \
120+
INSTANTIATE_SCAN_BY_KEY_ALL(ROp, long) \
121+
INSTANTIATE_SCAN_BY_KEY_ALL(ROp, ulong)
91122

92123
//accum
93-
INSTANTIATE(af_notzero_t, char , uint)
94-
INSTANTIATE_SCAN(af_add_t)
95-
INSTANTIATE_SCAN(af_sub_t)
96-
INSTANTIATE_SCAN(af_mul_t)
97-
INSTANTIATE_SCAN(af_div_t)
98-
INSTANTIATE_SCAN(af_min_t)
99-
INSTANTIATE_SCAN(af_max_t)
124+
INSTANTIATE_SCAN(af_notzero_t, char, uint)
125+
INSTANTIATE_SCAN_OP(af_add_t)
126+
INSTANTIATE_SCAN_OP(af_mul_t)
127+
INSTANTIATE_SCAN_OP(af_min_t)
128+
INSTANTIATE_SCAN_OP(af_max_t)
100129
}

src/backend/cpu/scan.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,7 @@ namespace cpu
1414
{
1515
template<af_op_t op, typename Ti, typename To>
1616
Array<To> scan(const Array<Ti>& in, const int dim, bool inclusive_scan = true);
17+
18+
template<af_op_t op, typename Ti, typename Tk, typename To>
19+
Array<To> scan(const Array<Ti>& in, const Array<Tk>& key, const int dim, bool inclusive_scan = true);
1720
}

src/backend/cuda/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ SOURCE_GROUP(api\\cpp\\Sources FILES ${cpp_sources})
260260

261261
INCLUDE("${CMAKE_CURRENT_SOURCE_DIR}/kernel/sort_by_key/CMakeLists.txt")
262262

263+
INCLUDE("${CMAKE_CURRENT_SOURCE_DIR}/kernel/scan_by_key/CMakeLists.txt")
264+
263265
LIST(LENGTH COMPUTE_VERSIONS COMPUTE_COUNT)
264266
IF(${COMPUTE_COUNT} EQUAL 1)
265267
SET(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${CUDA_GENERATE_CODE}")
@@ -391,6 +393,7 @@ MY_CUDA_ADD_LIBRARY(afcuda SHARED
391393
${c_sources}
392394
${cpp_sources}
393395
${sort_by_key_sources}
396+
${scan_by_key_sources}
394397
OPTIONS ${CUDA_GENERATE_CODE})
395398

396399
ADD_DEPENDENCIES(afcuda ${ptx_targets})
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
FILE(STRINGS "${CMAKE_CURRENT_SOURCE_DIR}/kernel/scan_by_key/scan_by_key_impl.cu.in" FILESTRINGS)
2+
3+
FOREACH(STR ${FILESTRINGS})
4+
IF(${STR} MATCHES "// SBK_BINARY_OPS")
5+
STRING(REPLACE "// SBK_BINARY_OPS:" "" TEMP ${STR})
6+
STRING(REPLACE " " ";" SBK_BINARY_OPS ${TEMP})
7+
ENDIF()
8+
ENDFOREACH()
9+
10+
FOREACH(SBK_BINARY_OP ${SBK_BINARY_OPS})
11+
CONFIGURE_FILE("${CMAKE_CURRENT_SOURCE_DIR}/kernel/scan_by_key/scan_by_key_impl.cu.in"
12+
"${CMAKE_CURRENT_BINARY_DIR}/scan_by_key/scan_by_key_impl_${SBK_BINARY_OP}.cu")
13+
ADD_CUSTOM_COMMAND(
14+
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/scan_by_key/scan_by_key_impl_${SBK_BINARY_OP}.cu"
15+
COMMAND ${CMAKE_COMMAND} -E touch "${CMAKE_CURRENT_BINARY_DIR}/scan_by_key/scan_by_key_impl_${SBK_BINARY_OP}.cu"
16+
DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/kernel/scan_first_by_key_impl.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/kernel/scan_dim_by_key_impl.hpp")
17+
ENDFOREACH(SBK_BINARY_OP ${SBK_BINARY_OPS})
18+
19+
FILE(GLOB scan_by_key_sources
20+
"${CMAKE_CURRENT_BINARY_DIR}/scan_by_key/*.cu"
21+
)
22+
23+
LIST(SORT scan_by_key_sources)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*******************************************************
2+
* Copyright (c) 2014, ArrayFire
3+
* All rights reserved.
4+
*
5+
* This file is distributed under 3-clause BSD license.
6+
* The complete license agreement can be obtained at:
7+
* http://arrayfire.com/licenses/BSD-3-Clause
8+
********************************************************/
9+
10+
#include <ops.hpp>
11+
#include <backend.hpp>
12+
#include <kernel/scan_first_by_key_impl.hpp>
13+
#include <kernel/scan_dim_by_key_impl.hpp>
14+
15+
// This file instantiates scan_dim_by_key as separate object files from CMake
16+
// The line below is read by CMake to determenine the instantiations
17+
// SBK_BINARY_OPS:af_add_t af_mul_t af_max_t af_min_t
18+
19+
namespace cuda
20+
{
21+
namespace kernel
22+
{
23+
INSTANTIATE_SCAN_FIRST_BY_KEY_OP(@SBK_BINARY_OP@)
24+
INSTANTIATE_SCAN_DIM_BY_KEY_OP(@SBK_BINARY_OP@)
25+
}
26+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*******************************************************
2+
* Copyright (c) 2014, ArrayFire
3+
* All rights reserved.
4+
*
5+
* This file is distributed under 3-clause BSD license.
6+
* The complete license agreement can be obtained at:
7+
* http://arrayfire.com/licenses/BSD-3-Clause
8+
********************************************************/
9+
#pragma once
10+
11+
#include <ops.hpp>
12+
#include <Param.hpp>
13+
14+
namespace cuda
15+
{
16+
namespace kernel
17+
{
18+
template<typename Ti, typename Tk, typename To, af_op_t op, int dim, bool inclusive_scan>
19+
void scan_dim_by_key(Param<To> out, CParam<Ti> in, CParam<Tk> key);
20+
}
21+
}

0 commit comments

Comments
 (0)