-
Notifications
You must be signed in to change notification settings - Fork 548
Expand file tree
/
Copy pathcusparse.hpp
More file actions
93 lines (83 loc) · 3.93 KB
/
cusparse.hpp
File metadata and controls
93 lines (83 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
/*******************************************************
* Copyright (c) 2014, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#pragma once
#include <common/SparseArray.hpp>
#include <common/defines.hpp>
#include <common/unique_handle.hpp>
#include <cudaDataType.hpp>
#include <cusparseModule.hpp>
#include <cusparse_v2.h>
#include <err_cuda.hpp>
#if defined(AF_USE_NEW_CUSPARSE_API)
namespace arrayfire {
namespace cuda {
template<typename T>
cusparseStatus_t createSpMatDescr(
cusparseSpMatDescr_t *out, const arrayfire::common::SparseArray<T> &arr) {
auto &_ = arrayfire::cuda::getCusparsePlugin();
switch (arr.getStorage()) {
case AF_STORAGE_CSR: {
return _.cusparseCreateCsr(
out, arr.dims()[0], arr.dims()[1], arr.getNNZ(),
(void *)arr.getRowIdx().get(), (void *)arr.getColIdx().get(),
(void *)arr.getValues().get(), CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, getType<T>());
}
#if CUSPARSE_VERSION >= 11300
case AF_STORAGE_CSC: {
return _.cusparseCreateCsc(
out, arr.dims()[0], arr.dims()[1], arr.getNNZ(),
(void *)arr.getColIdx().get(), (void *)arr.getRowIdx().get(),
(void *)arr.getValues().get(), CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, getType<T>());
}
#else
case AF_STORAGE_CSC:
CUDA_NOT_SUPPORTED(
"Sparse not supported for CSC on this version of the CUDA "
"Toolkit");
#endif
case AF_STORAGE_COO: {
return _.cusparseCreateCoo(
out, arr.dims()[0], arr.dims()[1], arr.getNNZ(),
(void *)arr.getColIdx().get(), (void *)arr.getRowIdx().get(),
(void *)arr.getValues().get(), CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, getType<T>());
}
}
return CUSPARSE_STATUS_SUCCESS;
}
} // namespace cuda
} // namespace arrayfire
#endif
// clang-format off
DEFINE_HANDLER(cusparseHandle_t, arrayfire::cuda::getCusparsePlugin().cusparseCreate, arrayfire::cuda::getCusparsePlugin().cusparseDestroy);
DEFINE_HANDLER(cusparseMatDescr_t, arrayfire::cuda::getCusparsePlugin().cusparseCreateMatDescr, arrayfire::cuda::getCusparsePlugin().cusparseDestroyMatDescr);
#if defined(AF_USE_NEW_CUSPARSE_API)
DEFINE_HANDLER(cusparseSpMatDescr_t, arrayfire::cuda::createSpMatDescr, arrayfire::cuda::getCusparsePlugin().cusparseDestroySpMat);
DEFINE_HANDLER(cusparseDnVecDescr_t, arrayfire::cuda::getCusparsePlugin().cusparseCreateDnVec, arrayfire::cuda::getCusparsePlugin().cusparseDestroyDnVec);
DEFINE_HANDLER(cusparseDnMatDescr_t, arrayfire::cuda::getCusparsePlugin().cusparseCreateDnMat, arrayfire::cuda::getCusparsePlugin().cusparseDestroyDnMat);
#endif
// clang-format on
namespace arrayfire {
namespace cuda {
const char *errorString(cusparseStatus_t err);
#define CUSPARSE_CHECK(fn) \
do { \
cusparseStatus_t _error = fn; \
if (_error != CUSPARSE_STATUS_SUCCESS) { \
char _err_msg[1024]; \
snprintf(_err_msg, sizeof(_err_msg), "CUSPARSE Error (%d): %s\n", \
(int)(_error), arrayfire::cuda::errorString(_error)); \
\
AF_ERROR(_err_msg, AF_ERR_INTERNAL); \
} \
} while (0)
} // namespace cuda
} // namespace arrayfire