Skip to content

Commit f7af337

Browse files
ShadyBoukharyumar456
authored andcommitted
Combined non-complex and complex statistics functions.
1 parent c64a418 commit f7af337

File tree

4 files changed

+64
-175
lines changed

4 files changed

+64
-175
lines changed

com/arrayfire/Statistics.java

Lines changed: 36 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,21 @@ public class Statistics extends ArrayFire {
55

66
static private native long afMeanWeighted(long ref, long weightsRef, int dim);
77

8-
static private native double afMeanAll(long ref);
8+
static private native DoubleComplex afMeanAll(long ref);
99

10-
static private native double afMeanAllWeighted(long ref, long weightsRef);
11-
12-
static private native FloatComplex afMeanAllFloatComplex(long ref);
13-
14-
static private native DoubleComplex afMeanAllDoubleComplex(long ref);
15-
16-
static private native FloatComplex afMeanAllFloatComplexWeighted(long ref, long weightsRef);
17-
18-
static private native DoubleComplex afMeanAllDoubleComplexWeighted(long ref, long weightsRef);
10+
static private native DoubleComplex afMeanAllWeighted(long ref, long weightsRef);
1911

2012
static private native long afVar(long ref, boolean isBiased, int dim);
2113

2214
static private native long afVarWeighted(long ref, long weightsRef, int dim);
2315

24-
static private native double afVarAll(long ref, boolean isBiased);
25-
26-
static private native double afVarAllWeighted(long ref, long weightsRef);
27-
28-
static private native FloatComplex afVarAllFloatComplex(long ref, boolean isBiased);
16+
static private native DoubleComplex afVarAll(long ref, boolean isBiased);
2917

30-
static private native DoubleComplex afVarAllDoubleComplex(long ref, boolean isBiased);
31-
32-
static private native FloatComplex afVarAllFloatComplexWeighted(long ref, long weightsRef);
33-
34-
static private native DoubleComplex afVarAllDoubleComplexWeighted(long ref, long weightsRef);
18+
static private native DoubleComplex afVarAllWeighted(long ref, long weightsRef);
3519

3620
static private native long afStdev(long ref, int dim);
3721

38-
static private native double afStdevAll(long ref);
39-
40-
static private native FloatComplex afStdevAllFloatComplex(long ref);
41-
42-
static private native DoubleComplex afStdevAllDoubleComplex(long ref);
22+
static private native DoubleComplex afStdevAll(long ref);
4323

4424
static public Array mean(final Array in, int dim) {
4525
return new Array(afMean(in.ref, dim));
@@ -50,43 +30,13 @@ static public Array mean(final Array in, final Array weights, int dim) {
5030
}
5131

5232
static public <T> T mean(final Array in, Class<T> type) throws Exception {
53-
if (type == FloatComplex.class) {
54-
FloatComplex res = (FloatComplex) afMeanAllFloatComplex(in.ref);
55-
return type.cast(res);
56-
} else if (type == DoubleComplex.class) {
57-
DoubleComplex res = (DoubleComplex) afMeanAllDoubleComplex(in.ref);
58-
return type.cast(res);
59-
}
60-
61-
double res = afMeanAll(in.ref);
62-
if (type == Float.class) {
63-
return type.cast(Float.valueOf((float) res));
64-
} else if (type == Double.class) {
65-
return type.cast(Double.valueOf((double) res));
66-
} else if (type == Integer.class) {
67-
return type.cast(Integer.valueOf((int) res));
68-
}
69-
throw new Exception("Unknown type");
33+
DoubleComplex res = afMeanAll(in.ref);
34+
return castResult(res, type);
7035
}
7136

7237
static public <T> T mean(final Array in, final Array weights, Class<T> type) throws Exception {
73-
if (type == FloatComplex.class) {
74-
FloatComplex res = (FloatComplex) afMeanAllFloatComplexWeighted(in.ref, weights.ref);
75-
return type.cast(res);
76-
} else if (type == DoubleComplex.class) {
77-
DoubleComplex res = (DoubleComplex) afMeanAllDoubleComplexWeighted(in.ref, weights.ref);
78-
return type.cast(res);
79-
}
80-
81-
double res = afMeanAllWeighted(in.ref, weights.ref);
82-
if (type == Float.class) {
83-
return type.cast(Float.valueOf((float) res));
84-
} else if (type == Double.class) {
85-
return type.cast(Double.valueOf((double) res));
86-
} else if (type == Integer.class) {
87-
return type.cast(Integer.valueOf((int) res));
88-
}
89-
throw new Exception("Unknown type");
38+
DoubleComplex res = afMeanAllWeighted(in.ref, weights.ref);
39+
return castResult(res, type);
9040
}
9141

9242
static public Array var(final Array in, boolean isBiased, int dim) {
@@ -98,66 +48,40 @@ static public Array var(final Array in, final Array weights, int dim) {
9848
}
9949

10050
static public <T> T var(final Array in, boolean isBiased, Class<T> type) throws Exception {
101-
if (type == FloatComplex.class) {
102-
FloatComplex res = (FloatComplex) afVarAllFloatComplex(in.ref, isBiased);
103-
return type.cast(res);
104-
} else if (type == DoubleComplex.class) {
105-
DoubleComplex res = (DoubleComplex) afVarAllDoubleComplex(in.ref, isBiased);
106-
return type.cast(res);
107-
}
108-
109-
double res = afVarAll(in.ref, isBiased);
110-
if (type == Float.class) {
111-
return type.cast(Float.valueOf((float) res));
112-
} else if (type == Double.class) {
113-
return type.cast(Double.valueOf((double) res));
114-
} else if (type == Integer.class) {
115-
return type.cast(Integer.valueOf((int) res));
116-
}
117-
throw new Exception("Unknown type");
51+
DoubleComplex res = afVarAll(in.ref, isBiased);
52+
return castResult(res, type);
11853
}
11954

12055
static public <T> T var(final Array in, final Array weights, Class<T> type) throws Exception {
121-
if (type == FloatComplex.class) {
122-
FloatComplex res = (FloatComplex) afVarAllFloatComplexWeighted(in.ref, weights.ref);
123-
return type.cast(res);
124-
} else if (type == DoubleComplex.class) {
125-
DoubleComplex res = (DoubleComplex) afVarAllDoubleComplexWeighted(in.ref, weights.ref);
126-
return type.cast(res);
127-
}
56+
DoubleComplex res = afVarAllWeighted(in.ref, weights.ref);
57+
return castResult(res, type);
58+
}
59+
60+
static public Array stdev(final Array in, int dim) {
61+
return new Array(afStdev(in.ref, dim));
62+
}
12863

129-
double res = afVarAllWeighted(in.ref, weights.ref);
64+
static public <T> T stdev(final Array in, Class<T> type) throws Exception {
65+
DoubleComplex res = afStdevAll(in.ref);
66+
return castResult(res, type);
67+
}
68+
69+
static public <T> T castResult(DoubleComplex res, Class<T> type) throws Exception {
70+
Object ret;
13071
if (type == Float.class) {
131-
return type.cast(Float.valueOf((float) res));
72+
ret = Float.valueOf((float) res.real());
13273
} else if (type == Double.class) {
133-
return type.cast(Double.valueOf((double) res));
74+
ret = Double.valueOf((double) res.real());
13475
} else if (type == Integer.class) {
135-
return type.cast(Integer.valueOf((int) res));
136-
}
137-
throw new Exception("Unknown type");
138-
}
139-
140-
static public Array stdev(final Array in, int dim) {
141-
return new Array(afStdev(in.ref, dim));
76+
ret = Integer.valueOf((int) res.real());
77+
} else if (type == FloatComplex.class) {
78+
ret = new FloatComplex((float) res.real(), (float) res.imag());
79+
} else if (type == DoubleComplex.class) {
80+
ret = res;
81+
} else {
82+
throw new Exception("Unknown type");
14283
}
14384

144-
static public <T> T stdev(final Array in, Class<T> type) throws Exception {
145-
if (type == FloatComplex.class) {
146-
FloatComplex res = (FloatComplex)afStdevAllFloatComplex(in.ref);
147-
return type.cast(res);
148-
} else if (type == DoubleComplex.class) {
149-
DoubleComplex res = (DoubleComplex)afStdevAllDoubleComplex(in.ref);
150-
return type.cast(res);
151-
}
152-
153-
double res = afStdevAll(in.ref);
154-
if (type == Float.class) {
155-
return type.cast(Float.valueOf((float) res));
156-
} else if (type == Double.class) {
157-
return type.cast(Double.valueOf((double) res));
158-
} else if (type == Integer.class) {
159-
return type.cast(Integer.valueOf((int) res));
160-
}
161-
throw new Exception("Unknown type");
162-
}
85+
return type.cast(ret);
86+
}
16387
}

examples/HelloWorld.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ public static void main(String[] args) {
2828
System.out.println("Calculate weighted variance.");
2929
Array forVar = new Array();
3030
Array weights = new Array();
31-
Data.randn(forVar, new int[] { 5, 3 }, Array.DoubleType);
32-
Data.randn(weights, new int[] { 5, 3 }, Array.DoubleType);
31+
Data.randn(forVar, new int[] { 5, 5 }, Array.DoubleType);
32+
Data.randn(weights, new int[] { 5, 5 }, Array.DoubleType);
3333
System.out.println(forVar.toString("forVar"));
3434

3535
double abc = Statistics.var(forVar, weights, Double.class);
@@ -40,7 +40,7 @@ public static void main(String[] args) {
4040
System.out.println("Calculate standard deviation");
4141
Array forStdev = new Array();
4242
Data.randu(forStdev, new int[] {5, 3}, Array.DoubleType);
43-
System.out.println(forStdev.toString("forVar"));
43+
System.out.println(forStdev.toString("forStdev"));
4444
double stdev = Statistics.stdev(forStdev, Double.class);
4545

4646
System.out.println(String.format("Stdev is: %f", stdev));

src/java/java.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,19 @@ template <typename... Args>
7676
jobject createJavaObject(JNIEnv *env, JavaObjects objectType, Args... args) {
7777
switch (objectType) {
7878
case JavaObjects::FloatComplex: {
79-
static jclass cls = env->FindClass("com/arrayfire/FloatComplex");
80-
static std::string sig = generateFunctionSignature(
79+
jclass cls = env->FindClass("com/arrayfire/FloatComplex");
80+
std::string sig = generateFunctionSignature(
8181
JavaType::Void, {JavaType::Float, JavaType::Float});
82-
static jmethodID id = env->GetMethodID(cls, "<init>", sig.c_str());
82+
jmethodID id = env->GetMethodID(cls, "<init>", sig.c_str());
8383
jobject obj = env->NewObject(cls, id, args...);
8484
return obj;
8585

8686
} break;
8787
case JavaObjects::DoubleComplex: {
88-
static jclass cls = env->FindClass("com/arrayfire/DoubleComplex");
89-
static std::string sig = generateFunctionSignature(
88+
jclass cls = env->FindClass("com/arrayfire/DoubleComplex");
89+
std::string sig = generateFunctionSignature(
9090
JavaType::Void, {JavaType::Double, JavaType::Double});
91-
static jmethodID id = env->GetMethodID(cls, "<init>", sig.c_str());
91+
jmethodID id = env->GetMethodID(cls, "<init>", sig.c_str());
9292
jobject obj = env->NewObject(cls, id, args...);
9393
return obj;
9494
} break;

src/statistics.cpp

Lines changed: 19 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,14 @@ BEGIN_EXTERN_C
55

66
#define STATISTICS_FUNC(FUNC) AF_MANGLE(Statistics, FUNC)
77

8-
#define INSTANTIATE_STAT_WEIGHTED_COMPLEX(jtype, Name, name) \
9-
JNIEXPORT jobject JNICALL STATISTICS_FUNC(af##Name##All##jtype##Weighted)( \
8+
#define INSTANTIATE_STAT_ALL_WEIGHTED(Name, name) \
9+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(af##Name##AllWeighted)( \
1010
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef) { \
1111
double real = 0, img = 0; \
1212
AF_CHECK( \
1313
af_##name##_all_weighted(&real, &img, ARRAY(ref), ARRAY(weightsRef))); \
14-
return java::createJavaObject(env, java::JavaObjects::jtype, real, img); \
15-
}
16-
17-
#define INSTANTIATE_STAT_ALL_WEIGHTED(Name, name) \
18-
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(af##Name##AllWeighted)( \
19-
JNIEnv * env, jclass clazz, jlong ref, jlong weightsRef) { \
20-
double ret = 0; \
21-
AF_CHECK( \
22-
af_##name##_all_weighted(&ret, NULL, ARRAY(ref), ARRAY(weightsRef))); \
23-
return (jdouble)ret; \
14+
return java::createJavaObject(env, java::JavaObjects::DoubleComplex, real, \
15+
img); \
2416
}
2517

2618
#define INSTANTIATE_STAT_WEIGHTED(Name, name) \
@@ -31,14 +23,6 @@ BEGIN_EXTERN_C
3123
return JLONG(ret); \
3224
}
3325

34-
#define INSTANTIATE_VAR(jtype) \
35-
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afVarAll##jtype)( \
36-
JNIEnv * env, jclass clazz, jlong ref, jboolean isBiased) { \
37-
double real = 0, img = 0; \
38-
AF_CHECK(af_var_all(&real, &img, ARRAY(ref), isBiased)); \
39-
return java::createJavaObject(env, java::JavaObjects::jtype, real, img); \
40-
}
41-
4226
#define INSTANTIATE_STAT(Name, name) \
4327
JNIEXPORT jlong JNICALL STATISTICS_FUNC(af##Name)( \
4428
JNIEnv * env, jclass clazz, jlong ref, jint dim) { \
@@ -47,32 +31,22 @@ BEGIN_EXTERN_C
4731
return JLONG(ret); \
4832
}
4933

50-
#define INSTANTIATE_STAT_ALL(Name, name) \
51-
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(af##Name##All)( \
52-
JNIEnv * env, jclass clazz, jlong ref) { \
53-
double ret = 0; \
54-
AF_CHECK(af_##name##_all(&ret, NULL, ARRAY(ref))); \
55-
return (jdouble)ret; \
56-
}
57-
58-
#define INSTANTIATE_STAT_ALL_COMPLEX(Name, name, jtype) \
59-
JNIEXPORT jobject JNICALL STATISTICS_FUNC(af##Name##All##jtype)( \
60-
JNIEnv * env, jclass clazz, jlong ref) { \
61-
double real = 0, img = 0; \
62-
AF_CHECK(af_##name##_all(&real, &img, ARRAY(ref))); \
63-
return java::createJavaObject(env, java::JavaObjects::jtype, real, img); \
34+
#define INSTANTIATE_STAT_ALL(Name, name) \
35+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(af##Name##All)( \
36+
JNIEnv * env, jclass clazz, jlong ref) { \
37+
double real = 0, img = 0; \
38+
AF_CHECK(af_##name##_all(&real, &img, ARRAY(ref))); \
39+
return java::createJavaObject(env, java::JavaObjects::DoubleComplex, real, \
40+
img); \
6441
}
6542

6643
// Mean
6744
INSTANTIATE_STAT(Mean, mean)
6845
INSTANTIATE_STAT_ALL(Mean, mean)
69-
INSTANTIATE_STAT_ALL_COMPLEX(Mean, mean, FloatComplex)
70-
INSTANTIATE_STAT_ALL_COMPLEX(Mean, mean, DoubleComplex)
7146
INSTANTIATE_STAT_ALL_WEIGHTED(Mean, mean)
7247
INSTANTIATE_STAT_WEIGHTED(Mean, mean)
73-
INSTANTIATE_STAT_WEIGHTED_COMPLEX(FloatComplex, Mean, mean)
74-
INSTANTIATE_STAT_WEIGHTED_COMPLEX(DoubleComplex, Mean, mean)
7548

49+
// Variance
7650
JNIEXPORT jlong JNICALL STATISTICS_FUNC(afVar)(JNIEnv *env, jclass clazz,
7751
jlong ref, jboolean isBiased,
7852
jint dim) {
@@ -81,34 +55,25 @@ JNIEXPORT jlong JNICALL STATISTICS_FUNC(afVar)(JNIEnv *env, jclass clazz,
8155
return JLONG(ret);
8256
}
8357

84-
JNIEXPORT jdouble JNICALL STATISTICS_FUNC(afVarAll)(JNIEnv *env, jclass clazz,
58+
JNIEXPORT jobject JNICALL STATISTICS_FUNC(afVarAll)(JNIEnv *env, jclass clazz,
8559
jlong ref,
8660
jboolean isBiased) {
87-
double ret = 0;
88-
AF_CHECK(af_var_all(&ret, NULL, ARRAY(ref), isBiased));
89-
return (jdouble)ret;
61+
double real = 0, img = 0;
62+
AF_CHECK(af_var_all(&real, &img, ARRAY(ref), isBiased));
63+
return java::createJavaObject(env, java::JavaObjects::DoubleComplex, real,
64+
img);
9065
}
9166

92-
// Variance
93-
INSTANTIATE_VAR(FloatComplex)
94-
INSTANTIATE_VAR(DoubleComplex)
9567
INSTANTIATE_STAT_WEIGHTED(Var, var)
9668
INSTANTIATE_STAT_ALL_WEIGHTED(Var, var)
97-
INSTANTIATE_STAT_WEIGHTED_COMPLEX(FloatComplex, Var, var)
98-
INSTANTIATE_STAT_WEIGHTED_COMPLEX(DoubleComplex, Var, var)
9969

10070
// Standard dev
10171
INSTANTIATE_STAT(Stdev, stdev)
10272
INSTANTIATE_STAT_ALL(Stdev, stdev)
103-
INSTANTIATE_STAT_ALL_COMPLEX(Stdev, stdev, FloatComplex)
104-
INSTANTIATE_STAT_ALL_COMPLEX(Stdev, stdev, DoubleComplex)
10573

106-
#undef INSTANTIATE_VAR
107-
#undef INSTANTIATE_STAT_WEIGHTED_COMPLEX
108-
#undef INSTANTIATE_STAT_WEIGHTED
109-
#undef INSTANTIATE_STAT_ALL_WEIGHTED
11074
#undef INSTANTIATE_STAT
11175
#undef INSTANTIATE_STAT_ALL
112-
#undef INSTANTIATE_STAT_ALL_COMPLEX
76+
#undef INSTANTIATE_STAT_WEIGHTED
77+
#undef INSTANTIATE_STAT_ALL_WEIGHTED
11378

11479
END_EXTERN_C

0 commit comments

Comments
 (0)