@@ -35,6 +35,14 @@ static inline af_array arithOp(const af_array lhs, const af_array rhs,
3535 return res;
3636}
3737
38+ template <typename T, af_op_t op>
39+ static inline
40+ af_array sparseArithOp (const af_array lhs, const af_array rhs)
41+ {
42+ auto res = arithOp<T, op>(getSparseArray<T>(lhs), getSparseArray<T>(rhs));
43+ return getHandle (res);
44+ }
45+
3846template <typename T, af_op_t op>
3947static inline af_array arithSparseDenseOp (const af_array lhs, const af_array rhs,
4048 const bool reverse)
@@ -80,10 +88,11 @@ static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, co
8088}
8189
8290template <af_op_t op>
83- static af_err af_arith_real (af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
91+ static
92+ af_err af_arith_real (af_array *out, const af_array lhs, const af_array rhs,
93+ const bool batchMode)
8494{
8595 try {
86-
8796 const ArrayInfo& linfo = getInfo (lhs);
8897 const ArrayInfo& rinfo = getInfo (rhs);
8998
@@ -111,38 +120,41 @@ static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rh
111120 return AF_SUCCESS;
112121}
113122
114- // template<af_op_t op>
115- // static af_err af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs)
116- // {
117- // try {
118- // SparseArrayBase linfo = getSparseArrayBase(lhs);
119- // SparseArrayBase rinfo = getSparseArrayBase(rhs);
120- //
121- // dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode);
122- //
123- // const af_dtype otype = implicit(linfo.getType(), rinfo.getType());
124- // af_array res;
125- // switch (otype) {
126- // case f32: res = arithOp<float , op>(lhs, rhs, odims); break;
127- // case f64: res = arithOp<double , op>(lhs, rhs, odims); break;
128- // case c32: res = arithOp<cfloat , op>(lhs, rhs, odims); break;
129- // case c64: res = arithOp<cdouble, op>(lhs, rhs, odims); break;
130- // default: TYPE_ERROR(0, otype);
131- // }
132- //
133- // std::swap(*out, res);
134- // }
135- // CATCHALL;
136- // return AF_SUCCESS;
137- // }
123+ template <af_op_t op>
124+ static af_err
125+ af_arith_sparse (af_array *out, const af_array lhs, const af_array rhs)
126+ {
127+ try {
128+ common::SparseArrayBase linfo = getSparseArrayBase (lhs);
129+ common::SparseArrayBase rinfo = getSparseArrayBase (rhs);
130+
131+ ARG_ASSERT (1 , (linfo.getStorage ()==rinfo.getStorage ()));
132+ ARG_ASSERT (1 , (linfo.dims ()==rinfo.dims ()));
133+ ARG_ASSERT (1 , (linfo.getStorage ()==AF_STORAGE_CSR));
134+
135+ const af_dtype otype = implicit (linfo.getType (), rinfo.getType ());
136+ af_array res;
137+ switch (otype) {
138+ case f32 : res = sparseArithOp<float , op>(lhs, rhs); break ;
139+ case f64 : res = sparseArithOp<double , op>(lhs, rhs); break ;
140+ case c32: res = sparseArithOp<cfloat , op>(lhs, rhs); break ;
141+ case c64: res = sparseArithOp<cdouble, op>(lhs, rhs); break ;
142+ default : TYPE_ERROR (0 , otype);
143+ }
144+
145+ std::swap (*out, res);
146+ }
147+ CATCHALL;
148+ return AF_SUCCESS;
149+ }
138150
139151template <af_op_t op>
140152static af_err af_arith_sparse_dense (af_array *out, const af_array lhs, const af_array rhs,
141153 const bool reverse = false )
142154{
143155 using namespace common ;
144156 try {
145- SparseArrayBase linfo = getSparseArrayBase (lhs);
157+ common:: SparseArrayBase linfo = getSparseArrayBase (lhs);
146158 ArrayInfo rinfo = getInfo (rhs);
147159
148160 const af_dtype otype = implicit (linfo.getType (), rinfo.getType ());
@@ -161,18 +173,20 @@ static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_
161173 return AF_SUCCESS;
162174}
163175
164- af_err af_add (af_array *out, const af_array lhs, const af_array rhs, const bool batchMode)
176+ af_err af_add (af_array *out, const af_array lhs, const af_array rhs,
177+ const bool batchMode)
165178{
166179 // Check if inputs are sparse
167180 ArrayInfo linfo = getInfo (lhs, false , true );
168181 ArrayInfo rinfo = getInfo (rhs, false , true );
169182
170183 if (linfo.isSparse () && rinfo.isSparse ()) {
171- return AF_ERR_NOT_SUPPORTED; // af_arith_sparse<af_add_t>(out, lhs, rhs);
184+ return af_arith_sparse<af_add_t >(out, lhs, rhs);
172185 } else if (linfo.isSparse () && !rinfo.isSparse ()) {
173186 return af_arith_sparse_dense<af_add_t >(out, lhs, rhs);
174187 } else if (!linfo.isSparse () && rinfo.isSparse ()) {
175- return af_arith_sparse_dense<af_add_t >(out, rhs, lhs, true ); // dense should be rhs
188+ // second operand(Array) of af_arith call should be dense
189+ return af_arith_sparse_dense<af_add_t >(out, rhs, lhs, true );
176190 } else {
177191 return af_arith<af_add_t >(out, lhs, rhs, batchMode);
178192 }
@@ -185,7 +199,10 @@ af_err af_mul(af_array *out, const af_array lhs, const af_array rhs, const bool
185199 ArrayInfo rinfo = getInfo (rhs, false , true );
186200
187201 if (linfo.isSparse () && rinfo.isSparse ()) {
188- return AF_ERR_NOT_SUPPORTED; // af_arith_sparse<af_mul_t>(out, lhs, rhs);
202+ // return af_arith_sparse<af_mul_t>(out, lhs, rhs);
203+ // MKL doesn't have mul or div support yet, hence
204+ // this is commented out although alternative cpu code exists
205+ return AF_ERR_NOT_SUPPORTED;
189206 } else if (linfo.isSparse () && !rinfo.isSparse ()) {
190207 return af_arith_sparse_dense<af_mul_t >(out, lhs, rhs);
191208 } else if (!linfo.isSparse () && rinfo.isSparse ()) {
@@ -202,7 +219,7 @@ af_err af_sub(af_array *out, const af_array lhs, const af_array rhs, const bool
202219 ArrayInfo rinfo = getInfo (rhs, false , true );
203220
204221 if (linfo.isSparse () && rinfo.isSparse ()) {
205- return AF_ERR_NOT_SUPPORTED; // af_arith_sparse<af_sub_t>(out, lhs, rhs);
222+ return af_arith_sparse<af_sub_t >(out, lhs, rhs);
206223 } else if (linfo.isSparse () && !rinfo.isSparse ()) {
207224 return af_arith_sparse_dense<af_sub_t >(out, lhs, rhs);
208225 } else if (!linfo.isSparse () && rinfo.isSparse ()) {
@@ -219,7 +236,10 @@ af_err af_div(af_array *out, const af_array lhs, const af_array rhs, const bool
219236 ArrayInfo rinfo = getInfo (rhs, false , true );
220237
221238 if (linfo.isSparse () && rinfo.isSparse ()) {
222- return AF_ERR_NOT_SUPPORTED; // af_arith_sparse<af_div_t>(out, lhs, rhs);
239+ // return af_arith_sparse<af_div_t>(out, lhs, rhs);
240+ // MKL doesn't have mul or div support yet, hence
241+ // this is commented out although alternative cpu code exists
242+ return AF_ERR_NOT_SUPPORTED;
223243 } else if (linfo.isSparse () && !rinfo.isSparse ()) {
224244 return af_arith_sparse_dense<af_div_t >(out, lhs, rhs);
225245 } else if (!linfo.isSparse () && rinfo.isSparse ()) {
0 commit comments