Skip to content

Commit 1ea8a45

Browse files
committed
use size_t for threshold computation
1 parent e4ba190 commit 1ea8a45

1 file changed

Lines changed: 4 additions & 8 deletions

File tree

src/gpuarray_blas_cuda_cublas.c

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,8 @@ static int sgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
288288
}
289289

290290
// use parallel cublasSgemm calls rather than cublasSgemmBatched for large products
291-
// (compute products in double because they can be large and we don't need to be exact)
292-
const double threshold = 650;
293-
const int multiple_dispatch = ((double)M * (double)N * (double)K >
294-
threshold * threshold * threshold);
291+
const size_t threshold = 650;
292+
const int multiple_dispatch = M * N * K > threshold * threshold * threshold;
295293
if (multiple_dispatch) {
296294
for (i = 0; i < batchCount; i++) {
297295
ASSERT_BUF(A[i]);
@@ -406,10 +404,8 @@ static int dgemmBatch(cb_order order, cb_transpose transA, cb_transpose transB,
406404
}
407405

408406
// use parallel cublasSgemm calls rather than cublasSgemmBatched for large products
409-
// (compute products in double because they can be large and we don't need to be exact)
410-
const double threshold = 650;
411-
const int multiple_dispatch = ((double)M * (double)N * (double)K >
412-
threshold * threshold * threshold);
407+
const size_t threshold = 650;
408+
const int multiple_dispatch = M * N * K > threshold * threshold * threshold;
413409
if (multiple_dispatch) {
414410
for (i = 0; i < batchCount; i++) {
415411
ASSERT_BUF(A[i]);

0 commit comments

Comments
 (0)