Skip to content

Commit 584d207

Browse files
alkanttensorflower-gardener
authored andcommitted
Allow Gather, SegmentReduction, and FillEmptyRows GPU Ops to handle int64 work element counts safely; fix unsafe DivUp arithmetic.
* Introduces `GpuLaunchConfig64` and updates `GetGpuLaunchConfig64` to return `absl::StatusOr<GpuLaunchConfig64>` to safely handle 64-bit work element counts and propagate errors without crashing. The `gather`, `segment_reduction`, and `fill_empty_rows` ops have been updated to use these new 64-bit utilities and handle the status returns. * Replaces deprecated 1D grid iterators with the new ones that support an `int64` type loop variables to prevent truncation issues when iterating over large grids. * Fix unsafe DivUp arithmetic for GPU kernel size computations. PiperOrigin-RevId: 896629242
1 parent d1f33e5 commit 584d207

6 files changed

Lines changed: 241 additions & 118 deletions

File tree

tensorflow/core/kernels/depthwise_conv_op_gpu.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ limitations under the License.
3737
#endif
3838

3939
namespace tensorflow {
40+
using Eigen::numext::div_ceil;
4041

4142
namespace detail {
4243
template <typename T>
@@ -640,7 +641,7 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
640641
case FORMAT_NHWC:
641642
block_dim = dim3(kBlockDepth, args.in_cols, block_height);
642643
block_count =
643-
args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth;
644+
args.batch * div_ceil(args.out_depth, kBlockDepth) * kBlockDepth;
644645
kernel =
645646
DepthwiseConv2dGPUKernelNHWCSmall<T, kDirection, kKnownFilterWidth,
646647
kKnownFilterHeight, kBlockDepth,
@@ -649,7 +650,7 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
649650
case FORMAT_NCHW:
650651
block_dim = dim3(args.in_cols, block_height, kBlockDepth);
651652
block_count =
652-
DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
653+
div_ceil(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
653654
kernel =
654655
DepthwiseConv2dGPUKernelNCHWSmall<T, kDirection, kKnownFilterWidth,
655656
kKnownFilterHeight, kBlockDepth,
@@ -1567,14 +1568,14 @@ Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
15671568
case FORMAT_NHWC:
15681569
block_dim = dim3(kBlockDepth, args.in_cols, block_height);
15691570
block_count =
1570-
args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth;
1571+
args.batch * div_ceil(args.out_depth, kBlockDepth) * kBlockDepth;
15711572
kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<
15721573
T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>;
15731574
break;
15741575
case FORMAT_NCHW:
15751576
block_dim = dim3(args.in_cols, block_height, kBlockDepth);
15761577
block_count =
1577-
DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
1578+
div_ceil(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
15781579
kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<
15791580
T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>;
15801581
break;

tensorflow/core/kernels/fill_empty_rows_functor_gpu.cu.cc

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#define EIGEN_USE_GPU
1919

2020
#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive
21+
#include "xla/tsl/platform/statusor.h"
2122
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
2223
#include "tensorflow/core/framework/register_types.h"
2324
#include "tensorflow/core/framework/tensor_types.h"
@@ -54,7 +55,8 @@ typename T::ConstPointerType to_pointers(const T& x) {
5455
template <typename Tindex, typename... CallerArgs, typename... KernelArgs>
5556
Status wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& device,
5657
Tindex size, CallerArgs... args) {
57-
auto config = GetGpuLaunchConfig(size, device);
58+
TF_ASSIGN_OR_RETURN(GpuLaunchConfig64 config,
59+
GetGpuLaunchConfig64(size, device));
5860
return GpuLaunchKernel(func, config.block_count, config.thread_per_block, 0,
5961
device.stream(), config, to_pointers(args)...);
6062
}
@@ -77,10 +79,10 @@ struct CastFunctor {
7779
// true if the indices are not ordered by row.
7880
template <typename Tindex>
7981
__global__ __launch_bounds__(1024) void CountElementsPerRowKernel(
80-
GpuLaunchConfig cfg, Tindex dense_rows, int rank, const Tindex* indices,
82+
GpuLaunchConfig64 cfg, Tindex dense_rows, int rank, const Tindex* indices,
8183
Tindex* elements_per_row, int* rows_are_not_ordered,
8284
int* first_invalid_index) {
83-
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
85+
for (int64_t i : GpuGridRangeX(cfg.virtual_thread_count)) {
8486
Tindex row = indices[i * rank];
8587
if (row < 0 || row >= dense_rows) {
8688
GpuAtomicMin(first_invalid_index, i);
@@ -98,18 +100,19 @@ __global__ __launch_bounds__(1024) void CountElementsPerRowKernel(
98100

99101
template <typename Tindex>
100102
__global__ __launch_bounds__(1024) void CopyRowIndicesKernel(
101-
GpuLaunchConfig cfg, int rank, const Tindex* indices, Tindex* row_indices) {
102-
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
103+
GpuLaunchConfig64 cfg, int rank, const Tindex* indices,
104+
Tindex* row_indices) {
105+
for (int64_t i : GpuGridRangeX(cfg.virtual_thread_count)) {
103106
row_indices[i] = indices[i * rank];
104107
}
105108
}
106109

107110
// Sets empty_row_indicator[row] to whether the row is empty.
108111
template <typename Tindex>
109112
__global__ __launch_bounds__(1024) void ComputeEmptyRowIndicatorKernel(
110-
GpuLaunchConfig cfg, const Tindex* elements_per_row,
113+
GpuLaunchConfig64 cfg, const Tindex* elements_per_row,
111114
bool* empty_row_indicator) {
112-
GPU_1D_KERNEL_LOOP(row, cfg.virtual_thread_count) {
115+
for (int64_t row : GpuGridRangeX(cfg.virtual_thread_count)) {
113116
empty_row_indicator[row] = elements_per_row[row] == 0;
114117
}
115118
}
@@ -119,11 +122,11 @@ __global__ __launch_bounds__(1024) void ComputeEmptyRowIndicatorKernel(
119122
// empty row.
120123
template <typename T, typename Tindex>
121124
__global__ __launch_bounds__(1024) void ScatterInputElementsKernel(
122-
GpuLaunchConfig cfg, Tindex dense_rows, int rank,
125+
GpuLaunchConfig64 cfg, Tindex dense_rows, int rank,
123126
const Tindex* input_index_map, const Tindex* indices, const T* values,
124127
const Tindex* num_new_rows_before, Tindex* output_indices, T* output_values,
125128
Tindex* reverse_index_map) {
126-
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
129+
for (int64_t i : ::tensorflow::GpuGridRangeX(cfg.virtual_thread_count)) {
127130
Tindex input_i = input_index_map ? input_index_map[i] : i;
128131
Tindex row = indices[input_i * rank];
129132
Tindex output_i = i + num_new_rows_before[row];
@@ -141,10 +144,10 @@ __global__ __launch_bounds__(1024) void ScatterInputElementsKernel(
141144
// input) in output_indices and output_values.
142145
template <typename T, typename Tindex>
143146
__global__ __launch_bounds__(1024) void ScatterNewElementsKernel(
144-
GpuLaunchConfig cfg, int rank, const T* default_value,
147+
GpuLaunchConfig64 cfg, int rank, const T* default_value,
145148
const Tindex* num_new_rows_through, const Tindex* input_row_ends,
146149
const bool* empty_row_indicator, Tindex* output_indices, T* output_values) {
147-
GPU_1D_KERNEL_LOOP(row, cfg.virtual_thread_count) {
150+
for (int64_t row : ::tensorflow::GpuGridRangeX(cfg.virtual_thread_count)) {
148151
if (!empty_row_indicator[row]) continue; // Only process empty rows
149152
Tindex input_i = (row == 0 ? 0 : input_row_ends[row - 1]);
150153
Tindex output_i = input_i + (row == 0 ? 0 : num_new_rows_through[row - 1]);
@@ -489,9 +492,9 @@ namespace {
489492

490493
template <typename T, typename Tindex>
491494
__global__ __launch_bounds__(1024) void GatherOriginalGradValuesKernel(
492-
GpuLaunchConfig cfg, const Tindex* reverse_index_map, const T* grad_values,
493-
T* d_values, bool* visited, Tindex N_full) {
494-
GPU_1D_KERNEL_LOOP(input_i, cfg.virtual_thread_count) {
495+
GpuLaunchConfig64 cfg, const Tindex* reverse_index_map,
496+
const T* grad_values, T* d_values, bool* visited, Tindex N_full) {
497+
for (int64_t input_i : GpuGridRangeX(cfg.virtual_thread_count)) {
495498
Tindex output_i = reverse_index_map[input_i];
496499
if (output_i >= 0 && output_i < N_full) {
497500
d_values[input_i] = grad_values[output_i];

tensorflow/core/kernels/gather_functor_gpu.cu.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#define EIGEN_USE_GPU
2222

23+
#include "xla/tsl/platform/statusor.h"
2324
#include "tensorflow/core/framework/register_types.h"
2425
#include "tensorflow/core/kernels/gather_functor.h"
2526
#include "tensorflow/core/platform/types.h"
@@ -35,7 +36,7 @@ __global__ void GatherOpKernel(const ValueOrVec* __restrict__ params,
3536
ValueOrVec* __restrict__ out,
3637
int64 gather_dim_size, int64 indices_size,
3738
int64 slice_size, int64 out_size) {
38-
GPU_1D_KERNEL_LOOP(i, out_size) {
39+
for (int64_t i : GpuGridRangeX(out_size)) {
3940
Index batch_i = 0;
4041
Index indices_i = 0;
4142
Index slice_i = 0;
@@ -91,9 +92,12 @@ struct LaunchGatherKernelVectorized {
9192
const Tvec* params_vec = reinterpret_cast<const Tvec*>(params);
9293
Tvec* out_vec = reinterpret_cast<Tvec*>(out);
9394

94-
GpuLaunchConfig config = GetGpuLaunchConfig(
95-
out_size_vec, d, &GatherOpKernel<Tvec, Index, is_axis_zero>,
96-
/*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
95+
TF_ASSIGN_OR_RETURN(
96+
GpuLaunchConfig64 config,
97+
GetGpuLaunchConfig64(out_size_vec, d,
98+
&GatherOpKernel<Tvec, Index, is_axis_zero>,
99+
/*dynamic_shared_memory_size=*/0,
100+
/*block_size_limit=*/0));
97101
return GpuLaunchKernel(
98102
GatherOpKernel<Tvec, Index, is_axis_zero>, config.block_count,
99103
config.thread_per_block, 0, d.stream(), params_vec, indices, out_vec,
@@ -142,13 +146,21 @@ struct GatherFunctor<GPUDevice, T, Index> {
142146
const int64 slice_size = params.dimension(2);
143147

144148
if (is_axis_zero) {
145-
TF_CHECK_OK(LaunchGatherKernel<true>(d, params.data(), indices.data(),
146-
out.data(), gather_dim_size,
147-
indices_size, slice_size, out_size));
149+
Status status = LaunchGatherKernel<true>(
150+
d, params.data(), indices.data(), out.data(), gather_dim_size,
151+
indices_size, slice_size, out_size);
152+
if (!status.ok()) {
153+
ctx->CtxFailure(__FILE__, __LINE__, status);
154+
return -1;
155+
}
148156
} else {
149-
TF_CHECK_OK(LaunchGatherKernel<false>(
157+
Status status = LaunchGatherKernel<false>(
150158
d, params.data(), indices.data(), out.data(), gather_dim_size,
151-
indices_size, slice_size, out_size));
159+
indices_size, slice_size, out_size);
160+
if (!status.ok()) {
161+
ctx->CtxFailure(__FILE__, __LINE__, status);
162+
return -1;
163+
}
152164
}
153165
// TODO(fpmc): enable indices validation on GPU.
154166
// Right now checking for indices out of bound in the kernel would

0 commit comments

Comments
 (0)