1919#include < copy.hpp>
2020#include < assign.hpp>
2121#include < math.hpp>
22+ #include < tile.hpp>
2223
2324using namespace detail ;
2425using std::vector;
@@ -49,13 +50,21 @@ void assign(Array<Tout> &out, const unsigned &ndims, const af_seq *index, const
4950 is_vector &= oDims[i] == 1 ;
5051 }
5152
52- if (is_vector && in_.isVector ()) {
53- if (oDims.elements () != (dim_t )in_.elements ()) {
53+ is_vector &= in_.isVector () || in_.isScalar ();
54+
55+ for (dim_t i = ndims; i < (int )in_.ndims (); i++) {
56+ oDims[i] = 1 ;
57+ }
58+
59+
60+ if (is_vector) {
61+ if (oDims.elements () != (dim_t )in_.elements () &&
62+ in_.elements () != 1 ) {
5463 AF_ERROR (" Size mismatch between input and output" , AF_ERR_SIZE);
5564 }
5665
5766 // If both out and in are vectors of equal elements, reshape in to out dims
58- Array<Tin> in = modDims (in_, oDims);
67+ Array<Tin> in = in_. elements () == 1 ? tile (in_, oDims) : modDims (in_, oDims);
5968 Array<Tout> dst = createSubArray<Tout>(out, index_, false );
6069
6170 copyArray<Tin , Tout>(dst, in);
@@ -112,6 +121,18 @@ af_err af_assign_seq(af_array *out,
112121 ARG_ASSERT (1 , (ndims>0 ));
113122 ARG_ASSERT (3 , (rhs!=0 ));
114123
124+ ArrayInfo lInfo = getInfo (lhs);
125+
126+ if (ndims == 1 && ndims != (dim_t )lInfo.ndims ()) {
127+ af_array tmp_in, tmp_out;
128+ AF_CHECK (af_flat (&tmp_in, lhs));
129+ AF_CHECK (af_assign_seq (&tmp_out, tmp_in, ndims, index, rhs));
130+ AF_CHECK (af_moddims (out, tmp_out, lInfo.ndims (), lInfo.dims ().get ()));
131+ AF_CHECK (af_release_array (tmp_in));
132+ AF_CHECK (af_release_array (tmp_out));
133+ return AF_SUCCESS;
134+ }
135+
115136 for (dim_t i=0 ; i<(dim_t )ndims; ++i) {
116137 ARG_ASSERT (2 , (index[i].step >=0 ));
117138 }
@@ -200,30 +221,42 @@ af_err af_assign_gen(af_array *out,
200221 ARG_ASSERT (1 , (lhs!=0 ));
201222 ARG_ASSERT (4 , (rhs!=0 ));
202223
203- if (*out != lhs) {
204- int count = 0 ;
205- AF_CHECK (af_get_data_ref_count (&count, lhs));
206- if (count > 1 ) {
207- AF_CHECK (af_copy_array (&output, lhs));
208- } else {
209- AF_CHECK (af_retain_array (&output, lhs));
210- }
211- } else {
212- output = lhs;
213- }
214-
215224 ArrayInfo lInfo = getInfo (lhs);
216225 ArrayInfo rInfo = getInfo (rhs);
217226 dim4 lhsDims = lInfo.dims ();
218227 dim4 rhsDims = rInfo.dims ();
219228 af_dtype lhsType= lInfo.getType ();
220229 af_dtype rhsType= rInfo.getType ();
221230
231+ ARG_ASSERT (2 , (ndims == 1 ) || (ndims == (dim_t )lInfo.ndims ()));
232+
233+ if (ndims == 1 && ndims != (dim_t )lInfo.ndims ()) {
234+ af_array tmp_in, tmp_out;
235+ AF_CHECK (af_flat (&tmp_in, lhs));
236+ AF_CHECK (af_assign_gen (&tmp_out, tmp_in, ndims, indexs, rhs_));
237+ AF_CHECK (af_moddims (out, tmp_out, lInfo.ndims (), lInfo.dims ().get ()));
238+ AF_CHECK (af_release_array (tmp_in));
239+ AF_CHECK (af_release_array (tmp_out));
240+ return AF_SUCCESS;
241+ }
242+
222243 ARG_ASSERT (1 , (lhsType==rhsType));
223244 ARG_ASSERT (3 , (rhsDims.ndims ()>0 ));
224245 ARG_ASSERT (1 , (lhsDims.ndims ()>=rhsDims.ndims ()));
225246 ARG_ASSERT (2 , (lhsDims.ndims ()>=ndims));
226247
248+ if (*out != lhs) {
249+ int count = 0 ;
250+ AF_CHECK (af_get_data_ref_count (&count, lhs));
251+ if (count > 1 ) {
252+ AF_CHECK (af_copy_array (&output, lhs));
253+ } else {
254+ AF_CHECK (af_retain_array (&output, lhs));
255+ }
256+ } else {
257+ output = lhs;
258+ }
259+
227260 dim4 oDims = toDims (seqs, lhsDims);
228261 // if af_array are indexs along any
229262 // particular dimension, set the length of
@@ -234,20 +267,29 @@ af_err af_assign_gen(af_array *out,
234267 }
235268 }
236269
270+ for (dim_t i = ndims; i < (dim_t )lInfo.ndims (); i++) {
271+ oDims[i] = 1 ;
272+ }
273+
237274 bool is_vector = true ;
238275 for (int i = 0 ; is_vector && i < oDims.ndims () - 1 ; i++) {
239276 is_vector &= oDims[i] == 1 ;
240277 }
241278
242279 // TODO: Move logic out of this
243- is_vector &= rInfo.isVector ();
280+ is_vector &= rInfo.isVector () || rInfo. isScalar () ;
244281 if (is_vector) {
245- if (oDims.elements () != (dim_t )rInfo.elements ()) {
282+ if (oDims.elements () != (dim_t )rInfo.elements () &&
283+ rInfo.elements () != 1 ) {
246284 AF_ERROR (" Size mismatch between input and output" , AF_ERR_SIZE);
247285 }
248286
249- // If both out and rhs are vectors of equal elements, reshape rhs to out dims
250- AF_CHECK (af_moddims (&rhs, rhs_, oDims.ndims (), oDims.get ()));
287+ if (rInfo.elements () == 1 ) {
288+ AF_CHECK (af_tile (&rhs, rhs_, oDims[0 ], oDims[1 ], oDims[2 ], oDims[3 ]));
289+ } else {
290+ // If both out and rhs are vectors of equal elements, reshape rhs to out dims
291+ AF_CHECK (af_moddims (&rhs, rhs_, oDims.ndims (), oDims.get ()));
292+ }
251293 } else {
252294 for (int i = 0 ; i < 4 ; i++) {
253295 if (oDims[i] != rhsDims[i]) {
0 commit comments