/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include #include #include "stats.h" using namespace detail; template static outType stdev(const af_array& in) { Array input = cast(getArray(in)); Array meanCnst= createValueArray(input.dims(), mean(input)); Array diff = detail::arithOp(input, meanCnst, input.dims()); Array diffSq = detail::arithOp(diff, diff, diff.dims()); outType result = division(reduce_all(diffSq), input.elements()); return sqrt(result); } template static af_array stdev(const af_array& in, int dim) { Array input = cast(getArray(in)); dim4 iDims = input.dims(); Array meanArr = mean(input, dim); dim4 oDims = meanArr.dims(); /* now tile meanArr along dim and use it for variance computation */ dim4 tileDims(1); tileDims[dim] = iDims[dim]; Array tMeanArr = detail::tile(meanArr, tileDims); /* now mean array is ready */ Array diff = detail::arithOp(input, tMeanArr, tMeanArr.dims()); Array diffSq = detail::arithOp(diff, diff, diff.dims()); Array redDiff = reduce(diffSq, dim); oDims = redDiff.dims(); Array divArr = createValueArray(oDims, scalar(iDims[dim])); Array varArr = detail::arithOp(redDiff, divArr, redDiff.dims()); Array result = detail::unaryOp(varArr); return getHandle(result); } af_err af_stdev_all(double *realVal, double *imagVal, const af_array in) { try { ArrayInfo info = getInfo(in); af_dtype type = info.getType(); switch(type) { case f64: *realVal = stdev(in); break; case f32: *realVal = stdev(in); break; case s32: *realVal = stdev(in); break; case u32: *realVal = stdev(in); break; case u8: *realVal = stdev(in); break; case b8: *realVal = stdev(in); break; // TODO: FIXME: sqrt(complex) is not present in cuda/opencl backend //case c32: { // cfloat tmp = stdev(in); // *realVal = real(tmp); // *imagVal = imag(tmp); // } break; //case c64: { // cdouble tmp = stdev(in); // *realVal = real(tmp); // *imagVal = imag(tmp); // } break; default : TYPE_ERROR(1, type); } } CATCHALL; return AF_SUCCESS; } af_err af_stdev(af_array *out, const af_array in, dim_type dim) { try { ARG_ASSERT(2, (dim>=0 && dim<=3)); af_array output = 0; ArrayInfo info = getInfo(in); af_dtype type = info.getType(); switch(type) { case f64: output = stdev(in, dim); break; case f32: output = stdev(in, dim); break; case s32: output = stdev(in, dim); break; case u32: output = stdev(in, dim); break; case u8: output = stdev(in, dim); break; case b8: output = stdev(in, dim); break; // TODO: FIXME: sqrt(complex) is not present in cuda/opencl backend //case c32: output = stdev(in, dim); break; //case c64: output = stdev(in, dim); break; default : TYPE_ERROR(1, type); } std::swap(*out, output); } CATCHALL; return AF_SUCCESS; }