/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using af::dim4; using common::half; using detail::arithOp; using detail::arithOpD; using detail::cdouble; using detail::cfloat; using detail::intl; using detail::uchar; using detail::uint; using detail::uintl; using detail::ushort; template static inline af_array arithOp(const af_array lhs, const af_array rhs, const dim4 &odims) { af_array res = getHandle(arithOp(castArray(lhs), castArray(rhs), odims)); return res; } template static inline af_array sparseArithOp(const af_array lhs, const af_array rhs) { auto res = arithOp(getSparseArray(lhs), getSparseArray(rhs)); return getHandle(res); } template static inline af_array arithSparseDenseOp(const af_array lhs, const af_array rhs, const bool reverse) { if (op == af_add_t || op == af_sub_t) { return getHandle( arithOpD(castSparse(lhs), castArray(rhs), reverse)); } if (op == af_mul_t || op == af_div_t) { return getHandle( arithOp(castSparse(lhs), castArray(rhs), reverse)); } } template static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { try { const ArrayInfo &linfo = getInfo(lhs); const ArrayInfo &rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); const af_dtype otype = implicit(linfo.getType(), rinfo.getType()); af_array res; switch (otype) { case f32: res = arithOp(lhs, rhs, odims); break; case f64: res = arithOp(lhs, rhs, odims); break; case c32: res = arithOp(lhs, rhs, odims); break; case c64: res = arithOp(lhs, rhs, odims); break; case s32: res = arithOp(lhs, rhs, odims); break; case u32: res = arithOp(lhs, rhs, odims); break; case u8: res = arithOp(lhs, rhs, odims); break; case b8: res = arithOp(lhs, rhs, odims); break; case s64: res = arithOp(lhs, rhs, odims); break; case u64: res = arithOp(lhs, rhs, odims); break; case s16: res = arithOp(lhs, rhs, odims); break; case u16: res = arithOp(lhs, rhs, odims); break; case f16: res = arithOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, otype); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } template static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { try { const ArrayInfo &linfo = getInfo(lhs); const ArrayInfo &rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); const af_dtype otype = implicit(linfo.getType(), rinfo.getType()); af_array res; switch (otype) { case f32: res = arithOp(lhs, rhs, odims); break; case f64: res = arithOp(lhs, rhs, odims); break; case s32: res = arithOp(lhs, rhs, odims); break; case u32: res = arithOp(lhs, rhs, odims); break; case u8: res = arithOp(lhs, rhs, odims); break; case b8: res = arithOp(lhs, rhs, odims); break; case s64: res = arithOp(lhs, rhs, odims); break; case u64: res = arithOp(lhs, rhs, odims); break; case s16: res = arithOp(lhs, rhs, odims); break; case u16: res = arithOp(lhs, rhs, odims); break; case f16: res = arithOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, otype); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } template static af_err af_arith_sparse(af_array *out, const af_array lhs, const af_array rhs) { try { const common::SparseArrayBase linfo = getSparseArrayBase(lhs); const common::SparseArrayBase rinfo = getSparseArrayBase(rhs); ARG_ASSERT(1, (linfo.getStorage() == rinfo.getStorage())); ARG_ASSERT(1, (linfo.dims() == rinfo.dims())); ARG_ASSERT(1, (linfo.getStorage() == AF_STORAGE_CSR)); const af_dtype otype = implicit(linfo.getType(), rinfo.getType()); af_array res; switch (otype) { case f32: res = sparseArithOp(lhs, rhs); break; case f64: res = sparseArithOp(lhs, rhs); break; case c32: res = sparseArithOp(lhs, rhs); break; case c64: res = sparseArithOp(lhs, rhs); break; default: TYPE_ERROR(0, otype); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } template static af_err af_arith_sparse_dense(af_array *out, const af_array lhs, const af_array rhs, const bool reverse = false) { try { const common::SparseArrayBase linfo = getSparseArrayBase(lhs); const ArrayInfo &rinfo = getInfo(rhs); const af_dtype otype = implicit(linfo.getType(), rinfo.getType()); af_array res; switch (otype) { case f32: res = arithSparseDenseOp(lhs, rhs, reverse); break; case f64: res = arithSparseDenseOp(lhs, rhs, reverse); break; case c32: res = arithSparseDenseOp(lhs, rhs, reverse); break; case c64: res = arithSparseDenseOp(lhs, rhs, reverse); break; default: TYPE_ERROR(0, otype); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } af_err af_add(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { // Check if inputs are sparse const ArrayInfo &linfo = getInfo(lhs, false, true); const ArrayInfo &rinfo = getInfo(rhs, false, true); if (linfo.isSparse() && rinfo.isSparse()) { return af_arith_sparse(out, lhs, rhs); } if (linfo.isSparse() && !rinfo.isSparse()) { return af_arith_sparse_dense(out, lhs, rhs); } if (!linfo.isSparse() && rinfo.isSparse()) { // second operand(Array) of af_arith call should be dense return af_arith_sparse_dense(out, rhs, lhs, true); } return af_arith(out, lhs, rhs, batchMode); } af_err af_mul(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { // Check if inputs are sparse const ArrayInfo &linfo = getInfo(lhs, false, true); const ArrayInfo &rinfo = getInfo(rhs, false, true); if (linfo.isSparse() && rinfo.isSparse()) { // return af_arith_sparse(out, lhs, rhs); // MKL doesn't have mul or div support yet, hence // this is commented out although alternative cpu code exists return AF_ERR_NOT_SUPPORTED; } if (linfo.isSparse() && !rinfo.isSparse()) { return af_arith_sparse_dense(out, lhs, rhs); } if (!linfo.isSparse() && rinfo.isSparse()) { return af_arith_sparse_dense(out, rhs, lhs, true); // dense should be rhs } return af_arith(out, lhs, rhs, batchMode); } af_err af_sub(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { // Check if inputs are sparse const ArrayInfo &linfo = getInfo(lhs, false, true); const ArrayInfo &rinfo = getInfo(rhs, false, true); if (linfo.isSparse() && rinfo.isSparse()) { return af_arith_sparse(out, lhs, rhs); } if (linfo.isSparse() && !rinfo.isSparse()) { return af_arith_sparse_dense(out, lhs, rhs); } if (!linfo.isSparse() && rinfo.isSparse()) { return af_arith_sparse_dense(out, rhs, lhs, true); // dense should be rhs } return af_arith(out, lhs, rhs, batchMode); } af_err af_div(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { // Check if inputs are sparse const ArrayInfo &linfo = getInfo(lhs, false, true); const ArrayInfo &rinfo = getInfo(rhs, false, true); if (linfo.isSparse() && rinfo.isSparse()) { // return af_arith_sparse(out, lhs, rhs); // MKL doesn't have mul or div support yet, hence // this is commented out although alternative cpu code exists return AF_ERR_NOT_SUPPORTED; } if (linfo.isSparse() && !rinfo.isSparse()) { return af_arith_sparse_dense(out, lhs, rhs); } if (!linfo.isSparse() && rinfo.isSparse()) { // Division by sparse is currently not allowed - for convinence of // dealing with division by 0 // return af_arith_sparse_dense(out, rhs, lhs, true); // dense // should be rhs return AF_ERR_NOT_SUPPORTED; } return af_arith(out, lhs, rhs, batchMode); } af_err af_maxof(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_arith(out, lhs, rhs, batchMode); } af_err af_minof(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_arith(out, lhs, rhs, batchMode); } af_err af_rem(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_arith_real(out, lhs, rhs, batchMode); } af_err af_mod(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_arith_real(out, lhs, rhs, batchMode); } af_err af_pow(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { try { const ArrayInfo &linfo = getInfo(lhs); const ArrayInfo &rinfo = getInfo(rhs); if (rinfo.isComplex()) { af_array log_lhs, log_res; af_array res; AF_CHECK(af_log(&log_lhs, lhs)); AF_CHECK(af_mul(&log_res, log_lhs, rhs, batchMode)); AF_CHECK(af_exp(&res, log_res)); AF_CHECK(af_release_array(log_lhs)); AF_CHECK(af_release_array(log_res)); std::swap(*out, res); return AF_SUCCESS; } if (linfo.isComplex()) { af_array mag, angle; af_array mag_res, angle_res; af_array real_res, imag_res, cplx_res; af_array res; AF_CHECK(af_abs(&mag, lhs)); AF_CHECK(af_arg(&angle, lhs)); AF_CHECK(af_pow(&mag_res, mag, rhs, batchMode)); AF_CHECK(af_mul(&angle_res, angle, rhs, batchMode)); AF_CHECK(af_cos(&real_res, angle_res)); AF_CHECK(af_sin(&imag_res, angle_res)); AF_CHECK(af_cplx2(&cplx_res, real_res, imag_res, batchMode)); AF_CHECK(af_mul(&res, mag_res, cplx_res, batchMode)); AF_CHECK(af_release_array(mag)); AF_CHECK(af_release_array(angle)); AF_CHECK(af_release_array(mag_res)); AF_CHECK(af_release_array(angle_res)); AF_CHECK(af_release_array(real_res)); AF_CHECK(af_release_array(imag_res)); AF_CHECK(af_release_array(cplx_res)); std::swap(*out, res); return AF_SUCCESS; } } CATCHALL; return af_arith_real(out, lhs, rhs, batchMode); } af_err af_root(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { try { const ArrayInfo &linfo = getInfo(lhs); const ArrayInfo &rinfo = getInfo(rhs); if (linfo.isComplex() || rinfo.isComplex()) { af_array log_lhs, log_res; af_array res; AF_CHECK(af_log(&log_lhs, lhs)); AF_CHECK(af_div(&log_res, log_lhs, rhs, batchMode)); AF_CHECK(af_exp(&res, log_res)); std::swap(*out, res); return AF_SUCCESS; } af_array one; AF_CHECK(af_constant(&one, 1, linfo.ndims(), linfo.dims().get(), linfo.getType())); af_array inv_lhs; AF_CHECK(af_div(&inv_lhs, one, lhs, batchMode)); AF_CHECK(af_arith_real(out, rhs, inv_lhs, batchMode)); AF_CHECK(af_release_array(one)); AF_CHECK(af_release_array(inv_lhs)); } CATCHALL; return AF_SUCCESS; } af_err af_atan2(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { try { const af_dtype type = implicit(lhs, rhs); if (type != f32 && type != f64) { AF_ERROR("Only floating point arrays are supported for atan2 ", AF_ERR_NOT_SUPPORTED); } const ArrayInfo &linfo = getInfo(lhs); const ArrayInfo &rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); af_array res; switch (type) { case f32: res = arithOp(lhs, rhs, odims); break; case f64: res = arithOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, type); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } af_err af_hypot(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { try { const af_dtype type = implicit(lhs, rhs); if (type != f32 && type != f64) { AF_ERROR("Only floating point arrays are supported for hypot ", AF_ERR_NOT_SUPPORTED); } const ArrayInfo &linfo = getInfo(lhs); const ArrayInfo &rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); af_array res; switch (type) { case f32: res = arithOp(lhs, rhs, odims); break; case f64: res = arithOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, type); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } template static inline af_array logicOp(const af_array lhs, const af_array rhs, const dim4 &odims) { af_array res = getHandle(logicOp(castArray(lhs), castArray(rhs), odims)); return res; } template static af_err af_logic(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { try { const af_dtype type = implicit(lhs, rhs); const ArrayInfo &linfo = getInfo(lhs); const ArrayInfo &rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); af_array res; switch (type) { case f32: res = logicOp(lhs, rhs, odims); break; case f64: res = logicOp(lhs, rhs, odims); break; case c32: res = logicOp(lhs, rhs, odims); break; case c64: res = logicOp(lhs, rhs, odims); break; case s32: res = logicOp(lhs, rhs, odims); break; case u32: res = logicOp(lhs, rhs, odims); break; case u8: res = logicOp(lhs, rhs, odims); break; case b8: res = logicOp(lhs, rhs, odims); break; case s64: res = logicOp(lhs, rhs, odims); break; case u64: res = logicOp(lhs, rhs, odims); break; case s16: res = logicOp(lhs, rhs, odims); break; case u16: res = logicOp(lhs, rhs, odims); break; case f16: res = logicOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, type); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } af_err af_eq(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_neq(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_gt(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_ge(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_lt(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_le(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_and(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_or(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } template static inline af_array bitOp(const af_array lhs, const af_array rhs, const dim4 &odims) { af_array res = getHandle(bitOp(castArray(lhs), castArray(rhs), odims)); return res; } template static af_err af_bitwise(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { try { const af_dtype type = implicit(lhs, rhs); const ArrayInfo &linfo = getInfo(lhs); const ArrayInfo &rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); if (odims.ndims() == 0) { return af_create_handle(out, 0, nullptr, type); } af_array res; switch (type) { case s32: res = bitOp(lhs, rhs, odims); break; case u32: res = bitOp(lhs, rhs, odims); break; case u8: res = bitOp(lhs, rhs, odims); break; case b8: res = bitOp(lhs, rhs, odims); break; case s64: res = bitOp(lhs, rhs, odims); break; case u64: res = bitOp(lhs, rhs, odims); break; case s16: res = bitOp(lhs, rhs, odims); break; case u16: res = bitOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, type); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } af_err af_bitand(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_bitwise(out, lhs, rhs, batchMode); } af_err af_bitor(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_bitwise(out, lhs, rhs, batchMode); } af_err af_bitxor(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_bitwise(out, lhs, rhs, batchMode); } af_err af_bitshiftl(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_bitwise(out, lhs, rhs, batchMode); } af_err af_bitshiftr(af_array *out, const af_array lhs, const af_array rhs, const bool batchMode) { return af_bitwise(out, lhs, rhs, batchMode); }