Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 147 additions & 46 deletions src/backend/cuda/blas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <functional>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <vector>

using arrayfire::common::half;
Expand Down Expand Up @@ -244,6 +245,62 @@ cublasStatus_t gemmBatchedDispatch(BlasHandle handle, cublasOperation_t lOpts,
#endif
}

static void getInputBatchStrides(const dim4 &outDims, const dim4 &inDims,
const dim4 &inStrides, dim_t &stride2,
dim_t &stride3) {
// Convert A/B batch broadcasting into effective batch strides.
// If the input does not vary in a batch dimension, its effective stride
// in that dimension is zero.
stride2 = (outDims[2] == inDims[2]) ? inStrides[2] : 0;
stride3 = (outDims[3] == inDims[3]) ? inStrides[3] : 0;
}

static bool getStridedBatchStride(dim_t dim2, dim_t dim3, dim_t stride2,
dim_t stride3, dim_t &batchStride) {
// stride2/stride3 are already effective batch strides.
// Check whether X(z,w) = X0 + z*stride2 + w*stride3 can be flattened as
// X_n = X0 + n*batchStride with n = z + w*dim2.

if (dim2 <= 1 && dim3 <= 1) {
batchStride = 0;
return true;
}

if (dim3 <= 1) {
batchStride = stride2;
return true;
}

if (dim2 <= 1) {
batchStride = stride3;
return true;
}

const dim_t inRowStep = stride2;
const dim_t rowBoundaryStep = stride3 - (dim2 - 1) * stride2;

batchStride = inRowStep;
return rowBoundaryStep == inRowStep;
}

#if __CUDACC_VER_MAJOR__ >= 10
template<typename Ti, typename To = Ti>
cublasStatus_t gemmStridedBatchedDispatch(
BlasHandle handle, cublasOperation_t lOpts, cublasOperation_t rOpts, int M,
int N, int K, const To *alpha, const Array<Ti> &lhs, dim_t lStride,
dim_t lBatchStride, const Array<Ti> &rhs, dim_t rStride,
dim_t rBatchStride, const To *beta, Array<To> &out, dim_t oStride,
dim_t oBatchStride, int batchSize) {
return cublasGemmStridedBatchedEx(
blasHandle(), lOpts, rOpts, M, N, K, alpha, lhs.get(), getType<Ti>(),
lStride, static_cast<long long int>(lBatchStride), rhs.get(),
getType<Ti>(), rStride, static_cast<long long int>(rBatchStride), beta,
out.get(), getType<To>(), oStride,
static_cast<long long int>(oBatchStride), batchSize,
getComputeType<To>(), selectGEMMAlgorithm<Ti>());
}
#endif

