Skip to content

Commit 72952ed

Browse files
committed
Linear indexing now flattens the arrays before the operation
1 parent 1f5e3f9 commit 72952ed

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

src/api/cpp/array.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,14 @@ namespace af
299299
array array::operator()(const array& idx) const
300300
{
301301
eval();
302+
303+
// Special case of indexing linearly
304+
// Flatten the current array and index accordingly
305+
if (this->numdims() > 1) {
306+
array tmp = flat(*this);
307+
return tmp(idx);
308+
}
309+
302310
af_array out = 0;
303311
AF_THROW(af_lookup(&out, this->get(), idx.get(), 0));
304312
return array(out);
@@ -307,6 +315,14 @@ namespace af
307315
array array::operator()(const seq &s0) const
308316
{
309317
eval();
318+
319+
// Special case of indexing linearly
320+
// Flatten the current array and index accordingly
321+
if (this->numdims() > 1) {
322+
array tmp = flat(*this);
323+
return tmp(s0);
324+
}
325+
310326
af_array out = 0;
311327
seq indices[] = {s0, span, span, span};
312328
//FIXME: check if this->s has same dimensions as numdims

test/index.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ using std::endl;
2929
using std::ostream_iterator;
3030
using af::dtype_traits;
3131

32+
3233
template<typename T, typename OP>
3334
void
3435
checkValues(const af_seq &seq, const T* data, const T* indexed_data, OP compair_op) {
@@ -627,7 +628,7 @@ TEST(lookup, CPP)
627628

628629
array input(dims0, &(in[0].front()));
629630
array indices(dims1, &(in[1].front()));
630-
array output = input(indices);
631+
array output = af::lookup(input, indices, 0);
631632

632633
vector<float> currGoldBar = tests[0];
633634
size_t nElems = currGoldBar.size();
@@ -728,7 +729,7 @@ TEST(SeqIndex, CPPLarge)
728729

729730
array input(dims0, &(in[0].front()));
730731
array indices(dims1, &(in[1].front()));
731-
array output = input(indices);
732+
array output = af::lookup(input, indices, 0);
732733

733734
vector<float> currGoldBar = tests[0];
734735
size_t nElems = currGoldBar.size();

0 commit comments

Comments
 (0)