Expand contiguous reduce specializations#6338
Expand contiguous reduce specializations#6338prsabahrami wants to merge 2 commits intomodular:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the GPU reduction backend to improve contiguous-axis reduction performance and fixes a benchmark host buffer sizing issue.
Changes:
- Fix host-side result buffer sizing in
bench_reduce.mojo. - Collapse per-thread SIMD row accumulators to scalars before
block_reduce. - Add a warp-level contiguous-axis fast path (
warp_reduce_kernel) and update kernel dispatch/SIMD selection inreduce_launch.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
mojo/stdlib/std/algorithm/backend/gpu/reduction.mojo |
Adds warp-level contiguous reduction kernel, changes SIMD→scalar collapse before block reduce, and adjusts dispatch/SIMD width selection. |
max/kernels/benchmarks/gpu/bench_reduce.mojo |
Fixes host-side buffer allocation to match the copied device buffer size. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| var tail_col = (row_size // VEC_STRIDE) * VEC_STRIDE + lid | ||
| while tail_col < row_size: | ||
| row_coords[axis] = tail_col | ||
| var v = input_fn[dtype, 1, rank](row_coords).cast[accum_type]() | ||
| comptime for i in range(num_reductions): | ||
| scalar_accum[i] = reduce_fn[accum_type, 1, i](scalar_accum[i], v) | ||
| tail_col += WARP_SIZE |
There was a problem hiding this comment.
In warp_reduce_kernel, tail_col starts at (row_size // VEC_STRIDE) * VEC_STRIDE + lid, but the vector loop already covers all elements < vec_limit (including the remainder when row_size < VEC_STRIDE or simd_width == 1). This causes some elements to be reduced twice (e.g. row_size=64, simd_width=8; or row_size=65, simd_width=1), producing incorrect results. Compute the scalar-tail start from the number of full SIMD vectors instead (e.g. tail_start = (row_size // simd_width) * simd_width, then tail_col = tail_start + lid) so only the final row_size % simd_width elements are handled in the tail loop.
| while vec_col < vec_limit: | ||
| row_coords[axis] = vec_col | ||
| var v = input_fn[dtype, simd_width, rank](row_coords).cast[accum_type]() | ||
| comptime for i in range(num_reductions): | ||
| accum[i] = reduce_fn[accum_type, simd_width, i](accum[i], v) | ||
| vec_col += VEC_STRIDE |
There was a problem hiding this comment.
The per-thread accumulation in warp_reduce_kernel uses reduce_fn(accum, v) (accumulator as the first operand). Elsewhere in the GPU backend (and the CPU backend) the reduction wrapper is consistently applied as reduce_fn(val, acc) (new value first), so this flips operand order for non-commutative reductions and makes results kernel-dependent. Swap the argument order here (and similarly in the scalar tail / lane-collapse paths) to match the established reduce_fn(val, acc) convention.
| var lane_accum = accum[i][0] | ||
| comptime for lane in range(1, simd_width): | ||
| lane_accum = reduce_fn[accum_type, 1, i](lane_accum, accum[i][lane]) | ||
| scalar_vals[i] = lane_accum |
There was a problem hiding this comment.
row_reduce now manually collapses the SIMD accumulator to a scalar before block_reduce, but the lane fold uses reduce_fn(lane_accum, accum[i][lane]) (accumulator first). Other call sites in this file reduce as reduce_fn(val, acc) (value first), and the std reduction wrapper contract uses (val, acc). To keep semantics consistent across kernels (especially for non-commutative reductions), flip the operand order in this lane-collapse fold.
|
!sync |
|
|
||
| var in_host = alloc[Scalar[dtype]](cb_in.alloc_size()) | ||
| var res_host = alloc[Scalar[dtype]](out_size) | ||
| var res_host = alloc[Scalar[dtype]](in_size) |
There was a problem hiding this comment.
This doesn't seem right to me. The host side results array should be of out_size. Shouldn't instead res_buffer be set to size out_size?
| comptime contig_simd = simd_width_of[dtype, get_gpu_target()]() | ||
| comptime for ax in range(rank): | ||
| if axis == ax: | ||
| comptime is_contig = (ax == rank - 1) |
There was a problem hiding this comment.
Don't need to recompute is_contig here. Just use reduce_contig_dim and pull reduce_simd out of the for loop.
joeatodd
left a comment
There was a problem hiding this comment.
Hi @prsabahrami, thanks for this PR. Did you get a chance to look over my review comments, and those from Copilot?
|
Apologies for the delay here, will address these tomorrow! |
Summary
reduction.mojobench_reduce.mojo3072fixed-turn ILP2 and4096warp0-epilogue paths for contiguous axis-2 reductionsBenchmark Notes
3072path:1x1024x3072 axis=2:+27.53%bf16 and+32.53%f16 versus the conservative control+19.65%bf16 and+27.08%f164096tail path:1x256x4096 axis=2:241.07bf16 /237.48f16 GElems/s+28.58%bf16 and+23.55%f161x1024x3072 axis=2:376.77bf16 /365.46f16 GElems/s1x256x4096 axis=2:161.52bf16 /158.92f16 GElems/s32x1024x256x1024 axis=3:1178.02bf16 /1385.40f16 GElems/s