/******************************************************* * Copyright (c) 2015, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include using af::dim4; using common::half; using detail::Array; using detail::cdouble; using detail::cfloat; using detail::createSelectNode; using detail::intl; using detail::uchar; using detail::uint; using detail::uintl; using detail::ushort; template af_array select(const af_array cond, const af_array a, const af_array b, const dim4& odims) { Array out = createSelectNode(getArray(cond), getArray(a), getArray(b), odims); return getHandle(out); } af_err af_select(af_array* out, const af_array cond, const af_array a, const af_array b) { try { const ArrayInfo& ainfo = getInfo(a); const ArrayInfo& binfo = getInfo(b); const ArrayInfo& cond_info = getInfo(cond); if (cond_info.ndims() == 0) { return af_retain_array(out, cond); } ARG_ASSERT(2, ainfo.getType() == binfo.getType()); ARG_ASSERT(1, cond_info.getType() == b8); dim4 adims = ainfo.dims(); dim4 bdims = binfo.dims(); dim4 cond_dims = cond_info.dims(); dim4 odims(1, 1, 1, 1); for (int i = 0; i < 4; i++) { DIM_ASSERT(2, (adims[i] == bdims[i] && adims[i] == cond_dims[i]) || adims[i] == 1 || bdims[i] == 1 || cond_dims[i] == 1); odims[i] = std::max(std::max(adims[i], bdims[i]), cond_dims[i]); } af_array res; switch (ainfo.getType()) { case f32: res = select(cond, a, b, odims); break; case f64: res = select(cond, a, b, odims); break; case c32: res = select(cond, a, b, odims); break; case c64: res = select(cond, a, b, odims); break; case s32: res = select(cond, a, b, odims); break; case u32: res = select(cond, a, b, odims); break; case s64: res = select(cond, a, b, odims); break; case u64: res = select(cond, a, b, odims); break; case s16: res = select(cond, a, b, odims); break; case u16: res = select(cond, a, b, odims); break; case u8: res = select(cond, a, b, odims); break; case b8: res = select(cond, a, b, odims); break; case f16: res = select(cond, a, b, odims); break; default: TYPE_ERROR(2, ainfo.getType()); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } template af_array select_scalar(const af_array cond, const af_array a, const double b, const dim4& odims) { Array out = createSelectNode(getArray(cond), getArray(a), b, odims); return getHandle(out); } af_err af_select_scalar_r(af_array* out, const af_array cond, const af_array a, const double b) { try { const ArrayInfo& ainfo = getInfo(a); const ArrayInfo& cinfo = getInfo(cond); ARG_ASSERT(1, cinfo.getType() == b8); dim4 adims = ainfo.dims(); dim4 cond_dims = cinfo.dims(); dim4 odims(1); for (int i = 0; i < 4; i++) { DIM_ASSERT(1, cond_dims[i] == adims[i] || cond_dims[i] == 1 || adims[i] == 1); odims[i] = std::max(cond_dims[i], adims[i]); } af_array res; switch (ainfo.getType()) { case f16: res = select_scalar(cond, a, b, odims); break; case f32: res = select_scalar(cond, a, b, odims); break; case f64: res = select_scalar(cond, a, b, odims); break; case c32: res = select_scalar(cond, a, b, odims); break; case c64: res = select_scalar(cond, a, b, odims); break; case s32: res = select_scalar(cond, a, b, odims); break; case u32: res = select_scalar(cond, a, b, odims); break; case s16: res = select_scalar(cond, a, b, odims); break; case u16: res = select_scalar(cond, a, b, odims); break; case s64: res = select_scalar(cond, a, b, odims); break; case u64: res = select_scalar(cond, a, b, odims); break; case u8: res = select_scalar(cond, a, b, odims); break; case b8: res = select_scalar(cond, a, b, odims); break; default: TYPE_ERROR(2, ainfo.getType()); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } af_err af_select_scalar_l(af_array* out, const af_array cond, const double a, const af_array b) { try { const ArrayInfo& binfo = getInfo(b); const ArrayInfo& cinfo = getInfo(cond); ARG_ASSERT(1, cinfo.getType() == b8); dim4 bdims = binfo.dims(); dim4 cond_dims = cinfo.dims(); dim4 odims(1); for (int i = 0; i < 4; i++) { DIM_ASSERT(1, cond_dims[i] == bdims[i] || cond_dims[i] == 1 || bdims[i] == 1); odims[i] = std::max(cond_dims[i], bdims[i]); } af_array res; switch (binfo.getType()) { case f16: res = select_scalar(cond, b, a, odims); break; case f32: res = select_scalar(cond, b, a, odims); break; case f64: res = select_scalar(cond, b, a, odims); break; case c32: res = select_scalar(cond, b, a, odims); break; case c64: res = select_scalar(cond, b, a, odims); break; case s32: res = select_scalar(cond, b, a, odims); break; case u32: res = select_scalar(cond, b, a, odims); break; case s16: res = select_scalar(cond, b, a, odims); break; case u16: res = select_scalar(cond, b, a, odims); break; case s64: res = select_scalar(cond, b, a, odims); break; case u64: res = select_scalar(cond, b, a, odims); break; case u8: res = select_scalar(cond, b, a, odims); break; case b8: res = select_scalar(cond, b, a, odims); break; default: TYPE_ERROR(2, binfo.getType()); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; }