/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using af::dim4; using namespace detail; using namespace std; dim4 verifyDims(const unsigned ndims, const dim_t * const dims) { DIM_ASSERT(1, ndims >= 1); dim4 d(1, 1, 1, 1); for(unsigned i = 0; i < ndims; i++) { d[i] = dims[i]; DIM_ASSERT(2, dims[i] >= 1); } return d; } //Strong Exception Guarantee af_err af_constant(af_array *result, const double value, const unsigned ndims, const dim_t * const dims, const af_dtype type) { try { af_array out; AF_CHECK(af_init()); dim4 d = verifyDims(ndims, dims); switch(type) { case f32: out = createHandleFromValue(d, value); break; case c32: out = createHandleFromValue(d, value); break; case f64: out = createHandleFromValue(d, value); break; case c64: out = createHandleFromValue(d, value); break; case b8: out = createHandleFromValue(d, value); break; case s32: out = createHandleFromValue(d, value); break; case u32: out = createHandleFromValue(d, value); break; case u8: out = createHandleFromValue(d, value); break; case s64: out = createHandleFromValue(d, value); break; case u64: out = createHandleFromValue(d, value); break; case s16: out = createHandleFromValue(d, value); break; case u16: out = createHandleFromValue(d, value); break; default: TYPE_ERROR(4, type); } std::swap(*result, out); } CATCHALL return AF_SUCCESS; } template static inline af_array createCplx(dim4 dims, const Ti real, const Ti imag) { To cval = scalar(real, imag); af_array out = getHandle(createValueArray(dims, cval)); return out; } af_err af_constant_complex(af_array *result, const double real, const double imag, const unsigned ndims, const dim_t * const dims, af_dtype type) { try { af_array out; AF_CHECK(af_init()); dim4 d = verifyDims(ndims, dims); switch (type) { case c32: out = createCplx(d, real, imag); break; case c64: out = createCplx(d, real, imag); break; default: TYPE_ERROR(5, type); } std::swap(*result, out); } CATCHALL return AF_SUCCESS; } af_err af_constant_long(af_array *result, const intl val, const unsigned ndims, const dim_t * const dims) { try { af_array out; AF_CHECK(af_init()); dim4 d = verifyDims(ndims, dims); out = getHandle(createValueArray(d, val)); std::swap(*result, out); } CATCHALL; return AF_SUCCESS; } af_err af_constant_ulong(af_array *result, const uintl val, const unsigned ndims, const dim_t * const dims) { try { af_array out; AF_CHECK(af_init()); dim4 d = verifyDims(ndims, dims); out = getHandle(createValueArray(d, val)); std::swap(*result, out); } CATCHALL; return AF_SUCCESS; } template static inline af_array randn_(const af::dim4 &dims) { return getHandle(randn(dims)); } template static inline af_array randu_(const af::dim4 &dims) { return getHandle(randu(dims)); } template static inline af_array identity_(const af::dim4 &dims) { return getHandle(detail::identity(dims)); } af_err af_randu(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type) { try { af_array result; AF_CHECK(af_init()); dim4 d = verifyDims(ndims, dims); switch(type) { case f32: result = randu_(d); break; case c32: result = randu_(d); break; case f64: result = randu_(d); break; case c64: result = randu_(d); break; case s32: result = randu_(d); break; case u32: result = randu_(d); break; case s64: result = randu_(d); break; case u64: result = randu_(d); break; case s16: result = randu_(d); break; case u16: result = randu_(d); break; case u8: result = randu_(d); break; case b8: result = randu_(d); break; default: TYPE_ERROR(3, type); } std::swap(*out, result); } CATCHALL return AF_SUCCESS; } af_err af_randn(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type) { try { af_array result; AF_CHECK(af_init()); dim4 d = verifyDims(ndims, dims); switch(type) { case f32: result = randn_(d); break; case c32: result = randn_(d); break; case f64: result = randn_(d); break; case c64: result = randn_(d); break; default: TYPE_ERROR(3, type); } std::swap(*out, result); } CATCHALL return AF_SUCCESS; } af_err af_set_seed(const uintl seed) { try { setSeed(seed); } CATCHALL; return AF_SUCCESS; } af_err af_get_seed(uintl *seed) { try { *seed = getSeed(); } CATCHALL; return AF_SUCCESS; } af_err af_identity(af_array *out, const unsigned ndims, const dim_t * const dims, const af_dtype type) { try { af_array result; AF_CHECK(af_init()); dim4 d = verifyDims(ndims, dims); switch(type) { case f32: result = identity_(d); break; case c32: result = identity_(d); break; case f64: result = identity_(d); break; case c64: result = identity_(d); break; case s32: result = identity_(d); break; case u32: result = identity_(d); break; case u8: result = identity_(d); break; case u64: result = identity_(d); break; case s64: result = identity_(d); break; case u16: result = identity_(d); break; case s16: result = identity_(d); break; // Removed because of bool type. Functions implementations exist. case b8: result = identity_(d); break; default: TYPE_ERROR(3, type); } std::swap(*out, result); } CATCHALL return AF_SUCCESS; } template static inline af_array range_(const dim4& d, const int seq_dim) { return getHandle(range(d, seq_dim)); } //Strong Exception Guarantee af_err af_range(af_array *result, const unsigned ndims, const dim_t * const dims, const int seq_dim, const af_dtype type) { try { af_array out; AF_CHECK(af_init()); dim4 d = verifyDims(ndims, dims); switch(type) { case f32: out = range_(d, seq_dim); break; case f64: out = range_(d, seq_dim); break; case s32: out = range_(d, seq_dim); break; case u32: out = range_(d, seq_dim); break; case s64: out = range_(d, seq_dim); break; case u64: out = range_(d, seq_dim); break; case s16: out = range_(d, seq_dim); break; case u16: out = range_(d, seq_dim); break; case u8: out = range_(d, seq_dim); break; default: TYPE_ERROR(4, type); } std::swap(*result, out); } CATCHALL return AF_SUCCESS; } template static inline af_array iota_(const dim4 &dims, const dim4 &tile_dims) { return getHandle(iota(dims, tile_dims)); } //Strong Exception Guarantee af_err af_iota(af_array *result, const unsigned ndims, const dim_t * const dims, const unsigned t_ndims, const dim_t * const tdims, const af_dtype type) { try { af_array out; AF_CHECK(af_init()); DIM_ASSERT(1, ndims > 0 && ndims <= 4); DIM_ASSERT(3, t_ndims > 0 && t_ndims <= 4); dim4 d = verifyDims(ndims, dims); dim4 t = verifyDims(t_ndims, tdims); switch(type) { case f32: out = iota_(d, t); break; case f64: out = iota_(d, t); break; case s32: out = iota_(d, t); break; case u32: out = iota_(d, t); break; case s64: out = iota_(d, t); break; case u64: out = iota_(d, t); break; case s16: out = iota_(d, t); break; case u16: out = iota_(d, t); break; case u8: out = iota_(d, t); break; default: TYPE_ERROR(4, type); } std::swap(*result, out); } CATCHALL return AF_SUCCESS; } template static inline af_array diagCreate(const af_array in, const int num) { return getHandle(diagCreate(getArray(in), num)); } template static inline af_array diagExtract(const af_array in, const int num) { return getHandle(diagExtract(getArray(in), num)); } af_err af_diag_create(af_array *out, const af_array in, const int num) { try { ArrayInfo in_info = getInfo(in); DIM_ASSERT(1, in_info.ndims() <= 2); af_dtype type = in_info.getType(); af_array result; switch(type) { case f32: result = diagCreate(in, num); break; case c32: result = diagCreate(in, num); break; case f64: result = diagCreate(in, num); break; case c64: result = diagCreate(in, num); break; case s32: result = diagCreate(in, num); break; case u32: result = diagCreate(in, num); break; case s64: result = diagCreate(in, num); break; case u64: result = diagCreate(in, num); break; case s16: result = diagCreate(in, num); break; case u16: result = diagCreate(in, num); break; case u8: result = diagCreate(in, num); break; // Removed because of bool type. Functions implementations exist. case b8: result = diagCreate(in, num); break; default: TYPE_ERROR(1, type); } std::swap(*out, result); } CATCHALL; return AF_SUCCESS; } af_err af_diag_extract(af_array *out, const af_array in, const int num) { try { ArrayInfo in_info = getInfo(in); DIM_ASSERT(1, in_info.ndims() >= 2); af_dtype type = in_info.getType(); af_array result; switch(type) { case f32: result = diagExtract(in, num); break; case c32: result = diagExtract(in, num); break; case f64: result = diagExtract(in, num); break; case c64: result = diagExtract(in, num); break; case s32: result = diagExtract(in, num); break; case u32: result = diagExtract(in, num); break; case s64: result = diagExtract(in, num); break; case u64: result = diagExtract(in, num); break; case s16: result = diagExtract(in, num); break; case u16: result = diagExtract(in, num); break; case u8: result = diagExtract(in, num); break; // Removed because of bool type. Functions implementations exist. case b8: result = diagExtract(in, num); break; default: TYPE_ERROR(1, type); } std::swap(*out, result); } CATCHALL; return AF_SUCCESS; } template af_array triangle(const af_array in, bool is_unit_diag) { if (is_unit_diag) return getHandle(triangle(getArray(in))); else return getHandle(triangle(getArray(in))); } af_err af_lower(af_array *out, const af_array in, bool is_unit_diag) { try { af_dtype type = getInfo(in).getType(); af_array res; switch(type) { case f32: res = triangle(in, is_unit_diag); break; case f64: res = triangle(in, is_unit_diag); break; case c32: res = triangle(in, is_unit_diag); break; case c64: res = triangle(in, is_unit_diag); break; case s32: res = triangle(in, is_unit_diag); break; case u32: res = triangle(in, is_unit_diag); break; case s64: res = triangle(in, is_unit_diag); break; case u64: res = triangle(in, is_unit_diag); break; case s16: res = triangle(in, is_unit_diag); break; case u16: res = triangle(in, is_unit_diag); break; case u8 : res = triangle(in, is_unit_diag); break; case b8 : res = triangle(in, is_unit_diag); break; } std::swap(*out, res); } CATCHALL return AF_SUCCESS; } af_err af_upper(af_array *out, const af_array in, bool is_unit_diag) { try { af_dtype type = getInfo(in).getType(); af_array res; switch(type) { case f32: res = triangle(in, is_unit_diag); break; case f64: res = triangle(in, is_unit_diag); break; case c32: res = triangle(in, is_unit_diag); break; case c64: res = triangle(in, is_unit_diag); break; case s32: res = triangle(in, is_unit_diag); break; case u32: res = triangle(in, is_unit_diag); break; case s64: res = triangle(in, is_unit_diag); break; case u64: res = triangle(in, is_unit_diag); break; case s16: res = triangle(in, is_unit_diag); break; case u16: res = triangle(in, is_unit_diag); break; case u8 : res = triangle(in, is_unit_diag); break; case b8 : res = triangle(in, is_unit_diag); break; } std::swap(*out, res); } CATCHALL return AF_SUCCESS; }