Skip to content

Commit 34f8113

Browse files
Added dynamic work-group size computation to sycl_buffer example
1 parent b1ca555 commit 34f8113

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

examples/cython/sycl_buffer/src/use_sycl_buffer.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#pragma once
3131

3232
#include <CL/sycl.hpp>
33-
#include <iostream>
33+
#include <algorithm>
3434

3535
inline size_t upper_multiple(size_t n, size_t wg)
3636
{
@@ -54,7 +54,10 @@ void columnwise_total(sycl::queue q,
5454
[=](sycl::id<1> i) { ct_acc[i] = dataT(0); });
5555
});
5656

57-
constexpr size_t wg = 256;
57+
const sycl::device &d = q.get_device();
58+
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
59+
size_t wg =
60+
2 * (*std::max_element(std::begin(sg_sizes), std::end(sg_sizes)));
5861

5962
q.submit([&](sycl::handler &h) {
6063
sycl::accessor mat_acc{mat_buffer, h, sycl::read_only};
@@ -66,12 +69,12 @@ void columnwise_total(sycl::queue q,
6669
h.parallel_for(
6770
sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) {
6871
size_t i = it.get_global_id(0);
69-
size_t j = it.get_global_id(1);
7072
dataT group_sum = sycl::reduce_over_group(
7173
it.get_group(),
7274
(i < n) ? mat_acc[it.get_global_id()] : dataT(0),
7375
std::plus<dataT>());
74-
if (it.get_local_id(0) == 0) {
76+
if (it.get_group().leader()) {
77+
size_t j = it.get_global_id(1);
7578
sycl::atomic_ref<dataT, sycl::memory_order::relaxed,
7679
sycl::memory_scope::system,
7780
sycl::access::address_space::global_space>(

0 commit comments

Comments
 (0)