Skip to content

Commit 793de5a

Browse files
committed
Adding Monte Carlo calculation of Pi example
1 parent 045247e commit 793de5a

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

examples/MonteCarloPi.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import java.util.Random;
2+
import com.arrayfire.Array;
3+
4+
public class MonteCarloPi {
5+
6+
public static double hostCalcPi(int size) {
7+
Random rand = new Random();
8+
int count = 0;
9+
for (int i = 0; i < size; i++) {
10+
float x = rand.nextFloat();
11+
float y = rand.nextFloat();
12+
boolean lt1 = (x * x + y * y) < 1;
13+
if (lt1) count++;
14+
}
15+
16+
return 4.0 * ((double)(count)) / size;
17+
}
18+
19+
public static double deviceCalcPi(int size) throws Exception {
20+
int[] dims = new int[] {size, 1};
21+
22+
Array x = Array.randu(dims, Array.FloatType);
23+
Array y = Array.randu(dims, Array.FloatType);
24+
25+
Array x2 = Array.mul(x, x);
26+
Array y2 = Array.mul(y, y);
27+
Array res = Array.lt(Array.add(x2, y2), 1);
28+
29+
double count = Array.sumAll(res);
30+
return 4.0 * ((double)(count)) / size;
31+
}
32+
33+
public static void main(String[] args) {
34+
35+
try {
36+
int size = 5000000;
37+
int iter = 100;
38+
double hostPi = hostCalcPi(size);
39+
double devicePi = deviceCalcPi(size);
40+
41+
System.out.println("Results from host: " + hostPi);
42+
System.out.println("Results from device: " + devicePi);
43+
44+
long hostStart = System.currentTimeMillis();
45+
for (int i = 0; i < iter; i++) {
46+
hostPi = hostCalcPi(size);
47+
}
48+
double hostElapsed = (double)(System.currentTimeMillis() - hostStart)/iter;
49+
50+
long deviceStart = System.currentTimeMillis();
51+
for (int i = 0; i < iter; i++) {
52+
devicePi = deviceCalcPi(size);
53+
}
54+
double deviceElapsed = (double)(System.currentTimeMillis() - deviceStart)/iter;
55+
56+
System.out.println("Time taken for host (ms): " + hostElapsed);
57+
System.out.println("Time taken for device (ms): " + deviceElapsed);
58+
System.out.println("Speedup: " + Math.round((hostElapsed) / (deviceElapsed)));
59+
60+
} catch (Exception e) {
61+
System.out.println(e.getMessage());
62+
}
63+
}
64+
}

0 commit comments

Comments
 (0)