Skip to content

Commit 5b671c2

Browse files
committed
Merge pull request arrayfire#971 from pavanky/assign
FEAT: Adding support for linear assignment in C API
2 parents 2c7044b + 654dffb commit 5b671c2

2 files changed

Lines changed: 216 additions & 19 deletions

File tree

src/api/c/assign.cpp

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <copy.hpp>
2020
#include <assign.hpp>
2121
#include <math.hpp>
22+
#include <tile.hpp>
2223

2324
using namespace detail;
2425
using std::vector;
@@ -49,13 +50,21 @@ void assign(Array<Tout> &out, const unsigned &ndims, const af_seq *index, const
4950
is_vector &= oDims[i] == 1;
5051
}
5152

52-
if (is_vector && in_.isVector()) {
53-
if (oDims.elements() != (dim_t)in_.elements()) {
53+
is_vector &= in_.isVector() || in_.isScalar();
54+
55+
for (dim_t i = ndims; i < (int)in_.ndims(); i++) {
56+
oDims[i] = 1;
57+
}
58+
59+
60+
if (is_vector) {
61+
if (oDims.elements() != (dim_t)in_.elements() &&
62+
in_.elements() != 1) {
5463
AF_ERROR("Size mismatch between input and output", AF_ERR_SIZE);
5564
}
5665

5766
// If both out and in are vectors of equal elements, reshape in to out dims
58-
Array<Tin> in = modDims(in_, oDims);
67+
Array<Tin> in = in_.elements() == 1 ? tile(in_, oDims) : modDims(in_, oDims);
5968
Array<Tout> dst = createSubArray<Tout>(out, index_, false);
6069

6170
copyArray<Tin , Tout>(dst, in);
@@ -112,6 +121,18 @@ af_err af_assign_seq(af_array *out,
112121
ARG_ASSERT(1, (ndims>0));
113122
ARG_ASSERT(3, (rhs!=0));
114123

124+
ArrayInfo lInfo = getInfo(lhs);
125+
126+
if (ndims == 1 && ndims != (dim_t)lInfo.ndims()) {
127+
af_array tmp_in, tmp_out;
128+
AF_CHECK(af_flat(&tmp_in, lhs));
129+
AF_CHECK(af_assign_seq(&tmp_out, tmp_in, ndims, index, rhs));
130+
AF_CHECK(af_moddims(out, tmp_out, lInfo.ndims(), lInfo.dims().get()));
131+
AF_CHECK(af_release_array(tmp_in));
132+
AF_CHECK(af_release_array(tmp_out));
133+
return AF_SUCCESS;
134+
}
135+
115136
for(dim_t i=0; i<(dim_t)ndims; ++i) {
116137
ARG_ASSERT(2, (index[i].step>=0));
117138
}
@@ -200,30 +221,42 @@ af_err af_assign_gen(af_array *out,
200221
ARG_ASSERT(1, (lhs!=0));
201222
ARG_ASSERT(4, (rhs!=0));
202223

203-
if (*out != lhs) {
204-
int count = 0;
205-
AF_CHECK(af_get_data_ref_count(&count, lhs));
206-
if (count > 1) {
207-
AF_CHECK(af_copy_array(&output, lhs));
208-
} else {
209-
AF_CHECK(af_retain_array(&output, lhs));
210-
}
211-
} else {
212-
output = lhs;
213-
}
214-
215224
ArrayInfo lInfo = getInfo(lhs);
216225
ArrayInfo rInfo = getInfo(rhs);
217226
dim4 lhsDims = lInfo.dims();
218227
dim4 rhsDims = rInfo.dims();
219228
af_dtype lhsType= lInfo.getType();
220229
af_dtype rhsType= rInfo.getType();
221230

231+
ARG_ASSERT(2, (ndims == 1) || (ndims == (dim_t)lInfo.ndims()));
232+
233+
if (ndims == 1 && ndims != (dim_t)lInfo.ndims()) {
234+
af_array tmp_in, tmp_out;
235+
AF_CHECK(af_flat(&tmp_in, lhs));
236+
AF_CHECK(af_assign_gen(&tmp_out, tmp_in, ndims, indexs, rhs_));
237+
AF_CHECK(af_moddims(out, tmp_out, lInfo.ndims(), lInfo.dims().get()));
238+
AF_CHECK(af_release_array(tmp_in));
239+
AF_CHECK(af_release_array(tmp_out));
240+
return AF_SUCCESS;
241+
}
242+
222243
ARG_ASSERT(1, (lhsType==rhsType));
223244
ARG_ASSERT(3, (rhsDims.ndims()>0));
224245
ARG_ASSERT(1, (lhsDims.ndims()>=rhsDims.ndims()));
225246
ARG_ASSERT(2, (lhsDims.ndims()>=ndims));
226247

248+
if (*out != lhs) {
249+
int count = 0;
250+
AF_CHECK(af_get_data_ref_count(&count, lhs));
251+
if (count > 1) {
252+
AF_CHECK(af_copy_array(&output, lhs));
253+
} else {
254+
AF_CHECK(af_retain_array(&output, lhs));
255+
}
256+
} else {
257+
output = lhs;
258+
}
259+
227260
dim4 oDims = toDims(seqs, lhsDims);
228261
// if af_array are indexs along any
229262
// particular dimension, set the length of
@@ -234,20 +267,29 @@ af_err af_assign_gen(af_array *out,
234267
}
235268
}
236269

270+
for (dim_t i = ndims; i < (dim_t)lInfo.ndims(); i++) {
271+
oDims[i] = 1;
272+
}
273+
237274
bool is_vector = true;
238275
for (int i = 0; is_vector && i < oDims.ndims() - 1; i++) {
239276
is_vector &= oDims[i] == 1;
240277
}
241278

242279
//TODO: Move logic out of this
243-
is_vector &= rInfo.isVector();
280+
is_vector &= rInfo.isVector() || rInfo.isScalar();
244281
if (is_vector) {
245-
if (oDims.elements() != (dim_t)rInfo.elements()) {
282+
if (oDims.elements() != (dim_t)rInfo.elements() &&
283+
rInfo.elements() != 1) {
246284
AF_ERROR("Size mismatch between input and output", AF_ERR_SIZE);
247285
}
248286

249-
// If both out and rhs are vectors of equal elements, reshape rhs to out dims
250-
AF_CHECK(af_moddims(&rhs, rhs_, oDims.ndims(), oDims.get()));
287+
if (rInfo.elements() == 1) {
288+
AF_CHECK(af_tile(&rhs, rhs_, oDims[0], oDims[1], oDims[2], oDims[3]));
289+
} else {
290+
// If both out and rhs are vectors of equal elements, reshape rhs to out dims
291+
AF_CHECK(af_moddims(&rhs, rhs_, oDims.ndims(), oDims.get()));
292+
}
251293
} else {
252294
for (int i = 0; i < 4; i++) {
253295
if (oDims[i] != rhsDims[i]) {

test/assign.cpp

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,3 +830,158 @@ TEST(Assign, Copy)
830830
delete[] h_a;
831831
delete[] h_b;
832832
}
833+
834+
TEST(Asssign, LinearCPP)
835+
{
836+
using af::array;
837+
const int nx = 5;
838+
const int ny = 4;
839+
const float val = 3;
840+
841+
const int st = nx - 2;
842+
const int en = nx * (ny - 1);
843+
844+
array a = af::randu(nx, ny);
845+
array a_copy = a;
846+
af::index idx = af::seq(st, en);
847+
a(idx) = 3;
848+
849+
ASSERT_EQ(a.dims(0), a_copy.dims(0));
850+
ASSERT_EQ(a.dims(1), a_copy.dims(1));
851+
852+
std::vector<float> ha(nx * ny);
853+
std::vector<float> ha_copy(nx * ny);
854+
855+
a.host(&ha[0]);
856+
a_copy.host(&ha_copy[0]);
857+
858+
for (int i = 0; i < nx * ny; i++) {
859+
if (i < st || i > en)
860+
ASSERT_EQ(ha[i], ha_copy[i]) << "at " << i;
861+
else
862+
ASSERT_EQ(ha[i], val) << "at " << i;
863+
}
864+
}
865+
866+
TEST(Asssign, LinearAssignSeq)
867+
{
868+
using af::array;
869+
const int nx = 5;
870+
const int ny = 4;
871+
const float val = 3;
872+
const array rhs = af::constant(val, 1, 1);
873+
874+
const int st = nx - 2;
875+
const int en = nx * (ny - 1);
876+
877+
array a = af::randu(nx, ny);
878+
af::index idx = af::seq(st, en);
879+
880+
af_array in_arr = a.get();
881+
af_index_t ii = idx.get();
882+
af_array rhs_arr = rhs.get();
883+
af_array out_arr;
884+
885+
ASSERT_EQ(AF_SUCCESS,
886+
af_assign_seq(&out_arr, in_arr, 1, &ii.idx.seq, rhs_arr));
887+
888+
af::array out(out_arr);
889+
890+
ASSERT_EQ(a.dims(0), out.dims(0));
891+
ASSERT_EQ(a.dims(1), out.dims(1));
892+
893+
std::vector<float> hout(nx * ny);
894+
std::vector<float> ha(nx * ny);
895+
896+
a.host(&ha[0]);
897+
out.host(&hout[0]);
898+
899+
for (int i = 0; i < nx * ny; i++) {
900+
if (i < st || i > en)
901+
ASSERT_EQ(hout[i], ha[i]) << "at " << i;
902+
else
903+
ASSERT_EQ(hout[i], val) << "at " << i;
904+
}
905+
}
906+
907+
TEST(Asssign, LinearAssignGenSeq)
908+
{
909+
using af::array;
910+
const int nx = 5;
911+
const int ny = 4;
912+
const float val = 3;
913+
const array rhs = af::constant(val, 1, 1);
914+
915+
const int st = nx - 2;
916+
const int en = nx * (ny - 1);
917+
918+
array a = af::randu(nx, ny);
919+
af::index idx = af::seq(st, en);
920+
921+
af_array in_arr = a.get();
922+
af_index_t ii = idx.get();
923+
af_array rhs_arr = rhs.get();
924+
af_array out_arr;
925+
926+
ASSERT_EQ(AF_SUCCESS,
927+
af_assign_gen(&out_arr, in_arr, 1, &ii, rhs_arr));
928+
929+
af::array out(out_arr);
930+
931+
ASSERT_EQ(a.dims(0), out.dims(0));
932+
ASSERT_EQ(a.dims(1), out.dims(1));
933+
934+
std::vector<float> hout(nx * ny);
935+
std::vector<float> ha(nx * ny);
936+
937+
a.host(&ha[0]);
938+
out.host(&hout[0]);
939+
940+
for (int i = 0; i < nx * ny; i++) {
941+
if (i < st || i > en)
942+
ASSERT_EQ(hout[i], ha[i]) << "at " << i;
943+
else
944+
ASSERT_EQ(hout[i], val) << "at " << i;
945+
}
946+
}
947+
948+
TEST(Asssign, LinearAssignGenArr)
949+
{
950+
using af::array;
951+
const int nx = 5;
952+
const int ny = 4;
953+
const float val = 3;
954+
const array rhs = af::constant(val, 1, 1);
955+
956+
const int st = nx - 2;
957+
const int en = nx * (ny - 1);
958+
959+
array a = af::randu(nx, ny);
960+
af::index idx = af::array(af::seq(st, en));
961+
962+
af_array in_arr = a.get();
963+
af_index_t ii = idx.get();
964+
af_array rhs_arr = rhs.get();
965+
af_array out_arr;
966+
967+
ASSERT_EQ(AF_SUCCESS,
968+
af_assign_gen(&out_arr, in_arr, 1, &ii, rhs_arr));
969+
970+
af::array out(out_arr);
971+
972+
ASSERT_EQ(a.dims(0), out.dims(0));
973+
ASSERT_EQ(a.dims(1), out.dims(1));
974+
975+
std::vector<float> hout(nx * ny);
976+
std::vector<float> ha(nx * ny);
977+
978+
a.host(&ha[0]);
979+
out.host(&hout[0]);
980+
981+
for (int i = 0; i < nx * ny; i++) {
982+
if (i < st || i > en)
983+
ASSERT_EQ(hout[i], ha[i]) << "at " << i;
984+
else
985+
ASSERT_EQ(hout[i], val) << "at " << i;
986+
}
987+
}

0 commit comments

Comments
 (0)