Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions include/af/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,33 @@
#pragma once
#include <af/defines.h>
#include <af/exception.h>

/// This file contain functions that apply only to the CUDA backend. It will
/// include cuda headers when it is built with NVCC. Otherwise the you can
/// define the AF_DEFINE_CUDA_TYPES before including this file and it will
/// define the cuda types used in this header.

#ifdef __NVCC__
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#else
#ifdef AF_DEFINE_CUDA_TYPES
typedef struct CUstream_st *cudaStream_t;

/*Enum for default math mode/tensor operation*/
typedef enum {
CUBLAS_DEFAULT_MATH = 0,
CUBLAS_TENSOR_OP_MATH = 1
} cublasMath_t;
#endif
#endif

#ifdef __cplusplus
extern "C" {
#endif


#if AF_API_VERSION >= 31
/**
Get the stream for the CUDA device with \p id in ArrayFire context
Expand Down Expand Up @@ -55,6 +75,20 @@ AFAPI af_err afcu_get_native_id(int* nativeid, int id);
AFAPI af_err afcu_set_native_id(int nativeid);
#endif

#if AF_API_VERSION >= 37
/**
Sets the cuBLAS math mode for the internal handle

See the cuBLAS documentation for additional details

\param[in] mode The cublasMath_t type to set
\returns \ref af_err error code

\ingroup cuda_mat
*/
AFAPI af_err afcu_cublasSetMathMode(cublasMath_t mode);
#endif

#ifdef __cplusplus
}
#endif
Expand Down
1 change: 1 addition & 0 deletions src/api/unified/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ target_sources(af
${CMAKE_CURRENT_SOURCE_DIR}/arith.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/blas.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/data.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/error.cpp
Expand Down
52 changes: 52 additions & 0 deletions src/api/unified/cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*******************************************************
* Copyright (c) 2019, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/

#include "symbol_manager.hpp"
#include <af/backend.h>

#define AF_DEFINE_CUDA_TYPES
#include <af/cuda.h>

af_err afcu_get_stream(cudaStream_t* stream, int id) {
af_backend backend;
af_get_active_backend(&backend);
if(backend == AF_BACKEND_CUDA) {
return CALL(stream, id);
}
return AF_ERR_NOT_SUPPORTED;
}

af_err afcu_get_native_id(int* nativeid, int id) {
af_backend backend;
af_get_active_backend(&backend);
if(backend == AF_BACKEND_CUDA) {
return CALL(nativeid, id);
}
return AF_ERR_NOT_SUPPORTED;
}


af_err afcu_set_native_id(int nativeid) {
af_backend backend;
af_get_active_backend(&backend);
if(backend == AF_BACKEND_CUDA) {
return CALL(nativeid);
}
return AF_ERR_NOT_SUPPORTED;
}


af_err afcu_cublasSetMathMode(cublasMath_t mode) {
af_backend backend;
af_get_active_backend(&backend);
if(backend == AF_BACKEND_CUDA) {
return CALL(mode);
}
return AF_ERR_NOT_SUPPORTED;
}
1 change: 1 addition & 0 deletions src/backend/cuda/device_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <platform.hpp>
#include <spdlog/spdlog.h>
#include <version.hpp>
#include <cublas_v2.h> // needed for af/cuda.h
#include <af/cuda.h>
#include <af/version.h>
// cuda_gl_interop.h does not include OpenGL headers for ARM
Expand Down
8 changes: 8 additions & 0 deletions src/backend/cuda/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,11 @@ af_err afcu_set_native_id(int nativeid) {
CATCHALL;
return AF_SUCCESS;
}

af_err afcu_cublasSetMathMode(cublasMath_t mode) {
try {
CUBLAS_CHECK(cublasSetMathMode(cuda::blasHandle(), mode));
}
CATCHALL;
return AF_SUCCESS;
}