/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "stats.h" using af::dim4; using detail::Array; using detail::cast; using detail::cdouble; using detail::cfloat; using detail::createValueArray; using detail::division; using detail::intl; using detail::mean; using detail::reduce; using detail::reduce_all; using detail::scalar; using detail::uchar; using detail::uint; using detail::uintl; using detail::ushort; template static outType stdev(const af_array& in, const af_var_bias bias) { using weightType = typename baseOutType::type; const Array _in = getArray(in); Array input = cast(_in); Array meanCnst = createValueArray( input.dims(), mean(_in)); Array diff = detail::arithOp(input, meanCnst, input.dims()); Array diffSq = detail::arithOp(diff, diff, diff.dims()); outType result = division(reduce_all(diffSq), (input.elements() - (bias == AF_VARIANCE_SAMPLE))); return sqrt(result); } template static af_array stdev(const af_array& in, int dim, const af_var_bias bias) { using weightType = typename baseOutType::type; const Array _in = getArray(in); Array input = cast(_in); dim4 iDims = input.dims(); Array meanArr = mean(_in, dim); /* now tile meanArr along dim and use it for variance computation */ dim4 tileDims(1); tileDims[dim] = iDims[dim]; Array tMeanArr = detail::tile(meanArr, tileDims); /* now mean array is ready */ Array diff = detail::arithOp(input, tMeanArr, tMeanArr.dims()); Array diffSq = detail::arithOp(diff, diff, diff.dims()); Array redDiff = reduce(diffSq, dim); const dim4& oDims = redDiff.dims(); Array divArr = createValueArray( oDims, scalar((iDims[dim] - (bias == AF_VARIANCE_SAMPLE)))); Array varArr = detail::arithOp(redDiff, divArr, redDiff.dims()); Array result = detail::unaryOp(varArr); return getHandle(result); } // NOLINTNEXTLINE(readability-non-const-parameter) af_err af_stdev_all(double* realVal, double* imagVal, const af_array in) { return af_stdev_all_v2(realVal, imagVal, in, AF_VARIANCE_POPULATION); } af_err af_stdev_all_v2(double* realVal, double* imagVal, const af_array in, const af_var_bias bias) { UNUSED(imagVal); // TODO implement for complex values try { const ArrayInfo& info = getInfo(in); af_dtype type = info.getType(); switch (type) { case f64: *realVal = stdev(in, bias); break; case f32: *realVal = stdev(in, bias); break; case s32: *realVal = stdev(in, bias); break; case u32: *realVal = stdev(in, bias); break; case s16: *realVal = stdev(in, bias); break; case u16: *realVal = stdev(in, bias); break; case s64: *realVal = stdev(in, bias); break; case u64: *realVal = stdev(in, bias); break; case u8: *realVal = stdev(in, bias); break; case b8: *realVal = stdev(in, bias); break; // TODO(umar): FIXME: sqrt(complex) is not present in cuda/opencl // backend case c32: { // cfloat tmp = stdev(in); // *realVal = real(tmp); // *imagVal = imag(tmp); // } break; // case c64: { // cdouble tmp = stdev(in); // *realVal = real(tmp); // *imagVal = imag(tmp); // } break; default: TYPE_ERROR(1, type); } } CATCHALL; return AF_SUCCESS; } af_err af_stdev(af_array* out, const af_array in, const dim_t dim) { return af_stdev_v2(out, in, AF_VARIANCE_POPULATION, dim); } af_err af_stdev_v2(af_array* out, const af_array in, const af_var_bias bias, const dim_t dim) { try { ARG_ASSERT(2, (dim >= 0 && dim <= 3)); af_array output = 0; const ArrayInfo& info = getInfo(in); af_dtype type = info.getType(); switch (type) { case f64: output = stdev(in, dim, bias); break; case f32: output = stdev(in, dim, bias); break; case s32: output = stdev(in, dim, bias); break; case u32: output = stdev(in, dim, bias); break; case s16: output = stdev(in, dim, bias); break; case u16: output = stdev(in, dim, bias); break; case s64: output = stdev(in, dim, bias); break; case u64: output = stdev(in, dim, bias); break; case u8: output = stdev(in, dim, bias); break; case b8: output = stdev(in, dim, bias); break; // TODO(umar): FIXME: sqrt(complex) is not present in cuda/opencl // backend case c32: output = stdev(in, dim); // break; case c64: output = stdev(in, dim); break; default: TYPE_ERROR(1, type); } std::swap(*out, output); } CATCHALL; return AF_SUCCESS; }