/******************************************************* * Copyright (c) 2015, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include #include #include using namespace detail; using af::dim4; template void replace(af_array a, const af_array cond, const af_array b) { select(getCopyOnWriteArray(a), getArray(cond), getArray(a), getArray(b)); } af_err af_replace(af_array a, const af_array cond, const af_array b) { try { const ArrayInfo& ainfo = getInfo(a); const ArrayInfo& binfo = getInfo(b); const ArrayInfo& cinfo = getInfo(cond); if (cinfo.ndims() == 0) { return AF_SUCCESS; } ARG_ASSERT(2, ainfo.getType() == binfo.getType()); ARG_ASSERT(1, cinfo.getType() == b8); DIM_ASSERT(1, ainfo.ndims() >= binfo.ndims()); DIM_ASSERT(1, cinfo.ndims() == std::min(ainfo.ndims(), binfo.ndims())); dim4 adims = ainfo.dims(); dim4 bdims = binfo.dims(); dim4 cdims = cinfo.dims(); for (int i = 0; i < 4; i++) { DIM_ASSERT(1, cdims[i] == std::min(adims[i], bdims[i])); DIM_ASSERT(2, adims[i] == bdims[i] || bdims[i] == 1); } switch (ainfo.getType()) { case f32: replace(a, cond, b); break; case f64: replace(a, cond, b); break; case c32: replace(a, cond, b); break; case c64: replace(a, cond, b); break; case s32: replace(a, cond, b); break; case u32: replace(a, cond, b); break; case s64: replace(a, cond, b); break; case u64: replace(a, cond, b); break; case s16: replace(a, cond, b); break; case u16: replace(a, cond, b); break; case u8: replace(a, cond, b); break; case b8: replace(a, cond, b); break; default: TYPE_ERROR(2, ainfo.getType()); } } CATCHALL; return AF_SUCCESS; } template void replace_scalar(af_array a, const af_array cond, const double b) { select_scalar(getCopyOnWriteArray(a), getArray(cond), getArray(a), b); } af_err af_replace_scalar(af_array a, const af_array cond, const double b) { try { const ArrayInfo& ainfo = getInfo(a); const ArrayInfo& cinfo = getInfo(cond); ARG_ASSERT(1, cinfo.getType() == b8); DIM_ASSERT(1, cinfo.ndims() == ainfo.ndims()); dim4 adims = ainfo.dims(); dim4 cdims = cinfo.dims(); for (int i = 0; i < 4; i++) { DIM_ASSERT(1, cdims[i] == adims[i]); } switch (ainfo.getType()) { case f32: replace_scalar(a, cond, b); break; case f64: replace_scalar(a, cond, b); break; case c32: replace_scalar(a, cond, b); break; case c64: replace_scalar(a, cond, b); break; case s32: replace_scalar(a, cond, b); break; case u32: replace_scalar(a, cond, b); break; case s64: replace_scalar(a, cond, b); break; case u64: replace_scalar(a, cond, b); break; case s16: replace_scalar(a, cond, b); break; case u16: replace_scalar(a, cond, b); break; case u8: replace_scalar(a, cond, b); break; case b8: replace_scalar(a, cond, b); break; default: TYPE_ERROR(2, ainfo.getType()); } } CATCHALL; return AF_SUCCESS; }