Skip to content

Commit 68a29ae

Browse files
committed
JavaCL/Blas: integrate @fvlankvelt's patch
Block multiplication doesn't work on CPU, when maxWorkItemSizes < 16x16. Reintroduced old, naive multiplication code as a fallback, and introduced stride and blockSize throughout the matrix code (default block size is defined in CLDefaultMatrix2D. DEFAULT_BLOCK_SIZE).
1 parent 406f77e commit 68a29ae

File tree

13 files changed

+247
-112
lines changed

13 files changed

+247
-112
lines changed

Blas/src/main/java/com/nativelibs4java/opencl/blas/CLDefaultMatrix2D.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,25 @@
2020
public class CLDefaultMatrix2D<T> implements CLMatrix2D<T> {
2121
protected final Primitive primitive;
2222
protected final Class<T> primitiveClass;
23-
protected final long rows, columns, length;
23+
protected final long rows, columns, stride, length;
24+
protected final int blockSize;
2425

2526
protected final CLKernels kernels;
2627
protected final CLBuffer<T> buffer;
2728
protected final CLQueue queue;
2829
protected final CLContext context;
2930
protected CLEvents _events = new CLEvents();
3031

32+
public static final int DEFAULT_BLOCK_SIZE = 16;
33+
3134
public CLDefaultMatrix2D(Primitive primitive, CLBuffer<T> buffer, long rows, long columns, CLKernels kernels) {
35+
this(primitive, buffer, rows, columns, DEFAULT_BLOCK_SIZE, kernels);
36+
}
37+
public CLDefaultMatrix2D(Primitive primitive, CLBuffer<T> buffer, long rows, long columns, int blockSize, CLKernels kernels) {
3238
this.primitive = primitive;
3339
this.primitiveClass = (Class<T>)primitive.primitiveType;
34-
this.length = CLMatrixUtils.roundUp(rows) * CLMatrixUtils.roundUp(columns);
40+
this.stride = CLMatrixUtils.roundUp(columns, blockSize);
41+
this.length = this.stride * CLMatrixUtils.roundUp(rows, blockSize);
3542
if (buffer != null) {
3643
if (buffer.getElementCount() < this.length) {
3744
throw new IllegalArgumentException("Buffer size too small; buffer of size " + this.length + " expected, size " + buffer.getByteCount() + " was given");
@@ -45,13 +52,17 @@ public CLDefaultMatrix2D(Primitive primitive, CLBuffer<T> buffer, long rows, lon
4552
this.columns = columns;
4653
this.queue = kernels.getQueue();
4754
this.context = kernels.getContext();
55+
this.blockSize = blockSize;
56+
57+
assert getBuffer().getElementCount() >= stride * rows &&
58+
getBuffer().getElementCount() <= stride * CLMatrixUtils.roundUp(rows, getBlockSize());
4859
}
4960

5061
public CLMatrix2D<T> blankClone() {
5162
return blankMatrix(getRowCount(), getColumnCount());
5263
}
5364
public CLMatrix2D<T> blankMatrix(long rows, long columns) {
54-
return new CLDefaultMatrix2D<T>(primitive, null, rows, columns, kernels);
65+
return new CLDefaultMatrix2D<T>(primitive, null, rows, columns, blockSize, kernels);
5566
}
5667

5768
public long getRowCount() {
@@ -62,6 +73,14 @@ public long getColumnCount() {
6273
return columns;
6374
}
6475

76+
public long getStride() {
77+
return stride;
78+
}
79+
80+
public int getBlockSize() {
81+
return blockSize;
82+
}
83+
6584
public CLEvents getEvents() {
6685
return _events;
6786
}

Blas/src/main/java/com/nativelibs4java/opencl/blas/CLKernels.java

Lines changed: 94 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import com.nativelibs4java.opencl.util.ParallelMath;
2727
import com.nativelibs4java.opencl.util.Primitive;
2828

29-
import static com.nativelibs4java.opencl.blas.CLMatrix2D.BLOCK_SIZE;
29+
import static com.nativelibs4java.opencl.blas.CLMatrixUtils.roundUp;
3030
import static org.bridj.Pointer.pointerToInt;
3131

3232
/**
@@ -70,10 +70,10 @@ public CLKernels(CLQueue queue) throws IOException, CLBuildException {
7070
this.queue = queue;
7171
}
7272

73-
public <T> CLEvent op1(Primitive prim, Fun1 fun, CLBuffer<T> a, long rows, long columns, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
74-
long length = rows * columns;
75-
if (out == null || out.getElementCount() != length)
76-
throw new IllegalArgumentException("Expected buffer of length " + length + ", got " + out);
73+
public <T> CLEvent op1(Primitive prim, Fun1 fun, CLBuffer<T> a, long rows, long columns, long stride, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
74+
long length = rows * stride;
75+
if (out == null || out.getElementCount() < length)
76+
throw new IllegalArgumentException("Expected buffer of length >= " + length + ", got " + out);
7777
//if (out != null)
7878
// out = (CLBuffer<T>)context.createBuffer(Usage.Output, prim.primitiveType, length);
7979

@@ -85,10 +85,10 @@ public <T> CLEvent op1(Primitive prim, Fun1 fun, CLBuffer<T> a, long rows, long
8585
}
8686
}
8787

88-
public <T> CLEvent op2(Primitive prim, Fun2 fun, CLBuffer<T> a, CLBuffer<T> b, long rows, long columns, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
89-
long length = rows * columns;
90-
if (out == null || out.getElementCount() != length)
91-
throw new IllegalArgumentException("Expected buffer of length " + length + ", got " + out);
88+
public <T> CLEvent op2(Primitive prim, Fun2 fun, CLBuffer<T> a, CLBuffer<T> b, long rows, long columns, long stride, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
89+
long length = rows * stride;
90+
if (out == null || out.getElementCount() < length)
91+
throw new IllegalArgumentException("Expected buffer of length >= " + length + ", got " + out.getElementCount());
9292
//if (out != null)
9393
// out = (CLBuffer<T>)context.createBuffer(Usage.Output, prim.primitiveType, length);
9494

@@ -100,10 +100,10 @@ public <T> CLEvent op2(Primitive prim, Fun2 fun, CLBuffer<T> a, CLBuffer<T> b, l
100100
}
101101
}
102102

103-
public <T> CLEvent op2(Primitive prim, Fun2 fun, CLBuffer<T> a, T b, long rows, long columns, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
104-
long length = rows * columns;
105-
if (out == null || out.getElementCount() != length)
106-
throw new IllegalArgumentException("Expected buffer of length " + length + ", got " + out.getElementCount());
103+
public <T> CLEvent op2(Primitive prim, Fun2 fun, CLBuffer<T> a, T b, long rows, long columns, long stride, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
104+
long length = rows * stride;
105+
if (out == null || out.getElementCount() < length)
106+
throw new IllegalArgumentException("Expected buffer of length >= " + length + ", got " + out.getElementCount());
107107
//if (out != null)
108108
// out = (CLBuffer<T>)context.createBuffer(Usage.Output, prim.primitiveType, length);
109109

@@ -178,20 +178,42 @@ public <V> CLEvent clear(Primitive primitive, CLBuffer<V> buffer, long length, C
178178
}
179179
}
180180

181-
Map<Primitive, CLKernel> matrixMultiplyKernels = new HashMap<Primitive, CLKernel>();
182-
public <T> CLEvent matrixMultiply(Primitive prim, CLBuffer<T> a, long aRows, long aColumns, CLBuffer<T> b, long bRows, long bColumns, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
181+
Map<String, CLKernel> matrixMultiplyKernels = new HashMap<String, CLKernel>();
182+
public <T> CLEvent matrixMultiply(Primitive prim,
183+
CLBuffer<T> a, long aRows, long aColumns, int aBlockSize,
184+
CLBuffer<T> b, long bRows, long bColumns, int bBlockSize,
185+
CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
186+
boolean useBlocks = false;
187+
int blockSize = aBlockSize;
188+
if (blockSize > 1 && blockSize == bBlockSize) {
189+
long[] maxWorkItemSizes = queue.getDevice().getMaxWorkItemSizes();
190+
useBlocks = maxWorkItemSizes.length >= 2 &&
191+
maxWorkItemSizes[0] >= blockSize &&
192+
maxWorkItemSizes[1] >= blockSize;
193+
}
194+
if (useBlocks) {
195+
return blockMatrixMultiply(
196+
blockSize, prim,
197+
a, roundUp(aRows, blockSize), roundUp(aColumns, blockSize),
198+
b, roundUp(bRows, blockSize), roundUp(bColumns, blockSize),
199+
out, eventsToWaitFor);
200+
} else {
201+
return naiveMatrixMultiply(prim, a, aRows, aColumns, b, bRows, bColumns, out, eventsToWaitFor);
202+
}
203+
}
204+
public <T> CLEvent blockMatrixMultiply(int blockSize, Primitive prim, CLBuffer<T> a, long aRows, long aColumns, CLBuffer<T> b, long bRows, long bColumns, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
183205
if (out == null)
184206
throw new IllegalArgumentException("Null output matrix !");
185207
//if (out != null)
186208
// out = (CLBuffer<T>)context.createBuffer(Usage.Output, prim.primitiveType, aRows * bColumns);
187209

188210
CLKernel kernel;
211+
String key = "block_" + blockSize + "_" + prim;
189212
synchronized (matrixMultiplyKernels) {
190-
kernel = matrixMultiplyKernels.get(prim);
213+
kernel = matrixMultiplyKernels.get(key);
191214
if (kernel == null) {
192-
String src =
193-
prim.getRequiredPragmas() +
194-
"#define BLOCK_SIZE " + BLOCK_SIZE + "\n" +
215+
String src = prim.getRequiredPragmas() +
216+
"#define BLOCK_SIZE " + blockSize + "\n" +
195217
"#define AS(i, j) As[j + i * BLOCK_SIZE]\n" +
196218
"#define BS(i, j) Bs[j + i * BLOCK_SIZE]\n" +
197219
"\n" +
@@ -263,23 +285,65 @@ public <T> CLEvent matrixMultiply(Primitive prim, CLBuffer<T> a, long aRows, lon
263285
String clTypeName = prim.clTypeName();
264286
src = src.replaceAll("double", clTypeName);
265287
kernel = context.createProgram(src).createKernel("mulMat");
266-
matrixMultiplyKernels.put(prim, kernel);
288+
matrixMultiplyKernels.put(key, kernel);
267289
}
268290
}
269291
synchronized (kernel) {
270292
kernel.setArgs(a, (int) aColumns, b, (int) bColumns, out,
271-
LocalSize.ofFloatArray(BLOCK_SIZE * BLOCK_SIZE),
272-
LocalSize.ofFloatArray(BLOCK_SIZE * BLOCK_SIZE));
293+
LocalSize.ofFloatArray(blockSize * blockSize),
294+
LocalSize.ofFloatArray(blockSize * blockSize));
273295
CLEvent evt = kernel.enqueueNDRange(queue,
274296
new int[]{(int) aRows, (int) bColumns},
275-
new int[]{BLOCK_SIZE, BLOCK_SIZE},
297+
new int[]{blockSize, blockSize},
276298
eventsToWaitFor);
277299
return evt;
278300
}
279301
}
280302

303+
public <T> CLEvent naiveMatrixMultiply(Primitive prim, CLBuffer<T> a, long aRows, long aColumns, CLBuffer<T> b, long bRows, long bColumns, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
304+
if (out == null)
305+
throw new IllegalArgumentException("Null output matrix !");
306+
//if (out != null)
307+
// out = (CLBuffer<T>)context.createBuffer(Usage.Output, prim.primitiveType, aRows * bColumns);
308+
309+
CLKernel kernel;
310+
String key = "naive_" + prim;
311+
synchronized (matrixMultiplyKernels) {
312+
kernel = matrixMultiplyKernels.get(key);
313+
if (kernel == null) {
314+
String src = prim.getRequiredPragmas() +
315+
"__kernel void mulMat( " +
316+
" __global const double* a, int aRows, int aColumns, " +
317+
" __global const double* b, int bColumns, " +
318+
" __global double* c " +
319+
") { " +
320+
" int i = get_global_id(0); " +
321+
" int j = get_global_id(1); " +
322+
" " +
323+
" if (i >= aRows || j >= bColumns) return; " +
324+
" double total = 0; " +
325+
" size_t iOff = i * (size_t)aColumns; " +
326+
" for (long k = 0; k < aColumns; k++) { " +
327+
" total += a[iOff + k] * b[k * (size_t)bColumns + j]; " +
328+
" } " +
329+
" c[i * (size_t)bColumns + j] = total; " +
330+
"} "
331+
;
332+
String clTypeName = prim.clTypeName();
333+
src = src.replaceAll("double", clTypeName);
334+
kernel = context.createProgram(src).createKernel("mulMat");
335+
matrixMultiplyKernels.put(key, kernel);
336+
}
337+
}
338+
synchronized (kernel) {
339+
kernel.setArgs(a, (int)aRows, (int)aColumns, b, (int)bColumns, out);
340+
CLEvent evt = kernel.enqueueNDRange(queue, new int [] { (int)aRows, (int)bColumns }, eventsToWaitFor);
341+
return evt;
342+
}
343+
}
344+
281345
Map<Primitive, CLKernel[]> matrixTransposeKernels = new HashMap<Primitive, CLKernel[]>();
282-
public <T> CLEvent matrixTranspose(Primitive prim, CLBuffer<T> a, long aRows, long aColumns, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
346+
public <T> CLEvent matrixTranspose(Primitive prim, CLBuffer<T> a, long aRows, long aColumns, long aStride, CLBuffer<T> out, CLEvent... eventsToWaitFor) throws CLBuildException {
283347
if (out == null)
284348
throw new IllegalArgumentException("Null output matrix !");
285349
//if (out != null)
@@ -292,29 +356,29 @@ public <T> CLEvent matrixTranspose(Primitive prim, CLBuffer<T> a, long aRows, lo
292356
String src =
293357
prim.getRequiredPragmas() +
294358
"__kernel void transposeSelf( \n" +
295-
" __global double* a, int aRows, int aColumns \n" +
359+
" __global double* a, int aRows, int aColumns, int aStride \n" +
296360
") { \n" +
297361
" int i = get_global_id(0); \n" +
298362
" int j = get_global_id(1); \n" +
299363
" \n" +
300364
" if (i >= aRows || j >= aColumns || j >= i) return; \n" +
301365
" \n" +
302-
" size_t aIndex = i * aColumns + j; \n" +
366+
" size_t aIndex = i * aStride + j; \n" +
303367
" size_t outIndex = j * aRows + i; \n" +
304368
" double temp = a[outIndex]; \n" +
305369
" a[outIndex] = a[aIndex]; \n" +
306370
" a[aIndex] = temp; \n" +
307371
"} \n" +
308372
"__kernel void transposeOther( \n" +
309-
" __global const double* a, int aRows, int aColumns, \n" +
373+
" __global const double* a, int aRows, int aColumns, int aStride, \n" +
310374
" __global double* out \n" +
311375
") { \n" +
312376
" int i = get_global_id(0); \n" +
313377
" int j = get_global_id(1); \n" +
314378
" \n" +
315379
" if (i >= aRows || j >= aColumns) return; \n" +
316380
" \n" +
317-
" size_t aIndex = i * aColumns + j; \n" +
381+
" size_t aIndex = i * aStride + j; \n" +
318382
" size_t outIndex = j * aRows + i; \n" +
319383
" out[outIndex] = a[aIndex]; \n" +
320384
"} \n"
@@ -330,9 +394,9 @@ public <T> CLEvent matrixTranspose(Primitive prim, CLBuffer<T> a, long aRows, lo
330394
CLKernel kernel = kernels[self ? 0 : 1];
331395
synchronized (kernel) {
332396
if (self)
333-
kernel.setArgs(a, (int)aRows, (int)aColumns);
397+
kernel.setArgs(a, (int)aRows, (int)aColumns, (int)aStride);
334398
else
335-
kernel.setArgs(a, (int)aRows, (int)aColumns, out);
399+
kernel.setArgs(a, (int)aRows, (int)aColumns, (int)aStride, out);
336400

337401
CLEvent evt = kernel.enqueueNDRange(queue, new int [] { (int)aRows, (int)aColumns }, eventsToWaitFor);
338402
return evt;

Blas/src/main/java/com/nativelibs4java/opencl/blas/CLMatrix2D.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
*/
1717
public interface CLMatrix2D<T> {
1818

19-
int BLOCK_SIZE = 16;
20-
2119
Primitive getPrimitive();
2220
Class<T> getPrimitiveClass();
2321
CLEvents getEvents();
@@ -26,6 +24,8 @@ public interface CLMatrix2D<T> {
2624
CLQueue getQueue();
2725
long getRowCount();
2826
long getColumnCount();
27+
long getStride();
28+
int getBlockSize();
2929
CLMatrix2D<T> blankClone();
3030
CLMatrix2D<T> blankMatrix(long rows, long columns);
3131
CLKernels getKernels();

0 commit comments

Comments
 (0)