Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 61 additions & 19 deletions src/api/c/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <copy.hpp>
#include <assign.hpp>
#include <math.hpp>
#include <tile.hpp>

using namespace detail;
using std::vector;
Expand Down Expand Up @@ -49,13 +50,21 @@ void assign(Array<Tout> &out, const unsigned &ndims, const af_seq *index, const
is_vector &= oDims[i] == 1;
}

if (is_vector && in_.isVector()) {
if (oDims.elements() != (dim_t)in_.elements()) {
is_vector &= in_.isVector() || in_.isScalar();

for (dim_t i = ndims; i < (int)in_.ndims(); i++) {
oDims[i] = 1;
}


if (is_vector) {
if (oDims.elements() != (dim_t)in_.elements() &&
in_.elements() != 1) {
AF_ERROR("Size mismatch between input and output", AF_ERR_SIZE);
}

// If both out and in are vectors of equal elements, reshape in to out dims
Array<Tin> in = modDims(in_, oDims);
Array<Tin> in = in_.elements() == 1 ? tile(in_, oDims) : modDims(in_, oDims);
Array<Tout> dst = createSubArray<Tout>(out, index_, false);

copyArray<Tin , Tout>(dst, in);
Expand Down Expand Up @@ -112,6 +121,18 @@ af_err af_assign_seq(af_array *out,
ARG_ASSERT(1, (ndims>0));
ARG_ASSERT(3, (rhs!=0));

ArrayInfo lInfo = getInfo(lhs);

if (ndims == 1 && ndims != (dim_t)lInfo.ndims()) {
af_array tmp_in, tmp_out;
AF_CHECK(af_flat(&tmp_in, lhs));
AF_CHECK(af_assign_seq(&tmp_out, tmp_in, ndims, index, rhs));
AF_CHECK(af_moddims(out, tmp_out, lInfo.ndims(), lInfo.dims().get()));
AF_CHECK(af_release_array(tmp_in));
AF_CHECK(af_release_array(tmp_out));
return AF_SUCCESS;
}

for(dim_t i=0; i<(dim_t)ndims; ++i) {
ARG_ASSERT(2, (index[i].step>=0));
}
Expand Down Expand Up @@ -200,30 +221,42 @@ af_err af_assign_gen(af_array *out,
ARG_ASSERT(1, (lhs!=0));
ARG_ASSERT(4, (rhs!=0));

if (*out != lhs) {
int count = 0;
AF_CHECK(af_get_data_ref_count(&count, lhs));
if (count > 1) {
AF_CHECK(af_copy_array(&output, lhs));
} else {
AF_CHECK(af_retain_array(&output, lhs));
}
} else {
output = lhs;
}

ArrayInfo lInfo = getInfo(lhs);
ArrayInfo rInfo = getInfo(rhs);
dim4 lhsDims = lInfo.dims();
dim4 rhsDims = rInfo.dims();
af_dtype lhsType= lInfo.getType();
af_dtype rhsType= rInfo.getType();

ARG_ASSERT(2, (ndims == 1) || (ndims == (dim_t)lInfo.ndims()));

if (ndims == 1 && ndims != (dim_t)lInfo.ndims()) {
af_array tmp_in, tmp_out;
AF_CHECK(af_flat(&tmp_in, lhs));
AF_CHECK(af_assign_gen(&tmp_out, tmp_in, ndims, indexs, rhs_));
AF_CHECK(af_moddims(out, tmp_out, lInfo.ndims(), lInfo.dims().get()));
AF_CHECK(af_release_array(tmp_in));
AF_CHECK(af_release_array(tmp_out));
return AF_SUCCESS;
}

ARG_ASSERT(1, (lhsType==rhsType));
ARG_ASSERT(3, (rhsDims.ndims()>0));
ARG_ASSERT(1, (lhsDims.ndims()>=rhsDims.ndims()));
ARG_ASSERT(2, (lhsDims.ndims()>=ndims));

if (*out != lhs) {
int count = 0;
AF_CHECK(af_get_data_ref_count(&count, lhs));
if (count > 1) {
AF_CHECK(af_copy_array(&output, lhs));
} else {
AF_CHECK(af_retain_array(&output, lhs));
}
} else {
output = lhs;
}

dim4 oDims = toDims(seqs, lhsDims);
// if af_array are indexs along any
// particular dimension, set the length of
Expand All @@ -234,20 +267,29 @@ af_err af_assign_gen(af_array *out,
}
}

for (dim_t i = ndims; i < (dim_t)lInfo.ndims(); i++) {
oDims[i] = 1;
}

bool is_vector = true;
for (int i = 0; is_vector && i < oDims.ndims() - 1; i++) {
is_vector &= oDims[i] == 1;
}

//TODO: Move logic out of this
is_vector &= rInfo.isVector();
is_vector &= rInfo.isVector() || rInfo.isScalar();
if (is_vector) {
if (oDims.elements() != (dim_t)rInfo.elements()) {
if (oDims.elements() != (dim_t)rInfo.elements() &&
rInfo.elements() != 1) {
AF_ERROR("Size mismatch between input and output", AF_ERR_SIZE);
}

// If both out and rhs are vectors of equal elements, reshape rhs to out dims
AF_CHECK(af_moddims(&rhs, rhs_, oDims.ndims(), oDims.get()));
if (rInfo.elements() == 1) {
AF_CHECK(af_tile(&rhs, rhs_, oDims[0], oDims[1], oDims[2], oDims[3]));
} else {
// If both out and rhs are vectors of equal elements, reshape rhs to out dims
AF_CHECK(af_moddims(&rhs, rhs_, oDims.ndims(), oDims.get()));
}
} else {
for (int i = 0; i < 4; i++) {
if (oDims[i] != rhsDims[i]) {
Expand Down
155 changes: 155 additions & 0 deletions test/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,3 +830,158 @@ TEST(Assign, Copy)
delete[] h_a;
delete[] h_b;
}

TEST(Asssign, LinearCPP)
{
using af::array;
const int nx = 5;
const int ny = 4;
const float val = 3;

const int st = nx - 2;
const int en = nx * (ny - 1);

array a = af::randu(nx, ny);
array a_copy = a;
af::index idx = af::seq(st, en);
a(idx) = 3;

ASSERT_EQ(a.dims(0), a_copy.dims(0));
ASSERT_EQ(a.dims(1), a_copy.dims(1));

std::vector<float> ha(nx * ny);
std::vector<float> ha_copy(nx * ny);

a.host(&ha[0]);
a_copy.host(&ha_copy[0]);

for (int i = 0; i < nx * ny; i++) {
if (i < st || i > en)
ASSERT_EQ(ha[i], ha_copy[i]) << "at " << i;
else
ASSERT_EQ(ha[i], val) << "at " << i;
}
}

TEST(Asssign, LinearAssignSeq)
{
using af::array;
const int nx = 5;
const int ny = 4;
const float val = 3;
const array rhs = af::constant(val, 1, 1);

const int st = nx - 2;
const int en = nx * (ny - 1);

array a = af::randu(nx, ny);
af::index idx = af::seq(st, en);

af_array in_arr = a.get();
af_index_t ii = idx.get();
af_array rhs_arr = rhs.get();
af_array out_arr;

ASSERT_EQ(AF_SUCCESS,
af_assign_seq(&out_arr, in_arr, 1, &ii.idx.seq, rhs_arr));

af::array out(out_arr);

ASSERT_EQ(a.dims(0), out.dims(0));
ASSERT_EQ(a.dims(1), out.dims(1));

std::vector<float> hout(nx * ny);
std::vector<float> ha(nx * ny);

a.host(&ha[0]);
out.host(&hout[0]);

for (int i = 0; i < nx * ny; i++) {
if (i < st || i > en)
ASSERT_EQ(hout[i], ha[i]) << "at " << i;
else
ASSERT_EQ(hout[i], val) << "at " << i;
}
}

TEST(Asssign, LinearAssignGenSeq)
{
using af::array;
const int nx = 5;
const int ny = 4;
const float val = 3;
const array rhs = af::constant(val, 1, 1);

const int st = nx - 2;
const int en = nx * (ny - 1);

array a = af::randu(nx, ny);
af::index idx = af::seq(st, en);

af_array in_arr = a.get();
af_index_t ii = idx.get();
af_array rhs_arr = rhs.get();
af_array out_arr;

ASSERT_EQ(AF_SUCCESS,
af_assign_gen(&out_arr, in_arr, 1, &ii, rhs_arr));

af::array out(out_arr);

ASSERT_EQ(a.dims(0), out.dims(0));
ASSERT_EQ(a.dims(1), out.dims(1));

std::vector<float> hout(nx * ny);
std::vector<float> ha(nx * ny);

a.host(&ha[0]);
out.host(&hout[0]);

for (int i = 0; i < nx * ny; i++) {
if (i < st || i > en)
ASSERT_EQ(hout[i], ha[i]) << "at " << i;
else
ASSERT_EQ(hout[i], val) << "at " << i;
}
}

TEST(Asssign, LinearAssignGenArr)
{
using af::array;
const int nx = 5;
const int ny = 4;
const float val = 3;
const array rhs = af::constant(val, 1, 1);

const int st = nx - 2;
const int en = nx * (ny - 1);

array a = af::randu(nx, ny);
af::index idx = af::array(af::seq(st, en));

af_array in_arr = a.get();
af_index_t ii = idx.get();
af_array rhs_arr = rhs.get();
af_array out_arr;

ASSERT_EQ(AF_SUCCESS,
af_assign_gen(&out_arr, in_arr, 1, &ii, rhs_arr));

af::array out(out_arr);

ASSERT_EQ(a.dims(0), out.dims(0));
ASSERT_EQ(a.dims(1), out.dims(1));

std::vector<float> hout(nx * ny);
std::vector<float> ha(nx * ny);

a.host(&ha[0]);
out.host(&hout[0]);

for (int i = 0; i < nx * ny; i++) {
if (i < st || i > en)
ASSERT_EQ(hout[i], ha[i]) << "at " << i;
else
ASSERT_EQ(hout[i], val) << "at " << i;
}
}