Skip to content

Commit 7618e38

Browse files
committed
JavaCL / BLAS: more sensible CLKernels fix
1 parent 933f352 commit 7618e38

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

Blas/src/main/java/com/nativelibs4java/opencl/blas/CLKernels.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,12 @@ public <V> boolean containsValue(Primitive primitive, CLBuffer<V> buffer, long l
120120
kernel = containsValueKernels.get(primitive);
121121
if (kernel == null) {
122122
kernel = context.createProgram((
123-
PRAGMA_DOUBLE +
123+
(primitive.primitiveType == double.class ? Primitive. PRAGMA_DOUBLE : "") +
124124
"__kernel void containsValue( \n" +
125125
" __global const double* a, \n" +
126126
" int length, \n" +
127127
" double value, \n" +
128-
" __global char* pOut \n" +
128+
" __global int* pOut \n" +
129129
") { \n" +
130130
" int i = get_global_id(0);\n" +
131131
" if (i >= length) \n" +
@@ -139,10 +139,10 @@ public <V> boolean containsValue(Primitive primitive, CLBuffer<V> buffer, long l
139139
}
140140
}
141141
synchronized(kernel) {
142-
CLBuffer<Byte> pOut = context.createBuffer(Usage.Output, Byte.class, 1);
142+
CLBuffer<Integer> pOut = context.createBuffer(Usage.Output, Integer.class, 1);
143143
kernel.setArgs(buffer, (int)length, value, pOut);
144144
kernel.enqueueNDRange(queue, new int[] { (int)length }, eventsToWaitFor).waitFor();
145-
return pOut.read(queue).getBoolean();
145+
return pOut.read(queue).getInt() != 0;
146146
}
147147
}
148148

@@ -153,7 +153,7 @@ public <V> CLEvent clear(Primitive primitive, CLBuffer<V> buffer, long length, C
153153
kernel = clearKernels.get(primitive);
154154
if (kernel == null) {
155155
kernel = context.createProgram((
156-
PRAGMA_DOUBLE +
156+
(primitive.primitiveType == double.class ? Primitive. PRAGMA_DOUBLE : "") +
157157
"__kernel void clear_buffer( \n" +
158158
" __global double* a, \n" +
159159
" int length \n" +
@@ -188,7 +188,7 @@ public <T> CLEvent matrixMultiply(Primitive prim, CLBuffer<T> a, long aRows, lon
188188
kernel = matrixMultiplyKernels.get(prim);
189189
if (kernel == null) {
190190
String src =
191-
PRAGMA_DOUBLE +
191+
(prim.primitiveType == double.class ? Primitive. PRAGMA_DOUBLE : "") +
192192
"__kernel void mulMat( " +
193193
" __global const double* a, int aRows, int aColumns, " +
194194
" __global const double* b, int bColumns, " +
@@ -231,7 +231,7 @@ public <T> CLEvent matrixTranspose(Primitive prim, CLBuffer<T> a, long aRows, lo
231231
kernels = matrixTransposeKernels.get(prim);
232232
if (kernels == null) {
233233
String src =
234-
PRAGMA_DOUBLE +
234+
(prim.primitiveType == double.class ? Primitive. PRAGMA_DOUBLE : "") +
235235
"__kernel void transposeSelf( \n" +
236236
" __global double* a, int aRows, int aColumns \n" +
237237
") { \n" +

0 commit comments

Comments
 (0)