@@ -18,6 +18,7 @@ limitations under the License.
1818#define EIGEN_USE_GPU
1919
2020#include " unsupported/Eigen/CXX11/Tensor" // from @eigen_archive
21+ #include " xla/tsl/platform/statusor.h"
2122#include " tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
2223#include " tensorflow/core/framework/register_types.h"
2324#include " tensorflow/core/framework/tensor_types.h"
@@ -54,7 +55,8 @@ typename T::ConstPointerType to_pointers(const T& x) {
5455template <typename Tindex, typename ... CallerArgs, typename ... KernelArgs>
5556Status wrap_kernel_call (void (*func)(KernelArgs...), const GPUDevice& device,
5657 Tindex size, CallerArgs... args) {
57- auto config = GetGpuLaunchConfig (size, device);
58+ TF_ASSIGN_OR_RETURN (GpuLaunchConfig64 config,
59+ GetGpuLaunchConfig64 (size, device));
5860 return GpuLaunchKernel (func, config.block_count , config.thread_per_block , 0 ,
5961 device.stream (), config, to_pointers (args)...);
6062}
@@ -77,10 +79,10 @@ struct CastFunctor {
7779// true if the indices are not ordered by row.
7880template <typename Tindex>
7981__global__ __launch_bounds__ (1024 ) void CountElementsPerRowKernel(
80- GpuLaunchConfig cfg, Tindex dense_rows, int rank, const Tindex* indices,
82+ GpuLaunchConfig64 cfg, Tindex dense_rows, int rank, const Tindex* indices,
8183 Tindex* elements_per_row, int * rows_are_not_ordered,
8284 int * first_invalid_index) {
83- GPU_1D_KERNEL_LOOP (i, cfg.virtual_thread_count ) {
85+ for ( int64_t i : GpuGridRangeX ( cfg.virtual_thread_count ) ) {
8486 Tindex row = indices[i * rank];
8587 if (row < 0 || row >= dense_rows) {
8688 GpuAtomicMin (first_invalid_index, i);
@@ -98,18 +100,19 @@ __global__ __launch_bounds__(1024) void CountElementsPerRowKernel(
98100
99101template <typename Tindex>
100102__global__ __launch_bounds__ (1024 ) void CopyRowIndicesKernel(
101- GpuLaunchConfig cfg, int rank, const Tindex* indices, Tindex* row_indices) {
102- GPU_1D_KERNEL_LOOP (i, cfg.virtual_thread_count ) {
103+ GpuLaunchConfig64 cfg, int rank, const Tindex* indices,
104+ Tindex* row_indices) {
105+ for (int64_t i : GpuGridRangeX (cfg.virtual_thread_count )) {
103106 row_indices[i] = indices[i * rank];
104107 }
105108}
106109
107110// Sets empty_row_indicator[row] to whether the row is empty.
108111template <typename Tindex>
109112__global__ __launch_bounds__ (1024 ) void ComputeEmptyRowIndicatorKernel(
110- GpuLaunchConfig cfg, const Tindex* elements_per_row,
113+ GpuLaunchConfig64 cfg, const Tindex* elements_per_row,
111114 bool * empty_row_indicator) {
112- GPU_1D_KERNEL_LOOP ( row, cfg.virtual_thread_count ) {
115+ for ( int64_t row : GpuGridRangeX ( cfg.virtual_thread_count ) ) {
113116 empty_row_indicator[row] = elements_per_row[row] == 0 ;
114117 }
115118}
@@ -119,11 +122,11 @@ __global__ __launch_bounds__(1024) void ComputeEmptyRowIndicatorKernel(
119122// empty row.
120123template <typename T, typename Tindex>
121124__global__ __launch_bounds__ (1024 ) void ScatterInputElementsKernel(
122- GpuLaunchConfig cfg, Tindex dense_rows, int rank,
125+ GpuLaunchConfig64 cfg, Tindex dense_rows, int rank,
123126 const Tindex* input_index_map, const Tindex* indices, const T* values,
124127 const Tindex* num_new_rows_before, Tindex* output_indices, T* output_values,
125128 Tindex* reverse_index_map) {
126- GPU_1D_KERNEL_LOOP (i, cfg.virtual_thread_count ) {
129+ for ( int64_t i : :: tensorflow::GpuGridRangeX ( cfg.virtual_thread_count ) ) {
127130 Tindex input_i = input_index_map ? input_index_map[i] : i;
128131 Tindex row = indices[input_i * rank];
129132 Tindex output_i = i + num_new_rows_before[row];
@@ -141,10 +144,10 @@ __global__ __launch_bounds__(1024) void ScatterInputElementsKernel(
141144// input) in output_indices and output_values.
142145template <typename T, typename Tindex>
143146__global__ __launch_bounds__ (1024 ) void ScatterNewElementsKernel(
144- GpuLaunchConfig cfg, int rank, const T* default_value,
147+ GpuLaunchConfig64 cfg, int rank, const T* default_value,
145148 const Tindex* num_new_rows_through, const Tindex* input_row_ends,
146149 const bool * empty_row_indicator, Tindex* output_indices, T* output_values) {
147- GPU_1D_KERNEL_LOOP ( row, cfg.virtual_thread_count ) {
150+ for ( int64_t row : :: tensorflow::GpuGridRangeX ( cfg.virtual_thread_count ) ) {
148151 if (!empty_row_indicator[row]) continue ; // Only process empty rows
149152 Tindex input_i = (row == 0 ? 0 : input_row_ends[row - 1 ]);
150153 Tindex output_i = input_i + (row == 0 ? 0 : num_new_rows_through[row - 1 ]);
@@ -489,9 +492,9 @@ namespace {
489492
490493template <typename T, typename Tindex>
491494__global__ __launch_bounds__ (1024 ) void GatherOriginalGradValuesKernel(
492- GpuLaunchConfig cfg, const Tindex* reverse_index_map, const T* grad_values ,
493- T* d_values, bool * visited, Tindex N_full) {
494- GPU_1D_KERNEL_LOOP ( input_i, cfg.virtual_thread_count ) {
495+ GpuLaunchConfig64 cfg, const Tindex* reverse_index_map,
496+ const T* grad_values, T* d_values, bool * visited, Tindex N_full) {
497+ for ( int64_t input_i : GpuGridRangeX ( cfg.virtual_thread_count ) ) {
495498 Tindex output_i = reverse_index_map[input_i];
496499 if (output_i >= 0 && output_i < N_full) {
497500 d_values[input_i] = grad_values[output_i];
0 commit comments