2626import com .nativelibs4java .opencl .util .ParallelMath ;
2727import com .nativelibs4java .opencl .util .Primitive ;
2828
29- import static com .nativelibs4java .opencl .blas .CLMatrix2D . BLOCK_SIZE ;
29+ import static com .nativelibs4java .opencl .blas .CLMatrixUtils . roundUp ;
3030import static org .bridj .Pointer .pointerToInt ;
3131
3232/**
@@ -70,10 +70,10 @@ public CLKernels(CLQueue queue) throws IOException, CLBuildException {
7070 this .queue = queue ;
7171 }
7272
73- public <T > CLEvent op1 (Primitive prim , Fun1 fun , CLBuffer <T > a , long rows , long columns , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
74- long length = rows * columns ;
75- if (out == null || out .getElementCount () != length )
76- throw new IllegalArgumentException ("Expected buffer of length " + length + ", got " + out );
73+ public <T > CLEvent op1 (Primitive prim , Fun1 fun , CLBuffer <T > a , long rows , long columns , long stride , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
74+ long length = rows * stride ;
75+ if (out == null || out .getElementCount () < length )
76+ throw new IllegalArgumentException ("Expected buffer of length >= " + length + ", got " + out );
7777 //if (out != null)
7878 // out = (CLBuffer<T>)context.createBuffer(Usage.Output, prim.primitiveType, length);
7979
@@ -85,10 +85,10 @@ public <T> CLEvent op1(Primitive prim, Fun1 fun, CLBuffer<T> a, long rows, long
8585 }
8686 }
8787
88- public <T > CLEvent op2 (Primitive prim , Fun2 fun , CLBuffer <T > a , CLBuffer <T > b , long rows , long columns , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
89- long length = rows * columns ;
90- if (out == null || out .getElementCount () != length )
91- throw new IllegalArgumentException ("Expected buffer of length " + length + ", got " + out );
88+ public <T > CLEvent op2 (Primitive prim , Fun2 fun , CLBuffer <T > a , CLBuffer <T > b , long rows , long columns , long stride , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
89+ long length = rows * stride ;
90+ if (out == null || out .getElementCount () < length )
91+ throw new IllegalArgumentException ("Expected buffer of length >= " + length + ", got " + out . getElementCount () );
9292 //if (out != null)
9393 // out = (CLBuffer<T>)context.createBuffer(Usage.Output, prim.primitiveType, length);
9494
@@ -100,10 +100,10 @@ public <T> CLEvent op2(Primitive prim, Fun2 fun, CLBuffer<T> a, CLBuffer<T> b, l
100100 }
101101 }
102102
103- public <T > CLEvent op2 (Primitive prim , Fun2 fun , CLBuffer <T > a , T b , long rows , long columns , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
104- long length = rows * columns ;
105- if (out == null || out .getElementCount () != length )
106- throw new IllegalArgumentException ("Expected buffer of length " + length + ", got " + out .getElementCount ());
103+ public <T > CLEvent op2 (Primitive prim , Fun2 fun , CLBuffer <T > a , T b , long rows , long columns , long stride , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
104+ long length = rows * stride ;
105+ if (out == null || out .getElementCount () < length )
106+ throw new IllegalArgumentException ("Expected buffer of length >= " + length + ", got " + out .getElementCount ());
107107 //if (out != null)
108108 // out = (CLBuffer<T>)context.createBuffer(Usage.Output, prim.primitiveType, length);
109109
@@ -178,20 +178,42 @@ public <V> CLEvent clear(Primitive primitive, CLBuffer<V> buffer, long length, C
178178 }
179179 }
180180
181- Map <Primitive , CLKernel > matrixMultiplyKernels = new HashMap <Primitive , CLKernel >();
182- public <T > CLEvent matrixMultiply (Primitive prim , CLBuffer <T > a , long aRows , long aColumns , CLBuffer <T > b , long bRows , long bColumns , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
181+ Map <String , CLKernel > matrixMultiplyKernels = new HashMap <String , CLKernel >();
182+ public <T > CLEvent matrixMultiply (Primitive prim ,
183+ CLBuffer <T > a , long aRows , long aColumns , int aBlockSize ,
184+ CLBuffer <T > b , long bRows , long bColumns , int bBlockSize ,
185+ CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
186+ boolean useBlocks = false ;
187+ int blockSize = aBlockSize ;
188+ if (blockSize > 1 && blockSize == bBlockSize ) {
189+ long [] maxWorkItemSizes = queue .getDevice ().getMaxWorkItemSizes ();
190+ useBlocks = maxWorkItemSizes .length >= 2 &&
191+ maxWorkItemSizes [0 ] >= blockSize &&
192+ maxWorkItemSizes [1 ] >= blockSize ;
193+ }
194+ if (useBlocks ) {
195+ return blockMatrixMultiply (
196+ blockSize , prim ,
197+ a , roundUp (aRows , blockSize ), roundUp (aColumns , blockSize ),
198+ b , roundUp (bRows , blockSize ), roundUp (bColumns , blockSize ),
199+ out , eventsToWaitFor );
200+ } else {
201+ return naiveMatrixMultiply (prim , a , aRows , aColumns , b , bRows , bColumns , out , eventsToWaitFor );
202+ }
203+ }
204+ public <T > CLEvent blockMatrixMultiply (int blockSize , Primitive prim , CLBuffer <T > a , long aRows , long aColumns , CLBuffer <T > b , long bRows , long bColumns , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
183205 if (out == null )
184206 throw new IllegalArgumentException ("Null output matrix !" );
185207 //if (out != null)
186208 // out = (CLBuffer<T>)context.createBuffer(Usage.Output, prim.primitiveType, aRows * bColumns);
187209
188210 CLKernel kernel ;
211+ String key = "block_" + blockSize + "_" + prim ;
189212 synchronized (matrixMultiplyKernels ) {
190- kernel = matrixMultiplyKernels .get (prim );
213+ kernel = matrixMultiplyKernels .get (key );
191214 if (kernel == null ) {
192- String src =
193- prim .getRequiredPragmas () +
194- "#define BLOCK_SIZE " + BLOCK_SIZE + "\n " +
215+ String src = prim .getRequiredPragmas () +
216+ "#define BLOCK_SIZE " + blockSize + "\n " +
195217 "#define AS(i, j) As[j + i * BLOCK_SIZE]\n " +
196218 "#define BS(i, j) Bs[j + i * BLOCK_SIZE]\n " +
197219 "\n " +
@@ -263,23 +285,65 @@ public <T> CLEvent matrixMultiply(Primitive prim, CLBuffer<T> a, long aRows, lon
263285 String clTypeName = prim .clTypeName ();
264286 src = src .replaceAll ("double" , clTypeName );
265287 kernel = context .createProgram (src ).createKernel ("mulMat" );
266- matrixMultiplyKernels .put (prim , kernel );
288+ matrixMultiplyKernels .put (key , kernel );
267289 }
268290 }
269291 synchronized (kernel ) {
270292 kernel .setArgs (a , (int ) aColumns , b , (int ) bColumns , out ,
271- LocalSize .ofFloatArray (BLOCK_SIZE * BLOCK_SIZE ),
272- LocalSize .ofFloatArray (BLOCK_SIZE * BLOCK_SIZE ));
293+ LocalSize .ofFloatArray (blockSize * blockSize ),
294+ LocalSize .ofFloatArray (blockSize * blockSize ));
273295 CLEvent evt = kernel .enqueueNDRange (queue ,
274296 new int []{(int ) aRows , (int ) bColumns },
275- new int []{BLOCK_SIZE , BLOCK_SIZE },
297+ new int []{blockSize , blockSize },
276298 eventsToWaitFor );
277299 return evt ;
278300 }
279301 }
280302
303+ public <T > CLEvent naiveMatrixMultiply (Primitive prim , CLBuffer <T > a , long aRows , long aColumns , CLBuffer <T > b , long bRows , long bColumns , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
304+ if (out == null )
305+ throw new IllegalArgumentException ("Null output matrix !" );
306+ //if (out != null)
307+ // out = (CLBuffer<T>)context.createBuffer(Usage.Output, prim.primitiveType, aRows * bColumns);
308+
309+ CLKernel kernel ;
310+ String key = "naive_" + prim ;
311+ synchronized (matrixMultiplyKernels ) {
312+ kernel = matrixMultiplyKernels .get (key );
313+ if (kernel == null ) {
314+ String src = prim .getRequiredPragmas () +
315+ "__kernel void mulMat( " +
316+ " __global const double* a, int aRows, int aColumns, " +
317+ " __global const double* b, int bColumns, " +
318+ " __global double* c " +
319+ ") { " +
320+ " int i = get_global_id(0); " +
321+ " int j = get_global_id(1); " +
322+ " " +
323+ " if (i >= aRows || j >= bColumns) return; " +
324+ " double total = 0; " +
325+ " size_t iOff = i * (size_t)aColumns; " +
326+ " for (long k = 0; k < aColumns; k++) { " +
327+ " total += a[iOff + k] * b[k * (size_t)bColumns + j]; " +
328+ " } " +
329+ " c[i * (size_t)bColumns + j] = total; " +
330+ "} "
331+ ;
332+ String clTypeName = prim .clTypeName ();
333+ src = src .replaceAll ("double" , clTypeName );
334+ kernel = context .createProgram (src ).createKernel ("mulMat" );
335+ matrixMultiplyKernels .put (key , kernel );
336+ }
337+ }
338+ synchronized (kernel ) {
339+ kernel .setArgs (a , (int )aRows , (int )aColumns , b , (int )bColumns , out );
340+ CLEvent evt = kernel .enqueueNDRange (queue , new int [] { (int )aRows , (int )bColumns }, eventsToWaitFor );
341+ return evt ;
342+ }
343+ }
344+
281345 Map <Primitive , CLKernel []> matrixTransposeKernels = new HashMap <Primitive , CLKernel []>();
282- public <T > CLEvent matrixTranspose (Primitive prim , CLBuffer <T > a , long aRows , long aColumns , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
346+ public <T > CLEvent matrixTranspose (Primitive prim , CLBuffer <T > a , long aRows , long aColumns , long aStride , CLBuffer <T > out , CLEvent ... eventsToWaitFor ) throws CLBuildException {
283347 if (out == null )
284348 throw new IllegalArgumentException ("Null output matrix !" );
285349 //if (out != null)
@@ -292,29 +356,29 @@ public <T> CLEvent matrixTranspose(Primitive prim, CLBuffer<T> a, long aRows, lo
292356 String src =
293357 prim .getRequiredPragmas () +
294358 "__kernel void transposeSelf( \n " +
295- " __global double* a, int aRows, int aColumns \n " +
359+ " __global double* a, int aRows, int aColumns, int aStride \n " +
296360 ") { \n " +
297361 " int i = get_global_id(0); \n " +
298362 " int j = get_global_id(1); \n " +
299363 " \n " +
300364 " if (i >= aRows || j >= aColumns || j >= i) return; \n " +
301365 " \n " +
302- " size_t aIndex = i * aColumns + j; \n " +
366+ " size_t aIndex = i * aStride + j; \n " +
303367 " size_t outIndex = j * aRows + i; \n " +
304368 " double temp = a[outIndex]; \n " +
305369 " a[outIndex] = a[aIndex]; \n " +
306370 " a[aIndex] = temp; \n " +
307371 "} \n " +
308372 "__kernel void transposeOther( \n " +
309- " __global const double* a, int aRows, int aColumns, \n " +
373+ " __global const double* a, int aRows, int aColumns, int aStride, \n " +
310374 " __global double* out \n " +
311375 ") { \n " +
312376 " int i = get_global_id(0); \n " +
313377 " int j = get_global_id(1); \n " +
314378 " \n " +
315379 " if (i >= aRows || j >= aColumns) return; \n " +
316380 " \n " +
317- " size_t aIndex = i * aColumns + j; \n " +
381+ " size_t aIndex = i * aStride + j; \n " +
318382 " size_t outIndex = j * aRows + i; \n " +
319383 " out[outIndex] = a[aIndex]; \n " +
320384 "} \n "
@@ -330,9 +394,9 @@ public <T> CLEvent matrixTranspose(Primitive prim, CLBuffer<T> a, long aRows, lo
330394 CLKernel kernel = kernels [self ? 0 : 1 ];
331395 synchronized (kernel ) {
332396 if (self )
333- kernel .setArgs (a , (int )aRows , (int )aColumns );
397+ kernel .setArgs (a , (int )aRows , (int )aColumns , ( int ) aStride );
334398 else
335- kernel .setArgs (a , (int )aRows , (int )aColumns , out );
399+ kernel .setArgs (a , (int )aRows , (int )aColumns , ( int ) aStride , out );
336400
337401 CLEvent evt = kernel .enqueueNDRange (queue , new int [] { (int )aRows , (int )aColumns }, eventsToWaitFor );
338402 return evt ;
0 commit comments