Skip to content

Commit 30e6daf

Browse files
Add (make public) method ComputeLogicalBufferUnpaddedSizes.
Reverts 584d207 PiperOrigin-RevId: 896637498
1 parent 5bd1b3d commit 30e6daf

9 files changed

Lines changed: 151 additions & 243 deletions

tensorflow/core/kernels/depthwise_conv_op_gpu.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ limitations under the License.
3737
#endif
3838

3939
namespace tensorflow {
40-
using Eigen::numext::div_ceil;
4140

4241
namespace detail {
4342
template <typename T>
@@ -641,7 +640,7 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
641640
case FORMAT_NHWC:
642641
block_dim = dim3(kBlockDepth, args.in_cols, block_height);
643642
block_count =
644-
args.batch * div_ceil(args.out_depth, kBlockDepth) * kBlockDepth;
643+
args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth;
645644
kernel =
646645
DepthwiseConv2dGPUKernelNHWCSmall<T, kDirection, kKnownFilterWidth,
647646
kKnownFilterHeight, kBlockDepth,
@@ -650,7 +649,7 @@ Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx,
650649
case FORMAT_NCHW:
651650
block_dim = dim3(args.in_cols, block_height, kBlockDepth);
652651
block_count =
653-
div_ceil(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
652+
DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
654653
kernel =
655654
DepthwiseConv2dGPUKernelNCHWSmall<T, kDirection, kKnownFilterWidth,
656655
kKnownFilterHeight, kBlockDepth,
@@ -1568,14 +1567,14 @@ Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
15681567
case FORMAT_NHWC:
15691568
block_dim = dim3(kBlockDepth, args.in_cols, block_height);
15701569
block_count =
1571-
args.batch * div_ceil(args.out_depth, kBlockDepth) * kBlockDepth;
1570+
args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth;
15721571
kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<
15731572
T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>;
15741573
break;
15751574
case FORMAT_NCHW:
15761575
block_dim = dim3(args.in_cols, block_height, kBlockDepth);
15771576
block_count =
1578-
div_ceil(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
1577+
DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth;
15791578
kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<
15801579
T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>;
15811580
break;

tensorflow/core/kernels/fill_empty_rows_functor_gpu.cu.cc

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License.
1818
#define EIGEN_USE_GPU
1919

2020
#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive
21-
#include "xla/tsl/platform/statusor.h"
2221
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
2322
#include "tensorflow/core/framework/register_types.h"
2423
#include "tensorflow/core/framework/tensor_types.h"
@@ -55,8 +54,7 @@ typename T::ConstPointerType to_pointers(const T& x) {
5554
template <typename Tindex, typename... CallerArgs, typename... KernelArgs>
5655
Status wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& device,
5756
Tindex size, CallerArgs... args) {
58-
TF_ASSIGN_OR_RETURN(GpuLaunchConfig64 config,
59-
GetGpuLaunchConfig64(size, device));
57+
auto config = GetGpuLaunchConfig(size, device);
6058
return GpuLaunchKernel(func, config.block_count, config.thread_per_block, 0,
6159
device.stream(), config, to_pointers(args)...);
6260
}
@@ -79,10 +77,10 @@ struct CastFunctor {
7977
// true if the indices are not ordered by row.
8078
template <typename Tindex>
8179
__global__ __launch_bounds__(1024) void CountElementsPerRowKernel(
82-
GpuLaunchConfig64 cfg, Tindex dense_rows, int rank, const Tindex* indices,
80+
GpuLaunchConfig cfg, Tindex dense_rows, int rank, const Tindex* indices,
8381
Tindex* elements_per_row, int* rows_are_not_ordered,
8482
int* first_invalid_index) {
85-
for (int64_t i : GpuGridRangeX(cfg.virtual_thread_count)) {
83+
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
8684
Tindex row = indices[i * rank];
8785
if (row < 0 || row >= dense_rows) {
8886
GpuAtomicMin(first_invalid_index, i);
@@ -100,19 +98,18 @@ __global__ __launch_bounds__(1024) void CountElementsPerRowKernel(
10098

10199
template <typename Tindex>
102100
__global__ __launch_bounds__(1024) void CopyRowIndicesKernel(
103-
GpuLaunchConfig64 cfg, int rank, const Tindex* indices,
104-
Tindex* row_indices) {
105-
for (int64_t i : GpuGridRangeX(cfg.virtual_thread_count)) {
101+
GpuLaunchConfig cfg, int rank, const Tindex* indices, Tindex* row_indices) {
102+
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
106103
row_indices[i] = indices[i * rank];
107104
}
108105
}
109106

110107
// Sets empty_row_indicator[row] to whether the row is empty.
111108
template <typename Tindex>
112109
__global__ __launch_bounds__(1024) void ComputeEmptyRowIndicatorKernel(
113-
GpuLaunchConfig64 cfg, const Tindex* elements_per_row,
110+
GpuLaunchConfig cfg, const Tindex* elements_per_row,
114111
bool* empty_row_indicator) {
115-
for (int64_t row : GpuGridRangeX(cfg.virtual_thread_count)) {
112+
GPU_1D_KERNEL_LOOP(row, cfg.virtual_thread_count) {
116113
empty_row_indicator[row] = elements_per_row[row] == 0;
117114
}
118115
}
@@ -122,11 +119,11 @@ __global__ __launch_bounds__(1024) void ComputeEmptyRowIndicatorKernel(
122119
// empty row.
123120
template <typename T, typename Tindex>
124121
__global__ __launch_bounds__(1024) void ScatterInputElementsKernel(
125-
GpuLaunchConfig64 cfg, Tindex dense_rows, int rank,
122+
GpuLaunchConfig cfg, Tindex dense_rows, int rank,
126123
const Tindex* input_index_map, const Tindex* indices, const T* values,
127124
const Tindex* num_new_rows_before, Tindex* output_indices, T* output_values,
128125
Tindex* reverse_index_map) {
129-
for (int64_t i : ::tensorflow::GpuGridRangeX(cfg.virtual_thread_count)) {
126+
GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
130127
Tindex input_i = input_index_map ? input_index_map[i] : i;
131128
Tindex row = indices[input_i * rank];
132129
Tindex output_i = i + num_new_rows_before[row];
@@ -144,10 +141,10 @@ __global__ __launch_bounds__(1024) void ScatterInputElementsKernel(
144141
// input) in output_indices and output_values.
145142
template <typename T, typename Tindex>
146143
__global__ __launch_bounds__(1024) void ScatterNewElementsKernel(
147-
GpuLaunchConfig64 cfg, int rank, const T* default_value,
144+
GpuLaunchConfig cfg, int rank, const T* default_value,
148145
const Tindex* num_new_rows_through, const Tindex* input_row_ends,
149146
const bool* empty_row_indicator, Tindex* output_indices, T* output_values) {
150-
for (int64_t row : ::tensorflow::GpuGridRangeX(cfg.virtual_thread_count)) {
147+
GPU_1D_KERNEL_LOOP(row, cfg.virtual_thread_count) {
151148
if (!empty_row_indicator[row]) continue; // Only process empty rows
152149
Tindex input_i = (row == 0 ? 0 : input_row_ends[row - 1]);
153150
Tindex output_i = input_i + (row == 0 ? 0 : num_new_rows_through[row - 1]);
@@ -492,9 +489,9 @@ namespace {
492489

493490
template <typename T, typename Tindex>
494491
__global__ __launch_bounds__(1024) void GatherOriginalGradValuesKernel(
495-
GpuLaunchConfig64 cfg, const Tindex* reverse_index_map,
496-
const T* grad_values, T* d_values, bool* visited, Tindex N_full) {
497-
for (int64_t input_i : GpuGridRangeX(cfg.virtual_thread_count)) {
492+
GpuLaunchConfig cfg, const Tindex* reverse_index_map, const T* grad_values,
493+
T* d_values, bool* visited, Tindex N_full) {
494+
GPU_1D_KERNEL_LOOP(input_i, cfg.virtual_thread_count) {
498495
Tindex output_i = reverse_index_map[input_i];
499496
if (output_i >= 0 && output_i < N_full) {
500497
d_values[input_i] = grad_values[output_i];

tensorflow/core/kernels/gather_functor_gpu.cu.h

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ limitations under the License.
2020

2121
#define EIGEN_USE_GPU
2222

23-
#include "xla/tsl/platform/statusor.h"
2423
#include "tensorflow/core/framework/register_types.h"
2524
#include "tensorflow/core/kernels/gather_functor.h"
2625
#include "tensorflow/core/platform/types.h"
@@ -36,7 +35,7 @@ __global__ void GatherOpKernel(const ValueOrVec* __restrict__ params,
3635
ValueOrVec* __restrict__ out,
3736
int64 gather_dim_size, int64 indices_size,
3837
int64 slice_size, int64 out_size) {
39-
for (int64_t i : GpuGridRangeX(out_size)) {
38+
GPU_1D_KERNEL_LOOP(i, out_size) {
4039
Index batch_i = 0;
4140
Index indices_i = 0;
4241
Index slice_i = 0;
@@ -92,12 +91,9 @@ struct LaunchGatherKernelVectorized {
9291
const Tvec* params_vec = reinterpret_cast<const Tvec*>(params);
9392
Tvec* out_vec = reinterpret_cast<Tvec*>(out);
9493

95-
TF_ASSIGN_OR_RETURN(
96-
GpuLaunchConfig64 config,
97-
GetGpuLaunchConfig64(out_size_vec, d,
98-
&GatherOpKernel<Tvec, Index, is_axis_zero>,
99-
/*dynamic_shared_memory_size=*/0,
100-
/*block_size_limit=*/0));
94+
GpuLaunchConfig config = GetGpuLaunchConfig(
95+
out_size_vec, d, &GatherOpKernel<Tvec, Index, is_axis_zero>,
96+
/*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
10197
return GpuLaunchKernel(
10298
GatherOpKernel<Tvec, Index, is_axis_zero>, config.block_count,
10399
config.thread_per_block, 0, d.stream(), params_vec, indices, out_vec,
@@ -146,21 +142,13 @@ struct GatherFunctor<GPUDevice, T, Index> {
146142
const int64 slice_size = params.dimension(2);
147143

148144
if (is_axis_zero) {
149-
Status status = LaunchGatherKernel<true>(
150-
d, params.data(), indices.data(), out.data(), gather_dim_size,
151-
indices_size, slice_size, out_size);
152-
if (!status.ok()) {
153-
ctx->CtxFailure(__FILE__, __LINE__, status);
154-
return -1;
155-
}
145+
TF_CHECK_OK(LaunchGatherKernel<true>(d, params.data(), indices.data(),
146+
out.data(), gather_dim_size,
147+
indices_size, slice_size, out_size));
156148
} else {
157-
Status status = LaunchGatherKernel<false>(
149+
TF_CHECK_OK(LaunchGatherKernel<false>(
158150
d, params.data(), indices.data(), out.data(), gather_dim_size,
159-
indices_size, slice_size, out_size);
160-
if (!status.ok()) {
161-
ctx->CtxFailure(__FILE__, __LINE__, status);
162-
return -1;
163-
}
151+
indices_size, slice_size, out_size));
164152
}
165153
// TODO(fpmc): enable indices validation on GPU.
166154
// Right now checking for indices out of bound in the kernel would

0 commit comments

Comments
 (0)