diff --git a/include/af/cuda.h b/include/af/cuda.h index dbf1480a80..27908427e7 100644 --- a/include/af/cuda.h +++ b/include/af/cuda.h @@ -10,13 +10,33 @@ #pragma once #include #include + +/// This file contain functions that apply only to the CUDA backend. It will +/// include cuda headers when it is built with NVCC. Otherwise the you can +/// define the AF_DEFINE_CUDA_TYPES before including this file and it will +/// define the cuda types used in this header. + +#ifdef __NVCC__ #include #include +#include +#else +#ifdef AF_DEFINE_CUDA_TYPES +typedef struct CUstream_st *cudaStream_t; + +/*Enum for default math mode/tensor operation*/ +typedef enum { + CUBLAS_DEFAULT_MATH = 0, + CUBLAS_TENSOR_OP_MATH = 1 +} cublasMath_t; +#endif +#endif #ifdef __cplusplus extern "C" { #endif + #if AF_API_VERSION >= 31 /** Get the stream for the CUDA device with \p id in ArrayFire context @@ -55,6 +75,20 @@ AFAPI af_err afcu_get_native_id(int* nativeid, int id); AFAPI af_err afcu_set_native_id(int nativeid); #endif +#if AF_API_VERSION >= 37 +/** + Sets the cuBLAS math mode for the internal handle + + See the cuBLAS documentation for additional details + + \param[in] mode The cublasMath_t type to set + \returns \ref af_err error code + + \ingroup cuda_mat +*/ +AFAPI af_err afcu_cublasSetMathMode(cublasMath_t mode); +#endif + #ifdef __cplusplus } #endif diff --git a/src/api/unified/CMakeLists.txt b/src/api/unified/CMakeLists.txt index f6fb2404de..a1c588ce5c 100644 --- a/src/api/unified/CMakeLists.txt +++ b/src/api/unified/CMakeLists.txt @@ -9,6 +9,7 @@ target_sources(af ${CMAKE_CURRENT_SOURCE_DIR}/arith.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/blas.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/error.cpp diff --git a/src/api/unified/cuda.cpp b/src/api/unified/cuda.cpp new file mode 100644 index 0000000000..451b0ebf78 --- /dev/null +++ b/src/api/unified/cuda.cpp @@ -0,0 +1,52 @@ +/******************************************************* + * Copyright (c) 2019, ArrayFire + * All rights reserved. + * + * This file is distributed under 3-clause BSD license. + * The complete license agreement can be obtained at: + * http://arrayfire.com/licenses/BSD-3-Clause + ********************************************************/ + +#include "symbol_manager.hpp" +#include + +#define AF_DEFINE_CUDA_TYPES +#include + +af_err afcu_get_stream(cudaStream_t* stream, int id) { + af_backend backend; + af_get_active_backend(&backend); + if(backend == AF_BACKEND_CUDA) { + return CALL(stream, id); + } + return AF_ERR_NOT_SUPPORTED; +} + +af_err afcu_get_native_id(int* nativeid, int id) { + af_backend backend; + af_get_active_backend(&backend); + if(backend == AF_BACKEND_CUDA) { + return CALL(nativeid, id); + } + return AF_ERR_NOT_SUPPORTED; +} + + +af_err afcu_set_native_id(int nativeid) { + af_backend backend; + af_get_active_backend(&backend); + if(backend == AF_BACKEND_CUDA) { + return CALL(nativeid); + } + return AF_ERR_NOT_SUPPORTED; +} + + +af_err afcu_cublasSetMathMode(cublasMath_t mode) { + af_backend backend; + af_get_active_backend(&backend); + if(backend == AF_BACKEND_CUDA) { + return CALL(mode); + } + return AF_ERR_NOT_SUPPORTED; +} diff --git a/src/backend/cuda/device_manager.cpp b/src/backend/cuda/device_manager.cpp index 944be4ee87..b4ccaf4736 100644 --- a/src/backend/cuda/device_manager.cpp +++ b/src/backend/cuda/device_manager.cpp @@ -23,6 +23,7 @@ #include #include #include +#include // needed for af/cuda.h #include #include // cuda_gl_interop.h does not include OpenGL headers for ARM diff --git a/src/backend/cuda/platform.cpp b/src/backend/cuda/platform.cpp index aa5d2ed373..9228ffaacc 100644 --- a/src/backend/cuda/platform.cpp +++ b/src/backend/cuda/platform.cpp @@ -435,3 +435,11 @@ af_err afcu_set_native_id(int nativeid) { CATCHALL; return AF_SUCCESS; } + +af_err afcu_cublasSetMathMode(cublasMath_t mode) { + try { + CUBLAS_CHECK(cublasSetMathMode(cuda::blasHandle(), mode)); + } + CATCHALL; + return AF_SUCCESS; +}