3030#pragma once
3131
3232#include < CL/sycl.hpp>
33- #include < iostream >
33+ #include < algorithm >
3434
3535inline size_t upper_multiple (size_t n, size_t wg)
3636{
@@ -54,7 +54,10 @@ void columnwise_total(sycl::queue q,
5454 [=](sycl::id<1 > i) { ct_acc[i] = dataT (0 ); });
5555 });
5656
57- constexpr size_t wg = 256 ;
57+ const sycl::device &d = q.get_device ();
58+ const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
59+ size_t wg =
60+ 2 * (*std::max_element (std::begin (sg_sizes), std::end (sg_sizes)));
5861
5962 q.submit ([&](sycl::handler &h) {
6063 sycl::accessor mat_acc{mat_buffer, h, sycl::read_only};
@@ -66,12 +69,12 @@ void columnwise_total(sycl::queue q,
6669 h.parallel_for (
6770 sycl::nd_range<2 >(global, local), [=](sycl::nd_item<2 > it) {
6871 size_t i = it.get_global_id (0 );
69- size_t j = it.get_global_id (1 );
7072 dataT group_sum = sycl::reduce_over_group (
7173 it.get_group (),
7274 (i < n) ? mat_acc[it.get_global_id ()] : dataT (0 ),
7375 std::plus<dataT>());
74- if (it.get_local_id (0 ) == 0 ) {
76+ if (it.get_group ().leader ()) {
77+ size_t j = it.get_global_id (1 );
7578 sycl::atomic_ref<dataT, sycl::memory_order::relaxed,
7679 sycl::memory_scope::system,
7780 sycl::access::address_space::global_space>(
0 commit comments