/******************************************************* * Copyright (c) 2016, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include using af::dim4; using common::half; using common::SparseArrayBase; using detail::cdouble; using detail::cfloat; using detail::intl; using detail::uchar; using detail::uint; using detail::uintl; using detail::ushort; af_array createHandle(const dim4 &d, af_dtype dtype) { // clang-format off switch (dtype) { case f32: return createHandle(d); case c32: return createHandle(d); case f64: return createHandle(d); case c64: return createHandle(d); case b8: return createHandle(d); case s32: return createHandle(d); case u32: return createHandle(d); case u8: return createHandle(d); case s64: return createHandle(d); case u64: return createHandle(d); case s16: return createHandle(d); case u16: return createHandle(d); case f16: return createHandle(d); default: TYPE_ERROR(3, dtype); } // clang-format on } af_array createHandleFromValue(const dim4 &d, double val, af_dtype dtype) { // clang-format off switch (dtype) { case f32: return createHandleFromValue(d, val); case c32: return createHandleFromValue(d, val); case f64: return createHandleFromValue(d, val); case c64: return createHandleFromValue(d, val); case b8: return createHandleFromValue(d, val); case s32: return createHandleFromValue(d, val); case u32: return createHandleFromValue(d, val); case u8: return createHandleFromValue(d, val); case s64: return createHandleFromValue(d, val); case u64: return createHandleFromValue(d, val); case s16: return createHandleFromValue(d, val); case u16: return createHandleFromValue(d, val); case f16: return createHandleFromValue(d, val); default: TYPE_ERROR(3, dtype); } // clang-format on } af_err af_get_data_ptr(void *data, const af_array arr) { try { af_dtype type = getInfo(arr).getType(); // clang-format off switch (type) { case f32: copyData(static_cast(data), arr); break; case c32: copyData(static_cast(data), arr); break; case f64: copyData(static_cast(data), arr); break; case c64: copyData(static_cast(data), arr); break; case b8: copyData(static_cast(data), arr); break; case s32: copyData(static_cast(data), arr); break; case u32: copyData(static_cast(data), arr); break; case u8: copyData(static_cast(data), arr); break; case s64: copyData(static_cast(data), arr); break; case u64: copyData(static_cast(data), arr); break; case s16: copyData(static_cast(data), arr); break; case u16: copyData(static_cast(data), arr); break; case f16: copyData(static_cast(data), arr); break; default: TYPE_ERROR(1, type); } // clang-format on } CATCHALL; return AF_SUCCESS; } // Strong Exception Guarantee af_err af_create_array(af_array *result, const void *const data, const unsigned ndims, const dim_t *const dims, const af_dtype type) { try { af_array out; AF_CHECK(af_init()); dim4 d = verifyDims(ndims, dims); switch (type) { case f32: out = createHandleFromData(d, static_cast(data)); break; case c32: out = createHandleFromData(d, static_cast(data)); break; case f64: out = createHandleFromData(d, static_cast(data)); break; case c64: out = createHandleFromData(d, static_cast(data)); break; case b8: out = createHandleFromData(d, static_cast(data)); break; case s32: out = createHandleFromData(d, static_cast(data)); break; case u32: out = createHandleFromData(d, static_cast(data)); break; case u8: out = createHandleFromData(d, static_cast(data)); break; case s64: out = createHandleFromData(d, static_cast(data)); break; case u64: out = createHandleFromData(d, static_cast(data)); break; case s16: out = createHandleFromData(d, static_cast(data)); break; case u16: out = createHandleFromData(d, static_cast(data)); break; case f16: out = createHandleFromData(d, static_cast(data)); break; default: TYPE_ERROR(4, type); } std::swap(*result, out); } CATCHALL return AF_SUCCESS; } // Strong Exception Guarantee af_err af_create_handle(af_array *result, const unsigned ndims, const dim_t *const dims, const af_dtype type) { try { AF_CHECK(af_init()); if (ndims > 0) { ARG_ASSERT(2, ndims > 0 && dims != NULL); } dim4 d(0); for (unsigned i = 0; i < ndims; i++) { d[i] = dims[i]; } af_array out = createHandle(d, type); std::swap(*result, out); } CATCHALL return AF_SUCCESS; } // Strong Exception Guarantee af_err af_copy_array(af_array *out, const af_array in) { try { const ArrayInfo &info = getInfo(in, false); const af_dtype type = info.getType(); af_array res = 0; if (info.isSparse()) { const SparseArrayBase sbase = getSparseArrayBase(in); if (info.ndims() == 0) { return af_create_sparse_array_from_ptr( out, info.dims()[0], info.dims()[1], 0, nullptr, nullptr, nullptr, type, sbase.getStorage(), afDevice); } switch (type) { case f32: res = copySparseArray(in); break; case f64: res = copySparseArray(in); break; case c32: res = copySparseArray(in); break; case c64: res = copySparseArray(in); break; default: TYPE_ERROR(0, type); } } else { if (info.ndims() == 0) { return af_create_handle(out, 0, nullptr, type); } switch (type) { case f32: res = copyArray(in); break; case c32: res = copyArray(in); break; case f64: res = copyArray(in); break; case c64: res = copyArray(in); break; case b8: res = copyArray(in); break; case s32: res = copyArray(in); break; case u32: res = copyArray(in); break; case u8: res = copyArray(in); break; case s64: res = copyArray(in); break; case u64: res = copyArray(in); break; case s16: res = copyArray(in); break; case u16: res = copyArray(in); break; case f16: res = copyArray(in); break; default: TYPE_ERROR(1, type); } } std::swap(*out, res); } CATCHALL return AF_SUCCESS; } // Strong Exception Guarantee af_err af_get_data_ref_count(int *use_count, const af_array in) { try { const ArrayInfo &info = getInfo(in, false, false); const af_dtype type = info.getType(); int res; switch (type) { case f32: res = getArray(in).useCount(); break; case c32: res = getArray(in).useCount(); break; case f64: res = getArray(in).useCount(); break; case c64: res = getArray(in).useCount(); break; case b8: res = getArray(in).useCount(); break; case s32: res = getArray(in).useCount(); break; case u32: res = getArray(in).useCount(); break; case u8: res = getArray(in).useCount(); break; case s64: res = getArray(in).useCount(); break; case u64: res = getArray(in).useCount(); break; case s16: res = getArray(in).useCount(); break; case u16: res = getArray(in).useCount(); break; case f16: res = getArray(in).useCount(); break; default: TYPE_ERROR(1, type); } std::swap(*use_count, res); } CATCHALL return AF_SUCCESS; } af_err af_release_array(af_array arr) { try { if (arr == 0) { return AF_SUCCESS; } const ArrayInfo &info = getInfo(arr, false, false); af_dtype type = info.getType(); if (info.isSparse()) { switch (type) { case f32: releaseSparseHandle(arr); break; case f64: releaseSparseHandle(arr); break; case c32: releaseSparseHandle(arr); break; case c64: releaseSparseHandle(arr); break; default: TYPE_ERROR(0, type); } } else { switch (type) { case f32: releaseHandle(arr); break; case c32: releaseHandle(arr); break; case f64: releaseHandle(arr); break; case c64: releaseHandle(arr); break; case b8: releaseHandle(arr); break; case s32: releaseHandle(arr); break; case u32: releaseHandle(arr); break; case u8: releaseHandle(arr); break; case s64: releaseHandle(arr); break; case u64: releaseHandle(arr); break; case s16: releaseHandle(arr); break; case u16: releaseHandle(arr); break; case f16: releaseHandle(arr); break; default: TYPE_ERROR(0, type); } } } CATCHALL return AF_SUCCESS; } af_array retain(const af_array in) { const ArrayInfo &info = getInfo(in, false, false); af_dtype ty = info.getType(); if (info.isSparse()) { switch (ty) { case f32: return retainSparseHandle(in); case f64: return retainSparseHandle(in); case c32: return retainSparseHandle(in); case c64: return retainSparseHandle(in); default: TYPE_ERROR(1, ty); } } else { switch (ty) { case f32: return retainHandle(in); case f64: return retainHandle(in); case s32: return retainHandle(in); case u32: return retainHandle(in); case u8: return retainHandle(in); case c32: return retainHandle(in); case c64: return retainHandle(in); case b8: return retainHandle(in); case s64: return retainHandle(in); case u64: return retainHandle(in); case s16: return retainHandle(in); case u16: return retainHandle(in); case f16: return retainHandle(in); default: TYPE_ERROR(1, ty); } } } af_err af_retain_array(af_array *out, const af_array in) { try { *out = retain(in); } CATCHALL; return AF_SUCCESS; } template void write_array(af_array arr, const T *const data, const size_t bytes, af_source src) { if (src == afHost) { writeHostDataArray(getArray(arr), data, bytes); } else { writeDeviceDataArray(getArray(arr), data, bytes); } } af_err af_write_array(af_array arr, const void *data, const size_t bytes, af_source src) { try { af_dtype type = getInfo(arr).getType(); // DIM_ASSERT(2, bytes <= getInfo(arr).bytes()); switch (type) { case f32: write_array(arr, static_cast(data), bytes, src); break; case c32: write_array(arr, static_cast(data), bytes, src); break; case f64: write_array(arr, static_cast(data), bytes, src); break; case c64: write_array(arr, static_cast(data), bytes, src); break; case b8: write_array(arr, static_cast(data), bytes, src); break; case s32: write_array(arr, static_cast(data), bytes, src); break; case u32: write_array(arr, static_cast(data), bytes, src); break; case u8: write_array(arr, static_cast(data), bytes, src); break; case s64: write_array(arr, static_cast(data), bytes, src); break; case u64: write_array(arr, static_cast(data), bytes, src); break; case s16: write_array(arr, static_cast(data), bytes, src); break; case u16: write_array(arr, static_cast(data), bytes, src); break; case f16: write_array(arr, static_cast(data), bytes, src); break; default: TYPE_ERROR(4, type); } } CATCHALL return AF_SUCCESS; } af_err af_get_elements(dim_t *elems, const af_array arr) { try { // Do not check for device mismatch *elems = getInfo(arr, false, false).elements(); } CATCHALL return AF_SUCCESS; } af_err af_get_type(af_dtype *type, const af_array arr) { try { // Do not check for device mismatch *type = getInfo(arr, false, false).getType(); } CATCHALL return AF_SUCCESS; } af_err af_get_dims(dim_t *d0, dim_t *d1, dim_t *d2, dim_t *d3, const af_array in) { try { // Do not check for device mismatch const ArrayInfo &info = getInfo(in, false, false); *d0 = info.dims()[0]; *d1 = info.dims()[1]; *d2 = info.dims()[2]; *d3 = info.dims()[3]; } CATCHALL return AF_SUCCESS; } af_err af_get_numdims(unsigned *nd, const af_array in) { try { // Do not check for device mismatch const ArrayInfo &info = getInfo(in, false, false); *nd = info.ndims(); } CATCHALL return AF_SUCCESS; } #undef INSTANTIATE #define INSTANTIATE(fn1, fn2) \ af_err fn1(bool *result, const af_array in) { \ try { \ const ArrayInfo &info = getInfo(in, false, false); \ *result = info.fn2(); \ } \ CATCHALL \ return AF_SUCCESS; \ } INSTANTIATE(af_is_empty, isEmpty) INSTANTIATE(af_is_scalar, isScalar) INSTANTIATE(af_is_row, isRow) INSTANTIATE(af_is_column, isColumn) INSTANTIATE(af_is_vector, isVector) INSTANTIATE(af_is_complex, isComplex) INSTANTIATE(af_is_real, isReal) INSTANTIATE(af_is_double, isDouble) INSTANTIATE(af_is_single, isSingle) INSTANTIATE(af_is_half, isHalf) INSTANTIATE(af_is_realfloating, isRealFloating) INSTANTIATE(af_is_floating, isFloating) INSTANTIATE(af_is_integer, isInteger) INSTANTIATE(af_is_bool, isBool) INSTANTIATE(af_is_sparse, isSparse) #undef INSTANTIATE template inline void getScalar(T *out, const af_array &arr) { out[0] = getScalar(getArray(arr)); } af_err af_get_scalar(void *output_value, const af_array arr) { try { ARG_ASSERT(0, (output_value != NULL)); const ArrayInfo &info = getInfo(arr); const af_dtype type = info.getType(); switch (type) { case f32: getScalar(reinterpret_cast(output_value), arr); break; case f64: getScalar(reinterpret_cast(output_value), arr); break; case b8: getScalar(reinterpret_cast(output_value), arr); break; case s32: getScalar(reinterpret_cast(output_value), arr); break; case u32: getScalar(reinterpret_cast(output_value), arr); break; case u8: getScalar(reinterpret_cast(output_value), arr); break; case s64: getScalar(reinterpret_cast(output_value), arr); break; case u64: getScalar(reinterpret_cast(output_value), arr); break; case s16: getScalar(reinterpret_cast(output_value), arr); break; case u16: getScalar(reinterpret_cast(output_value), arr); break; case c32: getScalar(reinterpret_cast(output_value), arr); break; case c64: getScalar(reinterpret_cast(output_value), arr); break; case f16: getScalar(static_cast(output_value), arr); break; default: TYPE_ERROR(4, type); } } CATCHALL; return AF_SUCCESS; }