Skip to content

Commit 71005e0

Browse files
authored
Add files via upload
1 parent a532f0c commit 71005e0

1 file changed

Lines changed: 72 additions & 61 deletions

File tree

Main.java

Lines changed: 72 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,94 @@
11
import CudaLib.CudaMemObj;
2+
import CudaLib.NDCuArray;
3+
import CudaLib.NDCuSlice;
4+
25
import static java.lang.foreign.ValueLayout.*;
36

47
import java.lang.foreign.MemorySegment;
58
import java.lang.foreign.ValueLayout;
9+
import java.util.Arrays;
610

711
import static CudaLib.CudaNumLib.*;
812
import static CudaLib.CudaNumLib.cudaMemcpyKind.*;
13+
import CudaLib.NDCuArray;
14+
import static CudaLib.NDCuSlice.Sliceof;
915

1016
public class Main {
1117
public static void main(String[] args) throws Throwable {
12-
int num_elements = 1000, threadsPerBlock = 256;
13-
14-
CudaMemObj a_arr = hostMalloc(JAVA_DOUBLE, num_elements);
15-
CudaMemObj b_arr = hostMalloc(JAVA_DOUBLE, num_elements);
16-
CudaMemObj c_arr = hostMalloc(JAVA_DOUBLE, num_elements);
18+
// Array test.
1719

18-
for (int i = 0; i < 5; i++) {
19-
a_arr.get_ptr().setAtIndex(JAVA_DOUBLE, i, 10.0);
20-
b_arr.get_ptr().setAtIndex(JAVA_DOUBLE, i, 10.0);
21-
}
20+
NDCuArray arr = new NDCuArray(new double[]{ 1, 2, 3, 4, 5, 6, 7, 8, 9,
21+
10, 11, 12, 13, 14, 15, 16, 17, 18,
22+
19, 20, 21, 22, 23, 24, 25, 26, 27,
23+
28, 29, 30, 31, 32, 33, 34, 35, 36 });
24+
25+
IO.println(arr.arr_size);
26+
arr.reshape(2, 2, 9);
27+
double get_val = arr.get(1, 1, 3); // depth - row - col
28+
IO.println(get_val);
29+
arr.set(1234.0, 1, 1, 3);
30+
double get_val_02 = arr.get(1, 1, 3);
31+
IO.println(get_val_02);
2232

23-
int byte_size_arr = (int)a_arr.get_ptr().byteSize();
24-
CudaMemObj cu_a_arr = cudaMalloc(byte_size_arr);
25-
CudaMemObj cu_b_arr = cudaMalloc(byte_size_arr);
26-
CudaMemObj cu_c_arr = cudaMalloc(byte_size_arr);
27-
28-
cudaMemcpy(cu_a_arr, a_arr, byte_size_arr, cudaMemcpyHostToDevice);
29-
cudaMemcpy(cu_b_arr, b_arr, byte_size_arr, cudaMemcpyHostToDevice);
30-
cuda_add(threadsPerBlock, num_elements, cu_a_arr, cu_b_arr, cu_c_arr);
31-
cuda__powf(threadsPerBlock, num_elements, cu_c_arr, 2.0, cu_c_arr);
32-
cudaDeviceSynchronize();
33-
cudaMemcpy(c_arr, cu_c_arr, byte_size_arr, cudaMemcpyDeviceToHost);
33+
NDCuArray arr02 = new NDCuArray(new double[]{ 1, 2, 3, 4, 5, 6, 7, 8, 9,
34+
10, 11, 12, 13, 14, 15, 16, 17, 18,
35+
19, 20, 21, 22, 23, 24, 25, 26, 27,
36+
28, 29, 30, 31, 32, 33, 34, 35, 36 });
37+
38+
NDCuArray arr03 = arr.add(arr02);
39+
IO.println(arr03.get(1, 1, 3));
3440

35-
for (int i = 0; i < 5; i++) {
36-
IO.println(c_arr.get_ptr().getAtIndex(JAVA_DOUBLE, i));
37-
}
41+
NDCuArray arr04 = new NDCuArray(new double[]{ 1, 2, 3, 4, 5, 6, 7, 8, 9,
42+
10, 11, 12, 13, 14, 15, 16, 17, 18,
43+
19, 20, 21, 22, 23, 24 });
44+
arr04.reshape(2, 3, 4);
45+
// 1, 2, 3, 4 // 13, 14, 15, 16
46+
// 5, [6, 7, 8 // 17, [18, 19, 20
47+
// 9, 10, 11,] 12 // 21, 22, 23,] 24
3848

39-
cudaFree(cu_a_arr);
40-
cudaFree(cu_b_arr);
41-
cudaFree(cu_c_arr);
49+
NDCuArray arr06 = arr04.get(Sliceof(0, 2), Sliceof(1, 3), Sliceof(1, 3));
50+
IO.println(Arrays.toString(arr06.shape));
51+
IO.println(Arrays.toString(arr04.shape));
4252

43-
// Conv2d
44-
double[] data_arr = { 1, 2, 3, 4, 5, 6, 7, 8, 9,
45-
10, 11, 12, 13, 14, 15, 16, 17, 18,
46-
19, 20, 21, 22, 23, 24, 25, 26, 27,
47-
28, 29, 30, 31, 32, 33, 34, 35, 36 };
48-
double[] mask_arr = { 3, 4, 5,
49-
6, 7, 8,
50-
9, 10, 11 };
53+
IO.println(arr06.get(0, 0, 0));
54+
IO.println(arr06.get(0, 0, 1));
55+
IO.println(arr06.get(0, 1, 0));
56+
IO.println(arr06.get(0, 1, 1));
57+
IO.println(arr06.get(1, 0, 0));
58+
IO.println(arr06.get(1, 0, 1));
59+
IO.println(arr06.get(1, 1, 0));
60+
IO.println(arr06.get(1, 1, 1));
5161

52-
int data_row = 4, data_col = 9;
53-
int mask_row = 3, mask_col = 3;
62+
arr06.set(100.0, 0, 0, 0);
63+
arr06.set(200.0, 0, 0, 1);
64+
arr06.set(300.0,0, 1, 0);
65+
arr06.set(400.0,0, 1, 1);
66+
arr06.set(500.0,1, 0, 0);
67+
arr06.set(600.0,1, 0, 1);
68+
arr06.set(700.0,1, 1, 0);
69+
arr06.set(800.0,1, 1, 1);
5470

55-
CudaMemObj data_arr_ptr = hostMalloc(JAVA_DOUBLE, data_arr.length);
56-
CudaMemObj mask_arr_ptr = hostMalloc(JAVA_DOUBLE, mask_arr.length);
57-
CudaMemObj result_arr_ptr = hostMalloc(JAVA_DOUBLE, (data_row-mask_row+1)*(data_col-mask_col+1));
71+
IO.println(arr06.get(0, 0, 0));
72+
IO.println(arr06.get(0, 0, 1));
73+
IO.println(arr06.get(0, 1, 0));
74+
IO.println(arr06.get(0, 1, 1));
75+
IO.println(arr06.get(1, 0, 0));
76+
IO.println(arr06.get(1, 0, 1));
77+
IO.println(arr06.get(1, 1, 0));
78+
IO.println(arr06.get(1, 1, 1));
5879

59-
for (int i = 0; i < data_arr.length; i++) {
60-
data_arr_ptr.get_ptr().setAtIndex(JAVA_DOUBLE, i, data_arr[i]);
61-
}
62-
for (int i = 0; i < mask_arr.length; i++) {
63-
mask_arr_ptr.get_ptr().setAtIndex(JAVA_DOUBLE, i, mask_arr[i]);
64-
}
65-
66-
CudaMemObj cu_d_a_arr = cudaMalloc((int)data_arr_ptr.get_ptr().byteSize());
67-
CudaMemObj cu_m_b_arr = cudaMalloc((int)mask_arr_ptr.get_ptr().byteSize());
68-
CudaMemObj cu_res_c_arr = cudaMalloc((int)result_arr_ptr.get_ptr().byteSize());
69-
70-
cudaMemcpy(cu_d_a_arr, data_arr_ptr, (int)data_arr_ptr.get_ptr().byteSize(), cudaMemcpyHostToDevice);
71-
cudaMemcpy(cu_m_b_arr, mask_arr_ptr, (int)mask_arr_ptr.get_ptr().byteSize(), cudaMemcpyHostToDevice);
72-
cuda_conv2d(16, data_row, data_col, mask_row, mask_col, cu_d_a_arr, cu_m_b_arr, cu_res_c_arr);
73-
cudaDeviceSynchronize();
74-
75-
IO.println((int)result_arr_ptr.get_ptr().byteSize());
76-
IO.println(((data_row-mask_row+1)*(data_col-mask_col+1))*8);
77-
cudaMemcpy(result_arr_ptr, cu_res_c_arr, (int)result_arr_ptr.get_ptr().byteSize(), cudaMemcpyDeviceToHost);
80+
arr04.set(arr06, Sliceof(0, 2), Sliceof(1, 3), Sliceof(2, 4));
81+
IO.println(arr04.get(0, 1, 2));
82+
IO.println(arr04.get(0, 1, 3));
83+
IO.println(arr04.get(0, 2, 2));
84+
IO.println(arr04.get(0, 2, 3));
85+
IO.println(arr04.get(1, 1, 2));
86+
IO.println(arr04.get(1, 1, 3));
87+
IO.println(arr04.get(1, 2, 2));
88+
IO.println(arr04.get(1, 2, 3));
89+
arr04.print();
7890

79-
for (int i = 0; i < (data_row-mask_row+1)*(data_col-mask_col+1); i++) {
80-
IO.println(result_arr_ptr.get_ptr().getAtIndex(JAVA_DOUBLE, i));
81-
}
91+
NDCuArray cu_arr_zeros = new NDCuArray(2, 3, 4);
92+
cu_arr_zeros.print();
8293
}
8394
}

0 commit comments

Comments
 (0)