template<typename Ti, typename To>
void gemm(Array<To> &out, af_mat_prop optLhs, af_mat_prop optRhs, const To *alpha,
const Array<Ti> &lhs, const Array<Ti> &rhs, const To *beta) {
Expand All @@ -270,54 +327,98 @@ void gemm(Array<To> &out, af_mat_prop optLhs, af_mat_prop optRhs, const To *alph
lhs, lStrides[1], rhs, rStrides[1], beta,
out, oStrides[1])));
} else {
int batchSize = oDims[2] * oDims[3];
vector<const Ti *> lptrs(batchSize);
vector<const Ti *> rptrs(batchSize);
vector<To *> optrs(batchSize);

bool is_l_d2_batched = oDims[2] == lDims[2];
bool is_l_d3_batched = oDims[3] == lDims[3];

bool is_r_d2_batched = oDims[2] == rDims[2];
bool is_r_d3_batched = oDims[3] == rDims[3];

const Ti *lptr = lhs.get();
const Ti *rptr = rhs.get();
To *optr = out.get();

for (int n = 0; n < batchSize; n++) {
int w = n / oDims[2];
int z = n - w * oDims[2];
int loff = z * (is_l_d2_batched * lStrides[2]) +
w * (is_l_d3_batched * lStrides[3]);
int roff = z * (is_r_d2_batched * rStrides[2]) +
w * (is_r_d3_batched * rStrides[3]);
lptrs[n] = lptr + loff;
rptrs[n] = rptr + roff;
optrs[n] = optr + z * oStrides[2] + w * oStrides[3];
}
const dim_t batchDim2 = oDims[2];
const dim_t batchDim3 = oDims[3];
int batchSize = batchDim2 * batchDim3;

size_t bytes = batchSize * sizeof(Ti **);
auto d_lptrs = memAlloc<uchar>(bytes);
auto d_rptrs = memAlloc<uchar>(bytes);
auto d_optrs = memAlloc<uchar>(bytes);
CUDA_CHECK(cudaMemcpyAsync(d_lptrs.get(), lptrs.data(), bytes,
cudaMemcpyHostToDevice, getActiveStream()));
CUDA_CHECK(cudaMemcpyAsync(d_rptrs.get(), rptrs.data(), bytes,
cudaMemcpyHostToDevice, getActiveStream()));
CUDA_CHECK(cudaMemcpyAsync(d_optrs.get(), optrs.data(), bytes,
cudaMemcpyHostToDevice, getActiveStream()));

// Call this before the gemm call so that you don't have to wait for the
// computation. Even though it would make more sense to put it
// afterwards
CUDA_CHECK(cudaStreamSynchronize(getActiveStream()));
dim_t lStride2 = 0;
dim_t lStride3 = 0;
dim_t rStride2 = 0;
dim_t rStride3 = 0;

using Nt = typename common::kernel_type<Ti>::native;
CUBLAS_CHECK(gemmBatchedDispatch(
blasHandle(), lOpts, rOpts, M, N, K, alpha,
(const Ti **)d_lptrs.get(), lStrides[1], (const Ti **)d_rptrs.get(),
rStrides[1], beta, (To **)d_optrs.get(), oStrides[1], batchSize));
getInputBatchStrides(oDims, lDims, lStrides, lStride2, lStride3);
getInputBatchStrides(oDims, rDims, rStrides, rStride2, rStride3);

const dim_t oStride2 = oStrides[2];
const dim_t oStride3 = oStrides[3];

dim_t lBatchStride = 0;
dim_t rBatchStride = 0;
dim_t oBatchStride = 0;

bool can_use_strided_batched = false;

#if __CUDACC_VER_MAJOR__ >= 10
{
auto prop = getDeviceProp(getActiveDeviceId());

const bool l_is_strided = getStridedBatchStride(
batchDim2, batchDim3, lStride2, lStride3, lBatchStride);

const bool r_is_strided = getStridedBatchStride(
batchDim2, batchDim3, rStride2, rStride3, rBatchStride);

const bool o_is_strided = getStridedBatchStride(
batchDim2, batchDim3, oStride2, oStride3, oBatchStride);

can_use_strided_batched =
prop.major > 3 && is_same<Ti, To>::value && l_is_strided &&
r_is_strided && o_is_strided;
}
#endif

if (can_use_strided_batched) {
#if __CUDACC_VER_MAJOR__ >= 10
CUBLAS_CHECK((gemmStridedBatchedDispatch<Ti, To>(
blasHandle(), lOpts, rOpts, M, N, K, alpha, lhs, lStrides[1],
lBatchStride, rhs, rStrides[1], rBatchStride, beta, out,
oStrides[1], oBatchStride, batchSize)));
#endif
} else {
vector<const Ti *> lptrs(batchSize);
vector<const Ti *> rptrs(batchSize);
vector<To *> optrs(batchSize);

const Ti *lptr = lhs.get();
const Ti *rptr = rhs.get();
To *optr = out.get();

for (int n = 0; n < batchSize; n++) {
int w = n / batchDim2;
int z = n - w * batchDim2;
dim_t loff = z * lStride2 + w * lStride3;
dim_t roff = z * rStride2 + w * rStride3;
dim_t ooff = z * oStride2 + w * oStride3;
lptrs[n] = lptr + loff;
rptrs[n] = rptr + roff;
optrs[n] = optr + ooff;
}

size_t bytes = batchSize * sizeof(Ti **);
auto d_lptrs = memAlloc<uchar>(bytes);
auto d_rptrs = memAlloc<uchar>(bytes);
auto d_optrs = memAlloc<uchar>(bytes);
CUDA_CHECK(cudaMemcpyAsync(d_lptrs.get(), lptrs.data(), bytes,
cudaMemcpyHostToDevice,
getActiveStream()));
CUDA_CHECK(cudaMemcpyAsync(d_rptrs.get(), rptrs.data(), bytes,
cudaMemcpyHostToDevice,
getActiveStream()));
CUDA_CHECK(cudaMemcpyAsync(d_optrs.get(), optrs.data(), bytes,
cudaMemcpyHostToDevice,
getActiveStream()));

// Call this before the gemm call so that you don't have to wait for
// the computation. Even though it would make more sense to put it
// afterwards
CUDA_CHECK(cudaStreamSynchronize(getActiveStream()));

CUBLAS_CHECK(gemmBatchedDispatch(
blasHandle(), lOpts, rOpts, M, N, K, alpha,
(const Ti **)d_lptrs.get(), lStrides[1],
(const Ti **)d_rptrs.get(), rStrides[1], beta,
(To **)d_optrs.get(), oStrides[1], batchSize));
}
}
}

Expand Down