/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include using af::dim4; using common::half; using detail::Array; using detail::cdouble; using detail::cfloat; using detail::intl; using detail::uchar; using detail::uint; using detail::uintl; using detail::ushort; using std::swap; using std::vector; template static inline af_array join(const int dim, const af_array first, const af_array second) { return getHandle(join(dim, getArray(first), getArray(second))); } template static inline af_array join_many(const int dim, const unsigned n_arrays, const af_array *inputs) { vector> inputs_; inputs_.reserve(n_arrays); for (unsigned i = 0; i < n_arrays; i++) { inputs_.push_back(getArray(inputs[i])); } return getHandle(join(dim, inputs_)); } af_err af_join(af_array *out, const int dim, const af_array first, const af_array second) { try { const ArrayInfo &finfo = getInfo(first); const ArrayInfo &sinfo = getInfo(second); dim4 fdims = finfo.dims(); dim4 sdims = sinfo.dims(); ARG_ASSERT(1, dim >= 0 && dim < 4); ARG_ASSERT(2, finfo.getType() == sinfo.getType()); if (sinfo.elements() == 0) { return af_retain_array(out, first); } if (finfo.elements() == 0) { return af_retain_array(out, second); } DIM_ASSERT(2, sinfo.elements() > 0); DIM_ASSERT(3, finfo.elements() > 0); // All dimensions except join dimension must be equal // Compute output dims for (int i = 0; i < 4; i++) { if (i != dim) { DIM_ASSERT(2, fdims[i] == sdims[i]); } } af_array output; switch (finfo.getType()) { case f32: output = join(dim, first, second); break; case c32: output = join(dim, first, second); break; case f64: output = join(dim, first, second); break; case c64: output = join(dim, first, second); break; case b8: output = join(dim, first, second); break; case s32: output = join(dim, first, second); break; case u32: output = join(dim, first, second); break; case s64: output = join(dim, first, second); break; case u64: output = join(dim, first, second); break; case s16: output = join(dim, first, second); break; case u16: output = join(dim, first, second); break; case u8: output = join(dim, first, second); break; case f16: output = join(dim, first, second); break; default: TYPE_ERROR(1, finfo.getType()); } std::swap(*out, output); } CATCHALL; return AF_SUCCESS; } af_err af_join_many(af_array *out, const int dim, const unsigned n_arrays, const af_array *inputs) { try { ARG_ASSERT(3, inputs != nullptr); if (n_arrays == 1) { af_array ret = nullptr; AF_CHECK(af_retain_array(&ret, inputs[0])); std::swap(*out, ret); return AF_SUCCESS; } vector info; info.reserve(n_arrays); vector dims(n_arrays); for (unsigned i = 0; i < n_arrays; i++) { info.push_back(getInfo(inputs[i])); dims[i] = info[i].dims(); } ARG_ASSERT(1, dim >= 0 && dim < 4); for (unsigned i = 1; i < n_arrays; i++) { ARG_ASSERT(3, info[0].getType() == info[i].getType()); DIM_ASSERT(3, info[i].elements() > 0); } // All dimensions except join dimension must be equal // Compute output dims for (int i = 0; i < 4; i++) { if (i != dim) { for (unsigned j = 1; j < n_arrays; j++) { DIM_ASSERT(3, dims[0][i] == dims[j][i]); } } } af_array output; switch (info[0].getType()) { case f32: output = join_many(dim, n_arrays, inputs); break; case c32: output = join_many(dim, n_arrays, inputs); break; case f64: output = join_many(dim, n_arrays, inputs); break; case c64: output = join_many(dim, n_arrays, inputs); break; case b8: output = join_many(dim, n_arrays, inputs); break; case s32: output = join_many(dim, n_arrays, inputs); break; case u32: output = join_many(dim, n_arrays, inputs); break; case s64: output = join_many(dim, n_arrays, inputs); break; case u64: output = join_many(dim, n_arrays, inputs); break; case s16: output = join_many(dim, n_arrays, inputs); break; case u16: output = join_many(dim, n_arrays, inputs); break; case u8: output = join_many(dim, n_arrays, inputs); break; case f16: output = join_many(dim, n_arrays, inputs); break; default: TYPE_ERROR(1, info[0].getType()); } swap(*out, output); } CATCHALL; return AF_SUCCESS; }