1616#include < af/index.h>
1717#include < af/device.h>
1818#include < af/gfor.h>
19+ #include < af/algorithm.h>
1920#include " error.hpp"
2021
2122namespace af
2223{
24+ static void copyIndices (af_index_t inds[4 ], af_index_t indices[4 ])
25+ {
26+ for (int i = 0 ; i < 4 ; i++) {
27+ if (!indices[i].mIsSeq ) {
28+ AF_THROW (af_weak_copy (&inds[i].mIndexer .arr , indices[i].mIndexer .arr ));
29+ } else {
30+ inds[i].mIndexer .seq = indices[i].mIndexer .seq ;
31+ }
32+ inds[i].mIsSeq = indices[i].mIsSeq ;
33+ inds[i].isBatch = indices[i].isBatch ;
34+ }
35+ }
36+
2337 static af_index_t toIndices (const seq &s)
2438 {
2539 af_index_t res;
@@ -29,15 +43,31 @@ namespace af
2943 return res;
3044 }
3145
32- static af_index_t toIndices (const array &idx )
46+ static af_index_t toIndices (const array &idx0 )
3347 {
3448 af_index_t res;
35- res.mIndexer .arr = idx.get ();
49+
50+ array idx = idx0.isbool () ? where (idx0) : idx0;
51+ af_array arr = 0 ;
52+ AF_THROW (af_weak_copy (&arr, idx.get ()));
53+ res.mIndexer .arr = arr;
54+
3655 res.mIsSeq = false ;
3756 res.isBatch = false ;
3857 return res;
3958 }
4059
60+ void cleanIndices (af_index_t indices[4 ])
61+ {
62+ for (int i = 0 ; i < 4 ; i++) {
63+ if (!indices[i].mIsSeq ) {
64+ AF_THROW (af_destroy_array (indices[i].mIndexer .arr ));
65+ }
66+ // Just to be safe
67+ indices[i] = toIndices (span);
68+ }
69+ }
70+
4171 static int gforDim (af_index_t seqs[4 ])
4272 {
4373 for (int i = 0 ; i < 4 ; i++) {
@@ -246,6 +276,8 @@ namespace af
246276 arr = temp;
247277 }
248278
279+ cleanIndices (indices);
280+
249281 isRef = false ;
250282 AF_THROW (err);
251283 return arr;
@@ -497,12 +529,12 @@ namespace af
497529
498530 af_array tmp;
499531 AF_THROW (af_assign_gen (&tmp, arr, nd, indices, other_arr));
500- parent-> set (tmp );
532+ cleanIndices (indices );
501533
534+ parent->set (tmp);
502535 if (dim >= 0 && is_reordered) AF_THROW (af_destroy_array (other_arr));
503536
504537 isRef = false ;
505-
506538 } else {
507539
508540 if (this ->get () == other.get ()) {
@@ -527,13 +559,17 @@ namespace af
527559 af_array lhs; \
528560 int dim = gforDim (this ->indices ); \
529561 AF_THROW (af_weak_copy (&lhs, this ->arr )); \
562+ af_index_t inds[4 ]; \
563+ /* FIXME: Figure out a way to not perform the copy*/ \
564+ copyIndices (inds, indices); \
530565 unsigned ndims = numDims (lhs); \
531566 /* FIXME: Unify with other af_assign_gen */ \
532567 array tmp = *this op1 other; \
533568 af_array tmp_arr = tmp.get (); \
534569 af_array out = 0 ; \
535570 tmp_arr = (dim == -1 ) ? tmp_arr : gforReorder (tmp_arr, dim); \
536- AF_THROW (af_assign_gen (&out, lhs, ndims, indices, tmp_arr)); \
571+ AF_THROW (af_assign_gen (&out, lhs, ndims, inds, tmp_arr)); \
572+ cleanIndices (indices); \
537573 AF_THROW (af_destroy_array (this ->arr )); \
538574 if (dim >= 0 ) AF_THROW (af_destroy_array (tmp_arr)); \
539575 this ->arr = lhs; \
0 commit comments