/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace detail; using std::vector; using std::swap; // From src/api/c/moddims.cpp TODO: move to header? template Array modDims(const Array& in, const af::dim4 &newDims); template static void assign(Array &out, const unsigned &ndims, const af_seq *index, const Array &in_) { dim4 const outDs = out.dims(); dim4 const iDims = in_.dims(); DIM_ASSERT(0, (outDs.ndims()>=iDims.ndims())); DIM_ASSERT(0, (outDs.ndims()>=(dim_t)ndims)); out.eval(); vector index_(index, index+ndims); dim4 oDims = toDims(index_, outDs); bool is_vector = true; for (int i = 0; is_vector && i < (int)oDims.ndims() - 1; i++) { is_vector &= oDims[i] == 1; } is_vector &= in_.isVector() || in_.isScalar(); for (dim_t i = ndims; i < (int)in_.ndims(); i++) { oDims[i] = 1; } if (is_vector) { if (oDims.elements() != (dim_t)in_.elements() && in_.elements() != 1) { AF_ERROR("Size mismatch between input and output", AF_ERR_SIZE); } // If both out and in are vectors of equal elements, reshape in to out dims Array in = in_.elements() == 1 ? tile(in_, oDims) : modDims(in_, oDims); Array dst = createSubArray(out, index_, false); copyArray(dst, in); } else { for (int i = 0; i < 4; i++) { if (oDims[i] != iDims[i]) { AF_ERROR("Size mismatch between input and output", AF_ERR_SIZE); } } Array dst = createSubArray(out, index_, false); copyArray(dst, in_); } } template static void assign_helper(Array &out, const unsigned &ndims, const af_seq *index, const af_array &in_) { ArrayInfo iInfo = getInfo(in_); af_dtype iType = iInfo.getType(); if(out.getType() == c64 || out.getType() == c32) { switch(iType) { case c64: assign(out, ndims, index, getArray(in_)); break; case c32: assign(out, ndims, index, getArray(in_)); break; default : TYPE_ERROR(1, iType); break; } } else { switch(iType) { case f64: assign(out, ndims, index, getArray(in_)); break; case f32: assign(out, ndims, index, getArray(in_)); break; case s32: assign(out, ndims, index, getArray(in_)); break; case u32: assign(out, ndims, index, getArray(in_)); break; case s64: assign(out, ndims, index, getArray(in_)); break; case u64: assign(out, ndims, index, getArray(in_)); break; case s16: assign(out, ndims, index, getArray(in_)); break; case u16: assign(out, ndims, index, getArray(in_)); break; case u8 : assign(out, ndims, index, getArray(in_)); break; case b8 : assign(out, ndims, index, getArray(in_)); break; default : TYPE_ERROR(1, iType); break; } } } af_err af_assign_seq(af_array *out, const af_array lhs, const unsigned ndims, const af_seq *index, const af_array rhs) { try { ARG_ASSERT(0, (lhs!=0)); ARG_ASSERT(1, (ndims>0)); ARG_ASSERT(3, (rhs!=0)); ArrayInfo lInfo = getInfo(lhs); if (ndims == 1 && ndims != lInfo.ndims()) { af_array tmp_in, tmp_out; AF_CHECK(af_flat(&tmp_in, lhs)); AF_CHECK(af_assign_seq(&tmp_out, tmp_in, ndims, index, rhs)); AF_CHECK(af_moddims(out, tmp_out, lInfo.ndims(), lInfo.dims().get())); AF_CHECK(af_release_array(tmp_in)); AF_CHECK(af_release_array(tmp_out)); return AF_SUCCESS; } for(dim_t i=0; i<(dim_t)ndims; ++i) { ARG_ASSERT(2, (index[i].step>=0)); } af_array res = 0; if (*out != lhs) { int count = 0; AF_CHECK(af_get_data_ref_count(&count, lhs)); if (count > 1) { AF_CHECK(af_copy_array(&res, lhs)); } else { AF_CHECK(af_retain_array(&res, lhs)); } } else { res = lhs; } try { if (lhs != rhs) { ArrayInfo oInfo = getInfo(lhs); af_dtype oType = oInfo.getType(); switch(oType) { case c64: assign_helper(getWritableArray(res), ndims, index, rhs); break; case c32: assign_helper(getWritableArray(res), ndims, index, rhs); break; case f64: assign_helper(getWritableArray(res), ndims, index, rhs); break; case f32: assign_helper(getWritableArray(res), ndims, index, rhs); break; case s32: assign_helper(getWritableArray(res), ndims, index, rhs); break; case u32: assign_helper(getWritableArray(res), ndims, index, rhs); break; case s64: assign_helper(getWritableArray(res), ndims, index, rhs); break; case u64: assign_helper(getWritableArray(res), ndims, index, rhs); break; case s16: assign_helper(getWritableArray(res), ndims, index, rhs); break; case u16: assign_helper(getWritableArray(res), ndims, index, rhs); break; case u8 : assign_helper(getWritableArray(res), ndims, index, rhs); break; case b8 : assign_helper(getWritableArray(res), ndims, index, rhs); break; default : TYPE_ERROR(1, oType); break; } } } catch(...) { af_release_array(res); throw; } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } template static void genAssign(af_array& out, const af_index_t* indexs, const af_array& rhs) { detail::assign(getWritableArray(out), indexs, getArray(rhs)); } af_err af_assign_gen(af_array *out, const af_array lhs, const dim_t ndims, const af_index_t* indexs, const af_array rhs_) { af_array output = 0; af_array rhs = rhs_; // spanner is sequence index used for indexing along the // dimensions after ndims af_index_t spanner; spanner.idx.seq = af_span; spanner.isSeq = true; try { ARG_ASSERT(2, (ndims>0)); ARG_ASSERT(3, (indexs!=NULL)); int track = 0; vector seqs(4, af_span); for (dim_t i = 0; i < ndims; i++) { if (indexs[i].isSeq) { track++; seqs[i] = indexs[i].idx.seq; } } if (track==(int)ndims) { // all indexs are sequences, redirecting to af_assign return af_assign_seq(out, lhs, ndims, &(seqs.front()), rhs); } ARG_ASSERT(1, (lhs!=0)); ARG_ASSERT(4, (rhs!=0)); ArrayInfo lInfo = getInfo(lhs); ArrayInfo rInfo = getInfo(rhs); dim4 lhsDims = lInfo.dims(); dim4 rhsDims = rInfo.dims(); af_dtype lhsType= lInfo.getType(); af_dtype rhsType= rInfo.getType(); ARG_ASSERT(2, (ndims == 1) || (ndims == (dim_t)lInfo.ndims())); if (ndims == 1 && ndims != (dim_t)lInfo.ndims()) { af_array tmp_in = 0, tmp_out = 0; AF_CHECK(af_flat(&tmp_in, lhs)); AF_CHECK(af_assign_gen(&tmp_out, tmp_in, ndims, indexs, rhs_)); AF_CHECK(af_moddims(out, tmp_out, lInfo.ndims(), lInfo.dims().get())); AF_CHECK(af_release_array(tmp_in)); AF_CHECK(af_release_array(tmp_out)); return AF_SUCCESS; } ARG_ASSERT(1, (lhsType==rhsType)); ARG_ASSERT(3, (rhsDims.ndims()>0)); ARG_ASSERT(1, (lhsDims.ndims()>=rhsDims.ndims())); ARG_ASSERT(2, (lhsDims.ndims()>=ndims)); if (*out != lhs) { int count = 0; AF_CHECK(af_get_data_ref_count(&count, lhs)); if (count > 1) { AF_CHECK(af_copy_array(&output, lhs)); } else { AF_CHECK(af_retain_array(&output, lhs)); } } else { output = lhs; } dim4 oDims = toDims(seqs, lhsDims); // if af_array are indexs along any // particular dimension, set the length of // that dimension accordingly before any checks for (dim_t i=0; i(output, idxrs, rhs); break; case f64: genAssign(output, idxrs, rhs); break; case c32: genAssign(output, idxrs, rhs); break; case f32: genAssign(output, idxrs, rhs); break; case u64: genAssign(output, idxrs, rhs); break; case u32: genAssign(output, idxrs, rhs); break; case s64: genAssign(output, idxrs, rhs); break; case s32: genAssign(output, idxrs, rhs); break; case s16: genAssign(output, idxrs, rhs); break; case u16: genAssign(output, idxrs, rhs); break; case u8: genAssign(output, idxrs, rhs); break; case b8: genAssign(output, idxrs, rhs); break; default: TYPE_ERROR(1, rhsType); } } catch(...) { if (*out != lhs) { AF_CHECK(af_release_array(output)); if (is_vector) { AF_CHECK(af_release_array(rhs)); } } throw; } if (is_vector) { AF_CHECK(af_release_array(rhs)); } std::swap(*out, output); } CATCHALL; return AF_SUCCESS; }