Skip to content

Commit 450340b

Browse files
author
pradeep
committed
FEAT: sparse-sparse add/sub support
CPU and OpenCL backends have support for mul/div but they are disabled to have feature parity with CUDA which doesn't have support for mul/div. The output of sub/div/mul is not guaranteed to have only non-zero results of the arithmetic operation. The user has to take care of pruning the zero results from the output.
1 parent fad8d38 commit 450340b

12 files changed

Lines changed: 665 additions & 89 deletions

src/api/c/binary.cpp

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ static inline af_array arithOp(const af_array lhs, const af_array rhs,
3535
return res;
3636
}
3737

38+
template<typename T, af_op_t op>
39+
static inline
40+
af_array sparseArithOp(const af_array lhs, const af_array rhs)
41+
{
42+
auto res = arithOp<T, op>(getSparseArray<T>(lhs), getSparseArray<T>(rhs));
43+
return getHandle(res);
44+
}
45+
3846
template<typename T, af_op_t op>
3947
static inline af_array arithSparseDenseOp(const af_array lhs, const af_array rhs,
4048
const bool reverse)
@@ -80,10 +88,11 @@ static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, co
8088
}
8189

8290
template<af_op_t op>
83-
static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
91+
static
92+
af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs,
93+
const bool batchMode)
8494
{
8595
try {
86-
8796
const ArrayInfo& linfo = getInfo(lhs);
8897
const ArrayInfo& rinfo = getInfo(rhs);
8998

@@ -111,38 +120,41 @@ static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rh
111120
return AF_SUCCESS;
112121
}
113122

114-
//template<af_op_t op>
115-
//static af_err af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
116-
//{
117-
// try {
118-
// SparseArrayBase linfo = getSparseArrayBase(lhs);
119-
// SparseArrayBase rinfo = getSparseArrayBase(rhs);
120-
//
121-
// dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
122-
//
123-
// const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
124-
// af_array res;
125-
// switch (otype) {
126-
// case f32: res = arithOp<float , op>(lhs, rhs, odims); break;
127-
// case f64: res = arithOp<double , op>(lhs, rhs, odims); break;
128-
// case c32: res = arithOp<cfloat , op>(lhs, rhs, odims); break;
129-
// case c64: res = arithOp<cdouble, op>(lhs, rhs, odims); break;
130-
// default: TYPE_ERROR(0, otype);
131-
// }
132-
//
133-
// std::swap(*out, res);
134-
// }
135-
// CATCHALL;
136-
// return AF_SUCCESS;
137-
//}
123+
template<af_op_t op>
124+
static af_err
125+
af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
126+
{
127+
try {
128+
common::SparseArrayBase linfo = getSparseArrayBase(lhs);
129+
common::SparseArrayBase rinfo = getSparseArrayBase(rhs);
130+
131+
ARG_ASSERT(1, (linfo.getStorage()==rinfo.getStorage()));
132+
ARG_ASSERT(1, (linfo.dims()==rinfo.dims()));
133+
ARG_ASSERT(1, (linfo.getStorage()==AF_STORAGE_CSR));
134+
135+
const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
136+
af_array res;
137+
switch (otype) {
138+
case f32: res = sparseArithOp<float , op>(lhs, rhs); break;
139+
case f64: res = sparseArithOp<double , op>(lhs, rhs); break;
140+
case c32: res = sparseArithOp<cfloat , op>(lhs, rhs); break;
141+
case c64: res = sparseArithOp<cdouble, op>(lhs, rhs); break;
142+
default: TYPE_ERROR(0, otype);
143+
}
144+
145+
std::swap(*out, res);
146+
}
147+
CATCHALL;
148+
return AF_SUCCESS;
149+
}
138150

139151
template<af_op_t op>
140152
static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_array rhs,
141153
const bool reverse = false)
142154
{
143155
using namespace common;
144156
try {
145-
SparseArrayBase linfo = getSparseArrayBase(lhs);
157+
common::SparseArrayBase linfo = getSparseArrayBase(lhs);
146158
ArrayInfo rinfo = getInfo(rhs);
147159

148160
const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
@@ -161,18 +173,20 @@ static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_
161173
return AF_SUCCESS;
162174
}
163175

164-
af_err af_add(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
176+
af_err af_add(af_array *out, const af_array lhs, const af_array rhs,
177+
const bool batchMode)
165178
{
166179
// Check if inputs are sparse
167180
ArrayInfo linfo = getInfo(lhs, false, true);
168181
ArrayInfo rinfo = getInfo(rhs, false, true);
169182

170183
if(linfo.isSparse() && rinfo.isSparse()) {
171-
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_add_t>(out, lhs, rhs);
184+
return af_arith_sparse<af_add_t>(out, lhs, rhs);
172185
} else if(linfo.isSparse() && !rinfo.isSparse()) {
173186
return af_arith_sparse_dense<af_add_t>(out, lhs, rhs);
174187
} else if(!linfo.isSparse() && rinfo.isSparse()) {
175-
return af_arith_sparse_dense<af_add_t>(out, rhs, lhs, true); // dense should be rhs
188+
// second operand(Array) of af_arith call should be dense
189+
return af_arith_sparse_dense<af_add_t>(out, rhs, lhs, true);
176190
} else {
177191
return af_arith<af_add_t>(out, lhs, rhs, batchMode);
178192
}
@@ -185,7 +199,10 @@ af_err af_mul(af_array *out, const af_array lhs, const af_array rhs, const bool
185199
ArrayInfo rinfo = getInfo(rhs, false, true);
186200

187201
if(linfo.isSparse() && rinfo.isSparse()) {
188-
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_mul_t>(out, lhs, rhs);
202+
//return af_arith_sparse<af_mul_t>(out, lhs, rhs);
203+
//MKL doesn't have mul or div support yet, hence
204+
//this is commented out although alternative cpu code exists
205+
return AF_ERR_NOT_SUPPORTED;
189206
} else if(linfo.isSparse() && !rinfo.isSparse()) {
190207
return af_arith_sparse_dense<af_mul_t>(out, lhs, rhs);
191208
} else if(!linfo.isSparse() && rinfo.isSparse()) {
@@ -202,7 +219,7 @@ af_err af_sub(af_array *out, const af_array lhs, const af_array rhs, const bool
202219
ArrayInfo rinfo = getInfo(rhs, false, true);
203220

204221
if(linfo.isSparse() && rinfo.isSparse()) {
205-
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_sub_t>(out, lhs, rhs);
222+
return af_arith_sparse<af_sub_t>(out, lhs, rhs);
206223
} else if(linfo.isSparse() && !rinfo.isSparse()) {
207224
return af_arith_sparse_dense<af_sub_t>(out, lhs, rhs);
208225
} else if(!linfo.isSparse() && rinfo.isSparse()) {
@@ -219,7 +236,10 @@ af_err af_div(af_array *out, const af_array lhs, const af_array rhs, const bool
219236
ArrayInfo rinfo = getInfo(rhs, false, true);
220237

221238
if(linfo.isSparse() && rinfo.isSparse()) {
222-
return AF_ERR_NOT_SUPPORTED; //af_arith_sparse<af_div_t>(out, lhs, rhs);
239+
//return af_arith_sparse<af_div_t>(out, lhs, rhs);
240+
//MKL doesn't have mul or div support yet, hence
241+
//this is commented out although alternative cpu code exists
242+
return AF_ERR_NOT_SUPPORTED;
223243
} else if(linfo.isSparse() && !rinfo.isSparse()) {
224244
return af_arith_sparse_dense<af_div_t>(out, lhs, rhs);
225245
} else if(!linfo.isSparse() && rinfo.isSparse()) {

src/backend/cpu/kernel/sparse_arith.hpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <Param.hpp>
1212
#include <math.hpp>
1313

14+
#include <cmath>
15+
1416
namespace cpu
1517
{
1618
namespace kernel
@@ -143,5 +145,104 @@ void sparseArithOpS(Param<T> values, Param<int> rowIdx, Param<int> colIdx,
143145
}
144146
}
145147

148+
// The following functions can handle CSR
149+
// storage format only as of now.
150+
static
151+
void calcOutNNZ(Param<int> outRowIdx,
152+
const uint M, const uint N,
153+
CParam<int> lRowIdx, CParam<int> lColIdx,
154+
CParam<int> rRowIdx, CParam<int> rColIdx)
155+
{
156+
int *orPtr = outRowIdx.get();
157+
const int *lrPtr = lRowIdx.get();
158+
const int *lcPtr = lColIdx.get();
159+
const int *rrPtr = rRowIdx.get();
160+
const int *rcPtr = rColIdx.get();
161+
162+
unsigned csrOutCount = 0;
163+
for (uint row=0; row<M; ++row) {
164+
const int lEnd = lrPtr[row+1];
165+
const int rEnd = rrPtr[row+1];
166+
167+
uint rowNNZ = 0;
168+
int l = lrPtr[row];
169+
int r = rrPtr[row];
170+
while (l < lEnd && r < rEnd) {
171+
int lci = lcPtr[l];
172+
int rci = rcPtr[r];
173+
174+
l += (lci <= rci);
175+
r += (lci >= rci);
176+
rowNNZ++;
177+
}
178+
// Elements from lhs or rhs are exhausted.
179+
// Just count left over elements
180+
rowNNZ += (lEnd-l);
181+
rowNNZ += (rEnd-r);
182+
183+
orPtr[row] = csrOutCount;
184+
csrOutCount += rowNNZ;
185+
}
186+
//Write out the Rows+1 entry
187+
orPtr[M] = csrOutCount;
188+
}
189+
190+
template<typename T, af_op_t op>
191+
void sparseArithOp(Param<T> oVals, Param<int> oColIdx,
192+
CParam<int> oRowIdx, const uint Rows,
193+
CParam<T> lvals, CParam<int> lRowIdx, CParam<int> lColIdx,
194+
CParam<T> rvals, CParam<int> rRowIdx, CParam<int> rColIdx)
195+
{
196+
const int *orPtr = oRowIdx.get();
197+
const T *lvPtr = lvals.get();
198+
const int *lrPtr = lRowIdx.get();
199+
const int *lcPtr = lColIdx.get();
200+
const T *rvPtr = rvals.get();
201+
const int *rrPtr = rRowIdx.get();
202+
const int *rcPtr = rColIdx.get();
203+
204+
arith_op<T, op> binOp;
205+
206+
auto ZERO = scalar<T>(0);
207+
208+
for (uint row=0; row<Rows; ++row) {
209+
const int lEnd = lrPtr[row+1];
210+
const int rEnd = rrPtr[row+1];
211+
const int offs = orPtr[row];
212+
213+
T *ovPtr = oVals.get() + offs;
214+
int *ocPtr = oColIdx.get() + offs;
215+
216+
uint rowNNZ = 0;
217+
int l = lrPtr[row];
218+
int r = rrPtr[row];
219+
while (l < lEnd && r < rEnd) {
220+
int lci = lcPtr[l];
221+
int rci = rcPtr[r];
222+
223+
T lhs = (lci <= rci ? lvPtr[l] : ZERO);
224+
T rhs = (lci >= rci ? rvPtr[r] : ZERO);
225+
226+
ovPtr[ rowNNZ ] = binOp(lhs, rhs);
227+
ocPtr[ rowNNZ ] = (lci <= rci) ? lci : rci;
228+
229+
l += (lci <= rci);
230+
r += (lci >= rci);
231+
rowNNZ++;
232+
}
233+
while (l < lEnd) {
234+
ovPtr[ rowNNZ ] = binOp(lvPtr[l], ZERO);
235+
ocPtr[ rowNNZ ] = lcPtr[l];
236+
l++;
237+
rowNNZ++;
238+
}
239+
while (r < rEnd) {
240+
ovPtr[ rowNNZ ] = binOp(ZERO, rvPtr[r]);
241+
ocPtr[ rowNNZ ] = rcPtr[r];
242+
r++;
243+
rowNNZ++;
244+
}
245+
}
246+
}
146247
}
147248
}

src/backend/cpu/sparse_arith.cpp

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,8 @@
99

1010
#include <sparse_arith.hpp>
1111
#include <common/SparseArray.hpp>
12-
#include <optypes.hpp>
1312
#include <sparse.hpp>
14-
15-
#include <kernel/sparse_arith.hpp>
16-
17-
#include <stdexcept>
18-
#include <string>
19-
13+
#include <optypes.hpp>
2014
#include <af/dim4.hpp>
2115
#include <arith.hpp>
2216
#include <complex.hpp>
@@ -26,6 +20,13 @@
2620
#include <platform.hpp>
2721
#include <queue.hpp>
2822

23+
#include <kernel/sparse_arith.hpp>
24+
25+
#include <algorithm>
26+
#include <stdexcept>
27+
#include <string>
28+
#include <vector>
29+
2930
namespace cpu
3031
{
3132

@@ -115,6 +116,39 @@ SparseArray<T> arithOpS(const SparseArray<T> &lhs, const Array<T> &rhs, const bo
115116
return out;
116117
}
117118

119+
template<typename T, af_op_t op>
120+
SparseArray<T> arithOp(const SparseArray<T> &lhs, const SparseArray<T> &rhs)
121+
{
122+
af::storage sfmt = lhs.getStorage();
123+
124+
lhs.eval();
125+
rhs.eval();
126+
127+
const dim4 dims = lhs.dims();
128+
const uint M = dims[0];
129+
const uint N = dims[1];
130+
131+
auto rowArr = createEmptyArray<int>(dim4(M+1));
132+
133+
getQueue().enqueue(kernel::calcOutNNZ, rowArr, M, N,
134+
lhs.getRowIdx(), lhs.getColIdx(),
135+
rhs.getRowIdx(), rhs.getColIdx());
136+
getQueue().sync();
137+
138+
uint nnz = rowArr.get()[M];
139+
auto out = createEmptySparseArray<T>(dims, nnz, sfmt);
140+
out.eval();
141+
142+
copyArray(out.getRowIdx(), rowArr);
143+
144+
getQueue().enqueue(kernel::sparseArithOp<T, op>,
145+
out.getValues(), out.getColIdx(),
146+
out.getRowIdx(), M,
147+
lhs.getValues(), lhs.getRowIdx(), lhs.getColIdx(),
148+
rhs.getValues(), rhs.getRowIdx(), rhs.getColIdx());
149+
return out;
150+
}
151+
118152
#define INSTANTIATE(T) \
119153
template Array<T> arithOpD<T, af_add_t>(const SparseArray<T> &lhs, const Array<T> &rhs, \
120154
const bool reverse); \
@@ -132,6 +166,14 @@ SparseArray<T> arithOpS(const SparseArray<T> &lhs, const Array<T> &rhs, const bo
132166
const bool reverse); \
133167
template SparseArray<T> arithOpS<T, af_div_t>(const SparseArray<T> &lhs, const Array<T> &rhs, \
134168
const bool reverse); \
169+
template SparseArray<T> arithOp<T, af_add_t>(const common::SparseArray<T> &lhs, \
170+
const common::SparseArray<T> &rhs); \
171+
template SparseArray<T> arithOp<T, af_sub_t>(const common::SparseArray<T> &lhs, \
172+
const common::SparseArray<T> &rhs); \
173+
template SparseArray<T> arithOp<T, af_mul_t>(const common::SparseArray<T> &lhs, \
174+
const common::SparseArray<T> &rhs); \
175+
template SparseArray<T> arithOp<T, af_div_t>(const common::SparseArray<T> &lhs, \
176+
const common::SparseArray<T> &rhs);
135177

136178
INSTANTIATE(float )
137179
INSTANTIATE(double )

src/backend/cpu/sparse_arith.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
* http://arrayfire.com/licenses/BSD-3-Clause
88
********************************************************/
99

10+
#pragma once
11+
1012
#include <Array.hpp>
1113
#include <common/SparseArray.hpp>
1214
#include <sparse.hpp>
1315
#include <optypes.hpp>
1416

1517
namespace cpu
1618
{
17-
1819
// These two functions cannot be overloaded by return type.
1920
// So have to give them separate names.
2021
template<typename T, af_op_t op>
@@ -25,4 +26,7 @@ template<typename T, af_op_t op>
2526
common::SparseArray<T> arithOpS(const common::SparseArray<T> &lhs, const Array<T> &rhs,
2627
const bool reverse = false);
2728

29+
template<typename T, af_op_t op>
30+
common::SparseArray<T> arithOp(const common::SparseArray<T> &lhs,
31+
const common::SparseArray<T> &rhs);
2832
}

0 commit comments

Comments
 (0)