Skip to content

Commit b6de1e7

Browse files
committed
Refactored sparse class to use af_array
* Added SparseArray and SparseArrayBase wrapper classes * Added functions to handle/get etc
1 parent b50dfb8 commit b6de1e7

15 files changed

Lines changed: 918 additions & 338 deletions

include/af/sparse.h

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
#pragma once
1111
#include <af/defines.h>
1212

13-
typedef void * af_sparse_array;
14-
1513
#ifdef __cplusplus
1614
namespace af
1715
{
@@ -25,46 +23,42 @@ extern "C" {
2523
#endif
2624

2725
AFAPI af_err af_create_sparse_array(
28-
af_sparse_array *out,
26+
af_array *out,
2927
const dim_t nRows, const dim_t nCols, const dim_t nNZ,
3028
const af_array values, const af_array rowIdx, const af_array colIdx,
3129
const af_sparse_storage storage);
3230

33-
AFAPI af_err af_create_sparse_array_from_host(
34-
af_sparse_array *out,
31+
AFAPI af_err af_create_sparse_array_from_ptr(
32+
af_array *out,
3533
const dim_t nRows, const dim_t nCols, const dim_t nNZ,
3634
const void * const values,
3735
const int * const rowIdx, const int * const colIdx,
38-
const af_dtype type, const af_sparse_storage storage);
36+
const af_dtype type, const af_sparse_storage storage,
37+
const af_source source);
3938

4039
AFAPI af_err af_create_sparse_array_from_dense(
41-
af_sparse_array *out, const af_array in,
40+
af_array *out, const af_array in,
4241
const af_sparse_storage storage);
4342

44-
AFAPI af_err af_retain_sparse_array(af_sparse_array *out, const af_sparse_array in);
43+
AFAPI af_err af_sparse_get_arrays(af_array *values, af_array *rows, af_array *cols, const af_array in);
4544

46-
AFAPI af_err af_sparse_get_values(af_array *out, const af_sparse_array in);
45+
AFAPI af_err af_sparse_get_values(af_array *out, const af_array in);
4746

48-
AFAPI af_err af_sparse_get_rows(af_array *out, const af_sparse_array in);
47+
AFAPI af_err af_sparse_get_rows(af_array *out, const af_array in);
4948

50-
AFAPI af_err af_sparse_get_cols(af_array *out, const af_sparse_array in);
49+
AFAPI af_err af_sparse_get_cols(af_array *out, const af_array in);
5150

52-
AFAPI af_err af_sparse_get_num_values(dim_t *out, const af_sparse_array in);
51+
AFAPI af_err af_sparse_get_num_values(dim_t *out, const af_array in);
5352

54-
AFAPI af_err af_sparse_get_num_rows(dim_t *out, const af_sparse_array in);
53+
AFAPI af_err af_sparse_get_num_rows(dim_t *out, const af_array in);
5554

56-
AFAPI af_err af_sparse_get_num_cols(dim_t *out, const af_sparse_array in);
55+
AFAPI af_err af_sparse_get_num_cols(dim_t *out, const af_array in);
5756

58-
AFAPI af_err af_sparse_get_storage(af_sparse_storage *out, const af_sparse_array in);
57+
AFAPI af_err af_sparse_get_storage(af_sparse_storage *out, const af_array in);
5958

60-
AFAPI af_err af_sparse_convert_storage(af_sparse_array *out, const af_sparse_array in,
59+
AFAPI af_err af_sparse_convert_storage(af_array *out, const af_array in,
6160
const af_sparse_storage destStorage);
6261

63-
AFAPI af_err af_release_sparse_array(af_sparse_array in);
64-
65-
AFAPI af_err af_sparse_matmul(af_sparse_array *out,
66-
const af_sparse_array lhs, const af_sparse_array rhs,
67-
const af_mat_prop optLhs, const af_mat_prop optRhs);
6862
#ifdef __cplusplus
6963
}
7064
#endif

src/api/c/array.cpp

Lines changed: 66 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,22 @@
1111
#include <platform.hpp>
1212
#include <handle.hpp>
1313
#include <backend.hpp>
14+
#include <sparse_handle.hpp>
1415

1516
using namespace detail;
1617

1718
const ArrayInfo&
18-
getInfo(const af_array arr, bool check)
19-
{
20-
const ArrayInfo *info = static_cast<ArrayInfo*>(reinterpret_cast<void *>(arr));
21-
22-
// Check Sparse
23-
ARG_ASSERT(0, info->isSparse() == false);
24-
25-
if (check && info->getDevId() != detail::getActiveDeviceId()) {
26-
AF_ERROR("Input Array not created on current device", AF_ERR_DEVICE);
27-
}
28-
29-
return *info;
30-
}
31-
32-
const ArrayInfo&
33-
getSparseInfo(const af_array arr, bool sparseCheck, bool check)
19+
getInfo(const af_array arr, bool sparse_check, bool device_check)
3420
{
3521
const ArrayInfo *info = static_cast<ArrayInfo*>(reinterpret_cast<void *>(arr));
3622

3723
// Check Sparse -> If false, then both standard Array<T> and SparseArray<T> are accepted
38-
if(sparseCheck) {
39-
ARG_ASSERT(0, info->isSparse() == true);
24+
// Otherwise only regular Array<T> is accepted
25+
if(sparse_check) {
26+
ARG_ASSERT(0, info->isSparse() == false);
4027
}
4128

42-
if (check && info->getDevId() != detail::getActiveDeviceId()) {
29+
if (device_check && info->getDevId() != detail::getActiveDeviceId()) {
4330
AF_ERROR("Input Array not created on current device", AF_ERR_DEVICE);
4431
}
4532

@@ -169,7 +156,7 @@ af_err af_copy_array(af_array *out, const af_array in)
169156
af_err af_get_data_ref_count(int *use_count, const af_array in)
170157
{
171158
try {
172-
ArrayInfo info = getSparseInfo(in, false, false);
159+
ArrayInfo info = getInfo(in, false, false);
173160
const af_dtype type = info.getType();
174161

175162
int res;
@@ -199,29 +186,39 @@ af_err af_release_array(af_array arr)
199186
try {
200187
int dev = getActiveDeviceId();
201188

202-
ArrayInfo info = getSparseInfo(arr, false, false);
203-
204-
setDevice(info.getDevId());
205-
189+
ArrayInfo info = getInfo(arr, false, false);
206190
af_dtype type = info.getType();
207191

208-
switch(type) {
209-
case f32: releaseHandle<float >(arr); break;
210-
case c32: releaseHandle<cfloat >(arr); break;
211-
case f64: releaseHandle<double >(arr); break;
212-
case c64: releaseHandle<cdouble >(arr); break;
213-
case b8: releaseHandle<char >(arr); break;
214-
case s32: releaseHandle<int >(arr); break;
215-
case u32: releaseHandle<uint >(arr); break;
216-
case u8: releaseHandle<uchar >(arr); break;
217-
case s64: releaseHandle<intl >(arr); break;
218-
case u64: releaseHandle<uintl >(arr); break;
219-
case s16: releaseHandle<short >(arr); break;
220-
case u16: releaseHandle<ushort >(arr); break;
221-
default: TYPE_ERROR(0, type);
192+
if(info.isSparse()) {
193+
switch(type) {
194+
case f32: releaseSparseHandle<float >(arr); break;
195+
case f64: releaseSparseHandle<double >(arr); break;
196+
case c32: releaseSparseHandle<cfloat >(arr); break;
197+
case c64: releaseSparseHandle<cdouble>(arr); break;
198+
default : TYPE_ERROR(0, type);
199+
}
200+
} else {
201+
202+
setDevice(info.getDevId());
203+
204+
switch(type) {
205+
case f32: releaseHandle<float >(arr); break;
206+
case c32: releaseHandle<cfloat >(arr); break;
207+
case f64: releaseHandle<double >(arr); break;
208+
case c64: releaseHandle<cdouble >(arr); break;
209+
case b8: releaseHandle<char >(arr); break;
210+
case s32: releaseHandle<int >(arr); break;
211+
case u32: releaseHandle<uint >(arr); break;
212+
case u8: releaseHandle<uchar >(arr); break;
213+
case s64: releaseHandle<intl >(arr); break;
214+
case u64: releaseHandle<uintl >(arr); break;
215+
case s16: releaseHandle<short >(arr); break;
216+
case u16: releaseHandle<ushort >(arr); break;
217+
default: TYPE_ERROR(0, type);
218+
}
219+
220+
setDevice(dev);
222221
}
223-
224-
setDevice(dev);
225222
}
226223
CATCHALL
227224

@@ -240,22 +237,33 @@ static af_array retainHandle(const af_array in)
240237

241238
af_array retain(const af_array in)
242239
{
243-
af_dtype ty = getSparseInfo(in, false, false).getType();
244-
switch(ty) {
245-
case f32: return retainHandle<float >(in);
246-
case f64: return retainHandle<double >(in);
247-
case s32: return retainHandle<int >(in);
248-
case u32: return retainHandle<uint >(in);
249-
case u8: return retainHandle<uchar >(in);
250-
case c32: return retainHandle<detail::cfloat >(in);
251-
case c64: return retainHandle<detail::cdouble >(in);
252-
case b8: return retainHandle<char >(in);
253-
case s64: return retainHandle<intl >(in);
254-
case u64: return retainHandle<uintl >(in);
255-
case s16: return retainHandle<short >(in);
256-
case u16: return retainHandle<ushort >(in);
257-
default:
258-
TYPE_ERROR(1, ty);
240+
ArrayInfo info = getInfo(in, false, false);
241+
af_dtype ty = info.getType();
242+
243+
if(info.isSparse()) {
244+
switch(ty) {
245+
case f32: return retainSparseHandle<float >(in);
246+
case f64: return retainSparseHandle<double >(in);
247+
case c32: return retainSparseHandle<detail::cfloat >(in);
248+
case c64: return retainSparseHandle<detail::cdouble>(in);
249+
default: TYPE_ERROR(1, ty);
250+
}
251+
} else {
252+
switch(ty) {
253+
case f32: return retainHandle<float >(in);
254+
case f64: return retainHandle<double >(in);
255+
case s32: return retainHandle<int >(in);
256+
case u32: return retainHandle<uint >(in);
257+
case u8: return retainHandle<uchar >(in);
258+
case c32: return retainHandle<detail::cfloat >(in);
259+
case c64: return retainHandle<detail::cdouble >(in);
260+
case b8: return retainHandle<char >(in);
261+
case s64: return retainHandle<intl >(in);
262+
case u64: return retainHandle<uintl >(in);
263+
case s16: return retainHandle<short >(in);
264+
case u16: return retainHandle<ushort >(in);
265+
default: TYPE_ERROR(1, ty);
266+
}
259267
}
260268
}
261269

@@ -309,7 +317,7 @@ af_err af_get_elements(dim_t *elems, const af_array arr)
309317
{
310318
try {
311319
// Do not check for device mismatch
312-
*elems = getSparseInfo(arr, false, false).elements();
320+
*elems = getInfo(arr, false, false).elements();
313321
} CATCHALL
314322
return AF_SUCCESS;
315323
}
@@ -355,7 +363,7 @@ af_err af_get_numdims(unsigned *nd, const af_array in)
355363
af_err fn1(bool *result, const af_array in) \
356364
{ \
357365
try { \
358-
ArrayInfo info = getSparseInfo(in, false, false); \
366+
ArrayInfo info = getInfo(in, false, false); \
359367
*result = info.fn2(); \
360368
} \
361369
CATCHALL \

src/api/c/blas.cpp

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,19 @@
1414
#include <af/array.h>
1515
#include <af/defines.h>
1616
#include <ArrayInfo.hpp>
17+
#include <sparse_handle.hpp>
18+
#include <sparse_blas.hpp>
1719
#include <err_common.hpp>
1820
#include <backend.hpp>
1921

22+
template<typename T>
23+
static inline af_array sparseMatmul(const af_array lhs, const af_array rhs,
24+
af_mat_prop optLhs, af_mat_prop optRhs)
25+
{
26+
return getHandle(detail::matmul<T>(getSparseArray<T>(lhs), getArray<T>(rhs),
27+
optLhs, optRhs));
28+
}
29+
2030
template<typename T>
2131
static inline af_array matmul(const af_array lhs, const af_array rhs,
2232
af_mat_prop optLhs, af_mat_prop optRhs)
@@ -31,16 +41,74 @@ static inline af_array dot(const af_array lhs, const af_array rhs,
3141
return getHandle(detail::dot<T>(getArray<T>(lhs), getArray<T>(rhs), optLhs, optRhs));
3242
}
3343

34-
af_err af_matmul( af_array *out,
35-
const af_array lhs, const af_array rhs,
36-
const af_mat_prop optLhs, const af_mat_prop optRhs)
44+
af_err af_sparse_matmul(af_array *out,
45+
const af_array lhs, const af_array rhs,
46+
const af_mat_prop optLhs, const af_mat_prop optRhs)
3747
{
3848
using namespace detail;
3949

4050
try {
41-
ArrayInfo lhsInfo = getInfo(lhs);
51+
common::SparseArrayBase lhsInfo = getSparseArrayBase(lhs);
4252
ArrayInfo rhsInfo = getInfo(rhs);
4353

54+
ARG_ASSERT(2, lhsInfo.isSparse() == true && rhsInfo.isSparse() == false);
55+
56+
af_dtype lhs_type = lhsInfo.getType();
57+
af_dtype rhs_type = rhsInfo.getType();
58+
59+
ARG_ASSERT(1, lhsInfo.getStorage() == AF_SPARSE_CSR);
60+
61+
if (!(optLhs == AF_MAT_NONE ||
62+
optLhs == AF_MAT_TRANS ||
63+
optLhs == AF_MAT_CTRANS)) { // Note the ! operator.
64+
AF_ERROR("Using this property is not yet supported in sparse matmul", AF_ERR_NOT_SUPPORTED);
65+
}
66+
67+
// No transpose options for RHS
68+
if (optRhs != AF_MAT_NONE) {
69+
AF_ERROR("Using this property is not yet supported in matmul", AF_ERR_NOT_SUPPORTED);
70+
}
71+
72+
if (rhsInfo.ndims() > 2) {
73+
AF_ERROR("Sparse matmul can not be used in batch mode", AF_ERR_BATCH);
74+
}
75+
76+
TYPE_ASSERT(lhs_type == rhs_type);
77+
78+
af::dim4 ldims = lhsInfo.dims();
79+
int lColDim = (optLhs == AF_MAT_NONE) ? 1 : 0;
80+
int rRowDim = (optRhs == AF_MAT_NONE) ? 0 : 1;
81+
82+
DIM_ASSERT(1, ldims[lColDim] == rhsInfo.dims()[rRowDim]);
83+
84+
af_array output = 0;
85+
switch(lhs_type) {
86+
case f32: output = sparseMatmul<float >(lhs, rhs, optLhs, optRhs); break;
87+
case c32: output = sparseMatmul<cfloat >(lhs, rhs, optLhs, optRhs); break;
88+
case f64: output = sparseMatmul<double >(lhs, rhs, optLhs, optRhs); break;
89+
case c64: output = sparseMatmul<cdouble>(lhs, rhs, optLhs, optRhs); break;
90+
default: TYPE_ERROR(1, lhs_type);
91+
}
92+
std::swap(*out, output);
93+
94+
} CATCHALL;
95+
96+
return AF_SUCCESS;
97+
}
98+
99+
af_err af_matmul(af_array *out,
100+
const af_array lhs, const af_array rhs,
101+
const af_mat_prop optLhs, const af_mat_prop optRhs)
102+
{
103+
using namespace detail;
104+
105+
try {
106+
ArrayInfo lhsInfo = getInfo(lhs, false, false);
107+
ArrayInfo rhsInfo = getInfo(rhs, true, false);
108+
109+
if(lhsInfo.isSparse())
110+
return af_sparse_matmul(out, lhs, rhs, optLhs, optRhs);
111+
44112
af_dtype lhs_type = lhsInfo.getType();
45113
af_dtype rhs_type = rhsInfo.getType();
46114

@@ -71,11 +139,11 @@ af_err af_matmul( af_array *out,
71139
DIM_ASSERT(1, lhsInfo.dims()[aColDim] == rhsInfo.dims()[bRowDim]);
72140

73141
switch(lhs_type) {
74-
case f32: output = matmul<float >(lhs, rhs, optLhs, optRhs); break;
75-
case c32: output = matmul<cfloat >(lhs, rhs, optLhs, optRhs); break;
76-
case f64: output = matmul<double >(lhs, rhs, optLhs, optRhs); break;
77-
case c64: output = matmul<cdouble>(lhs, rhs, optLhs, optRhs); break;
78-
default: TYPE_ERROR(1, lhs_type);
142+
case f32: output = matmul<float >(lhs, rhs, optLhs, optRhs); break;
143+
case c32: output = matmul<cfloat >(lhs, rhs, optLhs, optRhs); break;
144+
case f64: output = matmul<double >(lhs, rhs, optLhs, optRhs); break;
145+
case c64: output = matmul<cdouble>(lhs, rhs, optLhs, optRhs); break;
146+
default: TYPE_ERROR(1, lhs_type);
79147
}
80148
std::swap(*out, output);
81149
}

src/api/c/handle.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include <cast.hpp>
1818
#include <af/dim4.hpp>
1919

20-
const ArrayInfo& getInfo(const af_array arr, bool check = true);
20+
const ArrayInfo& getInfo(const af_array arr, bool device_check = true, bool sparse_check = true);
2121

2222
template<typename T>
2323
static const detail::Array<T> &

0 commit comments

Comments
 (0)