@@ -18,7 +18,6 @@ limitations under the License.
1818#define EIGEN_USE_GPU
1919
2020#include " unsupported/Eigen/CXX11/Tensor" // from @eigen_archive
21- #include " xla/tsl/platform/statusor.h"
2221#include " tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
2322#include " tensorflow/core/framework/register_types.h"
2423#include " tensorflow/core/framework/tensor_types.h"
@@ -55,8 +54,7 @@ typename T::ConstPointerType to_pointers(const T& x) {
5554template <typename Tindex, typename ... CallerArgs, typename ... KernelArgs>
5655Status wrap_kernel_call (void (*func)(KernelArgs...), const GPUDevice& device,
5756 Tindex size, CallerArgs... args) {
58- TF_ASSIGN_OR_RETURN (GpuLaunchConfig64 config,
59- GetGpuLaunchConfig64 (size, device));
57+ auto config = GetGpuLaunchConfig (size, device);
6058 return GpuLaunchKernel (func, config.block_count , config.thread_per_block , 0 ,
6159 device.stream (), config, to_pointers (args)...);
6260}
@@ -79,10 +77,10 @@ struct CastFunctor {
7977// true if the indices are not ordered by row.
8078template <typename Tindex>
8179__global__ __launch_bounds__ (1024 ) void CountElementsPerRowKernel(
82- GpuLaunchConfig64 cfg, Tindex dense_rows, int rank, const Tindex* indices,
80+ GpuLaunchConfig cfg, Tindex dense_rows, int rank, const Tindex* indices,
8381 Tindex* elements_per_row, int * rows_are_not_ordered,
8482 int * first_invalid_index) {
85- for ( int64_t i : GpuGridRangeX ( cfg.virtual_thread_count ) ) {
83+ GPU_1D_KERNEL_LOOP (i, cfg.virtual_thread_count ) {
8684 Tindex row = indices[i * rank];
8785 if (row < 0 || row >= dense_rows) {
8886 GpuAtomicMin (first_invalid_index, i);
@@ -100,19 +98,18 @@ __global__ __launch_bounds__(1024) void CountElementsPerRowKernel(
10098
10199template <typename Tindex>
102100__global__ __launch_bounds__ (1024 ) void CopyRowIndicesKernel(
103- GpuLaunchConfig64 cfg, int rank, const Tindex* indices,
104- Tindex* row_indices) {
105- for (int64_t i : GpuGridRangeX (cfg.virtual_thread_count )) {
101+ GpuLaunchConfig cfg, int rank, const Tindex* indices, Tindex* row_indices) {
102+ GPU_1D_KERNEL_LOOP (i, cfg.virtual_thread_count ) {
106103 row_indices[i] = indices[i * rank];
107104 }
108105}
109106
110107// Sets empty_row_indicator[row] to whether the row is empty.
111108template <typename Tindex>
112109__global__ __launch_bounds__ (1024 ) void ComputeEmptyRowIndicatorKernel(
113- GpuLaunchConfig64 cfg, const Tindex* elements_per_row,
110+ GpuLaunchConfig cfg, const Tindex* elements_per_row,
114111 bool * empty_row_indicator) {
115- for ( int64_t row : GpuGridRangeX ( cfg.virtual_thread_count ) ) {
112+ GPU_1D_KERNEL_LOOP ( row, cfg.virtual_thread_count ) {
116113 empty_row_indicator[row] = elements_per_row[row] == 0 ;
117114 }
118115}
@@ -122,11 +119,11 @@ __global__ __launch_bounds__(1024) void ComputeEmptyRowIndicatorKernel(
122119// empty row.
123120template <typename T, typename Tindex>
124121__global__ __launch_bounds__ (1024 ) void ScatterInputElementsKernel(
125- GpuLaunchConfig64 cfg, Tindex dense_rows, int rank,
122+ GpuLaunchConfig cfg, Tindex dense_rows, int rank,
126123 const Tindex* input_index_map, const Tindex* indices, const T* values,
127124 const Tindex* num_new_rows_before, Tindex* output_indices, T* output_values,
128125 Tindex* reverse_index_map) {
129- for ( int64_t i : :: tensorflow::GpuGridRangeX ( cfg.virtual_thread_count ) ) {
126+ GPU_1D_KERNEL_LOOP (i, cfg.virtual_thread_count ) {
130127 Tindex input_i = input_index_map ? input_index_map[i] : i;
131128 Tindex row = indices[input_i * rank];
132129 Tindex output_i = i + num_new_rows_before[row];
@@ -144,10 +141,10 @@ __global__ __launch_bounds__(1024) void ScatterInputElementsKernel(
144141// input) in output_indices and output_values.
145142template <typename T, typename Tindex>
146143__global__ __launch_bounds__ (1024 ) void ScatterNewElementsKernel(
147- GpuLaunchConfig64 cfg, int rank, const T* default_value,
144+ GpuLaunchConfig cfg, int rank, const T* default_value,
148145 const Tindex* num_new_rows_through, const Tindex* input_row_ends,
149146 const bool * empty_row_indicator, Tindex* output_indices, T* output_values) {
150- for ( int64_t row : :: tensorflow::GpuGridRangeX ( cfg.virtual_thread_count ) ) {
147+ GPU_1D_KERNEL_LOOP ( row, cfg.virtual_thread_count ) {
151148 if (!empty_row_indicator[row]) continue ; // Only process empty rows
152149 Tindex input_i = (row == 0 ? 0 : input_row_ends[row - 1 ]);
153150 Tindex output_i = input_i + (row == 0 ? 0 : num_new_rows_through[row - 1 ]);
@@ -492,9 +489,9 @@ namespace {
492489
493490template <typename T, typename Tindex>
494491__global__ __launch_bounds__ (1024 ) void GatherOriginalGradValuesKernel(
495- GpuLaunchConfig64 cfg, const Tindex* reverse_index_map,
496- const T* grad_values, T* d_values, bool * visited, Tindex N_full) {
497- for ( int64_t input_i : GpuGridRangeX ( cfg.virtual_thread_count ) ) {
492+ GpuLaunchConfig cfg, const Tindex* reverse_index_map, const T* grad_values ,
493+ T* d_values, bool * visited, Tindex N_full) {
494+ GPU_1D_KERNEL_LOOP ( input_i, cfg.virtual_thread_count ) {
498495 Tindex output_i = reverse_index_map[input_i];
499496 if (output_i >= 0 && output_i < N_full) {
500497 d_values[input_i] = grad_values[output_i];
0 commit comments