1111#include < platform.hpp>
1212#include < handle.hpp>
1313#include < backend.hpp>
14+ #include < sparse_handle.hpp>
1415
1516using namespace detail ;
1617
1718const ArrayInfo&
18- getInfo (const af_array arr, bool check)
19- {
20- const ArrayInfo *info = static_cast <ArrayInfo*>(reinterpret_cast <void *>(arr));
21-
22- // Check Sparse
23- ARG_ASSERT (0 , info->isSparse () == false );
24-
25- if (check && info->getDevId () != detail::getActiveDeviceId ()) {
26- AF_ERROR (" Input Array not created on current device" , AF_ERR_DEVICE );
27- }
28-
29- return *info;
30- }
31-
32- const ArrayInfo&
33- getSparseInfo (const af_array arr, bool sparseCheck, bool check)
19+ getInfo (const af_array arr, bool sparse_check, bool device_check)
3420{
3521 const ArrayInfo *info = static_cast <ArrayInfo*>(reinterpret_cast <void *>(arr));
3622
3723 // Check Sparse -> If false, then both standard Array<T> and SparseArray<T> are accepted
38- if (sparseCheck) {
39- ARG_ASSERT (0 , info->isSparse () == true );
24+ // Otherwise only regular Array<T> is accepted
25+ if (sparse_check) {
26+ ARG_ASSERT (0 , info->isSparse () == false );
4027 }
4128
42- if (check && info->getDevId () != detail::getActiveDeviceId ()) {
29+ if (device_check && info->getDevId () != detail::getActiveDeviceId ()) {
4330 AF_ERROR (" Input Array not created on current device" , AF_ERR_DEVICE );
4431 }
4532
@@ -169,7 +156,7 @@ af_err af_copy_array(af_array *out, const af_array in)
169156af_err af_get_data_ref_count (int *use_count, const af_array in)
170157{
171158 try {
172- ArrayInfo info = getSparseInfo (in, false , false );
159+ ArrayInfo info = getInfo (in, false , false );
173160 const af_dtype type = info.getType ();
174161
175162 int res;
@@ -199,29 +186,39 @@ af_err af_release_array(af_array arr)
199186 try {
200187 int dev = getActiveDeviceId ();
201188
202- ArrayInfo info = getSparseInfo (arr, false , false );
203-
204- setDevice (info.getDevId ());
205-
189+ ArrayInfo info = getInfo (arr, false , false );
206190 af_dtype type = info.getType ();
207191
208- switch (type) {
209- case f32 : releaseHandle<float >(arr); break ;
210- case c32: releaseHandle<cfloat >(arr); break ;
211- case f64 : releaseHandle<double >(arr); break ;
212- case c64: releaseHandle<cdouble >(arr); break ;
213- case b8: releaseHandle<char >(arr); break ;
214- case s32: releaseHandle<int >(arr); break ;
215- case u32 : releaseHandle<uint >(arr); break ;
216- case u8 : releaseHandle<uchar >(arr); break ;
217- case s64: releaseHandle<intl >(arr); break ;
218- case u64 : releaseHandle<uintl >(arr); break ;
219- case s16: releaseHandle<short >(arr); break ;
220- case u16 : releaseHandle<ushort >(arr); break ;
221- default : TYPE_ERROR (0 , type);
192+ if (info.isSparse ()) {
193+ switch (type) {
194+ case f32 : releaseSparseHandle<float >(arr); break ;
195+ case f64 : releaseSparseHandle<double >(arr); break ;
196+ case c32: releaseSparseHandle<cfloat >(arr); break ;
197+ case c64: releaseSparseHandle<cdouble>(arr); break ;
198+ default : TYPE_ERROR (0 , type);
199+ }
200+ } else {
201+
202+ setDevice (info.getDevId ());
203+
204+ switch (type) {
205+ case f32 : releaseHandle<float >(arr); break ;
206+ case c32: releaseHandle<cfloat >(arr); break ;
207+ case f64 : releaseHandle<double >(arr); break ;
208+ case c64: releaseHandle<cdouble >(arr); break ;
209+ case b8: releaseHandle<char >(arr); break ;
210+ case s32: releaseHandle<int >(arr); break ;
211+ case u32 : releaseHandle<uint >(arr); break ;
212+ case u8 : releaseHandle<uchar >(arr); break ;
213+ case s64: releaseHandle<intl >(arr); break ;
214+ case u64 : releaseHandle<uintl >(arr); break ;
215+ case s16: releaseHandle<short >(arr); break ;
216+ case u16 : releaseHandle<ushort >(arr); break ;
217+ default : TYPE_ERROR (0 , type);
218+ }
219+
220+ setDevice (dev);
222221 }
223-
224- setDevice (dev);
225222 }
226223 CATCHALL
227224
@@ -240,22 +237,33 @@ static af_array retainHandle(const af_array in)
240237
241238af_array retain (const af_array in)
242239{
243- af_dtype ty = getSparseInfo (in, false , false ).getType ();
244- switch (ty) {
245- case f32 : return retainHandle<float >(in);
246- case f64 : return retainHandle<double >(in);
247- case s32: return retainHandle<int >(in);
248- case u32 : return retainHandle<uint >(in);
249- case u8 : return retainHandle<uchar >(in);
250- case c32: return retainHandle<detail::cfloat >(in);
251- case c64: return retainHandle<detail::cdouble >(in);
252- case b8: return retainHandle<char >(in);
253- case s64: return retainHandle<intl >(in);
254- case u64 : return retainHandle<uintl >(in);
255- case s16: return retainHandle<short >(in);
256- case u16 : return retainHandle<ushort >(in);
257- default :
258- TYPE_ERROR (1 , ty);
240+ ArrayInfo info = getInfo (in, false , false );
241+ af_dtype ty = info.getType ();
242+
243+ if (info.isSparse ()) {
244+ switch (ty) {
245+ case f32 : return retainSparseHandle<float >(in);
246+ case f64 : return retainSparseHandle<double >(in);
247+ case c32: return retainSparseHandle<detail::cfloat >(in);
248+ case c64: return retainSparseHandle<detail::cdouble>(in);
249+ default : TYPE_ERROR (1 , ty);
250+ }
251+ } else {
252+ switch (ty) {
253+ case f32 : return retainHandle<float >(in);
254+ case f64 : return retainHandle<double >(in);
255+ case s32: return retainHandle<int >(in);
256+ case u32 : return retainHandle<uint >(in);
257+ case u8 : return retainHandle<uchar >(in);
258+ case c32: return retainHandle<detail::cfloat >(in);
259+ case c64: return retainHandle<detail::cdouble >(in);
260+ case b8: return retainHandle<char >(in);
261+ case s64: return retainHandle<intl >(in);
262+ case u64 : return retainHandle<uintl >(in);
263+ case s16: return retainHandle<short >(in);
264+ case u16 : return retainHandle<ushort >(in);
265+ default : TYPE_ERROR (1 , ty);
266+ }
259267 }
260268}
261269
@@ -309,7 +317,7 @@ af_err af_get_elements(dim_t *elems, const af_array arr)
309317{
310318 try {
311319 // Do not check for device mismatch
312- *elems = getSparseInfo (arr, false , false ).elements ();
320+ *elems = getInfo (arr, false , false ).elements ();
313321 } CATCHALL
314322 return AF_SUCCESS ;
315323}
@@ -355,7 +363,7 @@ af_err af_get_numdims(unsigned *nd, const af_array in)
355363 af_err fn1 (bool *result, const af_array in) \
356364 { \
357365 try { \
358- ArrayInfo info = getSparseInfo (in, false , false ); \
366+ ArrayInfo info = getInfo (in, false , false ); \
359367 *result = info.fn2 (); \
360368 } \
361369 CATCHALL \
0 commit comments