Skip to content

Commit 6b177a2

Browse files
umar456pradeep
authored andcommitted
Add a function to set the cuBLAS Math Mode (arrayfire#2584)
* Add a function to set the cuBLAS Math Mode * Include missing cuda.cpp file. Add AF_DEFINE_CUDA_TYPES definition
1 parent e73488d commit 6b177a2

5 files changed

Lines changed: 96 additions & 0 deletions

File tree

include/af/cuda.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,33 @@
1010
#pragma once
1111
#include <af/defines.h>
1212
#include <af/exception.h>
13+
14+
/// This file contain functions that apply only to the CUDA backend. It will
15+
/// include cuda headers when it is built with NVCC. Otherwise the you can
16+
/// define the AF_DEFINE_CUDA_TYPES before including this file and it will
17+
/// define the cuda types used in this header.
18+
19+
#ifdef __NVCC__
1320
#include <cuda.h>
1421
#include <cuda_runtime.h>
22+
#include <cublas_v2.h>
23+
#else
24+
#ifdef AF_DEFINE_CUDA_TYPES
25+
typedef struct CUstream_st *cudaStream_t;
26+
27+
/*Enum for default math mode/tensor operation*/
28+
typedef enum {
29+
CUBLAS_DEFAULT_MATH = 0,
30+
CUBLAS_TENSOR_OP_MATH = 1
31+
} cublasMath_t;
32+
#endif
33+
#endif
1534

1635
#ifdef __cplusplus
1736
extern "C" {
1837
#endif
1938

39+
2040
#if AF_API_VERSION >= 31
2141
/**
2242
Get the stream for the CUDA device with \p id in ArrayFire context
@@ -55,6 +75,20 @@ AFAPI af_err afcu_get_native_id(int* nativeid, int id);
5575
AFAPI af_err afcu_set_native_id(int nativeid);
5676
#endif
5777

78+
#if AF_API_VERSION >= 37
79+
/**
80+
Sets the cuBLAS math mode for the internal handle
81+
82+
See the cuBLAS documentation for additional details
83+
84+
\param[in] mode The cublasMath_t type to set
85+
\returns \ref af_err error code
86+
87+
\ingroup cuda_mat
88+
*/
89+
AFAPI af_err afcu_cublasSetMathMode(cublasMath_t mode);
90+
#endif
91+
5892
#ifdef __cplusplus
5993
}
6094
#endif

src/api/unified/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ target_sources(af
99
${CMAKE_CURRENT_SOURCE_DIR}/arith.cpp
1010
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
1111
${CMAKE_CURRENT_SOURCE_DIR}/blas.cpp
12+
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
1213
${CMAKE_CURRENT_SOURCE_DIR}/data.cpp
1314
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
1415
${CMAKE_CURRENT_SOURCE_DIR}/error.cpp

src/api/unified/cuda.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*******************************************************
2+
* Copyright (c) 2019, ArrayFire
3+
* All rights reserved.
4+
*
5+
* This file is distributed under 3-clause BSD license.
6+
* The complete license agreement can be obtained at:
7+
* http://arrayfire.com/licenses/BSD-3-Clause
8+
********************************************************/
9+
10+
#include "symbol_manager.hpp"
11+
#include <af/backend.h>
12+
13+
#define AF_DEFINE_CUDA_TYPES
14+
#include <af/cuda.h>
15+
16+
af_err afcu_get_stream(cudaStream_t* stream, int id) {
17+
af_backend backend;
18+
af_get_active_backend(&backend);
19+
if(backend == AF_BACKEND_CUDA) {
20+
return CALL(stream, id);
21+
}
22+
return AF_ERR_NOT_SUPPORTED;
23+
}
24+
25+
af_err afcu_get_native_id(int* nativeid, int id) {
26+
af_backend backend;
27+
af_get_active_backend(&backend);
28+
if(backend == AF_BACKEND_CUDA) {
29+
return CALL(nativeid, id);
30+
}
31+
return AF_ERR_NOT_SUPPORTED;
32+
}
33+
34+
35+
af_err afcu_set_native_id(int nativeid) {
36+
af_backend backend;
37+
af_get_active_backend(&backend);
38+
if(backend == AF_BACKEND_CUDA) {
39+
return CALL(nativeid);
40+
}
41+
return AF_ERR_NOT_SUPPORTED;
42+
}
43+
44+
45+
af_err afcu_cublasSetMathMode(cublasMath_t mode) {
46+
af_backend backend;
47+
af_get_active_backend(&backend);
48+
if(backend == AF_BACKEND_CUDA) {
49+
return CALL(mode);
50+
}
51+
return AF_ERR_NOT_SUPPORTED;
52+
}

src/backend/cuda/device_manager.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <platform.hpp>
2424
#include <spdlog/spdlog.h>
2525
#include <version.hpp>
26+
#include <cublas_v2.h> // needed for af/cuda.h
2627
#include <af/cuda.h>
2728
#include <af/version.h>
2829
// cuda_gl_interop.h does not include OpenGL headers for ARM

src/backend/cuda/platform.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,11 @@ af_err afcu_set_native_id(int nativeid) {
435435
CATCHALL;
436436
return AF_SUCCESS;
437437
}
438+
439+
af_err afcu_cublasSetMathMode(cublasMath_t mode) {
440+
try {
441+
CUBLAS_CHECK(cublasSetMathMode(cuda::blasHandle(), mode));
442+
}
443+
CATCHALL;
444+
return AF_SUCCESS;
445+
}

0 commit comments

Comments
 (0)