diff --git a/src/backend/cuda/blas.cu b/src/backend/cuda/blas.cu index 08df398a8d..bd77c27f89 100644 --- a/src/backend/cuda/blas.cu +++ b/src/backend/cuda/blas.cu @@ -31,6 +31,7 @@ #include #include #include +#include #include using arrayfire::common::half; @@ -244,6 +245,62 @@ cublasStatus_t gemmBatchedDispatch(BlasHandle handle, cublasOperation_t lOpts, #endif } +static void getInputBatchStrides(const dim4 &outDims, const dim4 &inDims, + const dim4 &inStrides, dim_t &stride2, + dim_t &stride3) { + // Convert A/B batch broadcasting into effective batch strides. + // If the input does not vary in a batch dimension, its effective stride + // in that dimension is zero. + stride2 = (outDims[2] == inDims[2]) ? inStrides[2] : 0; + stride3 = (outDims[3] == inDims[3]) ? inStrides[3] : 0; +} + +static bool getStridedBatchStride(dim_t dim2, dim_t dim3, dim_t stride2, + dim_t stride3, dim_t &batchStride) { + // stride2/stride3 are already effective batch strides. + // Check whether X(z,w) = X0 + z*stride2 + w*stride3 can be flattened as + // X_n = X0 + n*batchStride with n = z + w*dim2. + + if (dim2 <= 1 && dim3 <= 1) { + batchStride = 0; + return true; + } + + if (dim3 <= 1) { + batchStride = stride2; + return true; + } + + if (dim2 <= 1) { + batchStride = stride3; + return true; + } + + const dim_t inRowStep = stride2; + const dim_t rowBoundaryStep = stride3 - (dim2 - 1) * stride2; + + batchStride = inRowStep; + return rowBoundaryStep == inRowStep; +} + +#if __CUDACC_VER_MAJOR__ >= 10 +template +cublasStatus_t gemmStridedBatchedDispatch( + BlasHandle handle, cublasOperation_t lOpts, cublasOperation_t rOpts, int M, + int N, int K, const To *alpha, const Array &lhs, dim_t lStride, + dim_t lBatchStride, const Array &rhs, dim_t rStride, + dim_t rBatchStride, const To *beta, Array &out, dim_t oStride, + dim_t oBatchStride, int batchSize) { + return cublasGemmStridedBatchedEx( + blasHandle(), lOpts, rOpts, M, N, K, alpha, lhs.get(), getType(), + lStride, static_cast(lBatchStride), rhs.get(), + getType(), rStride, static_cast(rBatchStride), beta, + out.get(), getType(), oStride, + static_cast(oBatchStride), batchSize, + getComputeType(), selectGEMMAlgorithm()); +} +#endif + template void gemm(Array &out, af_mat_prop optLhs, af_mat_prop optRhs, const To *alpha, const Array &lhs, const Array &rhs, const To *beta) { @@ -270,54 +327,98 @@ void gemm(Array &out, af_mat_prop optLhs, af_mat_prop optRhs, const To *alph lhs, lStrides[1], rhs, rStrides[1], beta, out, oStrides[1]))); } else { - int batchSize = oDims[2] * oDims[3]; - vector lptrs(batchSize); - vector rptrs(batchSize); - vector optrs(batchSize); - - bool is_l_d2_batched = oDims[2] == lDims[2]; - bool is_l_d3_batched = oDims[3] == lDims[3]; - - bool is_r_d2_batched = oDims[2] == rDims[2]; - bool is_r_d3_batched = oDims[3] == rDims[3]; - - const Ti *lptr = lhs.get(); - const Ti *rptr = rhs.get(); - To *optr = out.get(); - - for (int n = 0; n < batchSize; n++) { - int w = n / oDims[2]; - int z = n - w * oDims[2]; - int loff = z * (is_l_d2_batched * lStrides[2]) + - w * (is_l_d3_batched * lStrides[3]); - int roff = z * (is_r_d2_batched * rStrides[2]) + - w * (is_r_d3_batched * rStrides[3]); - lptrs[n] = lptr + loff; - rptrs[n] = rptr + roff; - optrs[n] = optr + z * oStrides[2] + w * oStrides[3]; - } + const dim_t batchDim2 = oDims[2]; + const dim_t batchDim3 = oDims[3]; + int batchSize = batchDim2 * batchDim3; - size_t bytes = batchSize * sizeof(Ti **); - auto d_lptrs = memAlloc(bytes); - auto d_rptrs = memAlloc(bytes); - auto d_optrs = memAlloc(bytes); - CUDA_CHECK(cudaMemcpyAsync(d_lptrs.get(), lptrs.data(), bytes, - cudaMemcpyHostToDevice, getActiveStream())); - CUDA_CHECK(cudaMemcpyAsync(d_rptrs.get(), rptrs.data(), bytes, - cudaMemcpyHostToDevice, getActiveStream())); - CUDA_CHECK(cudaMemcpyAsync(d_optrs.get(), optrs.data(), bytes, - cudaMemcpyHostToDevice, getActiveStream())); - - // Call this before the gemm call so that you don't have to wait for the - // computation. Even though it would make more sense to put it - // afterwards - CUDA_CHECK(cudaStreamSynchronize(getActiveStream())); + dim_t lStride2 = 0; + dim_t lStride3 = 0; + dim_t rStride2 = 0; + dim_t rStride3 = 0; - using Nt = typename common::kernel_type::native; - CUBLAS_CHECK(gemmBatchedDispatch( - blasHandle(), lOpts, rOpts, M, N, K, alpha, - (const Ti **)d_lptrs.get(), lStrides[1], (const Ti **)d_rptrs.get(), - rStrides[1], beta, (To **)d_optrs.get(), oStrides[1], batchSize)); + getInputBatchStrides(oDims, lDims, lStrides, lStride2, lStride3); + getInputBatchStrides(oDims, rDims, rStrides, rStride2, rStride3); + + const dim_t oStride2 = oStrides[2]; + const dim_t oStride3 = oStrides[3]; + + dim_t lBatchStride = 0; + dim_t rBatchStride = 0; + dim_t oBatchStride = 0; + + bool can_use_strided_batched = false; + +#if __CUDACC_VER_MAJOR__ >= 10 + { + auto prop = getDeviceProp(getActiveDeviceId()); + + const bool l_is_strided = getStridedBatchStride( + batchDim2, batchDim3, lStride2, lStride3, lBatchStride); + + const bool r_is_strided = getStridedBatchStride( + batchDim2, batchDim3, rStride2, rStride3, rBatchStride); + + const bool o_is_strided = getStridedBatchStride( + batchDim2, batchDim3, oStride2, oStride3, oBatchStride); + + can_use_strided_batched = + prop.major > 3 && is_same::value && l_is_strided && + r_is_strided && o_is_strided; + } +#endif + + if (can_use_strided_batched) { +#if __CUDACC_VER_MAJOR__ >= 10 + CUBLAS_CHECK((gemmStridedBatchedDispatch( + blasHandle(), lOpts, rOpts, M, N, K, alpha, lhs, lStrides[1], + lBatchStride, rhs, rStrides[1], rBatchStride, beta, out, + oStrides[1], oBatchStride, batchSize))); +#endif + } else { + vector lptrs(batchSize); + vector rptrs(batchSize); + vector optrs(batchSize); + + const Ti *lptr = lhs.get(); + const Ti *rptr = rhs.get(); + To *optr = out.get(); + + for (int n = 0; n < batchSize; n++) { + int w = n / batchDim2; + int z = n - w * batchDim2; + dim_t loff = z * lStride2 + w * lStride3; + dim_t roff = z * rStride2 + w * rStride3; + dim_t ooff = z * oStride2 + w * oStride3; + lptrs[n] = lptr + loff; + rptrs[n] = rptr + roff; + optrs[n] = optr + ooff; + } + + size_t bytes = batchSize * sizeof(Ti **); + auto d_lptrs = memAlloc(bytes); + auto d_rptrs = memAlloc(bytes); + auto d_optrs = memAlloc(bytes); + CUDA_CHECK(cudaMemcpyAsync(d_lptrs.get(), lptrs.data(), bytes, + cudaMemcpyHostToDevice, + getActiveStream())); + CUDA_CHECK(cudaMemcpyAsync(d_rptrs.get(), rptrs.data(), bytes, + cudaMemcpyHostToDevice, + getActiveStream())); + CUDA_CHECK(cudaMemcpyAsync(d_optrs.get(), optrs.data(), bytes, + cudaMemcpyHostToDevice, + getActiveStream())); + + // Call this before the gemm call so that you don't have to wait for + // the computation. Even though it would make more sense to put it + // afterwards + CUDA_CHECK(cudaStreamSynchronize(getActiveStream())); + + CUBLAS_CHECK(gemmBatchedDispatch( + blasHandle(), lOpts, rOpts, M, N, K, alpha, + (const Ti **)d_lptrs.get(), lStrides[1], + (const Ti **)d_rptrs.get(), rStrides[1], beta, + (To **)d_optrs.get(), oStrides[1], batchSize)); + } } }