/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include using namespace detail; using af::dim4; template static inline af_array arithOp(const af_array lhs, const af_array rhs, const dim4 &odims) { af_array res = getHandle(arithOp(castArray(lhs), castArray(rhs), odims)); return res; } template static af_err af_arith(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { try { const af_dtype otype = implicit(lhs, rhs); ArrayInfo linfo = getInfo(lhs); ArrayInfo rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); af_array res; switch (otype) { case f32: res = arithOp(lhs, rhs, odims); break; case f64: res = arithOp(lhs, rhs, odims); break; case c32: res = arithOp(lhs, rhs, odims); break; case c64: res = arithOp(lhs, rhs, odims); break; case s32: res = arithOp(lhs, rhs, odims); break; case u32: res = arithOp(lhs, rhs, odims); break; case u8 : res = arithOp(lhs, rhs, odims); break; case b8 : res = arithOp(lhs, rhs, odims); break; case s64: res = arithOp(lhs, rhs, odims); break; case u64: res = arithOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, otype); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } template static af_err af_arith_real(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { try { const af_dtype otype = implicit(lhs, rhs); ArrayInfo linfo = getInfo(lhs); ArrayInfo rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); af_array res; switch (otype) { case f32: res = arithOp(lhs, rhs, odims); break; case f64: res = arithOp(lhs, rhs, odims); break; case s32: res = arithOp(lhs, rhs, odims); break; case u32: res = arithOp(lhs, rhs, odims); break; case u8 : res = arithOp(lhs, rhs, odims); break; case b8 : res = arithOp(lhs, rhs, odims); break; case s64: res = arithOp(lhs, rhs, odims); break; case u64: res = arithOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, otype); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } af_err af_add(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_arith(out, lhs, rhs, batchMode); } af_err af_mul(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_arith(out, lhs, rhs, batchMode); } af_err af_sub(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_arith(out, lhs, rhs, batchMode); } af_err af_div(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_arith(out, lhs, rhs, batchMode); } af_err af_maxof(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_arith(out, lhs, rhs, batchMode); } af_err af_minof(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_arith(out, lhs, rhs, batchMode); } af_err af_rem(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_arith_real(out, lhs, rhs, batchMode); } af_err af_mod(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_arith_real(out, lhs, rhs, batchMode); } af_err af_pow(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { try { ArrayInfo linfo = getInfo(lhs); ArrayInfo rinfo = getInfo(rhs); if (linfo.isComplex() || rinfo.isComplex()) { AF_ERROR("Powers of Complex numbers not supported", AF_ERR_NOT_SUPPORTED); } } CATCHALL; return af_arith_real(out, lhs, rhs, batchMode); } af_err af_atan2(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { try { const af_dtype type = implicit(lhs, rhs); if (type != f32 && type != f64) { AF_ERROR("Only floating point arrays are supported for atan2 ", AF_ERR_NOT_SUPPORTED); } ArrayInfo linfo = getInfo(lhs); ArrayInfo rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); af_array res; switch (type) { case f32: res = arithOp(lhs, rhs, odims); break; case f64: res = arithOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, type); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } af_err af_hypot(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { try { const af_dtype type = implicit(lhs, rhs); if (type != f32 && type != f64) { AF_ERROR("Only floating point arrays are supported for hypot ", AF_ERR_NOT_SUPPORTED); } ArrayInfo linfo = getInfo(lhs); ArrayInfo rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); af_array res; switch (type) { case f32: res = arithOp(lhs, rhs, odims); break; case f64: res = arithOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, type); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } template static inline af_array logicOp(const af_array lhs, const af_array rhs, const dim4 &odims) { af_array res = getHandle(logicOp(getArray(lhs), getArray(rhs), odims)); return res; } template static af_err af_logic(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { try { const af_dtype type = implicit(lhs, rhs); ArrayInfo linfo = getInfo(lhs); ArrayInfo rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); af_array res; switch (type) { case f32: res = logicOp(lhs, rhs, odims); break; case f64: res = logicOp(lhs, rhs, odims); break; case c32: res = logicOp(lhs, rhs, odims); break; case c64: res = logicOp(lhs, rhs, odims); break; case s32: res = logicOp(lhs, rhs, odims); break; case u32: res = logicOp(lhs, rhs, odims); break; case u8 : res = logicOp(lhs, rhs, odims); break; case b8 : res = logicOp(lhs, rhs, odims); break; case s64: res = logicOp(lhs, rhs, odims); break; case u64: res = logicOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, type); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } af_err af_eq(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_neq(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_gt(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_ge(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_lt(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_le(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_and(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } af_err af_or(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_logic(out, lhs, rhs, batchMode); } template static inline af_array bitOp(const af_array lhs, const af_array rhs, const dim4 &odims) { af_array res = getHandle(bitOp(getArray(lhs), getArray(rhs), odims)); return res; } template static af_err af_bitwise(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { try { const af_dtype type = implicit(lhs, rhs); ArrayInfo linfo = getInfo(lhs); ArrayInfo rinfo = getInfo(rhs); dim4 odims = getOutDims(linfo.dims(), rinfo.dims(), batchMode); af_array res; switch (type) { case s32: res = bitOp(lhs, rhs, odims); break; case u32: res = bitOp(lhs, rhs, odims); break; case u8 : res = bitOp(lhs, rhs, odims); break; case b8 : res = bitOp(lhs, rhs, odims); break; case s64: res = bitOp(lhs, rhs, odims); break; case u64: res = bitOp(lhs, rhs, odims); break; default: TYPE_ERROR(0, type); } std::swap(*out, res); } CATCHALL; return AF_SUCCESS; } af_err af_bitand(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_bitwise(out, lhs, rhs, batchMode); } af_err af_bitor(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_bitwise(out, lhs, rhs, batchMode); } af_err af_bitxor(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_bitwise(out, lhs, rhs, batchMode); } af_err af_bitshiftl(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_bitwise(out, lhs, rhs, batchMode); } af_err af_bitshiftr(af_array *out, const af_array lhs, const af_array rhs, bool batchMode) { return af_bitwise(out, lhs, rhs, batchMode); }