|
| 1 | +import java.util.Random; |
| 2 | +import com.arrayfire.Array; |
| 3 | + |
| 4 | +public class MonteCarloPi { |
| 5 | + |
| 6 | + public static double hostCalcPi(int size) { |
| 7 | + Random rand = new Random(); |
| 8 | + int count = 0; |
| 9 | + for (int i = 0; i < size; i++) { |
| 10 | + float x = rand.nextFloat(); |
| 11 | + float y = rand.nextFloat(); |
| 12 | + boolean lt1 = (x * x + y * y) < 1; |
| 13 | + if (lt1) count++; |
| 14 | + } |
| 15 | + |
| 16 | + return 4.0 * ((double)(count)) / size; |
| 17 | + } |
| 18 | + |
| 19 | + public static double deviceCalcPi(int size) throws Exception { |
| 20 | + int[] dims = new int[] {size, 1}; |
| 21 | + |
| 22 | + Array x = Array.randu(dims, Array.FloatType); |
| 23 | + Array y = Array.randu(dims, Array.FloatType); |
| 24 | + |
| 25 | + Array x2 = Array.mul(x, x); |
| 26 | + Array y2 = Array.mul(y, y); |
| 27 | + Array res = Array.lt(Array.add(x2, y2), 1); |
| 28 | + |
| 29 | + double count = Array.sumAll(res); |
| 30 | + return 4.0 * ((double)(count)) / size; |
| 31 | + } |
| 32 | + |
| 33 | + public static void main(String[] args) { |
| 34 | + |
| 35 | + try { |
| 36 | + int size = 5000000; |
| 37 | + int iter = 100; |
| 38 | + double hostPi = hostCalcPi(size); |
| 39 | + double devicePi = deviceCalcPi(size); |
| 40 | + |
| 41 | + System.out.println("Results from host: " + hostPi); |
| 42 | + System.out.println("Results from device: " + devicePi); |
| 43 | + |
| 44 | + long hostStart = System.currentTimeMillis(); |
| 45 | + for (int i = 0; i < iter; i++) { |
| 46 | + hostPi = hostCalcPi(size); |
| 47 | + } |
| 48 | + double hostElapsed = (double)(System.currentTimeMillis() - hostStart)/iter; |
| 49 | + |
| 50 | + long deviceStart = System.currentTimeMillis(); |
| 51 | + for (int i = 0; i < iter; i++) { |
| 52 | + devicePi = deviceCalcPi(size); |
| 53 | + } |
| 54 | + double deviceElapsed = (double)(System.currentTimeMillis() - deviceStart)/iter; |
| 55 | + |
| 56 | + System.out.println("Time taken for host (ms): " + hostElapsed); |
| 57 | + System.out.println("Time taken for device (ms): " + deviceElapsed); |
| 58 | + System.out.println("Speedup: " + Math.round((hostElapsed) / (deviceElapsed))); |
| 59 | + |
| 60 | + } catch (Exception e) { |
| 61 | + System.out.println(e.getMessage()); |
| 62 | + } |
| 63 | + } |
| 64 | +} |
0 commit comments