forked from arrayfire/arrayfire
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcorrcoef.cpp
More file actions
97 lines (83 loc) · 3.17 KB
/
corrcoef.cpp
File metadata and controls
97 lines (83 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
/*******************************************************
* Copyright (c) 2014, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#include <arith.hpp>
#include <backend.hpp>
#include <cast.hpp>
#include <common/err_common.hpp>
#include <handle.hpp>
#include <math.hpp>
#include <reduce.hpp>
#include <stats.h>
#include <types.hpp>
#include <af/defines.h>
#include <af/dim4.hpp>
#include <af/statistics.h>
#include <cmath>
using af::dim4;
using detail::arithOp;
using detail::Array;
using detail::cast;
using detail::intl;
using detail::reduce_all;
using detail::uchar;
using detail::uint;
using detail::uintl;
using detail::ushort;
template<typename Ti, typename To>
static To corrcoef(const af_array& X, const af_array& Y) {
Array<To> xIn = cast<To>(getArray<Ti>(X));
Array<To> yIn = cast<To>(getArray<Ti>(Y));
const dim4& dims = xIn.dims();
dim_t n = xIn.elements();
To xSum = reduce_all<af_add_t, To, To>(xIn);
To ySum = reduce_all<af_add_t, To, To>(yIn);
Array<To> xSq = arithOp<To, af_mul_t>(xIn, xIn, dims);
Array<To> ySq = arithOp<To, af_mul_t>(yIn, yIn, dims);
Array<To> xy = arithOp<To, af_mul_t>(xIn, yIn, dims);
To xSqSum = reduce_all<af_add_t, To, To>(xSq);
To ySqSum = reduce_all<af_add_t, To, To>(ySq);
To xySum = reduce_all<af_add_t, To, To>(xy);
To result =
(n * xySum - xSum * ySum) / (std::sqrt(n * xSqSum - xSum * xSum) *
std::sqrt(n * ySqSum - ySum * ySum));
return result;
}
// NOLINTNEXTLINE
af_err af_corrcoef(double* realVal, double* imagVal, const af_array X,
const af_array Y) {
UNUSED(imagVal); // TODO(umar): implement for complex types
try {
const ArrayInfo& xInfo = getInfo(X);
const ArrayInfo& yInfo = getInfo(Y);
dim4 xDims = xInfo.dims();
dim4 yDims = yInfo.dims();
af_dtype xType = xInfo.getType();
af_dtype yType = yInfo.getType();
ARG_ASSERT(2, (xType == yType));
ARG_ASSERT(2, (xDims.ndims() == yDims.ndims()));
for (dim_t i = 0; i < xDims.ndims(); ++i) {
ARG_ASSERT(2, (xDims[i] == yDims[i]));
}
switch (xType) {
case f64: *realVal = corrcoef<double, double>(X, Y); break;
case f32: *realVal = corrcoef<float, float>(X, Y); break;
case s32: *realVal = corrcoef<int, float>(X, Y); break;
case u32: *realVal = corrcoef<uint, float>(X, Y); break;
case s64: *realVal = corrcoef<intl, double>(X, Y); break;
case u64: *realVal = corrcoef<uintl, double>(X, Y); break;
case s16: *realVal = corrcoef<short, float>(X, Y); break;
case u16: *realVal = corrcoef<ushort, float>(X, Y); break;
case u8: *realVal = corrcoef<uchar, float>(X, Y); break;
case b8: *realVal = corrcoef<char, float>(X, Y); break;
default: TYPE_ERROR(1, xType);
}
}
CATCHALL;
return AF_SUCCESS;
}