Skip to content

Commit 5360328

Browse files
author
pradeep
committed
Address perf regression in approx after dim based interop was introduced
1 parent fe0c8d5 commit 5360328

8 files changed

Lines changed: 215 additions & 221 deletions

File tree

src/backend/cuda/kernel/approx.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ void approx1(Param<Ty> yo, CParam<Ty> yi, CParam<Tp> xo, const int xdim,
3131
const af::interpType method, const int order) {
3232
static const std::string source(approx1_cuh, approx1_cuh_len);
3333

34-
auto approx1 = common::getKernel(
35-
"cuda::approx1", {source},
36-
{TemplateTypename<Ty>(), TemplateTypename<Tp>(), TemplateArg(order)});
34+
auto approx1 =
35+
common::getKernel("cuda::approx1", {source},
36+
{TemplateTypename<Ty>(), TemplateTypename<Tp>(),
37+
TemplateArg(xdim), TemplateArg(order)});
3738

3839
dim3 threads(THREADS, 1, 1);
3940
int blocksPerMat = divup(yo.dims[0], threads.x);
@@ -48,7 +49,7 @@ void approx1(Param<Ty> yo, CParam<Ty> yi, CParam<Tp> xo, const int xdim,
4849

4950
EnqueueArgs qArgs(blocks, threads, getActiveStream());
5051

51-
approx1(qArgs, yo, yi, xo, xdim, xi_beg, xi_step, offGrid, blocksPerMat,
52+
approx1(qArgs, yo, yi, xo, xi_beg, Tp(1) / xi_step, offGrid, blocksPerMat,
5253
batch, method);
5354

5455
POST_LAUNCH_CHECK();
@@ -63,7 +64,8 @@ void approx2(Param<Ty> zo, CParam<Ty> zi, CParam<Tp> xo, const int xdim,
6364

6465
auto approx2 = common::getKernel(
6566
"cuda::approx2", {source},
66-
{TemplateTypename<Ty>(), TemplateTypename<Tp>(), TemplateArg(order)});
67+
{TemplateTypename<Ty>(), TemplateTypename<Tp>(), TemplateArg(xdim),
68+
TemplateArg(ydim), TemplateArg(order)});
6769

6870
dim3 threads(TX, TY, 1);
6971
int blocksPerMatX = divup(zo.dims[0], threads.x);
@@ -79,8 +81,9 @@ void approx2(Param<Ty> zo, CParam<Ty> zi, CParam<Tp> xo, const int xdim,
7981

8082
EnqueueArgs qArgs(blocks, threads, getActiveStream());
8183

82-
approx2(qArgs, zo, zi, xo, xdim, xi_beg, xi_step, yo, ydim, yi_beg, yi_step,
83-
offGrid, blocksPerMatX, blocksPerMatY, batch, method);
84+
approx2(qArgs, zo, zi, xo, xi_beg, Tp(1) / xi_step, yo, yi_beg,
85+
Tp(1) / yi_step, offGrid, blocksPerMatX, blocksPerMatY, batch,
86+
method);
8487

8588
POST_LAUNCH_CHECK();
8689
}

src/backend/cuda/kernel/approx1.cuh

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414

1515
namespace cuda {
1616

17-
template<typename Ty, typename Tp, int order>
18-
__global__
19-
void approx1(Param<Ty> yo, CParam<Ty> yi, CParam<Tp> xo,
20-
const int xdim, const Tp xi_beg,
21-
const Tp xi_step, const float offGrid,
22-
const int blocksMatX, const bool batch,
23-
af::interpType method) {
17+
template<typename Ty, typename Tp, int xdim, int order>
18+
__global__ void approx1(Param<Ty> yo, CParam<Ty> yi, CParam<Tp> xo,
19+
const Tp xi_beg, const Tp xi_step_reproc,
20+
const float offGrid, const int blocksMatX,
21+
const bool batch, af::interpType method) {
2422
const int idy = blockIdx.x / blocksMatX;
2523
const int blockIdx_x = blockIdx.x - idy * blocksMatX;
2624
const int idx = blockIdx_x * blockDim.x + threadIdx.x;
@@ -32,36 +30,42 @@ void approx1(Param<Ty> yo, CParam<Ty> yi, CParam<Tp> xo,
3230
idw >= yo.dims[3])
3331
return;
3432

35-
bool is_xo_off[] = {xo.dims[0] > 1, xo.dims[1] > 1, xo.dims[2] > 1,
36-
xo.dims[3] > 1};
37-
bool is_yi_off[] = {true, true, true, true};
38-
is_yi_off[xdim] = false;
33+
// FIXME: Only cubic interpolation is doing clamping
34+
// We need to make it consistent across all methods
35+
// Not changing the behavior because tests will fail
36+
const bool clamp = order == 3;
37+
38+
bool is_off[] = {xo.dims[0] > 1, xo.dims[1] > 1, xo.dims[2] > 1,
39+
xo.dims[3] > 1};
40+
41+
int xo_idx = idx * is_off[0];
42+
if (batch) {
43+
xo_idx += idw * xo.strides[3] * is_off[3];
44+
xo_idx += idz * xo.strides[2] * is_off[2];
45+
xo_idx += idy * xo.strides[1] * is_off[1];
46+
}
47+
48+
const Tp x = (xo.ptr[xo_idx] - xi_beg) * xi_step_reproc;
3949

4050
const int yo_idx =
4151
idw * yo.strides[3] + idz * yo.strides[2] + idy * yo.strides[1] + idx;
42-
int xo_idx = idx * is_xo_off[0];
43-
xo_idx += idw * xo.strides[3] * is_xo_off[3];
44-
xo_idx += idz * xo.strides[2] * is_xo_off[2];
45-
xo_idx += idy * xo.strides[1] * is_xo_off[1];
4652

47-
const Tp x = (xo.ptr[xo_idx] - xi_beg) / xi_step;
53+
#pragma unroll
54+
for (int flagIdx = 0; flagIdx < 4; ++flagIdx) { is_off[flagIdx] = true; }
55+
is_off[xdim] = false;
56+
4857
if (x < 0 || yi.dims[xdim] < x + 1) {
4958
yo.ptr[yo_idx] = scalar<Ty>(offGrid);
5059
return;
5160
}
5261

53-
int yi_idx = idx * is_yi_off[0];
54-
yi_idx += idw * yi.strides[3] * is_yi_off[3];
55-
yi_idx += idz * yi.strides[2] * is_yi_off[2];
56-
yi_idx += idy * yi.strides[1] * is_yi_off[1];
57-
58-
// FIXME: Only cubic interpolation is doing clamping
59-
// We need to make it consistent across all methods
60-
// Not changing the behavior because tests will fail
61-
bool clamp = order == 3;
62+
int yi_idx = idx * is_off[0];
63+
yi_idx += idw * yi.strides[3] * is_off[3];
64+
yi_idx += idz * yi.strides[2] * is_off[2];
65+
yi_idx += idy * yi.strides[1] * is_off[1];
6266

63-
Interp1<Ty, Tp, order> interp;
64-
interp(yo, yo_idx, yi, yi_idx, x, method, 1, clamp, xdim);
67+
Interp1<Ty, Tp, xdim, order> interp;
68+
interp(yo, yo_idx, yi, yi_idx, x, method, 1, clamp);
6569
}
6670

67-
}
71+
} // namespace cuda

src/backend/cuda/kernel/approx2.cuh

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@
1414

1515
namespace cuda {
1616

17-
template<typename Ty, typename Tp, int order>
18-
__global__
19-
void approx2(Param<Ty> zo, CParam<Ty> zi, CParam<Tp> xo,
20-
const int xdim, const Tp xi_beg,
21-
const Tp xi_step, CParam<Tp> yo, const int ydim,
22-
const Tp yi_beg, const Tp yi_step,
23-
const float offGrid, const int blocksMatX,
24-
const int blocksMatY, const bool batch,
25-
af::interpType method) {
17+
template<typename Ty, typename Tp, int xdim, int ydim, int order>
18+
__global__ void approx2(Param<Ty> zo, CParam<Ty> zi, CParam<Tp> xo,
19+
const Tp xi_beg, const Tp xi_step_reproc, CParam<Tp> yo,
20+
const Tp yi_beg, const Tp yi_step_reproc,
21+
const float offGrid, const int blocksMatX,
22+
const int blocksMatY, const bool batch,
23+
af::interpType method) {
2624
const int idz = blockIdx.x / blocksMatX;
2725
const int blockIdx_x = blockIdx.x - idz * blocksMatX;
2826
const int idx = threadIdx.x + blockIdx_x * blockDim.x;
@@ -36,39 +34,43 @@ void approx2(Param<Ty> zo, CParam<Ty> zi, CParam<Tp> xo,
3634
idw >= zo.dims[3])
3735
return;
3836

39-
bool is_xo_off[] = {xo.dims[0] > 1, xo.dims[1] > 1, xo.dims[2] > 1,
40-
xo.dims[3] > 1};
41-
bool is_zi_off[] = {true, true, true, true};
42-
is_zi_off[xdim] = false;
43-
is_zi_off[ydim] = false;
37+
// FIXME: Only cubic interpolation is doing clamping
38+
// We need to make it consistent across all methods
39+
// Not changing the behavior because tests will fail
40+
const bool clamp = order == 3;
41+
42+
bool is_off[] = {xo.dims[0] > 1, xo.dims[1] > 1, xo.dims[2] > 1,
43+
xo.dims[3] > 1};
4444

4545
const int zo_idx =
4646
idw * zo.strides[3] + idz * zo.strides[2] + idy * zo.strides[1] + idx;
47-
int xo_idx = idy * xo.strides[1] * is_xo_off[1] + idx * is_xo_off[0];
48-
int yo_idx = idy * yo.strides[1] * is_xo_off[1] + idx * is_xo_off[0];
49-
xo_idx +=
50-
idw * xo.strides[3] * is_xo_off[3] + idz * xo.strides[2] * is_xo_off[2];
51-
yo_idx +=
52-
idw * yo.strides[3] * is_xo_off[3] + idz * yo.strides[2] * is_xo_off[2];
47+
int xo_idx = idy * xo.strides[1] * is_off[1] + idx * is_off[0];
48+
int yo_idx = idy * yo.strides[1] * is_off[1] + idx * is_off[0];
49+
if (batch) {
50+
xo_idx +=
51+
idw * xo.strides[3] * is_off[3] + idz * xo.strides[2] * is_off[2];
52+
yo_idx +=
53+
idw * yo.strides[3] * is_off[3] + idz * yo.strides[2] * is_off[2];
54+
}
55+
56+
const Tp x = (xo.ptr[xo_idx] - xi_beg) * xi_step_reproc;
57+
const Tp y = (yo.ptr[yo_idx] - yi_beg) * yi_step_reproc;
58+
59+
#pragma unroll
60+
for (int flagIdx = 0; flagIdx < 4; ++flagIdx) { is_off[flagIdx] = true; }
61+
is_off[xdim] = false;
62+
is_off[ydim] = false;
5363

54-
const Tp x = (xo.ptr[xo_idx] - xi_beg) / xi_step;
55-
const Tp y = (yo.ptr[yo_idx] - yi_beg) / yi_step;
5664
if (x < 0 || y < 0 || zi.dims[xdim] < x + 1 || zi.dims[ydim] < y + 1) {
5765
zo.ptr[zo_idx] = scalar<Ty>(offGrid);
5866
return;
5967
}
6068

61-
int zi_idx = idy * zi.strides[1] * is_zi_off[1] + idx * is_zi_off[0];
62-
zi_idx +=
63-
idw * zi.strides[3] * is_zi_off[3] + idz * zi.strides[2] * is_zi_off[2];
64-
65-
// FIXME: Only cubic interpolation is doing clamping
66-
// We need to make it consistent across all methods
67-
// Not changing the behavior because tests will fail
68-
bool clamp = order == 3;
69+
int zi_idx = idy * zi.strides[1] * is_off[1] + idx * is_off[0];
70+
zi_idx += idw * zi.strides[3] * is_off[3] + idz * zi.strides[2] * is_off[2];
6971

70-
Interp2<Ty, Tp, order> interp;
71-
interp(zo, zo_idx, zi, zi_idx, x, y, method, 1, clamp, xdim, ydim);
72+
Interp2<Ty, Tp, xdim, ydim, order> interp;
73+
interp(zo, zo_idx, zi, zi_idx, x, y, method, 1, clamp);
7274
}
7375

74-
}
76+
} // namespace cuda

src/backend/cuda/kernel/interp.hpp

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ __device__ inline static Ty bicubicInterpFunc(Ty val[4][4], Tp xratio,
8585
return cubicInterpFunc(res, yratio, spline);
8686
}
8787

88-
template<typename Ty, typename Tp, int order>
88+
template<typename Ty, typename Tp, int xdim, int order>
8989
struct Interp1 {};
9090

91-
template<typename Ty, typename Tp>
92-
struct Interp1<Ty, Tp, 1> {
91+
template<typename Ty, typename Tp, int xdim>
92+
struct Interp1<Ty, Tp, xdim, 1> {
9393
__device__ void operator()(Param<Ty> out, int ooff, CParam<Ty> in, int ioff,
9494
Tp x, af::interpType method, int batch,
95-
bool clamp, int xdim = 0, int batch_dim = 1) {
95+
bool clamp, int batch_dim = 1) {
9696
Ty zero = scalar<Ty>(0);
9797

9898
const int x_lim = in.dims[xdim];
@@ -113,11 +113,11 @@ struct Interp1<Ty, Tp, 1> {
113113
}
114114
};
115115

116-
template<typename Ty, typename Tp>
117-
struct Interp1<Ty, Tp, 2> {
116+
template<typename Ty, typename Tp, int xdim>
117+
struct Interp1<Ty, Tp, xdim, 2> {
118118
__device__ void operator()(Param<Ty> out, int ooff, CParam<Ty> in, int ioff,
119119
Tp x, af::interpType method, int batch,
120-
bool clamp, int xdim = 0, int batch_dim = 1) {
120+
bool clamp, int batch_dim = 1) {
121121
typedef typename itype_t<Tp>::wtype WT;
122122
typedef typename itype_t<Ty>::vtype VT;
123123

@@ -149,11 +149,11 @@ struct Interp1<Ty, Tp, 2> {
149149
}
150150
};
151151

152-
template<typename Ty, typename Tp>
153-
struct Interp1<Ty, Tp, 3> {
152+
template<typename Ty, typename Tp, int xdim>
153+
struct Interp1<Ty, Tp, xdim, 3> {
154154
__device__ void operator()(Param<Ty> out, int ooff, CParam<Ty> in, int ioff,
155155
Tp x, af::interpType method, int batch,
156-
bool clamp, int xdim = 0, int batch_dim = 1) {
156+
bool clamp, int batch_dim = 1) {
157157
typedef typename itype_t<Tp>::wtype WT;
158158
typedef typename itype_t<Ty>::vtype VT;
159159

@@ -184,15 +184,14 @@ struct Interp1<Ty, Tp, 3> {
184184
}
185185
};
186186

187-
template<typename Ty, typename Tp, int order>
187+
template<typename Ty, typename Tp, int xdim, int ydim, int order>
188188
struct Interp2 {};
189189

190-
template<typename Ty, typename Tp>
191-
struct Interp2<Ty, Tp, 1> {
190+
template<typename Ty, typename Tp, int xdim, int ydim>
191+
struct Interp2<Ty, Tp, xdim, ydim, 1> {
192192
__device__ void operator()(Param<Ty> out, int ooff, CParam<Ty> in, int ioff,
193193
Tp x, Tp y, af::interpType method, int batch,
194-
bool clamp, int xdim = 0, int ydim = 1,
195-
int batch_dim = 2) {
194+
bool clamp, int batch_dim = 2) {
196195
int xid = (method == AF_INTERP_LOWER ? floor(x) : round(x));
197196
int yid = (method == AF_INTERP_LOWER ? floor(y) : round(y));
198197

@@ -222,12 +221,11 @@ struct Interp2<Ty, Tp, 1> {
222221
}
223222
};
224223

225-
template<typename Ty, typename Tp>
226-
struct Interp2<Ty, Tp, 2> {
224+
template<typename Ty, typename Tp, int xdim, int ydim>
225+
struct Interp2<Ty, Tp, xdim, ydim, 2> {
227226
__device__ void operator()(Param<Ty> out, int ooff, CParam<Ty> in, int ioff,
228227
Tp x, Tp y, af::interpType method, int batch,
229-
bool clamp, int xdim = 0, int ydim = 1,
230-
int batch_dim = 2) {
228+
bool clamp, int batch_dim = 2) {
231229
typedef typename itype_t<Tp>::wtype WT;
232230
typedef typename itype_t<Ty>::vtype VT;
233231

@@ -275,12 +273,11 @@ struct Interp2<Ty, Tp, 2> {
275273
}
276274
};
277275

278-
template<typename Ty, typename Tp>
279-
struct Interp2<Ty, Tp, 3> {
276+
template<typename Ty, typename Tp, int xdim, int ydim>
277+
struct Interp2<Ty, Tp, xdim, ydim, 3> {
280278
__device__ void operator()(Param<Ty> out, int ooff, CParam<Ty> in, int ioff,
281279
Tp x, Tp y, af::interpType method, int batch,
282-
bool clamp, int xdim = 0, int ydim = 1,
283-
int batch_dim = 2) {
280+
bool clamp, int batch_dim = 2) {
284281
typedef typename itype_t<Tp>::wtype WT;
285282
typedef typename itype_t<Ty>::vtype VT;
286283

src/backend/opencl/kernel/approx.hpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ inline std::string interpSrc() {
3333
}
3434

3535
template<typename Ty, typename Tp>
36-
auto genCompileOptions(const int order) {
36+
auto genCompileOptions(const int order, const int xdim, const int ydim = -1) {
3737
constexpr bool isComplex =
3838
static_cast<af_dtype>(dtype_traits<Ty>::af_type) == c32 ||
3939
static_cast<af_dtype>(dtype_traits<Ty>::af_type) == c64;
@@ -47,9 +47,11 @@ auto genCompileOptions(const int order) {
4747
DefineKeyValue(InterpValTy, dtype_traits<Ty>::getName()),
4848
DefineKeyValue(InterpPosTy, dtype_traits<Tp>::getName()),
4949
DefineKeyValue(ZERO, toNumStr(scalar<Ty>(0))),
50+
DefineKeyValue(XDIM, xdim),
5051
DefineKeyValue(INTERP_ORDER, order),
5152
DefineKeyValue(IS_CPLX, (isComplex ? 1 : 0)),
5253
};
54+
if (ydim != -1) { compileOpts.emplace_back(DefineKeyValue(YDIM, ydim)); }
5355
compileOpts.emplace_back(getTypeBuildDefinition<Ty>());
5456
addInterpEnumOptions(compileOpts);
5557

@@ -72,9 +74,10 @@ void approx1(Param yo, const Param yi, const Param xo, const int xdim,
7274
vector<TemplateArg> tmpltArgs = {
7375
TemplateTypename<Ty>(),
7476
TemplateTypename<Tp>(),
77+
TemplateArg(xdim),
7578
TemplateArg(order),
7679
};
77-
auto compileOpts = genCompileOptions<Ty, Tp>(order);
80+
auto compileOpts = genCompileOptions<Ty, Tp>(order, xdim);
7881

7982
auto approx1 = common::getKernel("approx1", {interpSrc(), src}, tmpltArgs,
8083
compileOpts);
@@ -89,7 +92,7 @@ void approx1(Param yo, const Param yi, const Param xo, const int xdim,
8992
!(xo.info.dims[1] == 1 && xo.info.dims[2] == 1 && xo.info.dims[3] == 1);
9093

9194
approx1(EnqueueArgs(getQueue(), global, local), *yo.data, yo.info, *yi.data,
92-
yi.info, *xo.data, xo.info, xdim, xi_beg, xi_step,
95+
yi.info, *xo.data, xo.info, xi_beg, Tp(1) / xi_step,
9396
scalar<Ty>(offGrid), (int)blocksPerMat, (int)batch, (int)method);
9497
CL_DEBUG_FINISH(getQueue());
9598
}
@@ -111,11 +114,10 @@ void approx2(Param zo, const Param zi, const Param xo, const int xdim,
111114
static const string src(approx2_cl, approx2_cl_len);
112115

113116
vector<TemplateArg> tmpltArgs = {
114-
TemplateTypename<Ty>(),
115-
TemplateTypename<Tp>(),
116-
TemplateArg(order),
117+
TemplateTypename<Ty>(), TemplateTypename<Tp>(), TemplateArg(xdim),
118+
TemplateArg(ydim), TemplateArg(order),
117119
};
118-
auto compileOpts = genCompileOptions<Ty, Tp>(order);
120+
auto compileOpts = genCompileOptions<Ty, Tp>(order, xdim, ydim);
119121

120122
auto approx2 = common::getKernel("approx2", {interpSrc(), src}, tmpltArgs,
121123
compileOpts);
@@ -130,8 +132,8 @@ void approx2(Param zo, const Param zi, const Param xo, const int xdim,
130132
bool batch = !(xo.info.dims[2] == 1 && xo.info.dims[3] == 1);
131133

132134
approx2(EnqueueArgs(getQueue(), global, local), *zo.data, zo.info, *zi.data,
133-
zi.info, *xo.data, xo.info, xdim, *yo.data, yo.info, ydim, xi_beg,
134-
xi_step, yi_beg, yi_step, scalar<Ty>(offGrid),
135+
zi.info, *xo.data, xo.info, *yo.data, yo.info, xi_beg,
136+
Tp(1) / xi_step, yi_beg, Tp(1) / yi_step, scalar<Ty>(offGrid),
135137
static_cast<int>(blocksPerMatX), static_cast<int>(blocksPerMatY),
136138
static_cast<int>(batch), static_cast<int>(method));
137139
CL_DEBUG_FINISH(getQueue());

0 commit comments

Comments
 (0)