diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index ccd54f26668a3a..1d24a509c2edf6 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1,5 +1,5 @@ load("@xla//third_party/rules_python/python:py_library.bzl", "py_library") -load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") +load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_cuda_cc_test") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/compiler/tests:build_combined_defs.bzl", "tf_xla_combined_py_test") @@ -44,7 +44,7 @@ package_group( generate_backend_suites() -py_strict_library( +py_library( name = "xla_test", testonly = 1, srcs = ["xla_test.py"], @@ -82,7 +82,7 @@ py_library( ], ) -py_strict_test( +py_test( name = "xla_test_test", size = "small", srcs = ["xla_test_test.py"], diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index df982a41602d9b..b849caf76638d4 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -48,7 +48,9 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/nccl/collective_communicator.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/host_info.h" @@ -66,62 +68,15 @@ absl::Status GetNumRetvals( FunctionLibraryDefinition* func_lib_def, const std::string& op_name, const google::protobuf::Map& attrs, int* num_retvals) { - const tensorflow::OpRegistrationData* op_reg_data = nullptr; - auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); - if (absl::IsNotFound(status)) { - status = func_lib_def->LookUp(op_name, &op_reg_data); - } - TF_RETURN_IF_ERROR(status); + const OpDef* op_def = nullptr; + TF_RETURN_IF_ERROR(func_lib_def->LookUpOpDef(op_name, &op_def)); - const tensorflow::OpDef& op_def = op_reg_data->op_def; + NodeDef ndef; + ndef.set_op(op_name); + *ndef.mutable_attr() = attrs; + AddDefaultsToNodeDef(*op_def, &ndef); - for (const auto& output_arg : op_def.output_arg()) { - if (!output_arg.number_attr().empty()) { - auto iter = attrs.find(output_arg.number_attr()); - if (iter == attrs.end()) { - return absl::InvalidArgumentError( - absl::StrCat("Unable to find number_attr ", - output_arg.number_attr(), " for Op: ", op_name)); - } - int64_t repeats = iter->second.i(); - if (repeats < 0) { - return absl::InvalidArgumentError( - absl::StrCat("Expected >= 0 number_attr for Op: ", op_name, - ", but got ", repeats)); - } - if (repeats > std::numeric_limits::max() - *num_retvals) { - return absl::InvalidArgumentError( - absl::StrCat("Too many return values for Op: ", op_name)); - } - *num_retvals += repeats; - } else if (!output_arg.type_list_attr().empty()) { - auto iter = attrs.find(output_arg.type_list_attr()); - if (iter == attrs.end()) { - return absl::InvalidArgumentError( - absl::StrCat("Unable to find type_list_attr ", - output_arg.type_list_attr(), " for Op: ", op_name)); - } - int64_t repeats = iter->second.list().type_size(); - if (repeats < 0) { - return absl::InvalidArgumentError( - absl::StrCat("Expected >= 0 type_list_attr size for Op: ", op_name, - ", but got ", repeats)); - } - if (repeats > std::numeric_limits::max() - *num_retvals) { - return absl::InvalidArgumentError( - absl::StrCat("Too many return values for Op: ", op_name)); - } - *num_retvals += repeats; - } else { - if (*num_retvals >= std::numeric_limits::max()) { - return absl::InvalidArgumentError( - absl::StrCat("Too many return values for Op: ", op_name)); - } - *num_retvals += 1; - } - } - - return absl::OkStatus(); + return NumOutputsForNode(ndef, *op_def, num_retvals); } absl::Status GetEagerOperationAndNumRetvals(const Operation& operation, @@ -904,6 +859,13 @@ absl::Status EagerServiceImpl::SendPackedHandle( std::vector handles; handles.resize(send_packed_handle.handles_size()); + // Cleanup handles in case of early exit due to errors. + auto cleanup = tensorflow::gtl::MakeCleanup([&handles] { + for (auto* h : handles) { + if (h) h->Unref(); + } + }); + for (int i = 0; i < send_packed_handle.handles_size(); ++i) { const auto& item = send_packed_handle.handles(i); if (item.has_local_handle()) { @@ -914,24 +876,43 @@ absl::Status EagerServiceImpl::SendPackedHandle( item.local_handle().tensor().DebugString())); } Device* op_device = nullptr; - TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName( - item.local_handle().device().c_str(), &op_device)); + absl::Status status = eager_context->FindDeviceFromName( + item.local_handle().device().c_str(), &op_device); + if (!status.ok()) { + return status; + } handles[i] = TensorHandle::CreateLocalHandle( std::move(tensor), /*d=*/nullptr, op_device, eager_context); } else { - TF_RETURN_IF_ERROR( + absl::Status status = eager_context->RemoteMgr()->DeserializeRemoteTensorHandle( - item.remote_handle(), &handles[i])); + item.remote_handle(), &handles[i]); + if (!status.ok()) { + return status; + } + } + } + + tensorflow::DataType dtype = handles.at(0)->dtype; + for (int i = 1; i < handles.size(); ++i) { + if (handles.at(i)->dtype != dtype) { + return absl::InvalidArgumentError("Handles do not have the same dtype."); } } tensorflow::TensorHandle* packed_handle = nullptr; std::vector handles_to_pack = handles; // Create a unshaped packed TensorHandle. - TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle( + absl::Status s = TensorHandle::CreatePackedHandle( std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(), - send_packed_handle.device_name(), eager_context, &packed_handle)); + send_packed_handle.device_name(), eager_context, &packed_handle); + if (!s.ok()) { + return s; + } + // Cancel the cleanup for the individual handles, as they are now refcounted + // by `packed_handle`. + cleanup.release(); for (auto* h : handles) { // Unref handle since it has a ref in the packed handle now. h->Unref(); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index e9be274d4fea19..3c2dd7e368d231 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -1404,6 +1404,7 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) { remote_handle->set_output_num(5); remote_handle->set_op_device(device2); remote_handle->set_device(device2); + remote_handle->set_dtype(tensorflow::DataType::DT_FLOAT); TF_ASSERT_OK(eager_service_impl.Enqueue(nullptr, &remote_enqueue_request, &remote_enqueue_response)); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index aee63513314c19..1caa74b7cd87cc 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -549,6 +549,12 @@ cc_library( cc_library( name = "pooling_ops_gpu_hdrs", hdrs = ["maxpooling_op_gpu.h"], + deps = [ + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core/framework:tensor_types", + "//tensorflow/core/util:tensor_format", + "@com_google_absl//absl/status", + ], ) # We keep this target only because some contrib/ targets depend on it. The @@ -4350,6 +4356,8 @@ tf_kernel_library( ":redux_functor", "//tensorflow/core/profiler/lib:scoped_annotation", "//tensorflow/core/util:determinism_for_kernels", + "//tensorflow/core/util:overflow", + "@com_google_absl//absl/status", ] + if_cuda_or_rocm([ ":reduction_ops", "@xla//xla/stream_executor:event_based_timer", @@ -4642,8 +4650,13 @@ tf_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core/framework:bounds_check", "//tensorflow/core/platform:stream_executor", + "//tensorflow/core/util:overflow", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", + "@local_config_cuda//cuda:cudnn_header", + "@xla//xla/tsl/framework/fixedpoint", ], ) @@ -4675,7 +4688,19 @@ cc_library( deps = [ ":eigen_helpers", ":ops_util_hdrs", + ":pooling_ops_gpu_hdrs", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core/framework:bounds_check", + "//tensorflow/core/framework:numeric_types", + "//tensorflow/core/framework:tensor_shape", + "//tensorflow/core/framework:tensor_types", + "//tensorflow/core/util:padding", + "//tensorflow/core/util:tensor_format", + "@com_google_absl//absl/status", "@eigen_archive//:eigen3", + "@xla//xla/tsl/framework/fixedpoint", ], ) @@ -4753,6 +4778,8 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/util:overflow", + "@com_google_absl//absl/status", "@eigen_archive//:eigen3", ], ) @@ -6661,7 +6688,6 @@ filegroup( "transpose_op.h", "where_op.h", "xent_op.h", - ] + [ "//tensorflow/core/kernels/data:portable_all_op_kernels_headers", "//tensorflow/core/kernels/image:adjust_contrast_op.h", "//tensorflow/core/kernels/image:adjust_hue_op.h", @@ -6810,9 +6836,6 @@ filegroup( "population_count_op.h", "winograd_transform.h", ":portable_extended_ops_headers", - "@xla//xla/tsl/framework/contraction:eigen_contraction_kernel.cc", - "@xla//xla/tsl/framework/contraction:eigen_contraction_kernel.h", - ] + [ "//tensorflow/core/kernels/image:colorspace_op.cc", "//tensorflow/core/kernels/image:crop_and_resize_op.cc", "//tensorflow/core/kernels/image:crop_and_resize_op.h", @@ -6830,6 +6853,8 @@ filegroup( "//tensorflow/core/kernels/linalg:einsum_op_impl_int32.cc", "//tensorflow/core/kernels/linalg:einsum_op_impl_int64.cc", "//tensorflow/core/kernels/uniform_quant_ops:portable_all_op_kernels", + "@xla//xla/tsl/framework/contraction:eigen_contraction_kernel.cc", + "@xla//xla/tsl/framework/contraction:eigen_contraction_kernel.h", ], ) @@ -7201,6 +7226,7 @@ cc_library( linkopts = if_android(["-ldl"]), tags = [ "manual", + "nofixdeps", "notap", ], # These headers are not self-contained, so should be included in textual_hdrs only. @@ -8023,8 +8049,8 @@ cc_library( ) # For a more maintainable build this target should not exist and the headers -# should be split into the existing cc_library targets, but this change was -# automatically done so that we can remove long standing issues and complexity +# should be split into the existing cc_library targets, but this change was +# automatically done so that we can remove long standing issues and complexity # in the build system. It's up to the OWNERS of this package to get rid of it or # not. The use of the textual_hdrs attribute is discouraged, use hdrs instead. # Here it is used to avoid header parsing errors in packages where the feature @@ -8035,16 +8061,14 @@ cc_library( tags = ["avoid_dep"], textual_hdrs = glob(["*.h"]), visibility = [ - "//visibility:public", - ], - deps = [ - "//tensorflow/core/framework:graph_proto_cc", - "//tensorflow/core/framework:node_def_proto_cc", - "//tensorflow/core/framework:types_proto_cc", - "@com_google_absl//absl/synchronization", + "//smartass/brain:__subpackages__", + "//tensorflow:__subpackages__", ], ) +# Deleted deps +# [ + tf_kernel_library( name = "stochastic_cast_op", features = ["-layering_check"], diff --git a/tensorflow/core/kernels/avgpooling_op.cc b/tensorflow/core/kernels/avgpooling_op.cc index 1ba5683142edc7..627e4cbad49236 100644 --- a/tensorflow/core/kernels/avgpooling_op.cc +++ b/tensorflow/core/kernels/avgpooling_op.cc @@ -623,21 +623,23 @@ class AvgPoolingGradOpCustomGPUKernel : public OpKernel { in_cols, window_cols, /*dilation_rate=*/1, col_stride, padding_, &out_width, &pad_cols)); - RunAvePoolBackwardNHWC(out_backprop.flat().data(), // top_diff - out_backprop_batch, // num - in_rows, // height - in_cols, // width - out_backprop_depth, // channels - out_backprop_rows, // pooled_height - out_backprop_cols, // pooled_width - window_rows, // kernel_h - window_cols, // kernel_w - row_stride, // stride_h - col_stride, // stride_w - pad_rows, // pad_t - pad_cols, // pad_l - output->flat().data(), // bottom_diff - context->eigen_gpu_device()); // d + OP_REQUIRES_OK( + context, + RunAvePoolBackwardNHWC(out_backprop.flat().data(), // top_diff + out_backprop_batch, // num + in_rows, // height + in_cols, // width + out_backprop_depth, // channels + out_backprop_rows, // pooled_height + out_backprop_cols, // pooled_width + window_rows, // kernel_h + window_cols, // kernel_w + row_stride, // stride_h + col_stride, // stride_w + pad_rows, // pad_t + pad_cols, // pad_l + output->flat().data(), // bottom_diff + context->eigen_gpu_device())); // d } else { DnnPoolingGradOp::Compute(context, se::dnn::PoolingMode::kAverage, ksize_, stride_, padding_, diff --git a/tensorflow/core/kernels/avgpooling_op.h b/tensorflow/core/kernels/avgpooling_op.h index 8008c3c43bc0b1..07b8e734089a2c 100644 --- a/tensorflow/core/kernels/avgpooling_op.h +++ b/tensorflow/core/kernels/avgpooling_op.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_AVGPOOLING_OP_H_ // Functor definition for AvgPoolingOp, must be compilable by nvcc. +#include "absl/status/status.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/eigen_pooling.h" #include "tensorflow/core/platform/types.h" @@ -62,14 +63,12 @@ typedef Eigen::GpuDevice GPUDevice; // pad_l: padding size to the left side // bottom_diff: backprop to the input of the pooling layer. template -bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num, - const int height, const int width, - const int channels, const int pooled_height, - const int pooled_width, const int kernel_h, - const int kernel_w, const int stride_h, - const int stride_w, const int pad_t, - const int pad_l, T* const bottom_diff, - const GPUDevice& d); +absl::Status RunAvePoolBackwardNHWC(const T* top_diff, int num, int height, + int width, int channels, int pooled_height, + int pooled_width, int kernel_h, + int kernel_w, int stride_h, int stride_w, + int pad_t, int pad_l, T* bottom_diff, + const GPUDevice& d); } // namespace tensorflow diff --git a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc index 59d79a1ed7876a..17b74a29ddf275 100644 --- a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc @@ -20,11 +20,13 @@ limitations under the License. #include #include +#include #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/avgpooling_op.h" #include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/overflow.h" namespace tensorflow { @@ -80,15 +82,23 @@ __global__ void AvePoolBackwardNHWC( } template -bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num, - const int height, const int width, - const int channels, const int pooled_height, - const int pooled_width, const int kernel_h, - const int kernel_w, const int stride_h, - const int stride_w, const int pad_t, - const int pad_l, T* const bottom_diff, - const GPUDevice& d) { - int x_size = num * height * width * channels; +absl::Status RunAvePoolBackwardNHWC(const T* const top_diff, const int num, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, + const int pad_l, T* const bottom_diff, + const GPUDevice& d) { + int64_t size_1 = MultiplyWithoutOverflow(num, height); + int64_t size_2 = MultiplyWithoutOverflow(size_1, width); + int64_t x_size = MultiplyWithoutOverflow(size_2, channels); + if (x_size < 0 || x_size > std::numeric_limits::max()) { + return absl::InternalError( + "RunAvePoolBackwardNHWC: num * height * width * channels exceeds " + "int32 bounds"); + } + GpuLaunchConfig config = GetGpuLaunchConfig(x_size, d); TF_CHECK_OK(GpuLaunchKernel( AvePoolBackwardNHWC, config.block_count, config.thread_per_block, 0, @@ -96,11 +106,12 @@ bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num, channels, pooled_height, pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_t, bottom_diff)); - return d.ok(); + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } #define DECLARE_GPU_SPEC(T) \ - template bool RunAvePoolBackwardNHWC( \ + template absl::Status RunAvePoolBackwardNHWC( \ const T* const top_diff, const int num, const int height, \ const int width, const int channels, const int pooled_height, \ const int pooled_width, const int kernel_h, const int kernel_w, \ diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc index 16f9b1574df8c1..4fc49d741492d4 100644 --- a/tensorflow/core/kernels/bias_op.cc +++ b/tensorflow/core/kernels/bias_op.cc @@ -270,10 +270,11 @@ class BiasOp : public BinaryOp { OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 0, input.shape(), &output)); if (input.NumElements() > 0) { - BiasGPU::compute(context->template eigen_device(), - input.flat().data(), bias.flat().data(), - output->flat().data(), batch, width, height, depth, - channel, data_format_); + OP_REQUIRES_OK(context, BiasGPU::compute( + context->template eigen_device(), + input.flat().data(), bias.flat().data(), + output->flat().data(), batch, width, + height, depth, channel, data_format_)); } } @@ -399,10 +400,11 @@ class BiasGradOp : public OpKernel { const Tensor& output_backprop, int32_t batch, int32_t width, int32_t height, int32_t depth, int32_t channel, Tensor* output) { - BiasGradGPU::compute(context->template eigen_device(), - output_backprop.template flat().data(), - output->flat().data(), batch, width, height, - depth, channel, data_format_); + OP_REQUIRES_OK(context, BiasGradGPU::compute( + context->template eigen_device(), + output_backprop.template flat().data(), + output->flat().data(), batch, width, height, + depth, channel, data_format_)); } void ComputeWithReduceSum(OpKernelContext* context, diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc index 4a90e24b2cb17c..f754e5e70df33c 100644 --- a/tensorflow/core/kernels/bias_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/kernels/bias_op_gpu.h" #include +#include #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/core/kernels/reduction_ops_common.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/overflow.h" namespace tensorflow { @@ -81,15 +83,28 @@ __global__ void BiasNCHWKernel(int32_t nthreads, const T* __restrict__ input, // Add "bias" to "input", broadcasting it on all dimensions but the bias // dimension. template -void BiasGPU::compute(const GPUDevice& d, const T* input, const T* bias, - T* output, int32_t batch, int32_t height, - int32_t width, int depth, int32_t channel, - TensorFormat data_format) { +absl::Status BiasGPU::compute(const GPUDevice& d, const T* input, + const T* bias, T* output, int32_t batch, + int32_t height, int32_t width, int depth, + int32_t channel, TensorFormat data_format) { const int32_t bias_size = channel; - const int32_t image_size = height * width * depth; - const int32_t total_count = batch * bias_size * image_size; + int64_t image_size_64 = + MultiplyWithoutOverflow(MultiplyWithoutOverflow(height, width), depth); + int64_t total_count_64 = MultiplyWithoutOverflow( + MultiplyWithoutOverflow(batch, bias_size), image_size_64); + + if (total_count_64 < 0 || + total_count_64 > std::numeric_limits::max() || + image_size_64 < 0 || + image_size_64 > std::numeric_limits::max()) { + return absl::InternalError("BiasGPU: dimensions exceed int32 bounds"); + } + + const int32_t image_size = image_size_64; + const int32_t total_count = total_count_64; + if (total_count == 0) { - return; + return absl::OkStatus(); } if (data_format == FORMAT_NHWC) { GpuLaunchConfig config = @@ -106,6 +121,7 @@ void BiasGPU::compute(const GPUDevice& d, const T* input, const T* bias, config.virtual_thread_count, input, bias, output, bias_size, image_size)); } + return absl::OkStatus(); } // A naive implementation that is functional on all cases. @@ -219,15 +235,30 @@ __global__ void BiasGradNCHW_SharedAtomics( } template -void BiasGradGPU::compute(const GPUDevice& d, const T* output_backprop, - T* bias_backprop, int32_t batch, int32_t height, - int32_t width, int32_t depth, int32_t channel, - TensorFormat data_format) { +absl::Status BiasGradGPU::compute(const GPUDevice& d, + const T* output_backprop, T* bias_backprop, + int32_t batch, int32_t height, + int32_t width, int32_t depth, + int32_t channel, + TensorFormat data_format) { const int32_t bias_size = channel; - const int32_t image_size = height * width * depth; - const int32_t total_count = batch * bias_size * image_size; + int64_t image_size_64 = + MultiplyWithoutOverflow(MultiplyWithoutOverflow(height, width), depth); + int64_t total_count_64 = MultiplyWithoutOverflow( + MultiplyWithoutOverflow(batch, bias_size), image_size_64); + + if (total_count_64 < 0 || + total_count_64 > std::numeric_limits::max() || + image_size_64 < 0 || + image_size_64 > std::numeric_limits::max()) { + return absl::InternalError("BiasGradGPU: dimensions exceed int32 bounds"); + } + + const int32_t image_size = image_size_64; + const int32_t total_count = total_count_64; + if (total_count == 0) { - return; + return absl::OkStatus(); } static constexpr int32_t kWarpSize = 32; GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d); @@ -271,6 +302,7 @@ void BiasGradGPU::compute(const GPUDevice& d, const T* output_backprop, bias_size, image_size)); } } + return absl::OkStatus(); } template diff --git a/tensorflow/core/kernels/bias_op_gpu.h b/tensorflow/core/kernels/bias_op_gpu.h index 60f17e6de240de..b07a9a911a5498 100644 --- a/tensorflow/core/kernels/bias_op_gpu.h +++ b/tensorflow/core/kernels/bias_op_gpu.h @@ -18,6 +18,7 @@ limitations under the License. #define EIGEN_USE_GPU +#include "absl/status/status.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" @@ -30,17 +31,18 @@ typedef Eigen::GpuDevice GPUDevice; template struct BiasGPU { - static void compute(const GPUDevice& d, const T* input, const T* bias, - T* output, int32_t batch, int32_t height, int32_t width, - int32_t depth, int32_t channel, TensorFormat data_format); + static absl::Status compute(const GPUDevice& d, const T* input, const T* bias, + T* output, int32_t batch, int32_t height, + int32_t width, int32_t depth, int32_t channel, + TensorFormat data_format); }; template struct BiasGradGPU { - static void compute(const GPUDevice& device, const T* output_backprop, - T* bias_backprop, int32_t batch, int32_t height, - int32_t width, int32_t depth, int32_t channel, - TensorFormat data_format); + static absl::Status compute(const GPUDevice& device, const T* output_backprop, + T* bias_backprop, int32_t batch, int32_t height, + int32_t width, int32_t depth, int32_t channel, + TensorFormat data_format); static void DoRowReduction(OpKernelContext* context, T* output, const T* input, int rows, int cols); diff --git a/tensorflow/core/kernels/depthtospace_op.cc b/tensorflow/core/kernels/depthtospace_op.cc index 09f92bb964c82d..06397c867bcd4e 100644 --- a/tensorflow/core/kernels/depthtospace_op.cc +++ b/tensorflow/core/kernels/depthtospace_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -126,12 +127,13 @@ class DepthToSpaceOp : public OpKernel { auto Toutput_v = outputs_tensor->reinterpret_last_dimension(); functor::DepthToSpaceOpFunctor functor; - functor(context->eigen_device(), Tinput_v, block_size_, - Toutput_v); + OP_REQUIRES_OK(context, functor(context->eigen_device(), + Tinput_v, block_size_, Toutput_v)); return; } else if (data_format_ == FORMAT_NCHW) { functor::DepthToSpaceOpFunctor functor; - functor(context->eigen_device(), Tinput, block_size_, Toutput); + OP_REQUIRES_OK(context, functor(context->eigen_device(), Tinput, + block_size_, Toutput)); return; } } @@ -141,7 +143,8 @@ class DepthToSpaceOp : public OpKernel { if (!is_int8x4) { functor::DepthToSpaceOpFunctor functor; - functor(context->eigen_device(), Tinput, block_size_, Toutput); + OP_REQUIRES_OK(context, functor(context->eigen_device(), Tinput, + block_size_, Toutput)); } }; @@ -155,8 +158,10 @@ class DepthToSpaceOp : public OpKernel { namespace functor { template struct DepthToSpaceOpFunctor { - void operator()(const CPUDevice& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output) { + absl::Status operator()(const CPUDevice& d, + typename TTypes::ConstTensor input, + int block_size, + typename TTypes::Tensor output) { const int batch_size = output.dimension(0); const int output_height = output.dimension(1); const int output_width = output.dimension(2); @@ -178,6 +183,7 @@ struct DepthToSpaceOpFunctor { } } } + return absl::OkStatus(); } }; } // namespace functor diff --git a/tensorflow/core/kernels/depthtospace_op.h b/tensorflow/core/kernels/depthtospace_op.h index 63dba5d0d5fc7c..4cadac810b14c0 100644 --- a/tensorflow/core/kernels/depthtospace_op.h +++ b/tensorflow/core/kernels/depthtospace_op.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_DEPTHTOSPACE_OP_H_ #define TENSORFLOW_CORE_KERNELS_DEPTHTOSPACE_OP_H_ +#include "absl/status/status.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/util/tensor_format.h" @@ -42,12 +43,14 @@ namespace functor { // n,iY,bY,iX,bX,oC template struct DepthToSpaceOpFunctor { - void operator()(const Device& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output); + absl::Status operator()(const Device& d, + typename TTypes::ConstTensor input, + int block_size, typename TTypes::Tensor output); // This 5-D version is to support NCHW_VECT_C. - void operator()(const Device& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output); + absl::Status operator()(const Device& d, + typename TTypes::ConstTensor input, + int block_size, typename TTypes::Tensor output); }; } // namespace functor diff --git a/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc b/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc index 651167dbed1463..e0a8bdc27935fc 100644 --- a/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc @@ -17,10 +17,13 @@ limitations under the License. #define EIGEN_USE_GPU +#include + #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/depthtospace_op.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/overflow.h" namespace tensorflow { namespace { @@ -145,8 +148,10 @@ namespace functor { template struct DepthToSpaceOpFunctor { - void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output) { + absl::Status operator()(const GPUDevice& d, + typename TTypes::ConstTensor input, + int block_size, + typename TTypes::Tensor output) { const int batch_size = output.dimension(0); const int input_height = input.dimension(1); const int input_width = input.dimension(2); @@ -155,10 +160,19 @@ struct DepthToSpaceOpFunctor { const int output_width = output.dimension(2); const int output_depth = output.dimension(3); - const int total_count = - batch_size * output_height * output_width * output_depth; + int64_t total_count_64 = MultiplyWithoutOverflow( + MultiplyWithoutOverflow( + MultiplyWithoutOverflow(batch_size, output_height), output_width), + output_depth); + if (total_count_64 < 0 || + total_count_64 > std::numeric_limits::max()) { + return absl::InternalError( + "DepthToSpaceOpFunctor NHWC: total_count exceeds int32 bounds"); + } + const int total_count = total_count_64; + if (total_count == 0) { - return; + return absl::OkStatus(); } GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d); TF_CHECK_OK(GpuLaunchKernel( @@ -166,33 +180,73 @@ struct DepthToSpaceOpFunctor { config.virtual_thread_count, input.data(), block_size, batch_size, input_height, input_width, input_depth, output_height, output_width, output_depth, output.data())); + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } - void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output) { - LOG(FATAL) << "5-D tensors should not be used with NHWC format"; + absl::Status operator()(const GPUDevice& d, + typename TTypes::ConstTensor input, + int block_size, + typename TTypes::Tensor output) { + return absl::InternalError( + "5-D tensors should not be used with NHWC format"); } }; template struct DepthToSpaceOpFunctor { - void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output) { + absl::Status operator()(const GPUDevice& d, + typename TTypes::ConstTensor input, + int block_size, + typename TTypes::Tensor output) { const int batch_size = input.dimension(0); const int input_depth = input.dimension(1); const int input_height = input.dimension(2); const int input_width = input.dimension(3); const int output_depth = output.dimension(1); - const int input_area = input_width * input_height; - const int input_depth_by_input_area = input_depth * input_area; + + int64_t input_area_64 = MultiplyWithoutOverflow(input_width, input_height); + if (input_area_64 < 0 || + input_area_64 > std::numeric_limits::max()) { + return absl::InternalError( + "DepthToSpaceOpFunctor NCHW: input_area exceeds int32 bounds"); + } + const int input_area = input_area_64; + + int64_t input_depth_by_input_area_64 = + MultiplyWithoutOverflow(input_depth, input_area); + if (input_depth_by_input_area_64 < 0 || + input_depth_by_input_area_64 > std::numeric_limits::max()) { + return absl::InternalError( + "DepthToSpaceOpFunctor NCHW: input_depth_by_input_area exceeds int32 " + "bounds"); + } + const int input_depth_by_input_area = input_depth_by_input_area_64; // We improve performance by generating instantiations of the loop kernel // for the most common block sizes. if (block_size <= 4) { const int output_width = output.dimension(3); - const int output_depth_by_input_area = output_depth * input_area; - const int total_count = batch_size * output_depth_by_input_area; + int64_t output_depth_by_input_area_64 = + MultiplyWithoutOverflow(output_depth, input_area); + if (output_depth_by_input_area_64 < 0 || + output_depth_by_input_area_64 > std::numeric_limits::max()) { + return absl::InternalError( + "DepthToSpaceOpFunctor NCHW: output_depth_by_input_area exceeds " + "int32 bounds"); + } + const int output_depth_by_input_area = output_depth_by_input_area_64; + + int64_t total_count_64 = + MultiplyWithoutOverflow(batch_size, output_depth_by_input_area); + if (total_count_64 < 0 || + total_count_64 > std::numeric_limits::max()) { + return absl::InternalError( + "DepthToSpaceOpFunctor NCHW: total_count exceeds int32 bounds"); + } + const int total_count = total_count_64; + if (total_count == 0) { - return; + return absl::OkStatus(); } GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d); switch (block_size) { @@ -202,38 +256,54 @@ struct DepthToSpaceOpFunctor { 0, d.stream(), total_count, input.data(), input_width, output_width, output_depth_by_input_area, input_depth_by_input_area, output.data())); - return; + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); case 3: TF_CHECK_OK(GpuLaunchKernel( D2S_NCHW_LOOP, config.block_count, config.thread_per_block, 0, d.stream(), total_count, input.data(), input_width, output_width, output_depth_by_input_area, input_depth_by_input_area, output.data())); - return; + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); case 4: TF_CHECK_OK(GpuLaunchKernel( D2S_NCHW_LOOP, config.block_count, config.thread_per_block, 0, d.stream(), total_count, input.data(), input_width, output_width, output_depth_by_input_area, input_depth_by_input_area, output.data())); - return; + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } } // Other block sizes are processed by the generic kernel. - const int total_count = batch_size * input_depth_by_input_area; + int64_t total_count_64 = + MultiplyWithoutOverflow(batch_size, input_depth_by_input_area); + if (total_count_64 < 0 || + total_count_64 > std::numeric_limits::max()) { + return absl::InternalError( + "DepthToSpaceOpFunctor NCHW: total_count exceeds int32 bounds"); + } + const int total_count = total_count_64; + if (total_count == 0) { - return; + return absl::OkStatus(); } auto config = GetGpuLaunchConfig(total_count, d); TF_CHECK_OK(GpuLaunchKernel( D2S_NCHW, config.block_count, config.thread_per_block, 0, d.stream(), config.virtual_thread_count, input.data(), block_size, input_width, output_depth * input_height, output.data())); + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } - void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output) { - LOG(FATAL) << "5-D tensors should not be used with NCHW format"; + absl::Status operator()(const GPUDevice& d, + typename TTypes::ConstTensor input, + int block_size, + typename TTypes::Tensor output) { + return absl::InternalError( + "5-D tensors should not be used with NCHW format"); } }; } // end namespace functor diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index e3eaacd8919fa7..bbcbbe8269649c 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -19,8 +19,6 @@ limitations under the License. #include "absl/strings/str_join.h" #define EIGEN_USE_THREADS -#include "tensorflow/core/kernels/maxpooling_op.h" - #include #include @@ -35,6 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/eigen_pooling.h" +#include "tensorflow/core/kernels/maxpooling_op.h" // IWYU pragma: keep #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/pooling_ops_common.h" #include "tensorflow/core/lib/core/errors.h" @@ -796,7 +795,7 @@ class MaxPoolingGradGradOp : public OpKernel { errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(), ", but got ", out_grad_backprop.shape())); - functor::MaxPoolGradBackwardNoMask()( + absl::Status status = functor::MaxPoolGradBackwardNoMask()( data_format_, tensor_in.flat().data(), tensor_out.flat().data(), params.tensor_in_batch, params.out_height, params.out_width, params.depth, params.tensor_in_rows, params.tensor_in_cols, @@ -804,6 +803,9 @@ class MaxPoolingGradGradOp : public OpKernel { params.col_stride, params.pad_top, params.pad_left, out_grad_backprop.flat().data(), output->flat().data(), context->eigen_device()); + if (!status.ok()) { + context->SetStatus(status); + } } private: @@ -1494,16 +1496,15 @@ template struct LaunchMaxPoolingNoMask { static void launch(OpKernelContext* context, const PoolParameters& params, const Tensor& input, Tensor* output, bool propagate_nans) { - bool status = functor::MaxPoolForwardWithOptionalArgmax()( + absl::Status status = functor::MaxPoolForwardWithOptionalArgmax()( input.flat().data(), params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols, params.depth, params.out_height, params.out_width, params.window_rows, params.window_cols, params.row_stride, params.col_stride, params.pad_top, params.pad_left, output->flat().data(), nullptr, context->eigen_gpu_device(), propagate_nans, false); - if (!status) { - context->SetStatus( - absl::InternalError("Failed launching MaxPoolForwardNoMask")); + if (!status.ok()) { + context->SetStatus(status); } } }; @@ -1513,7 +1514,7 @@ struct LaunchMaxPoolingWithArgmax { static void launch(OpKernelContext* context, const PoolParameters& params, const Tensor& input, Tensor* output, Tensor* argmax, bool propagate_nans, bool include_batch_in_index) { - bool status = functor::MaxPoolForwardWithOptionalArgmax()( + absl::Status status = functor::MaxPoolForwardWithOptionalArgmax()( input.flat().data(), params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols, params.depth, params.out_height, params.out_width, params.window_rows, params.window_cols, @@ -1521,9 +1522,8 @@ struct LaunchMaxPoolingWithArgmax { output->flat().data(), reinterpret_cast(argmax->flat().data()), context->eigen_gpu_device(), propagate_nans, include_batch_in_index); - if (!status) { - context->SetStatus( - absl::InternalError("Failed launching MaxPoolForwardWithArgmax")); + if (!status.ok()) { + context->SetStatus(status); } } }; @@ -1540,14 +1540,13 @@ struct LaunchMaxPoolingGradWithArgmax { const int top_offset = params.out_height * params.out_width * params.depth; const int bottom_offset = params.tensor_in_rows * params.tensor_in_cols * params.depth; - bool status = functor::MaxPoolBackwardWithArgmax()( + absl::Status status = functor::MaxPoolBackwardWithArgmax()( output_size, input_size, grad_in.flat().data(), reinterpret_cast(argmax.flat().data()), top_offset, bottom_offset, grad_out->flat().data(), context->eigen_gpu_device(), include_batch_in_index); - if (!status) { - context->SetStatus( - absl::InternalError("Failed launching MaxPoolBackwardWithArgmax")); + if (!status.ok()) { + context->SetStatus(status); } } }; @@ -1565,14 +1564,13 @@ struct LaunchMaxPoolingGradGradWithArgmax { params.tensor_in_rows * params.tensor_in_cols * params.depth; const int bottom_offset = params.out_width * params.out_height * params.depth; - bool status = functor::MaxPoolGradBackwardWithArgmax()( + absl::Status status = functor::MaxPoolGradBackwardWithArgmax()( output_size, input_size, grad_in.flat().data(), reinterpret_cast(argmax.flat().data()), top_offset, bottom_offset, grad_out->flat().data(), context->eigen_gpu_device(), include_batch_in_index); - if (!status) { - context->SetStatus(absl::InternalError( - "Failed launching MaxPoolGradBackwardWithArgmax")); + if (!status.ok()) { + context->SetStatus(status); } } }; diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index 99f52259e8e7a4..86ca4f756bea86 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -22,12 +22,14 @@ limitations under the License. #include #include +#include #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/kernels/maxpooling_op.h" #include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/overflow.h" namespace tensorflow { namespace { @@ -356,27 +358,37 @@ namespace functor { #if GOOGLE_CUDA // Note: channels is the outer channels (dim 1) which has already been // divided by 4. -bool MaxPoolForwardNoMask_NCHW_VECT_C::operator()( +absl::Status MaxPoolForwardNoMask_NCHW_VECT_C::operator()( const int32_t* bottom_data, const int batch, const int height, const int width, int channels, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_t, const int pad_l, int32_t* top_data, const Eigen::GpuDevice& d) { const int kThreadsPerBlock = 1024; - const int output_size = batch * channels * pooled_height * pooled_width; - if (output_size == 0) return true; + int64_t output_size_64 = MultiplyWithoutOverflow( + MultiplyWithoutOverflow(MultiplyWithoutOverflow(batch, channels), + pooled_height), + pooled_width); + if (output_size_64 < 0 || + output_size_64 > std::numeric_limits::max()) { + return absl::InternalError( + "MaxPoolForwardNoMask_NCHW_VECT_C: output size exceeds int32 bounds"); + } + const int output_size = output_size_64; + if (output_size == 0) return absl::OkStatus(); TF_CHECK_OK(GpuLaunchKernel( MaxPoolForwardNoMaskKernel_NCHW_VECT_C, (output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, d.stream(), output_size, bottom_data, height, width, channels, pooled_height, pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, top_data)); - return d.ok(); + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } #endif // GOOGLE_CUDA template -bool MaxPoolForwardWithOptionalArgmax::operator()( +absl::Status MaxPoolForwardWithOptionalArgmax::operator()( const T* bottom_data, const int batch, const int height, const int width, const int channels, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, @@ -384,8 +396,17 @@ bool MaxPoolForwardWithOptionalArgmax::operator()( int64_t* mask, const Eigen::GpuDevice& d, bool propagate_nans, const bool include_batch_in_index) { const int kThreadsPerBlock = 1024; - const int output_size = batch * channels * pooled_height * pooled_width; - if (output_size == 0) return true; + int64_t output_size_64 = MultiplyWithoutOverflow( + MultiplyWithoutOverflow(MultiplyWithoutOverflow(batch, channels), + pooled_height), + pooled_width); + if (output_size_64 < 0 || + output_size_64 > std::numeric_limits::max()) { + return absl::InternalError( + "MaxPoolForwardWithOptionalArgmax: output size exceeds int32 bounds"); + } + const int output_size = output_size_64; + if (output_size == 0) return absl::OkStatus(); if (propagate_nans) { TF_CHECK_OK( GpuLaunchKernel(MaxPoolForwardNHWC, @@ -403,17 +424,18 @@ bool MaxPoolForwardWithOptionalArgmax::operator()( pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, top_data, mask, include_batch_in_index)); } - return d.ok(); + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } template -bool MaxPoolBackwardWithArgmax::operator()( +absl::Status MaxPoolBackwardWithArgmax::operator()( const int output_size, const int input_size, const T* top_diff, const int64_t* mask, const int top_offset, const int bottom_offset, T* bottom_diff, const Eigen::GpuDevice& d, const bool include_batch_in_index) { const int kThreadsPerBlock = 1024; - if (input_size == 0) return true; + if (input_size == 0) return absl::OkStatus(); TF_CHECK_OK(GpuLaunchKernel( SetZero, (input_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, d.stream(), input_size, bottom_diff)); @@ -422,19 +444,29 @@ bool MaxPoolBackwardWithArgmax::operator()( (output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, d.stream(), output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff, include_batch_in_index)); - return d.ok(); + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } template -bool MaxPoolGradBackwardNoMask::operator()( +absl::Status MaxPoolGradBackwardNoMask::operator()( TensorFormat data_format, const T* bottom_data, const T* output_data, const int batch, const int pooled_height, const int pooled_width, const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_t, const int pad_l, const T* top_diff, T* bottom_diff, const Eigen::GpuDevice& d) { - const int num_kernels = batch * channels * pooled_height * pooled_width; - if (num_kernels == 0) return true; + int64_t num_kernels_64 = MultiplyWithoutOverflow( + MultiplyWithoutOverflow(MultiplyWithoutOverflow(batch, channels), + pooled_height), + pooled_width); + if (num_kernels_64 < 0 || + num_kernels_64 > std::numeric_limits::max()) { + return absl::InternalError( + "MaxPoolGradBackwardNoMask: num_kernels exceeds int32 bounds"); + } + const int num_kernels = num_kernels_64; + if (num_kernels == 0) return absl::OkStatus(); GpuLaunchConfig config = GetGpuLaunchConfig(num_kernels, d); if (data_format == FORMAT_NHWC) { @@ -452,22 +484,24 @@ bool MaxPoolGradBackwardNoMask::operator()( channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, top_diff, bottom_diff)); } - return d.ok(); + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } template -bool MaxPoolGradBackwardWithArgmax::operator()( +absl::Status MaxPoolGradBackwardWithArgmax::operator()( const int output_size, const int input_size, const T* top_diff, const int64_t* mask, const int top_offset, const int bottom_offset, T* bottom_diff, const Eigen::GpuDevice& d, const bool include_batch_in_index) { - if (input_size == 0) return true; + if (input_size == 0) return absl::OkStatus(); GpuLaunchConfig config = GetGpuLaunchConfig(output_size, d); TF_CHECK_OK(GpuLaunchKernel( MaxPoolGradBackward, config.block_count, config.thread_per_block, 0, d.stream(), output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff, include_batch_in_index, input_size)); - return d.ok(); + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.h b/tensorflow/core/kernels/maxpooling_op_gpu.h index 3e8ba784d9714e..74e015f6ef522a 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.h +++ b/tensorflow/core/kernels/maxpooling_op_gpu.h @@ -22,6 +22,7 @@ limitations under the License. #define EIGEN_USE_GPU +#include "absl/status/status.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/tensor_format.h" @@ -34,49 +35,55 @@ namespace functor { // argmax indices are not written. template struct MaxPoolForwardWithOptionalArgmax { - bool operator()(const T* bottom_data, const int batch, const int height, - const int width, const int channels, const int pooled_height, - const int pooled_width, const int kernel_h, - const int kernel_w, const int stride_h, const int stride_w, - const int pad_t, const int pad_l, T* top_data, int64_t* mask, - const Eigen::GpuDevice& d, bool propagate_nans, - const bool include_batch_in_index); + absl::Status operator()(const T* bottom_data, const int batch, + const int height, const int width, const int channels, + const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_t, const int pad_l, T* top_data, + int64_t* mask, const Eigen::GpuDevice& d, + bool propagate_nans, + const bool include_batch_in_index); }; struct MaxPoolForwardNoMask_NCHW_VECT_C { - bool operator()(const int32_t* bottom_data, const int batch, const int height, - const int width, int channels, const int pooled_height, - const int pooled_width, const int kernel_h, - const int kernel_w, const int stride_h, const int stride_w, - const int pad_t, const int pad_l, int32_t* top_data, - const Eigen::GpuDevice& d); + absl::Status operator()(const int32_t* bottom_data, const int batch, + const int height, const int width, int channels, + const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_t, const int pad_l, int32_t* top_data, + const Eigen::GpuDevice& d); }; template struct MaxPoolBackwardWithArgmax { - bool operator()(const int output_size, const int input_size, - const T* top_diff, const int64_t* mask, const int top_offset, - const int bottom_offset, T* bottom_diff, - const Eigen::GpuDevice& d, const bool include_batch_in_index); + absl::Status operator()(const int output_size, const int input_size, + const T* top_diff, const int64_t* mask, + const int top_offset, const int bottom_offset, + T* bottom_diff, const Eigen::GpuDevice& d, + const bool include_batch_in_index); }; template struct MaxPoolGradBackwardWithArgmax { - bool operator()(const int output_size, const int input_size, - const T* top_diff, const int64_t* mask, const int top_offset, - const int bottom_offset, T* bottom_diff, - const Eigen::GpuDevice& d, const bool include_batch_in_index); + absl::Status operator()(const int output_size, const int input_size, + const T* top_diff, const int64_t* mask, + const int top_offset, const int bottom_offset, + T* bottom_diff, const Eigen::GpuDevice& d, + const bool include_batch_in_index); }; template struct MaxPoolGradBackwardNoMask { - bool operator()(TensorFormat data_format, const T* bottom_data, - const T* output_data, const int batch, - const int pooled_height, const int pooled_width, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, const int stride_h, - const int stride_w, const int pad_t, const int pad_l, - const T* top_diff, T* bottom_diff, const Eigen::GpuDevice& d); + absl::Status operator()(TensorFormat data_format, const T* bottom_data, + const T* output_data, const int batch, + const int pooled_height, const int pooled_width, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, + const int pad_t, const int pad_l, const T* top_diff, + T* bottom_diff, const Eigen::GpuDevice& d); }; } // namespace functor diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index cced70b25d4a39..c3e64c7030c38a 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -313,7 +313,7 @@ struct LaunchMaxPoolingNoMask_NCHW_VECT_C { static void launch(OpKernelContext* context, const PoolParameters& params, const Tensor& input, Tensor* output) { #if GOOGLE_CUDA - bool status = functor::MaxPoolForwardNoMask_NCHW_VECT_C()( + absl::Status status = functor::MaxPoolForwardNoMask_NCHW_VECT_C()( reinterpret_cast(input.flat().data()), params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols, params.depth, params.out_height, params.out_width, params.window_rows, @@ -321,9 +321,8 @@ struct LaunchMaxPoolingNoMask_NCHW_VECT_C { params.pad_top, params.pad_left, reinterpret_cast(output->flat().data()), context->eigen_gpu_device()); - if (!status) { - context->SetStatus(errors::Internal( - "Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C")); + if (!status.ok()) { + context->SetStatus(status); } #else // ROCm TODO: add support __vmaxs4 on ROCm diff --git a/tensorflow/core/kernels/spacetodepth_op.cc b/tensorflow/core/kernels/spacetodepth_op.cc index a3dc63ff9a0800..d15d9a0856c859 100644 --- a/tensorflow/core/kernels/spacetodepth_op.cc +++ b/tensorflow/core/kernels/spacetodepth_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -135,25 +136,28 @@ class SpaceToDepthOp : public OpKernel { auto Toutput_v = outputs_tensor->reinterpret_last_dimension(); functor::SpaceToDepthOpFunctor functor; - functor(context->eigen_device(), Tinput_v, block_size_, - Toutput_v); + OP_REQUIRES_OK(context, functor(context->eigen_device(), + Tinput_v, block_size_, Toutput_v)); } else if (data_format_ == FORMAT_NCHW) { CHECK((std::is_same::value)); functor::SpaceToDepthOpFunctor functor; - functor(context->eigen_device(), input.tensor(), - block_size_, outputs_tensor->tensor()); + OP_REQUIRES_OK(context, functor(context->eigen_device(), + input.tensor(), block_size_, + outputs_tensor->tensor())); } else { CHECK((std::is_same::value)); functor::SpaceToDepthOpFunctor functor; - functor(context->eigen_device(), input.tensor(), - block_size_, outputs_tensor->tensor()); + OP_REQUIRES_OK(context, functor(context->eigen_device(), + input.tensor(), block_size_, + outputs_tensor->tensor())); } } else { // NOTE: Assumes data_format_ == FORMAT_NHWC here, since we have rejected // (CPU && data_format_ != FORMAT_NHWC) in the constructor. functor::SpaceToDepthOpFunctor functor; - functor(context->eigen_device(), input.tensor(), - block_size_, outputs_tensor->tensor()); + OP_REQUIRES_OK(context, functor(context->eigen_device(), + input.tensor(), block_size_, + outputs_tensor->tensor())); } }; @@ -166,8 +170,10 @@ class SpaceToDepthOp : public OpKernel { namespace functor { template struct SpaceToDepthOpFunctor { - void operator()(const CPUDevice& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output) { + absl::Status operator()(const CPUDevice& d, + typename TTypes::ConstTensor input, + int block_size, + typename TTypes::Tensor output) { const int batch_size = output.dimension(0); const int input_height = input.dimension(1); const int input_width = input.dimension(2); @@ -188,6 +194,7 @@ struct SpaceToDepthOpFunctor { } } } + return absl::OkStatus(); } }; } // namespace functor diff --git a/tensorflow/core/kernels/spacetodepth_op.h b/tensorflow/core/kernels/spacetodepth_op.h index 3cb1df5b0c7318..d83f8cef595d3d 100644 --- a/tensorflow/core/kernels/spacetodepth_op.h +++ b/tensorflow/core/kernels/spacetodepth_op.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_SPACETODEPTH_OP_H_ // Functor definition for XentOp, must be compilable by nvcc. +#include "absl/status/status.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/util/tensor_format.h" @@ -43,12 +44,14 @@ namespace functor { // n,oY,oX,bY,bX,iC template struct SpaceToDepthOpFunctor { - void operator()(const Device& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output); + absl::Status operator()(const Device& d, + typename TTypes::ConstTensor input, + int block_size, typename TTypes::Tensor output); // This 5-D version is to support NCHW_VECT_C. - void operator()(const Device& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output); + absl::Status operator()(const Device& d, + typename TTypes::ConstTensor input, + int block_size, typename TTypes::Tensor output); }; } // namespace functor diff --git a/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc b/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc index 97acca5442890d..94f16fb3f73269 100644 --- a/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc +++ b/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc @@ -17,10 +17,13 @@ limitations under the License. #define EIGEN_USE_GPU +#include + #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/spacetodepth_op.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/overflow.h" namespace tensorflow { @@ -142,8 +145,10 @@ __global__ void S2D_NCHW_LOOP(const int32_t nthreads, namespace functor { template struct SpaceToDepthOpFunctor { - void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output) { + absl::Status operator()(const GPUDevice& d, + typename TTypes::ConstTensor input, + int block_size, + typename TTypes::Tensor output) { const int batch_size = output.dimension(0); const int input_height = input.dimension(1); const int input_width = input.dimension(2); @@ -152,10 +157,19 @@ struct SpaceToDepthOpFunctor { const int output_width = output.dimension(2); const int output_depth = output.dimension(3); - const int total_count = - batch_size * input_height * input_width * input_depth; + int64_t total_count_64 = MultiplyWithoutOverflow( + MultiplyWithoutOverflow( + MultiplyWithoutOverflow(batch_size, input_height), input_width), + input_depth); + if (total_count_64 < 0 || + total_count_64 > std::numeric_limits::max()) { + return absl::InternalError( + "SpaceToDepthOpFunctor NHWC: total_count exceeds int32 bounds"); + } + const int total_count = total_count_64; + if (total_count == 0) { - return; + return absl::OkStatus(); } GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d); TF_CHECK_OK(GpuLaunchKernel( @@ -163,33 +177,74 @@ struct SpaceToDepthOpFunctor { config.virtual_thread_count, input.data(), block_size, batch_size, input_height, input_width, input_depth, output_height, output_width, output_depth, output.data())); + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } - void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output) { - LOG(FATAL) << "5-D tensors should not be used with NHWC format"; + absl::Status operator()(const GPUDevice& d, + typename TTypes::ConstTensor input, + int block_size, + typename TTypes::Tensor output) { + return absl::InternalError( + "5-D tensors should not be used with NHWC format"); } }; template struct SpaceToDepthOpFunctor { - void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output) { + absl::Status operator()(const GPUDevice& d, + typename TTypes::ConstTensor input, + int block_size, + typename TTypes::Tensor output) { const int batch_size = output.dimension(0); const int input_depth = input.dimension(1); const int output_depth = output.dimension(1); const int output_height = output.dimension(2); const int output_width = output.dimension(3); - const int output_area = output_width * output_height; - const int output_depth_by_output_area = output_depth * output_area; + + int64_t output_area_64 = + MultiplyWithoutOverflow(output_width, output_height); + if (output_area_64 < 0 || + output_area_64 > std::numeric_limits::max()) { + return absl::InternalError( + "SpaceToDepthOpFunctor NCHW: output_area exceeds int32 bounds"); + } + const int output_area = output_area_64; + + int64_t output_depth_by_output_area_64 = + MultiplyWithoutOverflow(output_depth, output_area); + if (output_depth_by_output_area_64 < 0 || + output_depth_by_output_area_64 > std::numeric_limits::max()) { + return absl::InternalError( + "SpaceToDepthOpFunctor NCHW: output_depth_by_output_area exceeds " + "int32 bounds"); + } + const int output_depth_by_output_area = output_depth_by_output_area_64; // We improve performance by generating instantiations of the loop kernel // for the most common block sizes. if (block_size <= 4) { const int input_width = input.dimension(3); - const int input_depth_by_output_area = input_depth * output_area; - const int total_count = batch_size * input_depth_by_output_area; + int64_t input_depth_by_output_area_64 = + MultiplyWithoutOverflow(input_depth, output_area); + if (input_depth_by_output_area_64 < 0 || + input_depth_by_output_area_64 > std::numeric_limits::max()) { + return absl::InternalError( + "SpaceToDepthOpFunctor NCHW: input_depth_by_output_area exceeds " + "int32 bounds"); + } + const int input_depth_by_output_area = input_depth_by_output_area_64; + + int64_t total_count_64 = + MultiplyWithoutOverflow(batch_size, input_depth_by_output_area); + if (total_count_64 < 0 || + total_count_64 > std::numeric_limits::max()) { + return absl::InternalError( + "SpaceToDepthOpFunctor NCHW: total_count exceeds int32 bounds"); + } + const int total_count = total_count_64; + if (total_count == 0) { - return; + return absl::OkStatus(); } GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d); switch (block_size) { @@ -199,38 +254,54 @@ struct SpaceToDepthOpFunctor { 0, d.stream(), total_count, input.data(), output_width, input_width, input_depth_by_output_area, output_depth_by_output_area, output.data())); - return; + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); case 3: TF_CHECK_OK(GpuLaunchKernel( S2D_NCHW_LOOP, config.block_count, config.thread_per_block, 0, d.stream(), total_count, input.data(), output_width, input_width, input_depth_by_output_area, output_depth_by_output_area, output.data())); - return; + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); case 4: TF_CHECK_OK(GpuLaunchKernel( S2D_NCHW_LOOP, config.block_count, config.thread_per_block, 0, d.stream(), total_count, input.data(), output_width, input_width, input_depth_by_output_area, output_depth_by_output_area, output.data())); - return; + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } } // Other block sizes are processed by the generic kernel. - const int total_count = batch_size * output_depth_by_output_area; + int64_t total_count_64 = + MultiplyWithoutOverflow(batch_size, output_depth_by_output_area); + if (total_count_64 < 0 || + total_count_64 > std::numeric_limits::max()) { + return absl::InternalError( + "SpaceToDepthOpFunctor NCHW: total_count exceeds int32 bounds"); + } + const int total_count = total_count_64; + if (total_count == 0) { - return; + return absl::OkStatus(); } GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d); TF_CHECK_OK(GpuLaunchKernel( S2D_NCHW, config.block_count, config.thread_per_block, 0, d.stream(), config.virtual_thread_count, input.data(), block_size, output_width, input_depth * output_height, output.data())); + return d.ok() ? absl::OkStatus() + : absl::InternalError("GPU execution failed"); } - void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, - int block_size, typename TTypes::Tensor output) { - LOG(FATAL) << "5-D tensors should not be used with NCHW format"; + absl::Status operator()(const GPUDevice& d, + typename TTypes::ConstTensor input, + int block_size, + typename TTypes::Tensor output) { + return absl::InternalError( + "5-D tensors should not be used with NCHW format"); } }; } // end namespace functor diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 8668187f55fe31..3b4e9779ec39d6 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -311,6 +311,12 @@ xnn_datatype GetXNNPackDatatype(TfLiteContext* context, return xnn_datatype_invalid; } return xnn_datatype_qint4; + case kTfLiteInt2: + if (!CheckZeroPointForPerTensorQuantization( + context, tensor, t, -2, 1, *quantization_zero_point)) { + return xnn_datatype_invalid; + } + return xnn_datatype_qint2; default: TF_LITE_KERNEL_LOG( context, @@ -537,6 +543,7 @@ TfLiteStatus DefineXNNPACKValue(TfLiteContext* context, xnn_subgraph_t subgraph, xnn_status status = xnn_status_success; switch (datatype) { + case xnn_datatype_qint2: case xnn_datatype_qint4: case xnn_datatype_qint8: case xnn_datatype_quint8: @@ -2916,6 +2923,7 @@ class Subgraph { case kTfLiteBuiltinHardSwish: case kTfLiteBuiltinLeakyRelu: case kTfLiteBuiltinLogistic: + case kTfLiteBuiltinLog: case kTfLiteBuiltinNeg: case kTfLiteBuiltinQuantize: case kTfLiteBuiltinRelu: @@ -4263,6 +4271,7 @@ class Subgraph { case BuiltinOperator_FLOOR: case BuiltinOperator_GELU: case BuiltinOperator_HARD_SWISH: + case BuiltinOperator_LOG: case BuiltinOperator_NEG: case BuiltinOperator_RELU_N1_TO_1: case BuiltinOperator_RELU: @@ -4441,6 +4450,9 @@ class Subgraph { unary_op_type = xnn_unary_leaky_relu; break; } + case BuiltinOperator_LOG: + unary_op_type = xnn_unary_log; + break; case BuiltinOperator_LOGISTIC: unary_op_type = xnn_unary_sigmoid; break; @@ -4732,10 +4744,12 @@ class Subgraph { xnn_datatype filter_datatype = GetXNNPackDatatype( logging_context, filter_tensor, filter_tensor_id); if (filter_datatype == xnn_datatype_qint8 || - filter_datatype == xnn_datatype_qint4) { - filter_datatype = filter_datatype == xnn_datatype_qint8 - ? xnn_datatype_qcint8 - : xnn_datatype_qcint4; + filter_datatype == xnn_datatype_qint4 || + filter_datatype == xnn_datatype_qint2) { + filter_datatype = + filter_datatype == xnn_datatype_qint8 ? xnn_datatype_qcint8 + : filter_datatype == xnn_datatype_qint4 ? xnn_datatype_qcint4 + : xnn_datatype_qcint2; // Check whether we have to re-allocated the scale.. if (output_channels > 1) { TfLiteFloatArrayFree(filter_quant_params->scale); diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 9cdde18590b7b1..f6efebfe2b4d92 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -157,6 +157,7 @@ cc_library( "//tensorflow/lite/core/kernels:builtin_ops", "//tensorflow/lite/delegates/nnapi:acceleration_test_util", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", + "//tensorflow/lite/kernels/internal:runtime_shape", "//tensorflow/lite/kernels/internal:tensor_ctypes", "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/kernels/internal/utils:sparsity_format_converter", diff --git a/tensorflow/lite/kernels/embedding_lookup.cc b/tensorflow/lite/kernels/embedding_lookup.cc index 1fdd500d41b6f4..4766a8a95e0013 100644 --- a/tensorflow/lite/kernels/embedding_lookup.cc +++ b/tensorflow/lite/kernels/embedding_lookup.cc @@ -32,6 +32,8 @@ limitations under the License. #include #include #include +#include +#include #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" @@ -189,10 +191,16 @@ void Unpack2Bit(float scaling_factor, int col_size, const int8_t* value_ptr, TfLiteStatus EvalBlockwise(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* lookup, const TfLiteTensor* value, TfLiteTensor* output) { - if (value->type != kTfLiteInt4) { - TF_LITE_KERNEL_LOG( - context, - "Embedding Lookup: Blockwise embedding lookup only supports Int4 data"); + if (value->type != kTfLiteInt4 && value->type != kTfLiteInt2) { + TF_LITE_KERNEL_LOG(context, + "Embedding Lookup: Blockwise embedding lookup only " + "supports Int4 and Int2 data"); + return kTfLiteError; + } + if (output->type != kTfLiteFloat32 && output->type != kTfLiteFloat16) { + TF_LITE_KERNEL_LOG(context, + "Embedding Lookup: Blockwise embedding lookup only " + "supports Float32 and Float16 outputs"); return kTfLiteError; } if (value->dims->size != 2) { @@ -209,19 +217,44 @@ TfLiteStatus EvalBlockwise(TfLiteContext* context, TfLiteNode* node, col_size *= SizeOfDimension(value, i); } - float* output_fp32_ptr = - output->type == kTfLiteFloat32 ? GetTensorData(output) : nullptr; - half* output_fp16_ptr = - output->type == kTfLiteFloat16 ? GetTensorData(output) : nullptr; - const int8_t* value_ptr = GetTensorData(value); - const int32_t* lookup_data = GetTensorData(lookup); - const auto quantization_params = reinterpret_cast( value->quantization.params); const TfLiteTensor& scale = context->tensors[quantization_params->scale]; const int blocksize = quantization_params->blocksize; const int dimension_size = SizeOfDimension(lookup, 0); + + float* output_fp32_ptr = GetTensorData(output); + half* output_fp16_ptr = GetTensorData(output); + + const int8_t* value_ptr = GetTensorData(value); + const int32_t* lookup_data = GetTensorData(lookup); + + // Wrap the correct 2/4-bit float32/float16 unpacking function. + auto [unpack_to_fp32, unpack_to_fp16] = + value->type == kTfLiteInt2 + ? std::make_pair(Unpack2Bit, Unpack2Bit) + : std::make_pair(Unpack4Bit, Unpack4Bit); + const int values_per_byte = value->type == kTfLiteInt2 ? 4 : 2; + std::function unpack; + if (output->type == kTfLiteFloat32) { + unpack = [&, unpack = unpack_to_fp32](float scaling_factor, + size_t value_offset, + size_t output_offset) { + unpack(scaling_factor, blocksize, + &value_ptr[value_offset / values_per_byte], + &output_fp32_ptr[output_offset]); + }; + } else { + unpack = [&, unpack = unpack_to_fp16](float scaling_factor, + size_t value_offset, + size_t output_offset) { + unpack(scaling_factor, blocksize, + &value_ptr[value_offset / values_per_byte], + &output_fp16_ptr[output_offset]); + }; + } + if (col_size % blocksize != 0) { TF_LITE_KERNEL_LOG(context, "Embedding Lookup: lookup dimension %d must be " @@ -243,16 +276,8 @@ TfLiteStatus EvalBlockwise(TfLiteContext* context, TfLiteNode* node, const size_t value_offset = static_cast(idx) * col_size; for (int j = 0; j < num_blocks; ++j) { float scaling_factor = GetTensorData(&scale)[scale_offset + j]; - - if (output_fp32_ptr) { - Unpack4Bit(scaling_factor, blocksize, - &value_ptr[(value_offset + j * blocksize) / 2], - &output_fp32_ptr[j * blocksize + i * col_size]); - } else { - Unpack4Bit(scaling_factor, blocksize, - &value_ptr[(value_offset + j * blocksize) / 2], - &output_fp16_ptr[j * blocksize + i * col_size]); - } + unpack(scaling_factor, (value_offset + j * blocksize), + j * blocksize + i * col_size); } } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/embedding_lookup_test.cc b/tensorflow/lite/kernels/embedding_lookup_test.cc index 8597b8bf7dea97..213a24bd80a81e 100644 --- a/tensorflow/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/lite/kernels/embedding_lookup_test.cc @@ -533,6 +533,39 @@ TEST(PerBlockHybridEmbeddingLookupHybridOpTest, PerBlockSimple2DTestInt4) { kTestTolerance))); } +TEST(PerBlockHybridEmbeddingLookupHybridOpTest, PerBlockSimple2DTestInt2) { + PerBlockHybridEmbeddingLookupOpModel m( + /*index_shape=*/{3}, + /*weight_shape=*/{3, 16}, + /*weights_type=*/TensorType_INT2, + /*blocksize=*/8, + /*scales=*/{1.0, 2.0, 0.5, 0.25, 4.0, 0.5}); + m.SetInput({1, 0, 2}); + m.SetSignedWeight({ + 0.0, -1.0, 0.0, 1.0, 1.0, -1.0, 0.0, 1.0, // Row 0 + 0.0, -2.0, 0.0, 2.0, 2.0, -2.0, 0.0, 2.0, + -0.5, -0.5, 0.0, 0.5, 0.5, 0.0, -0.5, 0.5, // Row 1 + 0.25, 0.0, -0.25, 0.25, -0.25, -0.25, 0.0, 0.25, + 4.0, -4.0, 0.0, 4.0, -4.0, -4.0, 0.0, 4.0, // Row 2 + 0.5, -0.5, 0.0, 0.5, -0.5, -0.5, 0.0, 0.5, + }); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + -0.5, -0.5, 0.0, 0.5, 0.5, 0.0, -0.5, 0.5, // Row 1 + 0.25, 0.0, -0.25, 0.25, -0.25, -0.25, 0.0, 0.25, + 0.0, -1.0, 0.0, 1.0, 1.0, -1.0, 0.0, 1.0, // Row 0 + 0.0, -2.0, 0.0, 2.0, 2.0, -2.0, 0.0, 2.0, + 4.0, -4.0, 0.0, 4.0, -4.0, -4.0, 0.0, 4.0, // Row 2 + 0.5, -0.5, 0.0, 0.5, -0.5, -0.5, 0.0, 0.5, + }, + kTestTolerance))); +} + TEST(PerBlockHybridEmbeddingLookupHybridOpTest, PerBlockSimple2DTestInt4Float16) { PerBlockHybridEmbeddingLookupOpModel m( @@ -562,6 +595,41 @@ TEST(PerBlockHybridEmbeddingLookupHybridOpTest, kFp16TestTolerance))); } +TEST(PerBlockHybridEmbeddingLookupHybridOpTest, + PerBlockSimple2DTestInt2Float16) { + PerBlockHybridEmbeddingLookupOpModel m( + /*index_shape=*/{3}, + /*weight_shape=*/{3, 16}, + /*weights_type=*/TensorType_INT2, + /*blocksize=*/8, + /*scales=*/{1.0, 2.0, 0.5, 0.25, 4.0, 0.5}, + /*output_type=*/TensorType_FLOAT16); + m.SetInput({1, 0, 2}); + m.SetSignedWeight({ + 0.0, -1.0, 0.0, 1.0, 1.0, -1.0, 0.0, 1.0, // Row 0 + 0.0, -2.0, 0.0, 2.0, 2.0, -2.0, 0.0, 2.0, + -0.5, -0.5, 0.0, 0.5, 0.5, 0.0, -0.5, 0.5, // Row 1 + 0.25, 0.0, -0.25, 0.25, -0.25, -0.25, 0.0, 0.25, + 4.0, -4.0, 0.0, 4.0, -4.0, -4.0, 0.0, 4.0, // Row 2 + 0.5, -0.5, 0.0, 0.5, -0.5, -0.5, 0.0, 0.5, + }); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + -0.5, -0.5, 0.0, 0.5, 0.5, 0.0, -0.5, 0.5, // Row 1 + 0.25, 0.0, -0.25, 0.25, -0.25, -0.25, 0.0, 0.25, + 0.0, -1.0, 0.0, 1.0, 1.0, -1.0, 0.0, 1.0, // Row 0 + 0.0, -2.0, 0.0, 2.0, 2.0, -2.0, 0.0, 2.0, + 4.0, -4.0, 0.0, 4.0, -4.0, -4.0, 0.0, 4.0, // Row 2 + 0.5, -0.5, 0.0, 0.5, -0.5, -0.5, 0.0, 0.5, + }, + kFp16TestTolerance))); +} + TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple2DTestInt4) { PerAxisHybridEmbeddingLookupOpModel m( /*index_shape=*/{3}, /*weight_shape=*/{3, 8}, diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index b3af8a6cfd4b55..f74fb7297f25ae 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -33,6 +33,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -43,10 +44,13 @@ limitations under the License. #include "absl/log/absl_log.h" #include "absl/types/span.h" #include "Eigen/Core" // from @eigen_archive +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/interpreter.h" #include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -738,6 +742,9 @@ class SingleOpModel { if (t->type == kTfLiteInt4) { PopulateTensor4bit(index, /*offset=*/0, quantized_output.data(), quantized_output.data() + quantized_output.size()); + } else if (t->type == kTfLiteInt2) { + PopulateTensor2bit(index, /*offset=*/0, quantized_output.data(), + quantized_output.data() + quantized_output.size()); } } } diff --git a/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake b/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake index 52c0f67f61d4f6..7dbcd3b9ab48b3 100644 --- a/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake +++ b/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( cpuinfo GIT_REPOSITORY https://github.com/pytorch/cpuinfo # Sync with tensorflow/workspace2.bzl - GIT_TAG 8a9210069b5a37dd89ed118a783945502a30a4ae + GIT_TAG bc3c01e230c6974283e4b89421cfb0e232435589 GIT_PROGRESS TRUE SOURCE_DIR "${CMAKE_BINARY_DIR}/cpuinfo" ) diff --git a/tensorflow/lite/tools/optimize/quantization_utils.cc b/tensorflow/lite/tools/optimize/quantization_utils.cc index 819b0ef89c8c3d..d116dd7ec502c6 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils.cc @@ -49,6 +49,8 @@ const int8_t kMaxQuantizedValue4bit = 7; // The maximum number of dimensions supported in per-channel quantization. constexpr int kPerChannelMaxDim = 4; // LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc:QuantizationUtilsConstants) +const int8_t kMinQuantizedValue2bit = -1; +const int8_t kMaxQuantizedValue2bit = 1; } // namespace // LINT.IfChange(NumElements) @@ -423,6 +425,10 @@ void SymmetricPerBlockQuantizeValues( output_value->at(index) = std::min( kMaxQuantizedValue4bit, std::max(kMinQuantizedValue4bit, quantized_value)); + } else if (type == kTfLiteInt2) { + output_value->at(index) = std::min( + kMaxQuantizedValue2bit, + std::max(kMinQuantizedValue2bit, quantized_value)); } } } diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 33d1f9632ed2c6..f55b20de528416 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1,6 +1,5 @@ load("@xla//third_party/rules_python/python:py_library.bzl", "py_library") load("//tensorflow:pytype.default.bzl", "pytype_strict_library") -load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test") load("//tensorflow/core/platform:distribute.bzl", "distribute_py_strict_test") diff --git a/tensorflow/python/distribute/cluster_resolver/BUILD b/tensorflow/python/distribute/cluster_resolver/BUILD index da7a4dc23e5929..4cec3106ae9487 100644 --- a/tensorflow/python/distribute/cluster_resolver/BUILD +++ b/tensorflow/python/distribute/cluster_resolver/BUILD @@ -1,7 +1,5 @@ -load("@xla//third_party/rules_python/python:py_library.bzl", "py_library") - # Description: Operations defined for Cluster Resolvers -load("//tensorflow:strict.default.bzl", "py_strict_library") +load("@xla//third_party/rules_python/python:py_library.bzl", "py_library") load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") package( diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index e63a5b1edd82a2..5b56f1269babe6 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -3,8 +3,8 @@ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_python//python:proto.bzl", "py_proto_library") load("@xla//third_party/rules_python/python:py_library.bzl", "py_library") +load("@xla//third_party/rules_python/python:py_test.bzl", "py_test") load("//tensorflow:pytype.default.bzl", "pytype_strict_library") -load("//tensorflow:strict.default.bzl", "py_strict_library") load( "//tensorflow:tensorflow.bzl", "if_cuda_tools", @@ -2662,7 +2662,7 @@ tf_gen_op_wrapper_py( "//tensorflow/python/util:dispatch", "//tensorflow/python/util:tf_export", ], - py_lib_rule = py_strict_library, + py_lib_rule = py_library, deps = [":test_ops_kernels"], ) @@ -2755,7 +2755,7 @@ tf_gen_op_wrapper_py( "//tensorflow/python/util:tf_export", ], op_allowlist = ["Namespace>TestStringOutput"], - py_lib_rule = py_strict_library, + py_lib_rule = py_library, deps = [ ":test_ops_kernels", ], diff --git a/tensorflow/python/ops/linalg/sparse/BUILD b/tensorflow/python/ops/linalg/sparse/BUILD index cdce88da3e2356..bc3ae3df1373aa 100644 --- a/tensorflow/python/ops/linalg/sparse/BUILD +++ b/tensorflow/python/ops/linalg/sparse/BUILD @@ -1,6 +1,5 @@ load("@xla//third_party/rules_python/python:py_library.bzl", "py_library") load("//tensorflow:pytype.default.bzl", "pytype_strict_library") -load("//tensorflow:strict.default.bzl", "py_strict_library") # Description: Sparse CSR support for TensorFlow. load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") @@ -21,7 +20,7 @@ tf_gen_op_wrapper_py( "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", ], - py_lib_rule = py_strict_library, + py_lib_rule = py_library, visibility = ["//visibility:private"], deps = ["//tensorflow/core:sparse_csr_matrix_ops_op_lib"], ) diff --git a/tensorflow/python/platform/BUILD b/tensorflow/python/platform/BUILD index 8df3884ae491ad..1165cb72562ba6 100644 --- a/tensorflow/python/platform/BUILD +++ b/tensorflow/python/platform/BUILD @@ -1,6 +1,6 @@ # platform package -load("//tensorflow:strict.default.bzl", "py_strict_library") +load("@xla//third_party/rules_python/python:py_library.bzl", "py_library") load( "//tensorflow:tensorflow.bzl", "if_oss", @@ -29,17 +29,17 @@ tf_py_build_info_genrule( out = "build_info.py", ) -py_strict_library( +py_library( name = "build_info", srcs = ["build_info.py"], ) -py_strict_library( +py_library( name = "windows_lib_diagnostics", srcs = ["windows_lib_diagnostics.py"], ) -py_strict_library( +py_library( name = "self_check", srcs = ["self_check.py"], deps = if_oss( @@ -50,7 +50,7 @@ py_strict_library( ), ) -py_strict_library( +py_library( name = "benchmark", srcs = ["benchmark.py"], visibility = visibility + ["//tensorflow:internal"], @@ -66,12 +66,12 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "analytics", srcs = ["analytics.py"], ) -py_strict_library( +py_library( name = "device_context", srcs = ["device_context.py"], deps = [ @@ -80,7 +80,7 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "test", srcs = ["googletest.py"], # copybara:uncomment_begin(google-only) @@ -218,7 +218,7 @@ tf_python_pybind_extension( ], ) -py_strict_library( +py_library( name = "client_testlib", srcs = ["test.py"], # copybara:uncomment_begin(google-only) @@ -241,7 +241,7 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "app", srcs = ["app.py"], deps = [ @@ -252,7 +252,7 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "sysconfig", srcs = ["sysconfig.py"], deps = [ @@ -263,33 +263,33 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "__init__", srcs = ["__init__.py"], deps = [ ], ) -py_strict_library( +py_library( name = "control_imports", srcs = ["control_imports.py"], deps = [ ], ) -py_strict_library( +py_library( name = "parameterized", srcs = ["parameterized.py"], ) -py_strict_library( +py_library( name = "remote_utils", srcs = ["remote_utils.py"], deps = [ ], ) -py_strict_library( +py_library( name = "gfile", srcs = ["gfile.py"], # copybara:uncomment_begin(google-only) @@ -306,7 +306,7 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "tf_logging", srcs = ["tf_logging.py"], # copybara:uncomment_begin(google-only) @@ -326,7 +326,7 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "flags", srcs = ["flags.py"], visibility = ["//visibility:public"], @@ -336,7 +336,7 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "resource_loader", srcs = ["resource_loader.py"], visibility = ["//visibility:public"], diff --git a/tensorflow/python/user_ops/BUILD b/tensorflow/python/user_ops/BUILD index e1bc63bf421790..7e0898a23e1689 100644 --- a/tensorflow/python/user_ops/BUILD +++ b/tensorflow/python/user_ops/BUILD @@ -2,7 +2,6 @@ # Contains User Ops (internal TensorFlow version). load("@xla//third_party/rules_python/python:py_library.bzl", "py_library") -load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py") visibility = [ @@ -18,7 +17,7 @@ package( tf_gen_op_wrapper_private_py( name = "user_ops_gen", out = "ops/gen_user_ops.py", - py_lib_rule = py_strict_library, + py_lib_rule = py_library, ) # This target is deprecated. diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index d79723d04d0d92..59ef2629b320a0 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -26,6 +26,7 @@ load("@xla//third_party/gpus:sycl_configure.bzl", "sycl_configure") load("@xla//third_party/highwayhash:workspace.bzl", highwayhash = "repo") load("@xla//third_party/hwloc:workspace.bzl", hwloc = "repo") load("@xla//third_party/implib_so:workspace.bzl", implib_so = "repo") +load("@xla//third_party/libdrm:workspace.bzl", libdrm = "repo") load("@xla//third_party/llvm:workspace.bzl", llvm = "repo") load("@xla//third_party/mkl_dnn:workspace.bzl", onednn = "repo") load("@xla//third_party/nanobind:workspace.bzl", nanobind = "repo") @@ -119,6 +120,7 @@ def _initialize_third_party(): hwloc() icu() implib_so() + libdrm() jpeg() jpegxl() kissfft() diff --git a/third_party/xla/.github/workflows/ci.yml b/third_party/xla/.github/workflows/ci.yml index b470301489ccbc..9533a9b15f827a 100644 --- a/third_party/xla/.github/workflows/ci.yml +++ b/third_party/xla/.github/workflows/ci.yml @@ -72,12 +72,6 @@ jobs: name: "XLA Linux x86 GPU ROCm", repo: "openxla/xla", }, - { - pool: "linux-x86-n2-16", - container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest", - name: "XLA Linux x86 GPU ROCm Local Sysroot", - repo: "openxla/xla", - }, { pool: "linux-x86-g2-16-l4-1gpu", container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest", diff --git a/third_party/xla/.github/workflows/pr_size_check.yml b/third_party/xla/.github/workflows/pr_size_check.yml index cc052d3fbd7672..2f11d0fea22d39 100644 --- a/third_party/xla/.github/workflows/pr_size_check.yml +++ b/third_party/xla/.github/workflows/pr_size_check.yml @@ -16,6 +16,7 @@ name: PR Size Check permissions: pull-requests: write on: + # zizmor: ignore[dangerous-triggers] pull_request_target: branches: - main diff --git a/third_party/xla/.github/workflows/scorecards-analysis.yml b/third_party/xla/.github/workflows/scorecards-analysis.yml index 26df5ad1639f57..7fbe937c46d322 100644 --- a/third_party/xla/.github/workflows/scorecards-analysis.yml +++ b/third_party/xla/.github/workflows/scorecards-analysis.yml @@ -67,6 +67,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@e46ed2cbd01164d986452f91f178727624ae40d7 # v4.35.3 + uses: github/codeql-action/upload-sarif@68bde559dea0fdcac2102bfdf6230c5f70eb485e # v4.35.4 with: sarif_file: results.sarif diff --git a/third_party/xla/MODULE.bazel b/third_party/xla/MODULE.bazel index c99c2e647c3ca8..2811edfbdf0ec4 100644 --- a/third_party/xla/MODULE.bazel +++ b/third_party/xla/MODULE.bazel @@ -45,9 +45,9 @@ bazel_dep(name = "rules_ml_toolchain") # echo "sha256-${HASH}" archive_override( module_name = "rules_ml_toolchain", - integrity = "sha256-C0L2k6YMYFDYfbHgoOrrhKs/VBkfzglNhjNPrtyAfaA=", - strip_prefix = "rules_ml_toolchain-398d613aea7a4c294da49b79a6d6f3f8732bd84c", - urls = ["https://github.com/google-ml-infra/rules_ml_toolchain/archive/398d613aea7a4c294da49b79a6d6f3f8732bd84c.tar.gz"], + integrity = "sha256-YCSIAeg0IqV2nXop07s6ZqSEghyEbfwP9iBF1k5+yYI=", + strip_prefix = "rules_ml_toolchain-a3bf5be11de756adf6e212b46017769530073766", + urls = ["https://github.com/google-ml-infra/rules_ml_toolchain/archive/a3bf5be11de756adf6e212b46017769530073766.tar.gz"], ) # TODO: Upstream the patch? @@ -309,3 +309,7 @@ use_repo(pjrt_nightly_timestamp, "nightly_timestamp") pjrt_rc_number = use_extension("//build_tools/pjrt_wheels:release_candidate.bzl", "rc_number_repo_bzlmod") use_repo(pjrt_rc_number, "rc_number") + +### libdrm (dedicated extension to avoid Bzlmod merging issues) +libdrm_ext = use_extension("//third_party/libdrm:extension.bzl", "libdrm_ext") +use_repo(libdrm_ext, "libdrm") diff --git a/third_party/xla/build_tools/ci/build.py b/third_party/xla/build_tools/ci/build.py index 15ffa423326d2d..82a48f5dd0ca6a 100755 --- a/third_party/xla/build_tools/ci/build.py +++ b/third_party/xla/build_tools/ci/build.py @@ -103,7 +103,6 @@ class BuildType(enum.Enum): XLA_LINUX_X86_GPU_8X_H100_GITHUB_ACTIONS = enum.auto() XLA_LINUX_X86_GPU_ONEAPI_GITHUB_ACTIONS = enum.auto() XLA_LINUX_X86_GPU_ROCM_GITHUB_ACTIONS = enum.auto() - XLA_LINUX_X86_GPU_ROCM_LOCAL_SYSROOT_GITHUB_ACTIONS = enum.auto() # Presubmit builds for regression testing. XLA_LINUX_ARM64_CPU_48_VCPU_PRESUBMIT_GITHUB_ACTIONS = enum.auto() @@ -552,18 +551,6 @@ def nvidia_gpu_build_with_compute_capability( subcommand="build", ) -# ROCm builds - hermetic LLVM with local sysroot -Build( - type_=BuildType.XLA_LINUX_X86_GPU_ROCM_LOCAL_SYSROOT_GITHUB_ACTIONS, - repo="openxla/xla", - configs=("warnings", "rbe_linux_cpu", "rocm_clang_hermetic_local_sysroot"), - target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, - build_tag_filters=rocm_tag_filter, - test_tag_filters=rocm_tag_filter, - options={**_DEFAULT_BAZEL_OPTIONS, "//xla/tsl:ci_build": True}, - subcommand="build", -) - Build( type_=BuildType.XLA_LINUX_X86_CPU_128_VCPU_PRESUBMIT_GITHUB_ACTIONS, repo="openxla/xla", diff --git a/third_party/xla/build_tools/ci/golden_commands.txt b/third_party/xla/build_tools/ci/golden_commands.txt index 6aa55e8669d6be..61fef1ff9c4dde 100644 --- a/third_party/xla/build_tools/ci/golden_commands.txt +++ b/third_party/xla/build_tools/ci/golden_commands.txt @@ -107,11 +107,6 @@ parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_fi bazel build --build_tag_filters=-no_gpu,-requires-gpu-intel,-requires-gpu-nvidia,-cuda-only,-oneapi-only,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-skip_rocprofiler_sdk,-no_oss,-oss_excluded,-oss_serial,gpu --test_tag_filters=-no_gpu,-requires-gpu-intel,-requires-gpu-nvidia,-cuda-only,-oneapi-only,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-skip_rocprofiler_sdk,-no_oss,-oss_excluded,-oss_serial,gpu --config=warnings --config=rbe_linux_cpu --config=rocm_clang_hermetic --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --//xla/tsl:ci_build -- //xla/... //build_tools/... @tsl//tsl/... bazel analyze-profile profile.json.gz # END BuildType.XLA_LINUX_X86_GPU_ROCM_GITHUB_ACTIONS -# BEGIN BuildType.XLA_LINUX_X86_GPU_ROCM_LOCAL_SYSROOT_GITHUB_ACTIONS -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_gpu,-requires-gpu-intel,-requires-gpu-nvidia,-cuda-only,-oneapi-only,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-skip_rocprofiler_sdk,-no_oss,-oss_excluded,-oss_serial,gpu --test_tag_filters=-no_gpu,-requires-gpu-intel,-requires-gpu-nvidia,-cuda-only,-oneapi-only,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-skip_rocprofiler_sdk,-no_oss,-oss_excluded,-oss_serial,gpu --config=warnings --config=rbe_linux_cpu --config=rocm_clang_hermetic_local_sysroot --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --//xla/tsl:ci_build --nobuild -- //xla/... //build_tools/... @tsl//tsl/... -bazel build --build_tag_filters=-no_gpu,-requires-gpu-intel,-requires-gpu-nvidia,-cuda-only,-oneapi-only,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-skip_rocprofiler_sdk,-no_oss,-oss_excluded,-oss_serial,gpu --test_tag_filters=-no_gpu,-requires-gpu-intel,-requires-gpu-nvidia,-cuda-only,-oneapi-only,-requires-gpu-sm60,-requires-gpu-sm60-only,-requires-gpu-sm70,-requires-gpu-sm70-only,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm86,-requires-gpu-sm86-only,-requires-gpu-sm89,-requires-gpu-sm89-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-skip_rocprofiler_sdk,-no_oss,-oss_excluded,-oss_serial,gpu --config=warnings --config=rbe_linux_cpu --config=rocm_clang_hermetic_local_sysroot --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --//xla/tsl:ci_build -- //xla/... //build_tools/... @tsl//tsl/... -bazel analyze-profile profile.json.gz -# END BuildType.XLA_LINUX_X86_GPU_ROCM_LOCAL_SYSROOT_GITHUB_ACTIONS # BEGIN BuildType.XLA_MACOS_ARM64_CPU_KOKORO df -h bazel --version diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index fd34da96516859..1fc0a04dd2fea4 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -120,6 +120,7 @@ xla/third_party/hwloc/BUILD.system: xla/third_party/hwloc/static-components.h: xla/third_party/implib_so/get_symbols.py: xla/third_party/implib_so/make_stub.py: +xla/third_party/libdrm/BUILD: xla/third_party/llvm/run_lit.sh: xla/third_party/llvm_openmp/cmake_vars.bzl: xla/third_party/llvm_openmp/expand_cmake_vars.py: diff --git a/third_party/xla/tensorflow.bazelrc b/third_party/xla/tensorflow.bazelrc index 8812c24571b9dc..b04f749647c505 100644 --- a/third_party/xla/tensorflow.bazelrc +++ b/third_party/xla/tensorflow.bazelrc @@ -285,6 +285,7 @@ common:rocm_base --repo_env TF_NEED_ROCM=1 common:rocm_base --define=using_rocm_hipcc=true common:rocm_base --define=tensorflow_mkldnn_contraction_kernel=0 common:rocm_base --action_env=HIPCC_COMPILE_FLAGS_APPEND="--offload-compress" +common:rocm_base --repo_env=TF_ROCM_AMDGPU_TARGETS="gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" common:rocm_clang_local --config=rocm_base common:rocm_clang_local --config=clang_local @@ -299,21 +300,18 @@ common:rocm_clang_hermetic --@rules_ml_toolchain//common:enable_cuda=False common:rocm_clang_hermetic --@rules_ml_toolchain//common:enable_sycl=False common:rocm_clang_hermetic --@rules_ml_toolchain//common:enable_hermetic_cc=True common:rocm_clang_hermetic --@local_config_rocm//rocm:rocm_path_type=hermetic -common:rocm_clang_hermetic --repo_env=TF_ROCM_AMDGPU_TARGETS=gfx90a,gfx942 +common:rocm_clang_hermetic --repo_env=SYSROOT_DIST=linux_glibc_2_35 common:rocm_clang_hermetic --strategy=CppLink=local +common:rocm_clang_hermetic --@rules_ml_toolchain//common:static_libcxx=True -# ROCm with hermetic clang but local (system) sysroot -# Use this when ROCm libraries are built against system libstdc++ to avoid ABI mismatches -common:rocm_clang_hermetic_local_sysroot --config=rocm_clang_hermetic -common:rocm_clang_hermetic_local_sysroot --@rules_ml_toolchain//cc/sysroots:use_local_sysroot=True - -common:rocm --config=rocm_clang_hermetic_local_sysroot +common:rocm --config=rocm_clang_hermetic common:rocm_ci --config=rocm common:rocm_ci --@local_config_rocm//rocm:rocm_path_type=hermetic common:rocm_ci_hermetic --dynamic_mode=off -common:rocm_ci_hermetic --config=rocm_clang_hermetic_local_sysroot -common:rocm_ci_hermetic --repo_env="ROCM_DISTRO_VERSION=rocm_7.10.0_gfx90X" +common:rocm_ci_hermetic --config=rocm_clang_hermetic +common:rocm_ci_hermetic --config="gfx908,gfx90a" +common:rocm_ci_hermetic --repo_env="ROCM_DISTRO_VERSION=rocm_7.12.0_gfx90X" common:rocm_ci_hermetic --@local_config_rocm//rocm:rocm_path_type=hermetic # This config option is used for SYCL as GPU backend. diff --git a/third_party/xla/third_party/libdrm/BUILD b/third_party/xla/third_party/libdrm/BUILD new file mode 100644 index 00000000000000..c4d881aeff97c5 --- /dev/null +++ b/third_party/xla/third_party/libdrm/BUILD @@ -0,0 +1,11 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "drm_headers", + visibility = ["//visibility:public"], + deps = ["@libdrm//:drm_headers"], +) diff --git a/third_party/xla/third_party/libdrm/extension.bzl b/third_party/xla/third_party/libdrm/extension.bzl new file mode 100644 index 00000000000000..70dcff44cf3317 --- /dev/null +++ b/third_party/xla/third_party/libdrm/extension.bzl @@ -0,0 +1,10 @@ +"""Module extension for libdrm.""" + +load("//third_party/libdrm:workspace.bzl", libdrm = "repo") + +def _libdrm_ext_impl(mctx): # @unused + libdrm() + +libdrm_ext = module_extension( + implementation = _libdrm_ext_impl, +) diff --git a/third_party/xla/third_party/libdrm/libdrm.BUILD b/third_party/xla/third_party/libdrm/libdrm.BUILD new file mode 100644 index 00000000000000..b0807f5e4ef72a --- /dev/null +++ b/third_party/xla/third_party/libdrm/libdrm.BUILD @@ -0,0 +1,16 @@ +"""BUILD file for libdrm headers.""" + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # MIT-style license + +# Export just the headers needed by ROCm +cc_library( + name = "drm_headers", + hdrs = glob([ + "include/drm/*.h", + ]), + include_prefix = "libdrm", + strip_include_prefix = "include/drm", + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/libdrm/workspace.bzl b/third_party/xla/third_party/libdrm/workspace.bzl new file mode 100644 index 00000000000000..6f3f53a56fded3 --- /dev/null +++ b/third_party/xla/third_party/libdrm/workspace.bzl @@ -0,0 +1,17 @@ +"""Loads libdrm headers for ROCm compatibility.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Import libdrm headers.""" + + # libdrm 2.4.120 - a recent stable version + tf_http_archive( + name = "libdrm", + build_file = str(Label("//third_party/libdrm:libdrm.BUILD")), + sha256 = "3bf55363f76c7250946441ab51d3a6cc0ae518055c0ff017324ab76cdefb327a", + strip_prefix = "libdrm-2.4.120", + urls = tf_mirror_urls( + "https://dri.freedesktop.org/libdrm/libdrm-2.4.120.tar.xz", + ), + ) diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index fb9de46a2de356..e8cfc34826632b 100644 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -19,6 +19,47 @@ diff --ruN a/stablehlo/stablehlo/dialect/Base.td b/stablehlo/stablehlo/dialect/B //===----------------------------------------------------------------------===// // HLO combined type definitions. +diff --ruN a/stablehlo/stablehlo/dialect/ChloOps.cpp b/stablehlo/stablehlo/dialect/ChloOps.cpp +--- stablehlo/stablehlo/dialect/ChloOps.cpp ++++ stablehlo/stablehlo/dialect/ChloOps.cpp +@@ -93,6 +93,7 @@ + INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfcOp) + INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfInvOp) + INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LgammaOp) ++INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MulhiOp) + INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NextAfterOp) + INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PolygammaOp) + INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SinhOp) +diff --ruN a/stablehlo/stablehlo/dialect/ChloOps.td b/stablehlo/stablehlo/dialect/ChloOps.td +--- stablehlo/stablehlo/dialect/ChloOps.td ++++ stablehlo/stablehlo/dialect/ChloOps.td +@@ -988,4 +988,26 @@ + let hasVerifier = 1; + } + ++ ++def CHLO_MulhiOp : CHLO_Op<"mulhi", [Commutative, Pure, ++ HLO_CompatibleOperandsAndResultType]> { ++ let summary = "Mulhi operation"; ++ let description = [{ ++ Performs element-wise multiplication of two N-bit integer tensors ++ `lhs` and `rhs`, returning a N-bit integer `result` tensor containing ++ the most significant N bits of the upcasted (N+N-bit) product. ++ ++ $$ ++ \text{mulhi}(x, y) = \text{downcast}((\text{upcast}(x) * \text{upcast}(y)) >> N) ++ $$ ++ }]; ++ let arguments = (ins HLO_IntTensor:$lhs, HLO_IntTensor:$rhs); ++ let results = (outs HLO_IntTensor:$result); ++ ++ let assemblyFormat = [{ ++ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) ++ }]; ++} ++ + #endif // STABLEHLO_DIALECT_CHLO_OPS ++ diff --ruN a/stablehlo/stablehlo/dialect/Version.h b/stablehlo/stablehlo/dialect/Version.h --- stablehlo/stablehlo/dialect/Version.h +++ stablehlo/stablehlo/dialect/Version.h @@ -632,6 +673,2001 @@ diff --ruN a/stablehlo/stablehlo/reference/Tensor.cpp b/stablehlo/stablehlo/refe return std::complex(value.real().convertToDouble(), value.imag().convertToDouble()); }); +diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +--- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir ++++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +@@ -89,24 +89,24 @@ + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[CONSTANT_0]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<8.000000e+00> : tensor + // CHECK: %[[DIVIDE_0:.*]] = stablehlo.divide %[[SQRT_0]], %[[CONSTANT_1]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[MAXIMUM_0]], %[[DIVIDE_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[MAXIMUM_0]], %[[DIVIDE_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[ADD_0:.*]] = stablehlo.add %[[ABS_0]], %[[CONSTANT_2]] : tensor + // CHECK: %[[ABS_2:.*]] = stablehlo.abs %[[ADD_0]] : tensor + // CHECK: %[[MAXIMUM_1:.*]] = stablehlo.maximum %[[ABS_2]], %[[ABS_1]] : tensor + // CHECK: %[[MINIMUM_0:.*]] = stablehlo.minimum %[[ABS_2]], %[[ABS_1]] : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[MAXIMUM_1]], %[[MINIMUM_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[MAXIMUM_1]], %[[MINIMUM_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_4:.*]] = stablehlo.constant dense<1.41421354> : tensor + // CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[CONSTANT_4]], %[[MAXIMUM_1]] : tensor + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[MINIMUM_0]], %[[MAXIMUM_1]] : tensor + // CHECK: %[[MULTIPLY_1:.*]] = stablehlo.multiply %[[DIVIDE_1]], %[[DIVIDE_1]] : tensor + // CHECK: %[[ADD_1:.*]] = stablehlo.add %[[CONSTANT_2]], %[[MULTIPLY_1]] : tensor + // CHECK: %[[SQRT_1:.*]] = stablehlo.sqrt %[[ADD_1]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare EQ, %[[SQRT_1]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare EQ, %[[SQRT_1]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare GT, %[[MULTIPLY_1]], %[[CONSTANT_5]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare GT, %[[MULTIPLY_1]], %[[CONSTANT_5]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_3]], %[[COMPARE_4]] : tensor + // CHECK: %[[MULTIPLY_2:.*]] = stablehlo.multiply %[[MAXIMUM_1]], %[[MULTIPLY_1]] : tensor + // CHECK: %[[CONSTANT_6:.*]] = stablehlo.constant dense<2.000000e+00> : tensor +@@ -119,14 +119,14 @@ + // CHECK: %[[ABS_3:.*]] = stablehlo.abs %[[SUBTRACT_0]] : tensor + // CHECK: %[[MAXIMUM_2:.*]] = stablehlo.maximum %[[ABS_3]], %[[ABS_1]] : tensor + // CHECK: %[[MINIMUM_1:.*]] = stablehlo.minimum %[[ABS_3]], %[[ABS_1]] : tensor +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare EQ, %[[MAXIMUM_2]], %[[MINIMUM_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare EQ, %[[MAXIMUM_2]], %[[MINIMUM_1]] : (tensor, tensor) -> tensor + // CHECK: %[[MULTIPLY_4:.*]] = stablehlo.multiply %[[CONSTANT_4]], %[[MAXIMUM_2]] : tensor + // CHECK: %[[DIVIDE_3:.*]] = stablehlo.divide %[[MINIMUM_1]], %[[MAXIMUM_2]] : tensor + // CHECK: %[[MULTIPLY_5:.*]] = stablehlo.multiply %[[DIVIDE_3]], %[[DIVIDE_3]] : tensor + // CHECK: %[[ADD_3:.*]] = stablehlo.add %[[CONSTANT_2]], %[[MULTIPLY_5]] : tensor + // CHECK: %[[SQRT_2:.*]] = stablehlo.sqrt %[[ADD_3]] : tensor +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[SQRT_2]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare GT, %[[MULTIPLY_5]], %[[CONSTANT_5]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[SQRT_2]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare GT, %[[MULTIPLY_5]], %[[CONSTANT_5]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[COMPARE_6]], %[[COMPARE_7]] : tensor + // CHECK: %[[MULTIPLY_6:.*]] = stablehlo.multiply %[[MAXIMUM_2]], %[[MULTIPLY_5]] : tensor + // CHECK: %[[DIVIDE_4:.*]] = stablehlo.divide %[[MULTIPLY_6]], %[[CONSTANT_6]] : tensor +@@ -155,21 +155,21 @@ + // CHECK: %[[SELECT_5:.*]] = stablehlo.select %[[COMPARE_0]], %[[ABS_1]], %[[SELECT_4]] : tensor, tensor + // CHECK: %[[CONSTANT_7:.*]] = stablehlo.constant dense<9.99999995E+11> : tensor + // CHECK: %[[MULTIPLY_13:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[CONSTANT_7]] : tensor +-// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_13]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_13]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_8:.*]] = stablehlo.constant dense<9.99999997E-7> : tensor + // CHECK: %[[MULTIPLY_14:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[CONSTANT_8]] : tensor + // CHECK: %[[CONSTANT_9:.*]] = stablehlo.constant dense<1.000000e+02> : tensor + // CHECK: %[[MULTIPLY_15:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[CONSTANT_9]] : tensor + // CHECK: %[[SELECT_6:.*]] = stablehlo.select %[[COMPARE_8]], %[[MULTIPLY_14]], %[[MULTIPLY_15]] : tensor, tensor +-// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare GE, %[[ABS_1]], %[[SELECT_6]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare GE, %[[ABS_1]], %[[SELECT_6]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_7:.*]] = stablehlo.select %[[COMPARE_9]], %[[ABS_1]], %[[ABS_0]] : tensor, tensor + // CHECK: %[[SELECT_8:.*]] = stablehlo.select %[[COMPARE_9]], %[[SELECT_6]], %[[DIVIDE_0]] : tensor, tensor +-// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare GE, %[[SELECT_7]], %[[SELECT_8]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare GE, %[[SELECT_7]], %[[SELECT_8]] : (tensor, tensor) -> tensor + // CHECK: %[[LOG_0:.*]] = stablehlo.log %[[CONSTANT_6]] : tensor + // CHECK: %[[LOG_1:.*]] = stablehlo.log %[[SELECT_7]] : tensor + // CHECK: %[[ADD_11:.*]] = stablehlo.add %[[LOG_0]], %[[LOG_1]] : tensor + // CHECK: %[[CONSTANT_10:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_10]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_10]] : (tensor, tensor) -> tensor + // CHECK: %[[NOT_0:.*]] = stablehlo.not %[[COMPARE_11]] : tensor + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[COMPARE_9]], %[[NOT_0]] : tensor + // CHECK: %[[DIVIDE_8:.*]] = stablehlo.divide %[[ABS_0]], %[[ABS_1]] : tensor +@@ -182,20 +182,20 @@ + // CHECK: %[[SQRT_5:.*]] = stablehlo.sqrt %[[CONSTANT_11]] : tensor + // CHECK: %[[CONSTANT_12:.*]] = stablehlo.constant dense<4.000000e+00> : tensor + // CHECK: %[[MULTIPLY_18:.*]] = stablehlo.multiply %[[SQRT_5]], %[[CONSTANT_12]] : tensor +-// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[MULTIPLY_18]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[MULTIPLY_18]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_3:.*]] = stablehlo.and %[[COMPARE_12]], %[[COMPARE_13]] : tensor + // CHECK: %[[MULTIPLY_19:.*]] = stablehlo.multiply %[[ADD_0]], %[[SUBTRACT_0]] : tensor + // CHECK: %[[ADD_13:.*]] = stablehlo.add %[[MULTIPLY_8]], %[[CONSTANT_2]] : tensor + // CHECK: %[[DIVIDE_9:.*]] = stablehlo.divide %[[MULTIPLY_19]], %[[ADD_13]] : tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[DIVIDE_9]] : tensor +-// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[MULTIPLY_20:.*]] = stablehlo.multiply %[[CONSTANT_3]], %[[MULTIPLY_10]] : tensor + // CHECK: %[[DIVIDE_10:.*]] = stablehlo.divide %[[MULTIPLY_20]], %[[ADD_7]] : tensor + // CHECK: %[[MULTIPLY_21:.*]] = stablehlo.multiply %[[CONSTANT_3]], %[[ADD_9]] : tensor + // CHECK: %[[ADD_14:.*]] = stablehlo.add %[[DIVIDE_10]], %[[MULTIPLY_21]] : tensor + // CHECK: %[[CONSTANT_13:.*]] = stablehlo.constant dense<1.500000e+00> : tensor +-// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare LE, %[[MULTIPLY_8]], %[[CONSTANT_13]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare LE, %[[MULTIPLY_8]], %[[CONSTANT_13]] : (tensor, tensor) -> tensor + // CHECK: %[[DIVIDE_11:.*]] = stablehlo.divide %[[MULTIPLY_20]], %[[SUBTRACT_1]] : tensor + // CHECK: %[[ADD_15:.*]] = stablehlo.add %[[DIVIDE_10]], %[[DIVIDE_11]] : tensor + // CHECK: %[[SUBTRACT_2:.*]] = stablehlo.subtract %[[MULTIPLY_8]], %[[CONSTANT_2]] : tensor +@@ -214,7 +214,7 @@ + // CHECK: %[[VAL_0:.*]] = stablehlo.atan2 %[[REAL_0]], %[[REAL_2]] : tensor + // CHECK: %[[IMAG_1:.*]] = stablehlo.imag %[[ARG0]] : (tensor>) -> tensor + // CHECK: %[[CONSTANT_14:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[IMAG_1]], %[[CONSTANT_14]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[IMAG_1]], %[[CONSTANT_14]] : (tensor, tensor) -> tensor + // CHECK: %[[IMAG_2:.*]] = stablehlo.imag %[[COMPLEX_0]] : (tensor>) -> tensor + // CHECK: %[[NEGATE_1:.*]] = stablehlo.negate %[[IMAG_2]] : tensor + // CHECK: %[[SELECT_15:.*]] = stablehlo.select %[[COMPARE_16]], %[[NEGATE_1]], %[[IMAG_2]] : tensor, tensor +@@ -238,168 +238,168 @@ + // CHECK: %[[MAXIMUM_0:.*]] = stablehlo.maximum %[[ABS_0]], %[[ABS_1]] : tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<1.7976931348623157E+308> : tensor + // CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[ASSUMING_0:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_0]], %[[SHAPE_OF_0]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[ASSUMING_0]] : tensor ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_0]], %[[SHAPE_OF_0]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<8.000000e+00> : tensor ++// CHECK: %[[SHAPE_OF_1:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_1]], %[[SHAPE_OF_1]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[DIVIDE_0:.*]] = stablehlo.divide %[[SQRT_0]], %[[DYNAMIC_BROADCAST_IN_DIM_1]] : tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[MAXIMUM_0]], %[[DIVIDE_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor ++// CHECK: %[[SHAPE_OF_2:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_2:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_2]], %[[SHAPE_OF_2]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[DYNAMIC_BROADCAST_IN_DIM_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[SHAPE_OF_3:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[ASSUMING_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_1]], %[[SHAPE_OF_3]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[DIVIDE_0:.*]] = stablehlo.divide %[[SQRT_0]], %[[ASSUMING_1]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[MAXIMUM_0]], %[[DIVIDE_0]] : (tensor, tensor) -> tensor +-// CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +-// CHECK: %[[SHAPE_OF_5:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[ASSUMING_2:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_2]], %[[SHAPE_OF_5]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[ASSUMING_2]] : (tensor, tensor) -> tensor +-// CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[SHAPE_OF_7:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[ASSUMING_3:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_3]], %[[SHAPE_OF_7]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[ADD_0:.*]] = stablehlo.add %[[ABS_0]], %[[ASSUMING_2]] : tensor ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_3:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_3]], %[[SHAPE_OF_3]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[ADD_0:.*]] = stablehlo.add %[[ABS_0]], %[[DYNAMIC_BROADCAST_IN_DIM_2]] : tensor + // CHECK: %[[ABS_2:.*]] = stablehlo.abs %[[ADD_0]] : tensor + // CHECK: %[[MAXIMUM_1:.*]] = stablehlo.maximum %[[ABS_2]], %[[ABS_1]] : tensor + // CHECK: %[[MINIMUM_0:.*]] = stablehlo.minimum %[[ABS_2]], %[[ABS_1]] : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[MAXIMUM_1]], %[[MINIMUM_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[MAXIMUM_1]], %[[MINIMUM_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_4:.*]] = stablehlo.constant dense<1.4142135623730951> : tensor +-// CHECK: %[[SHAPE_OF_9:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[ASSUMING_4:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_4]], %[[SHAPE_OF_9]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[ASSUMING_4]], %[[MAXIMUM_1]] : tensor ++// CHECK: %[[SHAPE_OF_4:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_4:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_4]], %[[SHAPE_OF_4]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[DYNAMIC_BROADCAST_IN_DIM_4]], %[[MAXIMUM_1]] : tensor + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[MINIMUM_0]], %[[MAXIMUM_1]] : tensor + // CHECK: %[[MULTIPLY_1:.*]] = stablehlo.multiply %[[DIVIDE_1]], %[[DIVIDE_1]] : tensor +-// CHECK: %[[ADD_1:.*]] = stablehlo.add %[[ASSUMING_2]], %[[MULTIPLY_1]] : tensor ++// CHECK: %[[ADD_1:.*]] = stablehlo.add %[[DYNAMIC_BROADCAST_IN_DIM_2]], %[[MULTIPLY_1]] : tensor + // CHECK: %[[SQRT_1:.*]] = stablehlo.sqrt %[[ADD_1]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare EQ, %[[SQRT_1]], %[[ASSUMING_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare EQ, %[[SQRT_1]], %[[DYNAMIC_BROADCAST_IN_DIM_2]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[SHAPE_OF_11:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[ASSUMING_5:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_5]], %[[SHAPE_OF_11]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare GT, %[[MULTIPLY_1]], %[[ASSUMING_5]] : (tensor, tensor) -> tensor ++// CHECK: %[[SHAPE_OF_5:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_5:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_5]], %[[SHAPE_OF_5]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare GT, %[[MULTIPLY_1]], %[[DYNAMIC_BROADCAST_IN_DIM_5]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_3]], %[[COMPARE_4]] : tensor + // CHECK: %[[MULTIPLY_2:.*]] = stablehlo.multiply %[[MAXIMUM_1]], %[[MULTIPLY_1]] : tensor + // CHECK: %[[CONSTANT_6:.*]] = stablehlo.constant dense<2.000000e+00> : tensor +-// CHECK: %[[SHAPE_OF_13:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[ASSUMING_6:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_6]], %[[SHAPE_OF_13]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[DIVIDE_2:.*]] = stablehlo.divide %[[MULTIPLY_2]], %[[ASSUMING_6]] : tensor ++// CHECK: %[[SHAPE_OF_6:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_6:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_6]], %[[SHAPE_OF_6]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[DIVIDE_2:.*]] = stablehlo.divide %[[MULTIPLY_2]], %[[DYNAMIC_BROADCAST_IN_DIM_6]] : tensor + // CHECK: %[[ADD_2:.*]] = stablehlo.add %[[MAXIMUM_1]], %[[DIVIDE_2]] : tensor + // CHECK: %[[MULTIPLY_3:.*]] = stablehlo.multiply %[[MAXIMUM_1]], %[[SQRT_1]] : tensor +-// CHECK: %[[ASSUMING_7:.*]] = stablehlo.select %[[AND_0]], %[[ADD_2]], %[[MULTIPLY_3]] : tensor, tensor +-// CHECK: %[[ASSUMING_8:.*]] = stablehlo.select %[[COMPARE_2]], %[[MULTIPLY_0]], %[[ASSUMING_7]] : tensor, tensor +-// CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[ABS_0]], %[[ASSUMING_2]] : tensor ++// CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[AND_0]], %[[ADD_2]], %[[MULTIPLY_3]] : tensor, tensor ++// CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_2]], %[[MULTIPLY_0]], %[[SELECT_0]] : tensor, tensor ++// CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[ABS_0]], %[[DYNAMIC_BROADCAST_IN_DIM_2]] : tensor + // CHECK: %[[ABS_3:.*]] = stablehlo.abs %[[SUBTRACT_0]] : tensor + // CHECK: %[[MAXIMUM_2:.*]] = stablehlo.maximum %[[ABS_3]], %[[ABS_1]] : tensor + // CHECK: %[[MINIMUM_1:.*]] = stablehlo.minimum %[[ABS_3]], %[[ABS_1]] : tensor +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare EQ, %[[MAXIMUM_2]], %[[MINIMUM_1]] : (tensor, tensor) -> tensor +-// CHECK: %[[MULTIPLY_4:.*]] = stablehlo.multiply %[[ASSUMING_4]], %[[MAXIMUM_2]] : tensor ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare EQ, %[[MAXIMUM_2]], %[[MINIMUM_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[MULTIPLY_4:.*]] = stablehlo.multiply %[[DYNAMIC_BROADCAST_IN_DIM_4]], %[[MAXIMUM_2]] : tensor + // CHECK: %[[DIVIDE_3:.*]] = stablehlo.divide %[[MINIMUM_1]], %[[MAXIMUM_2]] : tensor + // CHECK: %[[MULTIPLY_5:.*]] = stablehlo.multiply %[[DIVIDE_3]], %[[DIVIDE_3]] : tensor +-// CHECK: %[[ADD_3:.*]] = stablehlo.add %[[ASSUMING_2]], %[[MULTIPLY_5]] : tensor ++// CHECK: %[[ADD_3:.*]] = stablehlo.add %[[DYNAMIC_BROADCAST_IN_DIM_2]], %[[MULTIPLY_5]] : tensor + // CHECK: %[[SQRT_2:.*]] = stablehlo.sqrt %[[ADD_3]] : tensor +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[SQRT_2]], %[[ASSUMING_2]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare GT, %[[MULTIPLY_5]], %[[ASSUMING_5]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[SQRT_2]], %[[DYNAMIC_BROADCAST_IN_DIM_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare GT, %[[MULTIPLY_5]], %[[DYNAMIC_BROADCAST_IN_DIM_5]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[COMPARE_6]], %[[COMPARE_7]] : tensor + // CHECK: %[[MULTIPLY_6:.*]] = stablehlo.multiply %[[MAXIMUM_2]], %[[MULTIPLY_5]] : tensor +-// CHECK: %[[DIVIDE_4:.*]] = stablehlo.divide %[[MULTIPLY_6]], %[[ASSUMING_6]] : tensor ++// CHECK: %[[DIVIDE_4:.*]] = stablehlo.divide %[[MULTIPLY_6]], %[[DYNAMIC_BROADCAST_IN_DIM_6]] : tensor + // CHECK: %[[ADD_4:.*]] = stablehlo.add %[[MAXIMUM_2]], %[[DIVIDE_4]] : tensor + // CHECK: %[[MULTIPLY_7:.*]] = stablehlo.multiply %[[MAXIMUM_2]], %[[SQRT_2]] : tensor +-// CHECK: %[[ASSUMING_9:.*]] = stablehlo.select %[[AND_1]], %[[ADD_4]], %[[MULTIPLY_7]] : tensor, tensor +-// CHECK: %[[ASSUMING_10:.*]] = stablehlo.select %[[COMPARE_5]], %[[MULTIPLY_4]], %[[ASSUMING_9]] : tensor, tensor +-// CHECK: %[[ADD_5:.*]] = stablehlo.add %[[ASSUMING_8]], %[[ASSUMING_10]] : tensor +-// CHECK: %[[MULTIPLY_8:.*]] = stablehlo.multiply %[[ASSUMING_3]], %[[ADD_5]] : tensor ++// CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[AND_1]], %[[ADD_4]], %[[MULTIPLY_7]] : tensor, tensor ++// CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_5]], %[[MULTIPLY_4]], %[[SELECT_2]] : tensor, tensor ++// CHECK: %[[ADD_5:.*]] = stablehlo.add %[[SELECT_1]], %[[SELECT_3]] : tensor ++// CHECK: %[[MULTIPLY_8:.*]] = stablehlo.multiply %[[DYNAMIC_BROADCAST_IN_DIM_3]], %[[ADD_5]] : tensor + // CHECK: %[[ADD_6:.*]] = stablehlo.add %[[MULTIPLY_8]], %[[ABS_0]] : tensor +-// CHECK: %[[MULTIPLY_9:.*]] = stablehlo.multiply %[[ASSUMING_3]], %[[ADD_6]] : tensor ++// CHECK: %[[MULTIPLY_9:.*]] = stablehlo.multiply %[[DYNAMIC_BROADCAST_IN_DIM_3]], %[[ADD_6]] : tensor + // CHECK: %[[MULTIPLY_10:.*]] = stablehlo.multiply %[[ABS_1]], %[[ABS_1]] : tensor +-// CHECK: %[[ADD_7:.*]] = stablehlo.add %[[ASSUMING_8]], %[[ADD_0]] : tensor ++// CHECK: %[[ADD_7:.*]] = stablehlo.add %[[SELECT_1]], %[[ADD_0]] : tensor + // CHECK: %[[DIVIDE_5:.*]] = stablehlo.divide %[[MULTIPLY_10]], %[[ADD_7]] : tensor +-// CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[ASSUMING_10]], %[[SUBTRACT_0]] : tensor ++// CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[SELECT_3]], %[[SUBTRACT_0]] : tensor + // CHECK: %[[ADD_8:.*]] = stablehlo.add %[[DIVIDE_5]], %[[SUBTRACT_1]] : tensor + // CHECK: %[[MULTIPLY_11:.*]] = stablehlo.multiply %[[MULTIPLY_9]], %[[ADD_8]] : tensor + // CHECK: %[[SQRT_3:.*]] = stablehlo.sqrt %[[MULTIPLY_11]] : tensor + // CHECK: %[[DIVIDE_6:.*]] = stablehlo.divide %[[MULTIPLY_9]], %[[ADD_7]] : tensor +-// CHECK: %[[ADD_9:.*]] = stablehlo.add %[[ASSUMING_10]], %[[SUBTRACT_0]] : tensor ++// CHECK: %[[ADD_9:.*]] = stablehlo.add %[[SELECT_3]], %[[SUBTRACT_0]] : tensor + // CHECK: %[[DIVIDE_7:.*]] = stablehlo.divide %[[MULTIPLY_9]], %[[ADD_9]] : tensor + // CHECK: %[[ADD_10:.*]] = stablehlo.add %[[DIVIDE_6]], %[[DIVIDE_7]] : tensor + // CHECK: %[[SQRT_4:.*]] = stablehlo.sqrt %[[ADD_10]] : tensor + // CHECK: %[[MULTIPLY_12:.*]] = stablehlo.multiply %[[ABS_1]], %[[SQRT_4]] : tensor +-// CHECK: %[[ASSUMING_11:.*]] = stablehlo.select %[[COMPARE_1]], %[[SQRT_3]], %[[MULTIPLY_12]] : tensor, tensor +-// CHECK: %[[ASSUMING_12:.*]] = stablehlo.select %[[COMPARE_0]], %[[ABS_1]], %[[ASSUMING_11]] : tensor, tensor ++// CHECK: %[[SELECT_4:.*]] = stablehlo.select %[[COMPARE_1]], %[[SQRT_3]], %[[MULTIPLY_12]] : tensor, tensor ++// CHECK: %[[SELECT_5:.*]] = stablehlo.select %[[COMPARE_0]], %[[ABS_1]], %[[SELECT_4]] : tensor, tensor + // CHECK: %[[CONSTANT_7:.*]] = stablehlo.constant dense<1.000000e+12> : tensor +-// CHECK: %[[SHAPE_OF_50:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_25:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_7]], %[[SHAPE_OF_50]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[MULTIPLY_13:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[DYNAMIC_BROADCAST_IN_DIM_25]] : tensor +-// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_13]] : (tensor, tensor) -> tensor ++// CHECK: %[[SHAPE_OF_7:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_7]], %[[SHAPE_OF_7]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[MULTIPLY_13:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[DYNAMIC_BROADCAST_IN_DIM_7]] : tensor ++// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_13]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_8:.*]] = stablehlo.constant dense<9.9999999999999995E-7> : tensor +-// CHECK: %[[SHAPE_OF_51:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_26:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_8]], %[[SHAPE_OF_51]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[MULTIPLY_14:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[DYNAMIC_BROADCAST_IN_DIM_26]] : tensor ++// CHECK: %[[SHAPE_OF_8:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_8]], %[[SHAPE_OF_8]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[MULTIPLY_14:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[DYNAMIC_BROADCAST_IN_DIM_8]] : tensor + // CHECK: %[[CONSTANT_9:.*]] = stablehlo.constant dense<1.000000e+02> : tensor +-// CHECK: %[[SHAPE_OF_52:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_27:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_9]], %[[SHAPE_OF_52]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[MULTIPLY_15:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[DYNAMIC_BROADCAST_IN_DIM_27]] : tensor +-// CHECK: %[[ASSUMING_13:.*]] = stablehlo.select %[[COMPARE_8]], %[[MULTIPLY_14]], %[[MULTIPLY_15]] : tensor, tensor +-// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare GE, %[[ABS_1]], %[[ASSUMING_13]] : (tensor, tensor) -> tensor +-// CHECK: %[[ASSUMING_14:.*]] = stablehlo.select %[[COMPARE_9]], %[[ABS_1]], %[[ABS_0]] : tensor, tensor +-// CHECK: %[[ASSUMING_15:.*]] = stablehlo.select %[[COMPARE_9]], %[[ASSUMING_13]], %[[DIVIDE_0]] : tensor, tensor +-// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare GE, %[[ASSUMING_14]], %[[ASSUMING_15]] : (tensor, tensor) -> tensor +-// CHECK: %[[LOG_0:.*]] = stablehlo.log %[[ASSUMING_6]] : tensor +-// CHECK: %[[LOG_1:.*]] = stablehlo.log %[[ASSUMING_14]] : tensor ++// CHECK: %[[SHAPE_OF_9:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_9]], %[[SHAPE_OF_9]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[MULTIPLY_15:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[DYNAMIC_BROADCAST_IN_DIM_9]] : tensor ++// CHECK: %[[SELECT_6:.*]] = stablehlo.select %[[COMPARE_8]], %[[MULTIPLY_14]], %[[MULTIPLY_15]] : tensor, tensor ++// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare GE, %[[ABS_1]], %[[SELECT_6]] : (tensor, tensor) -> tensor ++// CHECK: %[[SELECT_7:.*]] = stablehlo.select %[[COMPARE_9]], %[[ABS_1]], %[[ABS_0]] : tensor, tensor ++// CHECK: %[[SELECT_8:.*]] = stablehlo.select %[[COMPARE_9]], %[[SELECT_6]], %[[DIVIDE_0]] : tensor, tensor ++// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare GE, %[[SELECT_7]], %[[SELECT_8]] : (tensor, tensor) -> tensor ++// CHECK: %[[LOG_0:.*]] = stablehlo.log %[[DYNAMIC_BROADCAST_IN_DIM_6]] : tensor ++// CHECK: %[[LOG_1:.*]] = stablehlo.log %[[SELECT_7]] : tensor + // CHECK: %[[ADD_11:.*]] = stablehlo.add %[[LOG_0]], %[[LOG_1]] : tensor + // CHECK: %[[CONSTANT_10:.*]] = stablehlo.constant dense<0x7FF0000000000000> : tensor +-// CHECK: %[[SHAPE_OF_71:.*]] = shape.shape_of %[[IMAG_0]] : tensor -> tensor<1xindex> +-// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_37:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_10]], %[[SHAPE_OF_71]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[DYNAMIC_BROADCAST_IN_DIM_37]] : (tensor, tensor) -> tensor ++// CHECK: %[[SHAPE_OF_10:.*]] = shape.shape_of %[[IMAG_0]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_10]], %[[SHAPE_OF_10]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[DYNAMIC_BROADCAST_IN_DIM_10]] : (tensor, tensor) -> tensor + // CHECK: %[[NOT_0:.*]] = stablehlo.not %[[COMPARE_11]] : tensor + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[COMPARE_9]], %[[NOT_0]] : tensor + // CHECK: %[[DIVIDE_8:.*]] = stablehlo.divide %[[ABS_0]], %[[ABS_1]] : tensor +-// CHECK: %[[ASSUMING_16:.*]] = stablehlo.select %[[AND_2]], %[[DIVIDE_8]], %[[ASSUMING_5]] : tensor, tensor +-// CHECK: %[[MULTIPLY_16:.*]] = stablehlo.multiply %[[ASSUMING_16]], %[[ASSUMING_16]] : tensor ++// CHECK: %[[SELECT_9:.*]] = stablehlo.select %[[AND_2]], %[[DIVIDE_8]], %[[DYNAMIC_BROADCAST_IN_DIM_5]] : tensor, tensor ++// CHECK: %[[MULTIPLY_16:.*]] = stablehlo.multiply %[[SELECT_9]], %[[SELECT_9]] : tensor + // CHECK: %[[LOG_PLUS_ONE_0:.*]] = stablehlo.log_plus_one %[[MULTIPLY_16]] : tensor +-// CHECK: %[[MULTIPLY_17:.*]] = stablehlo.multiply %[[ASSUMING_3]], %[[LOG_PLUS_ONE_0]] : tensor ++// CHECK: %[[MULTIPLY_17:.*]] = stablehlo.multiply %[[DYNAMIC_BROADCAST_IN_DIM_3]], %[[LOG_PLUS_ONE_0]] : tensor + // CHECK: %[[ADD_12:.*]] = stablehlo.add %[[ADD_11]], %[[MULTIPLY_17]] : tensor + // CHECK: %[[CONSTANT_11:.*]] = stablehlo.constant dense<2.2250738585072014E-308> : tensor +-// CHECK: %[[SHAPE_OF_78:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_41:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_11]], %[[SHAPE_OF_78]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[SQRT_5:.*]] = stablehlo.sqrt %[[DYNAMIC_BROADCAST_IN_DIM_41]] : tensor ++// CHECK: %[[SHAPE_OF_11:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_11:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_11]], %[[SHAPE_OF_11]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[SQRT_5:.*]] = stablehlo.sqrt %[[DYNAMIC_BROADCAST_IN_DIM_11]] : tensor + // CHECK: %[[CONSTANT_12:.*]] = stablehlo.constant dense<4.000000e+00> : tensor +-// CHECK: %[[SHAPE_OF_79:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_42:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_12]], %[[SHAPE_OF_79]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[MULTIPLY_18:.*]] = stablehlo.multiply %[[SQRT_5]], %[[DYNAMIC_BROADCAST_IN_DIM_42]] : tensor +-// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[MULTIPLY_18]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[ASSUMING_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[SHAPE_OF_12:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_12:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_12]], %[[SHAPE_OF_12]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[MULTIPLY_18:.*]] = stablehlo.multiply %[[SQRT_5]], %[[DYNAMIC_BROADCAST_IN_DIM_12]] : tensor ++// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[MULTIPLY_18]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[DYNAMIC_BROADCAST_IN_DIM_2]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_3:.*]] = stablehlo.and %[[COMPARE_12]], %[[COMPARE_13]] : tensor + // CHECK: %[[MULTIPLY_19:.*]] = stablehlo.multiply %[[ADD_0]], %[[SUBTRACT_0]] : tensor +-// CHECK: %[[ADD_13:.*]] = stablehlo.add %[[MULTIPLY_8]], %[[ASSUMING_2]] : tensor ++// CHECK: %[[ADD_13:.*]] = stablehlo.add %[[MULTIPLY_8]], %[[DYNAMIC_BROADCAST_IN_DIM_2]] : tensor + // CHECK: %[[DIVIDE_9:.*]] = stablehlo.divide %[[MULTIPLY_19]], %[[ADD_13]] : tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[DIVIDE_9]] : tensor +-// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[ASSUMING_2]] : (tensor, tensor) -> tensor +-// CHECK: %[[MULTIPLY_20:.*]] = stablehlo.multiply %[[ASSUMING_3]], %[[MULTIPLY_10]] : tensor ++// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[DYNAMIC_BROADCAST_IN_DIM_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[MULTIPLY_20:.*]] = stablehlo.multiply %[[DYNAMIC_BROADCAST_IN_DIM_3]], %[[MULTIPLY_10]] : tensor + // CHECK: %[[DIVIDE_10:.*]] = stablehlo.divide %[[MULTIPLY_20]], %[[ADD_7]] : tensor +-// CHECK: %[[MULTIPLY_21:.*]] = stablehlo.multiply %[[ASSUMING_3]], %[[ADD_9]] : tensor ++// CHECK: %[[MULTIPLY_21:.*]] = stablehlo.multiply %[[DYNAMIC_BROADCAST_IN_DIM_3]], %[[ADD_9]] : tensor + // CHECK: %[[ADD_14:.*]] = stablehlo.add %[[DIVIDE_10]], %[[MULTIPLY_21]] : tensor + // CHECK: %[[CONSTANT_13:.*]] = stablehlo.constant dense<1.500000e+00> : tensor +-// CHECK: %[[SHAPE_OF_80:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> +-// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_43:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_13]], %[[SHAPE_OF_80]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare LE, %[[MULTIPLY_8]], %[[DYNAMIC_BROADCAST_IN_DIM_43]] : (tensor, tensor) -> tensor ++// CHECK: %[[SHAPE_OF_13:.*]] = shape.shape_of %[[REAL_1]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_13:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_13]], %[[SHAPE_OF_13]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare LE, %[[MULTIPLY_8]], %[[DYNAMIC_BROADCAST_IN_DIM_13]] : (tensor, tensor) -> tensor + // CHECK: %[[DIVIDE_11:.*]] = stablehlo.divide %[[MULTIPLY_20]], %[[SUBTRACT_1]] : tensor + // CHECK: %[[ADD_15:.*]] = stablehlo.add %[[DIVIDE_10]], %[[DIVIDE_11]] : tensor +-// CHECK: %[[SUBTRACT_2:.*]] = stablehlo.subtract %[[MULTIPLY_8]], %[[ASSUMING_2]] : tensor +-// CHECK: %[[ASSUMING_17:.*]] = stablehlo.select %[[COMPARE_15]], %[[ADD_15]], %[[SUBTRACT_2]] : tensor, tensor +-// CHECK: %[[ASSUMING_18:.*]] = stablehlo.select %[[COMPARE_14]], %[[ADD_14]], %[[ASSUMING_17]] : tensor, tensor +-// CHECK: %[[ASSUMING_19:.*]] = stablehlo.select %[[AND_3]], %[[NEGATE_0]], %[[ASSUMING_18]] : tensor, tensor +-// CHECK: %[[MULTIPLY_22:.*]] = stablehlo.multiply %[[ASSUMING_19]], %[[ADD_13]] : tensor ++// CHECK: %[[SUBTRACT_2:.*]] = stablehlo.subtract %[[MULTIPLY_8]], %[[DYNAMIC_BROADCAST_IN_DIM_2]] : tensor ++// CHECK: %[[SELECT_10:.*]] = stablehlo.select %[[COMPARE_15]], %[[ADD_15]], %[[SUBTRACT_2]] : tensor, tensor ++// CHECK: %[[SELECT_11:.*]] = stablehlo.select %[[COMPARE_14]], %[[ADD_14]], %[[SELECT_10]] : tensor, tensor ++// CHECK: %[[SELECT_12:.*]] = stablehlo.select %[[AND_3]], %[[NEGATE_0]], %[[SELECT_11]] : tensor, tensor ++// CHECK: %[[MULTIPLY_22:.*]] = stablehlo.multiply %[[SELECT_12]], %[[ADD_13]] : tensor + // CHECK: %[[SQRT_6:.*]] = stablehlo.sqrt %[[MULTIPLY_22]] : tensor + // CHECK: %[[DIVIDE_12:.*]] = stablehlo.divide %[[ABS_1]], %[[SQRT_6]] : tensor +-// CHECK: %[[ADD_16:.*]] = stablehlo.add %[[ASSUMING_19]], %[[SQRT_6]] : tensor ++// CHECK: %[[ADD_16:.*]] = stablehlo.add %[[SELECT_12]], %[[SQRT_6]] : tensor + // CHECK: %[[LOG_PLUS_ONE_1:.*]] = stablehlo.log_plus_one %[[ADD_16]] : tensor +-// CHECK: %[[ASSUMING_20:.*]] = stablehlo.select %[[AND_3]], %[[DIVIDE_12]], %[[LOG_PLUS_ONE_1]] : tensor, tensor +-// CHECK: %[[ASSUMING_21:.*]] = stablehlo.select %[[COMPARE_10]], %[[ADD_12]], %[[ASSUMING_20]] : tensor, tensor +-// CHECK: %[[COMPLEX_0:.*]] = stablehlo.complex %[[ASSUMING_12]], %[[ASSUMING_21]] : tensor> ++// CHECK: %[[SELECT_13:.*]] = stablehlo.select %[[AND_3]], %[[DIVIDE_12]], %[[LOG_PLUS_ONE_1]] : tensor, tensor ++// CHECK: %[[SELECT_14:.*]] = stablehlo.select %[[COMPARE_10]], %[[ADD_12]], %[[SELECT_13]] : tensor, tensor ++// CHECK: %[[COMPLEX_0:.*]] = stablehlo.complex %[[SELECT_5]], %[[SELECT_14]] : tensor> + // CHECK: %[[REAL_2:.*]] = stablehlo.real %[[COMPLEX_0]] : (tensor>) -> tensor + // CHECK: %[[VAL_0:.*]] = stablehlo.atan2 %[[REAL_0]], %[[REAL_2]] : tensor + // CHECK: %[[IMAG_1:.*]] = stablehlo.imag %[[ARG0]] : (tensor>) -> tensor + // CHECK: %[[CONSTANT_14:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[SHAPE_OF_111:.*]] = shape.shape_of %[[REAL_0]] : tensor -> tensor<1xindex> +-// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_59:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_14]], %[[SHAPE_OF_111]], dims = [] : (tensor, tensor<1xindex>) -> tensor +-// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[IMAG_1]], %[[DYNAMIC_BROADCAST_IN_DIM_59]] : (tensor, tensor) -> tensor ++// CHECK: %[[SHAPE_OF_14:.*]] = shape.shape_of %[[REAL_0]] : tensor -> tensor<1xindex> ++// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_14:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONSTANT_14]], %[[SHAPE_OF_14]], dims = [] : (tensor, tensor<1xindex>) -> tensor ++// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[IMAG_1]], %[[DYNAMIC_BROADCAST_IN_DIM_14]] : (tensor, tensor) -> tensor + // CHECK: %[[IMAG_2:.*]] = stablehlo.imag %[[COMPLEX_0]] : (tensor>) -> tensor + // CHECK: %[[NEGATE_1:.*]] = stablehlo.negate %[[IMAG_2]] : tensor +-// CHECK: %[[ASSUMING_22:.*]] = stablehlo.select %[[COMPARE_16]], %[[NEGATE_1]], %[[IMAG_2]] : tensor, tensor +-// CHECK: %[[COMPLEX_1:.*]] = stablehlo.complex %[[VAL_0]], %[[ASSUMING_22]] : tensor> ++// CHECK: %[[SELECT_15:.*]] = stablehlo.select %[[COMPARE_16]], %[[NEGATE_1]], %[[IMAG_2]] : tensor, tensor ++// CHECK: %[[COMPLEX_1:.*]] = stablehlo.complex %[[VAL_0]], %[[SELECT_15]] : tensor> + // CHECK: return %[[COMPLEX_1]] : tensor> + // CHECK: } + func.func @asin_complex_f64_dynamic(%arg : tensor>) -> tensor> { +@@ -415,7 +415,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<3.389530e+38> : tensor + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[CONSTANT_0]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[SQRT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[SQRT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[LOG_0:.*]] = stablehlo.log %[[CONSTANT_1]] : tensor + // CHECK: %[[LOG_1:.*]] = stablehlo.log %[[ABS_0]] : tensor +@@ -446,7 +446,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<6.550400e+04> : tensor + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[CONSTANT_0]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[SQRT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[SQRT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[LOG_0:.*]] = stablehlo.log %[[CONSTANT_1]] : tensor + // CHECK: %[[LOG_1:.*]] = stablehlo.log %[[ABS_0]] : tensor +@@ -477,7 +477,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[CONSTANT_0]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[SQRT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[SQRT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[LOG_0:.*]] = stablehlo.log %[[CONSTANT_1]] : tensor + // CHECK: %[[LOG_1:.*]] = stablehlo.log %[[ABS_0]] : tensor +@@ -508,7 +508,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<1.7976931348623157E+308> : tensor + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[CONSTANT_0]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[SQRT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[SQRT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[LOG_0:.*]] = stablehlo.log %[[CONSTANT_1]] : tensor + // CHECK: %[[LOG_1:.*]] = stablehlo.log %[[ABS_0]] : tensor +@@ -536,7 +536,7 @@ + // CHECK-SAME: %[[ARG0:.*]]: tensor>) -> tensor> { + // CHECK: %[[REAL_0:.*]] = stablehlo.real %[[ARG0]] : (tensor>) -> tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[REAL_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[REAL_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[IMAG_0:.*]] = stablehlo.imag %[[ARG0]] : (tensor>) -> tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[IMAG_0]] : tensor + // CHECK: %[[COMPLEX_0:.*]] = stablehlo.complex %[[NEGATE_0]], %[[REAL_0]] : tensor> +@@ -549,24 +549,24 @@ + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[CONSTANT_1]] : tensor + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<8.000000e+00> : tensor + // CHECK: %[[DIVIDE_0:.*]] = stablehlo.divide %[[SQRT_0]], %[[CONSTANT_2]] : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare GE, %[[MAXIMUM_0]], %[[DIVIDE_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare GE, %[[MAXIMUM_0]], %[[DIVIDE_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_3]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_4:.*]] = stablehlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[ADD_0:.*]] = stablehlo.add %[[ABS_0]], %[[CONSTANT_3]] : tensor + // CHECK: %[[ABS_2:.*]] = stablehlo.abs %[[ADD_0]] : tensor + // CHECK: %[[MAXIMUM_1:.*]] = stablehlo.maximum %[[ABS_2]], %[[ABS_1]] : tensor + // CHECK: %[[MINIMUM_0:.*]] = stablehlo.minimum %[[ABS_2]], %[[ABS_1]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare EQ, %[[MAXIMUM_1]], %[[MINIMUM_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare EQ, %[[MAXIMUM_1]], %[[MINIMUM_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_5:.*]] = stablehlo.constant dense<1.41421354> : tensor + // CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[CONSTANT_5]], %[[MAXIMUM_1]] : tensor + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[MINIMUM_0]], %[[MAXIMUM_1]] : tensor + // CHECK: %[[MULTIPLY_1:.*]] = stablehlo.multiply %[[DIVIDE_1]], %[[DIVIDE_1]] : tensor + // CHECK: %[[ADD_1:.*]] = stablehlo.add %[[CONSTANT_3]], %[[MULTIPLY_1]] : tensor + // CHECK: %[[SQRT_1:.*]] = stablehlo.sqrt %[[ADD_1]] : tensor +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare EQ, %[[SQRT_1]], %[[CONSTANT_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare EQ, %[[SQRT_1]], %[[CONSTANT_3]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_6:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare GT, %[[MULTIPLY_1]], %[[CONSTANT_6]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare GT, %[[MULTIPLY_1]], %[[CONSTANT_6]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_4]], %[[COMPARE_5]] : tensor + // CHECK: %[[MULTIPLY_2:.*]] = stablehlo.multiply %[[MAXIMUM_1]], %[[MULTIPLY_1]] : tensor + // CHECK: %[[CONSTANT_7:.*]] = stablehlo.constant dense<2.000000e+00> : tensor +@@ -579,14 +579,14 @@ + // CHECK: %[[ABS_3:.*]] = stablehlo.abs %[[SUBTRACT_0]] : tensor + // CHECK: %[[MAXIMUM_2:.*]] = stablehlo.maximum %[[ABS_3]], %[[ABS_1]] : tensor + // CHECK: %[[MINIMUM_1:.*]] = stablehlo.minimum %[[ABS_3]], %[[ABS_1]] : tensor +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[MAXIMUM_2]], %[[MINIMUM_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[MAXIMUM_2]], %[[MINIMUM_1]] : (tensor, tensor) -> tensor + // CHECK: %[[MULTIPLY_4:.*]] = stablehlo.multiply %[[CONSTANT_5]], %[[MAXIMUM_2]] : tensor + // CHECK: %[[DIVIDE_3:.*]] = stablehlo.divide %[[MINIMUM_1]], %[[MAXIMUM_2]] : tensor + // CHECK: %[[MULTIPLY_5:.*]] = stablehlo.multiply %[[DIVIDE_3]], %[[DIVIDE_3]] : tensor + // CHECK: %[[ADD_3:.*]] = stablehlo.add %[[CONSTANT_3]], %[[MULTIPLY_5]] : tensor + // CHECK: %[[SQRT_2:.*]] = stablehlo.sqrt %[[ADD_3]] : tensor +-// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[SQRT_2]], %[[CONSTANT_3]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare GT, %[[MULTIPLY_5]], %[[CONSTANT_6]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[SQRT_2]], %[[CONSTANT_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare GT, %[[MULTIPLY_5]], %[[CONSTANT_6]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[COMPARE_7]], %[[COMPARE_8]] : tensor + // CHECK: %[[MULTIPLY_6:.*]] = stablehlo.multiply %[[MAXIMUM_2]], %[[MULTIPLY_5]] : tensor + // CHECK: %[[DIVIDE_4:.*]] = stablehlo.divide %[[MULTIPLY_6]], %[[CONSTANT_7]] : tensor +@@ -615,21 +615,21 @@ + // CHECK: %[[SELECT_5:.*]] = stablehlo.select %[[COMPARE_1]], %[[ABS_1]], %[[SELECT_4]] : tensor, tensor + // CHECK: %[[CONSTANT_8:.*]] = stablehlo.constant dense<9.99999995E+11> : tensor + // CHECK: %[[MULTIPLY_13:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[CONSTANT_8]] : tensor +-// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_13]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_13]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_9:.*]] = stablehlo.constant dense<9.99999997E-7> : tensor + // CHECK: %[[MULTIPLY_14:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[CONSTANT_9]] : tensor + // CHECK: %[[CONSTANT_10:.*]] = stablehlo.constant dense<1.000000e+02> : tensor + // CHECK: %[[MULTIPLY_15:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[CONSTANT_10]] : tensor + // CHECK: %[[SELECT_6:.*]] = stablehlo.select %[[COMPARE_9]], %[[MULTIPLY_14]], %[[MULTIPLY_15]] : tensor, tensor +-// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare GE, %[[ABS_1]], %[[SELECT_6]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare GE, %[[ABS_1]], %[[SELECT_6]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_7:.*]] = stablehlo.select %[[COMPARE_10]], %[[ABS_1]], %[[ABS_0]] : tensor, tensor + // CHECK: %[[SELECT_8:.*]] = stablehlo.select %[[COMPARE_10]], %[[SELECT_6]], %[[DIVIDE_0]] : tensor, tensor +-// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare GE, %[[SELECT_7]], %[[SELECT_8]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare GE, %[[SELECT_7]], %[[SELECT_8]] : (tensor, tensor) -> tensor + // CHECK: %[[LOG_0:.*]] = stablehlo.log %[[CONSTANT_7]] : tensor + // CHECK: %[[LOG_1:.*]] = stablehlo.log %[[SELECT_7]] : tensor + // CHECK: %[[ADD_11:.*]] = stablehlo.add %[[LOG_0]], %[[LOG_1]] : tensor + // CHECK: %[[CONSTANT_11:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_11]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_11]] : (tensor, tensor) -> tensor + // CHECK: %[[NOT_0:.*]] = stablehlo.not %[[COMPARE_12]] : tensor + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[COMPARE_10]], %[[NOT_0]] : tensor + // CHECK: %[[DIVIDE_8:.*]] = stablehlo.divide %[[ABS_0]], %[[ABS_1]] : tensor +@@ -642,20 +642,20 @@ + // CHECK: %[[SQRT_5:.*]] = stablehlo.sqrt %[[CONSTANT_12]] : tensor + // CHECK: %[[CONSTANT_13:.*]] = stablehlo.constant dense<4.000000e+00> : tensor + // CHECK: %[[MULTIPLY_18:.*]] = stablehlo.multiply %[[SQRT_5]], %[[CONSTANT_13]] : tensor +-// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[MULTIPLY_18]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[MULTIPLY_18]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_3]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_3:.*]] = stablehlo.and %[[COMPARE_13]], %[[COMPARE_14]] : tensor + // CHECK: %[[MULTIPLY_19:.*]] = stablehlo.multiply %[[ADD_0]], %[[SUBTRACT_0]] : tensor + // CHECK: %[[ADD_13:.*]] = stablehlo.add %[[MULTIPLY_8]], %[[CONSTANT_3]] : tensor + // CHECK: %[[DIVIDE_9:.*]] = stablehlo.divide %[[MULTIPLY_19]], %[[ADD_13]] : tensor + // CHECK: %[[NEGATE_1:.*]] = stablehlo.negate %[[DIVIDE_9]] : tensor +-// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[CONSTANT_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[CONSTANT_3]] : (tensor, tensor) -> tensor + // CHECK: %[[MULTIPLY_20:.*]] = stablehlo.multiply %[[CONSTANT_4]], %[[MULTIPLY_10]] : tensor + // CHECK: %[[DIVIDE_10:.*]] = stablehlo.divide %[[MULTIPLY_20]], %[[ADD_7]] : tensor + // CHECK: %[[MULTIPLY_21:.*]] = stablehlo.multiply %[[CONSTANT_4]], %[[ADD_9]] : tensor + // CHECK: %[[ADD_14:.*]] = stablehlo.add %[[DIVIDE_10]], %[[MULTIPLY_21]] : tensor + // CHECK: %[[CONSTANT_14:.*]] = stablehlo.constant dense<1.500000e+00> : tensor +-// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LE, %[[MULTIPLY_8]], %[[CONSTANT_14]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LE, %[[MULTIPLY_8]], %[[CONSTANT_14]] : (tensor, tensor) -> tensor + // CHECK: %[[DIVIDE_11:.*]] = stablehlo.divide %[[MULTIPLY_20]], %[[SUBTRACT_1]] : tensor + // CHECK: %[[ADD_15:.*]] = stablehlo.add %[[DIVIDE_10]], %[[DIVIDE_11]] : tensor + // CHECK: %[[SUBTRACT_2:.*]] = stablehlo.subtract %[[MULTIPLY_8]], %[[CONSTANT_3]] : tensor +@@ -754,7 +754,7 @@ + // ----- + + // CHECK-LABEL: func.func @conj_real( +-// CHECK-SAME: %[[ARG0:.*]]: tensor<3xf32>) -> tensor<3xf32> { ++// CHECK-SAME: %[[ARG0:.*]]: tensor<3xf32>) -> tensor<3xf32> { + // CHECK: return %[[ARG0]] : tensor<3xf32> + // CHECK: } + func.func @conj_real(%arg0: tensor<3xf32>) -> tensor<3xf32> { +@@ -893,19 +893,19 @@ + // CHECK: %[[ADD_35:.*]] = stablehlo.add %[[MULTIPLY_40]], %[[CONSTANT_42]] : tensor + // CHECK: %[[DIVIDE_2:.*]] = stablehlo.divide %[[MULTIPLY_34]], %[[ADD_35]] : tensor + // CHECK: %[[CONSTANT_43:.*]] = stablehlo.constant dense<8.000000e+00> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_43]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_43]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[DIVIDE_1]], %[[DIVIDE_2]] : tensor, tensor + // CHECK: %[[CONSTANT_44:.*]] = stablehlo.constant dense<-709.78271289338397> : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_0]], %[[CONSTANT_44]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_0]], %[[CONSTANT_44]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_45:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[CONSTANT_45]], %[[SELECT_0]] : tensor, tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_45]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_45]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_46:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[CONSTANT_46]], %[[SELECT_1]] : tensor + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[COMPARE_2]], %[[SUBTRACT_0]], %[[SELECT_1]] : tensor, tensor + // CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[CONSTANT_11]], %[[SELECT_2]] : tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[ARG0]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[CONSTANT_11]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[CONSTANT_11]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_3]], %[[DIVIDE_0]], %[[SUBTRACT_1]] : tensor, tensor + // CHECK: return %[[SELECT_3]] : tensor + // CHECK: } +@@ -1081,7 +1081,7 @@ + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<6.550400e+04> : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[DIVIDE_0:.*]] = stablehlo.divide %[[CONSTANT_0]], %[[CONSTANT_1]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[ARG0]], %[[DIVIDE_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[ARG0]], %[[DIVIDE_0]] : (tensor, tensor) -> tensor + // CHECK: %[[LOG_0:.*]] = stablehlo.log %[[CONSTANT_1]] : tensor + // CHECK: %[[LOG_1:.*]] = stablehlo.log %[[ARG0]] : tensor + // CHECK: %[[ADD_0:.*]] = stablehlo.add %[[LOG_0]], %[[LOG_1]] : tensor +@@ -1114,24 +1114,24 @@ + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[CONSTANT_0]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<8.000000e+00> : tensor + // CHECK: %[[DIVIDE_0:.*]] = stablehlo.divide %[[SQRT_0]], %[[CONSTANT_1]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[MAXIMUM_0]], %[[DIVIDE_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[MAXIMUM_0]], %[[DIVIDE_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[ADD_0:.*]] = stablehlo.add %[[ABS_0]], %[[CONSTANT_2]] : tensor + // CHECK: %[[ABS_2:.*]] = stablehlo.abs %[[ADD_0]] : tensor + // CHECK: %[[MAXIMUM_1:.*]] = stablehlo.maximum %[[ABS_2]], %[[ABS_1]] : tensor + // CHECK: %[[MINIMUM_0:.*]] = stablehlo.minimum %[[ABS_2]], %[[ABS_1]] : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[MAXIMUM_1]], %[[MINIMUM_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[MAXIMUM_1]], %[[MINIMUM_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_4:.*]] = stablehlo.constant dense<1.41421354> : tensor + // CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[CONSTANT_4]], %[[MAXIMUM_1]] : tensor + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[MINIMUM_0]], %[[MAXIMUM_1]] : tensor + // CHECK: %[[MULTIPLY_1:.*]] = stablehlo.multiply %[[DIVIDE_1]], %[[DIVIDE_1]] : tensor + // CHECK: %[[ADD_1:.*]] = stablehlo.add %[[CONSTANT_2]], %[[MULTIPLY_1]] : tensor + // CHECK: %[[SQRT_1:.*]] = stablehlo.sqrt %[[ADD_1]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare EQ, %[[SQRT_1]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare EQ, %[[SQRT_1]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare GT, %[[MULTIPLY_1]], %[[CONSTANT_5]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare GT, %[[MULTIPLY_1]], %[[CONSTANT_5]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_3]], %[[COMPARE_4]] : tensor + // CHECK: %[[MULTIPLY_2:.*]] = stablehlo.multiply %[[MAXIMUM_1]], %[[MULTIPLY_1]] : tensor + // CHECK: %[[CONSTANT_6:.*]] = stablehlo.constant dense<2.000000e+00> : tensor +@@ -1144,14 +1144,14 @@ + // CHECK: %[[ABS_3:.*]] = stablehlo.abs %[[SUBTRACT_0]] : tensor + // CHECK: %[[MAXIMUM_2:.*]] = stablehlo.maximum %[[ABS_3]], %[[ABS_1]] : tensor + // CHECK: %[[MINIMUM_1:.*]] = stablehlo.minimum %[[ABS_3]], %[[ABS_1]] : tensor +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare EQ, %[[MAXIMUM_2]], %[[MINIMUM_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare EQ, %[[MAXIMUM_2]], %[[MINIMUM_1]] : (tensor, tensor) -> tensor + // CHECK: %[[MULTIPLY_4:.*]] = stablehlo.multiply %[[CONSTANT_4]], %[[MAXIMUM_2]] : tensor + // CHECK: %[[DIVIDE_3:.*]] = stablehlo.divide %[[MINIMUM_1]], %[[MAXIMUM_2]] : tensor + // CHECK: %[[MULTIPLY_5:.*]] = stablehlo.multiply %[[DIVIDE_3]], %[[DIVIDE_3]] : tensor + // CHECK: %[[ADD_3:.*]] = stablehlo.add %[[CONSTANT_2]], %[[MULTIPLY_5]] : tensor + // CHECK: %[[SQRT_2:.*]] = stablehlo.sqrt %[[ADD_3]] : tensor +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[SQRT_2]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare GT, %[[MULTIPLY_5]], %[[CONSTANT_5]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[SQRT_2]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare GT, %[[MULTIPLY_5]], %[[CONSTANT_5]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[COMPARE_6]], %[[COMPARE_7]] : tensor + // CHECK: %[[MULTIPLY_6:.*]] = stablehlo.multiply %[[MAXIMUM_2]], %[[MULTIPLY_5]] : tensor + // CHECK: %[[DIVIDE_4:.*]] = stablehlo.divide %[[MULTIPLY_6]], %[[CONSTANT_6]] : tensor +@@ -1180,21 +1180,21 @@ + // CHECK: %[[SELECT_5:.*]] = stablehlo.select %[[COMPARE_0]], %[[ABS_1]], %[[SELECT_4]] : tensor, tensor + // CHECK: %[[CONSTANT_7:.*]] = stablehlo.constant dense<9.99999995E+11> : tensor + // CHECK: %[[MULTIPLY_13:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[CONSTANT_7]] : tensor +-// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_13]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_13]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_8:.*]] = stablehlo.constant dense<9.99999997E-7> : tensor + // CHECK: %[[MULTIPLY_14:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[CONSTANT_8]] : tensor + // CHECK: %[[CONSTANT_9:.*]] = stablehlo.constant dense<1.000000e+02> : tensor + // CHECK: %[[MULTIPLY_15:.*]] = stablehlo.multiply %[[DIVIDE_0]], %[[CONSTANT_9]] : tensor + // CHECK: %[[SELECT_6:.*]] = stablehlo.select %[[COMPARE_8]], %[[MULTIPLY_14]], %[[MULTIPLY_15]] : tensor, tensor +-// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare GE, %[[ABS_1]], %[[SELECT_6]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare GE, %[[ABS_1]], %[[SELECT_6]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_7:.*]] = stablehlo.select %[[COMPARE_9]], %[[ABS_1]], %[[ABS_0]] : tensor, tensor + // CHECK: %[[SELECT_8:.*]] = stablehlo.select %[[COMPARE_9]], %[[SELECT_6]], %[[DIVIDE_0]] : tensor, tensor +-// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare GE, %[[SELECT_7]], %[[SELECT_8]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare GE, %[[SELECT_7]], %[[SELECT_8]] : (tensor, tensor) -> tensor + // CHECK: %[[LOG_0:.*]] = stablehlo.log %[[CONSTANT_6]] : tensor + // CHECK: %[[LOG_1:.*]] = stablehlo.log %[[SELECT_7]] : tensor + // CHECK: %[[ADD_11:.*]] = stablehlo.add %[[LOG_0]], %[[LOG_1]] : tensor + // CHECK: %[[CONSTANT_10:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_10]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_10]] : (tensor, tensor) -> tensor + // CHECK: %[[NOT_0:.*]] = stablehlo.not %[[COMPARE_11]] : tensor + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[COMPARE_9]], %[[NOT_0]] : tensor + // CHECK: %[[DIVIDE_8:.*]] = stablehlo.divide %[[ABS_0]], %[[ABS_1]] : tensor +@@ -1207,20 +1207,20 @@ + // CHECK: %[[SQRT_5:.*]] = stablehlo.sqrt %[[CONSTANT_11]] : tensor + // CHECK: %[[CONSTANT_12:.*]] = stablehlo.constant dense<4.000000e+00> : tensor + // CHECK: %[[MULTIPLY_18:.*]] = stablehlo.multiply %[[SQRT_5]], %[[CONSTANT_12]] : tensor +-// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[MULTIPLY_18]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[MULTIPLY_18]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_3:.*]] = stablehlo.and %[[COMPARE_12]], %[[COMPARE_13]] : tensor + // CHECK: %[[MULTIPLY_19:.*]] = stablehlo.multiply %[[ADD_0]], %[[SUBTRACT_0]] : tensor + // CHECK: %[[ADD_13:.*]] = stablehlo.add %[[MULTIPLY_8]], %[[CONSTANT_2]] : tensor + // CHECK: %[[DIVIDE_9:.*]] = stablehlo.divide %[[MULTIPLY_19]], %[[ADD_13]] : tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[DIVIDE_9]] : tensor +-// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare GE, %[[ABS_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[MULTIPLY_20:.*]] = stablehlo.multiply %[[CONSTANT_3]], %[[MULTIPLY_10]] : tensor + // CHECK: %[[DIVIDE_10:.*]] = stablehlo.divide %[[MULTIPLY_20]], %[[ADD_7]] : tensor + // CHECK: %[[MULTIPLY_21:.*]] = stablehlo.multiply %[[CONSTANT_3]], %[[ADD_9]] : tensor + // CHECK: %[[ADD_14:.*]] = stablehlo.add %[[DIVIDE_10]], %[[MULTIPLY_21]] : tensor + // CHECK: %[[CONSTANT_13:.*]] = stablehlo.constant dense<1.500000e+00> : tensor +-// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare LE, %[[MULTIPLY_8]], %[[CONSTANT_13]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare LE, %[[MULTIPLY_8]], %[[CONSTANT_13]] : (tensor, tensor) -> tensor + // CHECK: %[[DIVIDE_11:.*]] = stablehlo.divide %[[MULTIPLY_20]], %[[SUBTRACT_1]] : tensor + // CHECK: %[[ADD_15:.*]] = stablehlo.add %[[DIVIDE_10]], %[[DIVIDE_11]] : tensor + // CHECK: %[[SUBTRACT_2:.*]] = stablehlo.subtract %[[MULTIPLY_8]], %[[CONSTANT_2]] : tensor +@@ -1238,7 +1238,7 @@ + // CHECK: %[[IMAG_1:.*]] = stablehlo.imag %[[COMPLEX_0]] : (tensor>) -> tensor + // CHECK: %[[IMAG_2:.*]] = stablehlo.imag %[[ARG0]] : (tensor>) -> tensor + // CHECK: %[[CONSTANT_14:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[IMAG_2]], %[[CONSTANT_14]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[IMAG_2]], %[[CONSTANT_14]] : (tensor, tensor) -> tensor + // CHECK: %[[REAL_1:.*]] = stablehlo.real %[[COMPLEX_0]] : (tensor>) -> tensor + // CHECK: %[[REAL_2:.*]] = stablehlo.real %[[ARG0]] : (tensor>) -> tensor + // CHECK: %[[VAL_0:.*]] = stablehlo.atan2 %[[REAL_1]], %[[REAL_2]] : tensor +@@ -1350,13 +1350,13 @@ + // CHECK: %[[ADD_26:.*]] = stablehlo.add %[[MULTIPLY_29]], %[[CONSTANT_30]] : tensor + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[MULTIPLY_23]], %[[ADD_26]] : tensor + // CHECK: %[[CONSTANT_31:.*]] = stablehlo.constant dense<8.000000e+00> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_31]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_31]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[DIVIDE_0]], %[[DIVIDE_1]] : tensor, tensor + // CHECK: %[[CONSTANT_32:.*]] = stablehlo.constant dense<-709.78271289338397> : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_0]], %[[CONSTANT_32]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_0]], %[[CONSTANT_32]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_33:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[CONSTANT_33]], %[[SELECT_0]] : tensor, tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_33]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_33]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_34:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[CONSTANT_34]], %[[SELECT_1]] : tensor + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[COMPARE_2]], %[[SUBTRACT_0]], %[[SELECT_1]] : tensor, tensor +@@ -1395,7 +1395,7 @@ + // CHECK: %[[DIVIDE_2:.*]] = stablehlo.divide %[[MULTIPLY_35]], %[[ADD_35]] : tensor + // CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[CONSTANT_35]], %[[DIVIDE_2]] : tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[ARG0]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[CONSTANT_35]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[CONSTANT_35]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_3]], %[[SUBTRACT_1]], %[[SELECT_2]] : tensor, tensor + // CHECK: return %[[SELECT_3]] : tensor + // CHECK: } +@@ -1417,7 +1417,7 @@ + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[CONSTANT_0]], %[[ABS_0]] : tensor + // CHECK: %[[MULTIPLY_1:.*]] = stablehlo.multiply %[[EXPONENTIAL_0]], %[[DIVIDE_1]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2.000000e+00> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<2.326820e-02> : tensor + // CHECK: %[[MULTIPLY_2:.*]] = stablehlo.multiply %[[CONSTANT_2]], %[[DIVIDE_0]] : tensor + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<-0.138703942> : tensor +@@ -1468,10 +1468,10 @@ + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[ADD_7]], %[[ADD_14]] : tensor, tensor + // CHECK: %[[MULTIPLY_17:.*]] = stablehlo.multiply %[[MULTIPLY_1]], %[[SELECT_0]] : tensor + // CHECK: %[[CONSTANT_19:.*]] = stablehlo.constant dense<-88.7228394> : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_0]], %[[CONSTANT_19]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_0]], %[[CONSTANT_19]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_20:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[CONSTANT_20]], %[[MULTIPLY_17]] : tensor, tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_20]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_20]] : (tensor, tensor) -> tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[CONSTANT_1]], %[[SELECT_1]] : tensor + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[COMPARE_2]], %[[SUBTRACT_0]], %[[SELECT_1]] : tensor, tensor + // CHECK: %[[CONSTANT_21:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +@@ -1498,7 +1498,7 @@ + // CHECK: %[[MULTIPLY_25:.*]] = stablehlo.multiply %[[ARG0]], %[[ADD_20]] : tensor + // CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[CONSTANT_21]], %[[MULTIPLY_25]] : tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[ARG0]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[CONSTANT_21]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[CONSTANT_21]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_3]], %[[SUBTRACT_1]], %[[SELECT_2]] : tensor, tensor + // CHECK: return %[[SELECT_3]] : tensor + // CHECK: } +@@ -1521,7 +1521,7 @@ + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[CONSTANT_0]], %[[ABS_0]] : tensor + // CHECK: %[[MULTIPLY_1:.*]] = stablehlo.multiply %[[EXPONENTIAL_0]], %[[DIVIDE_1]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2.000000e+00> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<2.326820e-02> : tensor + // CHECK: %[[MULTIPLY_2:.*]] = stablehlo.multiply %[[CONSTANT_2]], %[[DIVIDE_0]] : tensor + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<-0.138703942> : tensor +@@ -1572,10 +1572,10 @@ + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[ADD_7]], %[[ADD_14]] : tensor, tensor + // CHECK: %[[MULTIPLY_17:.*]] = stablehlo.multiply %[[MULTIPLY_1]], %[[SELECT_0]] : tensor + // CHECK: %[[CONSTANT_19:.*]] = stablehlo.constant dense<-88.7228394> : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_0]], %[[CONSTANT_19]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_0]], %[[CONSTANT_19]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_20:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[CONSTANT_20]], %[[MULTIPLY_17]] : tensor, tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_20]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_20]] : (tensor, tensor) -> tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[CONSTANT_1]], %[[SELECT_1]] : tensor + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[COMPARE_2]], %[[SUBTRACT_0]], %[[SELECT_1]] : tensor, tensor + // CHECK: %[[CONSTANT_21:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +@@ -1602,7 +1602,7 @@ + // CHECK: %[[MULTIPLY_25:.*]] = stablehlo.multiply %[[CONVERT_0]], %[[ADD_20]] : tensor + // CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[CONSTANT_21]], %[[MULTIPLY_25]] : tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[CONVERT_0]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[CONSTANT_21]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[CONSTANT_21]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_3]], %[[SUBTRACT_1]], %[[SELECT_2]] : tensor, tensor + // CHECK: %[[CONVERT_1:.*]] = stablehlo.convert %[[SELECT_3]] : (tensor) -> tensor + // CHECK: return %[[CONVERT_1]] : tensor +@@ -1626,7 +1626,7 @@ + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[CONSTANT_0]], %[[ABS_0]] : tensor + // CHECK: %[[MULTIPLY_1:.*]] = stablehlo.multiply %[[EXPONENTIAL_0]], %[[DIVIDE_1]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2.000000e+00> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<2.326820e-02> : tensor + // CHECK: %[[MULTIPLY_2:.*]] = stablehlo.multiply %[[CONSTANT_2]], %[[DIVIDE_0]] : tensor + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<-0.138703942> : tensor +@@ -1677,10 +1677,10 @@ + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[ADD_7]], %[[ADD_14]] : tensor, tensor + // CHECK: %[[MULTIPLY_17:.*]] = stablehlo.multiply %[[MULTIPLY_1]], %[[SELECT_0]] : tensor + // CHECK: %[[CONSTANT_19:.*]] = stablehlo.constant dense<-88.7228394> : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_0]], %[[CONSTANT_19]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_0]], %[[CONSTANT_19]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_20:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[CONSTANT_20]], %[[MULTIPLY_17]] : tensor, tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_20]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_20]] : (tensor, tensor) -> tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[CONSTANT_1]], %[[SELECT_1]] : tensor + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[COMPARE_2]], %[[SUBTRACT_0]], %[[SELECT_1]] : tensor, tensor + // CHECK: %[[CONSTANT_21:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +@@ -1707,7 +1707,7 @@ + // CHECK: %[[MULTIPLY_25:.*]] = stablehlo.multiply %[[CONVERT_0]], %[[ADD_20]] : tensor + // CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[CONSTANT_21]], %[[MULTIPLY_25]] : tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[CONVERT_0]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[CONSTANT_21]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[CONSTANT_21]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_3]], %[[SUBTRACT_1]], %[[SELECT_2]] : tensor, tensor + // CHECK: %[[CONVERT_1:.*]] = stablehlo.convert %[[SELECT_3]] : (tensor) -> tensor + // CHECK: return %[[CONVERT_1]] : tensor +@@ -1723,7 +1723,7 @@ + // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare EQ, %[[ABS_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare EQ, %[[ABS_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: return %[[COMPARE_0]] : tensor + // CHECK: } + func.func @is_inf_f32(%arg : tensor) -> tensor { +@@ -1736,7 +1736,7 @@ + // CHECK-LABEL: func.func @is_pos_inf_f32( + // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: return %[[COMPARE_0]] : tensor + // CHECK: } + func.func @is_pos_inf_f32(%arg : tensor) -> tensor { +@@ -1749,7 +1749,7 @@ + // CHECK-LABEL: func.func @is_neg_inf_f32( + // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0xFF800000> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: return %[[COMPARE_0]] : tensor + // CHECK: } + func.func @is_neg_inf_f32(%arg : tensor) -> tensor { +@@ -1762,7 +1762,7 @@ + // CHECK-LABEL: func.func @lgamma_f64( + // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[ARG0]], %[[CONSTANT_1]] : tensor +@@ -1825,7 +1825,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor + // CHECK: %[[FLOOR_0:.*]] = stablehlo.floor %[[ABS_0]] : tensor + // CHECK: %[[SUBTRACT_2:.*]] = stablehlo.subtract %[[ABS_0]], %[[FLOOR_0]] : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_0]], %[[SUBTRACT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_0]], %[[SUBTRACT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[SUBTRACT_3:.*]] = stablehlo.subtract %[[CONSTANT_1]], %[[SUBTRACT_2]] : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[SUBTRACT_3]], %[[SUBTRACT_2]] : tensor, tensor + // CHECK: %[[CONSTANT_22:.*]] = stablehlo.constant dense<3.1415926535897931> : tensor +@@ -1841,7 +1841,7 @@ + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_0]], %[[SELECT_2]], %[[ADD_20]] : tensor, tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_24:.*]] = stablehlo.constant dense<0x7FF0000000000000> : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_24]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_24]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_25:.*]] = stablehlo.constant dense<0x7FF0000000000000> : tensor + // CHECK: %[[SELECT_4:.*]] = stablehlo.select %[[COMPARE_2]], %[[CONSTANT_25]], %[[SELECT_3]] : tensor, tensor + // CHECK: return %[[SELECT_4]] : tensor +@@ -1856,7 +1856,7 @@ + // CHECK-LABEL: func.func @lgamma_f32( + // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[ARG0]], %[[CONSTANT_1]] : tensor +@@ -1919,7 +1919,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor + // CHECK: %[[FLOOR_0:.*]] = stablehlo.floor %[[ABS_0]] : tensor + // CHECK: %[[SUBTRACT_2:.*]] = stablehlo.subtract %[[ABS_0]], %[[FLOOR_0]] : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_0]], %[[SUBTRACT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_0]], %[[SUBTRACT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[SUBTRACT_3:.*]] = stablehlo.subtract %[[CONSTANT_1]], %[[SUBTRACT_2]] : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[SUBTRACT_3]], %[[SUBTRACT_2]] : tensor, tensor + // CHECK: %[[CONSTANT_22:.*]] = stablehlo.constant dense<3.14159274> : tensor +@@ -1935,7 +1935,7 @@ + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_0]], %[[SELECT_2]], %[[ADD_20]] : tensor, tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_24:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_24]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_24]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_25:.*]] = stablehlo.constant dense<0x7F800000> : tensor + // CHECK: %[[SELECT_4:.*]] = stablehlo.select %[[COMPARE_2]], %[[CONSTANT_25]], %[[SELECT_3]] : tensor, tensor + // CHECK: return %[[SELECT_4]] : tensor +@@ -1951,7 +1951,7 @@ + // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { + // CHECK: %[[CONVERT_0:.*]] = stablehlo.convert %[[ARG0]] : (tensor) -> tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[CONVERT_0]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[CONVERT_0]], %[[CONSTANT_1]] : tensor +@@ -2014,7 +2014,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[CONVERT_0]] : tensor + // CHECK: %[[FLOOR_0:.*]] = stablehlo.floor %[[ABS_0]] : tensor + // CHECK: %[[SUBTRACT_2:.*]] = stablehlo.subtract %[[ABS_0]], %[[FLOOR_0]] : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_0]], %[[SUBTRACT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_0]], %[[SUBTRACT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[SUBTRACT_3:.*]] = stablehlo.subtract %[[CONSTANT_1]], %[[SUBTRACT_2]] : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[SUBTRACT_3]], %[[SUBTRACT_2]] : tensor, tensor + // CHECK: %[[CONSTANT_22:.*]] = stablehlo.constant dense<3.14159274> : tensor +@@ -2030,7 +2030,7 @@ + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_0]], %[[SELECT_2]], %[[ADD_20]] : tensor, tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[CONVERT_0]] : tensor + // CHECK: %[[CONSTANT_24:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_24]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_24]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_25:.*]] = stablehlo.constant dense<0x7F800000> : tensor + // CHECK: %[[SELECT_4:.*]] = stablehlo.select %[[COMPARE_2]], %[[CONSTANT_25]], %[[SELECT_3]] : tensor, tensor + // CHECK: %[[CONVERT_1:.*]] = stablehlo.convert %[[SELECT_4]] : (tensor) -> tensor +@@ -2046,7 +2046,7 @@ + // CHECK-LABEL: func.func @digamma_f64( + // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[ARG0]], %[[CONSTANT_1]] : tensor +@@ -2141,9 +2141,9 @@ + // CHECK: %[[DIVIDE_19:.*]] = stablehlo.divide %[[MULTIPLY_9]], %[[SINE_0]] : tensor + // CHECK: %[[SUBTRACT_10:.*]] = stablehlo.subtract %[[SUBTRACT_9]], %[[DIVIDE_19]] : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_0]], %[[SUBTRACT_10]], %[[SUBTRACT_9]] : tensor, tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[ARG0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[ARG0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[FLOOR_1:.*]] = stablehlo.floor %[[ARG0]] : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_1]], %[[COMPARE_2]] : tensor + // CHECK: %[[CONSTANT_25:.*]] = stablehlo.constant dense<0x7FF8000000000000> : tensor + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[AND_0]], %[[CONSTANT_25]], %[[SELECT_1]] : tensor, tensor +@@ -2159,7 +2159,7 @@ + // CHECK-LABEL: func.func @digamma_f32( + // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[ARG0]], %[[CONSTANT_1]] : tensor +@@ -2254,9 +2254,9 @@ + // CHECK: %[[DIVIDE_19:.*]] = stablehlo.divide %[[MULTIPLY_9]], %[[SINE_0]] : tensor + // CHECK: %[[SUBTRACT_10:.*]] = stablehlo.subtract %[[SUBTRACT_9]], %[[DIVIDE_19]] : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_0]], %[[SUBTRACT_10]], %[[SUBTRACT_9]] : tensor, tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[ARG0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[ARG0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[FLOOR_1:.*]] = stablehlo.floor %[[ARG0]] : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_1]], %[[COMPARE_2]] : tensor + // CHECK: %[[CONSTANT_25:.*]] = stablehlo.constant dense<0x7FC00000> : tensor + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[AND_0]], %[[CONSTANT_25]], %[[SELECT_1]] : tensor, tensor +@@ -2273,7 +2273,7 @@ + // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { + // CHECK: %[[CONVERT_0:.*]] = stablehlo.convert %[[ARG0]] : (tensor) -> tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[CONVERT_0]] : tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[CONVERT_0]], %[[CONSTANT_1]] : tensor +@@ -2368,9 +2368,9 @@ + // CHECK: %[[DIVIDE_19:.*]] = stablehlo.divide %[[MULTIPLY_9]], %[[SINE_0]] : tensor + // CHECK: %[[SUBTRACT_10:.*]] = stablehlo.subtract %[[SUBTRACT_9]], %[[DIVIDE_19]] : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_0]], %[[SUBTRACT_10]], %[[SUBTRACT_9]] : tensor, tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[CONVERT_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LE, %[[CONVERT_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[FLOOR_1:.*]] = stablehlo.floor %[[CONVERT_0]] : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[CONVERT_0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[CONVERT_0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_1]], %[[COMPARE_2]] : tensor + // CHECK: %[[CONSTANT_25:.*]] = stablehlo.constant dense<0x7FC00000> : tensor + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[AND_0]], %[[CONSTANT_25]], %[[SELECT_1]] : tensor, tensor +@@ -2540,29 +2540,29 @@ + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[ADD_17]] : tensor + // CHECK: %[[CONSTANT_38:.*]] = stablehlo.constant dense<1.401300e-45> : tensor + // CHECK: %[[MULTIPLY_37:.*]] = stablehlo.multiply %[[ABS_1]], %[[CONSTANT_38]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_37]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_37]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[ADD_17]], %[[ADD_55]] : tensor, tensor + // CHECK: %[[CONSTANT_39:.*]] = stablehlo.constant dense<0x7FC00000> : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[CONSTANT_39]], %[[SELECT_0]] : tensor, tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LE, %[[CONVERT_1]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare LE, %[[CONVERT_1]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[FLOOR_0:.*]] = stablehlo.floor %[[CONVERT_0]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare NE, %[[CONVERT_0]], %[[FLOOR_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare NE, %[[CONVERT_0]], %[[FLOOR_0]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_2]], %[[COMPARE_3]] : tensor + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[AND_0]], %[[CONSTANT_39]], %[[SELECT_1]] : tensor, tensor + // CHECK: %[[CONSTANT_40:.*]] = stablehlo.constant dense<0x7F800000> : tensor + // CHECK: %[[FLOOR_1:.*]] = stablehlo.floor %[[CONVERT_1]] : tensor +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare EQ, %[[CONVERT_1]], %[[FLOOR_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare EQ, %[[CONVERT_1]], %[[FLOOR_1]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[COMPARE_2]], %[[COMPARE_4]] : tensor + // CHECK: %[[CONSTANT_41:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[FLOOR_2:.*]] = stablehlo.floor %[[CONVERT_0]] : tensor +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare EQ, %[[CONVERT_0]], %[[FLOOR_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare EQ, %[[CONVERT_0]], %[[FLOOR_2]] : (tensor, tensor) -> tensor + // CHECK: %[[REMAINDER_0:.*]] = stablehlo.remainder %[[CONVERT_0]], %[[CONSTANT_41]] : tensor +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[REMAINDER_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[REMAINDER_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[COMPARE_5]], %[[COMPARE_6]] : tensor + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[AND_2]], %[[CONSTANT_40]], %[[CONSTANT_39]] : tensor, tensor + // CHECK: %[[SELECT_4:.*]] = stablehlo.select %[[AND_1]], %[[SELECT_3]], %[[SELECT_2]] : tensor, tensor +-// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[CONVERT_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[CONVERT_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_5:.*]] = stablehlo.select %[[COMPARE_7]], %[[CONSTANT_40]], %[[SELECT_4]] : tensor, tensor + // CHECK: %[[CONVERT_2:.*]] = stablehlo.convert %[[SELECT_5]] : (tensor) -> tensor + // CHECK: return %[[CONVERT_2]] : tensor +@@ -2584,7 +2584,7 @@ + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[MULTIPLY_0]], %[[CONSTANT_0]] : tensor + // CHECK: %[[ADD_0:.*]] = stablehlo.add %[[ARG0]], %[[CONSTANT_0]] : tensor + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[ADD_0]] : tensor + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[ADD_0]], %[[CONSTANT_3]] : tensor +@@ -2647,7 +2647,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ADD_0]] : tensor + // CHECK: %[[FLOOR_0:.*]] = stablehlo.floor %[[ABS_0]] : tensor + // CHECK: %[[SUBTRACT_3:.*]] = stablehlo.subtract %[[ABS_0]], %[[FLOOR_0]] : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_2]], %[[SUBTRACT_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_2]], %[[SUBTRACT_3]] : (tensor, tensor) -> tensor + // CHECK: %[[SUBTRACT_4:.*]] = stablehlo.subtract %[[CONSTANT_3]], %[[SUBTRACT_3]] : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[SUBTRACT_4]], %[[SUBTRACT_3]] : tensor, tensor + // CHECK: %[[CONSTANT_24:.*]] = stablehlo.constant dense<3.14159274> : tensor +@@ -2663,7 +2663,7 @@ + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_0]], %[[SELECT_2]], %[[ADD_21]] : tensor, tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[ADD_0]] : tensor + // CHECK: %[[CONSTANT_26:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_26]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_26]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_27:.*]] = stablehlo.constant dense<0x7F800000> : tensor + // CHECK: %[[SELECT_4:.*]] = stablehlo.select %[[COMPARE_2]], %[[CONSTANT_27]], %[[SELECT_3]] : tensor, tensor + // CHECK: %[[EXPONENTIAL_0:.*]] = stablehlo.exponential %[[SELECT_4]] : tensor +@@ -2818,36 +2818,36 @@ + // CHECK: %[[ABS_3:.*]] = stablehlo.abs %[[ADD_39]] : tensor + // CHECK: %[[CONSTANT_66:.*]] = stablehlo.constant dense<1.401300e-45> : tensor + // CHECK: %[[MULTIPLY_40:.*]] = stablehlo.multiply %[[ABS_3]], %[[CONSTANT_66]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_2]], %[[MULTIPLY_40]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_2]], %[[MULTIPLY_40]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_5:.*]] = stablehlo.select %[[COMPARE_3]], %[[ADD_39]], %[[ADD_77]] : tensor, tensor + // CHECK: %[[CONSTANT_67:.*]] = stablehlo.constant dense<0x7FC00000> : tensor +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_30]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_30]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_6:.*]] = stablehlo.select %[[COMPARE_4]], %[[CONSTANT_67]], %[[SELECT_5]] : tensor, tensor +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LE, %[[ARG1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LE, %[[ARG1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor + // CHECK: %[[FLOOR_1:.*]] = stablehlo.floor %[[ADD_0]] : tensor +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare NE, %[[ADD_0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare NE, %[[ADD_0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_5]], %[[COMPARE_6]] : tensor + // CHECK: %[[SELECT_7:.*]] = stablehlo.select %[[AND_0]], %[[CONSTANT_67]], %[[SELECT_6]] : tensor, tensor + // CHECK: %[[CONSTANT_68:.*]] = stablehlo.constant dense<0x7F800000> : tensor + // CHECK: %[[FLOOR_2:.*]] = stablehlo.floor %[[ARG1]] : tensor +-// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[FLOOR_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[FLOOR_2]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[COMPARE_5]], %[[COMPARE_7]] : tensor + // CHECK: %[[CONSTANT_69:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[FLOOR_3:.*]] = stablehlo.floor %[[ADD_0]] : tensor +-// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[FLOOR_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[FLOOR_3]] : (tensor, tensor) -> tensor + // CHECK: %[[REMAINDER_1:.*]] = stablehlo.remainder %[[ADD_0]], %[[CONSTANT_69]] : tensor +-// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare EQ, %[[REMAINDER_1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare EQ, %[[REMAINDER_1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[COMPARE_8]], %[[COMPARE_9]] : tensor + // CHECK: %[[SELECT_8:.*]] = stablehlo.select %[[AND_2]], %[[CONSTANT_68]], %[[CONSTANT_67]] : tensor, tensor + // CHECK: %[[SELECT_9:.*]] = stablehlo.select %[[AND_1]], %[[SELECT_8]], %[[SELECT_7]] : tensor, tensor +-// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[CONSTANT_29]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[CONSTANT_29]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_10:.*]] = stablehlo.select %[[COMPARE_10]], %[[CONSTANT_68]], %[[SELECT_9]] : tensor, tensor + // CHECK: %[[MULTIPLY_41:.*]] = stablehlo.multiply %[[SUBTRACT_0]], %[[EXPONENTIAL_0]] : tensor + // CHECK: %[[MULTIPLY_42:.*]] = stablehlo.multiply %[[MULTIPLY_41]], %[[SELECT_10]] : tensor + // CHECK: %[[CONSTANT_70:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_71:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[ARG1]], %[[CONSTANT_71]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[ARG1]], %[[CONSTANT_71]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_3:.*]] = stablehlo.negate %[[ARG1]] : tensor + // CHECK: %[[CONSTANT_72:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_8:.*]] = stablehlo.subtract %[[ARG1]], %[[CONSTANT_72]] : tensor +@@ -2942,16 +2942,16 @@ + // CHECK: %[[DIVIDE_32:.*]] = stablehlo.divide %[[MULTIPLY_52]], %[[SINE_1]] : tensor + // CHECK: %[[SUBTRACT_18:.*]] = stablehlo.subtract %[[SUBTRACT_17]], %[[DIVIDE_32]] : tensor + // CHECK: %[[SELECT_12:.*]] = stablehlo.select %[[COMPARE_12]], %[[SUBTRACT_18]], %[[SUBTRACT_17]] : tensor, tensor +-// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LE, %[[ARG1]], %[[CONSTANT_73]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LE, %[[ARG1]], %[[CONSTANT_73]] : (tensor, tensor) -> tensor + // CHECK: %[[FLOOR_5:.*]] = stablehlo.floor %[[ARG1]] : tensor +-// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[FLOOR_5]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[FLOOR_5]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_3:.*]] = stablehlo.and %[[COMPARE_13]], %[[COMPARE_14]] : tensor + // CHECK: %[[CONSTANT_96:.*]] = stablehlo.constant dense<0x7FC00000> : tensor + // CHECK: %[[SELECT_13:.*]] = stablehlo.select %[[AND_3]], %[[CONSTANT_96]], %[[SELECT_12]] : tensor, tensor + // CHECK: %[[SELECT_14:.*]] = stablehlo.select %[[COMPARE_11]], %[[SELECT_13]], %[[MULTIPLY_42]] : tensor, tensor + // CHECK: %[[FLOOR_6:.*]] = stablehlo.floor %[[ARG0]] : tensor +-// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare NE, %[[ARG0]], %[[FLOOR_6]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare NE, %[[ARG0]], %[[FLOOR_6]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor + // CHECK: %[[OR_0:.*]] = stablehlo.or %[[COMPARE_15]], %[[COMPARE_16]] : tensor + // CHECK: %[[CONSTANT_97:.*]] = stablehlo.constant dense<0x7FC00000> : tensor + // CHECK: %[[SELECT_15:.*]] = stablehlo.select %[[OR_0]], %[[CONSTANT_97]], %[[SELECT_14]] : tensor, tensor +@@ -2974,7 +2974,7 @@ + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[MULTIPLY_0]], %[[CONSTANT_0]] : tensor + // CHECK: %[[ADD_0:.*]] = stablehlo.add %[[ARG0]], %[[CONSTANT_0]] : tensor + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[ADD_0]] : tensor + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[ADD_0]], %[[CONSTANT_3]] : tensor +@@ -3037,7 +3037,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ADD_0]] : tensor + // CHECK: %[[FLOOR_0:.*]] = stablehlo.floor %[[ABS_0]] : tensor + // CHECK: %[[SUBTRACT_3:.*]] = stablehlo.subtract %[[ABS_0]], %[[FLOOR_0]] : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_2]], %[[SUBTRACT_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_2]], %[[SUBTRACT_3]] : (tensor, tensor) -> tensor + // CHECK: %[[SUBTRACT_4:.*]] = stablehlo.subtract %[[CONSTANT_3]], %[[SUBTRACT_3]] : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[SUBTRACT_4]], %[[SUBTRACT_3]] : tensor, tensor + // CHECK: %[[CONSTANT_24:.*]] = stablehlo.constant dense<3.1415926535897931> : tensor +@@ -3053,7 +3053,7 @@ + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_0]], %[[SELECT_2]], %[[ADD_21]] : tensor, tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[ADD_0]] : tensor + // CHECK: %[[CONSTANT_26:.*]] = stablehlo.constant dense<0x7FF0000000000000> : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_26]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_26]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_27:.*]] = stablehlo.constant dense<0x7FF0000000000000> : tensor + // CHECK: %[[SELECT_4:.*]] = stablehlo.select %[[COMPARE_2]], %[[CONSTANT_27]], %[[SELECT_3]] : tensor, tensor + // CHECK: %[[EXPONENTIAL_0:.*]] = stablehlo.exponential %[[SELECT_4]] : tensor +@@ -3208,36 +3208,36 @@ + // CHECK: %[[ABS_3:.*]] = stablehlo.abs %[[ADD_39]] : tensor + // CHECK: %[[CONSTANT_66:.*]] = stablehlo.constant dense<4.940660e-324> : tensor + // CHECK: %[[MULTIPLY_40:.*]] = stablehlo.multiply %[[ABS_3]], %[[CONSTANT_66]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_2]], %[[MULTIPLY_40]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_2]], %[[MULTIPLY_40]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_5:.*]] = stablehlo.select %[[COMPARE_3]], %[[ADD_39]], %[[ADD_77]] : tensor, tensor + // CHECK: %[[CONSTANT_67:.*]] = stablehlo.constant dense<0x7FF8000000000000> : tensor +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_30]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_30]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_6:.*]] = stablehlo.select %[[COMPARE_4]], %[[CONSTANT_67]], %[[SELECT_5]] : tensor, tensor +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LE, %[[ARG1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LE, %[[ARG1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor + // CHECK: %[[FLOOR_1:.*]] = stablehlo.floor %[[ADD_0]] : tensor +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare NE, %[[ADD_0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare NE, %[[ADD_0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_5]], %[[COMPARE_6]] : tensor + // CHECK: %[[SELECT_7:.*]] = stablehlo.select %[[AND_0]], %[[CONSTANT_67]], %[[SELECT_6]] : tensor, tensor + // CHECK: %[[CONSTANT_68:.*]] = stablehlo.constant dense<0x7FF0000000000000> : tensor + // CHECK: %[[FLOOR_2:.*]] = stablehlo.floor %[[ARG1]] : tensor +-// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[FLOOR_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[FLOOR_2]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[COMPARE_5]], %[[COMPARE_7]] : tensor + // CHECK: %[[CONSTANT_69:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[FLOOR_3:.*]] = stablehlo.floor %[[ADD_0]] : tensor +-// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[FLOOR_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[FLOOR_3]] : (tensor, tensor) -> tensor + // CHECK: %[[REMAINDER_1:.*]] = stablehlo.remainder %[[ADD_0]], %[[CONSTANT_69]] : tensor +-// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare EQ, %[[REMAINDER_1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare EQ, %[[REMAINDER_1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[COMPARE_8]], %[[COMPARE_9]] : tensor + // CHECK: %[[SELECT_8:.*]] = stablehlo.select %[[AND_2]], %[[CONSTANT_68]], %[[CONSTANT_67]] : tensor, tensor + // CHECK: %[[SELECT_9:.*]] = stablehlo.select %[[AND_1]], %[[SELECT_8]], %[[SELECT_7]] : tensor, tensor +-// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[CONSTANT_29]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[CONSTANT_29]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_10:.*]] = stablehlo.select %[[COMPARE_10]], %[[CONSTANT_68]], %[[SELECT_9]] : tensor, tensor + // CHECK: %[[MULTIPLY_41:.*]] = stablehlo.multiply %[[SUBTRACT_0]], %[[EXPONENTIAL_0]] : tensor + // CHECK: %[[MULTIPLY_42:.*]] = stablehlo.multiply %[[MULTIPLY_41]], %[[SELECT_10]] : tensor + // CHECK: %[[CONSTANT_70:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_71:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[ARG1]], %[[CONSTANT_71]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[ARG1]], %[[CONSTANT_71]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_3:.*]] = stablehlo.negate %[[ARG1]] : tensor + // CHECK: %[[CONSTANT_72:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_8:.*]] = stablehlo.subtract %[[ARG1]], %[[CONSTANT_72]] : tensor +@@ -3332,16 +3332,16 @@ + // CHECK: %[[DIVIDE_32:.*]] = stablehlo.divide %[[MULTIPLY_52]], %[[SINE_1]] : tensor + // CHECK: %[[SUBTRACT_18:.*]] = stablehlo.subtract %[[SUBTRACT_17]], %[[DIVIDE_32]] : tensor + // CHECK: %[[SELECT_12:.*]] = stablehlo.select %[[COMPARE_12]], %[[SUBTRACT_18]], %[[SUBTRACT_17]] : tensor, tensor +-// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LE, %[[ARG1]], %[[CONSTANT_73]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LE, %[[ARG1]], %[[CONSTANT_73]] : (tensor, tensor) -> tensor + // CHECK: %[[FLOOR_5:.*]] = stablehlo.floor %[[ARG1]] : tensor +-// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[FLOOR_5]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[FLOOR_5]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_3:.*]] = stablehlo.and %[[COMPARE_13]], %[[COMPARE_14]] : tensor + // CHECK: %[[CONSTANT_96:.*]] = stablehlo.constant dense<0x7FF8000000000000> : tensor + // CHECK: %[[SELECT_13:.*]] = stablehlo.select %[[AND_3]], %[[CONSTANT_96]], %[[SELECT_12]] : tensor, tensor + // CHECK: %[[SELECT_14:.*]] = stablehlo.select %[[COMPARE_11]], %[[SELECT_13]], %[[MULTIPLY_42]] : tensor, tensor + // CHECK: %[[FLOOR_6:.*]] = stablehlo.floor %[[ARG0]] : tensor +-// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare NE, %[[ARG0]], %[[FLOOR_6]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare NE, %[[ARG0]], %[[FLOOR_6]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor + // CHECK: %[[OR_0:.*]] = stablehlo.or %[[COMPARE_15]], %[[COMPARE_16]] : tensor + // CHECK: %[[CONSTANT_97:.*]] = stablehlo.constant dense<0x7FF8000000000000> : tensor + // CHECK: %[[SELECT_15:.*]] = stablehlo.select %[[OR_0]], %[[CONSTANT_97]], %[[SELECT_14]] : tensor, tensor +@@ -3366,7 +3366,7 @@ + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[MULTIPLY_0]], %[[CONSTANT_0]] : tensor + // CHECK: %[[ADD_0:.*]] = stablehlo.add %[[CONVERT_0]], %[[CONSTANT_0]] : tensor + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_2]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[ADD_0]] : tensor + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[ADD_0]], %[[CONSTANT_3]] : tensor +@@ -3429,7 +3429,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ADD_0]] : tensor + // CHECK: %[[FLOOR_0:.*]] = stablehlo.floor %[[ABS_0]] : tensor + // CHECK: %[[SUBTRACT_3:.*]] = stablehlo.subtract %[[ABS_0]], %[[FLOOR_0]] : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_2]], %[[SUBTRACT_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[CONSTANT_2]], %[[SUBTRACT_3]] : (tensor, tensor) -> tensor + // CHECK: %[[SUBTRACT_4:.*]] = stablehlo.subtract %[[CONSTANT_3]], %[[SUBTRACT_3]] : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[SUBTRACT_4]], %[[SUBTRACT_3]] : tensor, tensor + // CHECK: %[[CONSTANT_24:.*]] = stablehlo.constant dense<3.14159274> : tensor +@@ -3445,7 +3445,7 @@ + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[COMPARE_0]], %[[SELECT_2]], %[[ADD_21]] : tensor, tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[ADD_0]] : tensor + // CHECK: %[[CONSTANT_26:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_26]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_1]], %[[CONSTANT_26]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_27:.*]] = stablehlo.constant dense<0x7F800000> : tensor + // CHECK: %[[SELECT_4:.*]] = stablehlo.select %[[COMPARE_2]], %[[CONSTANT_27]], %[[SELECT_3]] : tensor, tensor + // CHECK: %[[EXPONENTIAL_0:.*]] = stablehlo.exponential %[[SELECT_4]] : tensor +@@ -3600,36 +3600,36 @@ + // CHECK: %[[ABS_3:.*]] = stablehlo.abs %[[ADD_39]] : tensor + // CHECK: %[[CONSTANT_66:.*]] = stablehlo.constant dense<1.401300e-45> : tensor + // CHECK: %[[MULTIPLY_40:.*]] = stablehlo.multiply %[[ABS_3]], %[[CONSTANT_66]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_2]], %[[MULTIPLY_40]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_2]], %[[MULTIPLY_40]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_5:.*]] = stablehlo.select %[[COMPARE_3]], %[[ADD_39]], %[[ADD_77]] : tensor, tensor + // CHECK: %[[CONSTANT_67:.*]] = stablehlo.constant dense<0x7FC00000> : tensor +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_30]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare LT, %[[ADD_0]], %[[CONSTANT_30]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_6:.*]] = stablehlo.select %[[COMPARE_4]], %[[CONSTANT_67]], %[[SELECT_5]] : tensor, tensor +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LE, %[[CONVERT_1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LE, %[[CONVERT_1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor + // CHECK: %[[FLOOR_1:.*]] = stablehlo.floor %[[ADD_0]] : tensor +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare NE, %[[ADD_0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare NE, %[[ADD_0]], %[[FLOOR_1]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_5]], %[[COMPARE_6]] : tensor + // CHECK: %[[SELECT_7:.*]] = stablehlo.select %[[AND_0]], %[[CONSTANT_67]], %[[SELECT_6]] : tensor, tensor + // CHECK: %[[CONSTANT_68:.*]] = stablehlo.constant dense<0x7F800000> : tensor + // CHECK: %[[FLOOR_2:.*]] = stablehlo.floor %[[CONVERT_1]] : tensor +-// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[CONVERT_1]], %[[FLOOR_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[CONVERT_1]], %[[FLOOR_2]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[COMPARE_5]], %[[COMPARE_7]] : tensor + // CHECK: %[[CONSTANT_69:.*]] = stablehlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[FLOOR_3:.*]] = stablehlo.floor %[[ADD_0]] : tensor +-// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[FLOOR_3]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[FLOOR_3]] : (tensor, tensor) -> tensor + // CHECK: %[[REMAINDER_1:.*]] = stablehlo.remainder %[[ADD_0]], %[[CONSTANT_69]] : tensor +-// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare EQ, %[[REMAINDER_1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare EQ, %[[REMAINDER_1]], %[[CONSTANT_28]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[COMPARE_8]], %[[COMPARE_9]] : tensor + // CHECK: %[[SELECT_8:.*]] = stablehlo.select %[[AND_2]], %[[CONSTANT_68]], %[[CONSTANT_67]] : tensor, tensor + // CHECK: %[[SELECT_9:.*]] = stablehlo.select %[[AND_1]], %[[SELECT_8]], %[[SELECT_7]] : tensor, tensor +-// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[CONSTANT_29]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare EQ, %[[ADD_0]], %[[CONSTANT_29]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_10:.*]] = stablehlo.select %[[COMPARE_10]], %[[CONSTANT_68]], %[[SELECT_9]] : tensor, tensor + // CHECK: %[[MULTIPLY_41:.*]] = stablehlo.multiply %[[SUBTRACT_0]], %[[EXPONENTIAL_0]] : tensor + // CHECK: %[[MULTIPLY_42:.*]] = stablehlo.multiply %[[MULTIPLY_41]], %[[SELECT_10]] : tensor + // CHECK: %[[CONSTANT_70:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[CONVERT_0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_11:.*]] = stablehlo.compare EQ, %[[CONVERT_0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_71:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +-// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[CONVERT_1]], %[[CONSTANT_71]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_12:.*]] = stablehlo.compare LT, %[[CONVERT_1]], %[[CONSTANT_71]] : (tensor, tensor) -> tensor + // CHECK: %[[NEGATE_3:.*]] = stablehlo.negate %[[CONVERT_1]] : tensor + // CHECK: %[[CONSTANT_72:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[SUBTRACT_8:.*]] = stablehlo.subtract %[[CONVERT_1]], %[[CONSTANT_72]] : tensor +@@ -3724,16 +3724,16 @@ + // CHECK: %[[DIVIDE_32:.*]] = stablehlo.divide %[[MULTIPLY_52]], %[[SINE_1]] : tensor + // CHECK: %[[SUBTRACT_18:.*]] = stablehlo.subtract %[[SUBTRACT_17]], %[[DIVIDE_32]] : tensor + // CHECK: %[[SELECT_12:.*]] = stablehlo.select %[[COMPARE_12]], %[[SUBTRACT_18]], %[[SUBTRACT_17]] : tensor, tensor +-// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LE, %[[CONVERT_1]], %[[CONSTANT_73]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_13:.*]] = stablehlo.compare LE, %[[CONVERT_1]], %[[CONSTANT_73]] : (tensor, tensor) -> tensor + // CHECK: %[[FLOOR_5:.*]] = stablehlo.floor %[[CONVERT_1]] : tensor +-// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare EQ, %[[CONVERT_1]], %[[FLOOR_5]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_14:.*]] = stablehlo.compare EQ, %[[CONVERT_1]], %[[FLOOR_5]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_3:.*]] = stablehlo.and %[[COMPARE_13]], %[[COMPARE_14]] : tensor + // CHECK: %[[CONSTANT_96:.*]] = stablehlo.constant dense<0x7FC00000> : tensor + // CHECK: %[[SELECT_13:.*]] = stablehlo.select %[[AND_3]], %[[CONSTANT_96]], %[[SELECT_12]] : tensor, tensor + // CHECK: %[[SELECT_14:.*]] = stablehlo.select %[[COMPARE_11]], %[[SELECT_13]], %[[MULTIPLY_42]] : tensor, tensor + // CHECK: %[[FLOOR_6:.*]] = stablehlo.floor %[[CONVERT_0]] : tensor +-// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare NE, %[[CONVERT_0]], %[[FLOOR_6]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_15:.*]] = stablehlo.compare NE, %[[CONVERT_0]], %[[FLOOR_6]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_16:.*]] = stablehlo.compare LT, %[[CONVERT_0]], %[[CONSTANT_70]] : (tensor, tensor) -> tensor + // CHECK: %[[OR_0:.*]] = stablehlo.or %[[COMPARE_15]], %[[COMPARE_16]] : tensor + // CHECK: %[[CONSTANT_97:.*]] = stablehlo.constant dense<0x7FC00000> : tensor + // CHECK: %[[SELECT_15:.*]] = stablehlo.select %[[OR_0]], %[[CONSTANT_97]], %[[SELECT_14]] : tensor, tensor +@@ -3764,7 +3764,7 @@ + // CHECK: %[[ADD_2:.*]] = stablehlo.add %[[EXPONENTIAL_MINUS_ONE_0]], %[[DIVIDE_0]] : tensor + // CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[CONSTANT_2]], %[[ADD_2]] : tensor + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[MULTIPLY_0]], %[[SUBTRACT_1]] : tensor, tensor + // CHECK: return %[[SELECT_0]] : tensor + // CHECK: } +@@ -3793,7 +3793,7 @@ + // CHECK: %[[ADD_2:.*]] = stablehlo.add %[[EXPONENTIAL_MINUS_ONE_0]], %[[DIVIDE_0]] : tensor + // CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[CONSTANT_2]], %[[ADD_2]] : tensor + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[CONVERT_0]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[CONSTANT_1]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[MULTIPLY_0]], %[[SUBTRACT_1]] : tensor, tensor + // CHECK: %[[CONVERT_1:.*]] = stablehlo.convert %[[SELECT_0]] : (tensor) -> tensor + // CHECK: return %[[CONVERT_1]] : tensor +@@ -3883,7 +3883,7 @@ + // CHECK-SAME: %[[ARG0:.*]]: tensor) -> tensor { + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GT, %[[ABS_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GT, %[[ABS_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<0x7FC00000> : tensor + // CHECK: %[[LOG_PLUS_ONE_0:.*]] = stablehlo.log_plus_one %[[ARG0]] : tensor + // CHECK: %[[NEGATE_0:.*]] = stablehlo.negate %[[ARG0]] : tensor +@@ -3905,7 +3905,7 @@ + // CHECK-SAME: %[[ARG0:.*]]: tensor>) -> tensor> { + // CHECK: %[[REAL_0:.*]] = stablehlo.real %[[ARG0]] : (tensor>) -> tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[REAL_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[REAL_0]], %[[CONSTANT_0]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<-1.000000e+00> : tensor + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[CONSTANT_1]], %[[CONSTANT_2]] : tensor, tensor +@@ -3913,19 +3913,19 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[REAL_0]] : tensor + // CHECK: %[[CONSTANT_4:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor + // CHECK: %[[CONSTANT_5:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare GT, %[[CONSTANT_4]], %[[CONSTANT_5]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare GT, %[[CONSTANT_4]], %[[CONSTANT_5]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_6:.*]] = stablehlo.constant dense<9.00719925E+15> : tensor + // CHECK: %[[CONSTANT_7:.*]] = stablehlo.constant dense<9.99999968E+37> : tensor +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare GT, %[[CONSTANT_4]], %[[CONSTANT_7]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare GT, %[[CONSTANT_4]], %[[CONSTANT_7]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_8:.*]] = stablehlo.constant dense<0x4B800001> : tensor + // CHECK: %[[CONSTANT_9:.*]] = stablehlo.constant dense<2.050000e+03> : tensor + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_2]], %[[CONSTANT_8]], %[[CONSTANT_9]] : tensor, tensor + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[COMPARE_1]], %[[CONSTANT_6]], %[[SELECT_1]] : tensor, tensor + // CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[SELECT_2]], %[[SELECT_2]] : tensor +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[ABS_0]], %[[MULTIPLY_0]] : (tensor, tensor) -> tensor + // CHECK: %[[IMAG_0:.*]] = stablehlo.imag %[[ARG0]] : (tensor>) -> tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[IMAG_0]] : tensor +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[MULTIPLY_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare LT, %[[ABS_1]], %[[MULTIPLY_0]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_3]], %[[COMPARE_4]] : tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[CONSTANT_1]], %[[ABS_0]] : tensor + // CHECK: %[[MULTIPLY_1:.*]] = stablehlo.multiply %[[SUBTRACT_0]], %[[SUBTRACT_0]] : tensor +@@ -3933,15 +3933,15 @@ + // CHECK: %[[ADD_0:.*]] = stablehlo.add %[[MULTIPLY_1]], %[[MULTIPLY_2]] : tensor + // CHECK: %[[DIVIDE_0:.*]] = stablehlo.divide %[[ABS_0]], %[[ADD_0]] : tensor + // CHECK: %[[MULTIPLY_3:.*]] = stablehlo.multiply %[[ABS_1]], %[[SELECT_2]] : tensor +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LT, %[[MULTIPLY_3]], %[[ABS_0]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LT, %[[MULTIPLY_3]], %[[ABS_0]] : (tensor, tensor) -> tensor + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[CONSTANT_1]], %[[ABS_0]] : tensor + // CHECK: %[[CONSTANT_10:.*]] = stablehlo.constant dense<0x7F800000> : tensor +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[REAL_0]], %[[CONSTANT_10]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare EQ, %[[REAL_0]], %[[CONSTANT_10]] : (tensor, tensor) -> tensor + // CHECK: %[[CONSTANT_11:.*]] = stablehlo.constant dense<0xFF800000> : tensor +-// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[REAL_0]], %[[CONSTANT_11]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare EQ, %[[REAL_0]], %[[CONSTANT_11]] : (tensor, tensor) -> tensor + // CHECK: %[[OR_0:.*]] = stablehlo.or %[[COMPARE_6]], %[[COMPARE_7]] : tensor +-// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare EQ, %[[IMAG_0]], %[[CONSTANT_10]] : (tensor, tensor) -> tensor +-// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare EQ, %[[IMAG_0]], %[[CONSTANT_11]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_8:.*]] = stablehlo.compare EQ, %[[IMAG_0]], %[[CONSTANT_10]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_9:.*]] = stablehlo.compare EQ, %[[IMAG_0]], %[[CONSTANT_11]] : (tensor, tensor) -> tensor + // CHECK: %[[OR_1:.*]] = stablehlo.or %[[COMPARE_8]], %[[COMPARE_9]] : tensor + // CHECK: %[[OR_2:.*]] = stablehlo.or %[[OR_0]], %[[OR_1]] : tensor + // CHECK: %[[DIVIDE_2:.*]] = stablehlo.divide %[[ABS_0]], %[[IMAG_0]] : tensor +@@ -3963,7 +3963,7 @@ + // CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[MULTIPLY_7]], %[[MULTIPLY_2]] : tensor + // CHECK: %[[VAL_0:.*]] = stablehlo.atan2 %[[ADD_2]], %[[SUBTRACT_1]] : tensor + // CHECK: %[[CONSTANT_13:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +-// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare GE, %[[IMAG_0]], %[[CONSTANT_13]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_10:.*]] = stablehlo.compare GE, %[[IMAG_0]], %[[CONSTANT_13]] : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_6:.*]] = stablehlo.select %[[COMPARE_10]], %[[CONSTANT_1]], %[[CONSTANT_2]] : tensor, tensor + // CHECK: %[[CONSTANT_14:.*]] = stablehlo.constant dense<3.14159274> : tensor + // CHECK: %[[MULTIPLY_8:.*]] = stablehlo.multiply %[[SELECT_6]], %[[CONSTANT_14]] : tensor +@@ -3985,8 +3985,8 @@ + // CHECK-SAME: %[[ARG1:.*]]: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: %[[BITCAST_CONVERT_0:.*]] = stablehlo.bitcast_convert %[[ARG0]] : (tensor<2xf32>) -> tensor<2xi32> + // CHECK: %[[BITCAST_CONVERT_1:.*]] = stablehlo.bitcast_convert %[[ARG1]] : (tensor<2xf32>) -> tensor<2xi32> +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare NE, %[[ARG0]], %[[ARG0]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare NE, %[[ARG1]], %[[ARG1]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare NE, %[[ARG0]], %[[ARG0]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare NE, %[[ARG1]], %[[ARG1]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHECK: %[[OR_0:.*]] = stablehlo.or %[[COMPARE_0]], %[[COMPARE_1]] : tensor<2xi1> + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0x7FC00000> : tensor<2xf32> + // CHECK: %[[BITCAST_CONVERT_2:.*]] = stablehlo.bitcast_convert %[[CONSTANT_0]] : (tensor<2xf32>) -> tensor<2xi32> +@@ -3994,16 +3994,16 @@ + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<2147483647> : tensor<2xi32> + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[BITCAST_CONVERT_0]], %[[CONSTANT_2]] : tensor<2xi32> + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[BITCAST_CONVERT_1]], %[[CONSTANT_2]] : tensor<2xi32> +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[ARG1]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ARG0]], %[[ARG1]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<0> : tensor<2xi32> +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare EQ, %[[AND_0]], %[[CONSTANT_3]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare EQ, %[[AND_1]], %[[CONSTANT_3]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare EQ, %[[AND_0]], %[[CONSTANT_3]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare EQ, %[[AND_1]], %[[CONSTANT_3]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[BITCAST_CONVERT_0]], %[[CONSTANT_1]] : tensor<2xi32> + // CHECK: %[[AND_3:.*]] = stablehlo.and %[[BITCAST_CONVERT_1]], %[[CONSTANT_1]] : tensor<2xi32> + // CHECK: %[[CONSTANT_4:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK: %[[OR_1:.*]] = stablehlo.or %[[AND_3]], %[[CONSTANT_4]] : tensor<2xi32> +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare NE, %[[AND_2]], %[[AND_3]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare GT, %[[AND_0]], %[[AND_1]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare NE, %[[AND_2]], %[[AND_3]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare GT, %[[AND_0]], %[[AND_1]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK: %[[OR_2:.*]] = stablehlo.or %[[COMPARE_6]], %[[COMPARE_5]] : tensor<2xi1> + // CHECK: %[[CONSTANT_5:.*]] = stablehlo.constant dense<-1> : tensor<2xi32> + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[OR_2]], %[[CONSTANT_5]], %[[CONSTANT_4]] : tensor<2xi1>, tensor<2xi32> +@@ -4058,7 +4058,7 @@ + // CHECK: %[[IOTA_0:.*]] = stablehlo.iota dim = 1 : tensor<16x16xi32> + // CHECK: %[[VAL_0:.*]]:2 = "stablehlo.sort"(%[[ARG0]], %[[IOTA_0]]) <{dimension = 1 : i64, is_stable = true}> ({ + // CHECK: ^bb0(%[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor, %[[VAL_4:.*]]: tensor): +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GT, %[[VAL_1]], %[[VAL_2]], TOTALORDER : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GT, %[[VAL_1]], %[[VAL_2]], TOTALORDER : (tensor, tensor) -> tensor + // CHECK: stablehlo.return %[[COMPARE_0]] : tensor + // CHECK: }) : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) + // CHECK: %[[SLICE_0:.*]] = stablehlo.slice %[[VAL_5:.*]]#0 [0:16, 0:8] : (tensor<16x16xf32>) -> tensor<16x8xf32> +@@ -4077,25 +4077,25 @@ + // CHECK: %[[GET_DIMENSION_SIZE_0:.*]] = stablehlo.get_dimension_size %[[ARG0]], dim = 0 : (tensor) -> tensor + // CHECK: %[[CONVERT_0:.*]] = stablehlo.convert %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor + // CHECK: %[[RESHAPE_0:.*]] = stablehlo.reshape %[[CONVERT_0]] : (tensor) -> tensor<1xi64> +-// CHECK: %[[C_1:.*]] = stablehlo.constant dense<5> : tensor +-// CHECK: %[[RESHAPE_1:.*]] = stablehlo.reshape %[[C_1]] : (tensor) -> tensor<1xi64> +-// CHECK: %[[GET_DIMENSION_SIZE_2:.*]] = stablehlo.get_dimension_size %[[ARG0]], dim = 2 : (tensor) -> tensor +-// CHECK: %[[CONVERT_2:.*]] = stablehlo.convert %[[GET_DIMENSION_SIZE_2]] : (tensor) -> tensor +-// CHECK: %[[RESHAPE_2:.*]] = stablehlo.reshape %[[CONVERT_2]] : (tensor) -> tensor<1xi64> ++// CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<5> : tensor ++// CHECK: %[[RESHAPE_1:.*]] = stablehlo.reshape %[[CONSTANT_0]] : (tensor) -> tensor<1xi64> ++// CHECK: %[[GET_DIMENSION_SIZE_1:.*]] = stablehlo.get_dimension_size %[[ARG0]], dim = 2 : (tensor) -> tensor ++// CHECK: %[[CONVERT_1:.*]] = stablehlo.convert %[[GET_DIMENSION_SIZE_1]] : (tensor) -> tensor ++// CHECK: %[[RESHAPE_2:.*]] = stablehlo.reshape %[[CONVERT_1]] : (tensor) -> tensor<1xi64> + // CHECK: %[[CONCATENATE_0:.*]] = stablehlo.concatenate %[[RESHAPE_0]], %[[RESHAPE_1]], %[[RESHAPE_2]], dim = 0 : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> +-// CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<2> : tensor +-// CHECK: %[[RESHAPE_3:.*]] = stablehlo.reshape %[[CONSTANT_0]] : (tensor) -> tensor<1xi64> ++// CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2> : tensor ++// CHECK: %[[RESHAPE_3:.*]] = stablehlo.reshape %[[CONSTANT_1]] : (tensor) -> tensor<1xi64> + // CHECK: %[[CONCATENATE_1:.*]] = stablehlo.concatenate %[[RESHAPE_0]], %[[RESHAPE_1]], %[[RESHAPE_3]], dim = 0 : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> + // CHECK: %[[DYNAMIC_IOTA_0:.*]] = stablehlo.dynamic_iota %[[CONCATENATE_0]], dim = 2 : (tensor<3xi64>) -> tensor + // CHECK: %[[VAL_0:.*]]:2 = "stablehlo.sort"(%[[ARG0]], %[[DYNAMIC_IOTA_0]]) <{dimension = 2 : i64, is_stable = true}> ({ + // CHECK: ^bb0(%[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor, %[[VAL_4:.*]]: tensor): +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GT, %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GT, %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor + // CHECK: stablehlo.return %[[COMPARE_0]] : tensor + // CHECK: }) : (tensor, tensor) -> (tensor, tensor) +-// CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<0> : tensor<3xi64> +-// CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<1> : tensor<3xi64> +-// CHECK: %[[REAL_DYNAMIC_SLICE_0:.*]] = stablehlo.real_dynamic_slice %[[VAL_5:.*]]#0, %[[CONSTANT_1]], %[[CONCATENATE_1]], %[[CONSTANT_2]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +-// CHECK: %[[REAL_DYNAMIC_SLICE_1:.*]] = stablehlo.real_dynamic_slice %[[VAL_5]]#1, %[[CONSTANT_1]], %[[CONCATENATE_1]], %[[CONSTANT_2]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor ++// CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<0> : tensor<3xi64> ++// CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<1> : tensor<3xi64> ++// CHECK: %[[REAL_DYNAMIC_SLICE_0:.*]] = stablehlo.real_dynamic_slice %[[VAL_5:.*]]#0, %[[CONSTANT_2]], %[[CONCATENATE_1]], %[[CONSTANT_3]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor ++// CHECK: %[[REAL_DYNAMIC_SLICE_1:.*]] = stablehlo.real_dynamic_slice %[[VAL_5]]#1, %[[CONSTANT_2]], %[[CONCATENATE_1]], %[[CONSTANT_3]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor + // CHECK: return %[[REAL_DYNAMIC_SLICE_0]], %[[REAL_DYNAMIC_SLICE_1]] : tensor, tensor + // CHECK: } + func.func @dyn_top_k(%arg0: tensor) -> (tensor, tensor) { +@@ -4231,7 +4231,7 @@ + // CHECK: %[[MULTIPLY_27:.*]] = stablehlo.multiply %[[SUBTRACT_27]], %[[CONSTANT_35]] : tensor<16x16xf32> + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[ABS_0]] : tensor<16x16xf32> + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[MULTIPLY_27]], %[[SQRT_0]] : tensor<16x16xf32> +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_3]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_3]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[MULTIPLY_19]], %[[DIVIDE_1]] : tensor<16x16xi1>, tensor<16x16xf32> + // CHECK: %[[SIGN_0:.*]] = stablehlo.sign %[[CONVERT_0]] : tensor<16x16xf32> + // CHECK: %[[MULTIPLY_28:.*]] = stablehlo.multiply %[[SIGN_0]], %[[SELECT_0]] : tensor<16x16xf32> +@@ -4367,7 +4367,7 @@ + // CHECK: %[[MULTIPLY_27:.*]] = stablehlo.multiply %[[SUBTRACT_27]], %[[CONSTANT_35]] : tensor<16x16xf32> + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[ABS_0]] : tensor<16x16xf32> + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[MULTIPLY_27]], %[[SQRT_0]] : tensor<16x16xf32> +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_3]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_3]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[MULTIPLY_19]], %[[DIVIDE_1]] : tensor<16x16xi1>, tensor<16x16xf32> + // CHECK: %[[SIGN_0:.*]] = stablehlo.sign %[[ARG0]] : tensor<16x16xf32> + // CHECK: %[[MULTIPLY_28:.*]] = stablehlo.multiply %[[SIGN_0]], %[[SELECT_0]] : tensor<16x16xf32> +@@ -4622,7 +4622,7 @@ + // CHECK: %[[MULTIPLY_57:.*]] = stablehlo.multiply %[[SUBTRACT_57]], %[[CONSTANT_65]] : tensor<16x16xf64> + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[ABS_0]] : tensor<16x16xf64> + // CHECK: %[[DIVIDE_1:.*]] = stablehlo.divide %[[MULTIPLY_57]], %[[SQRT_0]] : tensor<16x16xf64> +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_3]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LE, %[[ABS_0]], %[[CONSTANT_3]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[MULTIPLY_31]], %[[DIVIDE_1]] : tensor<16x16xi1>, tensor<16x16xf64> + // CHECK: %[[SIGN_0:.*]] = stablehlo.sign %[[ARG0]] : tensor<16x16xf64> + // CHECK: %[[MULTIPLY_58:.*]] = stablehlo.multiply %[[SIGN_0]], %[[SELECT_0]] : tensor<16x16xf64> +@@ -4642,7 +4642,7 @@ + // CHECK: %[[LOG_PLUS_ONE_0:.*]] = stablehlo.log_plus_one %[[MULTIPLY_0]] : tensor<16x16xf32> + // CHECK: %[[NEGATE_1:.*]] = stablehlo.negate %[[LOG_PLUS_ONE_0]] : tensor<16x16xf32> + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<5.000000e+00> : tensor<16x16xf32> +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[NEGATE_1]], %[[CONSTANT_0]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[NEGATE_1]], %[[CONSTANT_0]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2.500000e+00> : tensor<16x16xf32> + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[NEGATE_1]], %[[CONSTANT_1]] : tensor<16x16xf32> + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[NEGATE_1]] : tensor<16x16xf32> +@@ -4695,7 +4695,7 @@ + // CHECK: %[[MULTIPLY_9:.*]] = stablehlo.multiply %[[ADD_7]], %[[ARG0]] : tensor<16x16xf32> + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor<16x16xf32> + // CHECK: %[[CONSTANT_21:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<16x16xf32> +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare EQ, %[[ABS_0]], %[[CONSTANT_21]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare EQ, %[[ABS_0]], %[[CONSTANT_21]] : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xi1> + // CHECK: %[[CONSTANT_22:.*]] = stablehlo.constant dense<0x7F800000> : tensor<16x16xf32> + // CHECK: %[[MULTIPLY_10:.*]] = stablehlo.multiply %[[ARG0]], %[[CONSTANT_22]] : tensor<16x16xf32> + // CHECK: %[[SELECT_10:.*]] = stablehlo.select %[[COMPARE_1]], %[[MULTIPLY_10]], %[[MULTIPLY_9]] : tensor<16x16xi1>, tensor<16x16xf32> +@@ -4715,9 +4715,9 @@ + // CHECK: %[[LOG_PLUS_ONE_0:.*]] = stablehlo.log_plus_one %[[MULTIPLY_0]] : tensor<16x16xf64> + // CHECK: %[[NEGATE_1:.*]] = stablehlo.negate %[[LOG_PLUS_ONE_0]] : tensor<16x16xf64> + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<6.250000e+00> : tensor<16x16xf64> +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[NEGATE_1]], %[[CONSTANT_0]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[NEGATE_1]], %[[CONSTANT_0]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.600000e+01> : tensor<16x16xf64> +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_1]], %[[CONSTANT_1]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[NEGATE_1]], %[[CONSTANT_1]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> + // CHECK: %[[SQRT_0:.*]] = stablehlo.sqrt %[[NEGATE_1]] : tensor<16x16xf64> + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<3.125000e+00> : tensor<16x16xf64> + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[NEGATE_1]], %[[CONSTANT_2]] : tensor<16x16xf64> +@@ -4874,7 +4874,7 @@ + // CHECK: %[[MULTIPLY_23:.*]] = stablehlo.multiply %[[SELECT_43]], %[[ARG0]] : tensor<16x16xf64> + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[ARG0]] : tensor<16x16xf64> + // CHECK: %[[CONSTANT_64:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<16x16xf64> +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_0]], %[[CONSTANT_64]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare EQ, %[[ABS_0]], %[[CONSTANT_64]] : (tensor<16x16xf64>, tensor<16x16xf64>) -> tensor<16x16xi1> + // CHECK: %[[CONSTANT_65:.*]] = stablehlo.constant dense<0x7FF0000000000000> : tensor<16x16xf64> + // CHECK: %[[MULTIPLY_24:.*]] = stablehlo.multiply %[[ARG0]], %[[CONSTANT_65]] : tensor<16x16xf64> + // CHECK: %[[SELECT_44:.*]] = stablehlo.select %[[COMPARE_2]], %[[MULTIPLY_24]], %[[MULTIPLY_23]] : tensor<16x16xi1>, tensor<16x16xf64> +@@ -4898,7 +4898,7 @@ + // CHECK: %[[BROADCAST_IN_DIM_0:.*]] = stablehlo.broadcast_in_dim %[[CONSTANT_0]], dims = [] : (tensor) -> tensor<16x16xf32> + // CHECK: %[[GET_DIMENSION_SIZE_0:.*]] = stablehlo.get_dimension_size %[[ARG0]], dim = 0 : (tensor>) -> tensor + // CHECK: %[[SET_DIMENSION_SIZE_0:.*]] = stablehlo.set_dimension_size %[[BROADCAST_IN_DIM_0]], %[[GET_DIMENSION_SIZE_0]], dim = 0 : (tensor<16x16xf32>, tensor) -> tensor> +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[NEGATE_1]], %[[SET_DIMENSION_SIZE_0]] : (tensor>, tensor>) -> tensor> ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[NEGATE_1]], %[[SET_DIMENSION_SIZE_0]] : (tensor>, tensor>) -> tensor> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<2.500000e+00> : tensor + // CHECK: %[[BROADCAST_IN_DIM_1:.*]] = stablehlo.broadcast_in_dim %[[CONSTANT_1]], dims = [] : (tensor) -> tensor<16x16xf32> + // CHECK: %[[GET_DIMENSION_SIZE_1:.*]] = stablehlo.get_dimension_size %[[ARG0]], dim = 0 : (tensor>) -> tensor +@@ -5014,7 +5014,7 @@ + // CHECK: %[[BROADCAST_IN_DIM_21:.*]] = stablehlo.broadcast_in_dim %[[CONSTANT_21]], dims = [] : (tensor) -> tensor<16x16xf32> + // CHECK: %[[GET_DIMENSION_SIZE_21:.*]] = stablehlo.get_dimension_size %[[ARG0]], dim = 0 : (tensor>) -> tensor + // CHECK: %[[SET_DIMENSION_SIZE_21:.*]] = stablehlo.set_dimension_size %[[BROADCAST_IN_DIM_21]], %[[GET_DIMENSION_SIZE_21]], dim = 0 : (tensor<16x16xf32>, tensor) -> tensor> +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare EQ, %[[ABS_0]], %[[SET_DIMENSION_SIZE_21]] : (tensor>, tensor>) -> tensor> ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare EQ, %[[ABS_0]], %[[SET_DIMENSION_SIZE_21]] : (tensor>, tensor>) -> tensor> + // CHECK: %[[CONSTANT_22:.*]] = stablehlo.constant dense<0x7F800000> : tensor + // CHECK: %[[BROADCAST_IN_DIM_22:.*]] = stablehlo.broadcast_in_dim %[[CONSTANT_22]], dims = [] : (tensor) -> tensor<16x16xf32> + // CHECK: %[[GET_DIMENSION_SIZE_22:.*]] = stablehlo.get_dimension_size %[[ARG0]], dim = 0 : (tensor>) -> tensor +@@ -5037,7 +5037,7 @@ + // CHECK: %[[ABS_0:.*]] = stablehlo.abs %[[REAL_0]] : tensor + // CHECK: %[[IMAG_0:.*]] = stablehlo.imag %[[ARG0]] : (tensor>) -> tensor + // CHECK: %[[ABS_1:.*]] = stablehlo.abs %[[IMAG_0]] : tensor +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare EQ, %[[ABS_0]], %[[ABS_1]] : (tensor, tensor) -> tensor ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare EQ, %[[ABS_0]], %[[ABS_1]] : (tensor, tensor) -> tensor + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[IS_FINITE_0]], %[[COMPARE_0]] : tensor + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[REAL_0]], %[[IMAG_0]] : tensor +@@ -5073,8 +5073,8 @@ + // CHECK-SAME: %[[ARG0:.*]]: tensor<2x11x5xf32>, + // CHECK-SAME: %[[ARG1:.*]]: tensor<3x2x5x7xf32>, + // CHECK-SAME: %[[ARG2:.*]]: tensor<3xi64>) -> tensor<2x11x7xf32> { +-// CHECK-DAG: %[[IOTA_0:.*]] = stablehlo.iota dim = 1 : tensor<1x11x1xi64> +-// CHECK-DAG: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0> : tensor ++// CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0> : tensor ++// CHECK: %[[IOTA_0:.*]] = stablehlo.iota dim = 1 : tensor<1x11x1xi64> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x11x5xf32> + // CHECK: %[[SLICE_0:.*]] = stablehlo.slice %[[ARG2]] [0:1] : (tensor<3xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPE_0:.*]] = stablehlo.reshape %[[SLICE_0]] : (tensor<1xi64>) -> tensor +@@ -5082,8 +5082,8 @@ + // CHECK: %[[BROADCAST_IN_DIM_0:.*]] = stablehlo.broadcast_in_dim %[[CONSTANT_0]], dims = [] : (tensor) -> tensor<2x11x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_1:.*]] = stablehlo.broadcast_in_dim %[[ADD_0]], dims = [] : (tensor) -> tensor<2x11x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_2:.*]] = stablehlo.broadcast_in_dim %[[IOTA_0]], dims = [0, 1, 2] : (tensor<1x11x1xi64>) -> tensor<2x11x5xi64> +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_2]], %[[BROADCAST_IN_DIM_0]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_2]], %[[BROADCAST_IN_DIM_1]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_2]], %[[BROADCAST_IN_DIM_0]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_2]], %[[BROADCAST_IN_DIM_1]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_0]], %[[COMPARE_1]] : tensor<2x11x5xi1> + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[AND_0]], %[[ARG0]], %[[CONSTANT_1]] : tensor<2x11x5xi1>, tensor<2x11x5xf32> + // CHECK: %[[SLICE_1:.*]] = stablehlo.slice %[[ARG1]] [0:1, 0:2, 0:5, 0:7] : (tensor<3x2x5x7xf32>) -> tensor<1x2x5x7xf32> +@@ -5095,8 +5095,8 @@ + // CHECK: %[[BROADCAST_IN_DIM_3:.*]] = stablehlo.broadcast_in_dim %[[ADD_0]], dims = [] : (tensor) -> tensor<2x11x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_4:.*]] = stablehlo.broadcast_in_dim %[[ADD_1]], dims = [] : (tensor) -> tensor<2x11x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_5:.*]] = stablehlo.broadcast_in_dim %[[IOTA_0]], dims = [0, 1, 2] : (tensor<1x11x1xi64>) -> tensor<2x11x5xi64> +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_5]], %[[BROADCAST_IN_DIM_3]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_5]], %[[BROADCAST_IN_DIM_4]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_5]], %[[BROADCAST_IN_DIM_3]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_5]], %[[BROADCAST_IN_DIM_4]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[COMPARE_2]], %[[COMPARE_3]] : tensor<2x11x5xi1> + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[AND_1]], %[[ARG0]], %[[CONSTANT_1]] : tensor<2x11x5xi1>, tensor<2x11x5xf32> + // CHECK: %[[SLICE_3:.*]] = stablehlo.slice %[[ARG1]] [1:2, 0:2, 0:5, 0:7] : (tensor<3x2x5x7xf32>) -> tensor<1x2x5x7xf32> +@@ -5109,8 +5109,8 @@ + // CHECK: %[[BROADCAST_IN_DIM_6:.*]] = stablehlo.broadcast_in_dim %[[ADD_1]], dims = [] : (tensor) -> tensor<2x11x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_7:.*]] = stablehlo.broadcast_in_dim %[[ADD_3]], dims = [] : (tensor) -> tensor<2x11x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_8:.*]] = stablehlo.broadcast_in_dim %[[IOTA_0]], dims = [0, 1, 2] : (tensor<1x11x1xi64>) -> tensor<2x11x5xi64> +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_8]], %[[BROADCAST_IN_DIM_6]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_8]], %[[BROADCAST_IN_DIM_7]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_8]], %[[BROADCAST_IN_DIM_6]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_8]], %[[BROADCAST_IN_DIM_7]] : (tensor<2x11x5xi64>, tensor<2x11x5xi64>) -> tensor<2x11x5xi1> + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[COMPARE_4]], %[[COMPARE_5]] : tensor<2x11x5xi1> + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[AND_2]], %[[ARG0]], %[[CONSTANT_1]] : tensor<2x11x5xi1>, tensor<2x11x5xf32> + // CHECK: %[[SLICE_5:.*]] = stablehlo.slice %[[ARG1]] [2:3, 0:2, 0:5, 0:7] : (tensor<3x2x5x7xf32>) -> tensor<1x2x5x7xf32> +@@ -5149,8 +5149,8 @@ + // CHECK: %[[BROADCAST_IN_DIM_0:.*]] = stablehlo.broadcast_in_dim %[[CONSTANT_0]], dims = [0] : (tensor<2xi64>) -> tensor<2x3x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_1:.*]] = stablehlo.broadcast_in_dim %[[ADD_0]], dims = [0] : (tensor<2xi64>) -> tensor<2x3x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_2:.*]] = stablehlo.broadcast_in_dim %[[IOTA_0]], dims = [0, 1, 2] : (tensor<1x1x5xi64>) -> tensor<2x3x5xi64> +-// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_2]], %[[BROADCAST_IN_DIM_0]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> +-// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_2]], %[[BROADCAST_IN_DIM_1]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_2]], %[[BROADCAST_IN_DIM_0]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_2]], %[[BROADCAST_IN_DIM_1]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> + // CHECK: %[[AND_0:.*]] = stablehlo.and %[[COMPARE_0]], %[[COMPARE_1]] : tensor<2x3x5xi1> + // CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[AND_0]], %[[ARG0]], %[[CONSTANT_1]] : tensor<2x3x5xi1>, tensor<2x3x5xf32> + // CHECK: %[[DOT_GENERAL_0:.*]] = stablehlo.dot_general %[[SELECT_0]], %[[ARG1]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32> +@@ -5160,8 +5160,8 @@ + // CHECK: %[[BROADCAST_IN_DIM_3:.*]] = stablehlo.broadcast_in_dim %[[ADD_0]], dims = [0] : (tensor<2xi64>) -> tensor<2x3x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_4:.*]] = stablehlo.broadcast_in_dim %[[ADD_1]], dims = [0] : (tensor<2xi64>) -> tensor<2x3x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_5:.*]] = stablehlo.broadcast_in_dim %[[IOTA_0]], dims = [0, 1, 2] : (tensor<1x1x5xi64>) -> tensor<2x3x5xi64> +-// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_5]], %[[BROADCAST_IN_DIM_3]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> +-// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_5]], %[[BROADCAST_IN_DIM_4]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> ++// CHECK: %[[COMPARE_2:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_5]], %[[BROADCAST_IN_DIM_3]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> ++// CHECK: %[[COMPARE_3:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_5]], %[[BROADCAST_IN_DIM_4]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> + // CHECK: %[[AND_1:.*]] = stablehlo.and %[[COMPARE_2]], %[[COMPARE_3]] : tensor<2x3x5xi1> + // CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[AND_1]], %[[ARG0]], %[[CONSTANT_1]] : tensor<2x3x5xi1>, tensor<2x3x5xf32> + // CHECK: %[[DOT_GENERAL_1:.*]] = stablehlo.dot_general %[[SELECT_1]], %[[ARG1]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32> +@@ -5171,8 +5171,8 @@ + // CHECK: %[[BROADCAST_IN_DIM_6:.*]] = stablehlo.broadcast_in_dim %[[ADD_1]], dims = [0] : (tensor<2xi64>) -> tensor<2x3x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_7:.*]] = stablehlo.broadcast_in_dim %[[ADD_2]], dims = [0] : (tensor<2xi64>) -> tensor<2x3x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_8:.*]] = stablehlo.broadcast_in_dim %[[IOTA_0]], dims = [0, 1, 2] : (tensor<1x1x5xi64>) -> tensor<2x3x5xi64> +-// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_8]], %[[BROADCAST_IN_DIM_6]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> +-// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_8]], %[[BROADCAST_IN_DIM_7]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> ++// CHECK: %[[COMPARE_4:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_8]], %[[BROADCAST_IN_DIM_6]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> ++// CHECK: %[[COMPARE_5:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_8]], %[[BROADCAST_IN_DIM_7]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> + // CHECK: %[[AND_2:.*]] = stablehlo.and %[[COMPARE_4]], %[[COMPARE_5]] : tensor<2x3x5xi1> + // CHECK: %[[SELECT_2:.*]] = stablehlo.select %[[AND_2]], %[[ARG0]], %[[CONSTANT_1]] : tensor<2x3x5xi1>, tensor<2x3x5xf32> + // CHECK: %[[DOT_GENERAL_2:.*]] = stablehlo.dot_general %[[SELECT_2]], %[[ARG1]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32> +@@ -5182,8 +5182,8 @@ + // CHECK: %[[BROADCAST_IN_DIM_9:.*]] = stablehlo.broadcast_in_dim %[[ADD_2]], dims = [0] : (tensor<2xi64>) -> tensor<2x3x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_10:.*]] = stablehlo.broadcast_in_dim %[[ADD_3]], dims = [0] : (tensor<2xi64>) -> tensor<2x3x5xi64> + // CHECK: %[[BROADCAST_IN_DIM_11:.*]] = stablehlo.broadcast_in_dim %[[IOTA_0]], dims = [0, 1, 2] : (tensor<1x1x5xi64>) -> tensor<2x3x5xi64> +-// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_11]], %[[BROADCAST_IN_DIM_9]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> +-// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_11]], %[[BROADCAST_IN_DIM_10]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> ++// CHECK: %[[COMPARE_6:.*]] = stablehlo.compare GE, %[[BROADCAST_IN_DIM_11]], %[[BROADCAST_IN_DIM_9]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> ++// CHECK: %[[COMPARE_7:.*]] = stablehlo.compare LT, %[[BROADCAST_IN_DIM_11]], %[[BROADCAST_IN_DIM_10]] : (tensor<2x3x5xi64>, tensor<2x3x5xi64>) -> tensor<2x3x5xi1> + // CHECK: %[[AND_3:.*]] = stablehlo.and %[[COMPARE_6]], %[[COMPARE_7]] : tensor<2x3x5xi1> + // CHECK: %[[SELECT_3:.*]] = stablehlo.select %[[AND_3]], %[[ARG0]], %[[CONSTANT_1]] : tensor<2x3x5xi1>, tensor<2x3x5xf32> + // CHECK: %[[DOT_GENERAL_3:.*]] = stablehlo.dot_general %[[SELECT_3]], %[[ARG1]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32> +@@ -5234,37 +5234,37 @@ + // ----- + + // CHECK-LABEL: func.func @scan( +-// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi32>, +-// CHECK-SAME: %[[ARG1:.*]]: tensor) -> (tensor<2xi32>, tensor) { +-// CHECK-DAG: %[[GET_DIMENSION_SIZE:.*]] = stablehlo.get_dimension_size %[[ARG0]], dim = 0 : (tensor<2xi32>) -> tensor +-// CHECK-DAG: %[[CONVERT:.*]] = stablehlo.convert %[[GET_DIMENSION_SIZE]] : (tensor) -> tensor +-// CHECK-DAG: %[[C0_I64:.*]] = stablehlo.constant dense<0> : tensor +-// CHECK-DAG: %[[C0_I32:.*]] = stablehlo.constant dense<0> : tensor +-// CHECK-DAG: %[[BROADCAST:.*]] = stablehlo.broadcast %[[C0_I32]], sizes = [2] : (tensor) -> tensor<2xi32> +-// CHECK: %[[WHILE:.*]]:3 = stablehlo.while(%[[ITER:.*]] = %[[C0_I64]], %[[ACC:.*]] = %[[ARG1]], %[[OUT:.*]] = %[[BROADCAST]]) : tensor, tensor, tensor<2xi32> +-// CHECK: cond { +-// CHECK: %[[CMP:.*]] = stablehlo.compare LT, %[[ITER]], %[[CONVERT]] : (tensor, tensor) -> tensor +-// CHECK: stablehlo.return %[[CMP]] : tensor +-// CHECK: } do { +-// CHECK-DAG: %[[C0_I64_2:.*]] = stablehlo.constant dense<0> : tensor +-// CHECK-DAG: %[[RESHAPE_ITER:.*]] = stablehlo.reshape %[[ITER]] : (tensor) -> tensor<1xi64> +-// CHECK-DAG: %[[CONCAT_START:.*]] = stablehlo.concatenate %[[RESHAPE_ITER]], dim = 0 : (tensor<1xi64>) -> tensor<1xi64> +-// CHECK-DAG: %[[C1_I64:.*]] = stablehlo.constant dense<1> : tensor +-// CHECK-DAG: %[[ITER_PLUS_1:.*]] = stablehlo.add %[[ITER]], %[[C1_I64]] : tensor +-// CHECK-DAG: %[[RESHAPE_LIMIT:.*]] = stablehlo.reshape %[[ITER_PLUS_1]] : (tensor) -> tensor<1xi64> +-// CHECK-DAG: %[[CONCAT_LIMIT:.*]] = stablehlo.concatenate %[[RESHAPE_LIMIT]], dim = 0 : (tensor<1xi64>) -> tensor<1xi64> +-// CHECK-DAG: %[[STRIDES:.*]] = stablehlo.constant dense<1> : tensor<1xi64> +-// CHECK-DAG: %[[SLICE:.*]] = stablehlo.real_dynamic_slice %[[ARG0]], %[[CONCAT_START]], %[[CONCAT_LIMIT]], %[[STRIDES]] : (tensor<2xi32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32> +-// CHECK-DAG: %[[INPUT_ELEM:.*]] = stablehlo.reshape %[[SLICE]] : (tensor<1xi32>) -> tensor +-// CHECK-DAG: %[[ADD_RES:.*]] = stablehlo.add %[[INPUT_ELEM]], %[[ACC]] : tensor +-// CHECK-DAG: %[[RESHAPE_RES:.*]] = stablehlo.reshape %[[ADD_RES]] : (tensor) -> tensor<1xi32> +-// CHECK-DAG: %[[C0_I64_3:.*]] = stablehlo.constant dense<0> : tensor +-// CHECK-DAG: %[[UPDATE:.*]] = stablehlo.dynamic_update_slice %[[OUT]], %[[RESHAPE_RES]], %[[ITER]] : (tensor<2xi32>, tensor<1xi32>, tensor) -> tensor<2xi32> +-// CHECK-DAG: %[[C1_I64_2:.*]] = stablehlo.constant dense<1> : tensor +-// CHECK-DAG: %[[NEXT_ITER:.*]] = stablehlo.add %[[ITER]], %[[C1_I64_2]] : tensor +-// CHECK: stablehlo.return %[[NEXT_ITER]], %[[ADD_RES]], %[[UPDATE]] : tensor, tensor, tensor<2xi32> +-// CHECK: } +-// CHECK: return %[[WHILE]]#2, %[[WHILE]]#1 : tensor<2xi32>, tensor ++// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi32>, ++// CHECK-SAME: %[[ARG1:.*]]: tensor) -> (tensor<2xi32>, tensor) { ++// CHECK: %[[GET_DIMENSION_SIZE_0:.*]] = stablehlo.get_dimension_size %[[ARG0]], dim = 0 : (tensor<2xi32>) -> tensor ++// CHECK: %[[CONVERT_0:.*]] = stablehlo.convert %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor ++// CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0> : tensor ++// CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<0> : tensor ++// CHECK: %[[BROADCAST_0:.*]] = stablehlo.broadcast %[[CONSTANT_1]], sizes = [2] : (tensor) -> tensor<2xi32> ++// CHECK: %[[WHILE_0:.*]]:3 = stablehlo.while(%[[VAL_0:.*]] = %[[CONSTANT_0]], %[[VAL_1:.*]] = %[[ARG1]], %[[VAL_2:.*]] = %[[BROADCAST_0]]) : tensor, tensor, tensor<2xi32> ++// CHECK: cond { ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[VAL_0]], %[[CONVERT_0]] : (tensor, tensor) -> tensor ++// CHECK: stablehlo.return %[[COMPARE_0]] : tensor ++// CHECK: } do { ++// CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<0> : tensor ++// CHECK: %[[RESHAPE_0:.*]] = stablehlo.reshape %[[VAL_0]] : (tensor) -> tensor<1xi64> ++// CHECK: %[[CONCATENATE_0:.*]] = stablehlo.concatenate %[[RESHAPE_0]], dim = 0 : (tensor<1xi64>) -> tensor<1xi64> ++// CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<1> : tensor ++// CHECK: %[[ADD_0:.*]] = stablehlo.add %[[VAL_0]], %[[CONSTANT_3]] : tensor ++// CHECK: %[[RESHAPE_1:.*]] = stablehlo.reshape %[[ADD_0]] : (tensor) -> tensor<1xi64> ++// CHECK: %[[CONCATENATE_1:.*]] = stablehlo.concatenate %[[RESHAPE_1]], dim = 0 : (tensor<1xi64>) -> tensor<1xi64> ++// CHECK: %[[CONSTANT_4:.*]] = stablehlo.constant dense<1> : tensor<1xi64> ++// CHECK: %[[REAL_DYNAMIC_SLICE_0:.*]] = stablehlo.real_dynamic_slice %[[ARG0]], %[[CONCATENATE_0]], %[[CONCATENATE_1]], %[[CONSTANT_4]] : (tensor<2xi32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32> ++// CHECK: %[[RESHAPE_2:.*]] = stablehlo.reshape %[[REAL_DYNAMIC_SLICE_0]] : (tensor<1xi32>) -> tensor ++// CHECK: %[[ADD_1:.*]] = stablehlo.add %[[RESHAPE_2]], %[[VAL_1]] : tensor ++// CHECK: %[[RESHAPE_3:.*]] = stablehlo.reshape %[[ADD_1]] : (tensor) -> tensor<1xi32> ++// CHECK: %[[CONSTANT_5:.*]] = stablehlo.constant dense<0> : tensor ++// CHECK: %[[DYNAMIC_UPDATE_SLICE_0:.*]] = stablehlo.dynamic_update_slice %[[VAL_2]], %[[RESHAPE_3]], %[[VAL_0]] : (tensor<2xi32>, tensor<1xi32>, tensor) -> tensor<2xi32> ++// CHECK: %[[CONSTANT_6:.*]] = stablehlo.constant dense<1> : tensor ++// CHECK: %[[ADD_2:.*]] = stablehlo.add %[[VAL_0]], %[[CONSTANT_6]] : tensor ++// CHECK: stablehlo.return %[[ADD_2]], %[[ADD_1]], %[[DYNAMIC_UPDATE_SLICE_0]] : tensor, tensor, tensor<2xi32> ++// CHECK: } ++// CHECK: return %[[WHILE_0]]#2, %[[WHILE_0]]#1 : tensor<2xi32>, tensor + // CHECK: } + func.func @scan(%arg0: tensor<2xi32>, %arg1: tensor) -> (tensor<2xi32>, tensor) { + %0:2 = chlo.scan(%arg0) inits(%arg1) dimension=0 { +@@ -5274,3 +5274,165 @@ + } : (tensor<2xi32>, tensor) -> (tensor<2xi32>, tensor) + func.return %0#0, %0#1 : tensor<2xi32>, tensor + } ++ ++// ----- ++ ++// CHECK-LABEL: func.func @mulhi_i32( ++// CHECK-SAME: %[[ARG0:.*]]: tensor<4xi32>, ++// CHECK-SAME: %[[ARG1:.*]]: tensor<4xi32>) -> tensor<4xi32> { ++// CHECK: %[[CONVERT_0:.*]] = stablehlo.convert %[[ARG0]] : (tensor<4xi32>) -> tensor<4xi64> ++// CHECK: %[[CONVERT_1:.*]] = stablehlo.convert %[[ARG1]] : (tensor<4xi32>) -> tensor<4xi64> ++// CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[CONVERT_0]], %[[CONVERT_1]] : tensor<4xi64> ++// CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<32> : tensor<4xi64> ++// CHECK: %[[SHIFT_RIGHT_ARITHMETIC_0:.*]] = stablehlo.shift_right_arithmetic %[[MULTIPLY_0]], %[[CONSTANT_0]] : tensor<4xi64> ++// CHECK: %[[CONVERT_2:.*]] = stablehlo.convert %[[SHIFT_RIGHT_ARITHMETIC_0]] : (tensor<4xi64>) -> tensor<4xi32> ++// CHECK: return %[[CONVERT_2]] : tensor<4xi32> ++// CHECK: } ++func.func @mulhi_i32(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { ++ %result = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> ++ func.return %result : tensor<4xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func.func @mulhi_ui32( ++// CHECK-SAME: %[[ARG0:.*]]: tensor<4xui32>, ++// CHECK-SAME: %[[ARG1:.*]]: tensor<4xui32>) -> tensor<4xui32> { ++// CHECK: %[[CONVERT_0:.*]] = stablehlo.convert %[[ARG0]] : (tensor<4xui32>) -> tensor<4xui64> ++// CHECK: %[[CONVERT_1:.*]] = stablehlo.convert %[[ARG1]] : (tensor<4xui32>) -> tensor<4xui64> ++// CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[CONVERT_0]], %[[CONVERT_1]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_0:.*]] = stablehlo.shift_right_logical %[[MULTIPLY_0]], %[[CONSTANT_0]] : tensor<4xui64> ++// CHECK: %[[CONVERT_2:.*]] = stablehlo.convert %[[SHIFT_RIGHT_LOGICAL_0]] : (tensor<4xui64>) -> tensor<4xui32> ++// CHECK: return %[[CONVERT_2]] : tensor<4xui32> ++// CHECK: } ++func.func @mulhi_ui32(%arg0: tensor<4xui32>, %arg1: tensor<4xui32>) -> tensor<4xui32> { ++ %result = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xui32>, tensor<4xui32>) -> tensor<4xui32> ++ func.return %result : tensor<4xui32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func.func @mulhi_i16( ++// CHECK-SAME: %[[ARG0:.*]]: tensor<4xi16>, ++// CHECK-SAME: %[[ARG1:.*]]: tensor<4xi16>) -> tensor<4xi16> { ++// CHECK: %[[CONVERT_0:.*]] = stablehlo.convert %[[ARG0]] : (tensor<4xi16>) -> tensor<4xi32> ++// CHECK: %[[CONVERT_1:.*]] = stablehlo.convert %[[ARG1]] : (tensor<4xi16>) -> tensor<4xi32> ++// CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[CONVERT_0]], %[[CONVERT_1]] : tensor<4xi32> ++// CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<16> : tensor<4xi32> ++// CHECK: %[[SHIFT_RIGHT_ARITHMETIC_0:.*]] = stablehlo.shift_right_arithmetic %[[MULTIPLY_0]], %[[CONSTANT_0]] : tensor<4xi32> ++// CHECK: %[[CONVERT_2:.*]] = stablehlo.convert %[[SHIFT_RIGHT_ARITHMETIC_0]] : (tensor<4xi32>) -> tensor<4xi16> ++// CHECK: return %[[CONVERT_2]] : tensor<4xi16> ++// CHECK: } ++func.func @mulhi_i16(%arg0: tensor<4xi16>, %arg1: tensor<4xi16>) -> tensor<4xi16> { ++ %result = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xi16>, tensor<4xi16>) -> tensor<4xi16> ++ func.return %result : tensor<4xi16> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func.func @mulhi_ui16( ++// CHECK-SAME: %[[ARG0:.*]]: tensor<4xui16>, ++// CHECK-SAME: %[[ARG1:.*]]: tensor<4xui16>) -> tensor<4xui16> { ++// CHECK: %[[CONVERT_0:.*]] = stablehlo.convert %[[ARG0]] : (tensor<4xui16>) -> tensor<4xui32> ++// CHECK: %[[CONVERT_1:.*]] = stablehlo.convert %[[ARG1]] : (tensor<4xui16>) -> tensor<4xui32> ++// CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[CONVERT_0]], %[[CONVERT_1]] : tensor<4xui32> ++// CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<16> : tensor<4xui32> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_0:.*]] = stablehlo.shift_right_logical %[[MULTIPLY_0]], %[[CONSTANT_0]] : tensor<4xui32> ++// CHECK: %[[CONVERT_2:.*]] = stablehlo.convert %[[SHIFT_RIGHT_LOGICAL_0]] : (tensor<4xui32>) -> tensor<4xui16> ++// CHECK: return %[[CONVERT_2]] : tensor<4xui16> ++// CHECK: } ++func.func @mulhi_ui16(%arg0: tensor<4xui16>, %arg1: tensor<4xui16>) -> tensor<4xui16> { ++ %result = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xui16>, tensor<4xui16>) -> tensor<4xui16> ++ func.return %result : tensor<4xui16> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func.func @mulhi_ui64( ++// CHECK-SAME: %[[ARG0:.*]]: tensor<4xui64>, ++// CHECK-SAME: %[[ARG1:.*]]: tensor<4xui64>) -> tensor<4xui64> { ++// CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_0:.*]] = stablehlo.shift_right_logical %[[ARG0]], %[[CONSTANT_0]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_1:.*]] = stablehlo.shift_right_logical %[[ARG1]], %[[CONSTANT_1]] : tensor<4xui64> ++// CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[SHIFT_RIGHT_LOGICAL_0]], %[[SHIFT_RIGHT_LOGICAL_1]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<4294967295> : tensor<4xui64> ++// CHECK: %[[AND_0:.*]] = stablehlo.and %[[ARG0]], %[[CONSTANT_2]] : tensor<4xui64> ++// CHECK: %[[MULTIPLY_1:.*]] = stablehlo.multiply %[[AND_0]], %[[SHIFT_RIGHT_LOGICAL_1]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<4294967295> : tensor<4xui64> ++// CHECK: %[[AND_1:.*]] = stablehlo.and %[[ARG1]], %[[CONSTANT_3]] : tensor<4xui64> ++// CHECK: %[[MULTIPLY_2:.*]] = stablehlo.multiply %[[AND_0]], %[[AND_1]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_4:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_2:.*]] = stablehlo.shift_right_logical %[[MULTIPLY_2]], %[[CONSTANT_4]] : tensor<4xui64> ++// CHECK: %[[ADD_0:.*]] = stablehlo.add %[[MULTIPLY_1]], %[[SHIFT_RIGHT_LOGICAL_2]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_5:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_3:.*]] = stablehlo.shift_right_logical %[[ADD_0]], %[[CONSTANT_5]] : tensor<4xui64> ++// CHECK: %[[ADD_1:.*]] = stablehlo.add %[[MULTIPLY_0]], %[[SHIFT_RIGHT_LOGICAL_3]] : tensor<4xui64> ++// CHECK: %[[MULTIPLY_3:.*]] = stablehlo.multiply %[[SHIFT_RIGHT_LOGICAL_0]], %[[AND_1]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_6:.*]] = stablehlo.constant dense<4294967295> : tensor<4xui64> ++// CHECK: %[[AND_2:.*]] = stablehlo.and %[[ADD_0]], %[[CONSTANT_6]] : tensor<4xui64> ++// CHECK: %[[ADD_2:.*]] = stablehlo.add %[[MULTIPLY_3]], %[[AND_2]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_7:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_4:.*]] = stablehlo.shift_right_logical %[[ADD_2]], %[[CONSTANT_7]] : tensor<4xui64> ++// CHECK: %[[ADD_3:.*]] = stablehlo.add %[[ADD_1]], %[[SHIFT_RIGHT_LOGICAL_4]] : tensor<4xui64> ++// CHECK: return %[[ADD_3]] : tensor<4xui64> ++// CHECK: } ++func.func @mulhi_ui64(%arg0: tensor<4xui64>, %arg1: tensor<4xui64>) -> tensor<4xui64> { ++ %result = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xui64>, tensor<4xui64>) -> tensor<4xui64> ++ func.return %result : tensor<4xui64> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func.func @mulhi_i64( ++// CHECK-SAME: %[[ARG0:.*]]: tensor<4xi64>, ++// CHECK-SAME: %[[ARG1:.*]]: tensor<4xi64>) -> tensor<4xi64> { ++// CHECK: %[[CONVERT_0:.*]] = stablehlo.convert %[[ARG0]] : (tensor<4xi64>) -> tensor<4xui64> ++// CHECK: %[[CONVERT_1:.*]] = stablehlo.convert %[[ARG1]] : (tensor<4xi64>) -> tensor<4xui64> ++// CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_0:.*]] = stablehlo.shift_right_logical %[[CONVERT_0]], %[[CONSTANT_0]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_1:.*]] = stablehlo.shift_right_logical %[[CONVERT_1]], %[[CONSTANT_1]] : tensor<4xui64> ++// CHECK: %[[MULTIPLY_0:.*]] = stablehlo.multiply %[[SHIFT_RIGHT_LOGICAL_0]], %[[SHIFT_RIGHT_LOGICAL_1]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<4294967295> : tensor<4xui64> ++// CHECK: %[[AND_0:.*]] = stablehlo.and %[[CONVERT_0]], %[[CONSTANT_2]] : tensor<4xui64> ++// CHECK: %[[MULTIPLY_1:.*]] = stablehlo.multiply %[[AND_0]], %[[SHIFT_RIGHT_LOGICAL_1]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<4294967295> : tensor<4xui64> ++// CHECK: %[[AND_1:.*]] = stablehlo.and %[[CONVERT_1]], %[[CONSTANT_3]] : tensor<4xui64> ++// CHECK: %[[MULTIPLY_2:.*]] = stablehlo.multiply %[[AND_0]], %[[AND_1]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_4:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_2:.*]] = stablehlo.shift_right_logical %[[MULTIPLY_2]], %[[CONSTANT_4]] : tensor<4xui64> ++// CHECK: %[[ADD_0:.*]] = stablehlo.add %[[MULTIPLY_1]], %[[SHIFT_RIGHT_LOGICAL_2]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_5:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_3:.*]] = stablehlo.shift_right_logical %[[ADD_0]], %[[CONSTANT_5]] : tensor<4xui64> ++// CHECK: %[[ADD_1:.*]] = stablehlo.add %[[MULTIPLY_0]], %[[SHIFT_RIGHT_LOGICAL_3]] : tensor<4xui64> ++// CHECK: %[[MULTIPLY_3:.*]] = stablehlo.multiply %[[SHIFT_RIGHT_LOGICAL_0]], %[[AND_1]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_6:.*]] = stablehlo.constant dense<4294967295> : tensor<4xui64> ++// CHECK: %[[AND_2:.*]] = stablehlo.and %[[ADD_0]], %[[CONSTANT_6]] : tensor<4xui64> ++// CHECK: %[[ADD_2:.*]] = stablehlo.add %[[MULTIPLY_3]], %[[AND_2]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_7:.*]] = stablehlo.constant dense<32> : tensor<4xui64> ++// CHECK: %[[SHIFT_RIGHT_LOGICAL_4:.*]] = stablehlo.shift_right_logical %[[ADD_2]], %[[CONSTANT_7]] : tensor<4xui64> ++// CHECK: %[[ADD_3:.*]] = stablehlo.add %[[ADD_1]], %[[SHIFT_RIGHT_LOGICAL_4]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_8:.*]] = stablehlo.constant dense<0> : tensor<4xi64> ++// CHECK: %[[COMPARE_0:.*]] = stablehlo.compare LT, %[[ARG0]], %[[CONSTANT_8]] : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1> ++// CHECK: %[[CONVERT_2:.*]] = stablehlo.convert %[[ARG1]] : (tensor<4xi64>) -> tensor<4xui64> ++// CHECK: %[[CONVERT_3:.*]] = stablehlo.convert %[[ARG1]] : (tensor<4xi64>) -> tensor<4xui64> ++// CHECK: %[[CONSTANT_9:.*]] = stablehlo.constant dense<0> : tensor<4xui64> ++// CHECK: %[[SELECT_0:.*]] = stablehlo.select %[[COMPARE_0]], %[[CONVERT_2]], %[[CONSTANT_9]] : tensor<4xi1>, tensor<4xui64> ++// CHECK: %[[SUBTRACT_0:.*]] = stablehlo.subtract %[[ADD_3]], %[[SELECT_0]] : tensor<4xui64> ++// CHECK: %[[CONSTANT_10:.*]] = stablehlo.constant dense<0> : tensor<4xi64> ++// CHECK: %[[COMPARE_1:.*]] = stablehlo.compare LT, %[[ARG1]], %[[CONSTANT_10]] : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1> ++// CHECK: %[[CONVERT_4:.*]] = stablehlo.convert %[[ARG0]] : (tensor<4xi64>) -> tensor<4xui64> ++// CHECK: %[[CONVERT_5:.*]] = stablehlo.convert %[[ARG0]] : (tensor<4xi64>) -> tensor<4xui64> ++// CHECK: %[[CONSTANT_11:.*]] = stablehlo.constant dense<0> : tensor<4xui64> ++// CHECK: %[[SELECT_1:.*]] = stablehlo.select %[[COMPARE_1]], %[[CONVERT_4]], %[[CONSTANT_11]] : tensor<4xi1>, tensor<4xui64> ++// CHECK: %[[SUBTRACT_1:.*]] = stablehlo.subtract %[[SUBTRACT_0]], %[[SELECT_1]] : tensor<4xui64> ++// CHECK: %[[CONVERT_6:.*]] = stablehlo.convert %[[SUBTRACT_1]] : (tensor<4xui64>) -> tensor<4xi64> ++// CHECK: return %[[CONVERT_6]] : tensor<4xi64> ++// CHECK: } ++func.func @mulhi_i64(%arg0: tensor<4xi64>, %arg1: tensor<4xi64>) -> tensor<4xi64> { ++ %result = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64> ++ func.return %result : tensor<4xi64> ++} ++ +diff --ruN a/stablehlo/stablehlo/tests/interpret/chlo/mulhi.mlir b/stablehlo/stablehlo/tests/interpret/chlo/mulhi.mlir +--- stablehlo/stablehlo/tests/interpret/chlo/mulhi.mlir ++++ stablehlo/stablehlo/tests/interpret/chlo/mulhi.mlir +@@ -0,0 +1,91 @@ ++// RUN: stablehlo-opt --chlo-legalize-to-stablehlo --split-input-file %s > %t.mlir ++// RUN: stablehlo-translate --interpret --split-input-file %t.mlir ++ ++// Test chlo.mulhi operation for signed 8-bit integers. ++// 64*4 = 256 ==> b0000_00001_0000_0000, high bits are 1. ++func.func @mulhi_op_test_si8() { ++ %lhs = stablehlo.constant dense<[64, -64, -64]> : tensor<3xi8> ++ %rhs = stablehlo.constant dense<[4, 4, -4]> : tensor<3xi8> ++ %result = "chlo.mulhi"(%lhs, %rhs) : (tensor<3xi8>, tensor<3xi8>) -> tensor<3xi8> ++ check.expect_eq_const %result, dense<[1, -1, 1]> : tensor<3xi8> ++ func.return ++} ++ ++// ----- ++ ++// Test chlo.mulhi operation for unsigned 8-bit integers. ++// 240 * 16 = 3840 ==> b0000_1111_0000_0000, high bits are 15. ++func.func @mulhi_op_test_ui8() { ++ %lhs = stablehlo.constant dense<[240, 128]> : tensor<2xui8> ++ %rhs = stablehlo.constant dense<[16, 2]> : tensor<2xui8> ++ %result = "chlo.mulhi"(%lhs, %rhs) : (tensor<2xui8>, tensor<2xui8>) -> tensor<2xui8> ++ check.expect_eq_const %result, dense<[15, 1]> : tensor<2xui8> ++ func.return ++} ++ ++// ----- ++ ++// Test chlo.mulhi operation for signed 16-bit integers. ++func.func @mulhi_op_test_si16() { ++ %lhs = stablehlo.constant dense<[16384, -16384]> : tensor<2xi16> ++ %rhs = stablehlo.constant dense<[8, 8]> : tensor<2xi16> ++ %result = "chlo.mulhi"(%lhs, %rhs) : (tensor<2xi16>, tensor<2xi16>) -> tensor<2xi16> ++ check.expect_eq_const %result, dense<[2, -2]> : tensor<2xi16> ++ func.return ++} ++ ++// ----- ++ ++// Test chlo.mulhi operation for unsigned 16-bit integers. ++func.func @mulhi_op_test_ui16() { ++ %lhs = stablehlo.constant dense<[61440, 32768]> : tensor<2xui16> ++ %rhs = stablehlo.constant dense<[16, 2]> : tensor<2xui16> ++ %result = "chlo.mulhi"(%lhs, %rhs) : (tensor<2xui16>, tensor<2xui16>) -> tensor<2xui16> ++ check.expect_eq_const %result, dense<[15, 1]> : tensor<2xui16> ++ func.return ++} ++ ++// ----- ++ ++// Test chlo.mulhi operation for signed 32-bit integers. ++func.func @mulhi_op_test_si32() { ++ %lhs = stablehlo.constant dense<[268435456, -268435456]> : tensor<2xi32> ++ %rhs = stablehlo.constant dense<[16, 16]> : tensor<2xi32> ++ %result = "chlo.mulhi"(%lhs, %rhs) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> ++ check.expect_eq_const %result, dense<[1, -1]> : tensor<2xi32> ++ func.return ++} ++ ++// ----- ++ ++// Test chlo.mulhi operation for unsigned 32-bit integers. ++func.func @mulhi_op_test_ui32() { ++ %lhs = stablehlo.constant dense<[4026531840, 2147483648]> : tensor<2xui32> ++ %rhs = stablehlo.constant dense<[16, 2]> : tensor<2xui32> ++ %result = "chlo.mulhi"(%lhs, %rhs) : (tensor<2xui32>, tensor<2xui32>) -> tensor<2xui32> ++ check.expect_eq_const %result, dense<[15, 1]> : tensor<2xui32> ++ func.return ++} ++ ++// ----- ++ ++// Test chlo.mulhi operation for signed 64-bit integers. ++// 1152921504606846976 * 16 = 18446744073709551616 ==> b1_0{64_zeros} ++func.func @mulhi_op_test_si64() { ++ %lhs = stablehlo.constant dense<[1152921504606846976, -1152921504606846976]> : tensor<2xi64> ++ %rhs = stablehlo.constant dense<[16, 16]> : tensor<2xi64> ++ %result = "chlo.mulhi"(%lhs, %rhs) : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64> ++ check.expect_eq_const %result, dense<[1, -1]> : tensor<2xi64> ++ func.return ++} ++ ++// ----- ++ ++// Test chlo.mulhi operation for unsigned 64-bit integers. ++func.func @mulhi_op_test_ui64() { ++ %lhs = stablehlo.constant dense<[17293822569102704640, 9223372036854775808]> : tensor<2xui64> ++ %rhs = stablehlo.constant dense<[16, 2]> : tensor<2xui64> ++ %result = "chlo.mulhi"(%lhs, %rhs) : (tensor<2xui64>, tensor<2xui64>) -> tensor<2xui64> ++ check.expect_eq_const %result, dense<[15, 1]> : tensor<2xui64> ++ func.return ++} +diff --ruN a/stablehlo/stablehlo/tests/ops_chlo.mlir b/stablehlo/stablehlo/tests/ops_chlo.mlir +--- stablehlo/stablehlo/tests/ops_chlo.mlir ++++ stablehlo/stablehlo/tests/ops_chlo.mlir +@@ -558,3 +558,27 @@ + } : (tensor<2x3xf32>, tensor<2xf32>) -> (tensor<2x3xf32>, tensor<2xf32>) + func.return %0#0 : tensor<2x3xf32> + } ++ ++// ----- ++ ++// CHECK-LABEL: func @mulhi_i32 ++func.func @mulhi_i32(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { ++ %0 = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> ++ func.return %0: tensor<4xi32> ++} ++ ++// ----- ++ ++func.func @mulhi_boolean(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { ++ // expected-error @+1 {{'chlo.mulhi' op operand #0 must be ranked tensor of 2/4/8/16/32/64-bit integer values, but got 'tensor<4xi1>'}} ++ %0 = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> ++ func.return %0: tensor<4xi1> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @mulhi_i64 ++func.func @mulhi_i64(%arg0: tensor<4xi64>, %arg1: tensor<4xi64>) -> tensor<4xi64> { ++ %0 = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64> ++ func.return %0: tensor<4xi64> ++} diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir --- stablehlo/stablehlo/tests/ops_stablehlo.mlir +++ stablehlo/stablehlo/tests/ops_stablehlo.mlir @@ -724,6 +2760,225 @@ diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_ + func.return %0 : !stablehlo.future> + } +} +diff --ruN a/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td b/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td +--- stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td ++++ stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td +@@ -20,6 +20,34 @@ + + class StableHLO_ConstantLike : NativeCodeCall< + "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; ++ ++def IsSignedIntegerType : Constraint(llvm::cast($0.getType()).getElementType()).isUnsigned()">>; ++ ++def IsSignedIntegerLessThan64Bit : Constraint>; ++ ++def IsSigned64BitIntegerType : Constraint>; ++ ++def IsUnsignedIntegerType : Constraint(llvm::cast($0.getType()).getElementType()).isUnsigned()">>; ++ ++def IsUnsignedIntegerLessThan64Bit : Constraint>; ++ ++def IsUnsigned64BitIntegerType : Constraint>; ++ ++def StableHLO_ConvertOpUpcast : NativeCodeCall< ++ "::mlir::stablehlo::getConvertOpUpcast($_builder, $_loc, $0)">; ++ ++def StableHLO_ConvertToUnsigned : NativeCodeCall< ++ "::mlir::stablehlo::getConvertToUnsigned($_builder, $_loc, $0)">; ++ ++def GetShiftAmountConstant : NativeCodeCall< ++ "::mlir::stablehlo::getConstantLike($_builder, $_loc," ++ " cast(cast($0.getType()).getElementType()).getWidth(), $1)">; + + def ComplexElementType : Type< + CPred<"isa(cast($_self).getElementType())">, +@@ -115,6 +143,118 @@ + (STABLEHLO_DEFAULT_COMPARISON_TYPE) + )>; + ++// Signed mulhi_s32 for N-bit integers where N < 64, meaning we can upcast. ++// downcast((upcast(x, N*2) * upcast(y,N*2)) >> N, N) ++def : Pat<(CHLO_MulhiOp $lhs, $rhs), ++ (StableHLO_ConvertOp ++ (StableHLO_ShiftRightArithmeticOp ++ (StableHLO_MulOp:$wideMul (StableHLO_ConvertOpUpcast $lhs), (StableHLO_ConvertOpUpcast $rhs)), ++ (GetShiftAmountConstant $lhs, $wideMul) ++ ) ++ ), ++ [(IsSignedIntegerLessThan64Bit $lhs)]>; ++ ++// Signed mulhi_s32 for N-bit integers where N < 64, meaning we can upcast. ++// downcast((upcast(x, N*2) * upcast(y,N*2)) >> N, N) ++// ++// Unsigned ints use logical shifts. ++def : Pat<(CHLO_MulhiOp $lhs, $rhs), ++ (StableHLO_ConvertOp ++ (StableHLO_ShiftRightLogicalOp ++ (StableHLO_MulOp:$wideMul (StableHLO_ConvertOpUpcast $lhs), (StableHLO_ConvertOpUpcast $rhs)), ++ (GetShiftAmountConstant $lhs, $wideMul) ++ ) ++ ), ++ [(IsUnsignedIntegerLessThan64Bit $lhs)]>; ++ ++// mulhi_u64 with decomposition into 32 bit parts. ++// ++// 1. Split operands into lower and upper 32-bit halves ++// uint64_t a_hi = a >> 32; ++// uint64_t b_hi = b >> 32; ++// uint64_t a_lo = a & 0xFFFFFFFF; ++// uint64_t b_lo = b & 0xFFFFFFFF; ++ ++// // 2. Compute the four 64-bit partial products ++// uint64_t p0 = a_lo * b_lo; ++// uint64_t p1 = a_lo * b_hi; ++// uint64_t p2 = a_hi * b_lo; ++// uint64_t p3 = a_hi * b_hi; ++ ++// // 3. Accumulate middle terms and propagate the carry ++// uint32_t carry_p0 = (uint32_t)(p0 >> 32); ++// uint64_t mid1 = p1 + carry_p0; ++// uint64_t mid2 = p2 + (uint32_t)mid1; ++ ++// // 4. Sum the high terms ++// return p3 + (mid1 >> 32) + (mid2 >> 32); ++def : Pat<(CHLO_MulhiOp $lhs, $rhs), ++ (StableHLO_AddOp ++ (StableHLO_AddOp ++ (StableHLO_MulOp:$p3 ++ (StableHLO_ShiftRightLogicalOp:$a_hi $lhs, (StableHLO_ConstantLike<"32"> $lhs)), ++ (StableHLO_ShiftRightLogicalOp:$b_hi $rhs, (StableHLO_ConstantLike<"32"> $rhs)) ++ ), ++ (StableHLO_ShiftRightLogicalOp:$carry_mid1 ++ (StableHLO_AddOp:$mid1 ++ (StableHLO_MulOp:$p1 ++ (StableHLO_AndOp:$a_lo $lhs, (StableHLO_ConstantLike<"0xFFFFFFFFULL"> $lhs)), ++ $b_hi ++ ), ++ (StableHLO_ShiftRightLogicalOp:$carry_p0 ++ (StableHLO_MulOp:$p0 $a_lo, (StableHLO_AndOp:$b_lo $rhs, (StableHLO_ConstantLike<"0xFFFFFFFFULL"> $rhs))), ++ (StableHLO_ConstantLike<"32"> $lhs) ++ ) ++ ), ++ (StableHLO_ConstantLike<"32"> $lhs) ++ ) ++ ), ++ (StableHLO_ShiftRightLogicalOp:$carry_mid2 ++ (StableHLO_AddOp:$mid2 ++ (StableHLO_MulOp:$p2 $a_hi, $b_lo), ++ (StableHLO_AndOp $mid1, (StableHLO_ConstantLike<"0xFFFFFFFFULL"> $lhs)) ++ ), ++ (StableHLO_ConstantLike<"32"> $lhs) ++ ) ++ ), ++ [(IsUnsigned64BitIntegerType $lhs)]>; ++ ++// mulhi_u64 with correction for two's complement sign extension: ++// res = mulhi_u64((uint64_t)a, (uint64_t)b) ++// if (a < 0) hi -= b; ++// if (b < 0) hi -= a; ++def : Pat<(CHLO_MulhiOp $lhs, $rhs), ++ (StableHLO_ConvertOp ++ (StableHLO_SubtractOp ++ (StableHLO_SubtractOp ++ (CHLO_MulhiOp:$hi ++ (StableHLO_ConvertToUnsigned $lhs), ++ (StableHLO_ConvertToUnsigned $rhs) ++ ), ++ (StableHLO_SelectOp ++ (StableHLO_CompareOp ++ $lhs, ++ (StableHLO_ConstantLike<"0"> $lhs), ++ StableHLO_ComparisonDirectionValue<"LT">, ++ (STABLEHLO_DEFAULT_COMPARISON_TYPE) ++ ), ++ (StableHLO_ConvertToUnsigned $rhs), ++ (StableHLO_ConstantLike<"0"> (StableHLO_ConvertToUnsigned $rhs)) ++ ) ++ ), ++ (StableHLO_SelectOp ++ (StableHLO_CompareOp ++ $rhs, ++ (StableHLO_ConstantLike<"0"> $rhs), ++ StableHLO_ComparisonDirectionValue<"LT">, ++ (STABLEHLO_DEFAULT_COMPARISON_TYPE) ++ ), ++ (StableHLO_ConvertToUnsigned $lhs), ++ (StableHLO_ConstantLike<"0"> (StableHLO_ConvertToUnsigned $lhs)) ++ ) ++ ) ++ ), ++ [(IsSigned64BitIntegerType $lhs)]>; + + def : Pat<(CHLO_TanOp $input), + (StableHLO_TanOp $input, ConstDefaultResultAccuracyAttr)>; +diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +--- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp ++++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +@@ -14,7 +14,6 @@ + #include + #include + #include +-#include + #include + #include + #include +@@ -24,6 +23,7 @@ + #include "llvm/ADT/STLExtras.h" + #include "llvm/ADT/Sequence.h" + #include "llvm/ADT/SmallVector.h" ++#include "llvm/Support/Casting.h" + #include "llvm/Support/Debug.h" + #include "mlir/Dialect/Arith/IR/Arith.h" + #include "mlir/Dialect/Complex/IR/Complex.h" +@@ -199,6 +199,42 @@ + return getConstantLike( + b, loc, llvm::APFloat::getSmallestNormalized(ty.getFloatSemantics()), + val); ++} ++ ++static Value getConvertOpUpcast(OpBuilder& builder, Location loc, Value val) { ++ auto inputType = cast(val.getType()).getElementType(); ++ auto intType = cast(inputType); ++ unsigned width = intType.getWidth(); ++ Type wideElementType = IntegerType::get(builder.getContext(), width * 2, ++ intType.getSignedness()); ++ auto wideType = cast(val.getType()).clone(wideElementType); ++ return ConvertOp::create(builder, loc, wideType, val); ++} ++ ++static Value getConvertToUnsigned(OpBuilder& builder, Location loc, Value val) { ++ auto inputType = cast(val.getType()).getElementType(); ++ auto intType = cast(inputType); ++ Type unsignedElementType = IntegerType::get( ++ builder.getContext(), intType.getWidth(), IntegerType::Unsigned); ++ auto unsignedType = ++ cast(val.getType()).clone(unsignedElementType); ++ return ConvertOp::create(builder, loc, unsignedType, val); ++} ++ ++// Helper to check if the element type of the shaped value is an integer ++// of the specified signedness and with width strictly less than n_bits. ++static bool isIntegerLessThanNBits(Value val, bool isSigned, uint32_t n_bits) { ++ auto shapedType = llvm::cast(val.getType()); ++ auto intType = llvm::cast(shapedType.getElementType()); ++ return (intType.isUnsigned() == !isSigned) && intType.getWidth() < n_bits; ++} ++ ++// Helper to check if the element type of the shaped value is an integer ++// of the specified signedness and with width exactly equal to n_bits. ++static bool isIntegerWithNBits(Value val, bool isSigned, uint32_t n_bits) { ++ auto shapedType = llvm::cast(val.getType()); ++ auto intType = llvm::cast(shapedType.getElementType()); ++ return (intType.isUnsigned() == !isSigned) && intType.getWidth() == n_bits; + } + + //===----------------------------------------------------------------------===// diff --ruN a/stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp --- stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp +++ stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index eda48d72684f91..01503992a45f81 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -34,6 +34,7 @@ load("//third_party/highway:workspace.bzl", highway = "repo") load("//third_party/highwayhash:workspace.bzl", highwayhash = "repo") load("//third_party/hwloc:workspace.bzl", hwloc = "repo") load("//third_party/implib_so:workspace.bzl", implib_so = "repo") +load("//third_party/libdrm:workspace.bzl", libdrm = "repo") load("//third_party/llvm:workspace.bzl", llvm = "repo") load("//third_party/llvm_openmp:workspace.bzl", llvm_openmp = "repo") load("//third_party/mkl_dnn:workspace.bzl", onednn = "repo") @@ -95,6 +96,7 @@ def _initialize_third_party(): highwayhash() hwloc() implib_so() + libdrm() llvm_openmp() ml_dtypes() mpitrampoline() diff --git a/third_party/xla/workspace3.bzl b/third_party/xla/workspace3.bzl index c8baec092b1ba7..66ebcd8265bfcb 100644 --- a/third_party/xla/workspace3.bzl +++ b/third_party/xla/workspace3.bzl @@ -50,10 +50,10 @@ def workspace(): # Details: https://github.com/google-ml-infra/rules_ml_toolchain tf_http_archive( name = "rules_ml_toolchain", - sha256 = "0b42f693a60c6050d87db1e0a0eaeb84ab3f54191fce094d86334faedc807da0", - strip_prefix = "rules_ml_toolchain-398d613aea7a4c294da49b79a6d6f3f8732bd84c", + sha256 = "60248801e83422a5769d7a29d3bb3a66a484821c846dfc0ff62045d64e7ec982", + strip_prefix = "rules_ml_toolchain-a3bf5be11de756adf6e212b46017769530073766", urls = tf_mirror_urls( - "https://github.com/google-ml-infra/rules_ml_toolchain/archive/398d613aea7a4c294da49b79a6d6f3f8732bd84c.tar.gz", + "https://github.com/google-ml-infra/rules_ml_toolchain/archive/a3bf5be11de756adf6e212b46017769530073766.tar.gz", ), ) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 61d8823a6ab661..9850c818a53453 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1295,6 +1295,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", "@tsl//tsl/platform:platform_port", diff --git a/third_party/xla/xla/backends/autotuner/BUILD b/third_party/xla/xla/backends/autotuner/BUILD index 8ddcecd7ee5be3..8498832224975f 100644 --- a/third_party/xla/xla/backends/autotuner/BUILD +++ b/third_party/xla/xla/backends/autotuner/BUILD @@ -128,6 +128,7 @@ cc_library( deps = [ "//xla/service:executable", "//xla/service:shaped_buffer", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", @@ -152,3 +153,13 @@ tf_proto_library( name = "backends_proto", srcs = ["backends.proto"], ) + +tf_proto_library( + name = "autotuning_proto", + srcs = ["autotuning.proto"], + deps = [ + ":backends_proto", + "@com_google_protobuf//:any_proto", + "@com_google_protobuf//:duration_proto", + ], +) diff --git a/third_party/xla/xla/backends/autotuner/autotuner.cc b/third_party/xla/xla/backends/autotuner/autotuner.cc index b88179a0cb04e1..cce710c9b4267e 100644 --- a/third_party/xla/xla/backends/autotuner/autotuner.cc +++ b/third_party/xla/xla/backends/autotuner/autotuner.cc @@ -185,8 +185,8 @@ absl::Status Autotuner::Autotune(HloModule* module, VLOG(1) << "Finding configs for " << instruction_groups.size() << " unique instructions."; - TF_ASSIGN_OR_RETURN(std::vector configs, - GetConfigsForAll(instruction_groups)); + ASSIGN_OR_RETURN(std::vector configs, + GetConfigsForAll(instruction_groups)); for (int i = 0; i < instruction_groups.size(); i++) { auto& instructions = instruction_groups[i]; @@ -234,8 +234,8 @@ absl::Status Autotuner::Autotune(HloModule* module, VLOG(1) << "Shard " << my_shard_index << "/" << total_shards << ": finding configs for " << instruction_groups.size() << "/" << all_instruction_groups.size() << " unique instructions "; - TF_ASSIGN_OR_RETURN(std::vector configs, - GetConfigsForAll(instruction_groups)); + ASSIGN_OR_RETURN(std::vector configs, + GetConfigsForAll(instruction_groups)); std::vector autotuned_instructions; autotuned_instructions.reserve(instruction_groups.size()); for (int i = 0; i < instruction_groups.size(); ++i) { @@ -249,8 +249,7 @@ absl::Status Autotuner::Autotune(HloModule* module, GetKvStoreKey(module, my_shard_index, codegen_backends_); std::string local_results; if (!autotuned_instructions.empty()) { - TF_ASSIGN_OR_RETURN(local_results, - cache_->Serialize(autotuned_instructions)); + ASSIGN_OR_RETURN(local_results, cache_->Serialize(autotuned_instructions)); } absl::StatusOr stored_result = kv_store.TryGet(local_key); if (stored_result.status().code() == absl::StatusCode::kNotFound) { @@ -282,10 +281,10 @@ absl::Status Autotuner::Autotune(HloModule* module, << i << " / " << total_shards << " at " << remote_key; // TODO(b/361009609): reset to infinite duration once issue with MPI is // fixed. https://github.com/google/jax/issues/22995. - TF_ASSIGN_OR_RETURN(std::string remote_results, - kv_store.Get(remote_key, absl::Hours(24))); + ASSIGN_OR_RETURN(std::string remote_results, + kv_store.Get(remote_key, absl::Hours(24))); if (!remote_results.empty()) { - TF_RETURN_IF_ERROR(cache_->Deserialize(remote_results)); + RETURN_IF_ERROR(cache_->Deserialize(remote_results)); } } @@ -301,11 +300,11 @@ absl::Status Autotuner::Autotune(HloModule* module, "across all shards.")); } if (autotune_config_.dump_hlos) { - TF_RETURN_IF_ERROR(DumpHlo(instruction_group[0], *cached_config)); + RETURN_IF_ERROR(DumpHlo(instruction_group[0], *cached_config)); } CodegenBackend* codegen_backend = cached_config->codegen_backend; for (auto* instr : instruction_group) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( codegen_backend->ApplyConfig(*instr, *cached_config->backend_config)); } } @@ -347,7 +346,7 @@ tsl::Future Autotuner::GetConfig(HloInstruction* instr) { } if (autotune_config_.use_default_config) { - TF_ASSIGN_OR_RETURN(Config default_config, GetDefaultConfig(*instr)); + ASSIGN_OR_RETURN(Config default_config, GetDefaultConfig(*instr)); VLOG(1) << "Using default config: " << default_config.ToString(); return default_config; } @@ -391,8 +390,8 @@ absl::Status Autotuner::IsValidExecutable( tsl::Future Autotuner::TuneBestConfig( HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(std::vector supported_configs, - GetSupportedConfigs(instr)); + ASSIGN_OR_RETURN(std::vector supported_configs, + GetSupportedConfigs(instr)); if (supported_configs.empty()) { return absl::InternalError( absl::StrCat("Autotuning failed for HLO: ", instr->ToString(), @@ -825,7 +824,7 @@ absl::Status Autotuner::DumpHlo(HloInstruction* instr, const Config& config) { DumpToFileInDirOrStdout(*parent_module, "", absl::StrCat(id, ".before.txt"), module->ToString()); HloInstruction* root = module->entry_computation()->root_instruction(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( config.codegen_backend->ApplyConfig(*root, *config.backend_config)); DumpToFileInDirOrStdout(*parent_module, "", absl::StrCat(id, ".after.txt"), module->ToString()); @@ -909,7 +908,7 @@ absl::Status Autotuner::DumpLogsToFile() { std::string textproto; tsl::protobuf::TextFormat::PrintToString(logs_, &textproto); - TF_RETURN_IF_ERROR(tsl::AppendStringToFile( + RETURN_IF_ERROR(tsl::AppendStringToFile( tsl::Env::Default(), autotune_config_.dump_logs_to, textproto)); VLOG(1) << "Autotune logs appended to file: " << autotune_config_.dump_logs_to; diff --git a/third_party/xla/xla/backends/autotuner/autotuning.proto b/third_party/xla/xla/backends/autotuner/autotuning.proto new file mode 100644 index 00000000000000..514e011edfa6e8 --- /dev/null +++ b/third_party/xla/xla/backends/autotuner/autotuning.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package xla.autotuner; + +import "google/protobuf/any.proto"; +import "google/protobuf/duration.proto"; +import "xla/backends/autotuner/backends.proto"; + +option java_outer_classname = "Tuning"; +option java_multiple_files = true; + +enum FailureKind { + FAILURE_KIND_UNSPECIFIED = 0; + COMPILATION_FAILED = 1; + EXECUTION_FAILED = 2; + REDZONE_CHECK_FAILED = 3; + WRONG_RESULTS = 4; +} + +message Failure { + FailureKind kind = 1; + string message = 2; +} + +message Config { + Backend backend = 1; + google.protobuf.Any backend_config = 2; +} + +message ConfigProfile { + Config config = 1; + google.protobuf.Duration run_time = 2; + int64 scratch_bytes = 3; + Failure failure = 4; + bool is_default_config = 5; +} + +message AutotuningKey { + string op_fingerprint = 1; + string device = 2; + // Optional user-specified version string. If provided, this is used as the + // version string in cache lookups, instead of a string derived from + // backend_versions. This is useful when strict compatibility guarantees are + // required and backend_versions cannot provide it. + string explicit_version = 3; + // map key is string representation of Backend enum. + map backend_versions = 4; +} + +message AutotuningRawProfiles { + AutotuningKey key = 1; + string hlo_text = 2; + repeated ConfigProfile config_profiles = 3; +} + +message ConfigCacheEntry { + AutotuningKey key = 1; + Config optimal_config = 2; +} + +message ConfigCache { + repeated ConfigCacheEntry results = 1; +} diff --git a/third_party/xla/xla/backends/autotuner/backends.proto b/third_party/xla/xla/backends/autotuner/backends.proto index 410087d63d01ef..569605b9ee8324 100644 --- a/third_party/xla/xla/backends/autotuner/backends.proto +++ b/third_party/xla/xla/backends/autotuner/backends.proto @@ -10,20 +10,18 @@ option java_outer_classname = "Backends"; // When adding a fission backend for a backend X, it should be named X_FISSION. enum Backend { UNSPECIFIED_BACKEND = 0; + reserved 3, 5, 12, 15; + reserved CUBLAS, CUBLAS_FISSION, ROCBLAS, ROCBLAS_FISSION; CUDNN = 1; TRITON = 2; - CUBLAS = 3; CUBLASLT = 4; - ROCBLAS = 5; HIPBLASLT = 6; MIOPEN = 7; CUSTOM_KERNEL = 8; BLOCK_LEVEL_EMITTER = 9; NATIVE_EMITTER = 10; LLVM_KERNEL_EMITTER = 11; - CUBLAS_FISSION = 12; CUBLASLT_FISSION = 13; CUSTOM_KERNEL_FISSION = 14; - ROCBLAS_FISSION = 15; HIPBLASLT_FISSION = 16; } diff --git a/third_party/xla/xla/backends/autotuner/profiler.h b/third_party/xla/xla/backends/autotuner/profiler.h index 922ac421ee59fd..5f327e4a8df88d 100644 --- a/third_party/xla/xla/backends/autotuner/profiler.h +++ b/third_party/xla/xla/backends/autotuner/profiler.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/service/executable.h" #include "xla/service/shaped_buffer.h" #include "xla/tsl/platform/statusor.h" @@ -58,8 +59,8 @@ class Profiler { // Profiles a single executable. virtual absl::StatusOr Profile( std::unique_ptr executable) { - TF_ASSIGN_OR_RETURN(std::unique_ptr buffers, - CreateInputBuffers(executable.get())); + ASSIGN_OR_RETURN(std::unique_ptr buffers, + CreateInputBuffers(executable.get())); return Profile(executable.get(), *buffers); } diff --git a/third_party/xla/xla/backends/cpu/BUILD b/third_party/xla/xla/backends/cpu/BUILD index 82305961ca93e4..9d879a9e94c4b1 100644 --- a/third_party/xla/xla/backends/cpu/BUILD +++ b/third_party/xla/xla/backends/cpu/BUILD @@ -86,6 +86,7 @@ onednn_graph_cc_library( "//xla/hlo/ir:hlo", "//xla/tsl/mkl:onednn", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", @@ -143,12 +144,12 @@ cc_library( "//xla/stream_executor:device_address", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@XNNPACK//ynnpack", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -174,6 +175,7 @@ cc_library( "//xla/backends/cpu/runtime/ynnpack:ynn_interop", "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@XNNPACK//ynnpack", "@com_google_absl//absl/algorithm:container", @@ -236,6 +238,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/stream_executor:device_address", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/backends/cpu/codegen/dot/BUILD b/third_party/xla/xla/backends/cpu/codegen/dot/BUILD index eff4f86b0f8984..550b18da04a1cb 100644 --- a/third_party/xla/xla/backends/cpu/codegen/dot/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/dot/BUILD @@ -24,6 +24,7 @@ cc_library( "//xla/service:hlo_module_config", "//xla/service/cpu:dot_op_emitter", "//xla/service/llvm_ir:ir_array", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/backends/cpu/codegen/dot/dot_kernel_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/dot/dot_kernel_emitter.cc index db8f67ec6f9bf7..84fb393b0afa15 100644 --- a/third_party/xla/xla/backends/cpu/codegen/dot/dot_kernel_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/dot/dot_kernel_emitter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "xla/backends/cpu/codegen/kernel_api_ir_builder.h" @@ -81,7 +82,7 @@ DotKernelEmitter::EmitKernelDefinition() { std::unique_ptr llvm_module = KernelApiIrBuilder::CreateModule( absl::StrCat(instr_->name(), "_elemental_kernel_module"), *ctx); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( KernelApiIrBuilder::KernelPrototype kernel_prototype, kernel_api_ir_builder.EmitKernelPrototype( *llvm_module, instr_, buffer_assignment_, name(), "_kernel")); @@ -94,15 +95,14 @@ DotKernelEmitter::EmitKernelDefinition() { llvm_ir::IrArray rhs_array = kernel_prototype.arguments[1]; llvm_ir::IrArray target_array = kernel_prototype.results[0]; - TF_ASSIGN_OR_RETURN( - DotOpWorkGroupDim num_workgroups, - EmitDotOperation( - *instr_, target_array, lhs_array, rhs_array, - /*addend_array=*/nullptr, - {kernel_prototype.workgroup_id.x, kernel_prototype.workgroup_id.y}, - /*executable_run_options_value=*/nullptr, &builder, config, - *target_machine_, - /*allow_runtime_calls=*/false)); + ASSIGN_OR_RETURN(DotOpWorkGroupDim num_workgroups, + EmitDotOperation(*instr_, target_array, lhs_array, rhs_array, + /*addend_array=*/nullptr, + {kernel_prototype.workgroup_id.x, + kernel_prototype.workgroup_id.y}, + /*executable_run_options_value=*/nullptr, + &builder, config, *target_machine_, + /*allow_runtime_calls=*/false)); LlvmKernelSource source(std::move(ctx), std::move(llvm_module)); diff --git a/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD b/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD index 17b505749eb983..570027889136bc 100644 --- a/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD @@ -35,6 +35,7 @@ cc_library( "//xla/service/cpu:backend_config_proto_cc", "//xla/service/cpu:ir_emitter", "//xla/service/llvm_ir:ir_array", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -87,6 +88,7 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/backends/cpu/codegen/elemental/concatenate_kernel_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/elemental/concatenate_kernel_emitter.cc index f0bef4e1657072..0ffaa9f692aabf 100644 --- a/third_party/xla/xla/backends/cpu/codegen/elemental/concatenate_kernel_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/elemental/concatenate_kernel_emitter.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/Analysis.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" @@ -94,7 +95,7 @@ ConcatenateKernelEmitter::EmitKernelDefinition() { std::unique_ptr llvm_module = KernelApiIrBuilder::CreateModule( absl::StrCat(instr_->name(), "_elemental_kernel_module"), *ctx); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( KernelApiIrBuilder::KernelPrototype kernel_prototype, kernel_api_ir_builder.EmitKernelPrototype( *llvm_module, instr_, buffer_assignment_, name(), "_kernel")); @@ -104,7 +105,7 @@ ConcatenateKernelEmitter::EmitKernelDefinition() { kernel_prototype.function->getEntryBlock().getTerminator()); llvm_ir::IrArray output_array = kernel_prototype.results[0]; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool is_parallel, EmitFastConcatenate(instr_, kernel_prototype.arguments, output_array, llvm_module.get(), ir_builder, diff --git a/third_party/xla/xla/backends/cpu/codegen/elemental/elemental_kernel_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/elemental/elemental_kernel_emitter.cc index a311e724be11f1..8c7c9407506ad5 100644 --- a/third_party/xla/xla/backends/cpu/codegen/elemental/elemental_kernel_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/elemental/elemental_kernel_emitter.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" @@ -186,7 +187,7 @@ ElementalKernelEmitter::EmitKernelDefinition() { std::unique_ptr llvm_module = KernelApiIrBuilder::CreateModule( absl::StrCat(instr_->name(), "_elemental_kernel_module"), *ctx); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( KernelApiIrBuilder::KernelPrototype kernel_prototype, kernel_api_ir_builder.EmitKernelPrototype( *llvm_module, instr_, buffer_assignment_, name(), "_kernel")); @@ -199,7 +200,7 @@ ElementalKernelEmitter::EmitKernelDefinition() { ir_builder.setFastMathFlags( llvm_ir::GetCpuFastMathFlags(hlo_module->config())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( CpuElementalIrEmitter::ThreadLocalCallCallback thread_local_call_fn, ThreadLocalCallbackFactory(ir_builder, *llvm_module)); @@ -223,9 +224,9 @@ ElementalKernelEmitter::EmitKernelDefinition() { llvm_ir::ElementGenerator element_generator = elemental_ir_emitter.MakeElementGenerator(instr_, operand_to_generator); - TF_ASSIGN_OR_RETURN(NumWorkGroups num_workgroups, - EmitElementalLoops(ir_builder, instr_, kernel_prototype, - element_generator)); + ASSIGN_OR_RETURN(NumWorkGroups num_workgroups, + EmitElementalLoops(ir_builder, instr_, kernel_prototype, + element_generator)); LlvmKernelSource source(std::move(ctx), std::move(llvm_module)); @@ -266,7 +267,7 @@ absl::StatusOr ElementalKernelEmitter::EmitElementalLoops( // TODO(ezhulenev): Support multiple results for parallel loops. if (multiple_results) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, kernel_prototype.results, &b) .EmitLoop(llvm_ir::IrName(instr))); return NumWorkGroups(); @@ -279,7 +280,7 @@ absl::StatusOr ElementalKernelEmitter::EmitElementalLoops( if (has_parallel_config) { ParallelPartitionBounds parallel_bounds = EmitParallelPartitionBounds( b, kernel_prototype, *parallel_config, instr->shape(), instr->name()); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, result, ¶llel_bounds, &b) .EmitLoop(llvm_ir::IrName(instr))); return NumWorkGroups{ @@ -288,8 +289,8 @@ absl::StatusOr ElementalKernelEmitter::EmitElementalLoops( } // Emit a whole loop for the instruction. - TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter(element_generator, result, &b) - .EmitLoop(llvm_ir::IrName(instr))); + RETURN_IF_ERROR(llvm_ir::LoopEmitter(element_generator, result, &b) + .EmitLoop(llvm_ir::IrName(instr))); return NumWorkGroups(); } @@ -311,13 +312,13 @@ ElementalKernelEmitter::ThreadLocalCallbackFactory(llvm::IRBuilderBase& builder, /*emit_code_for_msan=*/false); IrEmitter::IRBuilderGuard builder_guard = ir_emitter->WithBuilder(builder); - TF_RETURN_IF_ERROR(ir_emitter->EmitSmallConstantGlobals()); + RETURN_IF_ERROR(ir_emitter->EmitSmallConstantGlobals()); if (instr_->has_to_apply()) { HloComputation* nested_computation = instr_->to_apply(); bool is_reducer = instr_->opcode() == HloOpcode::kReduce || instr_->opcode() == HloOpcode::kReduceWindow; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ir_emitter ->EmitNestedComputation(*nested_computation, llvm_ir::IrName(nested_computation->name()), diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD b/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD index 9d8122e88c85bb..dd7df60dbac744 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD @@ -62,6 +62,7 @@ cc_library( "//xla/service/cpu:backend_config_proto_cc", "//xla/service/llvm_ir:llvm_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc index e20220f84dc5d4..2b52a5f0d9675b 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_fusion_emitter.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Instructions.h" @@ -184,13 +185,12 @@ absl::StatusOr EmitEntryFunctionApi( absl::string_view module_name(fusion_module.getName().value()); mlir::OpBuilder builder(context); auto loc = mlir::NameLoc::get(builder.getStringAttr(module_name)); - TF_ASSIGN_OR_RETURN( - std::vector arguments, - KernelApiIrBuilder::GetKernelArgumentsParameters(&fusion, - &buffer_assignment)); - TF_ASSIGN_OR_RETURN(std::vector results, - KernelApiIrBuilder::GetKernelResultsParameters( - &fusion, &buffer_assignment)); + ASSIGN_OR_RETURN(std::vector arguments, + KernelApiIrBuilder::GetKernelArgumentsParameters( + &fusion, &buffer_assignment)); + ASSIGN_OR_RETURN(std::vector results, + KernelApiIrBuilder::GetKernelResultsParameters( + &fusion, &buffer_assignment)); // TBD: Annotate tensors with the buffer indices. This way, the buffer // propagation pass can clean them up later. @@ -215,16 +215,15 @@ absl::StatusOr EmitEntryFunctionApi( for (const auto& [index, arg] : llvm::enumerate(arguments)) { param_types.push_back(emitters::TensorShapeToMlirType(arg.shape, builder)); - TF_ASSIGN_OR_RETURN( - arg_attrs.emplace_back(), - get_arg_attrs(index - 1, arg.slice, /*is_result=*/false)); + ASSIGN_OR_RETURN(arg_attrs.emplace_back(), + get_arg_attrs(index - 1, arg.slice, /*is_result=*/false)); } auto result_types = emitters::ShapeToMlirTypes(fusion.shape(), builder); param_types.append(result_types.begin(), result_types.end()); for (const auto& [index, result] : llvm::enumerate(results)) { - TF_ASSIGN_OR_RETURN(arg_attrs.emplace_back(), - get_arg_attrs(index, result.slice, /*is_result=*/true)); + ASSIGN_OR_RETURN(arg_attrs.emplace_back(), + get_arg_attrs(index, result.slice, /*is_result=*/true)); } builder.setInsertionPointToStart(fusion_module.getBody()); @@ -273,7 +272,7 @@ absl::StatusOr EmitCallTargets( for (const auto& comp : computations.partitioned_computations()) { for (const auto& subgraph : comp.subgraphs()) { if (subgraph_to_mlir_fn.contains(&subgraph)) { - TF_RETURN_IF_ERROR(emitters::SubgraphToMlirFunction( + RETURN_IF_ERROR(emitters::SubgraphToMlirFunction( comp, subgraph, subgraph_to_mlir_fn[&subgraph], call_targets, computations.mlir_context())); } @@ -281,7 +280,7 @@ absl::StatusOr EmitCallTargets( } for (const auto& epilogue : computations.epilogues()) { if (epilogue.roots.empty()) continue; - TF_RETURN_IF_ERROR(emitters::SubgraphToMlirFunction( + RETURN_IF_ERROR(emitters::SubgraphToMlirFunction( computations.FindPartitionedComputation( fusion.fused_instructions_computation()), epilogue, subgraph_to_mlir_fn[&epilogue], call_targets, @@ -295,7 +294,7 @@ int64_t CeilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; } absl::StatusOr> CreateNamedMlirModuleOp( const HloFusionInstruction& fusion, mlir::Builder& builder) { - TF_ASSIGN_OR_RETURN(std::string fusion_name, GetFusionName(fusion)); + ASSIGN_OR_RETURN(std::string fusion_name, GetFusionName(fusion)); auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_name)); return llvm_ir::CreateMlirModuleOp(loc, fusion_name); } @@ -307,9 +306,9 @@ absl::StatusOr GetFusionName(const HloFusionInstruction& fusion) { ->config() .debug_options() .xla_cpu_generate_unique_c_style_kernel_entry_points()) { - TF_ASSIGN_OR_RETURN(fusion_name, ConvertToCName(absl::StrCat( - fusion.parent()->parent()->name(), "_", - fusion.name()))); + ASSIGN_OR_RETURN(fusion_name, ConvertToCName(absl::StrCat( + fusion.parent()->parent()->name(), "_", + fusion.name()))); } return fusion_name; } diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc index 97ccce81cb2336..4fe9567bb66fef 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/LLVMContext.h" @@ -255,8 +256,8 @@ IndexingMap GetScatterIndexingMap( absl::StatusOr CpuScatterFusion::EmitKernelDefinition() { mlir::OpBuilder builder(mlir_context_); - TF_ASSIGN_OR_RETURN(mlir::OwningOpRef mlir_module, - CreateNamedMlirModuleOp(*fusion_, builder)); + ASSIGN_OR_RETURN(mlir::OwningOpRef mlir_module, + CreateNamedMlirModuleOp(*fusion_, builder)); absl::string_view module_name(mlir_module->getName().value()); emitters::SetIndexDataLayout(mlir_module.get(), *fusion_); @@ -272,7 +273,7 @@ CpuScatterFusion::EmitKernelDefinition() { xla::CpuMemoryRegionNameAttr::name, builder.getStringAttr(BuildModuleMemoryRegionName(name(), fusion_))); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( mlir::func::FuncOp entry_func, EmitEntryFunctionApi(mlir_module.get(), *fusion_, std::string(module_name), buffer_assignment_)); @@ -281,11 +282,11 @@ CpuScatterFusion::EmitKernelDefinition() { GetEpilogues(*fusion_, mlir_context_); emitters::PartitionedComputations computations( fusion_->fused_instructions_computation(), mlir_context_, epilogues); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( emitters::CallTargetProvider call_targets, EmitCallTargets(mlir_module.get(), *fusion_, computations, epilogues)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( EmitEntryFunction(computations, call_targets, entry_func, *fusion_)); // Convert kernel arguments to fake allocations and buffer uses. @@ -293,9 +294,8 @@ CpuScatterFusion::EmitKernelDefinition() { KernelSpec::Buffers result_buffers; for (auto& indexed : ShapeUtil::GetLeafShapes(fusion_->shape())) { - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice slice, - buffer_assignment_.GetUniqueSlice(fusion_, indexed.index)); + ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + buffer_assignment_.GetUniqueSlice(fusion_, indexed.index)); result_buffers.push_back({slice, indexed.shape}); } @@ -305,7 +305,7 @@ CpuScatterFusion::EmitKernelDefinition() { int64_t operand_index = 0; for (HloInstruction* operand : fusion_->operands()) { for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( BufferAllocation::Slice slice, buffer_assignment_.GetUniqueSlice(operand, indexed.index)); diff --git a/third_party/xla/xla/backends/cpu/codegen/tools/BUILD b/third_party/xla/xla/backends/cpu/codegen/tools/BUILD index dad63ea5e98be9..3bd2ccf75fc62c 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tools/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/tools/BUILD @@ -26,6 +26,7 @@ cc_library( "//xla/service/cpu:cpu_compiler_pure", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", @@ -90,6 +91,7 @@ xla_cc_binary( "//xla/codegen:kernel_definition", "//xla/codegen/tools:test_lib", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/backends/cpu/codegen/tools/fusion_to_mlir.cc b/third_party/xla/xla/backends/cpu/codegen/tools/fusion_to_mlir.cc index c97f72ecdef576..8902a7db505efe 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tools/fusion_to_mlir.cc +++ b/third_party/xla/xla/backends/cpu/codegen/tools/fusion_to_mlir.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/Support/raw_ostream.h" #include "xla/backends/cpu/codegen/fusion_compiler.h" #include "xla/backends/cpu/codegen/fusion_emitter.h" @@ -33,7 +34,7 @@ namespace xla::cpu { absl::Status Run(const std::string& filename) { auto mlir_context = FusionCompiler::CreateContext(); - TF_ASSIGN_OR_RETURN(auto module, LoadTestModule(filename)); + ASSIGN_OR_RETURN(auto module, LoadTestModule(filename)); auto* inst = module->entry_computation()->root_instruction(); while (inst && (inst->opcode() == HloOpcode::kTuple || inst->opcode() == HloOpcode::kGetTupleElement)) { @@ -41,7 +42,7 @@ absl::Status Run(const std::string& filename) { } auto fusion = DynCast(inst); fusion->SetAndSanitizeName("main"); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( KernelDefinition kernel_definition, EmitFusionKernel(*mlir_context, *fusion, nullptr, false, false)); llvm::outs() << kernel_definition.source().ToString(); diff --git a/third_party/xla/xla/backends/cpu/codegen/tools/ir_compiler_opt_main.cc b/third_party/xla/xla/backends/cpu/codegen/tools/ir_compiler_opt_main.cc index 32d24489c95343..bbd366428d22d4 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tools/ir_compiler_opt_main.cc +++ b/third_party/xla/xla/backends/cpu/codegen/tools/ir_compiler_opt_main.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" @@ -93,7 +94,7 @@ absl::StatusOr GetInputContents(const IrCompilerOptConfig& opts, } std::string data; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::ReadFileToString(tsl::Env::Default(), input_path, &data)); return data; } @@ -125,14 +126,13 @@ absl::StatusOr RunIrCompilerPasses(const IrCompilerOptConfig& opts, auto ir_compiler = IrCompiler::Create(target_options, ir_compiler_options, IrCompiler::CompilationHooks()); - TF_ASSIGN_OR_RETURN(std::string ir_content, - GetInputContents(opts, argc, argv)); + ASSIGN_OR_RETURN(std::string ir_content, GetInputContents(opts, argc, argv)); llvm::LLVMContext context; - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseLlvmIr(ir_content, context)); + ASSIGN_OR_RETURN(std::unique_ptr module, + ParseLlvmIr(ir_content, context)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr target_machine, ir_compiler->InferTargetMachine( target_options, static_cast(opts.opt_level), @@ -149,13 +149,12 @@ absl::StatusOr RunIrCompilerPasses(const IrCompilerOptConfig& opts, absl::Status RunIrCompilerOptMain(int argc, char** argv, const IrCompilerOptConfig& opts) { - TF_ASSIGN_OR_RETURN(std::string output, - RunIrCompilerPasses(opts, argc, argv)); + ASSIGN_OR_RETURN(std::string output, RunIrCompilerPasses(opts, argc, argv)); if (opts.output_file == "-") { std::cout << output << std::endl; } else { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::WriteStringToFile(tsl::Env::Default(), opts.output_file, output)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/backends/cpu/constant_allocation.cc b/third_party/xla/xla/backends/cpu/constant_allocation.cc index 5eb642f3640839..b412eb4e05009a 100644 --- a/third_party/xla/xla/backends/cpu/constant_allocation.cc +++ b/third_party/xla/xla/backends/cpu/constant_allocation.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" @@ -127,9 +128,9 @@ absl::StatusOr> CreateConstantAllocations( VLOG(3) << "Create constant allocation for index " << allocation.index() << " from constant literal " << const_instr->name() << "; shape=" << const_instr->literal().shape(); - TF_ASSIGN_OR_RETURN(constants.emplace_back(), - LiteralToConstantAllocation(allocation.index(), - const_instr->literal())); + ASSIGN_OR_RETURN(constants.emplace_back(), + LiteralToConstantAllocation(allocation.index(), + const_instr->literal())); } return constants; diff --git a/third_party/xla/xla/backends/cpu/lite_aot/BUILD b/third_party/xla/xla/backends/cpu/lite_aot/BUILD index ab25c341a93bb4..8fc84ef74f1a59 100644 --- a/third_party/xla/xla/backends/cpu/lite_aot/BUILD +++ b/third_party/xla/xla/backends/cpu/lite_aot/BUILD @@ -18,6 +18,7 @@ cc_library( "//xla/service:executable", "//xla/service/cpu:executable_proto_cc", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/backends/cpu/lite_aot/xla_aot_function.cc b/third_party/xla/xla/backends/cpu/lite_aot/xla_aot_function.cc index e035e8b28e40e3..dafcaf2a36c58a 100644 --- a/third_party/xla/xla/backends/cpu/lite_aot/xla_aot_function.cc +++ b/third_party/xla/xla/backends/cpu/lite_aot/xla_aot_function.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/cpu/nanort/nanort_executable.h" #include "xla/literal.h" #include "xla/service/cpu/executable.pb.h" @@ -64,34 +65,33 @@ absl::StatusOr GetProgramShape( absl::StatusOr CreateExecutableAndSupportingLiterals( const CompilationResultProto& compilation_result) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ProgramShape program_shape, ProgramShape::FromProto( compilation_result.hlo_module().hlo_module().host_program_shape())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr nanort_executable, - NanoRtExecutable::Create(compilation_result, program_shape)); + ASSIGN_OR_RETURN(std::unique_ptr nanort_executable, + NanoRtExecutable::Create(compilation_result, program_shape)); std::vector results_literals; - TF_ASSIGN_OR_RETURN(auto nanort_program_shape, - GetProgramShape(*nanort_executable)); + ASSIGN_OR_RETURN(auto nanort_program_shape, + GetProgramShape(*nanort_executable)); if (nanort_program_shape.result().IsTuple()) { auto tuple_shapes = nanort_program_shape.result().tuple_shapes(); results_literals.reserve(tuple_shapes.size()); for (const Shape& shape : tuple_shapes) { - TF_ASSIGN_OR_RETURN(results_literals.emplace_back(), - Literal::Make(shape, /*allocate_arrays=*/true)); + ASSIGN_OR_RETURN(results_literals.emplace_back(), + Literal::Make(shape, /*allocate_arrays=*/true)); } } else { - TF_ASSIGN_OR_RETURN(results_literals.emplace_back(), - Literal::Make(nanort_program_shape.result(), - /*allocate_arrays=*/true)); + ASSIGN_OR_RETURN(results_literals.emplace_back(), + Literal::Make(nanort_program_shape.result(), + /*allocate_arrays=*/true)); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Literal temp_literal, Literal::Make( ShapeUtil::MakeShape(U8, {static_cast( @@ -122,11 +122,10 @@ absl::StatusOr> XlaAotFunction::Create( "Result names must be unique. Got ", absl::StrJoin(result_names, ","))); } - TF_ASSIGN_OR_RETURN( - auto executable_and_supporting_literals, - CreateExecutableAndSupportingLiterals(compilation_result)); + ASSIGN_OR_RETURN(auto executable_and_supporting_literals, + CreateExecutableAndSupportingLiterals(compilation_result)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto program_shape, GetProgramShape(*executable_and_supporting_literals.nanort_executable)); @@ -158,9 +157,8 @@ absl::StatusOr> XlaAotFunction::Create( absl::StatusOr> XlaAotFunction::Create( const CompilationResultProto& compilation_result) { - TF_ASSIGN_OR_RETURN( - auto executable_and_supporting_literals, - CreateExecutableAndSupportingLiterals(compilation_result)); + ASSIGN_OR_RETURN(auto executable_and_supporting_literals, + CreateExecutableAndSupportingLiterals(compilation_result)); auto& nanort_executable = executable_and_supporting_literals.nanort_executable; @@ -182,7 +180,7 @@ absl::StatusOr> XlaAotFunction::Create( arg_names.push_back(std::string(instr->name())); } std::vector result_names; - TF_ASSIGN_OR_RETURN(auto program_shape, GetProgramShape(*nanort_executable)); + ASSIGN_OR_RETURN(auto program_shape, GetProgramShape(*nanort_executable)); if (program_shape.result().IsTuple()) { auto tuple_shapes = program_shape.result().tuple_shapes(); absl::string_view root_name = diff --git a/third_party/xla/xla/backends/cpu/onednn_emitter.cc b/third_party/xla/xla/backends/cpu/onednn_emitter.cc index 0e51804ff92743..1b918bff9f50ba 100644 --- a/third_party/xla/xla/backends/cpu/onednn_emitter.cc +++ b/third_party/xla/xla/backends/cpu/onednn_emitter.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/tsl/platform/status_macros.h" #include "oneapi/dnnl/dnnl_common.hpp" #include "oneapi/dnnl/dnnl_graph.hpp" // NOLINT #include "xla/backends/cpu/onednn_fusion.h" @@ -126,7 +127,7 @@ static absl::StatusOr FindLogicalTensor( static absl::StatusOr CreateLogicalTensor( size_t tensor_id, const Shape& shape) { - TF_ASSIGN_OR_RETURN(auto type, OneDnnDatatype(shape.element_type())); + ASSIGN_OR_RETURN(auto type, OneDnnDatatype(shape.element_type())); dnnl::graph::logical_tensor::dims dims = OneDnnDimensions(shape); dnnl::graph::logical_tensor::dims strides = OneDnnStrides(shape); @@ -148,14 +149,13 @@ static absl::StatusOr DefineUnaryOp( VLOG(3) << absl::StreamFormat("Define logical tensor value for unary op: %s", instr->ToString()); - TF_ASSIGN_OR_RETURN(auto unary_op, OneDnnUnaryOperator(instr->opcode())); + ASSIGN_OR_RETURN(auto unary_op, OneDnnUnaryOperator(instr->opcode())); - TF_ASSIGN_OR_RETURN(auto input, - FindLogicalTensor(logical_tensors, instr->operand(0))); + ASSIGN_OR_RETURN(auto input, + FindLogicalTensor(logical_tensors, instr->operand(0))); size_t output_id = logical_tensors.size(); - TF_ASSIGN_OR_RETURN(auto output, - CreateLogicalTensor(output_id, instr->shape())); + ASSIGN_OR_RETURN(auto output, CreateLogicalTensor(output_id, instr->shape())); VLOG(3) << absl::StreamFormat(" tensors: input=%d, output=%d", input.get_id(), output.get_id()); @@ -172,16 +172,15 @@ static absl::StatusOr DefineBinaryOp( VLOG(3) << absl::StreamFormat("Define logical tensor value for binary op: %s", instr->ToString()); - TF_ASSIGN_OR_RETURN(auto binary_op, OneDnnBinaryOperator(instr->opcode())); + ASSIGN_OR_RETURN(auto binary_op, OneDnnBinaryOperator(instr->opcode())); - TF_ASSIGN_OR_RETURN(auto lhs, - FindLogicalTensor(logical_tensors, instr->operand(0))); - TF_ASSIGN_OR_RETURN(auto rhs, - FindLogicalTensor(logical_tensors, instr->operand(1))); + ASSIGN_OR_RETURN(auto lhs, + FindLogicalTensor(logical_tensors, instr->operand(0))); + ASSIGN_OR_RETURN(auto rhs, + FindLogicalTensor(logical_tensors, instr->operand(1))); size_t output_id = logical_tensors.size(); - TF_ASSIGN_OR_RETURN(auto output, - CreateLogicalTensor(output_id, instr->shape())); + ASSIGN_OR_RETURN(auto output, CreateLogicalTensor(output_id, instr->shape())); VLOG(3) << absl::StreamFormat(" tensors: lhs=%d, rhs=%d, output=%d", lhs.get_id(), rhs.get_id(), output.get_id()); @@ -199,7 +198,7 @@ static absl::StatusOr DefineMatMul( const DotDimensionNumbers& dnums = instr->dot_dimension_numbers(); const Shape& lhs_shape = instr->operand(0)->shape(); const Shape& rhs_shape = instr->operand(1)->shape(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool is_supported, IsOneDnnDotSupported(dnums, lhs_shape, rhs_shape, instr->shape())); @@ -211,25 +210,24 @@ static absl::StatusOr DefineMatMul( VLOG(3) << absl::StreamFormat("Define logical tensor value for MatMul: %s", instr->ToString()); - TF_ASSIGN_OR_RETURN(auto matmul_op, OneDnnBinaryOperator(instr->opcode())); - TF_ASSIGN_OR_RETURN(auto lhs, - FindLogicalTensor(logical_tensors, instr->operand(0))); - TF_ASSIGN_OR_RETURN(auto rhs, - FindLogicalTensor(logical_tensors, instr->operand(1))); + ASSIGN_OR_RETURN(auto matmul_op, OneDnnBinaryOperator(instr->opcode())); + ASSIGN_OR_RETURN(auto lhs, + FindLogicalTensor(logical_tensors, instr->operand(0))); + ASSIGN_OR_RETURN(auto rhs, + FindLogicalTensor(logical_tensors, instr->operand(1))); size_t output_id = logical_tensors.size(); - TF_ASSIGN_OR_RETURN(auto output, - CreateLogicalTensor(output_id, instr->shape())); + ASSIGN_OR_RETURN(auto output, CreateLogicalTensor(output_id, instr->shape())); VLOG(3) << absl::StreamFormat(" tensors: lhs=%d, rhs=%d, output=%d", lhs.get_id(), rhs.get_id(), output.get_id()); dnnl::graph::op op(op_id, matmul_op, {lhs, rhs}, {output}); - TF_ASSIGN_OR_RETURN(DotShape dot_shape, - GetDotShape(dnums, lhs_shape, rhs_shape, instr->shape())); - TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, - GetDotCanonicalDims(dnums, dot_shape)); + ASSIGN_OR_RETURN(DotShape dot_shape, + GetDotShape(dnums, lhs_shape, rhs_shape, instr->shape())); + ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, + GetDotCanonicalDims(dnums, dot_shape)); if (!dot_canonical_dims.lhs_canonical) { op.set_attr(dnnl::graph::op::attr::transpose_a, true); @@ -263,29 +261,27 @@ static absl::StatusOr EmitOneDnnFusion( for (const HloInstruction* instr : instructions) { switch (instr->opcode()) { case HloOpcode::kParameter: { - TF_ASSIGN_OR_RETURN(logical_tensors[instr], - DefineParameter(logical_tensors, instr)); + ASSIGN_OR_RETURN(logical_tensors[instr], + DefineParameter(logical_tensors, instr)); } break; // Unary elementwise ops. case HloOpcode::kExp: { - TF_ASSIGN_OR_RETURN( - logical_tensors[instr], - DefineUnaryOp(graph, op_id++, logical_tensors, instr)); + ASSIGN_OR_RETURN(logical_tensors[instr], + DefineUnaryOp(graph, op_id++, logical_tensors, instr)); } break; // Binary elementwise ops. case HloOpcode::kAdd: case HloOpcode::kMultiply: { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( logical_tensors[instr], DefineBinaryOp(graph, op_id++, logical_tensors, instr)); } break; case HloOpcode::kDot: { - TF_ASSIGN_OR_RETURN( - logical_tensors[instr], - DefineMatMul(graph, op_id++, logical_tensors, instr)); + ASSIGN_OR_RETURN(logical_tensors[instr], + DefineMatMul(graph, op_id++, logical_tensors, instr)); } break; default: { diff --git a/third_party/xla/xla/backends/cpu/runtime/onednn/BUILD b/third_party/xla/xla/backends/cpu/runtime/onednn/BUILD index c53d029f7b5d55..2c061e0aabca8c 100644 --- a/third_party/xla/xla/backends/cpu/runtime/onednn/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/onednn/BUILD @@ -60,6 +60,7 @@ onednn_graph_cc_library( "//xla/tsl/mkl:onednn", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", @@ -137,6 +138,7 @@ cc_library( "//xla/tsl/concurrency:async_value", "//xla/tsl/mkl:onednn", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base", "@com_google_absl//absl/base:dynamic_annotations", diff --git a/third_party/xla/xla/backends/cpu/runtime/onednn/onednn_fusion_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/onednn/onednn_fusion_thunk.cc index b42d9c2a14b52f..313ba9dba99839 100644 --- a/third_party/xla/xla/backends/cpu/runtime/onednn/onednn_fusion_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/onednn/onednn_fusion_thunk.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "oneapi/dnnl/dnnl_common.hpp" #include "oneapi/dnnl/dnnl_graph.hpp" #include "oneapi/dnnl/dnnl_threadpool.hpp" @@ -116,7 +117,7 @@ OneDnnFusionThunk::CreateOneDnnRuntime( info().op_name, onednn_runtime_pool_.num_created()); // Construct oneDNN fusion using user-provided builder function. - TF_ASSIGN_OR_RETURN(OneDnnFusion fusion, builder()); + ASSIGN_OR_RETURN(OneDnnFusion fusion, builder()); OneDnnRuntime runtime(std::move(fusion), thread_pool); @@ -177,7 +178,7 @@ tsl::AsyncValueRef OneDnnFusionThunk::Execute( for (size_t i = 0; i < arguments_.size(); ++i) { Argument& argument = arguments_[i]; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( arguments_buffers[i], params.buffer_allocations->GetDeviceAddress(argument.slice)); @@ -193,7 +194,7 @@ tsl::AsyncValueRef OneDnnFusionThunk::Execute( for (size_t i = 0; i < results_.size(); ++i) { Result& result = results_[i]; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( results_buffers[i], params.buffer_allocations->GetDeviceAddress(results_[i].slice)); @@ -207,8 +208,7 @@ tsl::AsyncValueRef OneDnnFusionThunk::Execute( params.intra_op_threadpool->getPool(); // Borrow oneDNN runtime from the pool. - TF_ASSIGN_OR_RETURN(auto runtime, - onednn_runtime_pool_.GetOrCreate(thread_pool)); + ASSIGN_OR_RETURN(auto runtime, onednn_runtime_pool_.GetOrCreate(thread_pool)); auto executed = runtime->Invoke(thread_pool, absl::MakeSpan(arguments_buffers), absl::MakeSpan(results_buffers)); diff --git a/third_party/xla/xla/backends/cpu/runtime/onednn/onednn_op_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/onednn/onednn_op_thunk.cc index 5fa4ce29e4b2a3..d77bb2399112df 100644 --- a/third_party/xla/xla/backends/cpu/runtime/onednn/onednn_op_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/onednn/onednn_op_thunk.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "Eigen/ThreadPool" +#include "xla/tsl/platform/status_macros.h" #include "oneapi/dnnl/dnnl_common.hpp" #include "oneapi/dnnl/dnnl_threadpool.hpp" #include "xla/backends/cpu/runtime/onednn/onednn_threadpool.h" @@ -179,9 +180,9 @@ tsl::AsyncValueRef OneDnnOpThunk::Execute( base_resources.arg_memrefs.reserve(num_operands); for (size_t i = 0; i < num_operands; ++i) { const auto& shape = op_buffers_.arguments_shapes[i]; - TF_ASSIGN_OR_RETURN(se::DeviceAddressBase arg, - params.buffer_allocations->GetDeviceAddress( - op_buffers_.arguments_buffers[i])); + ASSIGN_OR_RETURN(se::DeviceAddressBase arg, + params.buffer_allocations->GetDeviceAddress( + op_buffers_.arguments_buffers[i])); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(arg.opaque(), arg.size()); VLOG(3) << absl::StreamFormat( @@ -197,9 +198,9 @@ tsl::AsyncValueRef OneDnnOpThunk::Execute( base_resources.result_memrefs.reserve(num_results); for (size_t i = 0; i < num_results; ++i) { const auto& shape = op_buffers_.results_shapes[i]; - TF_ASSIGN_OR_RETURN(se::DeviceAddressBase res, - params.buffer_allocations->GetDeviceAddress( - op_buffers_.results_buffers[i])); + ASSIGN_OR_RETURN(se::DeviceAddressBase res, + params.buffer_allocations->GetDeviceAddress( + op_buffers_.results_buffers[i])); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(res.opaque(), res.size()); VLOG(3) << absl::StreamFormat(" res: %s (%p)", diff --git a/third_party/xla/xla/backends/cpu/runtime/ynnpack/BUILD b/third_party/xla/xla/backends/cpu/runtime/ynnpack/BUILD index 30033295d13cb6..2acefc0dfe2bf7 100644 --- a/third_party/xla/xla/backends/cpu/runtime/ynnpack/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/ynnpack/BUILD @@ -92,6 +92,7 @@ cc_library( "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@XNNPACK//ynnpack:ynnpack_h", "@com_google_absl//absl/algorithm:container", @@ -127,6 +128,7 @@ xla_cc_test( "//xla/service:buffer_assignment", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:env", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@XNNPACK//ynnpack:ynnpack_h", diff --git a/third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.cc index d81fde82661aba..98f42c2441b2f0 100644 --- a/third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -146,14 +147,14 @@ YnnFusionThunk::CreateYnnExecutable( executable.captured_arguments = CaptureArguments(arguments_buffers); if (builder_) { - TF_ASSIGN_OR_RETURN(executable.subgraph, builder_(arguments_, results_)); + ASSIGN_OR_RETURN(executable.subgraph, builder_(arguments_, results_)); } else { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( executable.subgraph, capturing_builder_(arguments_, results_, arguments_buffers)); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( executable.runtime, CreateYnnRuntime([&](ynn_runtime_t* runtime) { uint32_t ynn_flags = 0; return ynn_create_runtime( @@ -186,16 +187,15 @@ absl::Status YnnFusionThunk::UpdateYnnExecutable( VLOG(3) << absl::StreamFormat("Update YNN executable for `%s` operation", info().op_name); - TF_RETURN_IF_ERROR(executable.Reset()); + RETURN_IF_ERROR(executable.Reset()); // Keep track of the updated arguments captured by value. executable.captured_arguments = std::move(capture_arguments); - TF_ASSIGN_OR_RETURN( - executable.subgraph, - capturing_builder_(arguments_, results_, arguments_buffers)); + ASSIGN_OR_RETURN(executable.subgraph, + capturing_builder_(arguments_, results_, arguments_buffers)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( executable.runtime, CreateYnnRuntime([&](ynn_runtime_t* runtime) { uint32_t ynn_flags = 0; return ynn_create_runtime( @@ -305,7 +305,7 @@ tsl::AsyncValueRef YnnFusionThunk::Execute( for (size_t i = 0; i < arguments_.size(); ++i) { Argument& argument = arguments_[i]; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( arguments_buffers[i], params.buffer_allocations->GetDeviceAddress(argument.slice)); @@ -321,7 +321,7 @@ tsl::AsyncValueRef YnnFusionThunk::Execute( for (size_t i = 0; i < results_.size(); ++i) { Result& result = results_[i]; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( results_buffers[i], params.buffer_allocations->GetDeviceAddress(results_[i].slice)); @@ -346,9 +346,9 @@ tsl::AsyncValueRef YnnFusionThunk::Execute( }; // Borrow YnnExecutable from the pool. - TF_ASSIGN_OR_RETURN(auto executable, - ynn_executable_pool_.GetOrCreate(GetYnnThreadpool(params), - arguments_buffers)); + ASSIGN_OR_RETURN(auto executable, + ynn_executable_pool_.GetOrCreate(GetYnnThreadpool(params), + arguments_buffers)); int concurrency = concurrency_.load(std::memory_order_acquire); if (concurrency == 0) { @@ -374,8 +374,8 @@ tsl::AsyncValueRef YnnFusionThunk::Execute( } // Otherwise reset YnnExecutable to capture new arguments buffers. - TF_RETURN_IF_ERROR(UpdateYnnExecutable(GetYnnThreadpool(params), *executable, - arguments_buffers)); + RETURN_IF_ERROR(UpdateYnnExecutable(GetYnnThreadpool(params), *executable, + arguments_buffers)); return invoke(std::move(executable)); } diff --git a/third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk_test.cc index 9ec7cd1e046bc5..0eeccd281980cb 100644 --- a/third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/thunk_testlib.h" @@ -51,12 +52,12 @@ namespace { static absl::StatusOr BuildBinaryAddSubgraph( absl::Span arguments, absl::Span results) { - TF_ASSIGN_OR_RETURN(YnnSubgraph subgraph, - CreateYnnSubgraph([&](ynn_subgraph_t* subgraph) { - return ynn_create_subgraph( - /*external_value_ids=*/3, - /*flags=*/0, subgraph); - })); + ASSIGN_OR_RETURN(YnnSubgraph subgraph, + CreateYnnSubgraph([&](ynn_subgraph_t* subgraph) { + return ynn_create_subgraph( + /*external_value_ids=*/3, + /*flags=*/0, subgraph); + })); auto dims = [](absl::Span dims) -> std::vector { return {dims.begin(), dims.end()}; @@ -91,12 +92,12 @@ static absl::StatusOr BuildBinaryAddSubgraph( static absl::StatusOr BuildIotaSubgraph( absl::Span arguments, absl::Span results) { - TF_ASSIGN_OR_RETURN(YnnSubgraph subgraph, - CreateYnnSubgraph([&](ynn_subgraph_t* subgraph) { - return ynn_create_subgraph( - /*external_value_ids=*/1, - /*flags=*/0, subgraph); - })); + ASSIGN_OR_RETURN(YnnSubgraph subgraph, + CreateYnnSubgraph([&](ynn_subgraph_t* subgraph) { + return ynn_create_subgraph( + /*external_value_ids=*/1, + /*flags=*/0, subgraph); + })); uint32_t out_id = 0; auto out_shape = results[0].shape; diff --git a/third_party/xla/xla/backends/cpu/testlib/BUILD b/third_party/xla/xla/backends/cpu/testlib/BUILD index 6a188e40c2f47c..5dfeb1ed8011a6 100644 --- a/third_party/xla/xla/backends/cpu/testlib/BUILD +++ b/third_party/xla/xla/backends/cpu/testlib/BUILD @@ -46,6 +46,7 @@ cc_library( "//xla/service/cpu:cpu_options", "//xla/service/llvm_ir:llvm_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -126,6 +127,7 @@ cc_library( "//xla/runtime:work_group", "//xla/service:buffer_assignment", "//xla/service:shaped_slice", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc index e1597dc72dcde0..518588eedd4262 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Module.h" #include "llvm/Target/TargetOptions.h" @@ -59,15 +60,15 @@ absl::StatusOr KernelRunner::Create( SetModuleMemoryRegionName(*thread_safe_module.getModuleUnlocked(), "kernel_runner_test"); - TF_RETURN_IF_ERROR(compiler.AddModule(std::move(thread_safe_module))); + RETURN_IF_ERROR(compiler.AddModule(std::move(thread_safe_module))); absl::string_view kernel_name = spec.name(); - TF_ASSIGN_OR_RETURN(std::unique_ptr library, - std::move(compiler).Compile( - {FunctionLibrary::Sym(kernel_name)})); + ASSIGN_OR_RETURN(std::unique_ptr library, + std::move(compiler).Compile( + {FunctionLibrary::Sym(kernel_name)})); - TF_ASSIGN_OR_RETURN(XLA_CPU_Kernel * kernel_fn, - library->ResolveFunction(kernel_name)); + ASSIGN_OR_RETURN(XLA_CPU_Kernel * kernel_fn, + library->ResolveFunction(kernel_name)); return KernelRunner(std::move(library), Kernel(1, kernel_fn), spec.num_workgroups()); @@ -77,7 +78,7 @@ absl::StatusOr KernelRunner::Create( KernelDefinition kernel, JitCompiler compiler) { auto spec = kernel.spec(); auto source = std::move(kernel).TakeSource(); - TF_ASSIGN_OR_RETURN(LlvmKernelSource llvm_kernel_source, LowerToLlvm(source)); + ASSIGN_OR_RETURN(LlvmKernelSource llvm_kernel_source, LowerToLlvm(source)); return Create(KernelDefinition(spec, std::move(llvm_kernel_source)), std::move(compiler)); @@ -148,7 +149,7 @@ absl::StatusOr LowerToLlvm( options.fast_min_max = true; FusionCompiler fusion_compiler(mlir_kernel_source.module().getContext(), options); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr llvm_module, fusion_compiler.Compile(*llvm_context, mlir_kernel_source.module())); diff --git a/third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.cc b/third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.cc index 62a2b32ab3d372..b783543dfffaaf 100644 --- a/third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.cc +++ b/third_party/xla/xla/backends/cpu/testlib/mlir_kernel_emitter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/STLExtras.h" #include "mlir/IR/MLIRContext.h" #include "xla/backends/cpu/codegen/fusion_compiler.h" @@ -55,9 +56,8 @@ absl::StatusOr MlirTestKernelEmitter::EmitKernelDefinition() { std::unique_ptr context = FusionCompiler::CreateContext(); - TF_ASSIGN_OR_RETURN( - MlirKernelSource source, - MlirKernelSource::ParseFromString(mlir_, std::move(context))); + ASSIGN_OR_RETURN(MlirKernelSource source, MlirKernelSource::ParseFromString( + mlir_, std::move(context))); // Convert kernel arguments to fake allocations and buffer uses. KernelSpec::Buffers argument_buffers; diff --git a/third_party/xla/xla/backends/cpu/transforms/BUILD b/third_party/xla/xla/backends/cpu/transforms/BUILD index d3927ea35fdcdc..2f7a30e3681443 100644 --- a/third_party/xla/xla/backends/cpu/transforms/BUILD +++ b/third_party/xla/xla/backends/cpu/transforms/BUILD @@ -35,6 +35,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/cpu:backend_config_proto_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/backends/cpu/transforms/library_rewriter.cc b/third_party/xla/xla/backends/cpu/transforms/library_rewriter.cc index 13821bf50ea13f..a2ee17ae1cdc5f 100644 --- a/third_party/xla/xla/backends/cpu/transforms/library_rewriter.cc +++ b/third_party/xla/xla/backends/cpu/transforms/library_rewriter.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/cpu/transforms/library_matcher.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -76,13 +77,13 @@ absl::StatusOr CreateLibraryFusion( BackendConfig backend_config; FusionBackendConfig* fusion_config = backend_config.mutable_fusion_config(); fusion_config->set_kind(fusion_kind); - TF_RETURN_IF_ERROR(fusion->set_backend_config(backend_config)); + RETURN_IF_ERROR(fusion->set_backend_config(backend_config)); // Replace the instruction. - TF_ASSIGN_OR_RETURN(bool changed, - computation->ReplaceInstructionWithDifferentShape( - instr, fusion, /*preserve_sharding=*/false, - /*relay_control_dependency=*/true)); + ASSIGN_OR_RETURN(bool changed, + computation->ReplaceInstructionWithDifferentShape( + instr, fusion, /*preserve_sharding=*/false, + /*relay_control_dependency=*/true)); if (!changed) { return absl::InternalError("Failed to replace instruction with fusion"); } @@ -138,10 +139,10 @@ absl::StatusOr FuseConsumerInstruction( *fusion->mutable_shape() = new_root->shape(); } - TF_ASSIGN_OR_RETURN(bool changed, - fusion->parent()->ReplaceInstructionWithDifferentShape( - to_fuse, fusion, /*preserve_sharding=*/false, - /*relay_control_dependency=*/true)); + ASSIGN_OR_RETURN(bool changed, + fusion->parent()->ReplaceInstructionWithDifferentShape( + to_fuse, fusion, /*preserve_sharding=*/false, + /*relay_control_dependency=*/true)); if (!changed) { return absl::InternalError("Failed to fuse consumer instruction"); } @@ -161,7 +162,7 @@ inline absl::Status InsertConvertIfNecessary( HloComputation* computation = instr->parent(); HloInstruction* convert = computation->AddInstruction( HloInstruction::CreateConvert(instr->shape(), instr)); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(convert)); + RETURN_IF_ERROR(instr->ReplaceAllUsesWith(convert)); instr->mutable_shape()->set_element_type(new_instr_out_dtype); if (instr == computation->root_instruction()) { @@ -179,7 +180,7 @@ inline bool IsElementwiseAndNotConstant(const HloInstruction* instr) { absl::StatusOr LibraryRewriter::ChooseLibrary( HloInstruction* instr) { for (std::unique_ptr& lib : libs_) { - TF_ASSIGN_OR_RETURN(bool op_supported, lib->IsOpSupported(instr)); + ASSIGN_OR_RETURN(bool op_supported, lib->IsOpSupported(instr)); if (op_supported && lib->ShouldCreateFusion(instr)) { return lib.get(); } @@ -220,12 +221,12 @@ absl::StatusOr LibraryRewriter::MergeFusionInstructions( << ": Fusing with: " << neighbor->ToString(); if (dir == FusionDirection::kUp) { main->MergeFusionInstruction(neighbor); - TF_RETURN_IF_ERROR(main->parent()->RemoveInstruction(neighbor)); + RETURN_IF_ERROR(main->parent()->RemoveInstruction(neighbor)); return main; } if (dir == FusionDirection::kDown) { neighbor->MergeFusionInstruction(main); - TF_RETURN_IF_ERROR(neighbor->parent()->RemoveInstruction(main)); + RETURN_IF_ERROR(neighbor->parent()->RemoveInstruction(main)); return neighbor; } return InvalidArgument("Invalid fusion direction: %s", @@ -241,10 +242,10 @@ absl::StatusOr LibraryRewriter::GrowFusion( if (dir == FusionDirection::kUp) { new_instr = fusion->FuseInstruction(to_fuse); if (to_fuse->user_count() == 0) { - TF_RETURN_IF_ERROR(to_fuse->parent()->RemoveInstruction(to_fuse)); + RETURN_IF_ERROR(to_fuse->parent()->RemoveInstruction(to_fuse)); } } else if (dir == FusionDirection::kDown) { - TF_ASSIGN_OR_RETURN(new_instr, FuseConsumerInstruction(fusion, to_fuse)); + ASSIGN_OR_RETURN(new_instr, FuseConsumerInstruction(fusion, to_fuse)); } return new_instr; } @@ -272,14 +273,14 @@ absl::Status LibraryRewriter::FuseNeighbors(HloFusionInstruction* fusion, // We don't need to add its neighbors to the frontier because anything that // can be fused would have already been fused into `instr`. if (IsCustomFusionWithKind(instr, lib->fusion_kind())) { - TF_ASSIGN_OR_RETURN(fusion, - MergeFusionInstructions( - fusion, Cast(instr), dir)); + ASSIGN_OR_RETURN(fusion, + MergeFusionInstructions( + fusion, Cast(instr), dir)); continue; } // Skip this instruction if it can't be fused. - TF_ASSIGN_OR_RETURN(bool op_supported, lib->IsOpSupported(instr)); + ASSIGN_OR_RETURN(bool op_supported, lib->IsOpSupported(instr)); if (!op_supported) { VLOG(4) << " Skipping unsupported instruction: " << instr->ToString(); continue; @@ -290,9 +291,9 @@ absl::Status LibraryRewriter::FuseNeighbors(HloFusionInstruction* fusion, AddFusionCandidates(fusion, instr, dir, frontier); // Fuse `instr` into `fusion` according to the travel direction. - TF_ASSIGN_OR_RETURN(HloInstruction * new_instr, - GrowFusion(fusion, instr, dir)); - TF_RETURN_IF_ERROR( + ASSIGN_OR_RETURN(HloInstruction * new_instr, + GrowFusion(fusion, instr, dir)); + RETURN_IF_ERROR( InsertConvertIfNecessary(new_instr, lib->LibraryOpOutputType(instr))); } return absl::OkStatus(); @@ -331,7 +332,7 @@ absl::StatusOr LibraryRewriter::ProcessComputation( } // Find the best library to use for the current instruction. - TF_ASSIGN_OR_RETURN(LibraryMatcher * lib, ChooseLibrary(centroid)); + ASSIGN_OR_RETURN(LibraryMatcher * lib, ChooseLibrary(centroid)); if (lib == nullptr) { continue; } @@ -339,15 +340,15 @@ absl::StatusOr LibraryRewriter::ProcessComputation( // Start a fusion node. fused_.insert(centroid); VLOG(3) << "Starting a fusion with: " << centroid->ToString(); - TF_ASSIGN_OR_RETURN(HloFusionInstruction * fusion, - CreateLibraryFusion(centroid, lib->fusion_prefix(), - lib->fusion_kind())); - TF_RETURN_IF_ERROR(InsertConvertIfNecessary( + ASSIGN_OR_RETURN(HloFusionInstruction * fusion, + CreateLibraryFusion(centroid, lib->fusion_prefix(), + lib->fusion_kind())); + RETURN_IF_ERROR(InsertConvertIfNecessary( fusion->fused_expression_root(), lib->LibraryOpOutputType(centroid))); // Fuse as many neighbors as as we can. if (lib->ShouldGrowFusion(centroid)) { - TF_RETURN_IF_ERROR(FuseNeighbors(fusion, lib)); + RETURN_IF_ERROR(FuseNeighbors(fusion, lib)); } } return !fused_.empty(); @@ -366,7 +367,7 @@ absl::StatusOr LibraryRewriter::RunImpl( })) { continue; } - TF_ASSIGN_OR_RETURN(bool comp_changed, ProcessComputation(computation)); + ASSIGN_OR_RETURN(bool comp_changed, ProcessComputation(computation)); module_changed |= comp_changed; } return module_changed; diff --git a/third_party/xla/xla/backends/cpu/ynn_emitter.cc b/third_party/xla/xla/backends/cpu/ynn_emitter.cc index 024f174d1712ca..38ed6512b07969 100644 --- a/third_party/xla/xla/backends/cpu/ynn_emitter.cc +++ b/third_party/xla/xla/backends/cpu/ynn_emitter.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/cpu/runtime/dot_dims.h" #include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h" #include "xla/backends/cpu/ynn_support.h" @@ -103,7 +104,7 @@ absl::StatusOr DefineTensorValue( } auto dims = YnnDimensions(instr->shape()); - TF_ASSIGN_OR_RETURN(auto type, YnnType(instr->shape().element_type())); + ASSIGN_OR_RETURN(auto type, YnnType(instr->shape().element_type())); if (output_id == YNN_INVALID_VALUE_ID) { // If instruction is a root instruction of the parent computation we assign @@ -161,7 +162,7 @@ absl::StatusOr DefineConstant(ynn_subgraph_t subgraph, } auto dims = YnnDimensions(instr->shape()); - TF_ASSIGN_OR_RETURN(auto type, YnnType(instr->shape().element_type())); + ASSIGN_OR_RETURN(auto type, YnnType(instr->shape().element_type())); uint32_t tensor_id = YNN_INVALID_VALUE_ID; @@ -181,7 +182,7 @@ absl::StatusOr DefineParameter(ynn_subgraph_t subgraph, param->ToString()); auto dims = YnnDimensions(param->shape()); - TF_ASSIGN_OR_RETURN(auto type, YnnType(param->shape().element_type())); + ASSIGN_OR_RETURN(auto type, YnnType(param->shape().element_type())); uint32_t tensor_id = param->parameter_number(); uint32_t flags = (data == nullptr) ? YNN_VALUE_FLAG_EXTERNAL_INPUT : 0; @@ -200,8 +201,8 @@ absl::StatusOr DefineBitcastOp(ynn_subgraph_t subgraph, CHECK_EQ(instr->opcode(), HloOpcode::kBitcast); const HloInstruction* input = instr->operand(0); CHECK_EQ(input->shape().element_type(), instr->shape().element_type()); - TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); - TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); + ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); auto dims = YnnDimensions(instr->shape()); YNN_RETURN_IF_ERROR(ynn_define_static_reshape(subgraph, dims.size(), @@ -223,8 +224,8 @@ absl::StatusOr DefineTransposeOp(ynn_subgraph_t subgraph, instr->ToString()); CHECK_EQ(instr->opcode(), HloOpcode::kTranspose); const HloInstruction* input = instr->operand(0); - TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); - TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); + ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); auto dimensions = instr->dimensions(); std::vector perm(dimensions.begin(), dimensions.end()); @@ -242,8 +243,8 @@ absl::StatusOr DefineBroadcastOp(ynn_subgraph_t subgraph, instr->ToString()); CHECK_EQ(instr->opcode(), HloOpcode::kBroadcast); const HloInstruction* input = instr->operand(0); - TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); - TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); + ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); auto dimensions = instr->dimensions(); auto output_dims = instr->shape().dimensions(); @@ -278,10 +279,10 @@ absl::StatusOr DefineConcatenateOp(ynn_subgraph_t subgraph, std::vector inputs; for (const HloInstruction* operand : instr->operands()) { - TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, operand)); + ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, operand)); inputs.push_back(in); } - TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); YNN_RETURN_IF_ERROR( ynn_define_concatenate(subgraph, instr->concatenate_dimension(), @@ -296,8 +297,8 @@ absl::StatusOr DefineSliceOp(ynn_subgraph_t subgraph, instr->ToString()); CHECK_EQ(instr->opcode(), HloOpcode::kSlice); const HloInstruction* input = instr->operand(0); - TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); - TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); + ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); const std::vector& starts = instr->slice_starts(); const std::vector& limits = instr->slice_limits(); @@ -321,9 +322,9 @@ absl::StatusOr DefinePadOp(ynn_subgraph_t subgraph, CHECK_EQ(instr->opcode(), HloOpcode::kPad); const HloInstruction* input = instr->operand(0); const HloInstruction* padding_value = instr->operand(1); - TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); - TF_ASSIGN_OR_RETURN(auto pad_val, FindTensorValue(tensor_ids, padding_value)); - TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); + ASSIGN_OR_RETURN(auto pad_val, FindTensorValue(tensor_ids, padding_value)); + ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); const PaddingConfig& config = instr->padding_config(); int rank = input->shape().dimensions().size(); @@ -350,24 +351,24 @@ absl::StatusOr DefineIotaOp(ynn_subgraph_t subgraph, CHECK_EQ(instr->opcode(), HloOpcode::kIota); const HloIotaInstruction* iota = Cast(instr); - TF_ASSIGN_OR_RETURN(uint32_t out_id, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(uint32_t out_id, DefineTensorValue(subgraph, instr)); const Shape& shape = instr->shape(); int64_t rank = shape.dimensions().size(); int64_t iota_dim = iota->iota_dimension(); PrimitiveType element_type = shape.element_type(); - TF_ASSIGN_OR_RETURN(auto ynn_element_type, YnnType(element_type)); + ASSIGN_OR_RETURN(auto ynn_element_type, YnnType(element_type)); auto stride_shape = ShapeUtil::MakeShape(element_type, {rank}); - TF_ASSIGN_OR_RETURN(auto stride_value, Literal::Make(stride_shape)); + ASSIGN_OR_RETURN(auto stride_value, Literal::Make(stride_shape)); for (int64_t i = 0; i < rank; ++i) { int value = (i == iota_dim) ? 1 : 0; if (primitive_util::IsIntegralType(element_type)) { - TF_RETURN_IF_ERROR(stride_value.SetIntegralAsS64({i}, value)); + RETURN_IF_ERROR(stride_value.SetIntegralAsS64({i}, value)); } else { - TF_RETURN_IF_ERROR(stride_value.SetFromDouble({i}, value)); + RETURN_IF_ERROR(stride_value.SetFromDouble({i}, value)); } } @@ -390,10 +391,10 @@ absl::StatusOr DefineUnaryOp(ynn_subgraph_t subgraph, const HloInstruction* instr) { VLOG(3) << absl::StreamFormat("Define tensor value for unary op: %s", instr->ToString()); - TF_ASSIGN_OR_RETURN(auto unary_op, YnnUnaryOperator(instr->opcode())); + ASSIGN_OR_RETURN(auto unary_op, YnnUnaryOperator(instr->opcode())); - TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, instr->operand(0))); - TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, instr->operand(0))); + ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); VLOG(3) << absl::StreamFormat(" tensors: in=%d, out=%d", in, out); @@ -409,11 +410,11 @@ absl::StatusOr DefineBinaryOp(ynn_subgraph_t subgraph, VLOG(3) << absl::StreamFormat("Define tensor value for binary op: %s", instr->ToString()); - TF_ASSIGN_OR_RETURN(auto binary_op, YnnBinaryOperator(instr->opcode())); + ASSIGN_OR_RETURN(auto binary_op, YnnBinaryOperator(instr->opcode())); - TF_ASSIGN_OR_RETURN(auto lhs, FindTensorValue(tensor_ids, instr->operand(0))); - TF_ASSIGN_OR_RETURN(auto rhs, FindTensorValue(tensor_ids, instr->operand(1))); - TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(auto lhs, FindTensorValue(tensor_ids, instr->operand(0))); + ASSIGN_OR_RETURN(auto rhs, FindTensorValue(tensor_ids, instr->operand(1))); + ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); VLOG(3) << absl::StreamFormat(" tensors: lhs=%d, rhs=%d, out=%d", lhs, rhs, out); @@ -439,16 +440,15 @@ absl::StatusOr DefineReduceOp(ynn_subgraph_t subgraph, CHECK_EQ(reduce_instr->to_apply()->num_parameters(), 2); CHECK_EQ(reduce_instr->to_apply()->instruction_count(), 3); - TF_ASSIGN_OR_RETURN( - auto ynn_reduce_op, - YnnReduceOperator( - reduce_instr->to_apply()->root_instruction()->opcode())); + ASSIGN_OR_RETURN(auto ynn_reduce_op, + YnnReduceOperator( + reduce_instr->to_apply()->root_instruction()->opcode())); const absl::Span reduce_dims = reduce_instr->dimensions(); const std::vector dims(reduce_dims.begin(), reduce_dims.end()); - TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); - TF_ASSIGN_OR_RETURN(auto init_id, FindTensorValue(tensor_ids, init)); - TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input)); + ASSIGN_OR_RETURN(auto init_id, FindTensorValue(tensor_ids, init)); + ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr)); YNN_RETURN_IF_ERROR( ynn_define_reduce(subgraph, ynn_reduce_op, /*num_axes=*/dims.size(), @@ -465,20 +465,20 @@ absl::StatusOr DefineDotOp( const HloInstruction* lhs = instr->operand(0); const HloInstruction* rhs = instr->operand(1); - TF_ASSIGN_OR_RETURN(auto lhs_id, FindTensorValue(tensor_ids, lhs)); - TF_ASSIGN_OR_RETURN(auto rhs_id, FindTensorValue(tensor_ids, rhs)); - TF_ASSIGN_OR_RETURN(output_id, DefineTensorValue(subgraph, instr, output_id)); + ASSIGN_OR_RETURN(auto lhs_id, FindTensorValue(tensor_ids, lhs)); + ASSIGN_OR_RETURN(auto rhs_id, FindTensorValue(tensor_ids, rhs)); + ASSIGN_OR_RETURN(output_id, DefineTensorValue(subgraph, instr, output_id)); const Shape& lhs_shape = lhs->shape(); const Shape& rhs_shape = rhs->shape(); const Shape& out_shape = instr->shape(); DotDimensionNumbers dot_dimensions = instr->dot_dimension_numbers(); - TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape, - rhs_shape, out_shape)); + ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape, + rhs_shape, out_shape)); - TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, - GetDotCanonicalDims(dot_dimensions, dot_shape)); + ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, + GetDotCanonicalDims(dot_dimensions, dot_shape)); const size_t b_rank = rhs_shape.dimensions().size(); const bool transpose_b = !dot_canonical_dims.rhs_canonical; @@ -516,12 +516,12 @@ absl::StatusOr DefineReduceWindowOp(ynn_subgraph_t subgraph, const HloInstruction* input = instr->operand(0); const HloInstruction* init = instr->operand(1); - TF_ASSIGN_OR_RETURN(auto input_id, FindTensorValue(tensor_ids, input)); - TF_ASSIGN_OR_RETURN(auto init_id, FindTensorValue(tensor_ids, init)); - TF_ASSIGN_OR_RETURN(auto output_id, DefineTensorValue(subgraph, instr)); + ASSIGN_OR_RETURN(auto input_id, FindTensorValue(tensor_ids, input)); + ASSIGN_OR_RETURN(auto init_id, FindTensorValue(tensor_ids, init)); + ASSIGN_OR_RETURN(auto output_id, DefineTensorValue(subgraph, instr)); HloOpcode to_apply_opcode = instr->to_apply()->root_instruction()->opcode(); - TF_ASSIGN_OR_RETURN(auto ynn_reduce_op, YnnReduceOperator(to_apply_opcode)); + ASSIGN_OR_RETURN(auto ynn_reduce_op, YnnReduceOperator(to_apply_opcode)); const Window& window = instr->window(); int rank = window.dimensions().size(); @@ -568,13 +568,13 @@ absl::StatusOr DefineReduceWindowOp(ynn_subgraph_t subgraph, // The padding should be the identity value of the reduction. PrimitiveType input_type = input->shape().element_type(); - TF_ASSIGN_OR_RETURN(double identity_float, ReduceIdentity(to_apply_opcode)); + ASSIGN_OR_RETURN(double identity_float, ReduceIdentity(to_apply_opcode)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto identity_literal, LiteralUtil::CreateR0(identity_float).Convert(input_type)); - TF_ASSIGN_OR_RETURN(ynn_type ynn_type, YnnType(input_type)); + ASSIGN_OR_RETURN(ynn_type ynn_type, YnnType(input_type)); uint32_t identity_id = YNN_INVALID_VALUE_ID; YNN_RETURN_IF_ERROR( @@ -619,14 +619,13 @@ absl::StatusOr DefineConvolutionOp( const HloInstruction* lhs = conv->operand(0); const HloInstruction* rhs = conv->operand(1); - TF_ASSIGN_OR_RETURN(auto lhs_id, FindTensorValue(tensor_ids, lhs)); - TF_ASSIGN_OR_RETURN(auto rhs_id, FindTensorValue(tensor_ids, rhs)); - TF_ASSIGN_OR_RETURN(output_id, DefineTensorValue(subgraph, instr, output_id)); + ASSIGN_OR_RETURN(auto lhs_id, FindTensorValue(tensor_ids, lhs)); + ASSIGN_OR_RETURN(auto rhs_id, FindTensorValue(tensor_ids, rhs)); + ASSIGN_OR_RETURN(output_id, DefineTensorValue(subgraph, instr, output_id)); - TF_ASSIGN_OR_RETURN(ynn_type ynn_lhs_type, - YnnType(lhs->shape().element_type())); - TF_ASSIGN_OR_RETURN(ynn_type ynn_out_type, - YnnType(conv->shape().element_type())); + ASSIGN_OR_RETURN(ynn_type ynn_lhs_type, YnnType(lhs->shape().element_type())); + ASSIGN_OR_RETURN(ynn_type ynn_out_type, + YnnType(conv->shape().element_type())); Window conv_window = conv->window(); ConvolutionDimensionNumbers conv_dims = conv->convolution_dimension_numbers(); @@ -800,7 +799,7 @@ absl::StatusOr EmitYnnSubgraph( absl::Span captured_parameters) { VLOG(3) << "Emit YNNPACK subgraph for computation: " << computation->name(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( YnnSubgraph subgraph, CreateYnnSubgraph([&](ynn_subgraph_t* subgraph) { return ynn_create_subgraph( /*external_value_ids=*/computation->num_parameters() + 1, @@ -826,8 +825,8 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported constant instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefineConstant(subgraph.get(), literals, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineConstant(subgraph.get(), literals, instr)); continue; } @@ -838,11 +837,11 @@ absl::StatusOr EmitYnnSubgraph( instr->ToString()); } if (instr->operand_count() == 1) { - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefineUnaryOp(subgraph.get(), tensor_ids, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineUnaryOp(subgraph.get(), tensor_ids, instr)); } else if (instr->operand_count() == 2) { - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefineBinaryOp(subgraph.get(), tensor_ids, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineBinaryOp(subgraph.get(), tensor_ids, instr)); } else { LOG(FATAL) << "Unexpected operand count " << instr->operand_count(); } @@ -856,8 +855,8 @@ absl::StatusOr EmitYnnSubgraph( instr->parameter_number())) { data = arguments_buffers[instr->parameter_number()].opaque(); } - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefineParameter(subgraph.get(), instr, data)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineParameter(subgraph.get(), instr, data)); } break; case HloOpcode::kBitcast: { @@ -866,8 +865,8 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported bitcast instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefineBitcastOp(subgraph.get(), tensor_ids, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineBitcastOp(subgraph.get(), tensor_ids, instr)); } break; case HloOpcode::kReshape: { @@ -876,8 +875,8 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported reshape instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefineReshapeOp(subgraph.get(), tensor_ids, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineReshapeOp(subgraph.get(), tensor_ids, instr)); } break; case HloOpcode::kTranspose: { @@ -886,9 +885,8 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported transpose instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN( - tensor_ids[instr], - DefineTransposeOp(subgraph.get(), tensor_ids, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineTransposeOp(subgraph.get(), tensor_ids, instr)); } break; case HloOpcode::kBroadcast: { @@ -897,9 +895,8 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported broadcast instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN( - tensor_ids[instr], - DefineBroadcastOp(subgraph.get(), tensor_ids, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineBroadcastOp(subgraph.get(), tensor_ids, instr)); } break; case HloOpcode::kConcatenate: { @@ -908,7 +905,7 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported concatenate instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( tensor_ids[instr], DefineConcatenateOp(subgraph.get(), tensor_ids, instr)); } break; @@ -919,8 +916,8 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported slice instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefineSliceOp(subgraph.get(), tensor_ids, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineSliceOp(subgraph.get(), tensor_ids, instr)); } break; case HloOpcode::kPad: { @@ -929,8 +926,8 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported pad instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefinePadOp(subgraph.get(), tensor_ids, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefinePadOp(subgraph.get(), tensor_ids, instr)); } break; case HloOpcode::kIota: { @@ -939,8 +936,8 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported iota instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefineIotaOp(subgraph.get(), instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineIotaOp(subgraph.get(), instr)); } break; case HloOpcode::kDot: { @@ -949,13 +946,13 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported dot instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefineDotOp(subgraph.get(), tensor_ids, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineDotOp(subgraph.get(), tensor_ids, instr)); } break; case HloOpcode::kReduce: { - TF_ASSIGN_OR_RETURN(tensor_ids[instr], - DefineReduceOp(subgraph.get(), tensor_ids, instr)); + ASSIGN_OR_RETURN(tensor_ids[instr], + DefineReduceOp(subgraph.get(), tensor_ids, instr)); } break; case HloOpcode::kReduceWindow: { @@ -964,7 +961,7 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported reduce window instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( tensor_ids[instr], DefineReduceWindowOp(subgraph.get(), tensor_ids, instr)); } break; @@ -975,7 +972,7 @@ absl::StatusOr EmitYnnSubgraph( "Unsupported convolution instruction in YNN fusion: %s", instr->ToString()); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( tensor_ids[instr], DefineConvolutionOp(subgraph.get(), tensor_ids, instr)); } break; @@ -989,7 +986,7 @@ absl::StatusOr EmitYnnSubgraph( ynn_status status = ynn_optimize_subgraph( subgraph.get(), /*threadpool=*/nullptr, /*flags=*/0); - TF_RETURN_IF_ERROR(YnnStatusToStatus(status)); + RETURN_IF_ERROR(YnnStatusToStatus(status)); return subgraph; } diff --git a/third_party/xla/xla/backends/cpu/ynn_support.cc b/third_party/xla/xla/backends/cpu/ynn_support.cc index d2c30853e35dd3..84f1796a17c037 100644 --- a/third_party/xla/xla/backends/cpu/ynn_support.cc +++ b/third_party/xla/xla/backends/cpu/ynn_support.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/cpu/runtime/dot_dims.h" #include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -392,11 +393,11 @@ absl::StatusOr IsDotSupportedByYnn(const HloInstruction* hlo) { } // Check shapes. - TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape, - rhs_shape, out_shape)); + ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape, + rhs_shape, out_shape)); - TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, - GetDotCanonicalDims(dot_dimensions, dot_shape)); + ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims, + GetDotCanonicalDims(dot_dimensions, dot_shape)); if (dot_canonical_dims.m == 1 || dot_canonical_dims.n == 1) { // TODO(b/430079105): YNNPACK does not handle vectors in dots. We could diff --git a/third_party/xla/xla/backends/gpu/autotuner/BUILD b/third_party/xla/xla/backends/gpu/autotuner/BUILD index 5a01f2d189580a..e2132274169feb 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/BUILD +++ b/third_party/xla/xla/backends/gpu/autotuner/BUILD @@ -1,4 +1,3 @@ -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("@rules_cc//cc:cc_library.bzl", "cc_library") load("//xla:xla.default.bzl", "xla_cc_binary", "xla_cc_test") @@ -134,92 +133,6 @@ xla_test( ], ) -cc_library( - name = "cublas", - srcs = ["cublas.cc"], - hdrs = ["cublas.h"], - tags = ["gpu"], - deps = [ - ":gpu_codegen_backend", - "//xla:autotuning_proto_cc", - "//xla:shape_util", - "//xla:xla_proto_cc", - "//xla/backends/autotuner:backends_proto_cc", - "//xla/backends/autotuner:codegen_backend", - "//xla/backends/gpu/transforms:dot_algorithm_rewriter", - "//xla/backends/gpu/transforms:gemm_rewriter", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:compiler", - "//xla/service:hlo_cost_analysis", - "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:matmul_utils", - "//xla/service/gpu/autotuning:redzone_buffers", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_address", - "//xla/stream_executor:device_address_allocator", - "//xla/stream_executor:device_description", - "//xla/stream_executor:semantic_version", - "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/stream_executor/cuda:cuda_compute_capability", - "//xla/stream_executor/gpu:gpu_blas_lt", - "//xla/stream_executor/gpu:redzone_allocator", - "//xla/tools:hlo_decomposer_lib", - "//xla/tsl/lib/gtl:iterator_range", - "//xla/tsl/platform:errors", - "//xla/tsl/platform:status_macros", - "//xla/tsl/platform:statusor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - ] + if_cuda([ - "//xla/stream_executor/cuda:repeat_buffer_kernel_cuda", - "//xla/stream_executor/cuda:cublas_plugin", - ]), -) - -xla_test( - name = "cublas_test", - srcs = ["cublas_test.cc"], - backends = [ - "a100", - "h100", - "b200", - ], - tags = [ - "cuda-only", - "no_mac", - ], - use_legacy_runtime = True, - deps = [ - ":cublas", - "//xla:autotuning_proto_cc", - "//xla:shape_util", - "//xla:xla_proto_cc", - "//xla/backends/autotuner:codegen_backend", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:filecheck", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:compiler", - "//xla/service:executable", - "//xla/service:platform_util", - "//xla/service/gpu:nvptx_compiler_impl", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_description_proto_cc", - "//xla/stream_executor:stream_executor_h", - "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:statusor", - "//xla/tsl/util/proto:proto_matchers", - "@com_google_absl//absl/status:status_matchers", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "cublaslt", srcs = ["cublaslt.cc"], @@ -241,13 +154,11 @@ cc_library( "//xla/service/gpu:matmul_utils", "//xla/stream_executor:blas", "//xla/stream_executor:device_description", - "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/tsl/platform:errors", "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], @@ -622,45 +533,6 @@ cc_library( alwayslink = True, ) -cc_library( - name = "rocblas", - srcs = ["rocblas.cc"], - hdrs = ["rocblas.h"], - tags = [ - "gpu", - "rocm-only", - ], - deps = [ - ":gpu_codegen_backend", - "//xla:autotuning_proto_cc", - "//xla:shape_layout", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/backends/autotuner:backends_proto_cc", - "//xla/backends/autotuner:codegen_backend", - "//xla/hlo/ir:hlo", - "//xla/service:compiler", - "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:matmul_utils", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_address", - "//xla/stream_executor:device_address_allocator", - "//xla/stream_executor:device_description", - "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/stream_executor/gpu:gpu_blas_lt", - "//xla/stream_executor/rocm:rocblas_plugin", - "//xla/tsl/platform:errors", - "//xla/tsl/platform:status_macros", - "//xla/tsl/platform:statusor", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], -) - cc_library( name = "hipblaslt", srcs = ["hipblaslt.cc"], @@ -688,7 +560,6 @@ cc_library( "//xla/service/gpu:matmul_utils", "//xla/stream_executor:blas", "//xla/stream_executor:device_description", - "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/rocm:amdhipblaslt_plugin", @@ -696,7 +567,6 @@ cc_library( "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], @@ -759,40 +629,6 @@ xla_test( ], ) -xla_test( - name = "rocblas_test", - srcs = ["rocblas_test.cc"], - backends = [ - "amdgpu_any", - ], - tags = [ - "gpu", - "rocm-only", - ], - deps = [ - ":rocblas", - "//xla:autotuning_proto_cc", - "//xla:xla_proto_cc", - "//xla/backends/autotuner:codegen_backend", - "//xla/hlo/ir:hlo", - "//xla/hlo/testlib:filecheck", - "//xla/hlo/testlib:hlo_hardware_independent_test_base", - "//xla/service:compiler", - "//xla/service:executable", - "//xla/service:platform_util", - "//xla/service/gpu:amdgpu_compiler_impl", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_description_proto_cc", - "//xla/stream_executor:stream_executor_h", - "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:statusor", - "//xla/tsl/util/proto:proto_matchers", - "@com_google_absl//absl/status:status_matchers", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "factory_rocm", srcs = ["factory_rocm.cc"], @@ -917,6 +753,7 @@ cc_library( hdrs = ["fission_backend.h"], deps = [ ":gpu_codegen_backend", + "//xla:shape_util", "//xla:xla_proto_cc", "//xla/backends/autotuner:backends_proto_cc", "//xla/backends/autotuner:codegen_backend", @@ -983,7 +820,6 @@ xla_test( "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", ] + if_cuda_is_configured([ - ":cublas", ":cublaslt", ]) + if_rocm_is_configured([ ":rocblas", diff --git a/third_party/xla/xla/backends/gpu/autotuner/cublas.cc b/third_party/xla/xla/backends/gpu/autotuner/cublas.cc deleted file mode 100644 index 3f14da0255c81c..00000000000000 --- a/third_party/xla/xla/backends/gpu/autotuner/cublas.cc +++ /dev/null @@ -1,178 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/backends/gpu/autotuner/cublas.h" - -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "xla/tsl/platform/status_macros.h" -#include "xla/autotuning.pb.h" -#include "xla/backends/autotuner/codegen_backend.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_address.h" -#include "xla/stream_executor/device_address_allocator.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/gpu/gpu_blas_lt.h" -#include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -namespace se = ::stream_executor; - -absl::StatusOr>> -CublasBackend::GetSupportedConfigs(const HloInstruction& instr) { - if (!IsSupported(instr)) { - return std::vector>(); - } - - if (ShouldUseCublasLt(instr)) { - std::vector> configs; - AutotuneResult::GemmKey gemm_key; - gemm_key.set_algorithm(0); - configs.push_back(std::make_unique()); - configs.back()->PackFrom(gemm_key); - return configs; - } - - std::unique_ptr allocator = - std::make_unique( - stream_executor()); - ASSIGN_OR_RETURN(se::Stream * stream, - allocator->GetStream(stream_executor()->device_ordinal())); - - // We use GemmConfig::For with GemmBackendConfig as a fallback because - // Matmul_utils.cc relies on backend config to determine gemm contracting - // dimensions. - GemmBackendConfig backend_config; - backend_config = - instr.backend_config()->gemm_backend_config(); - ASSIGN_OR_RETURN( - GemmConfig gemm_config, - GemmConfig::For( - &instr, backend_config, - target_config().device_description.gpu_compute_capability())); - - auto create_matrix_desc = [](const se::gpu::MatrixLayout& layout) - -> absl::StatusOr { - ASSIGN_OR_RETURN(se::blas::DataType type, - se::gpu::AsBlasDataType(layout.dtype)); - return se::gpu::MatrixDescriptor{ - /*data=*/se::DeviceAddressBase(), layout.leading_dim_stride, - layout.batch_stride, type, - // BLAS is column-major by default. - (layout.order == se::gpu::MatrixLayout::Order::kColumnMajor - ? se::blas::Transpose::kNoTranspose - : se::blas::Transpose::kTranspose)}; - }; - - ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor lhs_desc, - create_matrix_desc(gemm_config.lhs_layout)); - ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor rhs_desc, - create_matrix_desc(gemm_config.rhs_layout)); - ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor output_desc_base, - create_matrix_desc(gemm_config.output_layout)); - - se::gpu::OutputMatrixDescriptor out_desc(std::move(output_desc_base)); - out_desc.batch_size = gemm_config.output_layout.batch_size; - out_desc.m = gemm_config.output_layout.num_rows; - out_desc.n = gemm_config.output_layout.num_cols; - out_desc.k = gemm_config.lhs_layout.num_cols; - ASSIGN_OR_RETURN( - out_desc.compute_type, - se::gpu::GetBlasComputationType( - gemm_config.precision_algorithm, gemm_config.lhs_layout.dtype, - gemm_config.output_layout.dtype, gemm_config.compute_precision, - target_config().device_description.gpu_compute_capability())); - - se::blas::BlasSupport* blas = stream_executor()->AsBlas(); - if (blas == nullptr) { - return absl::InternalError("Failed to getBlas support."); - } - std::vector algorithms; - - blas->GetBlasGemmAlgorithms(stream, lhs_desc, rhs_desc, &out_desc, - &gemm_config.alpha, &gemm_config.beta, - &algorithms); - - std::vector> configs; - configs.reserve(algorithms.size()); - for (se::blas::AlgorithmType algorithm : algorithms) { - AutotuneResult::GemmKey gemm_key; - gemm_key.set_algorithm(algorithm); - auto any = std::make_unique(); - any->PackFrom(gemm_key); - configs.push_back(std::move(any)); - } - return configs; -} - -absl::StatusOr> CublasBackend::GetDefaultConfig( - const HloInstruction& instr) { - if (!IsSupported(instr)) { - return absl::InvalidArgumentError( - "CublasBackend does not support this instruction."); - } - AutotuneResult::GemmKey gemm_key; - gemm_key.set_algorithm(se::blas::kDefaultAlgorithm); - auto any = std::make_unique(); - if (ShouldUseCublasLt(instr)) { - gemm_key.set_algorithm(0); - } - any->PackFrom(gemm_key); - return any; -} - -absl::Status CublasBackend::ApplyConfig(HloInstruction& instr, - const BackendConfig& config) { - AutotuneResult::GemmKey gemm_key; - if (!config.UnpackTo(&gemm_key)) { - return absl::InvalidArgumentError( - "Failed to unpack CublasBackendConfig from Any."); - } - if (ShouldUseCublasLt(instr) && gemm_key.algorithm() == -1) { - gemm_key.set_algorithm(0); - } - ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - instr.backend_config()); - GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config(); - backend_config.set_selected_algorithm(gemm_key.algorithm()); - backend_config.set_autotune_workspace_size( - gemm_key.autotune_workspace_size()); - RETURN_IF_ERROR(instr.set_backend_config(std::move(gpu_config))); - return absl::OkStatus(); -} - -bool CublasBackend::IsSupported(const HloInstruction& instr) { - return IsLegacyCublasMatmul(instr) || ShouldUseCublasLt(instr); -} - -bool CublasBackend::ShouldUseCublasLt(const HloInstruction& instr) { - return fp8_lt_fallback_ && IsCublasLtMatmulF8(instr); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/backends/gpu/autotuner/cublas.h b/third_party/xla/xla/backends/gpu/autotuner/cublas.h deleted file mode 100644 index 5b7ca2ca3ca45c..00000000000000 --- a/third_party/xla/xla/backends/gpu/autotuner/cublas.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_BACKENDS_GPU_AUTOTUNER_CUBLAS_H_ -#define XLA_BACKENDS_GPU_AUTOTUNER_CUBLAS_H_ - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "xla/backends/autotuner/backends.pb.h" -#include "xla/backends/autotuner/codegen_backend.h" -#include "xla/backends/gpu/autotuner/gpu_codegen_backend.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/compiler.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -// A codegen backend for cuBLAS, with configurable fallback to cuBLAS LT for F8 -// matmuls. -// This backend is used to autotune cuBLAS algorithms. -// -// Cublas calls are represented as custom-call instructions, with and -// configurable algorithm: -// ``` -// %custom-call.1 = .. custom-call(...), custom_call_target="__cublas$gemm", -// backend_config={" -// gemm_backend_config":{"selected_algorithm":"18"} -// } -// ``` - -class CublasBackend : public GpuCodegenBackend { - public: - explicit CublasBackend(stream_executor::StreamExecutor* stream_executor, - const DebugOptions* debug_options, Compiler* compiler, - const Compiler::GpuTargetConfig* target_config, - bool fp8_lt_fallback = false) - : GpuCodegenBackend(autotuner::Backend::CUBLAS, debug_options, compiler, - target_config, stream_executor, - /*uses_last_output_for_scratch=*/true), - fp8_lt_fallback_(fp8_lt_fallback) {} - - absl::StatusOr>> - GetSupportedConfigs(const HloInstruction& instr) override; - - absl::StatusOr> GetDefaultConfig( - const HloInstruction& instr) override; - - absl::Status ApplyConfig(HloInstruction& instr, - const BackendConfig& config) override; - - private: - bool ShouldUseCublasLt(const HloInstruction& instr); - - bool IsSupported(const HloInstruction& instr) override; - // TODO(b/514330710): use valid version - std::string version() const override { return "unknown"; } - bool fp8_lt_fallback_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_BACKENDS_GPU_AUTOTUNER_CUBLAS_H_ diff --git a/third_party/xla/xla/backends/gpu/autotuner/cublas_test.cc b/third_party/xla/xla/backends/gpu/autotuner/cublas_test.cc deleted file mode 100644 index 697a40a57b32ba..00000000000000 --- a/third_party/xla/xla/backends/gpu/autotuner/cublas_test.cc +++ /dev/null @@ -1,247 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/backends/gpu/autotuner/cublas.h" - -#include -#include - -#include -#include -#include "absl/status/status_matchers.h" -#include "absl/status/statusor.h" -#include "xla/autotuning.pb.h" -#include "xla/backends/autotuner/codegen_backend.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/testlib/filecheck.h" -#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" -#include "xla/service/compiler.h" -#include "xla/service/executable.h" -#include "xla/service/gpu/nvptx_compiler.h" -#include "xla/service/platform_util.h" -#include "xla/shape.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/util/proto/proto_matchers.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -using CublasBackendConfig = AutotuneResult::GemmKey; - -using absl_testing::IsOk; -using absl_testing::IsOkAndHolds; -using ::testing::IsEmpty; -using ::testing::Not; -using ::tsl::proto_testing::EqualsProto; - -const char kCublasCustomCallHlo[] = R"( - HloModule module, entry_computation_layout={(f32[100,100]{1,0}, f32[100,100]{1,0})->f32[100,100]{1,0}} - - ENTRY %main (arg0: f32[100,100], arg1: f32[100,100]) -> f32[100,100] { - %arg0 = f32[100,100]{1,0} parameter(0) - %arg1 = f32[100,100]{1,0} parameter(1) - %custom-call.1 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%arg0, %arg1), - custom_call_target="__cublas$gemm", - backend_config={ - "gemm_backend_config":{ - "dot_dimension_numbers": - { - "lhs_contracting_dimensions":["1"], - "rhs_contracting_dimensions":["0"], - "lhs_batch_dimensions":[], - "rhs_batch_dimensions":[] - } - } - } - ROOT %get-tuple-element = f32[100,100]{1,0} get-tuple-element(%custom-call.1), index=0 - })"; - -const char kCublasLtCustomCallHlo[] = R"( - HloModule test, entry_computation_layout={(f8e4m3fn[16,32]{1,0}, f8e5m2[32,16]{1,0}, f32[], f32[])->f32[16,16]{1,0}} - - ENTRY %test (x: f8e4m3fn[16,32], y: f8e5m2[32,16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { - %x = f8e4m3fn[16,32]{1,0} parameter(0) - %y = f8e5m2[32,16]{1,0} parameter(1) - %transpose = f8e5m2[16,32]{1,0} transpose(%y), dimensions={1,0} - %x_scale = f32[] parameter(2) - %y_scale = f32[] parameter(3) - %cublas-gemm.1 = (f32[16,16]{1,0}, s8[33554432]{0}) custom-call(%x, %transpose, %x_scale, %y_scale), - custom_call_target="__cublas$lt$matmul$f8", - backend_config={ - "operation_queue_id":"0", - "gemm_backend_config":{ - "alpha_real":1, - "beta":0, - "dot_dimension_numbers":{ - "lhs_contracting_dimensions":["1"], - "rhs_contracting_dimensions":["1"], - "lhs_batch_dimensions":[], - "rhs_batch_dimensions":[] - }, - "alpha_imag":0, - "precision_config":{ - "operand_precision":["DEFAULT","DEFAULT"], - "algorithm":"ALG_UNSET" - }, - "epilogue":"DEFAULT", - "lhs_stride":"512", - "rhs_stride":"512", - "grad_x":false, - "grad_y":false, - "damax_output":false - }, - "force_earliest_schedule":false, - "reification_cost":[], - "device_type":"DEVICE_TYPE_INVALID" - } - ROOT %get-tuple-element = f32[16,16]{1,0} get-tuple-element(%cublas-gemm.1), index=0 -})"; - -const char kUnsupportedHlo[] = R"( - HloModule module - - computation { - p0 = bf16[1024,1024]{1,0} parameter(0) - convert0 = f32[1024,1024]{1,0} convert(p0) - p1 = s8[1024,1024]{1,0} parameter(1) - convert1 = f32[1024,1024]{1,0} convert(p1) - ROOT dot = f32[1024,1024]{1,0} dot(convert0, convert1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY main { - p0 = bf16[1024,1024]{1,0} parameter(0) - p1 = s8[1024,1024]{1,0} parameter(1) - ROOT fusion = f32[1024,1024]{1,0} fusion(p0, p1), - kind=kCustom, calls=computation, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} - })"; - -class CublasBackendTest : public HloHardwareIndependentTestBase { - protected: - DebugOptions debug_options_; - NVPTXCompiler compiler_; - se::StreamExecutor* stream_executor_; - Compiler::GpuTargetConfig target_config_; - CublasBackend backend_; - - CublasBackendTest() - : stream_executor_(PlatformUtil::GetDefaultPlatform() - .value() - ->ExecutorForDevice(0) - .value()), - target_config_(stream_executor_), - backend_(stream_executor_, &debug_options_, &compiler_, - &target_config_) {} - - CublasBackendConfig ExpectedDefaultAlgorithm() { - auto config = AutotuneResult::GemmKey(); - config.set_algorithm(se::blas::kDefaultAlgorithm); - return config; - } -}; - -TEST_F(CublasBackendTest, CanCreateCublasBackend) { - ASSERT_NE(nullptr, &backend_); -} - -TEST_F(CublasBackendTest, GetSupportedConfigsFromCublasCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kCublasCustomCallHlo)); - absl::StatusOr>> configs = - backend_.GetSupportedConfigs( - (*hlo_module->entry_computation()->root_instruction()->operand(0))); - EXPECT_THAT(configs, IsOkAndHolds(Not(IsEmpty()))); -} - -TEST_F(CublasBackendTest, CublasLtCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kCublasLtCustomCallHlo)); - const HloInstruction* instr = - hlo_module->entry_computation()->root_instruction()->operand(0); - CublasBackend backend(stream_executor_, &debug_options_, &compiler_, - &target_config_, /*fp8_lt_fallback=*/true); - absl::StatusOr>> configs = - backend.GetSupportedConfigs(*instr); - EXPECT_THAT(configs, IsOkAndHolds(Not(IsEmpty()))); - - EXPECT_THAT(backend.GetDefaultConfig(*instr), IsOk()); - EXPECT_THAT(backend.Compile(*instr, *configs.value()[0]), IsOk()); -} - -TEST_F(CublasBackendTest, - GetSupportedConfigsReturnsEmptyVectorNonCublasCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kUnsupportedHlo)); - absl::StatusOr>> configs = - backend_.GetSupportedConfigs( - (*hlo_module->entry_computation()->root_instruction())); - EXPECT_THAT(configs, IsOkAndHolds(testing::SizeIs(0))); -} - -TEST_F(CublasBackendTest, GetDefaultConfigFromCublasCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kCublasCustomCallHlo)); - - absl::StatusOr> config = - backend_.GetDefaultConfig( - (*hlo_module->entry_computation()->root_instruction()->operand(0))); - CublasBackendConfig config_proto; - ASSERT_TRUE(config.value()->UnpackTo(&config_proto)); - EXPECT_THAT(config_proto, EqualsProto(ExpectedDefaultAlgorithm())); -} - -TEST_F(CublasBackendTest, ApplyConfig) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kCublasCustomCallHlo)); - CublasBackendConfig config; - config.set_algorithm(2); - google::protobuf::Any any; - any.PackFrom(config); - TF_EXPECT_OK(backend_.ApplyConfig(*hlo_module->entry_computation() - ->root_instruction() - ->mutable_operands() - .at(0), - any)); - EXPECT_THAT(RunFileCheck(hlo_module->ToString(), - "CHECK: \"selected_algorithm\":\"2\""), - IsOkAndHolds(true)); -} - -TEST_F(CublasBackendTest, Compile) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kCublasCustomCallHlo)); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr config, - backend_.GetDefaultConfig( - *(module->entry_computation()->root_instruction()->operand(0)))); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, - backend_.Compile( - *(module->entry_computation()->root_instruction()->operand(0)), - *config)); - const ProgramShape& program_shape = - executable->compute_computation_layout().ComputeProgramShape(); - EXPECT_EQ(program_shape.parameters_size(), 2); - EXPECT_FALSE(program_shape.result().IsTuple()); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/backends/gpu/autotuner/cublaslt.cc b/third_party/xla/xla/backends/gpu/autotuner/cublaslt.cc index 510b35180cc2ab..56d69ea7a6900e 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/cublaslt.cc +++ b/third_party/xla/xla/backends/gpu/autotuner/cublaslt.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/tsl/platform/status_macros.h" @@ -38,7 +37,6 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" -#include "xla/stream_executor/stream.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" @@ -102,12 +100,10 @@ CublasLtBackend::GetSupportedConfigs(const HloInstruction& instr) { ASSIGN_OR_RETURN(BlasLt::Epilogue epilogue, AsBlasLtEpilogue(backend_config.epilogue())); - ASSIGN_OR_RETURN(std::unique_ptr stream, - stream_executor()->CreateStream()); + ASSIGN_OR_RETURN(BlasLt * blas_lt, se::gpu::BlasLt::Get(stream_executor())); - ASSIGN_OR_RETURN( - std::unique_ptr plan, - se::gpu::BlasLt::GetMatmulPlan(stream.get(), gemm_config, epilogue)); + ASSIGN_OR_RETURN(std::unique_ptr plan, + blas_lt->GetMatmulPlan(gemm_config, epilogue)); const Shape& output_shape = instr.shape(); if (!output_shape.IsTuple() || output_shape.tuple_shapes().empty()) { @@ -118,9 +114,9 @@ CublasLtBackend::GetSupportedConfigs(const HloInstruction& instr) { const int64_t workspace_size = ShapeUtil::ByteSizeOf(output_shape.tuple_shapes().back()); - ASSIGN_OR_RETURN(std::vector algorithms, - plan->GetAlgorithms(stream.get(), GemmConfig::kNumAlgorithms, - workspace_size)); + ASSIGN_OR_RETURN( + std::vector algorithms, + plan->GetAlgorithms(GemmConfig::kNumAlgorithms, workspace_size)); int num_algorithms = algorithms.size(); std::vector> configs; configs.reserve(num_algorithms); diff --git a/third_party/xla/xla/backends/gpu/autotuner/factory_test.cc b/third_party/xla/xla/backends/gpu/autotuner/factory_test.cc index b2ffbb7e4323da..6883cb33d534d0 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/factory_test.cc +++ b/third_party/xla/xla/backends/gpu/autotuner/factory_test.cc @@ -102,12 +102,12 @@ INSTANTIATE_TEST_SUITE_P( FactoryTestParams{{}, 6, /*run_on_cuda=*/true, /*run_on_rocm=*/false}, FactoryTestParams{{}, 6, /*run_on_cuda=*/false, /*run_on_rocm=*/true}, FactoryTestParams{{Backend::TRITON}, 1}, - FactoryTestParams{{Backend::TRITON, Backend::CUBLAS}, - 1, + FactoryTestParams{{Backend::TRITON, Backend::CUBLASLT}, + 2, /*run_on_cuda=*/true, /*run_on_rocm=*/false}, - FactoryTestParams{{Backend::TRITON, Backend::ROCBLAS}, - 1, + FactoryTestParams{{Backend::TRITON, Backend::HIPBLASLT}, + 2, /*run_on_cuda=*/false, /*run_on_rocm=*/true})); diff --git a/third_party/xla/xla/backends/gpu/autotuner/fission_backend.cc b/third_party/xla/xla/backends/gpu/autotuner/fission_backend.cc index e746f7dc1f39d5..f20236a88977e6 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/fission_backend.cc +++ b/third_party/xla/xla/backends/gpu/autotuner/fission_backend.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/backends/gpu/autotuner/fission_backend.h" +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "xla/tsl/platform/status_macros.h" #include "xla/backends/autotuner/codegen_backend.h" #include "xla/backends/gpu/transforms/priority_fusion.h" @@ -35,6 +37,7 @@ limitations under the License. #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/service/compiler.h" #include "xla/service/hlo_cost_analysis.h" +#include "xla/shape_util.h" #include "xla/tools/hlo_decomposer.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" @@ -91,15 +94,15 @@ FissionBackend::GetSupportedConfigs(const HloInstruction& instr) { } ASSIGN_OR_RETURN(std::unique_ptr hlo_module, GetFissionedAndRewrittenModule(instr)); - absl::StatusOr supported_instr = - FindFirstSupportedInstruction(hlo_module.get()); - if (supported_instr.status().code() == absl::StatusCode::kNotFound) { + absl::StatusOr> supported_instrs = + FindSupportedInstructions(hlo_module.get()); + if (supported_instrs.status().code() == absl::StatusCode::kNotFound) { VLOG(3) << "No supported instructions found by " << name() << ": " << instr.ToString(); return std::vector>(); } - RETURN_IF_ERROR(supported_instr.status()); - return codegen_backend_->GetSupportedConfigs(**supported_instr); + RETURN_IF_ERROR(supported_instrs.status()); + return codegen_backend_->GetSupportedConfigs(*(*supported_instrs)[0]); } absl::StatusOr> FissionBackend::GetDefaultConfig( @@ -109,9 +112,9 @@ absl::StatusOr> FissionBackend::GetDefaultConfig( } ASSIGN_OR_RETURN(std::unique_ptr hlo_module, GetFissionedAndRewrittenModule(instr)); - ASSIGN_OR_RETURN(HloInstruction * supported_instr, - FindFirstSupportedInstruction(hlo_module.get())); - return codegen_backend_->GetDefaultConfig(*supported_instr); + ASSIGN_OR_RETURN(std::vector supported_instrs, + FindSupportedInstructions(hlo_module.get())); + return codegen_backend_->GetDefaultConfig(*supported_instrs[0]); } absl::Status FissionBackend::RunPriorityFusion(HloModule* module) { @@ -139,9 +142,28 @@ absl::Status FissionBackend::ApplyConfig(HloInstruction& instr, HloModule* module = instr.GetModule(); ASSIGN_OR_RETURN(std::unique_ptr hlo_module, GetFissionedAndRewrittenModule(instr)); - ASSIGN_OR_RETURN(HloInstruction * supported_instr, - FindFirstSupportedInstruction(hlo_module.get())); - RETURN_IF_ERROR(codegen_backend_->ApplyConfig(*supported_instr, config)); + ASSIGN_OR_RETURN(std::vector supported_instrs, + FindSupportedInstructions(hlo_module.get())); + + for (size_t i = 0; i < supported_instrs.size(); ++i) { + HloInstruction* supported_instr = supported_instrs[i]; + if (i > 0) { + if (supported_instr->opcode() != supported_instrs[0]->opcode()) { + return absl::InternalError(absl::StrCat( + "FissionBackend expected isomorphic supported instructions, but " + "found different opcodes: ", + HloOpcodeString(supported_instrs[0]->opcode()), " vs ", + HloOpcodeString(supported_instr->opcode()))); + } + if (!ShapeUtil::Compatible(supported_instr->shape(), + supported_instrs[0]->shape())) { + return absl::InternalError( + "FissionBackend expected isomorphic supported instructions with " + "compatible shapes, but found incompatible shapes."); + } + } + RETURN_IF_ERROR(codegen_backend_->ApplyConfig(*supported_instr, config)); + } // Given that the autotuner runs post fusion, we have to run priority fusion // again to fuse the epilogue and prologues. @@ -168,8 +190,8 @@ FissionBackend::GetFissionedAndRewrittenModule( return hlo_module; } -absl::StatusOr FissionBackend::FindFirstSupportedInstruction( - const HloModule* module) { +absl::StatusOr> +FissionBackend::FindSupportedInstructions(const HloModule* module) { std::vector supported_instructions; for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { @@ -181,12 +203,7 @@ absl::StatusOr FissionBackend::FindFirstSupportedInstruction( if (supported_instructions.empty()) { return absl::NotFoundError("No supported instructions found."); } - if (supported_instructions.size() > 1) { - LOG(WARNING) << "Backend " << name() - << " found multiple supported instructions found. Using the " - "first one."; - } - return supported_instructions[0]; + return supported_instructions; } } // namespace gpu diff --git a/third_party/xla/xla/backends/gpu/autotuner/fission_backend.h b/third_party/xla/xla/backends/gpu/autotuner/fission_backend.h index 3e56257872863a..4396c6fe8e6936 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/fission_backend.h +++ b/third_party/xla/xla/backends/gpu/autotuner/fission_backend.h @@ -52,11 +52,10 @@ inline autotuner::Backend GetFissionBackend(autotuner::Backend backend) { // A proxy backend that wraps an actual codegen backend. The `rewriter_pipeline` // is used to transform unfused instructions to retarget them for the underlying // codegen backend. -// For the get/apply config operations, the proxy backend only operates on the -// *first* supported instruction by the underlying backend, found in the unfused -// and transmormed HLO. -// The assumption is that there is only one operation of interest in the fusion -// (e.g., a 'dot' in a gemm fusion). +// If multiple supported instructions are found, the first one is profiled, then +// we use its config for the rest of the supported instructions, provided they +// are identical. E.g. three 'dots' resulting from "_X3" or "_X6" algorithms are +// identical. class FissionBackend : public GpuCodegenBackend { public: FissionBackend(const DebugOptions* debug_options, Compiler* compiler, @@ -92,7 +91,7 @@ class FissionBackend : public GpuCodegenBackend { private: absl::StatusOr> GetFissionedAndRewrittenModule( const HloInstruction& fusion_instr); - absl::StatusOr FindFirstSupportedInstruction( + absl::StatusOr> FindSupportedInstructions( const HloModule* module); // Runs priority fusion to fuse prologues and epilogue after the fissioned // module has been generated. diff --git a/third_party/xla/xla/backends/gpu/autotuner/fission_backend_test.cc b/third_party/xla/xla/backends/gpu/autotuner/fission_backend_test.cc index e1a42d00c68824..26702bc9159dc3 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/fission_backend_test.cc +++ b/third_party/xla/xla/backends/gpu/autotuner/fission_backend_test.cc @@ -313,7 +313,7 @@ TEST_P(FissionTest, ApplyConfig) { INSTANTIATE_TEST_SUITE_P( FissionTests, FissionTest, ::testing::ValuesIn({ - {"TritonFusion_Cublas", kTritonFusionHlo, &GetCublasRewriterPipeline, + {"TritonFusion_CublasLt", kTritonFusionHlo, &GetCublasRewriterPipeline, &CreateCublasLtBackend, /*expected_module_substrings_fn=*/ [](const se::DeviceDescription& device_description) { diff --git a/third_party/xla/xla/backends/gpu/autotuner/hipblaslt.cc b/third_party/xla/xla/backends/gpu/autotuner/hipblaslt.cc index 508b7e807fc531..9249cac082e7bf 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/hipblaslt.cc +++ b/third_party/xla/xla/backends/gpu/autotuner/hipblaslt.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/tsl/platform/status_macros.h" @@ -42,7 +41,6 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" -#include "xla/stream_executor/stream.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" @@ -226,12 +224,10 @@ HipblasLtBackend::GetSupportedConfigs(const HloInstruction& instr) { ASSIGN_OR_RETURN(BlasLt::Epilogue epilogue, AsBlasLtEpilogue(backend_config.epilogue())); - ASSIGN_OR_RETURN(std::unique_ptr stream, - stream_executor()->CreateStream()); + ASSIGN_OR_RETURN(BlasLt * blas_lt, se::gpu::BlasLt::Get(stream_executor())); - ASSIGN_OR_RETURN( - std::unique_ptr plan, - se::gpu::BlasLt::GetMatmulPlan(stream.get(), gemm_config, epilogue)); + ASSIGN_OR_RETURN(BlasLt::MatmulPlanPtr plan, + blas_lt->GetMatmulPlan(gemm_config, epilogue)); const Shape& output_shape = instr.shape(); if (!output_shape.IsTuple() || output_shape.tuple_shapes().empty()) { @@ -244,8 +240,7 @@ HipblasLtBackend::GetSupportedConfigs(const HloInstruction& instr) { ASSIGN_OR_RETURN( std::vector algorithms, - plan->GetAlgorithms(stream.get(), GemmConfig::kNumAlgorithms, - workspace_size)); + plan->GetAlgorithms(GemmConfig::kNumAlgorithms, workspace_size)); int num_algorithms = algorithms.size(); std::vector> configs; configs.reserve(num_algorithms); @@ -288,10 +283,9 @@ HipblasLtBackend::GetSupportedConfigs(const HloInstruction& instr) { return std::vector>(); } - ASSIGN_OR_RETURN(std::unique_ptr stream, - stream_executor()->CreateStream()); - auto plan_or = se::gpu::BlasLt::GetMatmulPlan(stream.get(), *gemm_config_or, - BlasLt::Epilogue::kDefault); + ASSIGN_OR_RETURN(BlasLt * blas_lt, se::gpu::BlasLt::Get(stream_executor())); + auto plan_or = + blas_lt->GetMatmulPlan(*gemm_config_or, BlasLt::Epilogue::kDefault); if (!plan_or.ok()) { VLOG(2) << "hipBLASLt MX: GetMatmulPlan failed: " << plan_or.status(); return std::vector>(); @@ -300,8 +294,7 @@ HipblasLtBackend::GetSupportedConfigs(const HloInstruction& instr) { int64_t workspace_size = GemmConfig::kGFX950Workspace; ASSIGN_OR_RETURN( std::vector algorithms, - (*plan_or)->GetAlgorithms(stream.get(), GemmConfig::kNumAlgorithms, - workspace_size)); + (*plan_or)->GetAlgorithms(GemmConfig::kNumAlgorithms, workspace_size)); if (algorithms.empty()) { VLOG(2) << "hipBLASLt MX: no algorithms found for scaled dot."; return std::vector>(); @@ -322,10 +315,8 @@ HipblasLtBackend::GetSupportedConfigs(const HloInstruction& instr) { ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, instr.backend_config()); - ASSIGN_OR_RETURN(std::unique_ptr stream, - stream_executor()->CreateStream()); + ASSIGN_OR_RETURN(BlasLt * blas_lt, se::gpu::BlasLt::Get(stream_executor())); - std::unique_ptr plan; int64_t workspace_size; const GroupedGemmBackendConfig& grouped_config = gpu_config.grouped_gemm_backend_config(); @@ -341,8 +332,9 @@ HipblasLtBackend::GetSupportedConfigs(const HloInstruction& instr) { ASSIGN_OR_RETURN(BlasLt::Epilogue epilogue, AsBlasLtEpilogue(backend_config.epilogue())); - ASSIGN_OR_RETURN(plan, se::gpu::BlasLt::GetGroupedMatmulPlan( - stream.get(), grouped_gemm_config, epilogue)); + ASSIGN_OR_RETURN( + BlasLt::MatmulPlanPtr plan, + blas_lt->GetGroupedMatmulPlan(grouped_gemm_config, epilogue)); const Shape& output_shape = instr.shape(); if (!output_shape.IsTuple() || output_shape.tuple_shapes().empty()) { @@ -354,8 +346,7 @@ HipblasLtBackend::GetSupportedConfigs(const HloInstruction& instr) { ASSIGN_OR_RETURN( std::vector algorithms, - plan->GetAlgorithms(stream.get(), GemmConfig::kNumAlgorithms, - workspace_size)); + plan->GetAlgorithms(GemmConfig::kNumAlgorithms, workspace_size)); int num_algorithms = algorithms.size(); std::vector> configs; configs.reserve(num_algorithms); diff --git a/third_party/xla/xla/backends/gpu/autotuner/rocblas.cc b/third_party/xla/xla/backends/gpu/autotuner/rocblas.cc deleted file mode 100644 index 45f6b084c1a25d..00000000000000 --- a/third_party/xla/xla/backends/gpu/autotuner/rocblas.cc +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/backends/gpu/autotuner/rocblas.h" - -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "xla/tsl/platform/status_macros.h" -#include "xla/autotuning.pb.h" -#include "xla/backends/autotuner/codegen_backend.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/shape.h" -#include "xla/shape_layout.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_address.h" -#include "xla/stream_executor/device_address_allocator.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/gpu/gpu_blas_lt.h" -#include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -namespace se = ::stream_executor; - -absl::StatusOr>> -RocblasBackend::GetSupportedConfigs(const HloInstruction& instr) { - if (!IsSupported(instr)) { - return std::vector>(); - } - - if (ShouldUseHipblasLt(instr)) { - std::vector> configs; - AutotuneResult::GemmKey gemm_key; - gemm_key.set_algorithm(0); - configs.push_back(std::make_unique()); - configs.back()->PackFrom(gemm_key); - return configs; - } - - std::unique_ptr allocator = - std::make_unique( - stream_executor()); - ASSIGN_OR_RETURN(se::Stream * stream, - allocator->GetStream(stream_executor()->device_ordinal())); - - // We use GemmConfig::For with GemmBackendConfig as a fallback because - // Matmul_utils.cc relies on backend config to determine gemm contracting - // dimensions. - GemmBackendConfig backend_config; - backend_config = - instr.backend_config()->gemm_backend_config(); - ASSIGN_OR_RETURN( - GemmConfig gemm_config, - GemmConfig::For( - &instr, backend_config, - target_config().device_description.gpu_compute_capability())); - - auto create_matrix_desc = [](const se::gpu::MatrixLayout& layout) - -> absl::StatusOr { - ASSIGN_OR_RETURN(se::blas::DataType type, - se::gpu::AsBlasDataType(layout.dtype)); - return se::gpu::MatrixDescriptor{ - /*data=*/se::DeviceAddressBase(), layout.leading_dim_stride, - layout.batch_stride, type, - // BLAS is column-major by default. - (layout.order == se::gpu::MatrixLayout::Order::kColumnMajor - ? se::blas::Transpose::kNoTranspose - : se::blas::Transpose::kTranspose)}; - }; - - ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor lhs_desc, - create_matrix_desc(gemm_config.lhs_layout)); - ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor rhs_desc, - create_matrix_desc(gemm_config.rhs_layout)); - ASSIGN_OR_RETURN(se::gpu::MatrixDescriptor output_desc_base, - create_matrix_desc(gemm_config.output_layout)); - - se::gpu::OutputMatrixDescriptor out_desc(std::move(output_desc_base)); - out_desc.batch_size = gemm_config.output_layout.batch_size; - out_desc.m = gemm_config.output_layout.num_rows; - out_desc.n = gemm_config.output_layout.num_cols; - out_desc.k = gemm_config.lhs_layout.num_cols; - ASSIGN_OR_RETURN( - out_desc.compute_type, - se::gpu::GetBlasComputationType( - gemm_config.precision_algorithm, gemm_config.lhs_layout.dtype, - gemm_config.output_layout.dtype, gemm_config.compute_precision, - target_config().device_description.gpu_compute_capability())); - - se::blas::BlasSupport* blas = stream_executor()->AsBlas(); - if (blas == nullptr) { - return absl::InternalError("Failed to get BLAS support."); - } - std::vector algorithms; - - blas->GetBlasGemmAlgorithms(stream, lhs_desc, rhs_desc, &out_desc, - &gemm_config.alpha, &gemm_config.beta, - &algorithms); - - std::vector> configs; - configs.reserve(algorithms.size()); - for (se::blas::AlgorithmType algorithm : algorithms) { - AutotuneResult::GemmKey gemm_key; - gemm_key.set_algorithm(algorithm); - auto any = std::make_unique(); - any->PackFrom(gemm_key); - configs.push_back(std::move(any)); - } - return configs; -} - -absl::StatusOr> RocblasBackend::GetDefaultConfig( - const HloInstruction& instr) { - if (!IsSupported(instr)) { - return absl::InvalidArgumentError( - "RocblasBackend does not support this instruction."); - } - AutotuneResult::GemmKey gemm_key; - gemm_key.set_algorithm(se::blas::kDefaultAlgorithm); - auto any = std::make_unique(); - if (ShouldUseHipblasLt(instr)) { - gemm_key.set_algorithm(0); - } - any->PackFrom(gemm_key); - return any; -} - -absl::Status RocblasBackend::ApplyConfig(HloInstruction& instr, - const BackendConfig& config) { - AutotuneResult::GemmKey gemm_key; - if (!config.UnpackTo(&gemm_key)) { - return absl::InvalidArgumentError( - "Failed to unpack RocblasBackendConfig from Any."); - } - if (ShouldUseHipblasLt(instr) && gemm_key.algorithm() == -1) { - gemm_key.set_algorithm(0); - } - ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - instr.backend_config()); - GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config(); - backend_config.set_selected_algorithm(gemm_key.algorithm()); - backend_config.set_autotune_workspace_size( - gemm_key.autotune_workspace_size()); - RETURN_IF_ERROR(instr.set_backend_config(std::move(gpu_config))); - - if (instr.shape().IsTuple() && !instr.shape().tuple_shapes().empty()) { - Shape* workspace_shape = instr.mutable_shape()->mutable_tuple_shapes( - instr.shape().tuple_shapes().size() - 1); - if (workspace_shape->element_type() == S8 && - workspace_shape->dimensions().size() == 1) { - workspace_shape->set_dimensions(0, gemm_key.autotune_workspace_size()); - if (HloModule* module = instr.GetModule()) { - if (module->entry_computation() && - module->entry_computation()->root_instruction() == &instr) { - *module->mutable_entry_computation_layout()->mutable_result_layout() = - ShapeLayout(instr.shape()); - } - } - } - } - - return absl::OkStatus(); -} - -bool RocblasBackend::IsSupported(const HloInstruction& instr) { - return IsLegacyCublasMatmul(instr) || ShouldUseHipblasLt(instr); -} - -bool RocblasBackend::ShouldUseHipblasLt(const HloInstruction& instr) { - return fp8_lt_fallback_ && IsCublasLtMatmulF8(instr); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/backends/gpu/autotuner/rocblas.h b/third_party/xla/xla/backends/gpu/autotuner/rocblas.h deleted file mode 100644 index 8a176ffc3611a5..00000000000000 --- a/third_party/xla/xla/backends/gpu/autotuner/rocblas.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_BACKENDS_GPU_AUTOTUNER_ROCBLAS_H_ -#define XLA_BACKENDS_GPU_AUTOTUNER_ROCBLAS_H_ - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "xla/backends/autotuner/backends.pb.h" -#include "xla/backends/autotuner/codegen_backend.h" -#include "xla/backends/gpu/autotuner/gpu_codegen_backend.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/compiler.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -// A codegen backend for rocBLAS GEMM autotuning on ROCm/AMD GPUs. -// -// This backend is used to autotune rocBLAS algorithms for matrix -// multiplications. rocBLAS calls are represented as custom-call instructions: -// ``` -// %custom-call.1 = .. custom-call(...), custom_call_target="__cublas$gemm", -// backend_config={" -// gemm_backend_config":{"selected_algorithm":"18"} -// } -// ``` - -class RocblasBackend : public GpuCodegenBackend { - public: - explicit RocblasBackend(stream_executor::StreamExecutor* stream_executor, - const DebugOptions* debug_options, Compiler* compiler, - const Compiler::GpuTargetConfig* target_config, - bool fp8_lt_fallback = false) - : GpuCodegenBackend(autotuner::Backend::ROCBLAS, debug_options, compiler, - target_config, stream_executor, - /*uses_last_output_for_scratch=*/true), - fp8_lt_fallback_(fp8_lt_fallback) {} - - absl::StatusOr>> - GetSupportedConfigs(const HloInstruction& instr) override; - - absl::StatusOr> GetDefaultConfig( - const HloInstruction& instr) override; - - absl::Status ApplyConfig(HloInstruction& instr, - const BackendConfig& config) override; - - private: - bool ShouldUseHipblasLt(const HloInstruction& instr); - - bool IsSupported(const HloInstruction& instr) override; - // TODO(b/514330710): use valid version - std::string version() const override { return "unknown"; } - bool fp8_lt_fallback_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_BACKENDS_GPU_AUTOTUNER_ROCBLAS_H_ diff --git a/third_party/xla/xla/backends/gpu/autotuner/rocblas_test.cc b/third_party/xla/xla/backends/gpu/autotuner/rocblas_test.cc deleted file mode 100644 index 6e4bfabed63fec..00000000000000 --- a/third_party/xla/xla/backends/gpu/autotuner/rocblas_test.cc +++ /dev/null @@ -1,243 +0,0 @@ -/* Copyright 2025 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/backends/gpu/autotuner/rocblas.h" - -#include -#include - -#include -#include -#include "absl/status/status_matchers.h" -#include "absl/status/statusor.h" -#include "xla/autotuning.pb.h" -#include "xla/backends/autotuner/codegen_backend.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/testlib/filecheck.h" -#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" -#include "xla/service/compiler.h" -#include "xla/service/executable.h" -#include "xla/service/gpu/amdgpu_compiler.h" -#include "xla/service/platform_util.h" -#include "xla/stream_executor/blas.h" -#include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/tsl/util/proto/proto_matchers.h" -#include "xla/xla.pb.h" - -namespace xla { -namespace gpu { - -using RocblasBackendConfig = AutotuneResult::GemmKey; - -using absl_testing::IsOk; -using absl_testing::IsOkAndHolds; -using ::testing::IsEmpty; -using ::testing::Not; -using ::tsl::proto_testing::EqualsProto; - -const char kRocblasCustomCallHlo[] = R"( - HloModule module, entry_computation_layout={(f32[100,100]{1,0}, f32[100,100]{1,0})->f32[100,100]{1,0}} - - ENTRY %main (arg0: f32[100,100], arg1: f32[100,100]) -> f32[100,100] { - %arg0 = f32[100,100]{1,0} parameter(0) - %arg1 = f32[100,100]{1,0} parameter(1) - %custom-call.1 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%arg0, %arg1), - custom_call_target="__cublas$gemm", - backend_config={ - "gemm_backend_config":{ - "dot_dimension_numbers": - { - "lhs_contracting_dimensions":["1"], - "rhs_contracting_dimensions":["0"], - "lhs_batch_dimensions":[], - "rhs_batch_dimensions":[] - } - } - } - ROOT %get-tuple-element = f32[100,100]{1,0} get-tuple-element(%custom-call.1), index=0 - })"; - -const char kHipblasLtCustomCallHlo[] = R"( - HloModule test, entry_computation_layout={(f8e4m3fn[16,32]{1,0}, f8e5m2[32,16]{1,0}, f32[], f32[])->f32[16,16]{1,0}} - - ENTRY %test (x: f8e4m3fn[16,32], y: f8e5m2[32,16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] { - %x = f8e4m3fn[16,32]{1,0} parameter(0) - %y = f8e5m2[32,16]{1,0} parameter(1) - %transpose = f8e5m2[16,32]{1,0} transpose(%y), dimensions={1,0} - %x_scale = f32[] parameter(2) - %y_scale = f32[] parameter(3) - %cublas-gemm.1 = (f32[16,16]{1,0}, s8[33554432]{0}) custom-call(%x, %transpose, %x_scale, %y_scale), - custom_call_target="__cublas$lt$matmul$f8", - backend_config={ - "operation_queue_id":"0", - "gemm_backend_config":{ - "alpha_real":1, - "beta":0, - "dot_dimension_numbers":{ - "lhs_contracting_dimensions":["1"], - "rhs_contracting_dimensions":["1"], - "lhs_batch_dimensions":[], - "rhs_batch_dimensions":[] - }, - "alpha_imag":0, - "precision_config":{ - "operand_precision":["DEFAULT","DEFAULT"], - "algorithm":"ALG_UNSET" - }, - "epilogue":"DEFAULT", - "lhs_stride":"512", - "rhs_stride":"512", - "grad_x":false, - "grad_y":false, - "damax_output":false - }, - "force_earliest_schedule":false, - "reification_cost":[], - "device_type":"DEVICE_TYPE_INVALID" - } - ROOT %get-tuple-element = f32[16,16]{1,0} get-tuple-element(%cublas-gemm.1), index=0 -})"; - -const char kUnsupportedHlo[] = R"( - HloModule module - - computation { - p0 = bf16[1024,1024]{1,0} parameter(0) - convert0 = f32[1024,1024]{1,0} convert(p0) - p1 = s8[1024,1024]{1,0} parameter(1) - convert1 = f32[1024,1024]{1,0} convert(p1) - ROOT dot = f32[1024,1024]{1,0} dot(convert0, convert1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - - ENTRY main { - p0 = bf16[1024,1024]{1,0} parameter(0) - p1 = s8[1024,1024]{1,0} parameter(1) - ROOT fusion = f32[1024,1024]{1,0} fusion(p0, p1), - kind=kCustom, calls=computation, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} - })"; - -class RocblasBackendTest : public HloHardwareIndependentTestBase { - protected: - DebugOptions debug_options_; - AMDGPUCompiler compiler_; - se::StreamExecutor* stream_executor_; - Compiler::GpuTargetConfig target_config_; - RocblasBackend backend_; - - RocblasBackendTest() - : stream_executor_(PlatformUtil::GetDefaultPlatform() - .value() - ->ExecutorForDevice(0) - .value()), - target_config_(stream_executor_), - backend_(stream_executor_, &debug_options_, &compiler_, - &target_config_) {} - - RocblasBackendConfig ExpectedDefaultAlgorithm() { - auto config = AutotuneResult::GemmKey(); - config.set_algorithm(se::blas::kDefaultAlgorithm); - return config; - } -}; - -TEST_F(RocblasBackendTest, CanCreateRocblasBackend) { - ASSERT_NE(nullptr, &backend_); -} - -TEST_F(RocblasBackendTest, GetSupportedConfigsFromRocblasCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kRocblasCustomCallHlo)); - absl::StatusOr>> configs = - backend_.GetSupportedConfigs( - (*hlo_module->entry_computation()->root_instruction()->operand(0))); - EXPECT_THAT(configs, IsOkAndHolds(Not(IsEmpty()))); -} - -TEST_F(RocblasBackendTest, HipblasLtCustomCall) { - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kHipblasLtCustomCallHlo)); - const HloInstruction* instr = - hlo_module->entry_computation()->root_instruction()->operand(0); - RocblasBackend backend(stream_executor_, &debug_options_, &compiler_, - &target_config_, /*fp8_lt_fallback=*/true); - absl::StatusOr>> configs = - backend.GetSupportedConfigs(*instr); - EXPECT_THAT(configs, IsOkAndHolds(Not(IsEmpty()))); - - EXPECT_THAT(backend.GetDefaultConfig(*instr), IsOk()); - EXPECT_THAT(backend.Compile(*instr, *configs.value()[0]), IsOk()); -} - -TEST_F(RocblasBackendTest, - GetSupportedConfigsReturnsEmptyVectorNonRocblasCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kUnsupportedHlo)); - absl::StatusOr>> configs = - backend_.GetSupportedConfigs( - (*hlo_module->entry_computation()->root_instruction())); - EXPECT_THAT(configs, IsOkAndHolds(testing::SizeIs(0))); -} - -TEST_F(RocblasBackendTest, GetDefaultConfigFromRocblasCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kRocblasCustomCallHlo)); - - absl::StatusOr> config = - backend_.GetDefaultConfig( - (*hlo_module->entry_computation()->root_instruction()->operand(0))); - RocblasBackendConfig config_proto; - ASSERT_TRUE(config.value()->UnpackTo(&config_proto)); - EXPECT_THAT(config_proto, EqualsProto(ExpectedDefaultAlgorithm())); -} - -TEST_F(RocblasBackendTest, ApplyConfig) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnVerifiedModule(kRocblasCustomCallHlo)); - RocblasBackendConfig config; - config.set_algorithm(2); - config.set_autotune_workspace_size(42); - google::protobuf::Any any; - any.PackFrom(config); - TF_EXPECT_OK(backend_.ApplyConfig(*hlo_module->entry_computation() - ->root_instruction() - ->mutable_operands() - .at(0), - any)); - EXPECT_THAT(RunFileCheck(hlo_module->ToString(), - R"(CHECK: (f32[100,100]{1,0}, s8[42]{0}) custom-call - CHECK: "selected_algorithm":"2")"), - absl_testing::IsOkAndHolds(true)); -} - -TEST_F(RocblasBackendTest, Compile) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kRocblasCustomCallHlo)); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr config, - backend_.GetDefaultConfig( - *(module->entry_computation()->root_instruction()->operand(0)))); - absl::StatusOr> executable = backend_.Compile( - *(module->entry_computation()->root_instruction()), *config); - EXPECT_THAT(executable, IsOk()); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/backends/gpu/autotuner/triton/BUILD b/third_party/xla/xla/backends/gpu/autotuner/triton/BUILD index ab9e6393d1a73c..91cd23fa7378d0 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/autotuner/triton/BUILD @@ -139,6 +139,7 @@ cc_library( "//xla/service/gpu/model:tiling_from_block_parameters", "//xla/service/gpu/model:triton_emitter_constraints", "//xla/stream_executor:device_description", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:sorted_range", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/backends/gpu/autotuner/triton/cost_model_config_optimization.cc b/third_party/xla/xla/backends/gpu/autotuner/triton/cost_model_config_optimization.cc index 5e18d83f00e727..cbf5ab110b85af 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/triton/cost_model_config_optimization.cc +++ b/third_party/xla/xla/backends/gpu/autotuner/triton/cost_model_config_optimization.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/IR/MLIRContext.h" #include "xla/backends/gpu/transforms/convert_triton_gemm_config.h" #include "xla/codegen/tiling/symbolic_tile_analysis.h" @@ -72,21 +73,20 @@ absl::StatusOr EstimateRunTimeWithConfig( const TritonGemmConfig& config, GpuPerformanceModelWithIndexingAnalysis& cost_model, mlir::MLIRContext* mlir_context) { - TF_ASSIGN_OR_RETURN( - BlockLevelParameters block_params, - FindBlockLevelParameters(context.dot, config, mlir_context, - context.device_description)); + ASSIGN_OR_RETURN(BlockLevelParameters block_params, + FindBlockLevelParameters(context.dot, config, mlir_context, + context.device_description)); Tile dot_tiling; dot_tiling.add_sizes(config.block_k); - TF_ASSIGN_OR_RETURN(Tiling tiling, TilingFromAnnotatedFusion( - analysis, block_params, &dot_tiling)); + ASSIGN_OR_RETURN(Tiling tiling, TilingFromAnnotatedFusion( + analysis, block_params, &dot_tiling)); - TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, - analysis.ComputeTiledComputation(tiling)); + ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, + analysis.ComputeTiledComputation(tiling)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( EstimateRunTimeData estimate, cost_model.EstimateRunTimeForTiledHloComputation( fusion_adaptor, tiled_hlo_computation, block_params.num_warps)); @@ -305,7 +305,7 @@ absl::StatusOr> OptimizeConfigsWithCostModel( detail::EstimationContext context{fusion, dot, device_description}; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( detail::CostModelGemmTilingOptions options, detail::ParseCostModelGemmTilingOptions( debug_options.xla_gpu_experimental_cost_model_gemm_tiling_options())); @@ -335,7 +335,7 @@ absl::StatusOr> OptimizeConfigsWithCostModel( // Create the base set by either picking the top configs or estimating the // existing set. if (options.top.has_value()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( detail::OrderedEstimatesAndConfigs base_config_set, options.top_from_default ? EstimateConfigs(context, optimized_configs, mlir_context) @@ -351,9 +351,8 @@ absl::StatusOr> OptimizeConfigsWithCostModel( must_keep_original_configs = false; } else { VLOG(1) << "Cost Model: Using default set"; - TF_ASSIGN_OR_RETURN( - detail::OrderedEstimatesAndConfigs base_config_set, - EstimateConfigs(context, optimized_configs, mlir_context)); + ASSIGN_OR_RETURN(detail::OrderedEstimatesAndConfigs base_config_set, + EstimateConfigs(context, optimized_configs, mlir_context)); current_set = std::move(base_config_set); } @@ -361,8 +360,8 @@ absl::StatusOr> OptimizeConfigsWithCostModel( if (options.mixin.has_value()) { VLOG(1) << "Cost Model: Mixing in top " << *options.mixin << " configs"; - TF_ASSIGN_OR_RETURN(const detail::OrderedEstimatesAndConfigs& all, - get_estimated_all_configs()); + ASSIGN_OR_RETURN(const detail::OrderedEstimatesAndConfigs& all, + get_estimated_all_configs()); detail::OrderedEstimatesAndConfigs top_non_present = detail::GetTopEstimatedConfigs(all, *options.mixin, ¤t_set, diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD index 87ca4536980c63..bd424467758e5a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/BUILD @@ -31,6 +31,7 @@ cc_library( "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:launch_dimensions", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -200,6 +201,7 @@ cc_library( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -230,6 +232,7 @@ cc_library( "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:launch_dimensions", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/concatenate.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/concatenate.cc index fac4d5a0488d1a..5db818b608444a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/concatenate.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/concatenate.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" @@ -88,7 +89,7 @@ ConcatenateFusion::CreateMLIRModule( GetDefaultBufferAlignment(), GetWorkDimensions(), entry_function_name, BackendKind::kGpu); - TF_ASSIGN_OR_RETURN(auto kernel_definition, emitter.EmitKernelDefinition()); + ASSIGN_OR_RETURN(auto kernel_definition, emitter.EmitKernelDefinition()); return std::move(kernel_definition).TakeSource().TakeModule(); } diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/in_place_dynamic_update_slice.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/in_place_dynamic_update_slice.cc index 663bb53792f6fb..0e180d3cc01883 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/in_place_dynamic_update_slice.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/in_place_dynamic_update_slice.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" @@ -104,7 +105,7 @@ InPlaceDynamicUpdateSliceFusion::CreateMLIRModule( GetDefaultBufferAlignment(), GetWorkDimensions(), entry_function_name, BackendKind::kGpu); - TF_ASSIGN_OR_RETURN(auto kernel_definition, emitter.EmitKernelDefinition()); + ASSIGN_OR_RETURN(auto kernel_definition, emitter.EmitKernelDefinition()); return std::move(kernel_definition).TakeSource().TakeModule(); } diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/loop.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/loop.cc index e73388e4192b8c..fad5d96e20ba33 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/loop.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/loop.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -129,7 +130,7 @@ absl::StatusOr> LoopFusion::CreateMLIRModule( GetDefaultBufferAlignment(), GetWorkDimensions(), entry_function_name, BackendKind::kGpu); - TF_ASSIGN_OR_RETURN(auto kernel_definition, emitter.EmitKernelDefinition()); + ASSIGN_OR_RETURN(auto kernel_definition, emitter.EmitKernelDefinition()); return std::move(kernel_definition).TakeSource().TakeModule(); } diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/mlir_kernel_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/mlir_kernel_emitter.cc index b2b76eb0b1a7cd..81d1c8f1fde0e0 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/mlir_kernel_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/mlir_kernel_emitter.cc @@ -464,13 +464,13 @@ MlirKernelEmitter::CreateMLIRModule( auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion.name())); mlir::OwningOpRef module = llvm_ir::CreateMlirModuleOp(loc); - TF_ASSIGN_OR_RETURN(mlir::func::FuncOp entry_func, - emitters::EmitKernelApi( - *module, fusion, buffer_assignment, - GetDefaultBufferAlignment(), entry_function_name)); + ASSIGN_OR_RETURN(mlir::func::FuncOp entry_func, + emitters::EmitKernelApi(*module, fusion, buffer_assignment, + GetDefaultBufferAlignment(), + entry_function_name)); SetBackendKind(&mlir_context, entry_func, BackendKind::kGpu); - TF_RETURN_IF_ERROR(EmitMlir(module.get(), entry_func, fusion, mlir_context)); + RETURN_IF_ERROR(EmitMlir(module.get(), entry_func, fusion, mlir_context)); return module; } @@ -544,8 +544,8 @@ absl::Status MlirKernelEmitter::EmitMlir(mlir::ModuleOp module, emitters::PartitionedComputations computations( fusion.fused_instructions_computation(), &mlir_context, epilogues); - TF_ASSIGN_OR_RETURN(auto call_targets, emitters::EmitPartitionedComputations( - module, computations)); + ASSIGN_OR_RETURN(auto call_targets, + emitters::EmitPartitionedComputations(module, computations)); emitters::SetIndexDataLayout(module, fusion); diff --git a/third_party/xla/xla/backends/gpu/codegen/kernels/BUILD b/third_party/xla/xla/backends/gpu/codegen/kernels/BUILD index 321caea7e6c559..bb70486577700b 100644 --- a/third_party/xla/xla/backends/gpu/codegen/kernels/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/kernels/BUILD @@ -36,6 +36,7 @@ cc_library( ":custom_kernel_proto_cc", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/backends/gpu/codegen/kernels/custom_kernel.cc b/third_party/xla/xla/backends/gpu/codegen/kernels/custom_kernel.cc index 14d1870ac4ed9a..5b707b869817b4 100644 --- a/third_party/xla/xla/backends/gpu/codegen/kernels/custom_kernel.cc +++ b/third_party/xla/xla/backends/gpu/codegen/kernels/custom_kernel.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/codegen/kernels/custom_kernel.pb.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" @@ -85,7 +86,7 @@ std::string CustomKernel::ToString() const { absl::StatusOr CustomKernel::ToProto() const { CustomKernelProto proto; proto.set_name(name_); - TF_ASSIGN_OR_RETURN(*proto.mutable_kernel_spec(), kernel_spec_.ToProto()); + ASSIGN_OR_RETURN(*proto.mutable_kernel_spec(), kernel_spec_.ToProto()); *proto.mutable_block_dims() = block_dims_.ToProto(); *proto.mutable_thread_dims() = thread_dims_.ToProto(); if (cluster_dims_.has_value()) { @@ -99,16 +100,16 @@ absl::StatusOr CustomKernel::FromProto( const CustomKernelProto& proto, const std::optional& symbol_resolver) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::KernelLoaderSpec kernel_spec, se::KernelLoaderSpec::FromProto(proto.kernel_spec(), symbol_resolver)); - TF_ASSIGN_OR_RETURN(se::BlockDim block_dims, - se::BlockDim::FromProto(proto.block_dims())); - TF_ASSIGN_OR_RETURN(se::ThreadDim thread_dims, - se::ThreadDim::FromProto(proto.thread_dims())); + ASSIGN_OR_RETURN(se::BlockDim block_dims, + se::BlockDim::FromProto(proto.block_dims())); + ASSIGN_OR_RETURN(se::ThreadDim thread_dims, + se::ThreadDim::FromProto(proto.thread_dims())); if (proto.has_cluster_dim()) { - TF_ASSIGN_OR_RETURN(se::ClusterDim cluster_dims, - se::ClusterDim::FromProto(proto.cluster_dim())); + ASSIGN_OR_RETURN(se::ClusterDim cluster_dims, + se::ClusterDim::FromProto(proto.cluster_dim())); return CustomKernel(proto.name(), std::move(kernel_spec), block_dims, thread_dims, cluster_dims, proto.shared_memory_bytes()); } diff --git a/third_party/xla/xla/backends/gpu/codegen/llvm/BUILD b/third_party/xla/xla/backends/gpu/codegen/llvm/BUILD index a6b54b18d5d84a..ff8820d66364f4 100644 --- a/third_party/xla/xla/backends/gpu/codegen/llvm/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/llvm/BUILD @@ -116,6 +116,7 @@ cc_library( "//xla/service/llvm_ir:llvm_loop", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/backends/gpu/codegen/llvm/parallel_loop_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/llvm/parallel_loop_emitter.cc index f89b1e010d8641..0ad0ff1101f9a1 100644 --- a/third_party/xla/xla/backends/gpu/codegen/llvm/parallel_loop_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/llvm/parallel_loop_emitter.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Value.h" #include "xla/service/gpu/target_util.h" @@ -214,7 +215,7 @@ absl::Status ParallelLoopEmitter::EmitSerialLoop(absl::string_view loop_name, for (const llvm_ir::IrArray::Index& array_index : EmitIndexAndSetExitBasicBlock(loop_name, index_type, base_indvar)) { if (!check_bounds) { - TF_RETURN_IF_ERROR(body_emitter_(array_index)); + RETURN_IF_ERROR(body_emitter_(array_index)); } else { // If the unroll_factor does not divide the number of elements, we must // check that the index is in bounds, since the last iteration of the last @@ -229,7 +230,7 @@ absl::Status ParallelLoopEmitter::EmitSerialLoop(absl::string_view loop_name, llvm::ConstantInt::get(index_type, num_elements)), llvm_ir::IrName(loop_name, "unrolled_in_bounds"), b_, false); llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, b_); - TF_RETURN_IF_ERROR(body_emitter_(array_index)); + RETURN_IF_ERROR(body_emitter_(array_index)); llvm_ir::SetToFirstInsertPoint(if_in_bounds.after_block, b_); } } @@ -247,14 +248,14 @@ absl::Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name, // to add a loop inside the kernel. if (total_threads * unroll_factor_ >= num_elements) { VLOG(1) << "No loops inside the kernel"; - TF_RETURN_IF_ERROR(EmitSerialLoop(loop_name, index_type)); + RETURN_IF_ERROR(EmitSerialLoop(loop_name, index_type)); } else { KernelSupportLibrary ksl(b_, llvm_ir::UnrollMode::kDefaultUnroll); auto constant = [&](int64_t val) { return llvm::ConstantInt::get(index_type, val); }; - TF_RETURN_IF_ERROR(ksl.ForWithStatus( + RETURN_IF_ERROR(ksl.ForWithStatus( "loop", constant(0), constant(num_elements), constant(total_threads * unroll_factor_), [&](llvm::Value* base_indvar) { diff --git a/third_party/xla/xla/backends/gpu/codegen/tools/BUILD b/third_party/xla/xla/backends/gpu/codegen/tools/BUILD index d198bb5ee39cf4..cf6952f8f075d7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/tools/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/tools/BUILD @@ -120,6 +120,7 @@ xla_cc_binary( visibility = ["//visibility:private"], deps = [ "//xla/codegen/tools:test_lib", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/backends/gpu/codegen/tools/fusion_wrapper.cc b/third_party/xla/xla/backends/gpu/codegen/tools/fusion_wrapper.cc index daa8824f29687c..0c2af9967b42e9 100644 --- a/third_party/xla/xla/backends/gpu/codegen/tools/fusion_wrapper.cc +++ b/third_party/xla/xla/backends/gpu/codegen/tools/fusion_wrapper.cc @@ -16,6 +16,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/Support/raw_ostream.h" #include "xla/codegen/tools/test_lib.h" #include "xla/tsl/platform/statusor.h" @@ -25,7 +26,7 @@ namespace xla { namespace gpu { absl::Status Run(const std::string& filename) { - TF_ASSIGN_OR_RETURN(auto module, LoadTestModule(filename)); + ASSIGN_OR_RETURN(auto module, LoadTestModule(filename)); llvm::outs() << module->ToString(); return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index 8fa3418d095e07..15bec6b4faa682 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -507,10 +507,14 @@ cc_library( "//xla:util", "//xla/codegen/tiling:symbolic_tile_analysis", "//xla/codegen/tiling:tiling_specification", + "//xla/codegen/tiling/experimental:tile", + "//xla/codegen/tiling/experimental:tiled_hlo", + "//xla/codegen/xtile/codegen:emitter_helpers", "//xla/codegen/xtile/codegen:fusion_emitter", "//xla/hlo/analysis:symbolic_map", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", + "//xla/hlo/utils:hlo_traversal", "//xla/service:instruction_fusion", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", @@ -518,7 +522,7 @@ cc_library( "//xla/service/gpu/model:tiling_from_block_parameters", "//xla/service/gpu/model:triton_emitter_constraints", "//xla/tsl/platform:errors", - "//xla/tsl/platform:statusor", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/tests/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/tests/BUILD index fd5763b69878ba..70fd5770c02637 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/tests/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/tests/BUILD @@ -174,6 +174,7 @@ xla_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/log", @@ -268,8 +269,6 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/service/gpu/model:block_level_parameters", "//xla/tests:xla_internal_test_main", # fixdeps: keep - "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", ], diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/tests/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/tests/fusion_emitter_device_test.cc index ace7108af4eb7b..3db73f8ce12631 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/tests/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/tests/fusion_emitter_device_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "Eigen/Core" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/LLVMContext.h" #include "llvm/TargetParser/Triple.h" #include "mlir/IR/BuiltinOps.h" @@ -117,24 +118,24 @@ class TritonEmitterTest CreateXTileIrAndFileCheck(absl::string_view hlo_text, absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) { - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); + ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); return XTileTestBase::CreateXTileIrAndFileCheck( std::move(module), triton_fusion_name, filecheck_pattern); } absl::Status CreateTritonIrFromHloTextAndFileCheck( absl::string_view hlo_text, absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) { - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); + ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); return CreateTritonIrAndFileCheck(module.get(), triton_fusion_name, filecheck_pattern); } absl::Status CreateTritonIrFromHloTextAndFileCheckForDot( absl::string_view hlo_text, absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) { - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); + ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); return CreateTritonIrAndFileCheckForDot(module.get(), triton_fusion_name, filecheck_pattern); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/tests/fusion_emitter_shared_dialect_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/tests/fusion_emitter_shared_dialect_test.cc index 6102dbeaf9023f..bd89182ba6fae4 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/tests/fusion_emitter_shared_dialect_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/tests/fusion_emitter_shared_dialect_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include #include #include "absl/strings/string_view.h" #include "xla/backends/gpu/codegen/triton/xtile_test_base.h" @@ -22,8 +23,6 @@ limitations under the License. #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/gpu/model/block_level_parameters.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/tsl/platform/statusor.h" namespace xla { namespace gpu { @@ -41,7 +40,19 @@ namespace { class XTileDialectTest : public HloHardwareIndependentTestBase, public XTileTestBase {}; -TEST_F(XTileDialectTest, HloTransposeIsLoweredToStableHloTranspose) { +class XTileDialectTestParameterized + : public XTileDialectTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_SUITE_P(XTileDialectTestParameterized, + XTileDialectTestParameterized, testing::Bool(), + [](const ::testing::TestParamInfo& info) { + return info.param ? "ExperimentalEmitter" + : "LegacyEmitter"; + }); + +TEST_P(XTileDialectTestParameterized, + HloTransposeIsLoweredToStableHloTranspose) { constexpr absl::string_view kHloText = R"( HloModule t @@ -56,21 +67,22 @@ ENTRY e { calls=transpose_fusion, backend_config={"fusion_backend_config": {kind: "__triton"}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{16, 32}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *module->GetComputationWithName("transpose_fusion"), block_level_parameters, R"( CHECK: %[[RES:.*]] = stablehlo.transpose %[[ARG:.*]], dims = [1, 0] : (tensor<32x16xf32>) -> tensor<16x32xf32> -)")); +)", + GetParam())); } -TEST_F(XTileDialectTest, HloBitcastIsLoweredToTensorBitcast) { +TEST_P(XTileDialectTestParameterized, HloBitcastIsLoweredToTensorBitcast) { constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true @@ -85,20 +97,21 @@ ENTRY e { calls=bitcast_fusion, backend_config={"fusion_backend_config": {kind: "__triton"}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{16, 32}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *module->GetComputationWithName("bitcast_fusion"), block_level_parameters, R"( CHECK: %[[RES:.*]] = tensor.bitcast %[[ARG:.*]] : tensor<16x32xf32> to tensor<16x32xi32> -)")); +)", + GetParam())); } -TEST_F(XTileDialectTest, HloIotaIsLoweredToStableHloIota) { +TEST_P(XTileDialectTestParameterized, HloIotaIsLoweredToStableHloIota) { constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true @@ -111,20 +124,22 @@ ENTRY e { calls=iota_fusion, backend_config={"fusion_backend_config": {kind: "__triton"}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{16}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *module->GetComputationWithName("iota_fusion"), block_level_parameters, R"( CHECK: %[[RES:.*]] = stablehlo.iota dim = 0 : tensor<16xi32> -)")); +)", + GetParam())); } -TEST_F(XTileDialectTest, HloBroadcastInDimIsLoweredToStableHloBroadcastInDim) { +TEST_P(XTileDialectTestParameterized, + HloBroadcastInDimIsLoweredToStableHloBroadcastInDim) { constexpr absl::string_view kHloText = R"( HloModule t @@ -139,21 +154,22 @@ ENTRY e { calls=broadcast_in_dim_fusion, backend_config={"fusion_backend_config": {kind: "__triton"}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{16, 32, 8}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *module->GetComputationWithName("broadcast_in_dim_fusion"), block_level_parameters, R"( CHECK: %[[RES:.*]] = stablehlo.broadcast_in_dim %[[ARG:.*]], dims = [0, 1] : (tensor<16x32xf32>) -> tensor<16x32x8xf32> -)")); +)", + GetParam())); } -TEST_F(XTileDialectTest, +TEST_P(XTileDialectTestParameterized, HloZeroDimensionalBroadcastIsLoweredToStableHloBroadcastInDim) { constexpr absl::string_view kHloText = R"( HloModule t @@ -169,21 +185,22 @@ ENTRY e { calls=broadcast_in_dim_fusion, backend_config={"fusion_backend_config": {kind: "__triton"}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{16, 32, 8}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *module->GetComputationWithName("broadcast_in_dim_fusion"), block_level_parameters, R"( CHECK: %[[RES:.*]] = stablehlo.broadcast_in_dim %[[ARG:.*]], dims = [] : (tensor) -> tensor<16x32x8xf32> -)")); +)", + GetParam())); } -TEST_F(XTileDialectTest, HloReduceIsLoweredToStableHloReduce) { +TEST_P(XTileDialectTestParameterized, HloReduceIsLoweredToStableHloReduce) { constexpr absl::string_view kHloText = R"( HloModule t @@ -205,22 +222,23 @@ ENTRY e { calls=reduce_fusion, backend_config={"fusion_backend_config": {kind: "__triton"}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{16}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *module->GetComputationWithName("reduce_fusion"), block_level_parameters, R"( CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : tensor CHECK: %[[MASKED_INPUT:.*]] = xtile.mask {{.*}} CHECK: %[[RES:.*]] = stablehlo.reduce(%[[MASKED_INPUT]] init: %[[INIT]]) applies stablehlo.add across dimensions = [0] : (tensor<256x16xf32>, tensor) -> tensor<16xf32> -)")); +)", + GetParam())); } -TEST_F(XTileDialectTest, HloReshapeIsLoweredToStableHloReshape) { +TEST_P(XTileDialectTestParameterized, HloReshapeIsLoweredToStableHloReshape) { constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true @@ -235,20 +253,21 @@ ENTRY e { calls=reshape_fusion, backend_config={"fusion_backend_config": {kind: "__triton"}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{1, 16}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *module->GetComputationWithName("reshape_fusion"), block_level_parameters, R"( CHECK: %[[RES:.*]] = stablehlo.reshape %[[ARG:.*]] : (tensor<16xi32>) -> tensor<1x16xi32> -)")); +)", + GetParam())); } -TEST_F(XTileDialectTest, HloDotIsLoweredToStableHloDot) { +TEST_P(XTileDialectTestParameterized, HloDotIsLoweredToStableHloDot) { constexpr absl::string_view kHloText = R"( HloModule t @@ -267,21 +286,22 @@ ENTRY e { calls=dot_fusion, backend_config={"fusion_backend_config": {kind: "__triton"}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{32, 8}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *module->GetComputationWithName("dot_fusion"), block_level_parameters, R"( CHECK: %[[RES:.*]] = stablehlo.dot_general %[[ARG0:.*]], %[[ARG1:.*]], contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x8xf32>, tensor<8x8xf32>) -> tensor<32x8xf32> CHECK: %[[ADD_RES:.*]] = arith.addf %[[ARG2:.*]], %[[RES]] : tensor<32x8xf32> -)")); +)", + GetParam())); } -TEST_F(XTileDialectTest, HloScaledDotIsLoweredToXTileDotScaled) { +TEST_P(XTileDialectTestParameterized, HloScaledDotIsLoweredToXTileDotScaled) { constexpr absl::string_view kHloText = R"( HloModule m @@ -318,8 +338,8 @@ ENTRY e { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); auto& debug_options = module->mutable_config().mutable_debug_options(); debug_options.set_xla_gpu_experimental_scaled_dot_with_triton(true); @@ -327,15 +347,17 @@ ENTRY e { BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{128, 256}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *module->GetComputationWithName("triton_dot"), block_level_parameters, R"( CHECK: %[[DOT:.*]] = xtile.dot_scaled %[[LHS:.*]] scale %[[LHS_SCALE:.*]], %[[RHS:.*]] scale %[[RHS_SCALE:.*]] {dot_dimension_numbers = #stablehlo.dot, fastMath = true} : tensor<128x128xf8E5M2>, tensor<128x4xi8> * tensor<128x256xf8E5M2>, tensor<256x4xi8> -> tensor<128x256xf32> CHECK: %[[RES:.*]] = arith.addf %{{.*}}, %[[DOT]] : tensor<128x256xf32> - )")); + )", + GetParam())); } -TEST_F(XTileDialectTest, HloAllReduceIsLoweredToStableHloAllReduce) { +TEST_P(XTileDialectTestParameterized, + HloAllReduceIsLoweredToStableHloAllReduce) { constexpr absl::string_view kHloText = R"( HloModule wrapped_module_all-reduce-start @@ -359,22 +381,24 @@ TEST_F(XTileDialectTest, HloAllReduceIsLoweredToStableHloAllReduce) { // The HLO is not valid so we parse and return unverified. This is the same // HLO that gets generated in the collective_emitter_tests. - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, - ParseAndReturnUnverifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnUnverifiedModule(kHloText)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{4096}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *hlo_module->GetComputationWithName("wrapped_all-reduce-start"), block_level_parameters, R"( CHECK: stablehlo.all_reduce CHECK: stablehlo.add -)")); +)", + GetParam())); } -TEST_F(XTileDialectTest, HloUnsignedIntIsLoweredToStableHloUnsignedInt) { +TEST_P(XTileDialectTestParameterized, + HloUnsignedIntIsLoweredToStableHloUnsignedInt) { constexpr absl::string_view kHloText = R"( HloModule t, is_scheduled=true @@ -389,17 +413,75 @@ ENTRY e { calls=add_fusion, backend_config={"fusion_backend_config": {kind: "__triton"}} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{16}}; - TF_EXPECT_OK(CreateXTileIrAndFileCheck( + EXPECT_OK(CreateXTileIrAndFileCheck( *module->GetComputationWithName("add_fusion"), block_level_parameters, R"( CHECK: stablehlo.add{{.*}}: tensor<16xui32> -)")); +)", + GetParam())); +} + +TEST_F(XTileDialectTest, HloAllGatherDotLowering) { + constexpr absl::string_view kHloText = R"( + HloModule nested_all_gather_dot + + %ag_dot { + %param0 = f32[128,128]{1,0} parameter(0) + %param1 = f32[128,128]{1,0} parameter(1) + %all-gather1 = f32[256,128]{1,0} all-gather(%param0), + replica_groups={{0,1}}, dimensions={0} + %all-gather2 = f32[512,128]{1,0} all-gather(%all-gather1), + replica_groups={{0,1}}, dimensions={0} + ROOT %dot = f32[512,128]{1,0} dot(%all-gather2, %param1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + backend_config={sizes:[128]} + } + + ENTRY %entry { + %param0 = f32[128,128]{1,0} parameter(0) + %param1 = f32[128,128]{1,0} parameter(1) + ROOT %fusion = f32[512,128]{1,0} fusion(%param0, %param1), + kind=kLoop, calls=%ag_dot, + backend_config={ + "fusion_backend_config": { + "kind": "__triton_collective", + "block_level_fusion_config": { + "num_warps": "4", + "output_tiles": [{"sizes": ["128", "128"]}], + "num_ctas": 1, + "num_stages": 1, + "is_tma_allowed": false, + "is_warp_specialization_allowed": false + } + } + } + } + )"; + + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule(kHloText)); + + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = {{128, 128}}; + + EXPECT_OK(CreateXTileIrAndFileCheck( + *module->GetComputationWithName("ag_dot"), block_level_parameters, R"( + CHECK: xtile.entry_func @xtile_dialect_fn(%arg0: memref<2xi64> + CHECK: %[[SELECT1:.*]] = xtile.select_buffer %arg0[%{{.*}}] + CHECK-SAME: : memref<2xi64> -> memref<2xi64> + CHECK: %[[SELECT2:.*]] = xtile.select_buffer %[[SELECT1]][%{{.*}}] + CHECK-SAME: : memref<2xi64> -> memref<128x128xf32> + CHECK: %[[LHS_TILE:.*]] = xtile.extract %[[SELECT2]] + CHECK: %[[RHS_TILE:.*]] = xtile.extract %arg1 + CHECK: stablehlo.dot_general %[[LHS_TILE]], %[[RHS_TILE]] + )", + /*use_experimental_fusion_emitter=*/true)); } } // namespace diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/xtile_test_base.cc b/third_party/xla/xla/backends/gpu/codegen/triton/xtile_test_base.cc index f3ea21df3782bb..d7cf382660c7a6 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/xtile_test_base.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/xtile_test_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/backends/gpu/codegen/triton/xtile_test_base.h" +#include #include #include #include @@ -24,19 +25,26 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" #include "xla/backends/gpu/codegen/triton/xtile_compiler.h" +#include "xla/codegen/tiling/experimental/tiled_hlo.h" +#include "xla/codegen/tiling/experimental/tiling_space.h" #include "xla/codegen/tiling/symbolic_tile_analysis.h" #include "xla/codegen/tiling/tiling_specification.h" +#include "xla/codegen/xtile/codegen/emitter_helpers.h" +#include "xla/codegen/xtile/codegen/experimental_fusion_emitter.h" #include "xla/codegen/xtile/codegen/fusion_emitter.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/utils/hlo_traversal.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/block_level_parameters.h" @@ -45,7 +53,6 @@ limitations under the License. #include "xla/service/instruction_fusion.h" #include "xla/status_macros.h" #include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" #include "xla/util.h" namespace xla::gpu { @@ -54,7 +61,8 @@ absl::StatusOr< std::pair, std::unique_ptr>> XTileTestBase::CreateXTileIrAndFileCheck(std::unique_ptr hlo_module, absl::string_view triton_fusion_name, - absl::string_view filecheck_pattern) { + absl::string_view filecheck_pattern, + bool use_experimental_fusion_emitter) { auto* comp = hlo_module->GetComputationWithName(triton_fusion_name); TF_RET_CHECK(comp != nullptr) << absl::StrCat( "Computation '", triton_fusion_name, "' is not found in the module"); @@ -64,23 +72,22 @@ XTileTestBase::CreateXTileIrAndFileCheck(std::unique_ptr hlo_module, BlockLevelParameters block_level_parameters = BlockLevelParameters::FromBlockLevelFusionConfig( fusion_backend_config.block_level_fusion_config()); - TF_ASSIGN_OR_RETURN(mlir::OwningOpRef xtile_dialect_module, - CreateXTileIrAndFileCheck(*comp, block_level_parameters, - filecheck_pattern)); + ASSIGN_OR_RETURN(mlir::OwningOpRef xtile_dialect_module, + CreateXTileIrAndFileCheck(*comp, block_level_parameters, + filecheck_pattern, + use_experimental_fusion_emitter)); return std::make_pair(std::move(xtile_dialect_module), std::move(hlo_module)); } absl::StatusOr> -XTileTestBase::CreateXTileIrAndFileCheck( - const HloComputation& computation, +CreateXTileIrAndFileCheckLegacy( + mlir::MLIRContext* mlir_context, const HloComputation& computation, const BlockLevelParameters& block_level_parameters, absl::string_view filecheck_pattern) { auto* fusion = Cast(computation.FusionInstruction()); - LoadMlirDialectsForTriton(*mlir_context()); - SymbolicTileAnalysisOrError symbolic_tile_analysis_or = SymbolicTileAnalysis::AnalyzeComputation( - computation, mlir_context(), + computation, mlir_context, TritonEmitterConstraints::GetBuilder( TestGpuDeviceInfo::RTXA6000DeviceInfo())); @@ -93,19 +100,52 @@ XTileTestBase::CreateXTileIrAndFileCheck( const auto& symbolic_tile_analysis = std::get(symbolic_tile_analysis_or); - TF_ASSIGN_OR_RETURN(Tiling tiling, - TilingFromAnnotatedFusion(symbolic_tile_analysis, - block_level_parameters)); + ASSIGN_OR_RETURN(Tiling tiling, + TilingFromAnnotatedFusion(symbolic_tile_analysis, + block_level_parameters)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( mlir::OwningOpRef xtile_dialect_module, xtile::EmitXTileModule("xtile_dialect_fn", *fusion, - symbolic_tile_analysis, tiling, *mlir_context())); + symbolic_tile_analysis, tiling, *mlir_context)); + return xtile_dialect_module; +} +absl::StatusOr> +XTileTestBase::CreateXTileIrAndFileCheck( + const HloComputation& computation, + const BlockLevelParameters& block_level_parameters, + absl::string_view filecheck_pattern, bool use_experimental_fusion_emitter) { + mlir::OwningOpRef xtile_dialect_module; + LoadMlirDialectsForTriton(*mlir_context()); + if (use_experimental_fusion_emitter) { + namespace ge = ::xla::gpu::experimental; + auto* fusion = Cast(computation.FusionInstruction()); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(fusion); + std::unique_ptr tiling_space = + ge::TilingSpace::Create(*fusion_adaptor, mlir_context()); + ASSIGN_OR_RETURN( + llvm::SmallVector concrete_sizes, + GetTilingSpaceConcreteSizes(*tiling_space, block_level_parameters)); + TF_RETURN_IF_ERROR(tiling_space->AssignTileSizes( + xtile::GetPaddedTileSizes(concrete_sizes))); + ASSIGN_OR_RETURN(ge::TiledHloComputation tiled_computation, + ge::TiledHloComputation::Tile(*fusion_adaptor, + std::move(tiling_space))); + ASSIGN_OR_RETURN( + xtile_dialect_module, + xtile::EmitXTileModule("xtile_dialect_fn", *fusion, tiled_computation, + *mlir_context())); + } else { + ASSIGN_OR_RETURN(xtile_dialect_module, + CreateXTileIrAndFileCheckLegacy( + mlir_context(), computation, block_level_parameters, + filecheck_pattern)); + } std::string out; llvm::raw_string_ostream os(out); xtile_dialect_module->print(os); - TF_ASSIGN_OR_RETURN(bool succeeded, RunFileCheck(out, filecheck_pattern)); + ASSIGN_OR_RETURN(bool succeeded, RunFileCheck(out, filecheck_pattern)); if (!succeeded) { return absl::InternalError("FileCheck failed."); } @@ -127,7 +167,7 @@ absl::Status XTileTestBase::LowerXTileIrToTritonAndFileCheck( std::string out; llvm::raw_string_ostream os(out); xtile_dialect_module->print(os); - TF_ASSIGN_OR_RETURN(bool succeeded, RunFileCheck(out, filecheck_pattern)); + ASSIGN_OR_RETURN(bool succeeded, RunFileCheck(out, filecheck_pattern)); if (!succeeded) { return absl::InternalError("FileCheck failed."); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/xtile_test_base.h b/third_party/xla/xla/backends/gpu/codegen/triton/xtile_test_base.h index b1a5246da97a9f..63347370187b4c 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/xtile_test_base.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/xtile_test_base.h @@ -48,7 +48,8 @@ class XTileTestBase { std::pair, std::unique_ptr>> CreateXTileIrAndFileCheck(std::unique_ptr hlo_module, absl::string_view triton_fusion_name, - absl::string_view filecheck_pattern); + absl::string_view filecheck_pattern, + bool use_experimental_fusion_emitter = false); // Creates a shared dialect IR from the given HLO computation and returns it. // This function also checks the generated shared dialect IR against the @@ -56,7 +57,8 @@ class XTileTestBase { absl::StatusOr> CreateXTileIrAndFileCheck( const HloComputation& computation, const BlockLevelParameters& block_level_parameters, - absl::string_view filecheck_pattern); + absl::string_view filecheck_pattern, + bool use_experimental_fusion_emitter = false); // Lowers the given shared dialect IR to Triton IR and checks the result // against the `filecheck_pattern`. diff --git a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc index 7b2f14a71010db..879ded2e2a31c2 100644 --- a/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc +++ b/third_party/xla/xla/backends/gpu/collectives/nccl_communicator.cc @@ -166,6 +166,9 @@ absl::StatusOr> NcclCommunicator::Create( se::StreamExecutor* stream_executor, absl::AnyInvocable()> make_comm, std::shared_ptr cancel, bool is_async, tsl::Env& env) { + if (cancel == nullptr) { + cancel = std::make_shared(); + } auto f = [cancel, &make_comm]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN(ncclComm_t comm, make_comm()); if (cancel) { diff --git a/third_party/xla/xla/backends/gpu/profiler/BUILD b/third_party/xla/xla/backends/gpu/profiler/BUILD index db520d46dfe02c..783b31040e4bdc 100644 --- a/third_party/xla/xla/backends/gpu/profiler/BUILD +++ b/third_party/xla/xla/backends/gpu/profiler/BUILD @@ -52,6 +52,7 @@ cc_library( deps = [ "//xla/stream_executor:platform", "//xla/stream_executor/platform:platform_object_registry", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", ] + if_cuda_is_configured([ @@ -98,6 +99,7 @@ xla_test( "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", diff --git a/third_party/xla/xla/backends/gpu/profiler/kernel_name_tracer.cc b/third_party/xla/xla/backends/gpu/profiler/kernel_name_tracer.cc index 5413859c684bbc..2c6ec3d3a1dd57 100644 --- a/third_party/xla/xla/backends/gpu/profiler/kernel_name_tracer.cc +++ b/third_party/xla/xla/backends/gpu/profiler/kernel_name_tracer.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/profiler/kernel_name_tracer_factory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/platform_object_registry.h" @@ -28,9 +29,8 @@ namespace xla::gpu { absl::StatusOr> KernelNameTracer::Create( const stream_executor::Platform::Id& platform_id) { auto& registry = stream_executor::PlatformObjectRegistry::GetGlobalRegistry(); - TF_ASSIGN_OR_RETURN( - KernelNameTracerFactory::Type func, - registry.FindObject(platform_id)); + ASSIGN_OR_RETURN(KernelNameTracerFactory::Type func, + registry.FindObject(platform_id)); return func(); } diff --git a/third_party/xla/xla/backends/gpu/profiler/kernel_name_tracer_test.cc b/third_party/xla/xla/backends/gpu/profiler/kernel_name_tracer_test.cc index 5a3344ef09286b..46b4454c3bc1c7 100644 --- a/third_party/xla/xla/backends/gpu/profiler/kernel_name_tracer_test.cc +++ b/third_party/xla/xla/backends/gpu/profiler/kernel_name_tracer_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_buffer_thunk.h" #include "xla/backends/gpu/runtime/command_executor.h" @@ -64,8 +65,8 @@ using ::testing::ElementsAre; using ::testing::IsEmpty; absl::StatusOr GetPlatform() { - TF_ASSIGN_OR_RETURN(std::string name, - PlatformUtil::CanonicalPlatformName("gpu")); + ASSIGN_OR_RETURN(std::string name, + PlatformUtil::CanonicalPlatformName("gpu")); return stream_executor::PlatformManager::PlatformWithName( absl::AsciiStrToUpper(name)); } diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 04191b9abe33f1..5f1ae062cf959c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -125,7 +125,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:casts", ], ) @@ -147,6 +146,7 @@ cc_library( "//xla/tsl/lib/gtl:int_type", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -183,6 +183,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:trace_command_buffer_factory", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -208,6 +209,7 @@ cc_library( "//xla:util", "//xla/stream_executor:command_buffer", "//xla/stream_executor:stream", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -232,7 +234,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/service:shaped_slice", "//xla/service/gpu:buffer_allocations", "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_address", @@ -308,6 +309,7 @@ xla_test( "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", @@ -381,6 +383,7 @@ xla_cc_test( ":command", ":command_buffer_cmd_emitter", ":command_executor", + ":conditional_thunk", ":kernel_thunk", ":thunk", ":thunk_id", @@ -388,6 +391,7 @@ xla_cc_test( "//xla/codegen/emitters:kernel_arguments", "//xla/runtime:execution_graph", "//xla/service:buffer_assignment", + "//xla/service:shaped_slice", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/log:check", @@ -589,6 +593,7 @@ xla_test( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util/proto:proto_matchers", "@com_google_absl//absl/status", @@ -684,6 +689,7 @@ cc_library( "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -706,8 +712,10 @@ xla_test( deps = [ ":command", ":command_buffer_cmd", + ":command_buffer_cmd_emitter", ":command_buffer_thunk", ":command_executor", + ":conditional_thunk", ":device_to_device_copy_thunk", ":gemm_thunk", ":gpublas_lt_matmul_thunk", @@ -745,6 +753,7 @@ xla_test( "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -806,6 +815,8 @@ cc_library( srcs = ["conditional_thunk.cc"], hdrs = ["conditional_thunk.h"], deps = [ + ":command", + ":command_executor", ":host_memory_pool", ":thunk", ":thunk_executor", @@ -816,6 +827,7 @@ cc_library( "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:shaped_slice", + "//xla/stream_executor:command_buffer", "//xla/stream_executor:device_address", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:status_macros", @@ -837,6 +849,9 @@ xla_cc_test( name = "conditional_thunk_test", srcs = ["conditional_thunk_test.cc"], deps = [ + ":command", + ":command_executor", + ":command_state", ":conditional_thunk", ":thunk", ":thunk_executor", @@ -844,8 +859,14 @@ xla_cc_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/service:buffer_assignment", + "//xla/service:service_executable_run_options", "//xla/service:shaped_slice", + "//xla/service/gpu:buffer_allocations", + "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_address", + "//xla/stream_executor:mock_command_buffer", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/util/proto:proto_matchers", @@ -879,6 +900,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/base:core_headers", @@ -924,6 +946,7 @@ cc_library( "//xla/stream_executor:device_address", "//xla/stream_executor:dnn", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -1010,6 +1033,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1033,6 +1057,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1056,6 +1081,7 @@ cc_library( "//xla/service:shaped_slice", "//xla/stream_executor:device_address", "//xla/stream_executor:stream", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -1078,6 +1104,7 @@ cc_library( "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:shaped_slice", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -1382,6 +1409,7 @@ cc_library( "//xla/stream_executor:fft", "//xla/stream_executor:scratch_allocator", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -1426,6 +1454,7 @@ cc_library( "//xla/stream_executor:device_address", "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -1539,6 +1568,7 @@ xla_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util/proto:proto_matchers", "@com_google_absl//absl/log", @@ -1573,6 +1603,7 @@ cc_library( "//xla/stream_executor:device_address_handle", "//xla/stream_executor:stream", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -1707,6 +1738,7 @@ cuda_library( "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:scratch_allocator", "//xla/stream_executor:stream", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -1780,6 +1812,7 @@ cc_library( "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:stream", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", @@ -2074,6 +2107,7 @@ xla_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -2944,6 +2978,7 @@ cc_library( "//xla/service/gpu:gpu_executable_run_options", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -2994,6 +3029,7 @@ cc_library( "//xla/service/gpu:gpu_executable_run_options", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:tied_ref", "@com_google_absl//absl/log", @@ -3019,6 +3055,7 @@ cc_library( "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/runtime:device_id", "//xla/service:collective_ops_utils", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -3114,6 +3151,7 @@ cc_library( "//xla/service:computation_placer_hdr", "//xla/service:source_target_pairs", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -3333,6 +3371,7 @@ cc_library( "//xla/stream_executor:lazy_op_runner", "//xla/stream_executor:stream", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -3380,6 +3419,7 @@ cc_library( "//xla/stream_executor:device_address", "//xla/stream_executor:stream", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -3446,6 +3486,7 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/stream_executor/gpu:multi_gpu_barrier_kernel", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", @@ -3465,6 +3506,7 @@ cc_library( ":thunk_executor", ":thunk_proto_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", @@ -3756,6 +3798,7 @@ cc_library( "//xla/stream_executor:device_address", "//xla/stream_executor:stream", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -3868,6 +3911,7 @@ xla_test( "//xla/stream_executor:stream_executor_address_allocator", "//xla/tests:hlo_pjrt_test_base", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/util/proto:parse_text_proto", @@ -3975,6 +4019,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:check", @@ -4003,6 +4048,7 @@ cc_library( "//xla/stream_executor/gpu:buffer_comparator_kernel", "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -4053,6 +4099,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/stream_executor/gpu:make_batch_pointers_kernel", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", ], @@ -4070,6 +4117,7 @@ xla_test( "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", @@ -4098,6 +4146,7 @@ cc_library( "//xla/stream_executor/gpu:collective_kernel_metadata", "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/util:safe_reinterpret_cast", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -4128,6 +4177,7 @@ xla_cc_test( "//xla/stream_executor/gpu:all_reduce_kernel", "//xla/stream_executor/host:host_platform", "//xla/tsl/lib/gtl:int_type", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", @@ -4184,6 +4234,7 @@ xla_test( "//xla/tests:hlo_pjrt_test_base", "//xla/tests:literal_test_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", @@ -4213,6 +4264,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/stream_executor/gpu:ragged_all_to_all_kernel", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -4284,6 +4336,7 @@ cc_library( "//xla/stream_executor:platform_manager", "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/stream_executor/gpu:topk_kernel", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", @@ -4427,6 +4480,7 @@ cc_library( ":device_to_device_copy_thunk", ":device_to_host_copy_thunk", ":dynamic_memcpy_thunk", + ":dynamic_slice_fusion_v2_thunk", ":dynamic_slice_thunk", ":fft_thunk", ":gemm_thunk", @@ -4440,7 +4494,6 @@ cc_library( ":memset_thunk", ":norm_thunk", ":nvshmem_all_reduce_thunk", - ":nvshmem_collective_permute_thunk", ":nvshmem_collective_thunk", ":nvshmem_recv_thunk", ":nvshmem_send_thunk", @@ -4463,6 +4516,7 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:check", @@ -4500,12 +4554,12 @@ xla_cc_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/ffi", - "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:hlo_module_config", "//xla/stream_executor:device_description", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:safe_reinterpret_cast", "//xla/tsl/util/proto:parse_text_proto", @@ -4516,7 +4570,7 @@ xla_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", + "@com_google_googletest//:gtest_main", # fixdeps: keep ], ) @@ -4543,6 +4597,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -4600,6 +4655,7 @@ cc_library( "//xla/stream_executor:device_address", "//xla/stream_executor:stream", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -4640,6 +4696,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log", @@ -4691,6 +4748,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:stream", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log", @@ -4726,6 +4784,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:stream", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -4815,6 +4874,7 @@ xla_test( "//xla/tests:literal_test_util", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util/proto:proto_matchers", "@com_google_absl//absl/container:inlined_vector", @@ -4837,6 +4897,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/stream_executor:device_description", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log", @@ -5073,6 +5134,7 @@ cc_library( "//xla/tsl/lib/io:record_writer", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -5099,6 +5161,7 @@ xla_test( "//xla/tsl/lib/io:record_reader", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/testing:temporary_directory", @@ -5133,6 +5196,7 @@ cc_library( "//xla/stream_executor/gpu:buffer_debug_xor_checksum_kernel", "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -5179,6 +5243,7 @@ xla_test( "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:buffer_debug_log", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -5208,6 +5273,7 @@ cc_library( "//xla/stream_executor/gpu:buffer_debug_log", "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -5258,6 +5324,7 @@ xla_test( "//xla/stream_executor:stream_executor_address_allocator", "//xla/stream_executor/gpu:buffer_debug_log", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", @@ -5408,6 +5475,7 @@ xla_test( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util/proto:parse_text_proto", "//xla/tsl/util/proto:proto_matchers", diff --git a/third_party/xla/xla/backends/gpu/runtime/all_reduce.cc b/third_party/xla/xla/backends/gpu/runtime/all_reduce.cc index 0d9d49ed09f42d..9063ed93cda4ca 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_reduce.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" #include "xla/core/collectives/rank_id.h" @@ -100,7 +101,7 @@ absl::Status LaunchTypedKernel( static constexpr bool kIsTwoShot = TagType::kAllReduceStrategy == AllReduceStrategy::kTwoShot; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto kernel, (se::gpu::GpuKernelRegistry::GetGlobalRegistry() .LoadKernel< @@ -316,7 +317,7 @@ absl::StatusOr BuildAllReduceInfo( num_elements * primitive_util::ByteWidth(element_type); const AllReduceStrategy strategy = GetAllReduceStrategy(byte_size, is_multimem_enabled); - TF_RETURN_IF_ERROR(IsAllReduceKernelSupported( + RETURN_IF_ERROR(IsAllReduceKernelSupported( is_collective_kernel_enabled, device_info, num_operands, reduction_kind, num_devices, num_elements, element_type, /*is_local=*/true, is_multimem_enabled, all_reduce->replica_groups())); @@ -344,9 +345,9 @@ absl::Status RunAllReduceKernel( se::DeviceAddressBase symmetric_signal_buffer, // uint32_t signal_value, // se::DeviceAddressBase metadata) { - TF_RETURN_IF_ERROR(IsAllReduceKernelSupported(num_ranks, num_elements, - element_type, reduction_kind, - all_reduce_strategy)); + RETURN_IF_ERROR(IsAllReduceKernelSupported(num_ranks, num_elements, + element_type, reduction_kind, + all_reduce_strategy)); const auto launch_kernel_impl = [&](auto tag) -> absl::Status { return LaunchTypedKernel( tag, stream, launch_dimensions, symmetric_input_buffer, diff --git a/third_party/xla/xla/backends/gpu/runtime/all_reduce_build_info_test.cc b/third_party/xla/xla/backends/gpu/runtime/all_reduce_build_info_test.cc index de4c914a71287d..c1463cae41c257 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_reduce_build_info_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce_build_info_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/all_reduce.h" #include "xla/core/collectives/reduction_kind.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -89,7 +90,7 @@ class BuildAllReduceInfoTest : public HloHardwareIndependentTestBase { SCOPED_TRACE(testing::Message() << "module_str: " << module_str); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr module, ParseAndReturnVerifiedModule( module_str, replica_groups.empty() ? 1 : replica_groups.size())); diff --git a/third_party/xla/xla/backends/gpu/runtime/all_reduce_test.cc b/third_party/xla/xla/backends/gpu/runtime/all_reduce_test.cc index 85e284c939acd8..76986583265fa8 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_reduce_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/array.h" #include "xla/core/collectives/rank_id.h" #include "xla/core/collectives/reduction_kind.h" @@ -101,18 +102,18 @@ class AllReduceKernelTest : public ::testing::Test, int64_t num_elements = input_data[0].num_elements(); - TF_RETURN_IF_ERROR(executors[0]->EnablePeerAccessTo(executors[1])); - TF_RETURN_IF_ERROR(executors[1]->EnablePeerAccessTo(executors[0])); + RETURN_IF_ERROR(executors[0]->EnablePeerAccessTo(executors[1])); + RETURN_IF_ERROR(executors[1]->EnablePeerAccessTo(executors[0])); std::unique_ptr multicast_memory; if (params_.all_reduce_strategy == AllReduceStrategy::kMultimem) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( multicast_memory, dynamic_cast(executors[0]) ->CreateMulticastMemory(num_elements * sizeof(T), num_ranks)); for (int i = 0; i < num_ranks; ++i) { - TF_RETURN_IF_ERROR(multicast_memory->SubscribeDevice(i)); + RETURN_IF_ERROR(multicast_memory->SubscribeDevice(i)); } } @@ -153,16 +154,16 @@ class AllReduceKernelTest : public ::testing::Test, output_buffers.emplace_back(allocated_buffers[i].GetByteSlice( 2 * aligned_input_size, aligned_input_size)); TF_RET_CHECK(!output_buffers[i].is_null()); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( streams[i]->MemZero(&output_buffers[i], aligned_input_size)); signal_flags_buffers.emplace_back(allocated_buffers[i].GetByteSlice( 3 * aligned_input_size, aligned_signal_size)); TF_RET_CHECK(!signal_flags_buffers[i].is_null()); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( streams[i]->MemZero(&signal_flags_buffers[i], aligned_signal_size)); - TF_RETURN_IF_ERROR(streams[i]->Memcpy(&input_buffers[i], - input_data[i].data(), input_size)); + RETURN_IF_ERROR(streams[i]->Memcpy(&input_buffers[i], + input_data[i].data(), input_size)); XLA_VLOG_DEVICE(1, i) << "Allocated buffer: " << allocated_buffers[i].opaque() << ", Input buffer: " << input_buffers[i].opaque() @@ -192,7 +193,7 @@ class AllReduceKernelTest : public ::testing::Test, se::gpu::GpuExecutor* gpu_executor = dynamic_cast(executors[i]); TF_RET_CHECK(gpu_executor != nullptr); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( void* mapped_memory, multicast_memory->MapMemory(allocated_buffers[i], gpu_executor)); std::vector param_to_multimem_addresses = @@ -213,7 +214,7 @@ class AllReduceKernelTest : public ::testing::Test, param_to_multimem_addresses_byte_size); metadata.param_to_multimem_addresses = reinterpret_cast( param_to_multimem_addresses_buffer.opaque()); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( streams[i]->Memcpy(¶m_to_multimem_addresses_buffer, param_to_multimem_addresses.data(), param_to_multimem_addresses_byte_size)); @@ -249,20 +250,20 @@ class AllReduceKernelTest : public ::testing::Test, param_to_peers_size_bytes); metadata.param_to_peers = reinterpret_cast(param_to_peers_ptrs_buffer.opaque()); - TF_RETURN_IF_ERROR(streams[i]->Memcpy(¶m_to_peers_ptrs_buffer, - param_to_peers_ptrs.data(), - param_to_peers_size_bytes)); - TF_RETURN_IF_ERROR(streams[i]->Memcpy(&metadata_buffers[i], &metadata, - sizeof(CollectiveKernelMetadata))); + RETURN_IF_ERROR(streams[i]->Memcpy(¶m_to_peers_ptrs_buffer, + param_to_peers_ptrs.data(), + param_to_peers_size_bytes)); + RETURN_IF_ERROR(streams[i]->Memcpy(&metadata_buffers[i], &metadata, + sizeof(CollectiveKernelMetadata))); } for (int i = 0; i < num_ranks; ++i) { - TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); + RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); } for (int i = 0; i < num_ranks; ++i) { auto active_context = executors[i]->Activate(); - TF_RETURN_IF_ERROR(RunAllReduceKernel( + RETURN_IF_ERROR(RunAllReduceKernel( streams[i].get(), launch_dimensions, primitive_util::NativeToPrimitiveType(), /*reduction_kind=*/reduction_kind, @@ -280,13 +281,13 @@ class AllReduceKernelTest : public ::testing::Test, } for (int i = 0; i < num_ranks; ++i) { - TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); + RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); } std::vector> results; for (int i = 0; i < num_ranks; ++i) { Array output_results({num_elements}); - TF_RETURN_IF_ERROR(streams[i]->Memcpy( + RETURN_IF_ERROR(streams[i]->Memcpy( output_results.data(), output_buffers[i], num_elements * sizeof(T))); results.push_back(std::move(output_results)); diff --git a/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc index 6c0e26c0eb5d8f..ad8201c2ce005a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_reduce_thunk.cc @@ -60,7 +60,7 @@ namespace { absl::Status CheckImplementableInst(const HloInstruction* inst, Thunk::Kind reduction_op) { for (HloInstruction* operand : inst->operands()) { - TF_RETURN_IF_ERROR(IsValidOperand(operand->shape(), reduction_op)); + RETURN_IF_ERROR(IsValidOperand(operand->shape(), reduction_op)); } if (!MatchReductionComputation(inst->called_computations().front()) @@ -101,14 +101,14 @@ absl::Status RunAllReduce(ReductionKind reduction_kind, gpu_comm->GroupExecute([reduction_kind, &buffers, &stream](GpuCommunicator* comm) -> absl::Status { for (DeviceBufferPair& buffer : buffers) { - TF_RETURN_IF_ERROR(comm->LaunchAllReduce( + RETURN_IF_ERROR(comm->LaunchAllReduce( buffer.source_buffer, buffer.destination_buffer, buffer.element_type, buffer.element_count, reduction_kind, GpuCollectives::On(stream))); } return absl::OkStatus(); }); - TF_RETURN_IF_ERROR(future.Await()); + RETURN_IF_ERROR(future.Await()); XLA_VLOG_DEVICE(3, device_ordinal) << "Done performing all-reduce"; return absl::OkStatus(); } @@ -158,21 +158,21 @@ CollectiveOpGroupMode AllReduceThunk::GetGroupMode( } absl::Status AllReduceThunk::Prepare(const PrepareParams& params) { - TF_RETURN_IF_ERROR(CollectiveThunk::Prepare(params)); + RETURN_IF_ERROR(CollectiveThunk::Prepare(params)); return collective_kernel_thunk_->Prepare(params); } absl::Status AllReduceThunk::Initialize(const InitializeParams& params) { - TF_RETURN_IF_ERROR(CollectiveThunk::Initialize(params)); - TF_ASSIGN_OR_RETURN( + RETURN_IF_ERROR(CollectiveThunk::Initialize(params)); + ASSIGN_OR_RETURN( GpuCliqueKey clique_key, GetCollectiveGpuCliqueKey(*params.collective_params, config())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool use_collective_kernel, collective_kernel_thunk_->IsSupported(clique_key, *params.executor, *params.collective_params)); if (use_collective_kernel) { - TF_RETURN_IF_ERROR(collective_kernel_thunk_->Initialize(params)); + RETURN_IF_ERROR(collective_kernel_thunk_->Initialize(params)); } return absl::OkStatus(); } @@ -181,12 +181,11 @@ absl::Status AllReduceThunk::RunCollective(const ExecuteParams& params, const GpuCliqueKey& clique_key, se::Stream& stream, Communicator& comm) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params.buffer_allocations, buffers(), - config_.config.operand_element_type)); + ASSIGN_OR_RETURN(std::vector device_buffers, + ConvertToDeviceBuffers(params.buffer_allocations, buffers(), + config_.config.operand_element_type)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool use_collective_kernel, collective_kernel_thunk_->IsSupported(clique_key, *stream.parent(), *params.collective_params)); @@ -346,10 +345,9 @@ absl::Status ReduceScatterThunk::RunCollective(const ExecuteParams& params, const GpuCliqueKey& clique_key, se::Stream& stream, Communicator& comm) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params.buffer_allocations, buffers(), - config_.config.operand_element_type)); + ASSIGN_OR_RETURN(std::vector device_buffers, + ConvertToDeviceBuffers(params.buffer_allocations, buffers(), + config_.config.operand_element_type)); return RunReduceScatter(config_.reduction_kind, device_buffers, stream, comm, config_.config.use_symmetric_buffer); } @@ -361,7 +359,7 @@ absl::Status RunReduceScatter(ReductionKind reduction_kind, int device_ordinal = stream.parent()->device_ordinal(); XLA_VLOG_DEVICE(3, device_ordinal) << "Performing reduce-scatter"; - TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm.NumRanks()); + ASSIGN_OR_RETURN(int32_t num_ranks, comm.NumRanks()); auto* gpu_comm = tsl::down_cast(&comm); Future<> future = @@ -374,14 +372,14 @@ absl::Status RunReduceScatter(ReductionKind reduction_kind, << "Source buffer was not an exact multiple of the number of " "participants."; - TF_RETURN_IF_ERROR(comm->LaunchReduceScatter( + RETURN_IF_ERROR(comm->LaunchReduceScatter( buffer.source_buffer, buffer.destination_buffer, buffer.element_type, buffer.element_count / num_ranks, reduction_kind, GpuCollectives::On(stream))); } return absl::OkStatus(); }); - TF_RETURN_IF_ERROR(future.Await()); + RETURN_IF_ERROR(future.Await()); XLA_VLOG_DEVICE(3, device_ordinal) << "Done performing reduce-scatter"; return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc index a8d33646e5b554..033ecda53f718f 100644 --- a/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/all_to_all_thunk.cc @@ -104,7 +104,7 @@ AllToAllThunk::AllToAllThunk(ThunkInfo thunk_info, std::optional split_dim = instr->split_dimension(); for (HloInstruction* operand : instr->operands()) { Shape shape = operand->shape(); - TF_RETURN_IF_ERROR(IsValidOperand(shape, Thunk::kAllToAll)); + RETURN_IF_ERROR(IsValidOperand(shape, Thunk::kAllToAll)); if (split_dim && !ShapeUtil::IsEffectivelyMostMajorDimension(shape, *split_dim)) { return absl::UnimplementedError(absl::Substitute( @@ -124,29 +124,29 @@ AllToAllThunk::AllToAllThunk(ThunkInfo thunk_info, } absl::Status AllToAllThunk::Initialize(const InitializeParams& params) { - TF_RETURN_IF_ERROR(CollectiveThunk::Initialize(params)); + RETURN_IF_ERROR(CollectiveThunk::Initialize(params)); CHECK_GT(params.local_device_count, 0); XLA_VLOG_DEVICE(5, params.executor->device_ordinal()) << "Local device count : " << params.local_device_count; if (is_local(params.local_device_count) && p2p_memcpy_enabled_) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GpuCliqueKey clique_key, GetGpuCliqueKey(*params.collective_params, config().replica_groups, config().group_mode, communication_id())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Communicator * comm, params.collective_cliques->GetComm( clique_key, params.collective_params->global_device_id)); - TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); + ASSIGN_OR_RETURN(int32_t num_ranks, comm->NumRanks()); se::StreamExecutor* executor = params.executor; { absl::MutexLock lock(pointer_maps_mutex_); if (!receive_pointer_maps_.count(executor)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr alloc, executor->HostMemoryAllocate(num_ranks * sizeof(uint64_t))); bool inserted = @@ -157,15 +157,15 @@ absl::Status AllToAllThunk::Initialize(const InitializeParams& params) { { absl::MutexLock lock(events_mutex_); if (!events_.count(executor)) { - TF_ASSIGN_OR_RETURN(std::unique_ptr event, - executor->CreateEvent()); + ASSIGN_OR_RETURN(std::unique_ptr event, + executor->CreateEvent()); events_.insert({executor, std::move(event)}); } } std::optional rank = clique_key.rank(params.collective_params->global_device_id); size_t chunk_element_count = buffers()[0].element_count / num_ranks; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(params.buffer_allocations, buffers(), config_.config.operand_element_type)); @@ -187,7 +187,7 @@ absl::Status AllToAllThunk::Initialize(const InitializeParams& params) { buffer_rendezvous_value.buffer = reinterpret_cast( device_buffers[buffer_idx].destination_buffer.opaque()); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::shared_ptr> rendezvous_results, Rendezvous>( @@ -222,10 +222,9 @@ absl::Status AllToAllThunk::RunCollective(const ExecuteParams& params, const GpuCliqueKey& clique_key, se::Stream& stream, Communicator& comm) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params.buffer_allocations, buffers(), - config_.config.operand_element_type)); + ASSIGN_OR_RETURN(std::vector device_buffers, + ConvertToDeviceBuffers(params.buffer_allocations, buffers(), + config_.config.operand_element_type)); if (is_local(params.collective_params->local_device_count) && p2p_memcpy_enabled_) { @@ -317,7 +316,7 @@ absl::Status RunAllToAll(bool has_split_dimension, XLA_VLOG_DEVICE(3, device_ordinal) << "Performing all-to-all, has_split_dimension: " << has_split_dimension; - TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm.NumRanks()); + ASSIGN_OR_RETURN(int32_t num_ranks, comm.NumRanks()); PrimitiveType element_type = buffers[0].element_type; int64_t element_count = buffers[0].element_count; @@ -360,7 +359,7 @@ absl::Status RunAllToAll(bool has_split_dimension, auto future = comm.AllToAll( std::move(send_buffers), std::move(recv_buffers), element_type, chunk_element_count, GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(future.Await()); + RETURN_IF_ERROR(future.Await()); } else { for (const DeviceBufferPair& buffer : buffers) { send_buffers.push_back(buffer.source_buffer); @@ -370,7 +369,7 @@ absl::Status RunAllToAll(bool has_split_dimension, auto future = comm.AllToAll(std::move(send_buffers), std::move(recv_buffers), element_type, element_count, GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(future.Await()); + RETURN_IF_ERROR(future.Await()); } return absl::OkStatus(); @@ -382,20 +381,20 @@ absl::Status SyncProgress(absl::string_view name, int64_t num_ranks, se::Stream& stream, se::Event* event, std::vector& events) { // Record event for this device. - TF_RETURN_IF_ERROR(stream.RecordEvent(event)); + RETURN_IF_ERROR(stream.RecordEvent(event)); // Rendezvous to make sure that all devices have called RecordEvent before any // device calls WaitFor on another stream. std::string finish_rendezvous_key = absl::StrFormat("finish %s for rank %d, clique %s", name, rank.value(), clique_key.ToString()); - TF_RETURN_IF_ERROR(Rendezvous(/*name=*/finish_rendezvous_key, - /*key=*/clique_key, - /*num_threads=*/num_ranks)); + RETURN_IF_ERROR(Rendezvous(/*name=*/finish_rendezvous_key, + /*key=*/clique_key, + /*num_threads=*/num_ranks)); // Wait for all devices to reach the corresponding events. for (se::Event* e : events) { - TF_RETURN_IF_ERROR(stream.WaitFor(e)); + RETURN_IF_ERROR(stream.WaitFor(e)); } return absl::OkStatus(); } @@ -409,9 +408,9 @@ absl::Status RunMemCpyAllToAll(bool has_split_dimension, std::vector& events) { int device_ordinal = stream.parent()->device_ordinal(); XLA_VLOG_DEVICE(3, device_ordinal) << "Performing mem-copy-all-to-all"; - TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm.NumRanks()); - TF_RETURN_IF_ERROR(SyncProgress("before memcpy all-to-all", clique_key, rank, - num_ranks, stream, event, events)); + ASSIGN_OR_RETURN(int32_t num_ranks, comm.NumRanks()); + RETURN_IF_ERROR(SyncProgress("before memcpy all-to-all", clique_key, rank, + num_ranks, stream, event, events)); // AllToAll can operate in two modes. Either it specifies a split dimension, // in which case inputs are split and outputs concatenated in that dimension @@ -429,7 +428,7 @@ absl::Status RunMemCpyAllToAll(bool has_split_dimension, peer * chunk_element_count, chunk_element_count); se::DeviceAddressBase dst_addr = se::DeviceAddressBase((void*)receive_pointer_map[peer]); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( stream.MemcpyD2D(&dst_addr, send_slice, send_slice.size())); } } @@ -442,14 +441,14 @@ absl::Status RunMemCpyAllToAll(bool has_split_dimension, // double buffer, exchange data with peer se::DeviceAddressBase dst_addr = se::DeviceAddressBase((void*)receive_pointer_map[peer]); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( stream.MemcpyD2D(&dst_addr, buffers[buffer_idx].source_buffer, buffers[buffer_idx].source_buffer.size())); } } - TF_RETURN_IF_ERROR(SyncProgress("after memcpy all-to-all", clique_key, rank, - num_ranks, stream, event, events)); + RETURN_IF_ERROR(SyncProgress("after memcpy all-to-all", clique_key, rank, + num_ranks, stream, event, events)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc b/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc index 606d4346b7c3b7..d0265581f676a1 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffer_comparator.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "Eigen/Core" +#include "xla/tsl/platform/status_macros.h" #include "xla/primitive_util.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/shape.h" @@ -64,8 +65,7 @@ static absl::StatusOr DeviceCompare(const ComparisonParams& params) { se::DeviceAddressHandle out(executor, executor->AllocateScalar()); - TF_RETURN_IF_ERROR( - params.stream->MemZero(out.address_ptr(), sizeof(uint64_t))); + RETURN_IF_ERROR(params.stream->MemZero(out.address_ptr(), sizeof(uint64_t))); if (params.current.size() != params.expected.size()) { return Internal("Mismatched buffer size: %d bytes vs. %d bytes", params.current.size(), params.expected.size()); @@ -75,7 +75,7 @@ static absl::StatusOr DeviceCompare(const ComparisonParams& params) { se::DeviceAddress expected_typed(params.expected); uint64_t buffer_size = current_typed.ElementCount(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto comparison_kernel, stream_executor::gpu::GpuKernelRegistry::GetGlobalRegistry() .LoadKernel>( @@ -96,16 +96,16 @@ static absl::StatusOr DeviceCompare(const ComparisonParams& params) { dim.thread_counts_per_block()); se::DeviceAddress as_uint64(out.address()); - TF_RETURN_IF_ERROR(comparison_kernel.Launch( + RETURN_IF_ERROR(comparison_kernel.Launch( dim.thread_counts_per_block(), dim.block_counts(), params.stream, current_typed, expected_typed, static_cast(params.relative_tol), buffer_size, as_uint64)); uint64_t result = -1; CHECK_EQ(out.address().size(), sizeof(result)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( params.stream->Memcpy(&result, out.address(), sizeof(result))); - TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); + RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); return result == 0; } @@ -117,11 +117,11 @@ template static absl::StatusOr HostCompare(const ComparisonParams& params) { int64_t n = params.current.size() / sizeof(ElementType); std::vector host_current(n), host_expected(n); - TF_RETURN_IF_ERROR(params.stream->Memcpy(host_current.data(), params.current, - params.current.size())); - TF_RETURN_IF_ERROR(params.stream->Memcpy( - host_expected.data(), params.expected, params.expected.size())); - TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); + RETURN_IF_ERROR(params.stream->Memcpy(host_current.data(), params.current, + params.current.size())); + RETURN_IF_ERROR(params.stream->Memcpy(host_expected.data(), params.expected, + params.expected.size())); + RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); const auto canonicalize = [](ComparisonType a) -> ComparisonType { if (std::is_same::value && a) { @@ -171,12 +171,12 @@ template static absl::StatusOr CompareEqualParameterized( const ComparisonParams& params) { XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual"); - TF_ASSIGN_OR_RETURN(bool result, DeviceCompare(params)); + ASSIGN_OR_RETURN(bool result, DeviceCompare(params)); if (result) return true; if (!params.run_host_compare) return false; - TF_ASSIGN_OR_RETURN(bool host_return, - (HostCompare(params))); + ASSIGN_OR_RETURN(bool host_return, + (HostCompare(params))); CHECK_EQ(host_return, result) << "Host comparison succeeded even though GPU comparison failed."; return false; diff --git a/third_party/xla/xla/backends/gpu/runtime/buffers_checksum_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/buffers_checksum_thunk.cc index d7a4364bc841e3..ced1a760ededd4 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffers_checksum_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffers_checksum_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/buffer_debug_log_entry_metadata_store.h" #include "xla/backends/gpu/runtime/buffer_debug_log_structs.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -63,7 +64,7 @@ absl::Status BuffersDebugChecksumThunk::Initialize( if (!kernels_.contains(params.executor)) { se::gpu::GpuKernelRegistry registry = se::gpu::GpuKernelRegistry::GetGlobalRegistry(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto kernel, registry.LoadKernel( params.executor)); @@ -125,7 +126,7 @@ absl::Status BuffersDebugChecksumThunk::ExecuteOnStream( se::DeviceAddress device_buffer( params.buffer_allocations->GetDeviceAddress(buffer)); - TF_RETURN_IF_ERROR(kernel->Launch( + RETURN_IF_ERROR(kernel->Launch( thread_dim, se::BlockDim(1, 1, 1), params.stream, log_entry_id, device_buffer, device_buffer.size(), buffer_debug_log.GetDeviceHeader(), buffer_debug_log.GetDeviceEntries())); diff --git a/third_party/xla/xla/backends/gpu/runtime/buffers_checksum_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/buffers_checksum_thunk_test.cc index 9ae018366cb545..cdffd76fb61364 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffers_checksum_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffers_checksum_thunk_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/buffer_debug_log_entry_metadata_store.h" #include "xla/backends/gpu/runtime/buffer_debug_log_structs.h" #include "xla/backends/gpu/runtime/collective_clique_requests.h" @@ -240,10 +241,10 @@ TEST_F(BuffersDebugChecksumThunkTest, BufferAllocations allocations; }; auto setup_device = [this](int device_ordinal) -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - platform_->ExecutorForDevice(device_ordinal)); - TF_ASSIGN_OR_RETURN(std::unique_ptr stream, - executor->CreateStream()); + ASSIGN_OR_RETURN(se::StreamExecutor * executor, + platform_->ExecutorForDevice(device_ordinal)); + ASSIGN_OR_RETURN(std::unique_ptr stream, + executor->CreateStream()); auto allocator = std::make_unique( executor); diff --git a/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.cc index 4ac6c634926e72..f805be631c4aef 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/buffer_debug_log_entry_metadata_store.h" #include "xla/backends/gpu/runtime/buffer_debug_log_structs.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -97,19 +98,19 @@ absl::Status BuffersDebugFloatCheckThunk::Initialize( if (!kernels_.contains(params.executor)) { se::gpu::GpuKernelRegistry registry = se::gpu::GpuKernelRegistry::GetGlobalRegistry(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto kernel_f32, registry.LoadKernel( params.executor)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto kernel_bf16, registry.LoadKernel( params.executor)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto kernel_f64, registry.LoadKernel( params.executor)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto kernel_reduce, registry.LoadKernel< se::gpu::BufferDebugAppendReducedFloatCheckResultsKernel>( @@ -166,12 +167,12 @@ absl::Status CheckFloatsAndLog( GetBlockDimForBuffer(stream, buffer, tmp_ptr.ElementCount()); const size_t num_blocks = block_dim.x * block_dim.y * block_dim.z; - TF_RETURN_IF_ERROR(map_kernel.Launch(thread_dim, block_dim, stream, buffer, - buffer.ElementCount(), tmp_ptr, - tmp_ptr.ElementCount())); + RETURN_IF_ERROR(map_kernel.Launch(thread_dim, block_dim, stream, buffer, + buffer.ElementCount(), tmp_ptr, + tmp_ptr.ElementCount())); // Operations on the same stream perform in sequence, so at this point the // results of the previous FloatCheck operation are available. - TF_RETURN_IF_ERROR(reduce_append_kernel.Launch( + RETURN_IF_ERROR(reduce_append_kernel.Launch( thread_dim, se::BlockDim(1, 1, 1), stream, tmp_ptr, std::min(tmp_ptr.ElementCount(), num_blocks), entry_id, buffer_debug_log.GetDeviceHeader(), buffer_debug_log.GetDeviceEntries())); @@ -230,15 +231,15 @@ absl::Status BuffersDebugFloatCheckThunk::ExecuteOnStream( params.buffer_allocations->GetDeviceAddress(buffer); if (buffer_type == PrimitiveType::F32) { - TF_RETURN_IF_ERROR(CheckFloatsAndLog( + RETURN_IF_ERROR(CheckFloatsAndLog( params.stream, entry_id, buffer_debug_log, device_buffer, tmp_ptr, kernels->f32, kernels->reduce)); } else if (buffer_type == PrimitiveType::BF16) { - TF_RETURN_IF_ERROR(CheckFloatsAndLog( + RETURN_IF_ERROR(CheckFloatsAndLog( params.stream, entry_id, buffer_debug_log, device_buffer, tmp_ptr, kernels->bf16, kernels->reduce)); } else if (buffer_type == PrimitiveType::F64) { - TF_RETURN_IF_ERROR(CheckFloatsAndLog( + RETURN_IF_ERROR(CheckFloatsAndLog( params.stream, entry_id, buffer_debug_log, device_buffer, tmp_ptr, kernels->f64, kernels->reduce)); } else { diff --git a/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk_test.cc index bd1cd0dbbb2092..acb59c5e34a59f 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/buffer_debug_log_entry_metadata_store.h" #include "xla/backends/gpu/runtime/buffer_debug_log_structs.h" #include "xla/backends/gpu/runtime/buffer_debug_log_structs_test_matchers.h" @@ -761,10 +762,10 @@ TEST_F(BuffersDebugFloatCheckThunkTest, }; auto setup_device = [this](int device_ordinal) -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - platform_->ExecutorForDevice(device_ordinal)); - TF_ASSIGN_OR_RETURN(std::unique_ptr stream, - executor->CreateStream()); + ASSIGN_OR_RETURN(se::StreamExecutor * executor, + platform_->ExecutorForDevice(device_ordinal)); + ASSIGN_OR_RETURN(std::unique_ptr stream, + executor->CreateStream()); auto allocator = std::make_unique( executor); diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_cliques.cc b/third_party/xla/xla/backends/gpu/runtime/collective_cliques.cc index 25fead6dc67516..36c4b19c5e3a17 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_cliques.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_cliques.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/collectives/gpu_clique.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_cliques.h" @@ -183,8 +184,8 @@ absl::StatusOr AcquireCollectiveCliques( "For non-local GPU cliques (cliques that span multiple processes) " "clique id callback must be passed via execution params"); } - TF_ASSIGN_OR_RETURN(CliqueId clique_id, - params.collectives->CreateUniqueCliqueId()); + ASSIGN_OR_RETURN(CliqueId clique_id, + params.collectives->CreateUniqueCliqueId()); return CliqueIds(clique_id); }; @@ -192,7 +193,7 @@ absl::StatusOr AcquireCollectiveCliques( ? params.p2p_max_nchannels : params.collective_max_nchannels; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::shared_ptr clique, AcquireGpuClique(params.collectives, params.executor, params.run_id, r.key, r.device_groups, @@ -231,9 +232,9 @@ absl::StatusOr AcquireCollectiveCliques( auto* comm = dynamic_cast(*(*clique)->comm(*rank)); DCHECK(comm) << "Communicator must be in the acquired clique"; - TF_ASSIGN_OR_RETURN(std::unique_ptr dev_comm, - comm->CreateDeviceComm(reqs)); - TF_RETURN_IF_ERROR( + ASSIGN_OR_RETURN(std::unique_ptr dev_comm, + comm->CreateDeviceComm(reqs)); + RETURN_IF_ERROR( (*clique)->AddDeviceComm(*rank, reqs, std::move(dev_comm))); } } diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_execution.cc b/third_party/xla/xla/backends/gpu/runtime/collective_execution.cc index f3ba1a2020dc21..3792bb4752dff0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_execution.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_execution.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/runtime/collective_params.h" #include "xla/runtime/device_id.h" @@ -68,7 +69,7 @@ absl::StatusOr GetGpuCliqueKey( // Get the list of all devices that are participating in the collective // operation. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector devices, GetParticipatingDevices(global_device_id, *params.device_assn, replica_groups, group_mode)); diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_kernel_api.cc b/third_party/xla/xla/backends/gpu/runtime/collective_kernel_api.cc index 5bb2b58751b7ed..0fd5b4f668df12 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_kernel_api.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_kernel_api.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_clique_rendezvous.h" #include "xla/core/collectives/rank_id.h" @@ -57,7 +58,7 @@ absl::StatusOr GetCachedKernel( absl::MutexLock lock(*kernel_mutex); if (!kernel_per_executor->contains(executor)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto new_kernel, (stream_executor::gpu::GpuKernelRegistry::GetGlobalRegistry() .LoadKernel(executor))); @@ -110,8 +111,8 @@ absl::Status LaunchMultiGpuBarrier( signal_buffers[peer] = barrier_addresses[peer].opaque(); } - TF_ASSIGN_OR_RETURN(MultiGpuBarrierKernel::KernelType * kernel, - GetCachedKernel(stream->parent())); + ASSIGN_OR_RETURN(MultiGpuBarrierKernel::KernelType * kernel, + GetCachedKernel(stream->parent())); stream_executor::DeviceAddress typed_sync_counter( local_barrier_signal_value); @@ -133,7 +134,7 @@ absl::Status LaunchMultiGpuBarrierWithNccl( using MultiGpuBarrierWithNcclKernel = stream_executor::gpu::MultiGpuBarrierWithNcclKernel; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MultiGpuBarrierWithNcclKernel::KernelType * kernel, GetCachedKernel(stream->parent())); @@ -171,7 +172,7 @@ absl::StatusOr> CollectParamToPeers( size_t num_parameters = parameters.size(); // Exchange device parameters with all ranks in the clique. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto device_parameters, GpuCliqueRendezvous::Join(clique_key, rank, std::move(parameters))); @@ -184,8 +185,8 @@ absl::StatusOr> CollectParamToPeers( using DeviceParameters = std::vector; for (auto peer = RankId(0); peer < RankId(clique_key.num_devices()); ++peer) { - TF_ASSIGN_OR_RETURN(const DeviceParameters& peer_parameters, - device_parameters->at(peer)); + ASSIGN_OR_RETURN(const DeviceParameters& peer_parameters, + device_parameters->at(peer)); peer_to_parameters[peer.value()] = std::move(peer_parameters); } diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk.cc index bcdbd2565f32da..061c0f43a5d455 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk.cc @@ -140,14 +140,14 @@ absl::Status CopyCollectiveMetadataToDevice( reinterpret_cast(param_to_peers_ptrs_buffer.opaque()); metadata.param_to_multimem_addresses = reinterpret_cast(multimem_addresses_buffer.opaque()); - TF_RETURN_IF_ERROR(stream->Memcpy(&destination, &metadata, - sizeof(CollectiveKernelMetadata))); - TF_RETURN_IF_ERROR(stream->Memcpy(¶m_to_peers_ptrs_buffer, - param_to_peers_ptrs.data(), - param_to_peers_ptrs_size)); - TF_RETURN_IF_ERROR(stream->Memcpy(&multimem_addresses_buffer, - multimem_addresses.data(), - multimem_addresses_size)); + RETURN_IF_ERROR(stream->Memcpy(&destination, &metadata, + sizeof(CollectiveKernelMetadata))); + RETURN_IF_ERROR(stream->Memcpy(¶m_to_peers_ptrs_buffer, + param_to_peers_ptrs.data(), + param_to_peers_ptrs_size)); + RETURN_IF_ERROR(stream->Memcpy(&multimem_addresses_buffer, + multimem_addresses.data(), + multimem_addresses_size)); return absl::OkStatus(); } } // namespace @@ -165,10 +165,10 @@ absl::StatusOr CollectiveKernelThunk::IsSupported( VLOG(3) << "Collective kernel not supported: " << status.message(); return false; } - TF_RETURN_IF_ERROR(status); + RETURN_IF_ERROR(status); for (const GlobalDeviceId& device : clique_key.devices()) { - TF_ASSIGN_OR_RETURN(const int peer_device_id, - GetLocalDeviceId(device, collective_params)); + ASSIGN_OR_RETURN(const int peer_device_id, + GetLocalDeviceId(device, collective_params)); if (!executor.CanEnablePeerAccessTo(peer_device_id)) { XLA_VLOG_DEVICE(3, executor.device_ordinal()) << "Peer access is not supported with device " << peer_device_id; @@ -182,11 +182,11 @@ absl::Status CollectiveKernelThunk::Prepare(const PrepareParams& params) { TF_RET_CHECK(params.collective_params && params.collective_params->device_assn); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GpuCliqueKey clique_key, GetCollectiveGpuCliqueKey(*params.collective_params, collective_config_)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool use_collective_kernel, IsSupported(clique_key, *params.executor, *params.collective_params)); @@ -194,7 +194,7 @@ absl::Status CollectiveKernelThunk::Prepare(const PrepareParams& params) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> device_groups, GetParticipatingDevicesGroups(*params.collective_params->device_assn, collective_config_.replica_groups, @@ -204,7 +204,7 @@ absl::Status CollectiveKernelThunk::Prepare(const PrepareParams& params) { absl::c_for_each(device_groups, [](auto& group) { absl::c_sort(group); }); absl::c_sort(device_groups); - TF_RETURN_IF_ERROR(params.collective_clique_requests->RequestClique( + RETURN_IF_ERROR(params.collective_clique_requests->RequestClique( clique_key, device_groups)); absl::MutexLock lock(mutex_); @@ -222,12 +222,12 @@ absl::Status CollectiveKernelThunk::Prepare(const PrepareParams& params) { kNumSignalFlags * sizeof(int32_t), kXlaAllocatedBufferAlignBytes); const int64_t kLocalBufferSize = xla::RoundUpTo( buffers_[0].source_buffer.slice.size(), kXlaAllocatedBufferAlignBytes); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::DeviceAddressHandle local_buffers_handle, AllocateMemory(params.executor, kLocalBufferSize * kNumBuffers, "Local buffers")); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::DeviceAddressHandle signal_buffers_handle, AllocateMemory(params.executor, kSignalBufferSize * kNumBuffers, "Signal buffers")); @@ -246,11 +246,11 @@ absl::Status CollectiveKernelThunk::Prepare(const PrepareParams& params) { XLA_VLOG_DEVICE(3, params.executor->device_ordinal()) << "Request multicast address for source and destination buffers"; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( params.collective_memory_requests->RequestMulticastAddress( clique_key, params.buffer_allocations->GetDeviceAddress( buffers_[0].source_buffer.slice))); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( params.collective_memory_requests->RequestMulticastAddress( clique_key, params.buffer_allocations->GetDeviceAddress( buffers_[0].destination_buffer.slice))); @@ -267,7 +267,7 @@ int64_t CollectiveKernelThunk::GetInputSizeBytes() const { } absl::Status CollectiveKernelThunk::Initialize(const InitializeParams& params) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GpuCliqueKey clique_key, GetCollectiveGpuCliqueKey(*params.collective_params, collective_config_)); const std::optional rank = @@ -287,10 +287,10 @@ absl::Status CollectiveKernelThunk::Initialize(const InitializeParams& params) { // the buffer. The kernel will take care of leaving the buffer in // correct state after use, so we don't need to zero out after // initialization. - TF_RETURN_IF_ERROR(params.stream->MemZero( + RETURN_IF_ERROR(params.stream->MemZero( memory_state->signal_buffers_handle.address_ptr(), memory_state->signal_buffers_handle.address().size())); - TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); + RETURN_IF_ERROR(params.stream->BlockHostUntilDone()); // Create a kernel for execution. std::unique_ptr kernel = nullptr; if (!kernel_name_.empty()) { @@ -372,7 +372,7 @@ absl::Status CollectiveKernelThunk::Initialize(const InitializeParams& params) { multimem_addresses.resize(kNumParameters + 1, nullptr); const size_t multimem_addresses_size_bytes = multimem_addresses.size() * sizeof(void*); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( param_to_peers_ptrs, CollectParamToPeers(clique_key, state->rank, params.stream, std::move(parameters))); @@ -412,13 +412,13 @@ absl::Status CollectiveKernelThunk::Initialize(const InitializeParams& params) { state->metadata = params.executor->Allocate( sizeof(CollectiveKernelMetadata) + param_to_peers_ptrs_size_bytes, 0); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( param_to_peers_ptrs, CollectParamToPeers(clique_key, state->rank, params.stream, std::move(parameters))); } - TF_RETURN_IF_ERROR(CopyCollectiveMetadataToDevice( + RETURN_IF_ERROR(CopyCollectiveMetadataToDevice( params.stream, metadata, param_to_peers_ptrs, multimem_addresses, state->metadata)); return absl::OkStatus(); @@ -433,7 +433,7 @@ absl::Status CollectiveKernelThunk::ExecuteOnStream( TF_RET_CHECK(stream != nullptr); const int device_ordinal = stream->parent()->device_ordinal(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GpuCliqueKey clique_key, GetCollectiveGpuCliqueKey(*params.collective_params, collective_config_)); const int32_t num_devices = clique_key.num_devices(); @@ -491,16 +491,16 @@ absl::Status CollectiveKernelThunk::ExecuteOnStream( << launch_dimensions_->num_blocks() << "x" << launch_dimensions_->num_threads_per_block() << "(block x threadsPerBlock)"; - TF_ASSIGN_OR_RETURN(se::DeviceAddressBase remote_buffers, - GetParameterDeviceMemoryBase( - state->metadata, /*num_parameters=*/kNumParameters, - /*num_devices=*/num_devices, - /*parameter_index=*/0)); - TF_ASSIGN_OR_RETURN(se::DeviceAddressBase signal_buffers, - GetParameterDeviceMemoryBase( - state->metadata, /*num_parameters=*/kNumParameters, - /*num_devices=*/num_devices, - /*parameter_index=*/1)); + ASSIGN_OR_RETURN(se::DeviceAddressBase remote_buffers, + GetParameterDeviceMemoryBase( + state->metadata, /*num_parameters=*/kNumParameters, + /*num_devices=*/num_devices, + /*parameter_index=*/0)); + ASSIGN_OR_RETURN(se::DeviceAddressBase signal_buffers, + GetParameterDeviceMemoryBase( + state->metadata, /*num_parameters=*/kNumParameters, + /*num_devices=*/num_devices, + /*parameter_index=*/1)); std::array kernel_args = { source_buffer, destination_buffer, diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk_test.cc index 7d57de6cbab3ae..a9213236830052 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/runtime/collective_clique_requests.h" #include "xla/backends/gpu/runtime/collective_cliques.h" @@ -232,14 +233,14 @@ absl::StatusOr> CompilePtxToCubin( const DebugOptions& debug_options) { se::cuda::CompilationProviderOptions options = se::cuda::CompilationProviderOptions::FromDebugOptions(debug_options); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr compilation_provider, se::cuda::AssembleCompilationProvider(options)); se::CudaComputeCapability cc = *device_description.gpu_compute_capability().cuda_compute_capability(); se::cuda::CompilationOptions compilation_options = PtxCompileOptionsFromDebugOptions(debug_options); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::cuda::Assembly assembly, compilation_provider->Compile(cc, ptx_string, compilation_options)); return std::move(assembly.cubin); @@ -255,7 +256,7 @@ absl::StatusOr RunCollectiveKernelThunk( std::make_pair(LocalDeviceId(0), GlobalDeviceId(0)), std::make_pair(LocalDeviceId(1), GlobalDeviceId(1))}); - TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); + ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); ServiceExecutableRunOptions run_options; run_options.mutable_run_options()->set_stream(stream.get()); DeviceAssignment device_assignment(/*replica_count=*/metadata.num_devices, @@ -269,7 +270,7 @@ absl::StatusOr RunCollectiveKernelThunk( run_options.mutable_run_options()->set_gpu_executable_run_options( &gpu_options); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( CollectiveParams collective_params, CollectiveParams::Create(run_options, /*async_streams=*/{}, LocalDeviceId(executor->device_ordinal()))); @@ -293,9 +294,9 @@ absl::StatusOr RunCollectiveKernelThunk( if (!input_data.empty()) { VLOG(3) << "Copying input data to the device"; - TF_RETURN_IF_ERROR(stream->Memcpy(&input_buffer, input_data.data(), - metadata.input_data_size_bytes)); - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + RETURN_IF_ERROR(stream->Memcpy(&input_buffer, input_data.data(), + metadata.input_data_size_bytes)); + RETURN_IF_ERROR(stream->BlockHostUntilDone()); } CollectiveCliqueRequests clique_requests; @@ -305,7 +306,7 @@ absl::StatusOr RunCollectiveKernelThunk( all_replica_groups.add_replica_ids(i); } // Request a clique that covers all devices (this test runs on 2 gpus). - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GpuCliqueKey clique_key, xla::gpu::GetGpuCliqueKey( collective_params, {all_replica_groups}, @@ -314,7 +315,7 @@ absl::StatusOr RunCollectiveKernelThunk( for (int i = 0; i < metadata.num_devices; ++i) { all_device_groups.push_back(GlobalDeviceId(i)); } - TF_RETURN_IF_ERROR(clique_requests.RequestClique( + RETURN_IF_ERROR(clique_requests.RequestClique( clique_key, /*device_groups=*/{all_device_groups})); CollectiveMemoryRequests memory_requests(buffer_allocations); ScratchMemoryRequests scratch_memory_requests; @@ -322,13 +323,12 @@ absl::StatusOr RunCollectiveKernelThunk( &collective_params, &clique_requests, &memory_requests, &scratch_memory_requests, executor, &buffer_allocations}; - TF_RETURN_IF_ERROR(metadata.thunk->Prepare(prepare_params)); + RETURN_IF_ERROR(metadata.thunk->Prepare(prepare_params)); CollectiveMemoryCache collective_memory_cache; CollectiveCliques collective_cliques; - TF_ASSIGN_OR_RETURN( - collective_cliques, - AcquireCollectiveCliques(collective_params, clique_requests)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(collective_cliques, AcquireCollectiveCliques( + collective_params, clique_requests)); + ASSIGN_OR_RETURN( CollectiveMemory collective_memory, AcquireCollectiveMemory(collective_params, collective_cliques, memory_requests, collective_memory_cache)); @@ -350,19 +350,18 @@ absl::StatusOr RunCollectiveKernelThunk( std::vector cubin; if (!metadata.use_ptx) { - TF_ASSIGN_OR_RETURN( - cubin, - CompilePtxToCubin(kKernelSource, executor->GetDeviceDescription(), - DebugOptions())); + ASSIGN_OR_RETURN(cubin, CompilePtxToCubin(kKernelSource, + executor->GetDeviceDescription(), + DebugOptions())); initialize_params.src.binary = cubin; } - TF_RETURN_IF_ERROR(metadata.thunk->Initialize(initialize_params)); + RETURN_IF_ERROR(metadata.thunk->Initialize(initialize_params)); auto execute_params = Thunk::ExecuteParams::Create( run_options, buffer_allocations, stream.get(), /*command_buffer_trace_stream=*/nullptr, &collective_params, /*collective_cliques=*/nullptr, /*collective_memory=*/&collective_memory); - TF_RETURN_IF_ERROR(metadata.thunk->ExecuteOnStream(execute_params)); + RETURN_IF_ERROR(metadata.thunk->ExecuteOnStream(execute_params)); return output_buffer; } diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_params.cc b/third_party/xla/xla/backends/gpu/runtime/collective_params.cc index a3be16a1ed4455..ec7da5a8a13211 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_params.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_params.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/core/collectives/collectives.h" #include "xla/core/collectives/collectives_registry.h" @@ -102,8 +103,8 @@ absl::StatusOr CollectiveParams::Create( ? &*gpu_options->incarnations() : nullptr; - TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, - GetGlobalDeviceId(device_id_map, local_device_id)); + ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, + GetGlobalDeviceId(device_id_map, local_device_id)); return CollectiveParams( collectives, run_options.stream()->parent(), diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc index 10bf6659677d86..95957f88a37024 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_permute_thunk.cc @@ -385,7 +385,7 @@ absl::Status RunCollectivePermute(P2PConfig::SourceTargetRanks source_target, auto future = comm.CollectivePermute( src, dst, buf.element_type, buf.element_count, source_target.source, target_ranks, GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(future.Await()); + RETURN_IF_ERROR(future.Await()); } } else { auto* gpu_comm = tsl::down_cast(&comm); @@ -396,14 +396,14 @@ absl::Status RunCollectivePermute(P2PConfig::SourceTargetRanks source_target, se::DeviceAddressBase src = src_addrs.at(idx); se::DeviceAddressBase dst = dest_addrs.at(idx); const DeviceBufferPair& buf = buffers.at(idx); - TF_RETURN_IF_ERROR(comm->LaunchCollectivePermute( + RETURN_IF_ERROR(comm->LaunchCollectivePermute( src, dst, buf.element_type, buf.element_count, source_target.source, target_ranks, GpuCollectives::On(stream))); } return absl::OkStatus(); }); - TF_RETURN_IF_ERROR(future.Await()); + RETURN_IF_ERROR(future.Await()); } if (!source_target.source) { @@ -412,7 +412,7 @@ absl::Status RunCollectivePermute(P2PConfig::SourceTargetRanks source_target, VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", device_string); for (se::DeviceAddressBase& dest_addr : dest_addrs) { - TF_RETURN_IF_ERROR(stream.MemZero(&dest_addr, dest_addr.size())); + RETURN_IF_ERROR(stream.MemZero(&dest_addr, dest_addr.size())); } } @@ -440,7 +440,7 @@ static absl::Status RunPeerAccessPermute( // Borrow a "ready" event and record it on our stream to signal that all // prior work on this rank's buffers is complete. ASSIGN_OR_RETURN(EventPool::Event ready, pool->GetOrCreateEvent()); - TF_RETURN_IF_ERROR(stream.RecordEvent(ready->get())); + RETURN_IF_ERROR(stream.RecordEvent(ready->get())); // Create promise/future pair for the "done" event that the sender will // set after completing the memcpy. @@ -459,7 +459,7 @@ static absl::Status RunPeerAccessPermute( // Wait for target's stream to be ready before writing to its buffers. ASSIGN_OR_RETURN(const Events& target_events, rendezvous->at(target)); - TF_RETURN_IF_ERROR(stream.WaitFor(target_events.ready->get())); + RETURN_IF_ERROR(stream.WaitFor(target_events.ready->get())); // Perform D2D copies from our source to target's destination. for (const auto& buf : device_buffers) { @@ -469,14 +469,14 @@ static absl::Status RunPeerAccessPermute( return Internal("Peer address not found for target rank %d", target.value()); } - TF_RETURN_IF_ERROR(stream.MemcpyD2D(&*dst_addr, buf.source_buffer, - buf.source_buffer.size())); + RETURN_IF_ERROR(stream.MemcpyD2D(&*dst_addr, buf.source_buffer, + buf.source_buffer.size())); } // Record a "done" event and fulfill the promise so the target knows // the copy is complete. ASSIGN_OR_RETURN(EventPool::Event done, pool->GetOrCreateEvent()); - TF_RETURN_IF_ERROR(stream.RecordEvent(done->get())); + RETURN_IF_ERROR(stream.RecordEvent(done->get())); done_promise.Set(std::move(done)); } else { // Not a sender — fulfill promise with a dummy event. @@ -488,7 +488,7 @@ static absl::Status RunPeerAccessPermute( if (!source_target.source) { for (const auto& buf : device_buffers) { auto dest = buf.destination_buffer; - TF_RETURN_IF_ERROR(stream.MemZero(&dest, dest.size())); + RETURN_IF_ERROR(stream.MemZero(&dest, dest.size())); } } @@ -502,7 +502,7 @@ static absl::Status RunPeerAccessPermute( const absl::StatusOr& done_result = source_events.done.Await(); if (!done_result.ok()) return done_result.status(); - TF_RETURN_IF_ERROR(stream.WaitFor((*done_result)->get())); + RETURN_IF_ERROR(stream.WaitFor((*done_result)->get())); } return absl::OkStatus(); diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/collective_thunk.cc index abc92f34ceaa1f..b7b1955efc8c1e 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_thunk.cc @@ -311,13 +311,13 @@ absl::Status CollectiveThunk::Prepare(const PrepareParams& params) { if (CanUseSymmetricBuffer() && config().use_symmetric_buffer) { for (const Buffer& buffer : buffers_) { if (buffer.source_memory_space == kCollectiveMemorySpaceColor) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( params.collective_memory_requests->RequestSymmetricAllocation( clique_key, buffer.source_buffer.slice.index())); } if (buffer.destination_memory_space == kCollectiveMemorySpaceColor) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( params.collective_memory_requests->RequestSymmetricAllocation( clique_key, buffer.destination_buffer.slice.index())); } diff --git a/third_party/xla/xla/backends/gpu/runtime/command.h b/third_party/xla/xla/backends/gpu/runtime/command.h index 7c0b2242a9e746..82427691432d0a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command.h +++ b/third_party/xla/xla/backends/gpu/runtime/command.h @@ -39,7 +39,6 @@ limitations under the License. #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/platform.h" #include "xla/xla.pb.h" -#include "tsl/platform/casts.h" namespace xla::gpu { @@ -238,8 +237,15 @@ class Command : public Thunk { std::invoke_result_t Walk(F&& callback) const; protected: - // WalkNested uses Thunk::Walker = absl::FunctionRef. - // Subclasses that have nested commands must override this. + // Walks all nested commands and calls `callback` for them. This is separate + // from Thunk::WalkNested because a Thunk/Command hybrid can own non-command + // thunks for direct execution and command executors for command-buffer + // recording. + using CommandWalker = absl::FunctionRef; + virtual absl::Status WalkNestedCommands(CommandWalker /*callback*/) { + return absl::OkStatus(); + } + absl::Status WalkNested(Walker callback) override { return absl::OkStatus(); } private: @@ -272,11 +278,8 @@ std::invoke_result_t Command::Walk(F&& callback) { }).IgnoreError(); // Error can never happen here. } else { RETURN_IF_ERROR(callback(this)); - // Adapt Command*-typed callback to Thunk::Walker (Thunk*-typed) for - // WalkNested. The down_cast is safe because WalkNested only visits - // Commands in a Command context. - return WalkNested([&callback](Thunk* thunk) -> absl::Status { - return callback(tsl::down_cast(thunk)); + return WalkNestedCommands([&callback](Command* command) -> absl::Status { + return callback(command); }); } } diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 066a207b51ae85..c16ef2492d2140 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -41,7 +41,6 @@ limitations under the License. #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/shaped_slice.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/command_buffer.h" @@ -65,19 +64,6 @@ static se::CommandBuffer::CreateCommands CreateCommands( }; } -// Create callbacks to create a command buffer from command sequences. -static std::vector CreateCommands( - absl::Span commands, - const Thunk::ExecuteParams* execute_params, - const Command::RecordParams* record_params) { - std::vector create_commands; - for (const CommandExecutor& cmd : commands) { - create_commands.push_back( - CreateCommands(&cmd, execute_params, record_params)); - } - return create_commands; -} - // Create a callback to update a command buffer with command sequence. static se::CommandBuffer::UpdateCommands UpdateCommands( const CommandExecutor* commands, const Thunk::ExecuteParams* execute_params, @@ -88,19 +74,6 @@ static se::CommandBuffer::UpdateCommands UpdateCommands( }; } -// Create callbacks to update a command buffer with command sequence. -static std::vector UpdateCommands( - absl::Span commands, - const Thunk::ExecuteParams* execute_params, - const Command::RecordParams* record_params) { - std::vector update_commands; - for (const CommandExecutor& cmd : commands) { - update_commands.push_back( - UpdateCommands(&cmd, execute_params, record_params)); - } - return update_commands; -} - //===----------------------------------------------------------------------===// // Command::RecordAction helpers. //===----------------------------------------------------------------------===// @@ -121,79 +94,13 @@ static absl::StatusOr Handle( } if (auto* update = std::get_if(&action)) { - TF_RETURN_IF_ERROR(update_command(update->command)); + RETURN_IF_ERROR(update_command(update->command)); return update->command; } return Internal("Invalid record action"); } -//===----------------------------------------------------------------------===// -// CaseCmd -//===----------------------------------------------------------------------===// - -CaseCmd::CaseCmd(ShapedSlice index, std::vector branches) - : Command(CommandType::kCaseCmd), - index_(index), - index_is_bool_(index.shape.element_type() == PRED), - branches_(std::move(branches)) {} - -absl::Status CaseCmd::Initialize(const Thunk::InitializeParams& params) { - for (auto& branch : branches_) { - TF_RETURN_IF_ERROR(branch.Initialize(params)); - } - return absl::OkStatus(); -} - -absl::StatusOr CaseCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, RecordAction record_action, - se::CommandBuffer* command_buffer) { - se::DeviceAddressBase index = - execute_params.buffer_allocations->GetDeviceAddress(index_.slice); - - VLOG(5) << "CaseCmd:"; - VLOG(5) << " index: " << index_ << " (" << index.opaque() << ")"; - - return Handle( - std::move(record_action), - [&](absl::Span dependencies) { - if (index_is_bool_) { - return command_buffer->CreateCase( - se::DeviceAddress(index), - CreateCommands(branches_, &execute_params, &record_params), - dependencies); - } - return command_buffer->CreateCase( - se::DeviceAddress(index), - CreateCommands(branches_, &execute_params, &record_params), - dependencies); - }, - [&](const se::CommandBuffer::Command* command) { - if (index_is_bool_) { - return command_buffer->UpdateCase( - command, se::DeviceAddress(index), - UpdateCommands(branches_, &execute_params, &record_params)); - } - return command_buffer->UpdateCase( - command, se::DeviceAddress(index), - UpdateCommands(branches_, &execute_params, &record_params)); - }); -} - -Command::BufferUses CaseCmd::buffer_uses() const { - return {BufferUse::Read(index_.slice, index_.shape)}; -} - -absl::Status CaseCmd::WalkNested( - absl::FunctionRef callback) { - for (auto& branch : branches_) { - RETURN_IF_ERROR(branch.Walk( - [&](Command* cmd) -> absl::Status { return callback(cmd); })); - } - return absl::OkStatus(); -} - //===----------------------------------------------------------------------===// // WhileCmd //===----------------------------------------------------------------------===// @@ -209,8 +116,8 @@ WhileCmd::WhileCmd(BufferAllocation::Slice pred, CommandExecutor cond_commands, enable_loop_unroll_(enable_loop_unroll) {} absl::Status WhileCmd::Initialize(const Thunk::InitializeParams& params) { - TF_RETURN_IF_ERROR(cond_commands_.Initialize(params)); - TF_RETURN_IF_ERROR(body_commands_.Initialize(params)); + RETURN_IF_ERROR(cond_commands_.Initialize(params)); + RETURN_IF_ERROR(body_commands_.Initialize(params)); if (enable_loop_unroll_ && body_commands_.support_loop_unroll() && cond_commands_.support_loop_unroll() && trip_count_.has_value()) { is_unrolled_loop_ = true; @@ -224,8 +131,8 @@ absl::Status WhileCmd::Initialize(const Thunk::InitializeParams& params) { } absl::Status WhileCmd::Prepare(const Thunk::PrepareParams& params) { - TF_RETURN_IF_ERROR(cond_commands_.Prepare(params)); - TF_RETURN_IF_ERROR(body_commands_.Prepare(params)); + RETURN_IF_ERROR(cond_commands_.Prepare(params)); + RETURN_IF_ERROR(body_commands_.Prepare(params)); return absl::OkStatus(); } @@ -263,14 +170,14 @@ absl::StatusOr WhileCmd::Record( ScopedWhileLoop loop("record_fn", trip_count_); for (int64_t i = 0; i < *trip_count_; loop.IncLoopIteration(), ++i) { CommandExecutor::RecordId record_id(i); - TF_ASSIGN_OR_RETURN(dependencies, - cond_commands_.RecordCreate( - execute_params, new_record_params, - child_command_buffer, dependencies, record_id)); - TF_ASSIGN_OR_RETURN(dependencies, - body_commands_.RecordCreate( - execute_params, new_record_params, - child_command_buffer, dependencies, record_id)); + ASSIGN_OR_RETURN(dependencies, + cond_commands_.RecordCreate( + execute_params, new_record_params, + child_command_buffer, dependencies, record_id)); + ASSIGN_OR_RETURN(dependencies, + body_commands_.RecordCreate( + execute_params, new_record_params, + child_command_buffer, dependencies, record_id)); } return absl::OkStatus(); @@ -286,10 +193,10 @@ absl::StatusOr WhileCmd::Record( ScopedWhileLoop loop("record_fn", trip_count_); for (int64_t i = 0; i < *trip_count_; loop.IncLoopIteration(), ++i) { CommandExecutor::RecordId record_id(i); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cond_commands_.RecordUpdate(execute_params, new_record_params, child_command_buffer, record_id)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( body_commands_.RecordUpdate(execute_params, new_record_params, child_command_buffer, record_id)); } @@ -336,4 +243,9 @@ absl::Status WhileCmd::WalkNested( [&](Command* cmd) -> absl::Status { return callback(cmd); }); } +absl::Status WhileCmd::WalkNestedCommands(CommandWalker callback) { + RETURN_IF_ERROR(cond_commands_.Walk(callback)); + return body_commands_.Walk(callback); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index 59c4a89e1a79f2..5feec6515b8256 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -28,42 +28,10 @@ limitations under the License. #include "xla/backends/gpu/runtime/command_executor.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/shaped_slice.h" #include "xla/stream_executor/command_buffer.h" -#include "xla/xla_data.pb.h" namespace xla::gpu { -//===----------------------------------------------------------------------===// -// CaseCmd -//===----------------------------------------------------------------------===// - -class CaseCmd : public Command { - public: - CaseCmd(ShapedSlice index, std::vector branches); - - absl::Status Initialize(const Thunk::InitializeParams& params) override; - - absl::StatusOr Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, RecordAction record_action, - se::CommandBuffer* command_buffer) override; - - BufferUses buffer_uses() const override; - - absl::Status WalkNested( - absl::FunctionRef callback) override; - - private: - ShapedSlice index_; - bool index_is_bool_; - std::vector branches_; -}; - -//===----------------------------------------------------------------------===// -// WhileCmd -//===----------------------------------------------------------------------===// - class WhileCmd : public Command { public: WhileCmd(BufferAllocation::Slice pred, CommandExecutor cond_commands, @@ -86,6 +54,8 @@ class WhileCmd : public Command { absl::FunctionRef callback) override; private: + absl::Status WalkNestedCommands(CommandWalker callback) override; + BufferAllocation::Slice pred_; CommandExecutor cond_commands_; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc index f6eb3940ee8f11..51cca4b999d09d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc @@ -103,8 +103,8 @@ static absl::StatusOr> Convert( std::move(body_cmds), thunk.trip_count(), options.enable_loop_unroll); } -static absl::StatusOr> Convert( - const ConditionalThunk& thunk, const ConvertToCommandsOptions& options) { +static absl::Status SetOrUpdateCommandBufferBranchExecutors( + ConditionalThunk& thunk, const ConvertToCommandsOptions& options) { std::vector branch_cmds; branch_cmds.reserve(thunk.branch_executors().size()); if (thunk.branch_index_is_bool()) { @@ -118,14 +118,13 @@ static absl::StatusOr> Convert( branch_cmds.emplace_back(), ConvertToCommands(thunk.branch_executors()[0].thunks(), options)); } else { - for (auto& branch_thunk : thunk.branch_executors()) { + for (const ThunkExecutor& branch_thunk : thunk.branch_executors()) { ASSIGN_OR_RETURN(CommandExecutor cmds, ConvertToCommands(branch_thunk.thunks(), options)); branch_cmds.emplace_back(std::move(cmds)); } } - return std::make_unique(thunk.branch_index_buffer(), - std::move(branch_cmds)); + return thunk.SetOrUpdateCommandBufferBranchExecutors(std::move(branch_cmds)); } //===----------------------------------------------------------------------===// @@ -164,8 +163,13 @@ static absl::Status AppendCommands(ConversionContext& ctx, }; switch (thunk.kind()) { - case Thunk::Kind::kConditional: - return append(Convert(thunk, options)); + case Thunk::Kind::kConditional: { + auto& conditional_thunk = static_cast(thunk); + TF_RETURN_IF_ERROR( + SetOrUpdateCommandBufferBranchExecutors(conditional_thunk, options)); + cmd_sequence.Append(&conditional_thunk); + return absl::OkStatus(); + } case Thunk::Kind::kCopy: cmd_sequence.Append(static_cast(&thunk)); return absl::OkStatus(); @@ -507,7 +511,7 @@ absl::Status AppendCommandsInConcurrentRegions( ConcurrentRegionScheduler& scheduler = concurrent_region_schedules.emplace_back(absl::MakeSpan(thunks)); for (Thunk* thunk : scheduler.scheduled_thunks()) { - TF_RETURN_IF_ERROR(AppendCommands(ctx, cmd_sequence, *thunk, options)); + RETURN_IF_ERROR(AppendCommands(ctx, cmd_sequence, *thunk, options)); int64_t index = cmd_sequence.size() - 1; thunk_to_index[thunk] = index; } diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter_test.cc index 1b30f15ce89048..3000cc6e6fc490 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_emitter_test.cc @@ -26,18 +26,21 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_executor.h" +#include "xla/backends/gpu/runtime/conditional_thunk.h" #include "xla/backends/gpu/runtime/kernel_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk_id.h" #include "xla/codegen/emitters/kernel_arguments.h" #include "xla/runtime/execution_graph.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/shaped_slice.h" #include "xla/shape_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" namespace xla::gpu { +using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; MATCHER_P(HasEdgeTo, node_id, "") { return arg.id == node_id; } @@ -405,4 +408,146 @@ TEST_F(CommandBufferCmdEmitterTest, ConcurrentRegionsScheduleHasLaneAffinity) { node_id["e"]); } +TEST_F(CommandBufferCmdEmitterTest, ConvertsConditionalThunkToCommand) { + BufferAllocation branch_index_alloc(/*index=*/0, /*size=*/sizeof(int32_t), + /*color=*/0); + BufferAllocation data_alloc(/*index=*/1, /*size=*/2 * 1024, /*color=*/0); + + BufferAllocation::Slice branch_index_slice(&branch_index_alloc, /*offset=*/0, + /*size=*/sizeof(int32_t)); + BufferAllocation::Slice branch0_slice(&data_alloc, /*offset=*/0, + /*size=*/1024); + BufferAllocation::Slice branch1_slice(&data_alloc, /*offset=*/1024, + /*size=*/1024); + + ThunkSequence branch0; + branch0.push_back(std::make_unique(NextThunkInfo("branch0"), + branch0_slice)); + + ThunkSequence branch1; + branch1.push_back(std::make_unique(NextThunkInfo("branch1"), + branch1_slice)); + + std::vector branches; + branches.push_back(std::move(branch0)); + branches.push_back(std::move(branch1)); + + auto conditional = std::make_unique( + NextThunkInfo("conditional"), + ShapedSlice{branch_index_slice, ShapeUtil::MakeShape(S32, {})}, + std::move(branches)); + ConditionalThunk* conditional_ptr = conditional.get(); + + ThunkSequence thunks; + thunks.push_back(std::move(conditional)); + + ASSERT_OK_AND_ASSIGN(CommandExecutor commands, + ConvertToCommands(thunks, ConvertToCommandsOptions())); + + EXPECT_EQ(conditional_ptr->command_type(), CommandType::kCaseCmd); + + std::vector command_names; + CHECK_OK(commands.Walk([&](Command* command) { + command_names.push_back(std::string(command->profile_annotation())); + return absl::OkStatus(); + })); + EXPECT_THAT(command_names, ElementsAre("conditional", "branch0", "branch1")); +} + +TEST_F(CommandBufferCmdEmitterTest, ConvertsConditionalThunkRepeatedly) { + BufferAllocation branch_index_alloc(/*index=*/0, /*size=*/sizeof(int32_t), + /*color=*/0); + BufferAllocation data_alloc(/*index=*/1, /*size=*/2 * 1024, /*color=*/0); + + BufferAllocation::Slice branch_index_slice(&branch_index_alloc, /*offset=*/0, + /*size=*/sizeof(int32_t)); + BufferAllocation::Slice branch0_slice(&data_alloc, /*offset=*/0, + /*size=*/1024); + BufferAllocation::Slice branch1_slice(&data_alloc, /*offset=*/1024, + /*size=*/1024); + + ThunkSequence branch0; + branch0.push_back(std::make_unique(NextThunkInfo("branch0"), + branch0_slice)); + + ThunkSequence branch1; + branch1.push_back(std::make_unique(NextThunkInfo("branch1"), + branch1_slice)); + + std::vector branches; + branches.push_back(std::move(branch0)); + branches.push_back(std::move(branch1)); + + ThunkSequence thunks; + thunks.push_back(std::make_unique( + NextThunkInfo("conditional"), + ShapedSlice{branch_index_slice, ShapeUtil::MakeShape(S32, {})}, + std::move(branches))); + + auto collect_command_names = [](CommandExecutor& commands) { + std::vector command_names; + CHECK_OK(commands.Walk([&](Command* command) { + command_names.push_back(std::string(command->profile_annotation())); + return absl::OkStatus(); + })); + return command_names; + }; + + ASSERT_OK_AND_ASSIGN(CommandExecutor first_commands, + ConvertToCommands(thunks, ConvertToCommandsOptions())); + EXPECT_THAT(collect_command_names(first_commands), + ElementsAre("conditional", "branch0", "branch1")); + + ConvertToCommandsOptions concurrent_options; + concurrent_options.synchronization_mode = + CommandExecutor::SynchronizationMode::kConcurrent; + ASSERT_OK_AND_ASSIGN(CommandExecutor second_commands, + ConvertToCommands(thunks, concurrent_options)); + ASSERT_TRUE(second_commands.execution_graph().has_value()); + EXPECT_THAT(collect_command_names(second_commands), + ElementsAre("conditional", "branch0", "branch1")); +} + +TEST_F(CommandBufferCmdEmitterTest, + ConvertsBoolConditionalBranchesInCaseOrder) { + BufferAllocation branch_index_alloc(/*index=*/0, /*size=*/sizeof(bool), + /*color=*/0); + BufferAllocation data_alloc(/*index=*/1, /*size=*/2 * 1024, /*color=*/0); + + BufferAllocation::Slice branch_index_slice(&branch_index_alloc, /*offset=*/0, + /*size=*/sizeof(bool)); + BufferAllocation::Slice false_slice(&data_alloc, /*offset=*/0, /*size=*/1024); + BufferAllocation::Slice true_slice(&data_alloc, /*offset=*/1024, + /*size=*/1024); + + ThunkSequence false_branch; + false_branch.push_back(std::make_unique( + NextThunkInfo("false_branch"), false_slice)); + + ThunkSequence true_branch; + true_branch.push_back(std::make_unique( + NextThunkInfo("true_branch"), true_slice)); + + std::vector branches; + branches.push_back(std::move(false_branch)); + branches.push_back(std::move(true_branch)); + + ThunkSequence thunks; + thunks.push_back(std::make_unique( + NextThunkInfo("conditional"), + ShapedSlice{branch_index_slice, ShapeUtil::MakeShape(PRED, {})}, + std::move(branches))); + + ASSERT_OK_AND_ASSIGN(CommandExecutor commands, + ConvertToCommands(thunks, ConvertToCommandsOptions())); + + std::vector command_names; + CHECK_OK(commands.Walk([&](Command* command) { + command_names.push_back(std::string(command->profile_annotation())); + return absl::OkStatus(); + })); + EXPECT_THAT(command_names, + ElementsAre("conditional", "true_branch", "false_branch")); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index 5c366f8f845c1f..8cc3142d7eb1c5 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_executor.h" #include "xla/backends/gpu/runtime/command_state.h" @@ -420,7 +421,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { // buffer empty. int64_t num_calls = 0; auto trace = [&](se::Stream* stream) -> absl::Status { - TF_RETURN_IF_ERROR(stream->Memset32(&mem, 42, 16)); + RETURN_IF_ERROR(stream->Memset32(&mem, 42, 16)); num_calls++; return absl::OkStatus(); }; diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc index 63aa9b83704b8e..91426e6dd168c3 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_executor.h" #include "xla/backends/gpu/runtime/sequential_thunk.h" @@ -157,7 +158,7 @@ absl::Status CommandBufferThunk::Prepare(const PrepareParams& params) { // Always prepare thunks if they are present so we are ready to fall back // on them if we detect profiling activity. if (thunks_) { - TF_RETURN_IF_ERROR(thunks_->Prepare(params)); + RETURN_IF_ERROR(thunks_->Prepare(params)); } // TODO(b/290773547): Disabled CUDA graphs when profiling is active because of @@ -170,7 +171,7 @@ absl::Status CommandBufferThunk::Prepare(const PrepareParams& params) { return absl::OkStatus(); } - TF_RETURN_IF_ERROR(commands_.Prepare(params)); + RETURN_IF_ERROR(commands_.Prepare(params)); return absl::OkStatus(); } @@ -183,12 +184,12 @@ absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) { } // Initialize commands. - TF_RETURN_IF_ERROR(commands_.Initialize(params)); + RETURN_IF_ERROR(commands_.Initialize(params)); // Always initialize thunks if they are present so we are ready to fall back // on them if we detect profiling activity. if (thunks_) { - TF_RETURN_IF_ERROR(thunks_->Initialize(params)); + RETURN_IF_ERROR(thunks_->Initialize(params)); } // TODO(b/290773547): Disabled CUDA graphs when profiling is active because of @@ -201,7 +202,7 @@ absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::shared_ptr cmd_buffer, GetOrCreateCommandBuffer(params.executor, *params.buffer_allocations)); absl::MutexLock lock(cmd_buffer->mutex); @@ -265,8 +266,8 @@ absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) { /*is_initialization=*/true, /*command_buffer_update_mode=*/ command_buffer_update_mode_}; - TF_RETURN_IF_ERROR(commands_.Record(execute_params, record_params, - cmd_buffer->command_buffer.get())); + RETURN_IF_ERROR(commands_.Record(execute_params, record_params, + cmd_buffer->command_buffer.get())); uint64_t end_micros = tsl::Env::Default()->NowMicros(); VLOG(3) << "Initialized command buffer on device #" @@ -297,7 +298,7 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { } se::StreamExecutor* executor = params.stream->parent(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::shared_ptr cmd_buffer, GetOrCreateCommandBuffer(executor, *params.buffer_allocations)); @@ -306,7 +307,7 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { // warm up iteration, run through thunks if they are present. if (!cmd_buffer->warmup_done && thunks_) { VLOG(2) << "Executing warm up iteration of command buffer thunk"; - TF_RETURN_IF_ERROR(thunks_->ExecuteOnStream(params)); + RETURN_IF_ERROR(thunks_->ExecuteOnStream(params)); cmd_buffer->warmup_done = true; return absl::OkStatus(); } @@ -349,8 +350,8 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { /*is_initialization=*/is_first_record, /*command_buffer_update_mode=*/ command_buffer_update_mode_}; - TF_RETURN_IF_ERROR(commands_.Record(params, record_params, - cmd_buffer->command_buffer.get())); + RETURN_IF_ERROR(commands_.Record(params, record_params, + cmd_buffer->command_buffer.get())); uint64_t end_micros = tsl::Env::Default()->NowMicros(); XLA_VLOG_DEVICE(3, executor->device_ordinal()) @@ -414,9 +415,8 @@ CommandBufferThunk::GetOrCreateCommandBuffer( } // Create a new empty command buffer. - TF_ASSIGN_OR_RETURN( - auto command_buffer, - executor->CreateCommandBuffer(se::CommandBuffer::Mode::kPrimary)); + ASSIGN_OR_RETURN(auto command_buffer, executor->CreateCommandBuffer( + se::CommandBuffer::Mode::kPrimary)); auto emplaced = state_->command_buffers.emplace( key, std::make_shared(std::move(command_buffer))); // With kNumVaReservationSets=2, at most 2 command buffers should exist per @@ -491,7 +491,7 @@ void CommandBufferThunk::EvictCommandBuffers() { absl::Status CommandBufferThunk::WalkNested(Walker callback) { if (thunks_ != nullptr) { - TF_RETURN_IF_ERROR(thunks_->Walk(callback)); + RETURN_IF_ERROR(thunks_->Walk(callback)); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index e56ee2b6ea603b..9120c8f5e9b7cd 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -30,10 +30,13 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/codegen/kernels/custom_kernel.h" #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_buffer_cmd.h" +#include "xla/backends/gpu/runtime/command_buffer_cmd_emitter.h" #include "xla/backends/gpu/runtime/command_executor.h" +#include "xla/backends/gpu/runtime/conditional_thunk.h" #include "xla/backends/gpu/runtime/device_to_device_copy_thunk.h" #include "xla/backends/gpu/runtime/gemm_thunk.h" #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" @@ -97,7 +100,7 @@ struct OwningExecutableSource { }; absl::StatusOr ExecutableSource() { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector fatbin, se::gpu::GetGpuTestKernelsFatbin(GpuExecutor()->GetPlatform()->Name())); return OwningExecutableSource{/*text=*/{}, @@ -1086,7 +1089,7 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { ASSERT_EQ(dst, std::vector(4, 21 + 21)); } -TEST(CommandBufferThunkTest, CaseCmd) { +TEST(CommandBufferThunkTest, ConditionalThunkCaseCommand) { se::StreamExecutor* stream_executor = GpuExecutor(); if (!IsAtLeastCuda12300(stream_executor)) { @@ -1122,8 +1125,8 @@ TEST(CommandBufferThunkTest, CaseCmd) { BufferAllocation::Slice slice_a(&alloc_a, 0, byte_length); BufferAllocation::Slice slice_b(&alloc_b, 0, byte_length); - // Prepare commands sequence for branches. - std::vector branches_sequence(2); + // Prepare thunk sequences for branches. + std::vector branch_thunks(2); auto args_access = {MemoryAccess::kRead, MemoryAccess::kRead, MemoryAccess::kWrite}; @@ -1131,7 +1134,7 @@ TEST(CommandBufferThunkTest, CaseCmd) { { // Case 0: b = a + a std::vector args{ {slice_a, shape}, {slice_a, shape}, {slice_b, shape}}; - branches_sequence[0].Append(KernelThunk::MakeKernelThunk( + branch_thunks[0].push_back(KernelThunk::MakeKernelThunk( "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0)); } @@ -1139,28 +1142,26 @@ TEST(CommandBufferThunkTest, CaseCmd) { { // Case 1: b = b + b std::vector args{ {slice_b, shape}, {slice_b, shape}, {slice_b, shape}}; - branches_sequence[1].Append(KernelThunk::MakeKernelThunk( + branch_thunks[1].push_back(KernelThunk::MakeKernelThunk( "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0)); } - std::vector branches(2); - TF_ASSERT_OK_AND_ASSIGN( - branches[0], - CommandExecutor::Create(std::move(branches_sequence[0]), serialize)); - TF_ASSERT_OK_AND_ASSIGN( - branches[1], - CommandExecutor::Create(std::move(branches_sequence[1]), serialize)); + // Prepare thunk sequence for command buffer conversion. + ThunkSequence thunks; + thunks.push_back(std::make_unique( + Thunk::ThunkInfo(), ShapedSlice{slice_i, i_shape}, + std::move(branch_thunks))); - // Prepare commands sequence for thunk. - CommandSequence commands; - commands.Emplace(ShapedSlice{slice_i, i_shape}, std::move(branches)); - TF_ASSERT_OK_AND_ASSIGN( - CommandExecutor executor, - CommandExecutor::Create(std::move(commands), serialize)); + ConvertToCommandsOptions options; + options.synchronization_mode = serialize; + ASSERT_OK_AND_ASSIGN(CommandExecutor executor, + ConvertToCommands(thunks, options)); - // Construct a thunk with command sequence. - CommandBufferThunk thunk(std::move(executor), Thunk::ThunkInfo()); + // Construct a command buffer thunk with command sequence and fallback thunks. + CommandBufferThunk thunk( + std::move(executor), Thunk::ThunkInfo(), + std::make_unique(Thunk::ThunkInfo(), std::move(thunks))); ServiceExecutableRunOptions run_options; stream_executor::StreamExecutorAddressAllocator allocator(stream_executor); diff --git a/third_party/xla/xla/backends/gpu/runtime/command_executor.cc b/third_party/xla/xla/backends/gpu/runtime/command_executor.cc index 3d52741e21b24e..27222c1fec7b09 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_executor.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_executor.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/annotation.h" #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_state.h" @@ -264,11 +265,11 @@ absl::StatusOr CommandExecutor::Create( // sequence of commands and derive the structure of command dependencies // from the buffer use conflicts. if (synchronization_mode != SynchronizationMode::kSerialize) { - TF_ASSIGN_OR_RETURN(auto operations, - CreateCommandOperations(commands, synchronization_mode, - extra_resources)); - TF_ASSIGN_OR_RETURN(execution_graph, - ExecutionGraph::Create(operations)); + ASSIGN_OR_RETURN(auto operations, + CreateCommandOperations(commands, synchronization_mode, + extra_resources)); + ASSIGN_OR_RETURN(execution_graph, + ExecutionGraph::Create(operations)); VLOG(3) << "Execution graph: " << execution_graph->ToString(); } @@ -316,7 +317,7 @@ CommandExecutor::CommandExecutor( absl::Status CommandExecutor::Prepare(const Thunk::PrepareParams& params) { for (auto& command : commands_) { - TF_RETURN_IF_ERROR(command->Prepare(params)); + RETURN_IF_ERROR(command->Prepare(params)); } return absl::OkStatus(); } @@ -324,7 +325,7 @@ absl::Status CommandExecutor::Prepare(const Thunk::PrepareParams& params) { absl::Status CommandExecutor::Initialize( const Thunk::InitializeParams& params) { for (auto& command : commands_) { - TF_RETURN_IF_ERROR(command->Initialize(params)); + RETURN_IF_ERROR(command->Initialize(params)); } return absl::OkStatus(); } @@ -434,17 +435,16 @@ absl::Status CommandExecutor::Record(const Thunk::ExecuteParams& execute_params, se::CommandBuffer* command_buffer, RecordId record_id) { if (command_buffer->state() == se::CommandBuffer::State::kFinalized) { - TF_RETURN_IF_ERROR(command_buffer->Update()); + RETURN_IF_ERROR(command_buffer->Update()); } if (command_buffer->state() == se::CommandBuffer::State::kUpdate) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( RecordUpdate(execute_params, record_params, command_buffer, record_id)); } else { - TF_RETURN_IF_ERROR(RecordCreate(execute_params, record_params, - command_buffer, /*dependencies=*/{}, - record_id) - .status()); + RETURN_IF_ERROR(RecordCreate(execute_params, record_params, command_buffer, + /*dependencies=*/{}, record_id) + .status()); } return command_buffer->Finalize(); @@ -458,8 +458,8 @@ CommandExecutor::RecordCreate( absl::Span dependencies, RecordId record_id) const { // Command buffer must be in create state. - TF_RETURN_IF_ERROR(CheckCommandBufferState( - command_buffer, se::CommandBuffer::State::kCreate)); + RETURN_IF_ERROR(CheckCommandBufferState(command_buffer, + se::CommandBuffer::State::kCreate)); VLOG(1) << absl::StreamFormat( "Record create %d commands into command buffer %p: dependencies=%d, " @@ -509,10 +509,9 @@ CommandExecutor::RecordCreate( ? Command::RecordCreate{dependencies} : Command::RecordCreate{command_dependencies}; - TF_ASSIGN_OR_RETURN( - const se::CommandBuffer::Command* recorded_command, - command->Record(execute_params, record_params, std::move(record_action), - command_buffer)); + ASSIGN_OR_RETURN(const se::CommandBuffer::Command* recorded_command, + command->Record(execute_params, record_params, + std::move(record_action), command_buffer)); // Collect sink commands as external dependencies for the next command // sequence recorded into the same command buffer. @@ -546,8 +545,8 @@ absl::Status CommandExecutor::RecordUpdate( uint64_t start_micros = tsl::Env::Default()->NowMicros(); // Command buffer must be already prepared for recording updates. - TF_RETURN_IF_ERROR(CheckCommandBufferState( - command_buffer, se::CommandBuffer::State::kUpdate)); + RETURN_IF_ERROR(CheckCommandBufferState(command_buffer, + se::CommandBuffer::State::kUpdate)); // Short-circuit if there are no commands to update. if (commands_.empty()) { @@ -640,10 +639,9 @@ absl::Status CommandExecutor::RecordUpdate( } Command::RecordUpdate record_action{recorded_commands[id]}; - TF_ASSIGN_OR_RETURN( - recorded_commands[id], - command->Record(execute_params, record_params, std::move(record_action), - command_buffer)); + ASSIGN_OR_RETURN(recorded_commands[id], + command->Record(execute_params, record_params, + std::move(record_action), command_buffer)); } uint64_t end_micros = tsl::Env::Default()->NowMicros(); @@ -740,9 +738,9 @@ absl::StatusOr CommandExecutor::RenderExecutionGraph() const { "concurrent/LHS synchronization mode"); } - TF_ASSIGN_OR_RETURN(auto operations, - CreateCommandOperations(commands_, synchronization_mode_, - extra_resources_)); + ASSIGN_OR_RETURN(auto operations, + CreateCommandOperations(commands_, synchronization_mode_, + extra_resources_)); absl::InlinedVector operations_ptrs; operations_ptrs.reserve(operations.size()); for (const auto& operation : operations) { diff --git a/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.cc index 117d268bc4206f..aa527f711cf294 100644 --- a/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.cc @@ -39,16 +39,92 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/shaped_slice.h" #include "xla/status_macros.h" +#include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_address.h" #include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::gpu { +namespace { + +// Create a callback to create a command buffer from a command sequence. +se::CommandBuffer::CreateCommands CreateCommands( + const CommandExecutor* commands, const Thunk::ExecuteParams* execute_params, + const Command::RecordParams* record_params) { + return [=](se::CommandBuffer* command_buffer, + absl::Span dependencies) { + return commands->RecordCreate(*execute_params, *record_params, + command_buffer, dependencies); + }; +} + +// Create callbacks to create command buffers from command sequences. +std::vector CreateCommands( + absl::Span commands, + const Thunk::ExecuteParams* execute_params, + const Command::RecordParams* record_params) { + std::vector create_commands; + create_commands.reserve(commands.size()); + for (const CommandExecutor& cmd : commands) { + create_commands.push_back( + CreateCommands(&cmd, execute_params, record_params)); + } + return create_commands; +} + +// Create a callback to update a command buffer with a command sequence. +se::CommandBuffer::UpdateCommands UpdateCommands( + const CommandExecutor* commands, const Thunk::ExecuteParams* execute_params, + const Command::RecordParams* record_params) { + return [=](se::CommandBuffer* command_buffer) { + return commands->RecordUpdate(*execute_params, *record_params, + command_buffer); + }; +} + +// Create callbacks to update command buffers with command sequences. +std::vector UpdateCommands( + absl::Span commands, + const Thunk::ExecuteParams* execute_params, + const Command::RecordParams* record_params) { + std::vector update_commands; + update_commands.reserve(commands.size()); + for (const CommandExecutor& cmd : commands) { + update_commands.push_back( + UpdateCommands(&cmd, execute_params, record_params)); + } + return update_commands; +} + +using CreateCommand = + absl::FunctionRef( + absl::Span dependencies)>; + +using UpdateCommand = + absl::FunctionRef; + +absl::StatusOr HandleRecordAction( + Command::RecordAction action, CreateCommand create_command, + UpdateCommand update_command) { + if (auto* create = std::get_if(&action)) { + return create_command(create->dependencies); + } + + if (auto* update = std::get_if(&action)) { + TF_RETURN_IF_ERROR(update_command(update->command)); + return update->command; + } + + return Internal("Invalid record action"); +} + +} // namespace + ConditionalThunk::ConditionalThunk(ThunkInfo thunk_info, const ShapedSlice& branch_index_buffer_index, std::vector branch_thunks) - : Thunk(Kind::kConditional, thunk_info), + : Command(CommandType::kCaseCmd, Kind::kConditional, std::move(thunk_info)), branch_index_buffer_index_(branch_index_buffer_index), branch_index_is_bool_(branch_index_buffer_index.shape.element_type() == PRED) { @@ -71,6 +147,11 @@ absl::Status ConditionalThunk::Prepare(const PrepareParams& params) { for (auto& branch_executor : branch_executors_) { RETURN_IF_ERROR(branch_executor.Prepare(params)); } + if (command_branch_executors_.has_value()) { + for (CommandExecutor& branch_executor : *command_branch_executors_) { + RETURN_IF_ERROR(branch_executor.Prepare(params)); + } + } return absl::OkStatus(); } @@ -83,6 +164,11 @@ absl::Status ConditionalThunk::Initialize(const InitializeParams& params) { for (auto& branch_executor : branch_executors_) { RETURN_IF_ERROR(branch_executor.Initialize(params)); } + if (command_branch_executors_.has_value()) { + for (CommandExecutor& branch_executor : *command_branch_executors_) { + RETURN_IF_ERROR(branch_executor.Initialize(params)); + } + } absl::MutexLock lock(mutex_); @@ -97,6 +183,64 @@ absl::Status ConditionalThunk::Initialize(const InitializeParams& params) { return absl::OkStatus(); } +absl::StatusOr ConditionalThunk::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { + if (!command_branch_executors_.has_value()) { + return FailedPrecondition( + "ConditionalThunk command-buffer branches are not initialized"); + } + + se::DeviceAddressBase branch_index = + execute_params.buffer_allocations->GetDeviceAddress( + branch_index_buffer_index_.slice); + + VLOG(5) << "ConditionalThunk::Record:"; + VLOG(5) << " branch_index: " << branch_index_buffer_index_ << " (" + << branch_index.opaque() << ")"; + + absl::Span command_branches = + absl::MakeConstSpan(*command_branch_executors_); + + return HandleRecordAction( + std::move(record_action), + [&](absl::Span dependencies) { + if (branch_index_is_bool_) { + return command_buffer->CreateCase( + se::DeviceAddress(branch_index), + CreateCommands(command_branches, &execute_params, &record_params), + dependencies); + } + return command_buffer->CreateCase( + se::DeviceAddress(branch_index), + CreateCommands(command_branches, &execute_params, &record_params), + dependencies); + }, + [&](const se::CommandBuffer::Command* command) { + if (branch_index_is_bool_) { + return command_buffer->UpdateCase( + command, se::DeviceAddress(branch_index), + UpdateCommands(command_branches, &execute_params, + &record_params)); + } + return command_buffer->UpdateCase( + command, se::DeviceAddress(branch_index), + UpdateCommands(command_branches, &execute_params, &record_params)); + }); +} + +absl::Status ConditionalThunk::SetOrUpdateCommandBufferBranchExecutors( + std::vector branch_executors) { + if (branch_index_is_bool_) { + TF_RET_CHECK(branch_executors.size() == 2); + } else { + TF_RET_CHECK(!branch_executors.empty()); + } + command_branch_executors_ = std::move(branch_executors); + return absl::OkStatus(); +} + absl::Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { auto& stream = *params.stream; @@ -162,6 +306,16 @@ absl::Status ConditionalThunk::WalkNested(Walker callback) { return absl::OkStatus(); } +absl::Status ConditionalThunk::WalkNestedCommands(CommandWalker callback) { + if (!command_branch_executors_.has_value()) { + return absl::OkStatus(); + } + for (CommandExecutor& branch_executor : *command_branch_executors_) { + RETURN_IF_ERROR(branch_executor.Walk(callback)); + } + return absl::OkStatus(); +} + absl::Status ConditionalThunk::TransformNested(Transformer callback) { for (ThunkExecutor& branch_executor : branch_executors_) { RETURN_IF_ERROR(branch_executor.thunks().TransformNested(callback)); diff --git a/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.h b/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.h index e9694ae5fd31fe..3f05bc710e8b36 100644 --- a/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/conditional_thunk.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_BACKENDS_GPU_RUNTIME_CONDITIONAL_THUNK_H_ #include +#include #include #include @@ -26,6 +27,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/gpu/runtime/command.h" +#include "xla/backends/gpu/runtime/command_executor.h" #include "xla/backends/gpu/runtime/host_memory_pool.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" @@ -47,7 +50,7 @@ namespace xla::gpu { // instruction of the true computation share the same allocation. Similarly, the // buffers of the false operand and that of the parameter instruction of the // false computation share the same allocation. -class ConditionalThunk : public Thunk { +class ConditionalThunk : public Command { public: ConditionalThunk(ThunkInfo thunk_info, const ShapedSlice& branch_index_buffer_index, @@ -59,6 +62,13 @@ class ConditionalThunk : public Thunk { absl::Status Prepare(const PrepareParams& params) override; absl::Status Initialize(const InitializeParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override; + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) override; + + absl::Status SetOrUpdateCommandBufferBranchExecutors( + std::vector branch_executors); absl::Span branch_executors() const { return branch_executors_; @@ -102,8 +112,11 @@ class ConditionalThunk : public Thunk { std::string ToString(int indent) const override; private: + absl::Status WalkNestedCommands(CommandWalker callback) override; + const ShapedSlice branch_index_buffer_index_; std::vector branch_executors_; + std::optional> command_branch_executors_; bool branch_index_is_bool_; // Host memory pool for transferring predicate value from device to host. diff --git a/third_party/xla/xla/backends/gpu/runtime/conditional_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/conditional_thunk_test.cc index e344463b26bf9c..8e4848fb722c9d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/conditional_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/conditional_thunk_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include "xla/backends/gpu/runtime/conditional_thunk.h" +#include +#include #include #include #include +#include #include #include @@ -25,14 +28,23 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "google/protobuf/text_format.h" +#include "xla/backends/gpu/runtime/command.h" +#include "xla/backends/gpu/runtime/command_executor.h" +#include "xla/backends/gpu/runtime/command_state.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/backends/gpu/runtime/thunk_executor.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_slice.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/mock_command_buffer.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" @@ -57,8 +69,8 @@ struct DummyThunk : public Thunk { } static absl::StatusOr> FromProto( const ThunkProto& thunk_proto, Thunk::Kind kind) { - TF_ASSIGN_OR_RETURN(Thunk::ThunkInfo thunk_info, - Thunk::ThunkInfo::FromProto(thunk_proto.thunk_info())); + ASSIGN_OR_RETURN(Thunk::ThunkInfo thunk_info, + Thunk::ThunkInfo::FromProto(thunk_proto.thunk_info())); return std::make_unique(kind, std::move(thunk_info)); } @@ -77,6 +89,90 @@ std::unique_ptr CreateConditionalThunk( thunk_info, branch_index_buffer_index, std::move(branch_thunk_sequences)); } +struct BranchRecordCounts { + int prepares = 0; + int creates = 0; + int updates = 0; +}; + +struct FakeSeCommand : public se::CommandBuffer::Command {}; + +class BranchRecordingCommand : public Command { + public: + explicit BranchRecordingCommand(BranchRecordCounts* counts) + : Command(CommandType::kUnknownCmd), counts_(counts) {} + + absl::Status Prepare(const PrepareParams&) override { + ++counts_->prepares; + return absl::OkStatus(); + } + + absl::StatusOr Record( + const Thunk::ExecuteParams&, const RecordParams&, RecordAction action, + se::CommandBuffer*) override { + if (std::get_if(&action) != nullptr) { + ++counts_->creates; + return &recorded_command_; + } + + auto* update = std::get_if(&action); + if (update == nullptr) { + return absl::InternalError("unexpected record action"); + } + if (update->command != &recorded_command_) { + return absl::InternalError("unexpected recorded command"); + } + ++counts_->updates; + return &recorded_command_; + } + + BufferUses buffer_uses() const override { return {}; } + + private: + BranchRecordCounts* counts_; + FakeSeCommand recorded_command_; +}; + +absl::StatusOr MakeBranchExecutor(BranchRecordCounts* counts) { + CommandSequence commands; + commands.Append(std::make_unique(counts)); + return CommandExecutor::Create( + std::move(commands), CommandExecutor::SynchronizationMode::kSerialize); +} + +struct BranchCommandBuffer { + std::unique_ptr> command_buffer = + std::make_unique>(); + se::CommandBuffer::State state = se::CommandBuffer::State::kCreate; +}; + +void ConfigureNestedCommandBuffer(BranchCommandBuffer* branch) { + using Mode = se::CommandBuffer::Mode; + using State = se::CommandBuffer::State; + + ON_CALL(*branch->command_buffer, mode()) + .WillByDefault(testing::Return(Mode::kNested)); + ON_CALL(*branch->command_buffer, state()).WillByDefault([branch] { + return branch->state; + }); + ON_CALL(*branch->command_buffer, Finalize()).WillByDefault([branch] { + if (branch->state != State::kCreate && branch->state != State::kUpdate) { + return absl::FailedPreconditionError( + "command buffer is not in create/update state"); + } + branch->state = State::kFinalized; + return absl::OkStatus(); + }); + ON_CALL(*branch->command_buffer, Update()).WillByDefault([branch] { + if (branch->state != State::kFinalized) { + return absl::FailedPreconditionError( + "command buffer is not in finalized state"); + } + branch->state = State::kUpdate; + return absl::OkStatus(); + }); +} + TEST(ConditionalThunkTest, BufferUses) { Thunk::ThunkInfo thunk_info; thunk_info.profile_annotation = "profile_annotation"; @@ -112,6 +208,177 @@ TEST(ConditionalThunkTest, BufferUses) { ElementsAre(branch_matcher, branch_matcher)); } +TEST(ConditionalThunkTest, PreparePropagatesToCommandBufferBranchExecutors) { + BufferAllocation branch_index_alloc(/*index=*/0, /*size=*/sizeof(int32_t), + /*color=*/0); + BufferAllocation::Slice branch_index_slice(&branch_index_alloc, /*offset=*/0, + /*size=*/sizeof(int32_t)); + + std::vector branch_thunks(2); + ConditionalThunk thunk( + Thunk::ThunkInfo(), + ShapedSlice{branch_index_slice, ShapeUtil::MakeShape(S32, {})}, + std::move(branch_thunks)); + + BranchRecordCounts branch0_counts; + BranchRecordCounts branch1_counts; + std::vector branch_executors; + ASSERT_OK_AND_ASSIGN(CommandExecutor branch0_executor, + MakeBranchExecutor(&branch0_counts)); + ASSERT_OK_AND_ASSIGN(CommandExecutor branch1_executor, + MakeBranchExecutor(&branch1_counts)); + branch_executors.push_back(std::move(branch0_executor)); + branch_executors.push_back(std::move(branch1_executor)); + ASSERT_OK(thunk.SetOrUpdateCommandBufferBranchExecutors( + std::move(branch_executors))); + + Thunk::PrepareParams prepare_params; + ASSERT_OK(thunk.Prepare(prepare_params)); + + EXPECT_EQ(branch0_counts.prepares, 1); + EXPECT_EQ(branch1_counts.prepares, 1); +} + +TEST(ConditionalThunkTest, RecordCreatesAndUpdatesCommandBufferCase) { + BufferAllocation branch_index_alloc(/*index=*/0, /*size=*/sizeof(int32_t), + /*color=*/0); + BufferAllocation::Slice branch_index_slice(&branch_index_alloc, /*offset=*/0, + /*size=*/sizeof(int32_t)); + + std::vector branch_thunks(2); + ConditionalThunk thunk( + Thunk::ThunkInfo(), + ShapedSlice{branch_index_slice, ShapeUtil::MakeShape(S32, {})}, + std::move(branch_thunks)); + + BranchRecordCounts branch0_counts; + BranchRecordCounts branch1_counts; + std::vector branch_executors; + ASSERT_OK_AND_ASSIGN(CommandExecutor branch0_executor, + MakeBranchExecutor(&branch0_counts)); + ASSERT_OK_AND_ASSIGN(CommandExecutor branch1_executor, + MakeBranchExecutor(&branch1_counts)); + branch_executors.push_back(std::move(branch0_executor)); + branch_executors.push_back(std::move(branch1_executor)); + ASSERT_OK(thunk.SetOrUpdateCommandBufferBranchExecutors( + std::move(branch_executors))); + + int32_t branch_index = 0; + std::vector buffers = { + se::DeviceAddressBase(&branch_index, sizeof(branch_index))}; + BufferAllocations allocations(buffers, /*device_ordinal=*/0, + /*memory_allocator=*/nullptr); + ServiceExecutableRunOptions run_options; + Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create( + run_options, allocations, /*stream=*/nullptr, + /*command_buffer_trace_stream=*/nullptr, /*collective_params=*/nullptr, + /*collective_cliques=*/nullptr, /*collective_memory=*/nullptr); + + CommandStateManager state_manager; + Command::RecordParams record_params{state_manager}; + testing::NiceMock command_buffer; + FakeSeCommand case_se_command; + FakeSeCommand dependency_command; + std::vector> branch_command_buffers; + int create_case_calls = 0; + int update_case_calls = 0; + int create_case_branch_count = 0; + int update_case_branch_count = 0; + int create_case_dependency_count = 0; + + std::vector dependencies = { + &dependency_command}; + EXPECT_CALL(command_buffer, + CreateCase(testing::A>(), testing::_, + testing::_)) + .WillOnce( + [&](se::DeviceAddress, + std::vector create_branches, + absl::Span + create_dependencies) + -> absl::StatusOr { + ++create_case_calls; + create_case_branch_count = create_branches.size(); + create_case_dependency_count = create_dependencies.size(); + if (create_dependencies.size() != dependencies.size()) { + return absl::InternalError("unexpected dependency count"); + } + for (size_t i = 0; i < dependencies.size(); ++i) { + if (create_dependencies[i] != dependencies[i]) { + return absl::InternalError("unexpected dependency"); + } + } + + branch_command_buffers.clear(); + branch_command_buffers.reserve(create_branches.size()); + for (se::CommandBuffer::CreateCommands& create_branch : + create_branches) { + auto branch = std::make_unique(); + ConfigureNestedCommandBuffer(branch.get()); + TF_RETURN_IF_ERROR(create_branch(branch->command_buffer.get(), + /*dependencies=*/{}) + .status()); + TF_RETURN_IF_ERROR(branch->command_buffer->Finalize()); + branch_command_buffers.push_back(std::move(branch)); + } + return &case_se_command; + }); + + ASSERT_OK_AND_ASSIGN( + const se::CommandBuffer::Command* case_command, + thunk.Record(execute_params, record_params, + Command::RecordCreate{dependencies}, &command_buffer)); + + EXPECT_EQ(case_command, &case_se_command); + EXPECT_EQ(create_case_calls, 1); + EXPECT_EQ(create_case_branch_count, 2); + EXPECT_EQ(create_case_dependency_count, 1); + EXPECT_EQ(branch0_counts.creates, 1); + EXPECT_EQ(branch1_counts.creates, 1); + EXPECT_EQ(branch0_counts.updates, 0); + EXPECT_EQ(branch1_counts.updates, 0); + + EXPECT_CALL(command_buffer, + UpdateCase(&case_se_command, + testing::A>(), testing::_)) + .WillOnce( + [&](const se::CommandBuffer::Command* command, + se::DeviceAddress, + std::vector update_branches) + -> absl::Status { + ++update_case_calls; + update_case_branch_count = update_branches.size(); + if (command != &case_se_command) { + return absl::InternalError("unexpected case command"); + } + if (update_branches.size() != branch_command_buffers.size()) { + return absl::InternalError("unexpected branch count"); + } + for (size_t i = 0; i < update_branches.size(); ++i) { + TF_RETURN_IF_ERROR( + branch_command_buffers[i]->command_buffer->Update()); + TF_RETURN_IF_ERROR(update_branches[i]( + branch_command_buffers[i]->command_buffer.get())); + TF_RETURN_IF_ERROR( + branch_command_buffers[i]->command_buffer->Finalize()); + } + return absl::OkStatus(); + }); + + ASSERT_OK_AND_ASSIGN( + const se::CommandBuffer::Command* updated_case_command, + thunk.Record(execute_params, record_params, + Command::RecordUpdate{case_command}, &command_buffer)); + + EXPECT_EQ(updated_case_command, case_command); + EXPECT_EQ(update_case_calls, 1); + EXPECT_EQ(update_case_branch_count, 2); + EXPECT_EQ(branch0_counts.creates, 1); + EXPECT_EQ(branch1_counts.creates, 1); + EXPECT_EQ(branch0_counts.updates, 1); + EXPECT_EQ(branch1_counts.updates, 1); +} + TEST(ConditionalThunkTest, ToProto) { Thunk::ThunkInfo thunk_info; thunk_info.profile_annotation = "profile_annotation"; diff --git a/third_party/xla/xla/backends/gpu/runtime/convolution_reorder_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/convolution_reorder_thunk.cc index 872cce67cfda74..eb204d492c4560 100644 --- a/third_party/xla/xla/backends/gpu/runtime/convolution_reorder_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/convolution_reorder_thunk.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/convolution_filter_thunk.pb.h" #include "xla/backends/gpu/runtime/shaped_slice.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -112,21 +113,21 @@ absl::StatusOr> ConvolutionReorderThunk::FromProto( ThunkInfo thunk_info, const ConvolutionReorderThunkProto& proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice filter_input, ShapedSlice::FromProto(proto.filter_input(), buffer_allocations)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice filter_output, ShapedSlice::FromProto(proto.filter_output(), buffer_allocations)); std::optional biases; if (proto.has_biases()) { - TF_ASSIGN_OR_RETURN(ShapedSlice bias_input, - ShapedSlice::FromProto(proto.biases().bias_input(), - buffer_allocations)); - TF_ASSIGN_OR_RETURN(ShapedSlice bias_output, - ShapedSlice::FromProto(proto.biases().bias_output(), - buffer_allocations)); + ASSIGN_OR_RETURN(ShapedSlice bias_input, + ShapedSlice::FromProto(proto.biases().bias_input(), + buffer_allocations)); + ASSIGN_OR_RETURN(ShapedSlice bias_output, + ShapedSlice::FromProto(proto.biases().bias_output(), + buffer_allocations)); biases = {{bias_input, bias_output}}; } @@ -141,18 +142,18 @@ absl::StatusOr ConvolutionReorderThunk::ToProto() const { ConvolutionReorderThunkProto* reorder_proto = thunk_proto.mutable_convolution_reorder_thunk(); - TF_ASSIGN_OR_RETURN(*reorder_proto->mutable_filter_input(), - filter_input_.ToProto()); - TF_ASSIGN_OR_RETURN(*reorder_proto->mutable_filter_output(), - filter_output_.ToProto()); + ASSIGN_OR_RETURN(*reorder_proto->mutable_filter_input(), + filter_input_.ToProto()); + ASSIGN_OR_RETURN(*reorder_proto->mutable_filter_output(), + filter_output_.ToProto()); if (biases_.has_value()) { ConvolutionReorderBiasBuffers* biases_proto = reorder_proto->mutable_biases(); - TF_ASSIGN_OR_RETURN(*biases_proto->mutable_bias_input(), - biases_->bias_input.ToProto()); - TF_ASSIGN_OR_RETURN(*biases_proto->mutable_bias_output(), - biases_->bias_output.ToProto()); + ASSIGN_OR_RETURN(*biases_proto->mutable_bias_input(), + biases_->bias_input.ToProto()); + ASSIGN_OR_RETURN(*biases_proto->mutable_bias_output(), + biases_->bias_output.ToProto()); } return thunk_proto; diff --git a/third_party/xla/xla/backends/gpu/runtime/convolution_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/convolution_thunk.cc index 14860c0d5f07be..f37bccdbfc4170 100644 --- a/third_party/xla/xla/backends/gpu/runtime/convolution_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/convolution_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/shaped_slice.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" @@ -49,8 +50,8 @@ absl::StatusOr> ConvolutionThunk::Create( std::vector operand_slices, std::vector result_slices, BufferAllocation::Slice scratch_slice) { - TF_ASSIGN_OR_RETURN(GpuConvConfig config, - GetGpuConvConfig(descriptor, /*inst_as_string=*/"")); + ASSIGN_OR_RETURN(GpuConvConfig config, + GetGpuConvConfig(descriptor, /*inst_as_string=*/"")); // Can't use std::make_unique because the constructor is private. return absl::WrapUnique(new ConvolutionThunk( @@ -111,9 +112,9 @@ absl::Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { << " addr: " << scratch.opaque(); auto opts = GetOrCreate(config_, params.stream); - TF_RETURN_IF_ERROR(RunGpuConv(config_, absl::MakeSpan(operand_se_buffers), - absl::MakeSpan(result_se_buffers), scratch, - params.stream, opts)); + RETURN_IF_ERROR(RunGpuConv(config_, absl::MakeSpan(operand_se_buffers), + absl::MakeSpan(result_se_buffers), scratch, + params.stream, opts)); // Note: Convolution has a tuple buffer as an output, but we don't need to // populate it as no one should be reading from the tuple directly. @@ -141,28 +142,26 @@ Thunk::BufferUses ConvolutionThunk::buffer_uses() const { absl::StatusOr> ConvolutionThunk::FromProto( ThunkInfo thunk_info, const ConvolutionThunkProto& proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN(GpuConvDescriptor descriptor, - GpuConvDescriptor::FromProto(proto.conv_descriptor())); + ASSIGN_OR_RETURN(GpuConvDescriptor descriptor, + GpuConvDescriptor::FromProto(proto.conv_descriptor())); std::vector operand_slices; operand_slices.reserve(proto.operand_buffers_size()); for (const ShapedSliceProto& slice_proto : proto.operand_buffers()) { - TF_ASSIGN_OR_RETURN( - operand_slices.emplace_back(), - ShapedSlice::FromProto(slice_proto, buffer_allocations)); + ASSIGN_OR_RETURN(operand_slices.emplace_back(), + ShapedSlice::FromProto(slice_proto, buffer_allocations)); } std::vector result_slices; result_slices.reserve(proto.result_buffers_size()); for (const ShapedSliceProto& slice_proto : proto.result_buffers()) { - TF_ASSIGN_OR_RETURN( - result_slices.emplace_back(), - ShapedSlice::FromProto(slice_proto, buffer_allocations)); + ASSIGN_OR_RETURN(result_slices.emplace_back(), + ShapedSlice::FromProto(slice_proto, buffer_allocations)); } - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - BufferAllocation::Slice::FromProto(proto.scratch_buffer(), - buffer_allocations)); + ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice::FromProto(proto.scratch_buffer(), + buffer_allocations)); return Create(std::move(thunk_info), std::move(descriptor), std::move(operand_slices), std::move(result_slices), @@ -177,13 +176,13 @@ absl::StatusOr ConvolutionThunk::ToProto() const { *conv_proto->mutable_conv_descriptor() = descriptor_.ToProto(); for (const ShapedSlice& slice : operand_buffers_) { - TF_ASSIGN_OR_RETURN(*conv_proto->add_operand_buffers(), slice.ToProto()); + ASSIGN_OR_RETURN(*conv_proto->add_operand_buffers(), slice.ToProto()); } for (const ShapedSlice& slice : result_buffers_) { - TF_ASSIGN_OR_RETURN(*conv_proto->add_result_buffers(), slice.ToProto()); + ASSIGN_OR_RETURN(*conv_proto->add_result_buffers(), slice.ToProto()); } - TF_ASSIGN_OR_RETURN(*conv_proto->mutable_scratch_buffer(), - scratch_buffer_.ToProto()); + ASSIGN_OR_RETURN(*conv_proto->mutable_scratch_buffer(), + scratch_buffer_.ToProto()); return proto; } diff --git a/third_party/xla/xla/backends/gpu/runtime/copy_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/copy_thunk.cc index 465ae661a31c68..b248a6b7232d23 100644 --- a/third_party/xla/xla/backends/gpu/runtime/copy_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/copy_thunk.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/service/buffer_assignment.h" @@ -55,10 +56,10 @@ absl::StatusOr CopyThunk::ToProto() const { *proto.mutable_thunk_info() = thunk_info().ToProto(); CopyThunkProto* copy_thunk_proto = proto.mutable_copy_thunk(); - TF_ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_source_buffer(), - source_buffer_.ToProto()); - TF_ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_destination_buffer(), - destination_buffer_.ToProto()); + ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_source_buffer(), + source_buffer_.ToProto()); + ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_destination_buffer(), + destination_buffer_.ToProto()); copy_thunk_proto->set_mem_size(size_bytes()); return proto; } @@ -66,12 +67,12 @@ absl::StatusOr CopyThunk::ToProto() const { absl::StatusOr> CopyThunk::FromProto( ThunkInfo thunk_info, const CopyThunkProto& thunk_proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice src_slice, ShapedSlice::FromProto(thunk_proto.source_buffer(), buffer_allocations)); - TF_ASSIGN_OR_RETURN(ShapedSlice dst_slice, - ShapedSlice::FromProto(thunk_proto.destination_buffer(), - buffer_allocations)); + ASSIGN_OR_RETURN(ShapedSlice dst_slice, + ShapedSlice::FromProto(thunk_proto.destination_buffer(), + buffer_allocations)); if (ShapeUtil::ByteSizeOfElements(src_slice.shape) != ShapeUtil::ByteSizeOfElements(dst_slice.shape)) { return absl::FailedPreconditionError( diff --git a/third_party/xla/xla/backends/gpu/runtime/custom_kernel_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/custom_kernel_thunk_test.cc index c2838b35e1f767..cd92ecfd7d9b55 100644 --- a/third_party/xla/xla/backends/gpu/runtime/custom_kernel_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/custom_kernel_thunk_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/codegen/kernels/custom_kernel.h" #include "xla/backends/gpu/codegen/kernels/ptx_custom_kernel.h" #include "xla/backends/gpu/runtime/command.h" @@ -204,9 +205,9 @@ TEST(CustomKernelThunkTest, FromProto) { //===----------------------------------------------------------------------===// static absl::StatusOr GpuExecutor() { - TF_ASSIGN_OR_RETURN(std::string name, - PlatformUtil::CanonicalPlatformName("gpu")); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(std::string name, + PlatformUtil::CanonicalPlatformName("gpu")); + ASSIGN_OR_RETURN( se::Platform * platform, se::PlatformManager::PlatformWithName(absl::AsciiStrToUpper(name))); return platform->ExecutorForDevice(0); @@ -220,7 +221,7 @@ static absl::StatusOr> MakeAddI32CustomKernelThunk(const std::vector& allocs) { absl::string_view ptx = se::gpu::GetAddI32PtxKernelSpec().cuda_ptx_in_memory().value().ptx; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( CustomKernel kernel, kernel::GetPtxCustomKernel(/*kernel_name=*/"AddI32", ptx, /*num_args=*/3, /*block_dim=*/se::BlockDim(1, 1, 1), diff --git a/third_party/xla/xla/backends/gpu/runtime/device_to_device_copy_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/device_to_device_copy_thunk.cc index 2ad2f1932e14ba..c04d7ba5643ee0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/device_to_device_copy_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/device_to_device_copy_thunk.cc @@ -109,10 +109,10 @@ absl::StatusOr DeviceToDeviceCopyThunk::ToProto() const { DeviceToDeviceCopyThunkProto* d2d_copy_thunk_proto = proto.mutable_device_to_device_copy_thunk(); CopyThunkProto* copy_thunk_proto = d2d_copy_thunk_proto->mutable_copy_thunk(); - TF_ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_source_buffer(), - source_buffer_.ToProto()); - TF_ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_destination_buffer(), - destination_buffer_.ToProto()); + ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_source_buffer(), + source_buffer_.ToProto()); + ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_destination_buffer(), + destination_buffer_.ToProto()); copy_thunk_proto->set_mem_size(size_bytes()); return proto; } @@ -121,11 +121,11 @@ absl::StatusOr> DeviceToDeviceCopyThunk::FromProto( ThunkInfo thunk_info, const DeviceToDeviceCopyThunkProto& thunk_proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice src_slice, ShapedSlice::FromProto(thunk_proto.copy_thunk().source_buffer(), buffer_allocations)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice dst_slice, ShapedSlice::FromProto(thunk_proto.copy_thunk().destination_buffer(), buffer_allocations)); diff --git a/third_party/xla/xla/backends/gpu/runtime/device_to_host_copy_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/device_to_host_copy_thunk.cc index 107e93f7c4f666..20ef1e071bec03 100644 --- a/third_party/xla/xla/backends/gpu/runtime/device_to_host_copy_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/device_to_host_copy_thunk.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/copy_thunk.h" #include "xla/backends/gpu/runtime/copy_thunk.pb.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -64,10 +65,10 @@ absl::StatusOr DeviceToHostCopyThunk::ToProto() const { DeviceToHostCopyThunkProto* d2h_copy_thunk_proto = proto.mutable_device_to_host_copy_thunk(); CopyThunkProto* copy_thunk_proto = d2h_copy_thunk_proto->mutable_copy_thunk(); - TF_ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_source_buffer(), - source().ToProto()); - TF_ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_destination_buffer(), - destination().ToProto()); + ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_source_buffer(), + source().ToProto()); + ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_destination_buffer(), + destination().ToProto()); copy_thunk_proto->set_mem_size(size_bytes()); return proto; } @@ -76,11 +77,11 @@ absl::StatusOr> DeviceToHostCopyThunk::FromProto( ThunkInfo thunk_info, const DeviceToHostCopyThunkProto& thunk_proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice src_slice, ShapedSlice::FromProto(thunk_proto.copy_thunk().source_buffer(), buffer_allocations)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice dst_slice, ShapedSlice::FromProto(thunk_proto.copy_thunk().destination_buffer(), buffer_allocations)); diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_memcpy_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/dynamic_memcpy_thunk.cc index db3c71371219e2..d72b0be8abcf66 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_memcpy_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_memcpy_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/copy_thunk.pb.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" @@ -99,10 +100,10 @@ absl::StatusOr DynamicMemcpyThunk::ToProto() const { DynamicMemcpyThunkProto* dynamic_memcpy_thunk_proto = proto.mutable_dynamic_memcpy_thunk(); - TF_ASSIGN_OR_RETURN(*dynamic_memcpy_thunk_proto->mutable_source_buffer(), - source_buffer_.ToProto()); - TF_ASSIGN_OR_RETURN(*dynamic_memcpy_thunk_proto->mutable_destination_buffer(), - destination_buffer_.ToProto()); + ASSIGN_OR_RETURN(*dynamic_memcpy_thunk_proto->mutable_source_buffer(), + source_buffer_.ToProto()); + ASSIGN_OR_RETURN(*dynamic_memcpy_thunk_proto->mutable_destination_buffer(), + destination_buffer_.ToProto()); dynamic_memcpy_thunk_proto->set_mem_size(mem_size_); *dynamic_memcpy_thunk_proto->mutable_offsets() = offsets_.ToProto(); return proto; @@ -112,14 +113,13 @@ absl::StatusOr> DynamicMemcpyThunk::FromProto( ThunkInfo thunk_info, const DynamicMemcpyThunkProto& thunk_proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice src_slice, ShapedSlice::FromProto(thunk_proto.source_buffer(), buffer_allocations)); - TF_ASSIGN_OR_RETURN(ShapedSlice dst_slice, - ShapedSlice::FromProto(thunk_proto.destination_buffer(), - buffer_allocations)); - TF_ASSIGN_OR_RETURN(Offsets offsets, - Offsets::FromProto(thunk_proto.offsets())); + ASSIGN_OR_RETURN(ShapedSlice dst_slice, + ShapedSlice::FromProto(thunk_proto.destination_buffer(), + buffer_allocations)); + ASSIGN_OR_RETURN(Offsets offsets, Offsets::FromProto(thunk_proto.offsets())); return std::make_unique(std::move(thunk_info), src_slice, dst_slice, thunk_proto.mem_size(), std::move(offsets)); diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc index 9f9dba4e781a34..28902791e945f1 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/ffi.h" #include "xla/backends/gpu/runtime/collective_clique_requests.h" #include "xla/backends/gpu/runtime/collective_memory_requests.h" @@ -124,12 +125,12 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk, -> absl::StatusOr> { ThunkSequenceProto thunk_sequence_proto; *thunk_sequence_proto.add_thunks() = thunk_proto; - TF_ASSIGN_OR_RETURN(ThunkSequence sequence, - DeserializeThunkSequenceProto( - thunk_sequence_proto, fake_allocations_span, - /*hlo_module=*/nullptr, - /*platform_name=*/"TEST_PLATFORM", - /*gpu_compute_capability=*/{})); + ASSIGN_OR_RETURN(ThunkSequence sequence, + DeserializeThunkSequenceProto( + thunk_sequence_proto, fake_allocations_span, + /*hlo_module=*/nullptr, + /*platform_name=*/"TEST_PLATFORM", + /*gpu_compute_capability=*/{})); return std::move(sequence.front()); }; @@ -220,7 +221,7 @@ absl::StatusOr> CreateSlicedGemmThunk( backing_allocations.push_back(std::move(alloc_lhs_offset_0)); backing_allocations.push_back(std::move(alloc_lhs_offset_1)); // Preparing config for GEMM thunk. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GemmConfig config, GemmConfig::For( ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -390,7 +391,7 @@ CreateMultipleSlicedOperandsGemmThunk( backing_allocations.push_back(std::move(alloc_rhs_offset_1)); // Preparing config for GEMM thunk. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GemmConfig config, GemmConfig::For( ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -934,7 +935,7 @@ CreateSlicedGemmArbitraryArgumentOrderThunk( backing_allocations.push_back(std::move(alloc_lhs_offset_1)); // Preparing config for GEMM thunk. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GemmConfig config, GemmConfig::For( ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -1109,7 +1110,7 @@ CreateSlicedGemmArbitraryNumberOfArgumentsThunk( backing_allocations.push_back(std::move(alloc_lhs_offset_1)); // Preparing config for GEMM thunk. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GemmConfig config, GemmConfig::For( ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -1275,7 +1276,7 @@ CreateSlicedTupledOperandGemmThunk( backing_allocations.push_back(std::move(alloc_lhs_offset_1)); // Preparing config for GEMM thunk. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GemmConfig config, GemmConfig::For( ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -1655,7 +1656,7 @@ CreateSlicedOperandsSameBufferGemmThunk( backing_allocations.push_back(std::move(alloc_lhs_offset_1)); // Preparing config for GEMM thunk. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GemmConfig config, GemmConfig::For( ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), {}, {1}, @@ -1792,8 +1793,7 @@ CreateHostInductionVariableAndOffsetEvaluationThunk( ROOT select = s32[] select(compare, add, p0) } )"; - TF_ASSIGN_OR_RETURN(auto offset_module, - ParseAndReturnUnverifiedModule(offset)); + ASSIGN_OR_RETURN(auto offset_module, ParseAndReturnUnverifiedModule(offset)); offset_modules.emplace_back(std::move(offset_module)); HloModule* offset_module_ptr = offset_modules.back().get(); const char* indvar_init = R"( @@ -1802,8 +1802,8 @@ CreateHostInductionVariableAndOffsetEvaluationThunk( ROOT c0 = s32[] constant(0) } )"; - TF_ASSIGN_OR_RETURN(auto indvar_init_module, - ParseAndReturnUnverifiedModule(indvar_init)); + ASSIGN_OR_RETURN(auto indvar_init_module, + ParseAndReturnUnverifiedModule(indvar_init)); const char* indvar_update = R"( HloModule indvar_update ENTRY main { @@ -1812,8 +1812,8 @@ CreateHostInductionVariableAndOffsetEvaluationThunk( ROOT add = s32[] add(p0, c1) } )"; - TF_ASSIGN_OR_RETURN(auto indvar_update_module, - ParseAndReturnUnverifiedModule(indvar_update)); + ASSIGN_OR_RETURN(auto indvar_update_module, + ParseAndReturnUnverifiedModule(indvar_update)); se::StreamExecutor* executor = GpuExecutor(); int64_t lhs_length = sizeof(float) * 2 * 4; @@ -1852,7 +1852,7 @@ CreateHostInductionVariableAndOffsetEvaluationThunk( // Preparing config for GEMM thunk. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GemmConfig config, GemmConfig::For( /*lhs_shape=*/ShapeUtil::MakeShape( diff --git a/third_party/xla/xla/backends/gpu/runtime/fft_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/fft_thunk.cc index a2cf3258f202ed..043a9c4b59dac9 100644 --- a/third_party/xla/xla/backends/gpu/runtime/fft_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/fft_thunk.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/buffer_assignment.pb.h" @@ -166,7 +167,7 @@ absl::Status RunFft(se::DeviceAddressBase input, const Shape& input_shape, // protect each plan with a mutex. absl::MutexLock lock(fft_plan_ptr->mu); std::unique_ptr& fft_plan = fft_plan_ptr->plan; - TF_ASSIGN_OR_RETURN(auto fft, GetFft(stream)); + ASSIGN_OR_RETURN(auto fft, GetFft(stream)); if (fft_plan == nullptr) { const int64_t fft_rank = fft_len.size(); CHECK_LE(fft_rank, 3); @@ -225,7 +226,7 @@ absl::Status RunFft(se::DeviceAddressBase input, const Shape& input_shape, se::DeviceAddress output_data(output); launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); if (launch_ok) { - TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), complex64(1.0f / scale_factor), &output_data, 1); @@ -237,7 +238,7 @@ absl::Status RunFft(se::DeviceAddressBase input, const Shape& input_shape, se::DeviceAddress output_data(output); launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); if (launch_ok) { - TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), complex128(1.0 / scale_factor), &output_data, 1); @@ -261,7 +262,7 @@ absl::Status RunFft(se::DeviceAddressBase input, const Shape& input_shape, se::DeviceAddress output_data(output); launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); if (launch_ok) { - TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), 1.0f / scale_factor, &output_data, 1); @@ -273,7 +274,7 @@ absl::Status RunFft(se::DeviceAddressBase input, const Shape& input_shape, se::DeviceAddress output_data(output); launch_ok = fft->DoFft(stream, fft_plan.get(), input_data, &output_data); if (launch_ok) { - TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); + ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), 1.0 / scale_factor, &output_data, 1); @@ -293,16 +294,15 @@ absl::Status RunFft(se::DeviceAddressBase input, const Shape& input_shape, absl::StatusOr> FftThunk::FromProto( ThunkInfo thunk_info, const FftThunkProto& proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_buffer, - BufferAllocation::Slice::FromProto(proto.input_buffer(), - buffer_allocations)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_buffer, - BufferAllocation::Slice::FromProto(proto.output_buffer(), - buffer_allocations)); + ASSIGN_OR_RETURN(BufferAllocation::Slice input_buffer, + BufferAllocation::Slice::FromProto(proto.input_buffer(), + buffer_allocations)); + ASSIGN_OR_RETURN(BufferAllocation::Slice output_buffer, + BufferAllocation::Slice::FromProto(proto.output_buffer(), + buffer_allocations)); - TF_ASSIGN_OR_RETURN(Shape input_shape, Shape::FromProto(proto.input_shape())); - TF_ASSIGN_OR_RETURN(Shape output_shape, - Shape::FromProto(proto.output_shape())); + ASSIGN_OR_RETURN(Shape input_shape, Shape::FromProto(proto.input_shape())); + ASSIGN_OR_RETURN(Shape output_shape, Shape::FromProto(proto.output_shape())); std::vector fft_length{proto.fft_length().begin(), proto.fft_length().end()}; @@ -317,14 +317,13 @@ absl::StatusOr FftThunk::ToProto() const { *thunk_proto.mutable_thunk_info() = thunk_info().ToProto(); FftThunkProto* proto = thunk_proto.mutable_fft_thunk(); - TF_ASSIGN_OR_RETURN(FftType fft_type, SeTypeToFftType(fft_type_)); + ASSIGN_OR_RETURN(FftType fft_type, SeTypeToFftType(fft_type_)); proto->set_fft_type(fft_type); *proto->mutable_fft_length() = {fft_length_.begin(), fft_length_.end()}; - TF_ASSIGN_OR_RETURN(*proto->mutable_input_buffer(), input_buffer_.ToProto()); - TF_ASSIGN_OR_RETURN(*proto->mutable_output_buffer(), - output_buffer_.ToProto()); + ASSIGN_OR_RETURN(*proto->mutable_input_buffer(), input_buffer_.ToProto()); + ASSIGN_OR_RETURN(*proto->mutable_output_buffer(), output_buffer_.ToProto()); *proto->mutable_input_shape() = input_shape_.ToProto(); *proto->mutable_output_shape() = output_shape_.ToProto(); diff --git a/third_party/xla/xla/backends/gpu/runtime/gemm_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/gemm_thunk.cc index cbe82311594859..1eceb7fcc51930 100644 --- a/third_party/xla/xla/backends/gpu/runtime/gemm_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/gemm_thunk.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/backends/gpu/runtime/traced_command.h" @@ -103,15 +104,15 @@ absl::StatusOr GemmThunk::ToProto() const { auto* gemm_thunk_proto = proto.mutable_gemm_thunk(); *gemm_thunk_proto->mutable_gemm_config() = config_.ToProto(); - TF_ASSIGN_OR_RETURN(*gemm_thunk_proto->mutable_lhs_buffer(), - lhs_buffer_.ToProto()); - TF_ASSIGN_OR_RETURN(*gemm_thunk_proto->mutable_rhs_buffer(), - rhs_buffer_.ToProto()); - TF_ASSIGN_OR_RETURN(*gemm_thunk_proto->mutable_output_buffer(), - output_buffer_.ToProto()); + ASSIGN_OR_RETURN(*gemm_thunk_proto->mutable_lhs_buffer(), + lhs_buffer_.ToProto()); + ASSIGN_OR_RETURN(*gemm_thunk_proto->mutable_rhs_buffer(), + rhs_buffer_.ToProto()); + ASSIGN_OR_RETURN(*gemm_thunk_proto->mutable_output_buffer(), + output_buffer_.ToProto()); if (workspace_.has_value()) { - TF_ASSIGN_OR_RETURN(*gemm_thunk_proto->mutable_workspace(), - workspace_.value().ToProto()); + ASSIGN_OR_RETURN(*gemm_thunk_proto->mutable_workspace(), + workspace_.value().ToProto()); } gemm_thunk_proto->set_deterministic(deterministic_); return proto; @@ -120,21 +121,21 @@ absl::StatusOr GemmThunk::ToProto() const { absl::StatusOr> GemmThunk::FromProto( ThunkInfo thunk_info, const GemmThunkProto& proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN(stream_executor::gpu::GemmConfig config, - GemmConfig::FromProto(proto.gemm_config())); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_buffer, - BufferAllocation::Slice::FromProto(proto.lhs_buffer(), - buffer_allocations)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_buffer, - BufferAllocation::Slice::FromProto(proto.rhs_buffer(), - buffer_allocations)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_buffer, - BufferAllocation::Slice::FromProto(proto.output_buffer(), - buffer_allocations)); + ASSIGN_OR_RETURN(stream_executor::gpu::GemmConfig config, + GemmConfig::FromProto(proto.gemm_config())); + ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_buffer, + BufferAllocation::Slice::FromProto(proto.lhs_buffer(), + buffer_allocations)); + ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_buffer, + BufferAllocation::Slice::FromProto(proto.rhs_buffer(), + buffer_allocations)); + ASSIGN_OR_RETURN(BufferAllocation::Slice output_buffer, + BufferAllocation::Slice::FromProto(proto.output_buffer(), + buffer_allocations)); std::optional workspace; if (proto.has_workspace()) { - TF_ASSIGN_OR_RETURN(workspace, BufferAllocation::Slice::FromProto( - proto.workspace(), buffer_allocations)); + ASSIGN_OR_RETURN(workspace, BufferAllocation::Slice::FromProto( + proto.workspace(), buffer_allocations)); } return std::make_unique(thunk_info, GemmConfig(config), lhs_buffer, rhs_buffer, output_buffer, workspace, diff --git a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc index 30f36c1617fbbd..0b525d2e85cb49 100644 --- a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include #include +#include +#include #include "absl/log/log.h" #include "absl/status/status.h" @@ -166,26 +168,47 @@ absl::Status CublasLtMatmulThunk::ExecuteOnStreamInternal( "GroupedMatmul must have a non-empty group_sizes_"); } // Grouped matmul execution - TF_ASSIGN_OR_RETURN(auto* plan, GetCachedGroupedMatmulPlan(params)); + ASSIGN_OR_RETURN(auto* plan, GetCachedGroupedMatmulPlan(params)); return plan->ExecuteOnStream( - stream, allocs.GetDeviceAddress(a_.slice), - allocs.GetDeviceAddress(b_.slice), allocs.GetDeviceAddress(c_.slice), - allocs.GetDeviceAddress(d_.slice), - allocs.GetDeviceAddress(group_sizes_->slice), bias, aux, a_scale, - b_scale, c_scale, d_scale, d_amax, workspace); + stream, + se::gpu::BlasLt::MemoryArgs{ + allocs.GetDeviceAddress(a_.slice), + allocs.GetDeviceAddress(b_.slice), + allocs.GetDeviceAddress(c_.slice), + allocs.GetDeviceAddress(d_.slice), + bias, + aux, + a_scale, + b_scale, + c_scale, + d_scale, + {allocs.GetDeviceAddress(group_sizes_->slice)}, + workspace}, + /* profile_result */ nullptr); } // Regular matmul execution - TF_ASSIGN_OR_RETURN(auto* plan, GetCachedMatmulPlan(params)); + ASSIGN_OR_RETURN(auto* plan, GetCachedMatmulPlan(params)); return plan->ExecuteOnStream( - stream, allocs.GetDeviceAddress(a_.slice), - allocs.GetDeviceAddress(b_.slice), allocs.GetDeviceAddress(c_.slice), - allocs.GetDeviceAddress(d_.slice), bias, aux, a_scale, b_scale, c_scale, - d_scale, d_amax, workspace); + stream, + se::gpu::BlasLt::MemoryArgs{allocs.GetDeviceAddress(a_.slice), + allocs.GetDeviceAddress(b_.slice), + allocs.GetDeviceAddress(c_.slice), + allocs.GetDeviceAddress(d_.slice), + bias, + aux, + a_scale, + b_scale, + c_scale, + d_scale, + {d_amax}, + workspace}, + /* profile_result */ nullptr); } absl::StatusOr CublasLtMatmulThunk::GetCachedMatmulPlan(const ExecuteParams& params) { - auto* blas_lt = se::gpu::BlasLt::Get(params.stream); + TF_ASSIGN_OR_RETURN(auto* blas_lt, + se::gpu::BlasLt::Get(params.stream->parent())); auto create = [&]() -> absl::StatusOr { VLOG(2) << this << ": Adding new MatmulPlan for stream: " << params.stream << " instr: " << canonical_hlo_; @@ -195,8 +218,8 @@ CublasLtMatmulThunk::GetCachedMatmulPlan(const ExecuteParams& params) { return absl::InternalError( "Expected GemmConfig but gemm_config_ holds a different type"); } - TF_ASSIGN_OR_RETURN(auto plan, - blas_lt->GetMatmulPlan(*gemm_config, epilogue_)); + ASSIGN_OR_RETURN(auto plan, + blas_lt->GetMatmulPlan(*gemm_config, epilogue_)); // Set the workspace size to the size that was used for autotuning, so // algorithm index will be the same as returned by GetAlgorithms called @@ -207,15 +230,14 @@ CublasLtMatmulThunk::GetCachedMatmulPlan(const ExecuteParams& params) { // algorithms, it's enough to get the default one only. int64_t num_algorithms = algorithm_idx_ == 0 ? 1 : GemmConfig::kNumAlgorithms; - TF_ASSIGN_OR_RETURN( - auto algorithms, - plan->GetAlgorithms(params.stream, num_algorithms, max_workspace)); + ASSIGN_OR_RETURN(auto algorithms, + plan->GetAlgorithms(num_algorithms, max_workspace)); if (algorithms.empty()) { return absl::InternalError( "Failed to get a MatmulPlan: no valid algorithm found."); } - TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[algorithm_idx_])); + RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[algorithm_idx_])); return std::move(plan); }; return blas_lt->GetOrCreateMatmulPlan(canonical_hlo_, create); @@ -223,7 +245,8 @@ CublasLtMatmulThunk::GetCachedMatmulPlan(const ExecuteParams& params) { absl::StatusOr CublasLtMatmulThunk::GetCachedGroupedMatmulPlan(const ExecuteParams& params) { - auto* blas_lt = se::gpu::BlasLt::Get(params.stream); + TF_ASSIGN_OR_RETURN(auto* blas_lt, + se::gpu::BlasLt::Get(params.stream->parent())); auto create = [&]() -> absl::StatusOr { VLOG(2) << this << ": Adding new Grouped MatmulPlan for stream: " << params.stream @@ -234,8 +257,8 @@ CublasLtMatmulThunk::GetCachedGroupedMatmulPlan(const ExecuteParams& params) { return absl::InternalError( "Expected GroupedGemmConfig but gemm_config_ holds a different type"); } - TF_ASSIGN_OR_RETURN(auto plan, - blas_lt->GetGroupedMatmulPlan(*gemm_config, epilogue_)); + ASSIGN_OR_RETURN(auto plan, + blas_lt->GetGroupedMatmulPlan(*gemm_config, epilogue_)); // Set the workspace size to the size that was used for autotuning, so // algorithm index will be the same as returned by GetAlgorithms called @@ -246,15 +269,14 @@ CublasLtMatmulThunk::GetCachedGroupedMatmulPlan(const ExecuteParams& params) { // algorithms, it's enough to get the default one only. int64_t num_algorithms = algorithm_idx_ == 0 ? 1 : GemmConfig::kNumAlgorithms; - TF_ASSIGN_OR_RETURN( - auto algorithms, - plan->GetAlgorithms(params.stream, num_algorithms, max_workspace)); + ASSIGN_OR_RETURN(auto algorithms, + plan->GetAlgorithms(num_algorithms, max_workspace)); if (algorithms.empty()) { return absl::InternalError( "Failed to get a GroupedMatmulPlan: no valid algorithm found."); } - TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[algorithm_idx_])); + RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[algorithm_idx_])); return std::move(plan); }; return blas_lt->GetOrCreateGroupedMatmulPlan(canonical_hlo_, create); @@ -315,8 +337,8 @@ absl::StatusOr CublasLtMatmulThunk::ToProto() const { } *cublas_lt_matmul_thunk->mutable_grouped_gemm_config() = gemm_config->ToProto(); - TF_ASSIGN_OR_RETURN(*cublas_lt_matmul_thunk->mutable_group_sizes(), - group_sizes_.value().ToProto()); + ASSIGN_OR_RETURN(*cublas_lt_matmul_thunk->mutable_group_sizes(), + group_sizes_.value().ToProto()); } else { // Serialize regular matmul auto* gemm_config = std::get_if(&gemm_config_); @@ -424,9 +446,9 @@ absl::StatusOr> CublasLtMatmulThunk::FromProto( // Check if this is grouped or regular matmul if (proto.has_grouped_gemm_config()) { // Grouped matmul - TF_ASSIGN_OR_RETURN(stream_executor::gpu::GroupedGemmConfig gemm_config, - stream_executor::gpu::GroupedGemmConfig::FromProto( - proto.grouped_gemm_config())); + ASSIGN_OR_RETURN(stream_executor::gpu::GroupedGemmConfig gemm_config, + stream_executor::gpu::GroupedGemmConfig::FromProto( + proto.grouped_gemm_config())); ASSIGN_OR_RETURN(ShapedSlice group_sizes, ShapedSlice::FromProto(proto.group_sizes(), allocations)); return std::make_unique( diff --git a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h index f7cf37483421df..cb4ba45a521365 100644 --- a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h @@ -25,8 +25,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/backends/gpu/runtime/traced_command.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/matmul_utils.h" diff --git a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc index 69760d47e9aa8c..ead4110a78f5e4 100644 --- a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "google/protobuf/text_format.h" #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_state.h" @@ -121,16 +122,15 @@ class GpuBlasLtThunkBuilder { absl::StatusOr> CreateThunk( HloInstruction* gemm) { - TF_ASSIGN_OR_RETURN(const auto gpu_config, - gemm->backend_config()); + ASSIGN_OR_RETURN(const auto gpu_config, + gemm->backend_config()); const auto& backend_config = gpu_config.gemm_backend_config(); - TF_ASSIGN_OR_RETURN( - bool has_vector_bias, - gpublas_lt::EpilogueAddsVectorBias(backend_config.epilogue())); + ASSIGN_OR_RETURN(bool has_vector_bias, gpublas_lt::EpilogueAddsVectorBias( + backend_config.epilogue())); bool has_matrix_bias = backend_config.beta() != 0; - TF_ASSIGN_OR_RETURN( - auto epilogue, gpublas_lt::AsBlasLtEpilogue(backend_config.epilogue())); + ASSIGN_OR_RETURN(auto epilogue, + gpublas_lt::AsBlasLtEpilogue(backend_config.epilogue())); std::vector buf_shapes; for (auto op : gemm->operands()) { @@ -147,8 +147,8 @@ class GpuBlasLtThunkBuilder { for (const Shape& shape : buf_shapes) { int64_t size = ShapeUtil::ByteSizeOf(shape); mem_buffers_.emplace_back(); - TF_ASSIGN_OR_RETURN(mem_buffers_.back(), - allocator_.Allocate(exec_->device_ordinal(), size)); + ASSIGN_OR_RETURN(mem_buffers_.back(), + allocator_.Allocate(exec_->device_ordinal(), size)); allocs_.emplace_back(/*index=*/idx++, size, /*color=*/0); slices.push_back( {BufferAllocation::Slice{&allocs_.back(), /*offset*/ 0, size}, @@ -157,7 +157,7 @@ class GpuBlasLtThunkBuilder { // we need at least 3 buffers: lhs, rhs and output EXPECT_EQ(slices.size(), 3 + size_t{has_matrix_bias} + size_t{has_vector_bias}); - TF_ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm, gpu_comp_)); + ASSIGN_OR_RETURN(auto gemm_config, GemmConfig::For(gemm, gpu_comp_)); std::optional bias; if (has_vector_bias) { @@ -227,11 +227,11 @@ void GpuBlasLtMatmulThunkTest::CreateExecuteThunksFromHLO( Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; for (auto& thunk : gemm_thunks) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( thunk->Initialize({executor, source, allocs.get(), stream, stream})); - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params)); + RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params)); } - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + RETURN_IF_ERROR(stream->BlockHostUntilDone()); return absl::OkStatus(); }; @@ -339,7 +339,7 @@ struct MockBlasLt : public se::gpu::BlasLt { } absl::StatusOr GetGroupedMatmulPlan( - se::gpu::GroupedGemmConfig&, Epilogue) const override { + const se::gpu::GroupedGemmConfig&, Epilogue) const override { return MatmulPlanPtr{}; } diff --git a/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk_test.cc index e44e282a2d097b..30c552e03f19bb 100644 --- a/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/cpu/alignment.h" #include "xla/backends/cpu/nanort/nanort_client.h" #include "xla/backends/cpu/nanort/nanort_executable.h" @@ -81,10 +82,10 @@ CreateHostExecuteStartThunk( XlaComputation host_computation( *host_offloading_executable_proto.mutable_hlo_module()); - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - client.Compile(host_computation)); - TF_ASSIGN_OR_RETURN(std::unique_ptr aot_compilation_result, - client.Export(executable.get())); + ASSIGN_OR_RETURN(std::unique_ptr executable, + client.Compile(host_computation)); + ASSIGN_OR_RETURN(std::unique_ptr aot_compilation_result, + client.Export(executable.get())); xla::cpu::CpuAotCompilationResult* cpu_aot_compilation_result = tsl::down_cast( diff --git a/third_party/xla/xla/backends/gpu/runtime/host_memory_pool.cc b/third_party/xla/xla/backends/gpu/runtime/host_memory_pool.cc index 674a888cbb64e1..5671c7ca003aff 100644 --- a/third_party/xla/xla/backends/gpu/runtime/host_memory_pool.cc +++ b/third_party/xla/xla/backends/gpu/runtime/host_memory_pool.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/primitive_util.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream_executor.h" @@ -34,9 +35,9 @@ namespace gpu { absl::StatusOr> HostMemoryPool::Create( se::StreamExecutor* executor, PrimitiveType type) { - TF_ASSIGN_OR_RETURN(std::unique_ptr allocation, - executor->HostMemoryAllocate( - kNumElems * primitive_util::ByteWidth(type))); + ASSIGN_OR_RETURN(std::unique_ptr allocation, + executor->HostMemoryAllocate( + kNumElems * primitive_util::ByteWidth(type))); return absl::WrapUnique(new HostMemoryPool(std::move(allocation), type)); } diff --git a/third_party/xla/xla/backends/gpu/runtime/host_to_device_copy_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/host_to_device_copy_thunk.cc index 58bbf7751a7da5..d9ec9d77bc77bf 100644 --- a/third_party/xla/xla/backends/gpu/runtime/host_to_device_copy_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/host_to_device_copy_thunk.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/copy_thunk.h" #include "xla/backends/gpu/runtime/copy_thunk.pb.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -64,10 +65,10 @@ absl::StatusOr HostToDeviceCopyThunk::ToProto() const { HostToDeviceCopyThunkProto* h2d_copy_thunk_proto = proto.mutable_host_to_device_copy_thunk(); CopyThunkProto* copy_thunk_proto = h2d_copy_thunk_proto->mutable_copy_thunk(); - TF_ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_source_buffer(), - source().ToProto()); - TF_ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_destination_buffer(), - destination().ToProto()); + ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_source_buffer(), + source().ToProto()); + ASSIGN_OR_RETURN(*copy_thunk_proto->mutable_destination_buffer(), + destination().ToProto()); copy_thunk_proto->set_mem_size(size_bytes()); return proto; } @@ -76,11 +77,11 @@ absl::StatusOr> HostToDeviceCopyThunk::FromProto( ThunkInfo thunk_info, const HostToDeviceCopyThunkProto& thunk_proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice src_slice, ShapedSlice::FromProto(thunk_proto.copy_thunk().source_buffer(), buffer_allocations)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice dst_slice, ShapedSlice::FromProto(thunk_proto.copy_thunk().destination_buffer(), buffer_allocations)); diff --git a/third_party/xla/xla/backends/gpu/runtime/infeed_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/infeed_thunk.cc index 0d2ae94255e5ee..f4c4d6b730d012 100644 --- a/third_party/xla/xla/backends/gpu/runtime/infeed_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/infeed_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" @@ -71,8 +72,8 @@ absl::Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { << ShapeUtil::HumanStringWithLayout(dest_slices_[index].shape); se::DeviceAddressBase dest_address = buffer_allocations.GetDeviceAddress(dest_slices_[index++].slice); - TF_RETURN_IF_ERROR(stream.Memcpy(&dest_address, buffer.address(), - buffer.address().size())); + RETURN_IF_ERROR(stream.Memcpy(&dest_address, buffer.address(), + buffer.address().size())); } // Make sure that all dest slices have been copied into. @@ -95,7 +96,7 @@ absl::StatusOr> InfeedThunk::FromProto( std::vector dest_slices(thunk_proto.dest_slices_size()); for (int i = 0; i < dest_slices.size(); i++) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( dest_slices[i], ShapedSlice::FromProto(thunk_proto.dest_slices(i), buffer_allocations)); } @@ -110,8 +111,8 @@ absl::StatusOr InfeedThunk::ToProto() const { InfeedThunkProto* thunk_proto = proto.mutable_infeed_thunk(); for (int i = 0; i < dest_slices_.size(); i++) { - TF_ASSIGN_OR_RETURN(*thunk_proto->add_dest_slices(), - dest_slices_[i].ToProto()); + ASSIGN_OR_RETURN(*thunk_proto->add_dest_slices(), + dest_slices_[i].ToProto()); } return proto; } diff --git a/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc index 3af6ad766ccf95..cf1291cfda3221 100644 --- a/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc +++ b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/device_address.h" #include "xla/stream_executor/gpu/gpu_kernel_registry.h" #include "xla/stream_executor/gpu/make_batch_pointers_kernel.h" @@ -43,7 +44,7 @@ absl::Status MakeBatchPointers(se::Stream* stream, return 128; }(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto kernel, stream_executor::gpu::GpuKernelRegistry::GetGlobalRegistry() .LoadKernel(executor)); diff --git a/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers_test.cc b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers_test.cc index 9d6311c625926e..0c043033584d4c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/make_batch_pointers_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/device_address.h" #include "xla/stream_executor/platform.h" @@ -34,8 +35,8 @@ namespace { using ::testing::ElementsAreArray; static absl::StatusOr GpuExecutor() { - TF_ASSIGN_OR_RETURN(stream_executor::Platform * platform, - PlatformUtil::GetDefaultPlatform()); + ASSIGN_OR_RETURN(stream_executor::Platform * platform, + PlatformUtil::GetDefaultPlatform()); return platform->ExecutorForDevice(0); } diff --git a/third_party/xla/xla/backends/gpu/runtime/memset_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/memset_thunk.cc index fbd59662dbd133..c220a9955480f9 100644 --- a/third_party/xla/xla/backends/gpu/runtime/memset_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/memset_thunk.cc @@ -77,7 +77,7 @@ absl::StatusOr MemzeroThunk::Record( absl::StatusOr> MemzeroThunk::FromProto( ThunkInfo thunk_info, const MemzeroThunkProto& thunk_proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice dest, ShapedSlice::FromProto(thunk_proto.dest_buffer(), buffer_allocations)); return std::make_unique(std::move(thunk_info), dest); @@ -88,8 +88,8 @@ absl::StatusOr MemzeroThunk::ToProto() const { *proto.mutable_thunk_info() = thunk_info().ToProto(); MemzeroThunkProto* memzero_thunk_proto = proto.mutable_memzero_thunk(); - TF_ASSIGN_OR_RETURN(*memzero_thunk_proto->mutable_dest_buffer(), - dest_.ToProto()); + ASSIGN_OR_RETURN(*memzero_thunk_proto->mutable_dest_buffer(), + dest_.ToProto()); return proto; } @@ -133,9 +133,9 @@ absl::StatusOr> Memset32BitValueThunk::FromProto( ThunkInfo thunk_info, const Memset32BitValueThunkProto& thunk_proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest, - BufferAllocation::Slice::FromProto( - thunk_proto.dest_buffer(), buffer_allocations)); + ASSIGN_OR_RETURN(BufferAllocation::Slice dest, + BufferAllocation::Slice::FromProto(thunk_proto.dest_buffer(), + buffer_allocations)); return std::make_unique(std::move(thunk_info), thunk_proto.value(), dest); } @@ -146,8 +146,7 @@ absl::StatusOr Memset32BitValueThunk::ToProto() const { Memset32BitValueThunkProto* memset_thunk_proto = proto.mutable_memset32bit_value_thunk(); - TF_ASSIGN_OR_RETURN(*memset_thunk_proto->mutable_dest_buffer(), - dest_.ToProto()); + ASSIGN_OR_RETURN(*memset_thunk_proto->mutable_dest_buffer(), dest_.ToProto()); memset_thunk_proto->set_value(value_); return proto; } diff --git a/third_party/xla/xla/backends/gpu/runtime/norm_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/norm_thunk.cc index e2e0435830984a..cd7707ada80684 100644 --- a/third_party/xla/xla/backends/gpu/runtime/norm_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/norm_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/runtime/buffer_use.h" @@ -51,7 +52,7 @@ absl::StatusOr> NormThunk::Create( std::optional dscale_slice, std::optional dbias_slice, BufferAllocation::Slice scratch_slice) { - TF_ASSIGN_OR_RETURN(GpuNormConfig config, GpuNormConfig::For(descriptor)); + ASSIGN_OR_RETURN(GpuNormConfig config, GpuNormConfig::For(descriptor)); // Can't use make_unique because the constructor is private. go/totw/134 return absl::WrapUnique(new NormThunk( @@ -132,7 +133,7 @@ absl::Status NormThunk::ExecuteOnStream(const ExecuteParams& params) { RunNormOptions opts; opts.norm_runner = &GetOrCreateRunner(params.stream); - TF_RETURN_IF_ERROR(RunGpuNorm( + RETURN_IF_ERROR(RunGpuNorm( config_, x_se_buffer, scale_se_buffer, y_or_dx_se_buffer, bias_se_buffer, dy_se_buffer, expectation_se_buffer, norm_factor_se_buffer, dscale_se_buffer, dbias_se_buffer, scratch, params.stream, opts)); @@ -148,7 +149,7 @@ absl::Status NormThunk::Initialize(const InitializeParams& params) { // the execution plan while a NCCL collective is running. se::dnn::LazyOpRunner* lazy_runner = GetOrCreateRunner(params.stream).AsNormRunner(); - TF_ASSIGN_OR_RETURN(auto ln_config, config_.AsDnnNormOpConfig()); + ASSIGN_OR_RETURN(auto ln_config, config_.AsDnnNormOpConfig()); return lazy_runner->GetOrCreateRunner(ln_config, params.stream).status(); } @@ -185,49 +186,47 @@ Thunk::BufferUses NormThunk::buffer_uses() const { absl::StatusOr> NormThunk::FromProto( ThunkInfo thunk_info, const NormThunkProto& proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN(GpuNormDescriptor descriptor, - GpuNormDescriptor::FromProto(proto.norm_descriptor())); + ASSIGN_OR_RETURN(GpuNormDescriptor descriptor, + GpuNormDescriptor::FromProto(proto.norm_descriptor())); - TF_ASSIGN_OR_RETURN(auto x, BufferAllocation::Slice::FromProto( - proto.x(), buffer_allocations)); - TF_ASSIGN_OR_RETURN(auto scale, BufferAllocation::Slice::FromProto( - proto.scale(), buffer_allocations)); - TF_ASSIGN_OR_RETURN(auto y_or_dx, BufferAllocation::Slice::FromProto( - proto.y_or_dx(), buffer_allocations)); + ASSIGN_OR_RETURN(auto x, BufferAllocation::Slice::FromProto( + proto.x(), buffer_allocations)); + ASSIGN_OR_RETURN(auto scale, BufferAllocation::Slice::FromProto( + proto.scale(), buffer_allocations)); + ASSIGN_OR_RETURN(auto y_or_dx, BufferAllocation::Slice::FromProto( + proto.y_or_dx(), buffer_allocations)); std::optional bias; if (proto.has_bias()) { - TF_ASSIGN_OR_RETURN(bias, BufferAllocation::Slice::FromProto( - proto.bias(), buffer_allocations)); + ASSIGN_OR_RETURN(bias, BufferAllocation::Slice::FromProto( + proto.bias(), buffer_allocations)); } std::optional expectation; if (proto.has_expectation()) { - TF_ASSIGN_OR_RETURN(expectation, - BufferAllocation::Slice::FromProto(proto.expectation(), - buffer_allocations)); + ASSIGN_OR_RETURN(expectation, BufferAllocation::Slice::FromProto( + proto.expectation(), buffer_allocations)); } std::optional norm_factor; if (proto.has_norm_factor()) { - TF_ASSIGN_OR_RETURN(norm_factor, - BufferAllocation::Slice::FromProto(proto.norm_factor(), - buffer_allocations)); + ASSIGN_OR_RETURN(norm_factor, BufferAllocation::Slice::FromProto( + proto.norm_factor(), buffer_allocations)); } std::optional dy; if (proto.has_dy()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( dy, BufferAllocation::Slice::FromProto(proto.dy(), buffer_allocations)); } std::optional dscale; if (proto.has_dscale()) { - TF_ASSIGN_OR_RETURN(dscale, BufferAllocation::Slice::FromProto( - proto.dscale(), buffer_allocations)); + ASSIGN_OR_RETURN(dscale, BufferAllocation::Slice::FromProto( + proto.dscale(), buffer_allocations)); } std::optional dbias; if (proto.has_dbias()) { - TF_ASSIGN_OR_RETURN(dbias, BufferAllocation::Slice::FromProto( - proto.dbias(), buffer_allocations)); + ASSIGN_OR_RETURN(dbias, BufferAllocation::Slice::FromProto( + proto.dbias(), buffer_allocations)); } - TF_ASSIGN_OR_RETURN(auto scratch, BufferAllocation::Slice::FromProto( - proto.scratch(), buffer_allocations)); + ASSIGN_OR_RETURN(auto scratch, BufferAllocation::Slice::FromProto( + proto.scratch(), buffer_allocations)); return Create(std::move(thunk_info), descriptor, x, scale, y_or_dx, bias, expectation, norm_factor, dy, dscale, dbias, scratch); @@ -240,33 +239,30 @@ absl::StatusOr NormThunk::ToProto() const { NormThunkProto* norm_proto = proto.mutable_norm_thunk(); *norm_proto->mutable_norm_descriptor() = descriptor_.ToProto(); - TF_ASSIGN_OR_RETURN(*norm_proto->mutable_x(), x_buffer_.ToProto()); - TF_ASSIGN_OR_RETURN(*norm_proto->mutable_scale(), scale_buffer_.ToProto()); - TF_ASSIGN_OR_RETURN(*norm_proto->mutable_y_or_dx(), - y_or_dx_buffer_.ToProto()); + ASSIGN_OR_RETURN(*norm_proto->mutable_x(), x_buffer_.ToProto()); + ASSIGN_OR_RETURN(*norm_proto->mutable_scale(), scale_buffer_.ToProto()); + ASSIGN_OR_RETURN(*norm_proto->mutable_y_or_dx(), y_or_dx_buffer_.ToProto()); if (bias_buffer_.has_value()) { - TF_ASSIGN_OR_RETURN(*norm_proto->mutable_bias(), bias_buffer_->ToProto()); + ASSIGN_OR_RETURN(*norm_proto->mutable_bias(), bias_buffer_->ToProto()); } if (expectation_buffer_.has_value()) { - TF_ASSIGN_OR_RETURN(*norm_proto->mutable_expectation(), - expectation_buffer_->ToProto()); + ASSIGN_OR_RETURN(*norm_proto->mutable_expectation(), + expectation_buffer_->ToProto()); } if (norm_factor_buffer_.has_value()) { - TF_ASSIGN_OR_RETURN(*norm_proto->mutable_norm_factor(), - norm_factor_buffer_->ToProto()); + ASSIGN_OR_RETURN(*norm_proto->mutable_norm_factor(), + norm_factor_buffer_->ToProto()); } if (dy_buffer_.has_value()) { - TF_ASSIGN_OR_RETURN(*norm_proto->mutable_dy(), dy_buffer_->ToProto()); + ASSIGN_OR_RETURN(*norm_proto->mutable_dy(), dy_buffer_->ToProto()); } if (dscale_buffer_.has_value()) { - TF_ASSIGN_OR_RETURN(*norm_proto->mutable_dscale(), - dscale_buffer_->ToProto()); + ASSIGN_OR_RETURN(*norm_proto->mutable_dscale(), dscale_buffer_->ToProto()); } if (dbias_buffer_.has_value()) { - TF_ASSIGN_OR_RETURN(*norm_proto->mutable_dbias(), dbias_buffer_->ToProto()); + ASSIGN_OR_RETURN(*norm_proto->mutable_dbias(), dbias_buffer_->ToProto()); } - TF_ASSIGN_OR_RETURN(*norm_proto->mutable_scratch(), - scratch_buffer_.ToProto()); + ASSIGN_OR_RETURN(*norm_proto->mutable_scratch(), scratch_buffer_.ToProto()); return proto; } diff --git a/third_party/xla/xla/backends/gpu/runtime/nvshmem_all_reduce_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/nvshmem_all_reduce_thunk.cc index 30edc2de390528..f1039487feb367 100644 --- a/third_party/xla/xla/backends/gpu/runtime/nvshmem_all_reduce_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/nvshmem_all_reduce_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/runtime/all_reduce_thunk.h" #include "xla/backends/gpu/runtime/collective_thunk.h" @@ -51,9 +52,9 @@ namespace gpu { absl::Status RunNvshmemAllReduce(ReductionKind reduction_kind, std::vector& buffers, se::Stream& stream) { - TF_ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectivesFromRegistry()); - TF_ASSIGN_OR_RETURN(std::unique_ptr nvshmem_comm, - collectives->CreateCommunicator()); + ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectivesFromRegistry()); + ASSIGN_OR_RETURN(std::unique_ptr nvshmem_comm, + collectives->CreateCommunicator()); VLOG(3) << "Performing nvshmem all-reduce from device ordinal: " << *nvshmem_comm->CurrentRank(); @@ -61,7 +62,7 @@ absl::Status RunNvshmemAllReduce(ReductionKind reduction_kind, auto future = nvshmem_comm->AllReduce( buffer.source_buffer, buffer.destination_buffer, buffer.element_type, buffer.element_count, reduction_kind, GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(future.Await()); + RETURN_IF_ERROR(future.Await()); } return absl::OkStatus(); @@ -72,7 +73,7 @@ namespace impl { absl::Status CheckNvshmemImplementableInst(const HloInstruction* inst, Thunk::Kind reduction_op) { for (HloInstruction* operand : inst->operands()) { - TF_RETURN_IF_ERROR(IsValidNvshmemOperand(operand->shape(), reduction_op)); + RETURN_IF_ERROR(IsValidNvshmemOperand(operand->shape(), reduction_op)); } if (!MatchReductionComputation(inst->called_computations().front()) @@ -135,7 +136,7 @@ absl::StatusOr NvshmemAllReduceThunk::ToProto() const { proto.mutable_nvshmem_all_reduce_thunk(); for (const CollectiveThunk::Buffer& buffer : buffers_) { - TF_ASSIGN_OR_RETURN(*thunk_proto->add_buffers(), buffer.ToProto()); + ASSIGN_OR_RETURN(*thunk_proto->add_buffers(), buffer.ToProto()); } *thunk_proto->mutable_collective_config() = config_.config.ToProto(); @@ -151,15 +152,15 @@ NvshmemAllReduceThunk::FromProto( std::vector buffers; buffers.reserve(thunk_proto.buffers_size()); for (const CollectiveBufferProto& buffer_proto : thunk_proto.buffers()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( buffers.emplace_back(), CollectiveThunk::Buffer::FromProto(buffer_proto, buffer_allocations)); } CollectiveConfig config = CollectiveConfig::FromProto(thunk_proto.collective_config()); - TF_ASSIGN_OR_RETURN(ReductionKind reduction_kind, - FromReductionKindProto(thunk_proto.reduction_kind())); + ASSIGN_OR_RETURN(ReductionKind reduction_kind, + FromReductionKindProto(thunk_proto.reduction_kind())); return absl::WrapUnique(new NvshmemAllReduceThunk( std::move(thunk_info), AllReduceConfig{std::move(config), reduction_kind}, @@ -168,10 +169,9 @@ NvshmemAllReduceThunk::FromProto( absl::Status NvshmemAllReduceThunk::RunNvshmemCollective( const ExecuteParams& params, se::Stream& stream) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params.buffer_allocations, buffers_, - config_.config.operand_element_type)); + ASSIGN_OR_RETURN(std::vector device_buffers, + ConvertToDeviceBuffers(params.buffer_allocations, buffers_, + config_.config.operand_element_type)); return ::xla::gpu::RunNvshmemAllReduce(config_.reduction_kind, device_buffers, stream); } diff --git a/third_party/xla/xla/backends/gpu/runtime/nvshmem_collective_permute_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/nvshmem_collective_permute_thunk.cc index 7daa3e1e69348c..dd3ae313b6cfcc 100644 --- a/third_party/xla/xla/backends/gpu/runtime/nvshmem_collective_permute_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/nvshmem_collective_permute_thunk.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/runtime/collective_permute_thunk.h" #include "xla/backends/gpu/runtime/collective_thunk.h" @@ -116,7 +117,7 @@ NvshmemCollectivePermuteThunk::NvshmemCollectivePermuteThunk( absl::Status NvshmemCollectivePermuteThunk::Initialize( const InitializeParams& params) { - TF_RETURN_IF_ERROR(NvshmemCollectiveThunk::Initialize(params)); + RETURN_IF_ERROR(NvshmemCollectiveThunk::Initialize(params)); if (p2p_memcpy_enabled_) { return absl::InvalidArgumentError( @@ -144,7 +145,7 @@ absl::StatusOr NvshmemCollectivePermuteThunk::ToProto() const { *thunk_proto->mutable_p2p_config() = P2PConfigToProto(config_); for (const CollectiveThunk::Buffer& buffer : buffers_) { - TF_ASSIGN_OR_RETURN(*thunk_proto->add_buffers(), buffer.ToProto()); + ASSIGN_OR_RETURN(*thunk_proto->add_buffers(), buffer.ToProto()); } thunk_proto->set_p2p_memcpy_enabled(p2p_memcpy_enabled_); @@ -159,13 +160,13 @@ NvshmemCollectivePermuteThunk::FromProto( std::vector buffers; buffers.reserve(thunk_proto.buffers_size()); for (const CollectiveBufferProto& buffer_proto : thunk_proto.buffers()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( buffers.emplace_back(), CollectiveThunk::Buffer::FromProto(buffer_proto, buffer_allocations)); } - TF_ASSIGN_OR_RETURN(P2PConfig config, - P2PConfigFromProto(thunk_proto.p2p_config())); + ASSIGN_OR_RETURN(P2PConfig config, + P2PConfigFromProto(thunk_proto.p2p_config())); return absl::WrapUnique( new NvshmemCollectivePermuteThunk(std::move(thunk_info), @@ -175,14 +176,13 @@ NvshmemCollectivePermuteThunk::FromProto( absl::Status NvshmemCollectivePermuteThunk::RunNvshmemCollective( const ExecuteParams& params, se::Stream& stream) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector device_buffers, ConvertToDeviceBuffers(params.buffer_allocations, std::vector(buffers_), config_.config.operand_element_type)); - TF_ASSIGN_OR_RETURN( - const int64_t current_id, - GetCollectiveCurrentId(params.collective_params, config_)); + ASSIGN_OR_RETURN(const int64_t current_id, + GetCollectiveCurrentId(params.collective_params, config_)); std::string device_string = CollectiveThunk::GetDeviceString(*params.collective_params); @@ -198,9 +198,9 @@ absl::Status RunCollectivePermute(P2PConfig::SourceTargetMapEntry source_target, se::Stream& stream, absl::string_view device_string, int64_t current_id) { - TF_ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectivesFromRegistry()); - TF_ASSIGN_OR_RETURN(std::unique_ptr nvshmem_comm, - collectives->CreateCommunicator()); + ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectivesFromRegistry()); + ASSIGN_OR_RETURN(std::unique_ptr nvshmem_comm, + collectives->CreateCommunicator()); int device_ordinal = stream.parent()->device_ordinal(); @@ -234,7 +234,7 @@ absl::Status RunCollectivePermute(P2PConfig::SourceTargetMapEntry source_target, auto send_future = nvshmem_comm->Send( dest_addr, src_addr, buffer.element_type, buffer.element_count, RankId(*target_id), GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(send_future.Await()); + RETURN_IF_ERROR(send_future.Await()); } } if (source_id) { @@ -244,7 +244,7 @@ absl::Status RunCollectivePermute(P2PConfig::SourceTargetMapEntry source_target, VLOG(1) << "CollectivePermute: rank " << device_ordinal << " receiving data from source " << *source_id; - TF_RETURN_IF_ERROR(nvshmem_comm->Barrier(GpuCollectives::On(stream))); + RETURN_IF_ERROR(nvshmem_comm->Barrier(GpuCollectives::On(stream))); } if (!source_id) { @@ -252,8 +252,8 @@ absl::Status RunCollectivePermute(P2PConfig::SourceTargetMapEntry source_target, VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", device_string); for (DeviceBufferPair& buffer : buffers) { - TF_RETURN_IF_ERROR(stream.MemZero(&buffer.destination_buffer, - buffer.destination_buffer.size())); + RETURN_IF_ERROR(stream.MemZero(&buffer.destination_buffer, + buffer.destination_buffer.size())); } } @@ -271,8 +271,8 @@ absl::Status RunCollectivePermute(P2PConfig::SourceTargetMapEntry source_target, // Check if the operation is implementable with NVSHMEM for (const auto& operand : inst->operands()) { - TF_RETURN_IF_ERROR(IsValidNvshmemOperand(operand->shape(), - Thunk::kNvshmemCollectivePermute)); + RETURN_IF_ERROR(IsValidNvshmemOperand(operand->shape(), + Thunk::kNvshmemCollectivePermute)); } // Check if all source-target pairs are valid diff --git a/third_party/xla/xla/backends/gpu/runtime/nvshmem_collective_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/nvshmem_collective_thunk.cc index ffecf7458e4b70..2f706c27c590d6 100644 --- a/third_party/xla/xla/backends/gpu/runtime/nvshmem_collective_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/nvshmem_collective_thunk.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/runtime/collective_clique_requests.h" @@ -87,8 +88,8 @@ NvshmemCollectiveThunk::NvshmemCollectiveThunk(Kind kind, ThunkInfo thunk_info, : Thunk(kind, thunk_info) {} absl::StatusOr GetNvshmemCollectivesFromRegistry() { - TF_ASSIGN_OR_RETURN(xla::Collectives * collectives, - xla::CollectivesRegistry::Get("gpu", "nvshmem")); + ASSIGN_OR_RETURN(xla::Collectives * collectives, + xla::CollectivesRegistry::Get("gpu", "nvshmem")); return tsl::down_cast(collectives); } @@ -96,15 +97,15 @@ absl::Status NvshmemCollectiveThunk::Prepare(const PrepareParams& params) { TF_RET_CHECK(params.collective_params && params.collective_params->device_assn); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GpuCliqueKey clique_key, GetGpuCliqueKey(*params.collective_params, config().replica_groups, config().group_mode)); - TF_ASSIGN_OR_RETURN(std::vector> device_groups, - GetParticipatingDevicesGroups( - *params.collective_params->device_assn, - config().replica_groups, config().group_mode)); + ASSIGN_OR_RETURN(std::vector> device_groups, + GetParticipatingDevicesGroups( + *params.collective_params->device_assn, + config().replica_groups, config().group_mode)); // Sort device groups: RequestClique expects pre-sorted groups. absl::c_for_each(device_groups, [](auto& group) { absl::c_sort(group); }); @@ -129,7 +130,7 @@ absl::Status NvshmemCollectiveThunk::ExecuteOnStream( const ExecuteParams& params) { VLOG(1) << absl::StreamFormat("Starting %s.", Thunk::KindToString(kind())); // Launch collective operation on the main stream. - TF_RETURN_IF_ERROR(RunNvshmemCollective(params, *params.stream)); + RETURN_IF_ERROR(RunNvshmemCollective(params, *params.stream)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/runtime/nvshmem_recv_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/nvshmem_recv_thunk.cc index c5c5846fbfbb07..dc425dd6b6ee95 100644 --- a/third_party/xla/xla/backends/gpu/runtime/nvshmem_recv_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/nvshmem_recv_thunk.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/runtime/collective_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_collective_thunk.h" @@ -85,7 +86,7 @@ absl::StatusOr NvshmemRecvThunk::ToProto() const { NvshmemRecvThunkProto* thunk_proto = proto.mutable_nvshmem_recv_thunk(); *thunk_proto->mutable_config() = P2PConfigToProto(config_); - TF_ASSIGN_OR_RETURN(*thunk_proto->mutable_buffer(), buffer_.ToProto()); + ASSIGN_OR_RETURN(*thunk_proto->mutable_buffer(), buffer_.ToProto()); thunk_proto->set_hlo_name(hlo_name_); return proto; @@ -95,11 +96,10 @@ absl::StatusOr> NvshmemRecvThunk::FromProto( ThunkInfo thunk_info, const NvshmemRecvThunkProto& thunk_proto, absl::Span buffer_allocations, std::shared_ptr absl_nonnull buffer_addresses) { - TF_ASSIGN_OR_RETURN(P2PConfig config, - P2PConfigFromProto(thunk_proto.config())); - TF_ASSIGN_OR_RETURN(CollectiveThunk::Buffer buffer, - CollectiveThunk::Buffer::FromProto(thunk_proto.buffer(), - buffer_allocations)); + ASSIGN_OR_RETURN(P2PConfig config, P2PConfigFromProto(thunk_proto.config())); + ASSIGN_OR_RETURN(CollectiveThunk::Buffer buffer, + CollectiveThunk::Buffer::FromProto(thunk_proto.buffer(), + buffer_allocations)); return absl::WrapUnique(new NvshmemRecvThunk( std::move(thunk_info), std::move(config), buffer, @@ -107,23 +107,22 @@ absl::StatusOr> NvshmemRecvThunk::FromProto( } absl::Status NvshmemRecvThunk::Initialize(const InitializeParams& params) { - TF_RETURN_IF_ERROR(NvshmemCollectiveThunk::Initialize(params)); + RETURN_IF_ERROR(NvshmemCollectiveThunk::Initialize(params)); return absl::OkStatus(); } absl::Status NvshmemRecvThunk::RunNvshmemCollective(const ExecuteParams& params, se::Stream& stream) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params.buffer_allocations, {buffer_}, - config_.config.operand_element_type)); + ASSIGN_OR_RETURN(std::vector device_buffers, + ConvertToDeviceBuffers(params.buffer_allocations, {buffer_}, + config_.config.operand_element_type)); TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair."; GlobalDeviceId global_device_id = params.collective_params->global_device_id; - TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id, - params.collective_params->device_assn->LogicalIdForDevice( - global_device_id)); + ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id, + params.collective_params->device_assn->LogicalIdForDevice( + global_device_id)); const int64_t current_id = config_.config.group_mode == CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_CROSS_REPLICA @@ -162,9 +161,9 @@ absl::Status NvshmemRecvThunk::RunNvshmemCollective(const ExecuteParams& params, return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectivesFromRegistry()); - TF_ASSIGN_OR_RETURN(std::unique_ptr nvshmem_comm, - collectives->CreateCommunicator()); + ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectivesFromRegistry()); + ASSIGN_OR_RETURN(std::unique_ptr nvshmem_comm, + collectives->CreateCommunicator()); VLOG(1) << "Running Recv operation" << " element_type=" << buffer.element_type << " destination_buffer=" << buffer.destination_buffer.opaque() @@ -174,8 +173,8 @@ absl::Status NvshmemRecvThunk::RunNvshmemCollective(const ExecuteParams& params, auto recv_future = nvshmem_comm->Recv( buffer.destination_buffer, buffer.source_buffer, buffer.element_type, buffer.element_count, RankId(*source_id), GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(recv_future.Await()); - TF_RETURN_IF_ERROR(nvshmem_comm->Quiet(GpuCollectives::On(stream))); + RETURN_IF_ERROR(recv_future.Await()); + RETURN_IF_ERROR(nvshmem_comm->Quiet(GpuCollectives::On(stream))); return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/runtime/nvshmem_send_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/nvshmem_send_thunk.cc index 1ee5f8d435d90b..57091ecbc97896 100644 --- a/third_party/xla/xla/backends/gpu/runtime/nvshmem_send_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/nvshmem_send_thunk.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/collectives/gpu_collectives.h" #include "xla/backends/gpu/runtime/collective_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_collective_thunk.h" @@ -87,7 +88,7 @@ absl::StatusOr NvshmemSendThunk::ToProto() const { NvshmemSendThunkProto* nvshmem_proto = proto.mutable_nvshmem_send_thunk(); *nvshmem_proto->mutable_p2p_config() = P2PConfigToProto(config_); nvshmem_proto->set_hlo_name(hlo_name_); - TF_ASSIGN_OR_RETURN(*nvshmem_proto->mutable_buffer(), buffer_.ToProto()); + ASSIGN_OR_RETURN(*nvshmem_proto->mutable_buffer(), buffer_.ToProto()); return proto; } @@ -97,10 +98,10 @@ absl::StatusOr> NvshmemSendThunk::FromProto( absl::Span buffer_allocations, std::shared_ptr absl_nonnull buffer_addresses) { TF_RET_CHECK(buffer_addresses != nullptr); - TF_ASSIGN_OR_RETURN(P2PConfig p2p_config, - P2PConfigFromProto(proto.p2p_config())); + ASSIGN_OR_RETURN(P2PConfig p2p_config, + P2PConfigFromProto(proto.p2p_config())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( CollectiveThunk::Buffer buffer, CollectiveThunk::Buffer::FromProto(proto.buffer(), buffer_allocations)); @@ -111,23 +112,22 @@ absl::StatusOr> NvshmemSendThunk::FromProto( absl::Status NvshmemSendThunk::Initialize(const InitializeParams& params) { VLOG(3) << "Initializing NvshmemSendThunk for: " << hlo_name_; - TF_RETURN_IF_ERROR(NvshmemCollectiveThunk::Initialize(params)); + RETURN_IF_ERROR(NvshmemCollectiveThunk::Initialize(params)); return absl::OkStatus(); } absl::Status NvshmemSendThunk::RunNvshmemCollective(const ExecuteParams& params, se::Stream& stream) { - TF_ASSIGN_OR_RETURN( - std::vector device_buffers, - ConvertToDeviceBuffers(params.buffer_allocations, {buffer_}, - config_.config.operand_element_type)); + ASSIGN_OR_RETURN(std::vector device_buffers, + ConvertToDeviceBuffers(params.buffer_allocations, {buffer_}, + config_.config.operand_element_type)); TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair."; GlobalDeviceId global_device_id = params.collective_params->global_device_id; - TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id, - params.collective_params->device_assn->LogicalIdForDevice( - global_device_id)); + ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id, + params.collective_params->device_assn->LogicalIdForDevice( + global_device_id)); const int64_t current_id = config_.config.group_mode == CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_CROSS_REPLICA @@ -171,9 +171,9 @@ absl::Status NvshmemSendThunk::RunNvshmemCollective(const ExecuteParams& params, return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectivesFromRegistry()); - TF_ASSIGN_OR_RETURN(std::unique_ptr nvshmem_comm, - collectives->CreateCommunicator()); + ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectivesFromRegistry()); + ASSIGN_OR_RETURN(std::unique_ptr nvshmem_comm, + collectives->CreateCommunicator()); VLOG(1) << "Running Send operation" << " element_type=" << buffer.element_type << " destination_buffer=" << buffer.destination_buffer.opaque() @@ -183,8 +183,8 @@ absl::Status NvshmemSendThunk::RunNvshmemCollective(const ExecuteParams& params, auto send_future = nvshmem_comm->Send( buffer.destination_buffer, buffer.source_buffer, buffer.element_type, buffer.element_count, RankId(*target_id), GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(send_future.Await()); - TF_RETURN_IF_ERROR(nvshmem_comm->Quiet(GpuCollectives::On(stream))); + RETURN_IF_ERROR(send_future.Await()); + RETURN_IF_ERROR(nvshmem_comm->Quiet(GpuCollectives::On(stream))); return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/runtime/outfeed_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/outfeed_thunk.cc index bbdbc92f7b0821..3eb41ffb628e08 100644 --- a/third_party/xla/xla/backends/gpu/runtime/outfeed_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/outfeed_thunk.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/service/buffer_assignment.h" @@ -110,9 +111,9 @@ absl::Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { // TODO(b/111309141): Run this on a separate stream so it doesn't block // the GPU from doing work during the transfer. - TF_RETURN_IF_ERROR(stream.Memcpy(buffer->destination()->untyped_data(), - data_address, buffer->length())); - TF_RETURN_IF_ERROR(stream.DoHostCallback([&buffer]() { buffer->Done(); })); + RETURN_IF_ERROR(stream.Memcpy(buffer->destination()->untyped_data(), + data_address, buffer->length())); + RETURN_IF_ERROR(stream.DoHostCallback([&buffer]() { buffer->Done(); })); } absl::Status block_status = stream.BlockHostUntilDone(); @@ -131,7 +132,7 @@ absl::StatusOr> OutfeedThunk::FromProto( std::vector source_slices; source_slices.reserve(proto.source_slices_size()); for (const ShapedSliceProto& proto_source_slice : proto.source_slices()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( source_slices.emplace_back(), ShapedSlice::FromProto(proto_source_slice, source_allocations)); } @@ -145,9 +146,8 @@ absl::StatusOr OutfeedThunk::ToProto() const { *thunk_proto.mutable_thunk_info() = thunk_info().ToProto(); for (const ShapedSlice& shaped_slice : source_slices_) { - TF_ASSIGN_OR_RETURN( - *thunk_proto.mutable_outfeed_thunk()->add_source_slices(), - shaped_slice.ToProto()); + ASSIGN_OR_RETURN(*thunk_proto.mutable_outfeed_thunk()->add_source_slices(), + shaped_slice.ToProto()); } return thunk_proto; diff --git a/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.cc b/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.cc index 106985589948d0..8ea09395c9f311 100644 --- a/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.cc +++ b/third_party/xla/xla/backends/gpu/runtime/p2p_thunk_common.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/IR/BuiltinAttributes.h" #include "xla/backends/gpu/runtime/collective_params.h" #include "xla/backends/gpu/runtime/collective_thunk.h" @@ -53,8 +54,8 @@ absl::StatusOr>> GetSourceTargetPairs( absl::StrCat("expecting send/recv op with string attribute ", kSendRecvSourceTargetPairsAttr)); } - TF_ASSIGN_OR_RETURN(std::vector replica_groups, - ParseReplicaGroupsOnly(src_dst_string.str())); + ASSIGN_OR_RETURN(std::vector replica_groups, + ParseReplicaGroupsOnly(src_dst_string.str())); std::vector> source_target_pairs; source_target_pairs.reserve(replica_groups.size()); for (const ReplicaGroup& replica_group : replica_groups) { @@ -121,7 +122,7 @@ P2PConfig GetP2PConfigForSendRecv(const HloSendRecvInstruction* instr, absl::StatusOr GetCollectiveCurrentId( CollectiveParams* collective_params, const P2PConfig& config) { GlobalDeviceId global_device_id = collective_params->global_device_id; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const DeviceAssignment::LogicalID current_logical_id, collective_params->device_assn->LogicalIdForDevice(global_device_id)); const int64_t current_id = diff --git a/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all.cc b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all.cc index 0abd9aed3ae5cd..1949b3890b23a3 100644 --- a/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all.cc +++ b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/core/collectives/symmetric_memory.h" #include "xla/primitive_util.h" #include "xla/stream_executor/device_address.h" @@ -49,10 +50,8 @@ absl::Status LaunchTypedKernel( int64_t num_row_elements) { using KernelTrait = se::gpu::RaggedAllToAllKernel; - TF_ASSIGN_OR_RETURN( - auto kernel, - se::gpu::GpuKernelRegistry::GetGlobalRegistry().LoadKernel( - stream->parent())); + ASSIGN_OR_RETURN(auto kernel, se::gpu::GpuKernelRegistry::GetGlobalRegistry() + .LoadKernel(stream->parent())); return kernel.Launch(thread_dims, block_dims, stream, input_buffer, output_ptrs, input_offsets_buffer, send_sizes_buffer, @@ -72,10 +71,8 @@ absl::Status LaunchTypedKernelWithSymmetricMemory( using KernelTrait = se::gpu::RaggedAllToAllWithSymmetricMemoryKernel; - TF_ASSIGN_OR_RETURN( - auto kernel, - se::gpu::GpuKernelRegistry::GetGlobalRegistry().LoadKernel( - stream->parent())); + ASSIGN_OR_RETURN(auto kernel, se::gpu::GpuKernelRegistry::GetGlobalRegistry() + .LoadKernel(stream->parent())); return kernel.Launch(thread_dims, block_dims, stream, input_buffer, output_ptrs_symmetric_memory, output_sym_offset, diff --git a/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc index 736521227fbd8d..5ab7082775e96d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/ragged_all_to_all_thunk.cc @@ -1000,8 +1000,8 @@ absl::Status RunOneShotRaggedAllToAllWithNccl( // by incoming P2P writes from peers is preserved. se::DeviceAddressBase output_temporary_symmetric_memory_addr = output_sym_mem->addr(); - TF_RETURN_IF_ERROR(stream.MemcpyD2D(&output_temporary_symmetric_memory_addr, - output_buffer, output_buffer.size())); + RETURN_IF_ERROR(stream.MemcpyD2D(&output_temporary_symmetric_memory_addr, + output_buffer, output_buffer.size())); } // 1. Barrier (Pre-Kernel) // Global synchronization before P2P writes. @@ -1009,14 +1009,14 @@ absl::Status RunOneShotRaggedAllToAllWithNccl( // are ready to receive data. This prevents the kernel from attempting to // write to a peer's memory before that peer has completed the rendezvous // setup. - TF_RETURN_IF_ERROR(xla::gpu::LaunchMultiGpuBarrierWithNccl( + RETURN_IF_ERROR(xla::gpu::LaunchMultiGpuBarrierWithNccl( &stream, num_ranks, rank, barrier_signal_symmetric_memory.get(), barrier_signal_value)); // 2. Execution of RunRaggedAllToAllKernel const int64_t num_updates_per_replica = num_total_updates / num_ranks; - TF_RETURN_IF_ERROR(RunRaggedAllToAllWithSymmetricMemoryKernel( + RETURN_IF_ERROR(RunRaggedAllToAllWithSymmetricMemoryKernel( &stream, element_type, input_buffer, output_sym_mem, output_sym_offset, buffers[2].source_buffer, buffers[3].source_buffer, buffers[4].source_buffer, num_ranks, num_updates_per_replica, @@ -1027,24 +1027,24 @@ absl::Status RunOneShotRaggedAllToAllWithNccl( // We wait for all peers to signal completion. // This guarantees that all P2P writes to our output buffer are complete and // safe to consume. - TF_RETURN_IF_ERROR(xla::gpu::LaunchMultiGpuBarrierWithNccl( + RETURN_IF_ERROR(xla::gpu::LaunchMultiGpuBarrierWithNccl( &stream, num_ranks, rank, barrier_signal_symmetric_memory.get(), barrier_signal_value)); if (!is_zero_copy) { // TODO: b/482045400 - Remove double-copy approach once testing is done. // 4. Copy from temporary symmetric memory to actual output buffer. - TF_RETURN_IF_ERROR(stream.MemcpyD2D(&output_buffer, output_sym_mem->addr(), - output_buffer.size())); + RETURN_IF_ERROR(stream.MemcpyD2D(&output_buffer, output_sym_mem->addr(), + output_buffer.size())); } if (VLOG_IS_ON(6)) { - TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + RETURN_IF_ERROR(stream.BlockHostUntilDone()); se::StreamExecutor* stream_executor = stream.parent(); std::vector input_buffer_host; input_buffer_host.resize(output_buffer.size()); - TF_RETURN_IF_ERROR(stream_executor->SynchronousMemcpyD2H( + RETURN_IF_ERROR(stream_executor->SynchronousMemcpyD2H( input_buffer, input_buffer.size(), input_buffer_host.data())); XLA_VLOG_DEVICE(6, device_ordinal) << "Ragged-all-to-all with NCCL input buffer: " @@ -1052,7 +1052,7 @@ absl::Status RunOneShotRaggedAllToAllWithNccl( std::vector output_buffer_host; output_buffer_host.resize(output_buffer.size()); - TF_RETURN_IF_ERROR(stream_executor->SynchronousMemcpyD2H( + RETURN_IF_ERROR(stream_executor->SynchronousMemcpyD2H( output_buffer, output_buffer.size(), output_buffer_host.data())); XLA_VLOG_DEVICE(6, device_ordinal) << "Ragged-all-to-all with NCCL output before kernel: " @@ -1091,8 +1091,8 @@ absl::Status RunOneShotRaggedAllToAll( // Ensures that all peers have reached this point and their output buffers are // ready to receive data. This prevents the kernel from attempting to write // to a peer's memory before that peer has completed the rendezvous setup. - TF_RETURN_IF_ERROR(LaunchMultiGpuBarrier(&stream, rank, num_ranks, - participants, barrier_signal_value)); + RETURN_IF_ERROR(LaunchMultiGpuBarrier(&stream, rank, num_ranks, participants, + barrier_signal_value)); // 2. Execution of RunRaggedAllToAllKernel const int64_t num_updates_per_replica = num_total_updates / num_ranks; @@ -1102,7 +1102,7 @@ absl::Status RunOneShotRaggedAllToAll( output_ptrs[i] = participants[i].output_buffer.opaque(); } - TF_RETURN_IF_ERROR(RunRaggedAllToAllKernel( + RETURN_IF_ERROR(RunRaggedAllToAllKernel( &stream, element_type, input_buffer, output_ptrs, buffers[2].source_buffer, buffers[3].source_buffer, buffers[4].source_buffer, num_ranks, num_updates_per_replica, @@ -1113,8 +1113,8 @@ absl::Status RunOneShotRaggedAllToAll( // We wait for all peers to signal completion. // This guarantees that all P2P writes to our output buffer are complete and // safe to consume. - TF_RETURN_IF_ERROR(LaunchMultiGpuBarrier(&stream, rank, num_ranks, - participants, barrier_signal_value)); + RETURN_IF_ERROR(LaunchMultiGpuBarrier(&stream, rank, num_ranks, participants, + barrier_signal_value)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/runtime/recv_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/recv_thunk.cc index 8d0e1a8eed8f55..79fce6ab18e422 100644 --- a/third_party/xla/xla/backends/gpu/runtime/recv_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/recv_thunk.cc @@ -67,7 +67,7 @@ RecvThunk::RecvThunk(ThunkInfo thunk_info, const P2PConfig& config, hlo_name_(instr_name) {} absl::Status RecvThunk::Initialize(const InitializeParams& params) { - TF_RETURN_IF_ERROR(CollectiveThunk::Initialize(params)); + RETURN_IF_ERROR(CollectiveThunk::Initialize(params)); return absl::OkStatus(); } @@ -141,13 +141,13 @@ absl::Status RunRecv(DeviceBufferPair& buffer, se::Stream& stream, auto future = comm.Recv(dest_addr, buffer.element_type, buffer.element_count, RankId(*source_id), GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(future.Await()); + RETURN_IF_ERROR(future.Await()); } else { // If there is no source peer, i.e. no sender to this instance, zero out // the destination buffer. XLA_VLOG_DEVICE(3, device_ordinal) << absl::StreamFormat("%s : Recv: Issuing MemZero", device_string); - TF_RETURN_IF_ERROR(stream.MemZero(&dest_addr, dest_addr.size())); + RETURN_IF_ERROR(stream.MemZero(&dest_addr, dest_addr.size())); } return absl::OkStatus(); @@ -169,9 +169,9 @@ absl::Status RecvThunk::RunCollective(const ExecuteParams& params, GlobalDeviceId global_device_id = params.collective_params->global_device_id; - TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id, - params.collective_params->device_assn->LogicalIdForDevice( - global_device_id)); + ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id, + params.collective_params->device_assn->LogicalIdForDevice( + global_device_id)); const int64_t current_id = config_.config.group_mode == CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_CROSS_REPLICA diff --git a/third_party/xla/xla/backends/gpu/runtime/runtime_intrinsics.cc b/third_party/xla/xla/backends/gpu/runtime/runtime_intrinsics.cc index 6ebbcdee3b07b3..7034018c4c88ed 100644 --- a/third_party/xla/xla/backends/gpu/runtime/runtime_intrinsics.cc +++ b/third_party/xla/xla/backends/gpu/runtime/runtime_intrinsics.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/ffi.h" #include "xla/ffi/ffi.h" #include "xla/layout_util.h" @@ -71,9 +72,8 @@ absl::Status AssertionCustomCall( int8_t expected = false; int64_t byte_size = sizeof(int8_t); CHECK_EQ(byte_size, ShapeUtil::ByteSizeOfPrimitiveType(PrimitiveType::PRED)); - TF_RETURN_IF_ERROR( - stream->Memcpy(&expected, buffer.device_memory(), byte_size)); - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + RETURN_IF_ERROR(stream->Memcpy(&expected, buffer.device_memory(), byte_size)); + RETURN_IF_ERROR(stream->BlockHostUntilDone()); if (!static_cast(expected)) { return Internal("%s", error_msg); } @@ -92,14 +92,14 @@ absl::StatusOr ConvertToLiteral(se::Stream* stream, Shape shape = ShapeUtil::MakeShape(arg.element_type(), arg.dimensions()); LayoutUtil::SetToDefaultLayout(&shape); - TF_ASSIGN_OR_RETURN(Literal literal, Literal::Make(shape)); + ASSIGN_OR_RETURN(Literal literal, Literal::Make(shape)); int64_t size_bytes = arg.size_bytes(); - TF_ASSIGN_OR_RETURN(std::unique_ptr host_buffer, - stream->parent()->HostMemoryAllocate(size_bytes)); - TF_RETURN_IF_ERROR( + ASSIGN_OR_RETURN(std::unique_ptr host_buffer, + stream->parent()->HostMemoryAllocate(size_bytes)); + RETURN_IF_ERROR( stream->Memcpy(literal.untyped_data(), arg.device_memory(), size_bytes)); - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + RETURN_IF_ERROR(stream->BlockHostUntilDone()); return literal; } @@ -130,8 +130,8 @@ absl::Status DebugPrintCustomCall(se::Stream* stream, ffi::RemainingArgs args, return absl::FailedPreconditionError(absl::Substitute( "Missing formatter for argument $0 in debug print custom call", i)); } - TF_ASSIGN_OR_RETURN(Literal literal, - ConvertToLiteral(stream, args_buffers[i])); + ASSIGN_OR_RETURN(Literal literal, + ConvertToLiteral(stream, args_buffers[i])); formatted = absl::StrReplaceAll(formatted, {{to_substitute, literal.ToString()}}); @@ -157,15 +157,15 @@ absl::Status AppendToFileCustomCall(se::Stream* stream, ffi::AnyBuffer buffer, } static absl::Mutex host_mutex{absl::kConstInit}; - TF_ASSIGN_OR_RETURN(Literal literal, ConvertToLiteral(stream, buffer)); + ASSIGN_OR_RETURN(Literal literal, ConvertToLiteral(stream, buffer)); auto* env = tsl::Env::Default(); std::string destination{dir}; - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(destination)); + RETURN_IF_ERROR(env->RecursivelyCreateDir(destination)); std::string path = tsl::io::JoinPath(destination, GetUniqueFilenameForHost()); // Supports tensors 2+GB. Should not be serialized as proto. - TF_ASSIGN_OR_RETURN(std::string serialized, literal.SerializeAsString()); + ASSIGN_OR_RETURN(std::string serialized, literal.SerializeAsString()); std::unique_ptr file; std::string filename(path); @@ -173,13 +173,13 @@ absl::Status AppendToFileCustomCall(se::Stream* stream, ffi::AnyBuffer buffer, { absl::MutexLock lock(host_mutex); - TF_RETURN_IF_ERROR(env->NewAppendableFile(filename, &file)); + RETURN_IF_ERROR(env->NewAppendableFile(filename, &file)); tsl::io::RecordWriter writer(file.get()); - TF_RETURN_IF_ERROR(writer.WriteRecord(metadata)); - TF_RETURN_IF_ERROR(writer.WriteRecord(serialized)); + RETURN_IF_ERROR(writer.WriteRecord(metadata)); + RETURN_IF_ERROR(writer.WriteRecord(serialized)); - TF_RETURN_IF_ERROR(writer.Close()); + RETURN_IF_ERROR(writer.Close()); } return absl::OkStatus(); diff --git a/third_party/xla/xla/backends/gpu/runtime/runtime_intrinsics_test.cc b/third_party/xla/xla/backends/gpu/runtime/runtime_intrinsics_test.cc index cf85588137b10b..b9dbb527dfb709 100644 --- a/third_party/xla/xla/backends/gpu/runtime/runtime_intrinsics_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/runtime_intrinsics_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" @@ -55,12 +56,12 @@ ReadTFRecordIOLiteral(const std::string& dir) { auto* env = tsl::Env::Default(); std::vector files; - TF_RETURN_IF_ERROR(env->GetChildren(dir, &files)); + RETURN_IF_ERROR(env->GetChildren(dir, &files)); std::vector> result; for (const std::string& path : files) { std::unique_ptr file; - TF_RETURN_IF_ERROR(tsl::Env::Default()->NewRandomAccessFile( + RETURN_IF_ERROR(tsl::Env::Default()->NewRandomAccessFile( tsl::io::JoinPath(dir, path), &file)); tsl::io::RecordReader reader(file.get()); @@ -73,11 +74,10 @@ ReadTFRecordIOLiteral(const std::string& dir) { if (absl::IsOutOfRange(status)) { break; } - TF_RETURN_IF_ERROR(status); + RETURN_IF_ERROR(status); - TF_RETURN_IF_ERROR(reader.ReadRecord(&offset, &record)); - TF_ASSIGN_OR_RETURN(Literal literal, - Literal::DeserializeFromString(record)); + RETURN_IF_ERROR(reader.ReadRecord(&offset, &record)); + ASSIGN_OR_RETURN(Literal literal, Literal::DeserializeFromString(record)); result.emplace_back(metadata, std::move(literal)); } } diff --git a/third_party/xla/xla/backends/gpu/runtime/select_k_exec_raft.cc b/third_party/xla/xla/backends/gpu/runtime/select_k_exec_raft.cc index ee104c089981c8..ca27bf084b64c0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/select_k_exec_raft.cc +++ b/third_party/xla/xla/backends/gpu/runtime/select_k_exec_raft.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda_bf16.h" #include "raft/core/device_mdspan.hpp" #include "raft/core/mdspan_types.hpp" @@ -62,9 +63,9 @@ class OwningScratchAllocator { // Allocate memory and track ownership absl::StatusOr> AllocateBytes(int64_t byte_size) { - TF_ASSIGN_OR_RETURN(se::ScopedDeviceAddress buffer, - allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false)); + ASSIGN_OR_RETURN(se::ScopedDeviceAddress buffer, + allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); se::DeviceAddress res = *buffer; void* raw_ptr = res.opaque(); diff --git a/third_party/xla/xla/backends/gpu/runtime/select_k_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/select_k_thunk.cc index a53066008232a9..b245138ad4dba3 100644 --- a/third_party/xla/xla/backends/gpu/runtime/select_k_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/select_k_thunk.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/select_k_exec.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" @@ -118,7 +119,7 @@ absl::StatusOr SelectKThunk::ToProto() const { select_k_proto->set_dtype(dtype_); for (const BufferAllocation::Slice& arg : args_) { - TF_ASSIGN_OR_RETURN(*select_k_proto->add_args(), arg.ToProto()); + ASSIGN_OR_RETURN(*select_k_proto->add_args(), arg.ToProto()); } return proto; } @@ -130,7 +131,7 @@ absl::StatusOr> SelectKThunk::FromProto( arguments.reserve(proto.args().size()); for (const xla::buffer_assignment::BufferAllocationSliceProto& arg : proto.args()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( BufferAllocation::Slice slice, BufferAllocation::Slice::FromProto(arg, buffer_allocations)); emitters::KernelArgument argument{Shape{}, slice}; diff --git a/third_party/xla/xla/backends/gpu/runtime/send_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/send_thunk.cc index 5b1843cb6d3a1c..ce6ed330e7aad1 100644 --- a/third_party/xla/xla/backends/gpu/runtime/send_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/send_thunk.cc @@ -65,7 +65,7 @@ SendThunk::SendThunk(ThunkInfo thunk_info, const P2PConfig& config, hlo_name_(instr_name) {} absl::Status SendThunk::Initialize(const InitializeParams& params) { - TF_RETURN_IF_ERROR(CollectiveThunk::Initialize(params)); + RETURN_IF_ERROR(CollectiveThunk::Initialize(params)); return absl::OkStatus(); } @@ -159,7 +159,7 @@ absl::Status RunSend(DeviceBufferPair& buffer, se::Stream& stream, << absl::StreamFormat("target_id = %d, call comm.Send()", target_id); auto future = comm.Send(src_addr, buffer.element_type, buffer.element_count, RankId(target_id), GpuCollectives::On(stream)); - TF_RETURN_IF_ERROR(future.Await()); + RETURN_IF_ERROR(future.Await()); return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/runtime/sequential_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/sequential_thunk.cc index 2f0f80c00c54bd..258c7402b5cf0d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/sequential_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/sequential_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/annotation.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" @@ -70,8 +71,8 @@ absl::StatusOr SequentialThunk::ToProto() const { // empty. proto.mutable_sequential_thunk(); for (const auto& thunk : executor_.thunks()) { - TF_ASSIGN_OR_RETURN(*proto.mutable_sequential_thunk()->add_thunks(), - thunk->ToProto()); + ASSIGN_OR_RETURN(*proto.mutable_sequential_thunk()->add_thunks(), + thunk->ToProto()); } return proto; } @@ -81,8 +82,8 @@ absl::StatusOr> SequentialThunk::FromProto( const Deserializer& deserializer) { ThunkSequence thunk_sequence; for (const auto& sub_thunk_proto : thunk_proto.thunks()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr sub_thunk, - deserializer(sub_thunk_proto)); + ASSIGN_OR_RETURN(std::unique_ptr sub_thunk, + deserializer(sub_thunk_proto)); thunk_sequence.push_back(std::move(sub_thunk)); } diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.cc b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.cc index ce1c5ab95dd3de..0631b37216b9d3 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.cc +++ b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk_buffer_debug_checksum.h" #include "xla/backends/gpu/runtime/thunk_buffer_debug_float_check.h" @@ -48,15 +49,15 @@ absl::StatusOr ThunkBufferDebugPass::Run( switch (mode_) { case Mode::kChecksum: - TF_RETURN_IF_ERROR(RunChecksumPassInternal(thunk_sequence, debug_options, - hlo_module, allocator)); + RETURN_IF_ERROR(RunChecksumPassInternal(thunk_sequence, debug_options, + hlo_module, allocator)); break; case Mode::kFloatChecker: - TF_RETURN_IF_ERROR(RunFloatCheckPassInternal( - thunk_sequence, debug_options, hlo_module, allocator)); + RETURN_IF_ERROR(RunFloatCheckPassInternal(thunk_sequence, debug_options, + hlo_module, allocator)); break; case Mode::kBufferSaver: - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( RunDebugSaverInserter(thunk_sequence, debug_options, *hlo_module)); break; } diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk_pass_pipeline.cc b/third_party/xla/xla/backends/gpu/runtime/thunk_pass_pipeline.cc index 1ffc568063e728..0deae41b30f195 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk_pass_pipeline.cc +++ b/third_party/xla/xla/backends/gpu/runtime/thunk_pass_pipeline.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/stream_executor/device_description.h" @@ -37,9 +38,9 @@ absl::StatusOr ThunkPassPipeline::Run( bool changed = false; for (const auto& pass : passes_) { VLOG(1) << "Running ThunkPass: " << pass->name(); - TF_ASSIGN_OR_RETURN(bool pass_changed, - pass->Run(thunk_sequence, debug_options, hlo_module, - device_info, allocator)); + ASSIGN_OR_RETURN(bool pass_changed, + pass->Run(thunk_sequence, debug_options, hlo_module, + device_info, allocator)); changed |= pass_changed; } return changed; diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization.cc b/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization.cc index d91cef7c75d0e8..1df8a8bd8e2b4d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization.cc +++ b/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "xla/backends/cpu/target_machine_options.h" @@ -48,6 +49,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/device_to_device_copy_thunk.h" #include "xla/backends/gpu/runtime/device_to_host_copy_thunk.h" #include "xla/backends/gpu/runtime/dynamic_memcpy_thunk.h" +#include "xla/backends/gpu/runtime/dynamic_slice_fusion_v2_thunk.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" #include "xla/backends/gpu/runtime/fft_thunk.h" #include "xla/backends/gpu/runtime/gemm_thunk.h" @@ -61,7 +63,6 @@ limitations under the License. #include "xla/backends/gpu/runtime/memset_thunk.h" #include "xla/backends/gpu/runtime/norm_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_all_reduce_thunk.h" -#include "xla/backends/gpu/runtime/nvshmem_collective_permute_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_collective_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_recv_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_send_thunk.h" @@ -119,8 +120,8 @@ absl::StatusOr> DeserializeThunkProtoImpl( std::shared_ptr nvshmem_buffer_addresses, const std::optional& cpu_target_machine_options) { - TF_ASSIGN_OR_RETURN(Thunk::ThunkInfo thunk_info, - Thunk::ThunkInfo::FromProto(thunk_proto.thunk_info())); + ASSIGN_OR_RETURN(Thunk::ThunkInfo thunk_info, + Thunk::ThunkInfo::FromProto(thunk_proto.thunk_info())); auto deserializer = [&](const ThunkProto& thunk_proto) { return DeserializeThunkProtoImpl( thunk_proto, buffer_allocations, hlo_module, platform_name, @@ -225,6 +226,21 @@ absl::StatusOr> DeserializeThunkProtoImpl( thunk_proto.dynamic_slice_thunk(), buffer_allocations, deserializer); } + case ThunkProto::kDynamicSliceFusionThunk: { + auto deserializer = + [&](const ThunkProto& thunk_proto, + absl::Span custom_allocations) { + return DeserializeThunkProtoImpl( + thunk_proto, custom_allocations, hlo_module, platform_name, + host_executable_async_events_map, + host_send_recv_async_events_map, async_execution_map, + gpu_compute_capability, symbol_resolver, + nvshmem_buffer_addresses, cpu_target_machine_options); + }; + return DynamicSliceFusionV2Thunk::FromProto( + std::move(thunk_info), thunk_proto.dynamic_slice_fusion_thunk(), + buffer_allocations, deserializer); + } case ThunkProto::kCustomCallThunk: { const auto& cc_proto = thunk_proto.custom_call_thunk(); if (cc_proto.api_version() != @@ -373,7 +389,7 @@ absl::StatusOr DeserializeThunkSequenceProto( std::make_shared(); ThunkSequence sequence; for (const ThunkProto& thunk_proto : thunk_sequence_proto.thunks()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr thunk, DeserializeThunkProtoImpl( thunk_proto, buffer_allocations, hlo_module, platform_name, diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization_test.cc b/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization_test.cc index 39bf9dd3c69055..fddb670b029190 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/async_thunk.h" #include "xla/backends/gpu/runtime/collective_kernel_thunk.h" #include "xla/backends/gpu/runtime/conditional_thunk.h" @@ -47,7 +48,6 @@ limitations under the License. #include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/backends/gpu/runtime/while_thunk.h" #include "xla/ffi/ffi.h" -#include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -74,11 +74,10 @@ absl::StatusOr> DeserializeThunkProto( symbol_resolver = std::nullopt) { ThunkSequenceProto thunk_sequence_proto; *thunk_sequence_proto.add_thunks() = thunk_proto; - TF_ASSIGN_OR_RETURN( - ThunkSequence sequence, - DeserializeThunkSequenceProto(thunk_sequence_proto, buffer_allocations, - hlo_module, platform_name, - gpu_compute_capability, symbol_resolver)); + ASSIGN_OR_RETURN(ThunkSequence sequence, + DeserializeThunkSequenceProto( + thunk_sequence_proto, buffer_allocations, hlo_module, + platform_name, gpu_compute_capability, symbol_resolver)); return std::move(sequence.front()); } @@ -311,6 +310,98 @@ TEST(ThunkProtoDeserializationTest, DeviceToDeviceCopyThunk) { EXPECT_THAT(round_trip_proto, EqualsProto(proto)); } +TEST(ThunkProtoDeserializationTest, DynamicSliceFusionThunk) { + constexpr int64_t kBufferSize = 1024; + ThunkProto proto = ParseTextProtoOrDie( + R"pb( + thunk_info { profile_annotation: "profile_annotation" } + dynamic_slice_fusion_thunk { + parameters { + parameter_number: 0 + parameter_shape { + dimensions: 256 + element_type: F32 + is_dynamic_dimension: false + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + } + slice_shape { + dimensions: 256 + element_type: F32 + is_dynamic_dimension: false + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + } + } + results { + result_number: 0 + result_shape { + dimensions: 256 + element_type: F32 + is_dynamic_dimension: false + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + } + update_shape { + dimensions: 256 + element_type: F32 + is_dynamic_dimension: false + layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } + } + } + parameter_buffers { + buffer_allocation_index: 0 + size: 1024 + element_type: F32 + } + result_buffers { + buffer_allocation_index: 1 + size: 1024 + element_type: F32 + } + slice_allocations { index: 0 size: 1024 color: 0 } + embedded_thunks { + thunks { + thunk_info { profile_annotation: "embedded" } + memzero_thunk { + dest_buffer { + slice { + buffer_allocation_index: 0 + size: 1024 + element_type: F32 + } + shape { + dimensions: 256 + element_type: F32 + is_dynamic_dimension: false + layout { + minor_to_major: 0 + tail_padding_alignment_in_elements: 1 + } + } + } + } + } + } + } + )pb"); + + std::vector buffer_allocations = { + BufferAllocation(/*index=*/0, /*size=*/kBufferSize, /*color=*/0), + BufferAllocation(/*index=*/1, /*size=*/kBufferSize, /*color=*/0)}; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr thunk, + DeserializeThunkProto(proto, buffer_allocations, /*hlo_module=*/nullptr, + kTestPlatformName, se::GpuComputeCapability())); + + EXPECT_EQ(thunk->kind(), Kind::kDynamicSliceFusion); + TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, thunk->ToProto()); + EXPECT_TRUE(round_trip_proto.has_dynamic_slice_fusion_thunk()); + EXPECT_EQ(round_trip_proto.dynamic_slice_fusion_thunk() + .embedded_thunks() + .thunks() + .size(), + 1); +} + TEST(ThunkProtoDeserializationTest, WhileThunk) { ThunkProto proto = ParseTextProtoOrDie( R"pb( diff --git a/third_party/xla/xla/backends/gpu/runtime/topk.cc b/third_party/xla/xla/backends/gpu/runtime/topk.cc index 8572f117020319..c51380fd724b71 100644 --- a/third_party/xla/xla/backends/gpu/runtime/topk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/topk.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/codegen/kernels/custom_kernel.h" #include "xla/stream_executor/gpu/gpu_kernel_registry.h" #include "xla/stream_executor/gpu/topk_kernel.h" @@ -142,9 +143,9 @@ absl::StatusOr GetTypedTopK(std::string name, size_t num_elements, "TopkSpecializer."); } - TF_ASSIGN_OR_RETURN(se::Platform * platform, - se::PlatformManager::PlatformWithName(platform_name)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(se::Platform * platform, + se::PlatformManager::PlatformWithName(platform_name)); + ASSIGN_OR_RETURN( se::KernelLoaderSpec spec, GetTopKKernelForKAndPlatformAndN(k, platform->id(), num_elements)); diff --git a/third_party/xla/xla/backends/gpu/runtime/traced_command.cc b/third_party/xla/xla/backends/gpu/runtime/traced_command.cc index b2caa6bc73159e..e01cd2d6379624 100644 --- a/third_party/xla/xla/backends/gpu/runtime/traced_command.cc +++ b/third_party/xla/xla/backends/gpu/runtime/traced_command.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_state.h" #include "xla/backends/gpu/runtime/traced_command_buffer.h" @@ -69,7 +70,7 @@ TracedCommand::RecordTracedCommand( debug_options.xla_cmd_buffer_trace_cache_size()); }); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto nested_cmd, traced_cmd->GetOrTraceCommandBuffer( execute_params.buffer_allocations, execute_params.stream->parent(), @@ -83,7 +84,7 @@ TracedCommand::RecordTracedCommand( } if (auto* update = std::get_if(&record_action)) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( command_buffer->UpdateChildCommand(update->command, *nested_cmd)); return update->command; } diff --git a/third_party/xla/xla/backends/gpu/runtime/traced_command_buffer.cc b/third_party/xla/xla/backends/gpu/runtime/traced_command_buffer.cc index b1dc4495ab2d29..bb680df9fa3750 100644 --- a/third_party/xla/xla/backends/gpu/runtime/traced_command_buffer.cc +++ b/third_party/xla/xla/backends/gpu/runtime/traced_command_buffer.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/command.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" @@ -94,12 +95,12 @@ absl::StatusOr TracedCommandBuffer::GetOrTraceCommandBuffer( // Create a new entry by calling a user-provided tracing function, move it // to front and return a pointer to cached command buffer. if (entries_[i].command_buffer == nullptr) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( entries_[i].command_buffer, se::TraceCommandBufferFactory::Create(executor, stream, trace)); entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end()); if (priority != se::StreamPriority::Default) { - TF_RETURN_IF_ERROR(entries_[i].command_buffer->SetPriority(priority)); + RETURN_IF_ERROR(entries_[i].command_buffer->SetPriority(priority)); } VLOG(6) << "Command buffer trace cache create new item for command " << trace_cmd_->ToString(0); @@ -110,7 +111,7 @@ absl::StatusOr TracedCommandBuffer::GetOrTraceCommandBuffer( // Create a new entry by calling a user-provided tracing function, replace // the last entry with it, move it to front and return a pointer to cached // command buffer. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( entries_[capacity_ - 1].command_buffer, se::TraceCommandBufferFactory::Create(executor, stream, trace)); entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(), allocs.end()); diff --git a/third_party/xla/xla/backends/gpu/runtime/triangular_solve_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/triangular_solve_thunk.cc index 16881b25aa35ee..f2e89e89929243 100644 --- a/third_party/xla/xla/backends/gpu/runtime/triangular_solve_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/triangular_solve_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/make_batch_pointers.h" #include "xla/backends/gpu/runtime/shaped_slice.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -89,12 +90,12 @@ absl::StatusOr> TriangularSolveThunk::FromProto( ThunkInfo thunk_info, const TriangularSolveThunkProto& proto, absl::Span allocations) { - TF_ASSIGN_OR_RETURN(ShapedSlice a_buffer, - ShapedSlice::FromProto(proto.a_buffer(), allocations)); - TF_ASSIGN_OR_RETURN(ShapedSlice b_buffer, - ShapedSlice::FromProto(proto.b_buffer(), allocations)); - TF_ASSIGN_OR_RETURN(ShapedSlice temp_buffer, - ShapedSlice::FromProto(proto.temp_buffer(), allocations)); + ASSIGN_OR_RETURN(ShapedSlice a_buffer, + ShapedSlice::FromProto(proto.a_buffer(), allocations)); + ASSIGN_OR_RETURN(ShapedSlice b_buffer, + ShapedSlice::FromProto(proto.b_buffer(), allocations)); + ASSIGN_OR_RETURN(ShapedSlice temp_buffer, + ShapedSlice::FromProto(proto.temp_buffer(), allocations)); if (b_buffer.shape.dimensions().size() < 2) { return absl::InvalidArgumentError("Unsupported shape for b"); @@ -131,12 +132,12 @@ absl::StatusOr TriangularSolveThunk::ToProto() const { transpose_a_); } - TF_ASSIGN_OR_RETURN(*triangular_solve_thunk_proto->mutable_a_buffer(), - a_buffer_.ToProto()); - TF_ASSIGN_OR_RETURN(*triangular_solve_thunk_proto->mutable_b_buffer(), - b_buffer_.ToProto()); - TF_ASSIGN_OR_RETURN(*triangular_solve_thunk_proto->mutable_temp_buffer(), - temp_buffer_.ToProto()); + ASSIGN_OR_RETURN(*triangular_solve_thunk_proto->mutable_a_buffer(), + a_buffer_.ToProto()); + ASSIGN_OR_RETURN(*triangular_solve_thunk_proto->mutable_b_buffer(), + b_buffer_.ToProto()); + ASSIGN_OR_RETURN(*triangular_solve_thunk_proto->mutable_temp_buffer(), + temp_buffer_.ToProto()); return proto; } @@ -213,10 +214,10 @@ absl::Status RunTriangularSolve(se::DeviceAddressBase a_data, se::DeviceAddressBase b_pointers(temp_base + batch_size, batch_pointers_bytes); - TF_RETURN_IF_ERROR(MakeBatchPointers(stream, a_data, a_batch_stride, - batch_size, a_pointers)); - TF_RETURN_IF_ERROR(MakeBatchPointers(stream, b_data, b_batch_stride, - batch_size, b_pointers)); + RETURN_IF_ERROR(MakeBatchPointers(stream, a_data, a_batch_stride, + batch_size, a_pointers)); + RETURN_IF_ERROR(MakeBatchPointers(stream, b_data, b_batch_stride, + batch_size, b_pointers)); switch (type) { case F32: { diff --git a/third_party/xla/xla/backends/gpu/runtime/while_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/while_thunk_test.cc index 659df2f7f3645e..36f4ac9868c722 100644 --- a/third_party/xla/xla/backends/gpu/runtime/while_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/while_thunk_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/backends/gpu/runtime/thunk_executor.h" @@ -67,8 +68,8 @@ struct DummyThunk : public Thunk { } static absl::StatusOr> FromProto( const ThunkProto& thunk_proto, Thunk::Kind kind) { - TF_ASSIGN_OR_RETURN(Thunk::ThunkInfo thunk_info, - Thunk::ThunkInfo::FromProto(thunk_proto.thunk_info())); + ASSIGN_OR_RETURN(Thunk::ThunkInfo thunk_info, + Thunk::ThunkInfo::FromProto(thunk_proto.thunk_info())); return std::make_unique(kind, std::move(thunk_info)); } @@ -120,12 +121,12 @@ class IterationLoggerThunk : public Thunk { class KnownTripCountWhileThunkTest : public HloPjRtTestBase { protected: absl::Status ExecuteThunk(Thunk& thunk) { - TF_ASSIGN_OR_RETURN(auto name, PlatformUtil::CanonicalPlatformName("gpu")); - TF_ASSIGN_OR_RETURN(auto* platform, - se::PlatformManager::PlatformWithName(name)); - TF_ASSIGN_OR_RETURN(auto* executor, platform->ExecutorForDevice(0)); - TF_ASSIGN_OR_RETURN(std::unique_ptr stream, - executor->CreateStream()); + ASSIGN_OR_RETURN(auto name, PlatformUtil::CanonicalPlatformName("gpu")); + ASSIGN_OR_RETURN(auto* platform, + se::PlatformManager::PlatformWithName(name)); + ASSIGN_OR_RETURN(auto* executor, platform->ExecutorForDevice(0)); + ASSIGN_OR_RETURN(std::unique_ptr stream, + executor->CreateStream()); stream_executor::StreamExecutorAddressAllocator allocator(executor); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( ServiceExecutableRunOptions(), BufferAllocations({}, 0, &allocator), diff --git a/third_party/xla/xla/backends/gpu/transforms/BUILD b/third_party/xla/xla/backends/gpu/transforms/BUILD index 5cd902c774dcd1..3580946a877ca1 100644 --- a/third_party/xla/xla/backends/gpu/transforms/BUILD +++ b/third_party/xla/xla/backends/gpu/transforms/BUILD @@ -41,6 +41,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -82,6 +83,7 @@ cc_library( "//xla/service/gpu:matmul_utils", "//xla/stream_executor:device_description", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -122,6 +124,7 @@ cc_library( "//xla/service:algorithm_util", "//xla/stream_executor:device_description", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -141,6 +144,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", @@ -173,6 +177,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -222,6 +227,7 @@ cc_library( "//xla/service/gpu/model:gpu_indexing_performance_model", "//xla/stream_executor:device_description", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -318,6 +324,7 @@ cc_library( "//xla/service/gpu:cublas_cudnn", "//xla/stream_executor:dnn", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -394,6 +401,7 @@ cc_library( "//xla/service:hlo_creation_utils", "//xla/service:shape_inference", "//xla/service/gpu:cublas_cudnn", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -438,6 +446,7 @@ cc_library( "//xla/stream_executor:dnn", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -468,6 +477,7 @@ cc_library( "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -553,6 +563,7 @@ cc_library( "//xla/service/gpu:reduction_utils", "//xla/stream_executor:device_description", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -660,6 +671,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:ir_emission_utils", "//xla/stream_executor:device_description", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -700,6 +712,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:ir_emission_utils", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -885,6 +898,7 @@ cc_library( "//xla/service:matmul_indexing_utils", "//xla/stream_executor:engine_options", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings", "@cudnn_frontend_archive//:cudnn_frontend", @@ -910,6 +924,7 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", @@ -963,6 +978,7 @@ cc_library( "//xla/service/gpu:stream_executor_util", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:bind_front", @@ -1004,6 +1020,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -1033,6 +1050,7 @@ xla_cc_test( "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", @@ -1079,6 +1097,7 @@ cc_library( ]) + [ "//xla:xla_data_proto_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/container:flat_hash_map", @@ -1097,6 +1116,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1133,6 +1153,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:op_expander_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -1165,6 +1186,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:op_expander_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -1219,6 +1241,7 @@ cc_library( "//xla/service:matmul_indexing_utils", "//xla/service/gpu:matmul_utils", "//xla/stream_executor:device_description", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", @@ -1260,6 +1283,7 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1304,6 +1328,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1422,6 +1447,7 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/stream_executor:platform", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1473,6 +1499,62 @@ xla_cc_test( ], ) +cc_library( + name = "dynamic_slice_fusion_rewriter_v2", + srcs = ["dynamic_slice_fusion_rewriter_v2.cc"], + hdrs = ["dynamic_slice_fusion_rewriter_v2.h"], + tags = ["gpu"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_constants", + "//xla/service/gpu:ir_emission_utils", + "//xla/stream_executor:platform_id", + "//xla/tsl/platform:status_macros", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "dynamic_slice_fusion_rewriter_v2_test", + srcs = ["dynamic_slice_fusion_rewriter_v2_test.cc"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":dynamic_slice_annotator", + ":dynamic_slice_fusion", + ":dynamic_slice_fusion_rewriter_v2", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:platform_util", + "//xla/service/gpu:backend_configs_cc", + "//xla/stream_executor:platform_id", + "//xla/stream_executor/gpu:gpu_init_impl", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", # fixdeps: keep + ], +) + cc_library( name = "explicit_collectives_group_async_wrapper", srcs = ["explicit_collectives_group_async_wrapper.cc"], @@ -1486,6 +1568,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:backend_configs_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1522,6 +1605,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:backend_configs_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1569,6 +1653,7 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", @@ -1586,6 +1671,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -1671,6 +1757,7 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service/gpu:triton_fusion_analysis", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -1750,6 +1837,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/rocm:rocm_compute_capability", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", @@ -1908,6 +1996,7 @@ cc_library( "//xla/stream_executor:dnn", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:dnn_proto_cc", "//xla/tsl/util:env_var", @@ -1980,6 +2069,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -2025,6 +2115,7 @@ cc_library( "//xla/service/gpu/model:gpu_performance_model_base", "//xla/stream_executor:device_description", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -2074,6 +2165,7 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", @@ -2195,6 +2287,7 @@ cc_library( "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -2281,6 +2374,7 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/stream_executor:device_description", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -2341,6 +2435,7 @@ cc_library( "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", @@ -2362,6 +2457,7 @@ xla_cc_test( "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:device_description", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -2379,6 +2475,7 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -2410,6 +2507,7 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -2445,6 +2543,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -2486,6 +2585,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:reduction_utils", "//xla/stream_executor:device_description", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -2511,6 +2611,7 @@ xla_cc_test( "//xla/service:pattern_matcher", "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", @@ -2600,6 +2701,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -2619,6 +2721,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -2676,6 +2779,7 @@ cc_library( "//xla/hlo/transforms/expanders:op_expander_pass", "//xla/service:hlo_creation_utils", "//xla/service:scatter_utils", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -2713,6 +2817,7 @@ cc_library( "//xla/service:hlo_creation_utils", "//xla/service:scatter_simplifier", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -2759,6 +2864,7 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -2817,6 +2923,7 @@ cc_library( "//xla/stream_executor:device_description", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3118,6 +3225,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/stream_executor:semantic_version", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -3183,6 +3291,7 @@ cc_library( "//xla/service/gpu:gpu_fusible", "//xla/stream_executor:device_description", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -3203,6 +3312,7 @@ xla_cc_test( "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -3222,6 +3332,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:backend_configs_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -3301,6 +3412,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -3346,6 +3458,7 @@ cc_library( "//xla/service:hlo_module_config", "//xla/service/gpu:reduction_utils", "//xla/stream_executor:device_description", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -3388,6 +3501,7 @@ cc_library( "//xla/service:hlo_creation_utils", "//xla/service/gpu:cublas_cudnn", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -3449,6 +3563,7 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3509,6 +3624,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -3559,6 +3675,7 @@ cc_library( "//xla/service:while_loop_unroller", "//xla/service/gpu:backend_configs_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -3598,6 +3715,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service:profile_guided_latency_estimator", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -3640,6 +3758,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service:collective_ops_utils", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -3685,6 +3804,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -3726,6 +3846,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -3765,6 +3886,7 @@ cc_library( "//xla/service/gpu:matmul_utils", "//xla/stream_executor:device_description", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/backends/gpu/transforms/algebraic_simplifier.cc b/third_party/xla/xla/backends/gpu/transforms/algebraic_simplifier.cc index df056576098fea..eb110491d47b03 100644 --- a/third_party/xla/xla/backends/gpu/transforms/algebraic_simplifier.cc +++ b/third_party/xla/xla/backends/gpu/transforms/algebraic_simplifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" @@ -62,15 +63,14 @@ GpuAlgebraicSimplifierVisitor::TryToSinkBroadcastOperandsOfChainedAdds( HloInstruction* new_bcast = add->AddInstruction(HloInstruction::CreateBroadcast( broadcast_0->shape(), new_constant_add, broadcast_0->dimensions())); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + RETURN_IF_ERROR(ReplaceWithNewInstruction( add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, new_bcast, conv))); return true; } absl::Status GpuAlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { - TF_ASSIGN_OR_RETURN(bool replaced, - TryToSinkBroadcastOperandsOfChainedAdds(add)); + ASSIGN_OR_RETURN(bool replaced, TryToSinkBroadcastOperandsOfChainedAdds(add)); if (replaced) { return absl::OkStatus(); } diff --git a/third_party/xla/xla/backends/gpu/transforms/algorithm_checker.cc b/third_party/xla/xla/backends/gpu/transforms/algorithm_checker.cc index 17fa044210c047..f34dfc25c667cf 100644 --- a/third_party/xla/xla/backends/gpu/transforms/algorithm_checker.cc +++ b/third_party/xla/xla/backends/gpu/transforms/algorithm_checker.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/algorithm_util.h" @@ -54,7 +55,7 @@ class AlgorithmCheckerVisitor : public ConstDfsHloVisitorWithDefault { const absl::flat_hash_set& execution_threads = {}) { for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_RETURN_IF_ERROR(computation->Accept(this)); + RETURN_IF_ERROR(computation->Accept(this)); } return absl::OkStatus(); } @@ -100,8 +101,8 @@ class AlgorithmCheckerVisitor : public ConstDfsHloVisitorWithDefault { absl::StatusOr AlgorithmChecker::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_RETURN_IF_ERROR(AlgorithmCheckerVisitor(gpu_compute_capability_) - .RunOnModule(module, execution_threads)); + RETURN_IF_ERROR(AlgorithmCheckerVisitor(gpu_compute_capability_) + .RunOnModule(module, execution_threads)); // No change was made. return false; } diff --git a/third_party/xla/xla/backends/gpu/transforms/alias_passthrough_params.cc b/third_party/xla/xla/backends/gpu/transforms/alias_passthrough_params.cc index 25656aa3be620d..753dd442143d35 100644 --- a/third_party/xla/xla/backends/gpu/transforms/alias_passthrough_params.cc +++ b/third_party/xla/xla/backends/gpu/transforms/alias_passthrough_params.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" @@ -55,7 +56,7 @@ absl::StatusOr AliasPassthroughParams::RunImpl( continue; } - TF_RETURN_IF_ERROR(module->input_output_alias_config().SetUpAlias( + RETURN_IF_ERROR(module->input_output_alias_config().SetUpAlias( /*output_index=*/{i}, /*param_number=*/root->operand(i)->parameter_number(), /*param_index=*/{})); diff --git a/third_party/xla/xla/backends/gpu/transforms/async_wrapper.cc b/third_party/xla/xla/backends/gpu/transforms/async_wrapper.cc index 86cfb6a9e89792..e76cf457155283 100644 --- a/third_party/xla/xla/backends/gpu/transforms/async_wrapper.cc +++ b/third_party/xla/xla/backends/gpu/transforms/async_wrapper.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -57,11 +58,10 @@ absl::StatusOr AsyncWrapper::RunImpl( "AsyncWrapper will make the following instruction async:\n", instruction->ToString())); // If the predicate matches, then wrap the instructions in async blocks. - TF_RETURN_IF_ERROR( - computation - ->CreateAsyncInstructions(instruction, - {ShapeUtil::MakeScalarShape(U32)}) - .status()); + RETURN_IF_ERROR(computation + ->CreateAsyncInstructions( + instruction, {ShapeUtil::MakeScalarShape(U32)}) + .status()); changed = true; continue; } diff --git a/third_party/xla/xla/backends/gpu/transforms/bitcast_utils.cc b/third_party/xla/xla/backends/gpu/transforms/bitcast_utils.cc index fb9e815f080371..ecb1383bbab46e 100644 --- a/third_party/xla/xla/backends/gpu/transforms/bitcast_utils.cc +++ b/third_party/xla/xla/backends/gpu/transforms/bitcast_utils.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/SmallVector.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/layout.h" @@ -399,7 +400,7 @@ absl::StatusOr CalculateBitcastOfTransposeImpl( // - transpose does not change layout (checks); absl::StatusOr CalculateBitcastOfTranspose( const HloTransposeInstruction* transpose, const Shape& result_shape) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( BitcastParams result, CalculateBitcastOfTransposeImpl( transpose, result_shape, transpose->shape(), diff --git a/third_party/xla/xla/backends/gpu/transforms/block_scaling_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/block_scaling_rewriter.cc index c9145e060833bd..08ac968324133b 100644 --- a/third_party/xla/xla/backends/gpu/transforms/block_scaling_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/block_scaling_rewriter.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/builder/lib/constants.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" @@ -57,11 +58,10 @@ namespace { // Expand builder into a new instruction that will replace the old one. absl::StatusOr ExpandInstructionUsingBuilder( XlaBuilder& builder, HloInstruction* old_instruction) { - TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - TF_ASSIGN_OR_RETURN( - HloComputation * computation, - XlaComputationToHloComputation(xla_computation, - old_instruction->parent()->parent())); + ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); + ASSIGN_OR_RETURN(HloComputation * computation, + XlaComputationToHloComputation( + xla_computation, old_instruction->parent()->parent())); // Fix broadcast layouts (they cannot be inferred correctly). for (HloInstruction* instruction : computation->instructions()) { @@ -114,7 +114,7 @@ absl::StatusOr BuildQuantize(XlaBuilder& builder, // Get block size from output shape. const Shape& quant_shape = output_shape.tuple_shapes(0); const Shape& scale_shape = output_shape.tuple_shapes(1); - TF_ASSIGN_OR_RETURN(int block_size, GetBlockSize(quant_shape, scale_shape)); + ASSIGN_OR_RETURN(int block_size, GetBlockSize(quant_shape, scale_shape)); // Reshape input into blocks. std::vector new_dims(scale_shape.dimensions().begin(), @@ -128,7 +128,7 @@ absl::StatusOr BuildQuantize(XlaBuilder& builder, Shape scalar = ShapeUtil::MakeShape(input_shape.element_type(), {}); XlaOp out = Max(Abs(Parameter(&amax_builder, 0, scalar, "a")), Abs(Parameter(&amax_builder, 1, scalar, "b"))); - TF_ASSIGN_OR_RETURN(XlaComputation amax_comp, amax_builder.Build(out)); + ASSIGN_OR_RETURN(XlaComputation amax_comp, amax_builder.Build(out)); XlaOp amax = Reduce(input_blocks, ConstantLiteral(&builder, Literal(scalar)), amax_comp, {scale_shape.dimensions_size()}); @@ -136,7 +136,7 @@ absl::StatusOr BuildQuantize(XlaBuilder& builder, double emax_value = 1ll << (primitive_util::OverflowExponent(quant_shape.element_type()) - 1); Literal denominator_literal(scalar); - TF_RETURN_IF_ERROR(denominator_literal.SetFromDouble({}, emax_value)); + RETURN_IF_ERROR(denominator_literal.SetFromDouble({}, emax_value)); XlaOp denominator = ConstantLiteral(&builder, denominator_literal); XlaOp amax_norm = Div(amax, denominator); @@ -182,9 +182,9 @@ absl::StatusOr ExpandQuantizeCustomCall( // Build replacement instruction sequence. XlaBuilder builder(std::string(instruction->name())); - TF_RETURN_IF_ERROR(BuildQuantize(builder, instruction->operand(0)->shape(), - instruction->shape()) - .status()); + RETURN_IF_ERROR(BuildQuantize(builder, instruction->operand(0)->shape(), + instruction->shape()) + .status()); return ExpandInstructionUsingBuilder(builder, instruction); } @@ -195,9 +195,9 @@ absl::StatusOr BuildDequantize(XlaOp input_op, XlaOp scale_op, PrimitiveType result_type) { // Get block size from input shapes. XlaBuilder& builder = *input_op.builder(); - TF_ASSIGN_OR_RETURN(Shape input_shape, builder.GetShape(input_op)); - TF_ASSIGN_OR_RETURN(Shape scale_shape, builder.GetShape(scale_op)); - TF_ASSIGN_OR_RETURN(int block_size, GetBlockSize(input_shape, scale_shape)); + ASSIGN_OR_RETURN(Shape input_shape, builder.GetShape(input_op)); + ASSIGN_OR_RETURN(Shape scale_shape, builder.GetShape(scale_op)); + ASSIGN_OR_RETURN(int block_size, GetBlockSize(input_shape, scale_shape)); // Convert input parameters to the same type. input_op = ConvertElementType(input_op, result_type); @@ -241,11 +241,10 @@ absl::StatusOr ExpandDequantizeCustomCall( // Build replacement instruction sequence. XlaBuilder builder(std::string(instruction->name())); - TF_RETURN_IF_ERROR( - BuildDequantize(Parameter(&builder, 0, input_shape, "input"), - Parameter(&builder, 1, scale_shape, "scale"), - instruction->shape().element_type()) - .status()); + RETURN_IF_ERROR(BuildDequantize(Parameter(&builder, 0, input_shape, "input"), + Parameter(&builder, 1, scale_shape, "scale"), + instruction->shape().element_type()) + .status()); return ExpandInstructionUsingBuilder(builder, instruction); } @@ -318,8 +317,8 @@ absl::StatusOr> BuildCudnnScaledDotInput( bool pad_input) { // Get shapes from the inputs. XlaBuilder& builder = *input_op.builder(); - TF_ASSIGN_OR_RETURN(Shape input_shape, builder.GetShape(input_op)); - TF_ASSIGN_OR_RETURN(Shape scale_shape, builder.GetShape(scale_op)); + ASSIGN_OR_RETURN(Shape input_shape, builder.GetShape(input_op)); + ASSIGN_OR_RETURN(Shape scale_shape, builder.GetShape(scale_op)); int64_t rank = input_shape.dimensions().size(); TF_RET_CHECK(rank == 2 || rank == 3); @@ -379,7 +378,7 @@ absl::StatusOr> BuildCudnnScaledDotInput( // TMEM. This transpose can potentially be done in the kernel (at the cost of // using non-vectorized loads or using an extra shared memory buffer). // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x - TF_ASSIGN_OR_RETURN(Shape scale_valid_shape, builder.GetShape(scale_op)); + ASSIGN_OR_RETURN(Shape scale_valid_shape, builder.GetShape(scale_op)); int64_t scale_rows = scale_valid_shape.dimensions(rank - 2); int64_t scale_cols = scale_valid_shape.dimensions(rank - 1); scale_op = @@ -404,20 +403,20 @@ absl::StatusOr BuildCudnnScaledDot(XlaOp lhs_input, XlaOp rhs_input, cudnn_version >= kCudnnSupportsBlockScaledDotWithGlobalScale; // Get inputs from parameters. - TF_ASSIGN_OR_RETURN(auto lhs_ops_and_size, - BuildCudnnScaledDotInput(lhs_input, lhs_scale, block_size, - /*pad_input=*/true)); + ASSIGN_OR_RETURN(auto lhs_ops_and_size, + BuildCudnnScaledDotInput(lhs_input, lhs_scale, block_size, + /*pad_input=*/true)); auto [lhs_input_op, lhs_scale_op, lhs_size] = lhs_ops_and_size; - TF_ASSIGN_OR_RETURN(auto rhs_ops_and_size, - BuildCudnnScaledDotInput(rhs_input, rhs_scale, block_size, - /*pad_input=*/true)); + ASSIGN_OR_RETURN(auto rhs_ops_and_size, + BuildCudnnScaledDotInput(rhs_input, rhs_scale, block_size, + /*pad_input=*/true)); auto [rhs_input_op, rhs_scale_op, rhs_size] = rhs_ops_and_size; // Calculate output shape. XlaBuilder& builder = *lhs_input.builder(); - TF_ASSIGN_OR_RETURN(Shape lhs_shape, builder.GetShape(lhs_input_op)); - TF_ASSIGN_OR_RETURN(Shape rhs_shape, builder.GetShape(rhs_input_op)); + ASSIGN_OR_RETURN(Shape lhs_shape, builder.GetShape(lhs_input_op)); + ASSIGN_OR_RETURN(Shape rhs_shape, builder.GetShape(rhs_input_op)); int rank = lhs_shape.dimensions().size(); std::vector result_dims{lhs_shape.dimensions(rank - 2), rhs_shape.dimensions(rank - 2)}; @@ -465,8 +464,8 @@ absl::StatusOr BuildBlockScaledDotInput( std::optional block_size) { // Get shapes of the input and scales. XlaBuilder& builder = *input_op.builder(); - TF_ASSIGN_OR_RETURN(Shape input_shape, builder.GetShape(input_op)); - TF_ASSIGN_OR_RETURN(Shape scale_shape, builder.GetShape(scale_op)); + ASSIGN_OR_RETURN(Shape input_shape, builder.GetShape(input_op)); + ASSIGN_OR_RETURN(Shape scale_shape, builder.GetShape(scale_op)); // Make sure the input and scale shapes are compatible (scales may be padded). int64_t rank = input_shape.dimensions().size(); @@ -546,13 +545,11 @@ absl::StatusOr BuildBlockScaledDot( } // Build general dot op. - TF_ASSIGN_OR_RETURN( - lhs_op, - BuildBlockScaledDotInput(lhs_op, lhs_scale_op, result_type, block_size)); + ASSIGN_OR_RETURN(lhs_op, BuildBlockScaledDotInput(lhs_op, lhs_scale_op, + result_type, block_size)); if (rhs_scale_op.valid()) { - TF_ASSIGN_OR_RETURN( - rhs_op, BuildBlockScaledDotInput(rhs_op, rhs_scale_op, result_type, - block_size)); + ASSIGN_OR_RETURN(rhs_op, BuildBlockScaledDotInput(rhs_op, rhs_scale_op, + result_type, block_size)); } XlaOp result = DotGeneral(lhs_op, rhs_op, dnums, /*precision_config=*/nullptr, /*preferred_element_type=*/result_type); @@ -587,9 +584,9 @@ absl::StatusOr ExpandBlockScaledDotCustomCall( dnums.add_rhs_batch_dimensions(0); } - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, - dnums, result_type)); + ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dnums, + result_type)); if (inferred_shape != instruction->shape()) { return InvalidArgument("Incorrect output shape for block scaled dot op"); } @@ -618,7 +615,7 @@ absl::StatusOr ExpandBlockScaledDotCustomCall( // Build replacement instruction sequence. XlaBuilder builder(std::string(instruction->name())); auto operands = absl::MakeSpan(instruction->operands()); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( BuildBlockScaledDot(builder, operands[0], operands[1], operands[2], operands.size() >= 4 ? operands[3] : nullptr, operands.size() == 5 ? operands[4] : nullptr, dnums, @@ -638,22 +635,22 @@ absl::StatusOr CreateScaleSwizzleComputation( XlaOp scale_op = Parameter(&builder, 1, scale->shape(), "scale"); // Build swizzle computation. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto ops_and_size, BuildCudnnScaledDotInput(input_op, scale_op, /*block_size=*/std::nullopt, /*pad_input=*/false)); auto [result_input_op, result_scale_op, _] = ops_and_size; Tuple(&builder, {result_input_op, result_scale_op}); - TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); + ASSIGN_OR_RETURN( HloComputation * computation, XlaComputationToHloComputation(xla_computation, input->GetModule())); for (HloInstruction* instr : computation->instructions()) { // Replace reshapes with bitcasts (post layout assignment). if (instr->opcode() == HloOpcode::kReshape) { - TF_RETURN_IF_ERROR(computation->ReplaceInstruction( + RETURN_IF_ERROR(computation->ReplaceInstruction( instr, computation->AddInstruction(HloInstruction::CreateBitcast( instr->shape(), instr->mutable_operand(0))))); } @@ -763,9 +760,9 @@ absl::StatusOr CudnnScaledDotHelper::AddScaleSwizzle( // Add swizzling to LHS/RHS. std::vector swizzled_operands(4); for (int i = 0; i < 2; ++i) { - TF_ASSIGN_OR_RETURN(HloComputation * swizzle_computation, - CreateScaleSwizzleComputation(fusion->operand(i), - fusion->operand(i + 2))); + ASSIGN_OR_RETURN(HloComputation * swizzle_computation, + CreateScaleSwizzleComputation(fusion->operand(i), + fusion->operand(i + 2))); HloInstruction* call = parent->AddInstruction(HloInstruction::CreateCall( swizzle_computation->root_instruction()->shape(), {fusion->mutable_operand(i), fusion->mutable_operand(i + 2)}, @@ -795,7 +792,7 @@ absl::StatusOr CudnnScaledDotHelper::AddScaleSwizzle( if (need_slicing) { HloInstruction* scaled_dot = computation->parameter_instruction(0)->users()[0]; - TF_RETURN_IF_ERROR(SliceScaledDotOperands(scaled_dot)); + RETURN_IF_ERROR(SliceScaledDotOperands(scaled_dot)); } // Create new fusion with the swizzled operands. @@ -803,7 +800,7 @@ absl::StatusOr CudnnScaledDotHelper::AddScaleSwizzle( parent->AddInstruction(HloInstruction::CreateFusion( computation->root_instruction()->shape(), fusion->fusion_kind(), swizzled_operands, fusion->fused_instructions_computation())); - TF_RETURN_IF_ERROR(parent->ReplaceInstruction(fusion, new_fusion)); + RETURN_IF_ERROR(parent->ReplaceInstruction(fusion, new_fusion)); return new_fusion; } diff --git a/third_party/xla/xla/backends/gpu/transforms/composite_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/composite_rewriter.cc index 78214f08fbe35d..c9aaece161193a 100644 --- a/third_party/xla/xla/backends/gpu/transforms/composite_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/composite_rewriter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" @@ -126,7 +127,7 @@ absl::StatusOr CompositeRewriter::RewriteComputation( return absl::InvalidArgumentError( "composite.attributes is not set for xla.scaled_dot"); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( DotDimensionNumbers dot_dimension_numbers, ParseDimensionNumbers(frontend_attrs.at("composite.attributes"))); @@ -198,8 +199,8 @@ absl::StatusOr CompositeRewriter::RewriteComputation( call->shape(), call->mutable_operand(0), call->mutable_operand(1), call->mutable_operand(2), call->mutable_operand(3), dot_dimension_numbers, precision)); - TF_RETURN_IF_ERROR(call->ReplaceAllUsesWith(scaled_dot)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(call)); + RETURN_IF_ERROR(call->ReplaceAllUsesWith(scaled_dot)); + RETURN_IF_ERROR(computation->RemoveInstruction(call)); changed = true; } return changed; @@ -209,7 +210,7 @@ absl::StatusOr CompositeRewriter::RunImpl( HloModule* module, const absl::flat_hash_set&) { bool changed = false; for (HloComputation* computation : module->computations()) { - TF_ASSIGN_OR_RETURN(bool result, RewriteComputation(computation)); + ASSIGN_OR_RETURN(bool result, RewriteComputation(computation)); changed |= result; } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/conv_fusion_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/conv_fusion_rewriter.cc index 5017cafedc728e..0a557399e981e7 100644 --- a/third_party/xla/xla/backends/gpu/transforms/conv_fusion_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/conv_fusion_rewriter.cc @@ -430,17 +430,17 @@ absl::StatusOr RunOnInstruction(HloInstruction* conv) { FusionBackendConfig* fusion_config = gpu_backend_config.mutable_fusion_backend_config(); fusion_config->set_kind(kCuDnnFusionKind); - TF_RETURN_IF_ERROR(conv_fusion->set_backend_config(gpu_backend_config)); + RETURN_IF_ERROR(conv_fusion->set_backend_config(gpu_backend_config)); VLOG(1) << "Replacing convolution " << conv->ToString() << " with " << conv_fusion->ToString(); if (fusion_outputs.size() == 1) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( conv->parent()->ReplaceInstruction(fusion_outputs[0], conv_fusion)); } else { for (int idx = 0; idx < fusion_outputs.size(); ++idx) { HloInstruction* output = fusion_outputs[idx]; - TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction( + RETURN_IF_ERROR(conv->parent()->ReplaceInstruction( output, conv->parent()->AddInstruction( HloInstruction::CreateGetTupleElement(conv_fusion, idx)))); diff --git a/third_party/xla/xla/backends/gpu/transforms/conv_kind_assignment.cc b/third_party/xla/xla/backends/gpu/transforms/conv_kind_assignment.cc index 28a63fbf9ece56..f45c40e00dd2a3 100644 --- a/third_party/xla/xla/backends/gpu/transforms/conv_kind_assignment.cc +++ b/third_party/xla/xla/backends/gpu/transforms/conv_kind_assignment.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -100,9 +101,9 @@ absl::Status CheckTypes(HloInstruction* conv, const se::GpuComputeCapability cc, return absl::OkStatus(); }; - TF_RETURN_IF_ERROR(valid_shape(conv->shape())); - TF_RETURN_IF_ERROR(valid_shape(conv->operand(0)->shape())); - TF_RETURN_IF_ERROR(valid_shape(conv->operand(1)->shape())); + RETURN_IF_ERROR(valid_shape(conv->shape())); + RETURN_IF_ERROR(valid_shape(conv->operand(0)->shape())); + RETURN_IF_ERROR(valid_shape(conv->operand(1)->shape())); return absl::OkStatus(); } @@ -500,7 +501,7 @@ HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution( absl::StatusOr AssignConvKind( HloInstruction* conv, const se::GpuComputeCapability& cc, const se::dnn::VersionInfo& dnn_version) { - TF_RETURN_IF_ERROR(CheckTypes(conv, cc, dnn_version)); + RETURN_IF_ERROR(CheckTypes(conv, cc, dnn_version)); if (ConvolutionMatch m = MatchBackwardInput(conv)) { conv = CreateGpuConv(CONVOLUTION_KIND_DGRAD, conv, conv->mutable_operand(0), *m); @@ -523,15 +524,15 @@ absl::StatusOr RunOnInstruction(HloInstruction* conv, const se::GpuComputeCapability& cc, const se::dnn::VersionInfo& dnn_version) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - TF_ASSIGN_OR_RETURN(HloInstruction * conv_with_kind, - AssignConvKind(conv, cc, dnn_version)); + ASSIGN_OR_RETURN(HloInstruction * conv_with_kind, + AssignConvKind(conv, cc, dnn_version)); if (conv == nullptr) { return false; } VLOG(1) << "Replacing convolution " << conv->ToString() << " with " << conv_with_kind->ToString(); - TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, conv_with_kind)); + RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, conv_with_kind)); return true; } @@ -551,7 +552,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation, bool changed = false; for (HloInstruction* conv : convs) { - TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv, cc, dnn_version)); + ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv, cc, dnn_version)); changed |= result; } return changed; @@ -566,7 +567,7 @@ absl::StatusOr ConvKindAssignment::RunImpl( bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool result, RunOnComputation(computation, compute_capability_, dnn_version_)); changed |= result; diff --git a/third_party/xla/xla/backends/gpu/transforms/conv_padding_legalization.cc b/third_party/xla/xla/backends/gpu/transforms/conv_padding_legalization.cc index 70b76d3a3a2ba4..650a543a96035b 100644 --- a/third_party/xla/xla/backends/gpu/transforms/conv_padding_legalization.cc +++ b/third_party/xla/xla/backends/gpu/transforms/conv_padding_legalization.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -434,7 +435,7 @@ absl::StatusOr ConvPaddingLegalization::RunOnComputation( } } for (HloCustomCallInstruction* instruction : convs) { - TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction)); + ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction)); changed |= [&] { switch (kind) { case CudnnConvKind::kForward: @@ -457,7 +458,7 @@ absl::StatusOr ConvPaddingLegalization::RunImpl( bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); changed |= result; } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/conv_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/conv_rewriter.cc index 8d830ec0fe9119..1912ae25d50ce3 100644 --- a/third_party/xla/xla/backends/gpu/transforms/conv_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/conv_rewriter.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -102,9 +103,9 @@ absl::Status CheckTypes(HloInstruction* conv, const se::GpuComputeCapability cc, return absl::OkStatus(); }; - TF_RETURN_IF_ERROR(valid_shape(conv->shape())); - TF_RETURN_IF_ERROR(valid_shape(conv->operand(0)->shape())); - TF_RETURN_IF_ERROR(valid_shape(conv->operand(1)->shape())); + RETURN_IF_ERROR(valid_shape(conv->shape())); + RETURN_IF_ERROR(valid_shape(conv->operand(0)->shape())); + RETURN_IF_ERROR(valid_shape(conv->operand(1)->shape())); return absl::OkStatus(); } @@ -815,7 +816,7 @@ CudnnConvBackendConfig GetDefaultBackendConfig() { static absl::StatusOr CreateCustomCallHelper( HloInstruction* conv, const se::GpuComputeCapability& cc, const se::dnn::VersionInfo& dnn_version) { - TF_RETURN_IF_ERROR(CheckTypes(conv, cc, dnn_version)); + RETURN_IF_ERROR(CheckTypes(conv, cc, dnn_version)); if (ConvolutionMatch m = MatchBackwardInput(conv)) { auto& [window, dnums, rhs] = *m; return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(), @@ -854,8 +855,8 @@ absl::StatusOr RunOnInstruction(HloInstruction* conv, const se::dnn::VersionInfo& dnn_version) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - TF_ASSIGN_OR_RETURN(HloInstruction * custom_call, - CreateCustomCallHelper(conv, cc, dnn_version)); + ASSIGN_OR_RETURN(HloInstruction * custom_call, + CreateCustomCallHelper(conv, cc, dnn_version)); if (custom_call == nullptr) { return false; } @@ -863,14 +864,14 @@ absl::StatusOr RunOnInstruction(HloInstruction* conv, GpuBackendConfig gpu_backend_config; *gpu_backend_config.mutable_cudnn_conv_backend_config() = GetDefaultBackendConfig(); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); + RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); VLOG(1) << "Replacing convolution " << conv->ToString() << " with " << custom_call->ToString(); // The CustomCall returns a tuple (conv_result, scratch_memory). Extract // out the conv result and replace `conv` with it. - TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( + RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( conv, HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0))); return true; @@ -891,7 +892,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation, bool changed = false; for (HloInstruction* conv : convs) { - TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv, cc, dnn_version)); + ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv, cc, dnn_version)); changed |= result; } return changed; @@ -905,7 +906,7 @@ absl::StatusOr ConvRewriter::RunImpl( bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool result, RunOnComputation(computation, compute_capability_, dnn_version_)); changed |= result; diff --git a/third_party/xla/xla/backends/gpu/transforms/copy_fusion.cc b/third_party/xla/xla/backends/gpu/transforms/copy_fusion.cc index efcdd4ad3e7465..42abd472617107 100644 --- a/third_party/xla/xla/backends/gpu/transforms/copy_fusion.cc +++ b/third_party/xla/xla/backends/gpu/transforms/copy_fusion.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/codegen/ir_emission_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -196,17 +197,17 @@ absl::StatusOr CopyFusion::DoCopyFusion( } if (HloPredicateIsOp(root)) { - TF_RETURN_IF_ERROR(fused_computation->RemoveInstruction(root)); + RETURN_IF_ERROR(fused_computation->RemoveInstruction(root)); } else { auto get_tuple_element_root = computation->AddInstruction( HloInstruction::CreateGetTupleElement(hlo, 0)); - TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWithDifferentShape( + RETURN_IF_ERROR(hlo->ReplaceAllUsesWithDifferentShape( other_users, get_tuple_element_root)); } for (int64_t i = 0; i < copies.size(); ++i) { auto get_tuple_element = computation->AddInstruction( HloInstruction::CreateGetTupleElement(hlo, num_outputs + i)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation->ReplaceInstruction(copies[i], get_tuple_element)); } } diff --git a/third_party/xla/xla/backends/gpu/transforms/cublas_pad_for_gemms.cc b/third_party/xla/xla/backends/gpu/transforms/cublas_pad_for_gemms.cc index dd650e9ce65dba..4e52b613620cdf 100644 --- a/third_party/xla/xla/backends/gpu/transforms/cublas_pad_for_gemms.cc +++ b/third_party/xla/xla/backends/gpu/transforms/cublas_pad_for_gemms.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/codegen/triton/support_legacy.h" #include "xla/backends/gpu/transforms/gemm_fusion.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -172,7 +173,7 @@ absl::StatusOr> GetRelevantDots( std::vector gemms; for (HloInstruction* instr : comp->instructions()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool is_matmul, IsCublasSupportedMatMul(*instr, /*allow_matrix_vector_multiplication=*/false)); @@ -195,12 +196,11 @@ absl::StatusOr CublasPadForGemms::RunImpl( bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN( - std::vector dots, - GetRelevantDots(gpu_compute_capability_, comp, datatype_)); + ASSIGN_OR_RETURN(std::vector dots, + GetRelevantDots(gpu_compute_capability_, comp, datatype_)); for (HloDotInstruction* dot : dots) { - TF_ASSIGN_OR_RETURN(bool result, - PadForGemm(dot, datatype_, pad_to_multiple_of_)); + ASSIGN_OR_RETURN(bool result, + PadForGemm(dot, datatype_, pad_to_multiple_of_)); changed |= result; } } diff --git a/third_party/xla/xla/backends/gpu/transforms/cudnn_custom_call_compiler.cc b/third_party/xla/xla/backends/gpu/transforms/cudnn_custom_call_compiler.cc index 6fafca5f7ae0d7..3e3b9975055958 100644 --- a/third_party/xla/xla/backends/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/third_party/xla/xla/backends/gpu/transforms/cudnn_custom_call_compiler.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/transforms/block_scaling_rewriter.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -82,8 +83,8 @@ using se::dnn::MatmulTensorDescriptor; using se::dnn::TensorDescriptor; absl::StatusOr TensorDescriptorFor(const Shape &shape) { - TF_ASSIGN_OR_RETURN(const DataType type, - GetDNNDataTypeFromPrimitiveType(shape.element_type())); + ASSIGN_OR_RETURN(const DataType type, + GetDNNDataTypeFromPrimitiveType(shape.element_type())); return TensorDescriptor::For(type, shape.dimensions(), shape.layout().minor_to_major()); } @@ -92,8 +93,8 @@ enum Side { LHS, RHS }; absl::StatusOr MatmulTensorDescriptorFor( const Shape &shape, const DotDimensionNumbers &dnums, const Side side) { - TF_ASSIGN_OR_RETURN(const DataType type, - GetDNNDataTypeFromPrimitiveType(shape.element_type())); + ASSIGN_OR_RETURN(const DataType type, + GetDNNDataTypeFromPrimitiveType(shape.element_type())); return MatmulTensorDescriptor::For( type, shape.dimensions(), shape.layout().minor_to_major(), (side == LHS) ? dnums.lhs_batch_dimensions() @@ -104,27 +105,26 @@ absl::StatusOr MatmulTensorDescriptorFor( absl::StatusOr BuildGraphForCustomCallToForwardFMHA( se::dnn::DnnSupport &dnn_support, HloCustomCallInstruction *custom_call) { - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(custom_call)); - TF_ASSIGN_OR_RETURN( - const auto gpu_config, - custom_call->backend_config()); + ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, + xla::gpu::GetCudnnfMHAKind(custom_call)); + ASSIGN_OR_RETURN(const auto gpu_config, + custom_call->backend_config()); const xla::gpu::CudnnfMHABackendConfig &config = gpu_config.cudnn_fmha_backend_config(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor q, MatmulTensorDescriptorFor(custom_call->operand(0)->shape(), config.bmm1_dot_dimension_numbers(), LHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor k, MatmulTensorDescriptorFor(custom_call->operand(1)->shape(), config.bmm1_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor v, MatmulTensorDescriptorFor(custom_call->operand(2)->shape(), config.bmm2_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( TensorDescriptor output, TensorDescriptorFor(ShapeUtil::GetSubshape(custom_call->shape(), {0}))); @@ -132,8 +132,8 @@ absl::StatusOr BuildGraphForCustomCallToForwardFMHA( const bool has_activation = xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; if (has_activation) { - TF_ASSIGN_OR_RETURN(activation, TensorDescriptorFor(ShapeUtil::GetSubshape( - custom_call->shape(), {1}))); + ASSIGN_OR_RETURN(activation, TensorDescriptorFor(ShapeUtil::GetSubshape( + custom_call->shape(), {1}))); } int input_index = 3; @@ -141,16 +141,16 @@ absl::StatusOr BuildGraphForCustomCallToForwardFMHA( if (kind == CudnnfMHAKind::kScaleBiasSoftmax || kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout) { const HloInstruction &bias_hlo = *custom_call->operand(3); - TF_ASSIGN_OR_RETURN(bias, TensorDescriptorFor(bias_hlo.shape())); + ASSIGN_OR_RETURN(bias, TensorDescriptorFor(bias_hlo.shape())); input_index++; } const double dropout_rate = config.dropout_rate(); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); const int sliding_window_length = config.sliding_window_length(); const int max_seg_per_batch = config.max_seg_per_batch(); @@ -171,10 +171,10 @@ absl::StatusOr BuildGraphForCustomCallToForwardFMHA( std::optional page_table_k; std::optional page_table_v; if (is_paged_attention) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( page_table_k, TensorDescriptorFor(custom_call->operand(input_index++)->shape())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( page_table_v, TensorDescriptorFor(custom_call->operand(input_index++)->shape())); } @@ -191,7 +191,7 @@ absl::StatusOr BuildGraphForCustomCallToForwardFMHA( } auto score_mod_ptr = score_mod.has_value() ? &score_mod.value() : nullptr; TF_RET_CHECK(input_index == custom_call->operand_count()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::gpu::CudnnGraph graph, se::gpu::GetCudnnFlashAttentionOperationGraph( dnn_support, q, k, v, output, bias, activation, page_table_k, @@ -203,31 +203,30 @@ absl::StatusOr BuildGraphForCustomCallToForwardFMHA( absl::StatusOr BuildGraphForCustomCallToForwardFMHAF8( se::dnn::DnnSupport &dnn_support, HloCustomCallInstruction *custom_call) { - TF_ASSIGN_OR_RETURN( - const auto gpu_config, - custom_call->backend_config()); + ASSIGN_OR_RETURN(const auto gpu_config, + custom_call->backend_config()); const xla::gpu::CudnnfMHABackendConfig &config = gpu_config.cudnn_fmha_backend_config(); - TF_ASSIGN_OR_RETURN(Shape intermediate_tensor_shape, - Shape::FromProto(config.intermediate_tensor_shape())); - - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(Shape intermediate_tensor_shape, + Shape::FromProto(config.intermediate_tensor_shape())); + + ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + ASSIGN_OR_RETURN( MatmulTensorDescriptor q, MatmulTensorDescriptorFor(custom_call->operand(0)->shape(), config.bmm1_dot_dimension_numbers(), LHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor k, MatmulTensorDescriptorFor(custom_call->operand(1)->shape(), config.bmm1_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor v, MatmulTensorDescriptorFor(custom_call->operand(2)->shape(), config.bmm2_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( TensorDescriptor output, TensorDescriptorFor(ShapeUtil::GetSubshape(custom_call->shape(), {0}))); @@ -235,22 +234,20 @@ absl::StatusOr BuildGraphForCustomCallToForwardFMHAF8( bool has_activation = xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 5; if (has_activation) { - TF_ASSIGN_OR_RETURN(activation, TensorDescriptorFor(ShapeUtil::GetSubshape( - custom_call->shape(), {3}))); + ASSIGN_OR_RETURN(activation, TensorDescriptorFor(ShapeUtil::GetSubshape( + custom_call->shape(), {3}))); } - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - se::gpu::GetCudnnFlashAttentionF8OperationGraph( - dnn_support, q, k, v, output, activation, - static_cast(config.fmha_scale()), dnn_mask_type)); + ASSIGN_OR_RETURN(se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionF8OperationGraph( + dnn_support, q, k, v, output, activation, + static_cast(config.fmha_scale()), dnn_mask_type)); return graph; } absl::StatusOr BuildGraphForCustomCallToBackwardFMHA( se::dnn::DnnSupport &dnn_support, HloCustomCallInstruction *custom_call) { - TF_ASSIGN_OR_RETURN( - auto gpu_config, - custom_call->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, + custom_call->backend_config()); xla::gpu::CudnnfMHABackendConfig &config = *gpu_config.mutable_cudnn_fmha_backend_config(); @@ -258,12 +255,12 @@ absl::StatusOr BuildGraphForCustomCallToBackwardFMHA( const Shape &q_shape = custom_call->operand(input_index++)->shape(); const Shape &k_shape = custom_call->operand(input_index++)->shape(); const Shape &v_shape = custom_call->operand(input_index++)->shape(); - TF_ASSIGN_OR_RETURN(const Shape p_shape, - Shape::FromProto(config.intermediate_tensor_shape())); + ASSIGN_OR_RETURN(const Shape p_shape, + Shape::FromProto(config.intermediate_tensor_shape())); ++input_index; const Shape &d_output_shape = custom_call->operand(input_index++)->shape(); - TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, GetCudnnfMHAKind(custom_call)); + ASSIGN_OR_RETURN(const CudnnfMHAKind kind, GetCudnnfMHAKind(custom_call)); bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); @@ -307,49 +304,49 @@ absl::StatusOr BuildGraphForCustomCallToBackwardFMHA( const bool force_deterministic = RequireDeterminism(custom_call->GetModule()->config()); config.set_force_deterministic(force_deterministic); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor q, MatmulTensorDescriptorFor( q_shape, config.bmm1_grad_gemm1_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor k, MatmulTensorDescriptorFor( k_shape, config.bmm1_grad_gemm2_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor p, MatmulTensorDescriptorFor( p_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(), LHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor v, MatmulTensorDescriptorFor( v_shape, config.bmm2_grad_gemm2_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor d_output, MatmulTensorDescriptorFor( d_output_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN(TensorDescriptor dq, TensorDescriptorFor(dq_shape)); - TF_ASSIGN_OR_RETURN(TensorDescriptor dk, TensorDescriptorFor(dk_shape)); - TF_ASSIGN_OR_RETURN(TensorDescriptor dv, TensorDescriptorFor(dv_shape)); + ASSIGN_OR_RETURN(TensorDescriptor dq, TensorDescriptorFor(dq_shape)); + ASSIGN_OR_RETURN(TensorDescriptor dk, TensorDescriptorFor(dk_shape)); + ASSIGN_OR_RETURN(TensorDescriptor dv, TensorDescriptorFor(dv_shape)); std::optional bias; std::optional dbias; if (bias_shape.has_value()) { - TF_ASSIGN_OR_RETURN(bias, TensorDescriptorFor(*bias_shape)); + ASSIGN_OR_RETURN(bias, TensorDescriptorFor(*bias_shape)); } if (dbias_shape.has_value()) { - TF_ASSIGN_OR_RETURN(dbias, TensorDescriptorFor(*dbias_shape)); + ASSIGN_OR_RETURN(dbias, TensorDescriptorFor(*dbias_shape)); } const double dropout_rate = config.dropout_rate(); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); const int sliding_window_length = config.sliding_window_length(); auto computations = custom_call->called_computations(); @@ -373,7 +370,7 @@ absl::StatusOr BuildGraphForCustomCallToBackwardFMHA( } TF_RET_CHECK(input_index == custom_call->operand_count()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::gpu::CudnnGraph graph, se::gpu::GetCudnnFlashAttentionBackwardOperationGraph( dnn_support, q, k, p, v, d_output, dq, dk, dv, bias, dbias, @@ -385,9 +382,8 @@ absl::StatusOr BuildGraphForCustomCallToBackwardFMHA( absl::StatusOr BuildGraphForCustomCallToBackwardFMHAF8( se::dnn::DnnSupport &dnn_support, HloCustomCallInstruction *custom_call) { - TF_ASSIGN_OR_RETURN( - auto gpu_config, - custom_call->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, + custom_call->backend_config()); xla::gpu::CudnnfMHABackendConfig &config = *gpu_config.mutable_cudnn_fmha_backend_config(); @@ -398,50 +394,50 @@ absl::StatusOr BuildGraphForCustomCallToBackwardFMHAF8( Shape fwd_output_shape = custom_call->operand(3)->shape(); Shape d_output_shape = custom_call->operand(4)->shape(); - TF_ASSIGN_OR_RETURN(Shape p_shape, - Shape::FromProto(config.intermediate_tensor_shape())); + ASSIGN_OR_RETURN(Shape p_shape, + Shape::FromProto(config.intermediate_tensor_shape())); Shape dq_shape = ShapeUtil::GetSubshape(custom_call->shape(), {0}); Shape dk_shape = ShapeUtil::GetSubshape(custom_call->shape(), {1}); Shape dv_shape = ShapeUtil::GetSubshape(custom_call->shape(), {2}); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor q, MatmulTensorDescriptorFor( q_shape, config.bmm1_grad_gemm1_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor k, MatmulTensorDescriptorFor( k_shape, config.bmm1_grad_gemm2_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor p, MatmulTensorDescriptorFor( p_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(), LHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor v, MatmulTensorDescriptorFor( v_shape, config.bmm2_grad_gemm2_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( MatmulTensorDescriptor d_output, MatmulTensorDescriptorFor( d_output_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(), RHS)); - TF_ASSIGN_OR_RETURN(TensorDescriptor dq, TensorDescriptorFor(dq_shape)); - TF_ASSIGN_OR_RETURN(TensorDescriptor dk, TensorDescriptorFor(dk_shape)); - TF_ASSIGN_OR_RETURN(TensorDescriptor dv, TensorDescriptorFor(dv_shape)); + ASSIGN_OR_RETURN(TensorDescriptor dq, TensorDescriptorFor(dq_shape)); + ASSIGN_OR_RETURN(TensorDescriptor dk, TensorDescriptorFor(dk_shape)); + ASSIGN_OR_RETURN(TensorDescriptor dv, TensorDescriptorFor(dv_shape)); // 3 gradients, 4 amaxs and one workspace TF_RET_CHECK(8 == custom_call->shape().tuple_shapes().size()); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); - TF_ASSIGN_OR_RETURN(se::gpu::CudnnGraph graph, - se::gpu::GetCudnnFlashAttentionBackwardF8OperationGraph( - dnn_support, q, k, p, v, d_output, dq, dk, dv, - config.fmha_scale(), dnn_mask_type)); + ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + ASSIGN_OR_RETURN(se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionBackwardF8OperationGraph( + dnn_support, q, k, p, v, d_output, dq, dk, dv, + config.fmha_scale(), dnn_mask_type)); return graph; } @@ -451,14 +447,14 @@ absl::StatusOr BuildGraphForCustomCallToBlockScaledDot( TF_RET_CHECK(custom_call->operand_count() == 4 || has_global_scale); TF_RET_CHECK(custom_call->shape().tuple_shapes().size() == 2); - TF_ASSIGN_OR_RETURN(TensorDescriptor lhs_data, - TensorDescriptorFor(custom_call->operand(0)->shape())); - TF_ASSIGN_OR_RETURN(TensorDescriptor rhs_data, - TensorDescriptorFor(custom_call->operand(1)->shape())); - TF_ASSIGN_OR_RETURN(TensorDescriptor lhs_scale, - TensorDescriptorFor(custom_call->operand(2)->shape())); - TF_ASSIGN_OR_RETURN(TensorDescriptor rhs_scale, - TensorDescriptorFor(custom_call->operand(3)->shape())); + ASSIGN_OR_RETURN(TensorDescriptor lhs_data, + TensorDescriptorFor(custom_call->operand(0)->shape())); + ASSIGN_OR_RETURN(TensorDescriptor rhs_data, + TensorDescriptorFor(custom_call->operand(1)->shape())); + ASSIGN_OR_RETURN(TensorDescriptor lhs_scale, + TensorDescriptorFor(custom_call->operand(2)->shape())); + ASSIGN_OR_RETURN(TensorDescriptor rhs_scale, + TensorDescriptorFor(custom_call->operand(3)->shape())); DataType result_type; switch (custom_call->shape().tuple_shapes(0).element_type()) { @@ -484,10 +480,10 @@ absl::StatusOr BuildGraphForCustomCallToBlockScaledDot( ? BlockScalingRewriter::kBlockSizeMXFP8 : BlockScalingRewriter::kBlockSizeNVFP4; - TF_ASSIGN_OR_RETURN(se::gpu::CudnnGraph graph, - se::gpu::GetCudnnBlockScaledDotOperationGraph( - dnn_support, lhs_data, lhs_scale, rhs_data, rhs_scale, - result_type, block_size, has_global_scale)); + ASSIGN_OR_RETURN(se::gpu::CudnnGraph graph, + se::gpu::GetCudnnBlockScaledDotOperationGraph( + dnn_support, lhs_data, lhs_scale, rhs_data, rhs_scale, + result_type, block_size, has_global_scale)); return graph; } @@ -531,12 +527,12 @@ class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(const std::string fingerprint_without_workspace, - FingerprintWithBackendConfig(*hlo)); + ASSIGN_OR_RETURN(const std::string fingerprint_without_workspace, + FingerprintWithBackendConfig(*hlo)); auto workspace_size_it = workspace_sizes_.find(fingerprint_without_workspace); if (workspace_size_it == workspace_sizes_.cend()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::gpu::CudnnGraph graph, HloCustomCallToCuDnnGraph(dnn_support_, DynCast(hlo))); @@ -550,8 +546,8 @@ class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph)); // Compute a new fingerprint with a potential workspace for the // compilation results to match a fingerprint computed by the emitter. - TF_ASSIGN_OR_RETURN(const std::string fingerprint_with_workspace, - FingerprintWithBackendConfig(*hlo)); + ASSIGN_OR_RETURN(const std::string fingerprint_with_workspace, + FingerprintWithBackendConfig(*hlo)); compilation_results_[fingerprint_with_workspace] = std::string(reinterpret_cast(serialized_graph.data()), serialized_graph.size()); diff --git a/third_party/xla/xla/backends/gpu/transforms/cudnn_custom_call_converter.cc b/third_party/xla/xla/backends/gpu/transforms/cudnn_custom_call_converter.cc index 8cff27e0a0a6f5..e193c20ec4cfc2 100644 --- a/third_party/xla/xla/backends/gpu/transforms/cudnn_custom_call_converter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/cudnn_custom_call_converter.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -47,8 +48,8 @@ class CustomCallVisitor : public DfsHloRewriteVisitor { FusionBackendConfig &backend_config = *gpu_config.mutable_fusion_backend_config(); backend_config.set_kind(hlo->custom_call_target()); - TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, fusion)); + RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, fusion)); return absl::OkStatus(); } }; diff --git a/third_party/xla/xla/backends/gpu/transforms/cudnn_fused_conv_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/cudnn_fused_conv_rewriter.cc index 43b6df51758be5..9b29d27d9e5f5b 100644 --- a/third_party/xla/xla/backends/gpu/transforms/cudnn_fused_conv_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/cudnn_fused_conv_rewriter.cc @@ -234,7 +234,7 @@ absl::StatusOr EnsureIsConvBiasActivation( HloInstruction* new_conv = comp->AddInstruction( conv->CloneWithNewOperands(conv->shape(), new_operands)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv)); + RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv)); new_conv->set_custom_call_target(kCudnnConvBiasActivationForwardCallTarget); comp->parent()->SetAndUniquifyInstrName(new_conv, "cudnn-conv-bias-activation"); @@ -271,9 +271,9 @@ absl::StatusOr FuseConvertTypeIntoConv(HloComputation* comp, HloInstruction* new_conv = comp->AddInstruction(conv->CloneWithNewShape(new_shape)); comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name()); - TF_ASSIGN_OR_RETURN(HloInstruction * new_gte, - MakeGetTupleElementHlo(new_conv, 0)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_gte)); + ASSIGN_OR_RETURN(HloInstruction * new_gte, + MakeGetTupleElementHlo(new_conv, 0)); + RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_gte)); changed = true; } @@ -299,8 +299,8 @@ absl::StatusOr FuseRemoveConvertInConv(HloComputation* comp) { {F32, S8}, }}; for (auto [conv_type, cvt_type] : types) { - TF_ASSIGN_OR_RETURN(bool curr_change, - FuseConvertTypeIntoConv(comp, conv_type, cvt_type)); + ASSIGN_OR_RETURN(bool curr_change, + FuseConvertTypeIntoConv(comp, conv_type, cvt_type)); changed |= curr_change; } return changed; @@ -335,8 +335,7 @@ absl::StatusOr FuseConvAlpha(HloComputation* comp, continue; } - TF_ASSIGN_OR_RETURN(auto gpu_config, - conv->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, conv->backend_config()); CudnnConvBackendConfig& config = *gpu_config.mutable_cudnn_conv_backend_config(); @@ -351,12 +350,12 @@ absl::StatusOr FuseConvAlpha(HloComputation* comp, // StreamExecutor doesn't support the alpha parameter on non-bias-activation // convs, so we have to upgrade `conv`. - TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); + ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); - TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); + ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); config.set_conv_result_scale(alpha_f64.GetFirstElement()); - TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte)); + RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); + RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(instr, gte)); changed = true; } @@ -797,25 +796,25 @@ CaptureConvGraph(HloInstruction* instr, HloInstruction* convolution, // Shift the scaling of the input and filter to the output of the convolution. HloInstruction *input_scaled_conv, *filter_scaled_conv; if (input_scale) { - TF_RETURN_IF_ERROR(convolution->ReplaceOperandWith(0, wide_input)); + RETURN_IF_ERROR(convolution->ReplaceOperandWith(0, wide_input)); HloInstruction* bcast_input_scale = instr->AddInstruction( HloInstruction::CreateBroadcast(instr->shape(), input_scale, {})); input_scaled_conv = instr->AddInstruction(HloInstruction::CreateBinary( instr->shape(), x_mult_scale ? HloOpcode::kMultiply : HloOpcode::kDivide, instr, bcast_input_scale)); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(input_scaled_conv)); + RETURN_IF_ERROR(instr->ReplaceAllUsesWith(input_scaled_conv)); } if (filter_scale) { - TF_RETURN_IF_ERROR(convolution->ReplaceOperandWith(1, wide_filter)); + RETURN_IF_ERROR(convolution->ReplaceOperandWith(1, wide_filter)); HloInstruction* bcast_filter_scale = instr->AddInstruction( HloInstruction::CreateBroadcast(instr->shape(), filter_scale, {})); filter_scaled_conv = instr->AddInstruction(HloInstruction::CreateBinary( instr->shape(), w_mult_scale ? HloOpcode::kMultiply : HloOpcode::kDivide, input_scale ? input_scaled_conv : instr, bcast_filter_scale)); - TF_RETURN_IF_ERROR((input_scale ? input_scaled_conv : instr) - ->ReplaceAllUsesWith(filter_scaled_conv)); + RETURN_IF_ERROR((input_scale ? input_scaled_conv : instr) + ->ReplaceAllUsesWith(filter_scaled_conv)); } std::vector operands, aux_outputs; @@ -904,7 +903,7 @@ absl::StatusOr F8GraphConv(HloComputation* comp, GraphString graph_string; HloInstruction* final_instr; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::tie(operands, aux_outputs, graph_string, final_instr), CaptureConvGraph( instr, convolution, wide_input, wide_filter, input_scale, @@ -915,8 +914,8 @@ absl::StatusOr F8GraphConv(HloComputation* comp, filter_scale_op ? HloPredicateIsOp(filter_scale_op) : false)); - TF_ASSIGN_OR_RETURN(auto gpu_config, - convolution->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, + convolution->backend_config()); CudnnConvBackendConfig& config = *gpu_config.mutable_cudnn_conv_backend_config(); @@ -939,15 +938,15 @@ absl::StatusOr F8GraphConv(HloComputation* comp, ShapeUtil::MakeTupleShape(output_shapes), operands)); new_convolution->set_custom_call_target(kCudnnConvForwardGraphCallTarget); - TF_RETURN_IF_ERROR(new_convolution->set_backend_config(gpu_config)); - TF_ASSIGN_OR_RETURN(HloInstruction * new_gte, - MakeGetTupleElementHlo(new_convolution, 0)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte)); + RETURN_IF_ERROR(new_convolution->set_backend_config(gpu_config)); + ASSIGN_OR_RETURN(HloInstruction * new_gte, + MakeGetTupleElementHlo(new_convolution, 0)); + RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte)); for (int i = 0; i < aux_outputs.size(); ++i) { - TF_ASSIGN_OR_RETURN(HloInstruction * new_gte, - MakeGetTupleElementHlo(new_convolution, i + 1)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(aux_outputs[i], new_gte)); + ASSIGN_OR_RETURN(HloInstruction * new_gte, + MakeGetTupleElementHlo(new_convolution, i + 1)); + RETURN_IF_ERROR(comp->ReplaceInstruction(aux_outputs[i], new_gte)); } changed = true; @@ -985,8 +984,7 @@ absl::StatusOr FuseBiasOrSideInput(HloComputation* comp, // Can't fuse bias or side-input if the conv already has a relu (or other // activation), because bias and side-input are added before the activation // is applied. - TF_ASSIGN_OR_RETURN(auto gpu_config, - conv->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, conv->backend_config()); CudnnConvBackendConfig& config = *gpu_config.mutable_cudnn_conv_backend_config(); if (config.activation_mode() != se::dnn::kNone) { @@ -1051,10 +1049,10 @@ absl::StatusOr FuseBiasOrSideInput(HloComputation* comp, HloInstruction* new_conv = comp->AddInstruction( conv->CloneWithNewOperands(conv->shape(), new_operands)); comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name()); - TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config)); - TF_ASSIGN_OR_RETURN(HloInstruction * new_instr, - MakeGetTupleElementHlo(new_conv, 0)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr)); + RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config)); + ASSIGN_OR_RETURN(HloInstruction * new_instr, + MakeGetTupleElementHlo(new_conv, 0)); + RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr)); changed = true; } return changed; @@ -1085,8 +1083,7 @@ absl::StatusOr FuseSideInputAlpha(HloComputation* comp, if (!Match(instr, pattern)) { continue; } - TF_ASSIGN_OR_RETURN(auto gpu_config, - conv->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, conv->backend_config()); CudnnConvBackendConfig& config = *gpu_config.mutable_cudnn_conv_backend_config(); if (config.side_input_scale() != 1) { @@ -1171,11 +1168,11 @@ absl::StatusOr FuseSideInputAlpha(HloComputation* comp, conv->CloneWithNewOperands(conv->shape(), new_operands)); comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name()); - TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); + ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); config.set_side_input_scale(alpha_f64.GetFirstElement()); - TF_RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config)); + RETURN_IF_ERROR(new_conv->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv)); + RETURN_IF_ERROR(comp->ReplaceInstruction(conv, new_conv)); changed = true; } return changed; @@ -1223,8 +1220,8 @@ absl::StatusOr FuseElu(HloComputation* comp, continue; } - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - conv->backend_config()); + ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + conv->backend_config()); CudnnConvBackendConfig& config = *gpu_config.mutable_cudnn_conv_backend_config(); if (config.activation_mode() != se::dnn::kNone) { @@ -1236,10 +1233,10 @@ absl::StatusOr FuseElu(HloComputation* comp, })) { continue; } - TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); + ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); config.set_activation_mode(se::dnn::kElu); - TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1)); + RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); + RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1)); changed = true; } return changed; @@ -1260,8 +1257,8 @@ absl::StatusOr FuseRelu(HloComputation* comp) { .WithOneUse()))) { continue; } - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - conv->backend_config()); + ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + conv->backend_config()); CudnnConvBackendConfig& config = *gpu_config.mutable_cudnn_conv_backend_config(); if (config.activation_mode() != se::dnn::kNone) { @@ -1273,10 +1270,10 @@ absl::StatusOr FuseRelu(HloComputation* comp) { })) { continue; } - TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); + ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); config.set_activation_mode(se::dnn::kRelu); - TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); + RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); + RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); changed = true; } return changed; @@ -1305,8 +1302,8 @@ absl::StatusOr FuseRelu6(HloComputation* comp, m::Broadcast(m::ConstantEffectiveScalar(6))))) { continue; } - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - conv->backend_config()); + ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + conv->backend_config()); CudnnConvBackendConfig& config = *gpu_config.mutable_cudnn_conv_backend_config(); if (config.activation_mode() != se::dnn::kNone) { @@ -1322,10 +1319,10 @@ absl::StatusOr FuseRelu6(HloComputation* comp, })) { continue; } - TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); + ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); config.set_activation_mode(se::dnn::kRelu6); - TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); + RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); + RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); changed = true; } return changed; @@ -1364,8 +1361,8 @@ absl::StatusOr FuseLeakyRelu(HloComputation* comp, continue; } - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - conv->backend_config()); + ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + conv->backend_config()); CudnnConvBackendConfig& config = *gpu_config.mutable_cudnn_conv_backend_config(); if (config.activation_mode() != se::dnn::kNone) { @@ -1381,12 +1378,12 @@ absl::StatusOr FuseLeakyRelu(HloComputation* comp, })) { continue; } - TF_ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); + ASSIGN_OR_RETURN(conv, EnsureIsConvBiasActivation(conv)); config.set_activation_mode(se::dnn::kLeakyRelu); - TF_ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); + ASSIGN_OR_RETURN(Literal alpha_f64, alpha->literal().Convert(F64)); config.set_leakyrelu_alpha(alpha_f64.GetFirstElement()); - TF_RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1)); + RETURN_IF_ERROR(conv->set_backend_config(gpu_config)); + RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte1)); changed = true; } return changed; @@ -1444,9 +1441,9 @@ absl::StatusOr FuseConvertToF16(HloComputation* comp) { HloInstruction* new_conv = comp->AddInstruction( conv->CloneWithNewOperands(new_shape, new_operands)); comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name()); - TF_ASSIGN_OR_RETURN(HloInstruction * new_instr, - MakeGetTupleElementHlo(new_conv, 0)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr)); + ASSIGN_OR_RETURN(HloInstruction * new_instr, + MakeGetTupleElementHlo(new_conv, 0)); + RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr)); changed = true; } return changed; @@ -1543,9 +1540,9 @@ absl::StatusOr FuseConvertToS8(HloComputation* comp, HloInstruction* new_conv = comp->AddInstruction( conv->CloneWithNewOperands(new_shape, new_operands)); comp->parent()->SetAndUniquifyInstrName(new_conv, conv->name()); - TF_ASSIGN_OR_RETURN(HloInstruction * new_instr, - MakeGetTupleElementHlo(new_conv, 0)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr)); + ASSIGN_OR_RETURN(HloInstruction * new_instr, + MakeGetTupleElementHlo(new_conv, 0)); + RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr)); changed = true; } return changed; @@ -1698,14 +1695,14 @@ absl::StatusOr CudnnFusedConvRewriter::RunImpl( // ForwardGraph Custom Call. if (!compute_capability_.IsRocm() && !compute_capability_.IsOneAPI()) { auto* cc = compute_capability_.cuda_compute_capability(); - TF_ASSIGN_OR_RETURN( - changed, F8GraphConv(comp, *cc, dnn_version_, toolkit_version_)); + ASSIGN_OR_RETURN(changed, + F8GraphConv(comp, *cc, dnn_version_, toolkit_version_)); if (changed) { return changed; } } // Fuse "inside out" starting with the operations closest to the conv. - TF_ASSIGN_OR_RETURN(changed, FuseRemoveConvertInConv(comp)); + ASSIGN_OR_RETURN(changed, FuseRemoveConvertInConv(comp)); any_changed |= changed; ASSIGN_OR_RETURN(changed, FuseConvAlpha(comp, compute_capability_)); @@ -1724,19 +1721,19 @@ absl::StatusOr CudnnFusedConvRewriter::RunImpl( // Relu might appear before or after convert-to-f16/s8, so we check in both // cases. - TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp)); + ASSIGN_OR_RETURN(changed, FuseRelu(comp)); any_changed |= changed; - TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_)); + ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_)); any_changed |= changed; - TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_)); + ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_)); any_changed |= changed; - TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_)); + ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_)); any_changed |= changed; - TF_ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp)); + ASSIGN_OR_RETURN(changed, FuseConvertToF16(comp)); any_changed |= changed; - TF_ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp, compute_capability_)); + ASSIGN_OR_RETURN(changed, FuseConvertToS8(comp, compute_capability_)); any_changed |= changed; // f16 convs' bias+side-input can appear before or after conversion to f16. @@ -1747,19 +1744,19 @@ absl::StatusOr CudnnFusedConvRewriter::RunImpl( ASSIGN_OR_RETURN(changed, FuseSideInputAlpha(comp, compute_capability_)); any_changed |= changed; - TF_ASSIGN_OR_RETURN(changed, FuseRelu(comp)); + ASSIGN_OR_RETURN(changed, FuseRelu(comp)); any_changed |= changed; - TF_ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_)); + ASSIGN_OR_RETURN(changed, FuseElu(comp, compute_capability_)); any_changed |= changed; - TF_ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_)); + ASSIGN_OR_RETURN(changed, FuseRelu6(comp, compute_capability_)); any_changed |= changed; - TF_ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_)); + ASSIGN_OR_RETURN(changed, FuseLeakyRelu(comp, compute_capability_)); any_changed |= changed; // Check that we don't have any convs outputting integer types other than // s8 - cudnn does not support these. They should have been transformed to // int8->int8 or int8->float above. - TF_RETURN_IF_ERROR(CheckNoIllegalIntegerConvs(comp)); + RETURN_IF_ERROR(CheckNoIllegalIntegerConvs(comp)); } VlogStats(module); diff --git a/third_party/xla/xla/backends/gpu/transforms/cudnn_fusion_compiler.cc b/third_party/xla/xla/backends/gpu/transforms/cudnn_fusion_compiler.cc index 2e546c956c079f..43165438c639f3 100644 --- a/third_party/xla/xla/backends/gpu/transforms/cudnn_fusion_compiler.cc +++ b/third_party/xla/xla/backends/gpu/transforms/cudnn_fusion_compiler.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cudnn/cudnn_version.h" #include "xla/backends/gpu/transforms/block_scaling_rewriter.h" #include "xla/codegen/emitters/computation_fingerprint.h" @@ -307,8 +308,7 @@ class GemmDimensionAdapter { VLOG(3) << "Non-default algorithm is not supported."; return std::nullopt; } - TF_ASSIGN_OR_RETURN(auto analysis, - TritonFusionAnalysis::Execute(computation)); + ASSIGN_OR_RETURN(auto analysis, TritonFusionAnalysis::Execute(computation)); return GemmDimensionAdapter{*dot, std::move(analysis)}; } @@ -646,13 +646,12 @@ absl::StatusOr> HloFusionToCuDnnGraph( absl::flat_hash_map> hlo_to_cudnn; - TF_ASSIGN_OR_RETURN(std::optional gemm_adapter, - GemmDimensionAdapter::Create(computation)); - TF_ASSIGN_OR_RETURN(std::optional conv_adapter, - ConvDimensionAdapter::Create(fusion, computation)); - TF_ASSIGN_OR_RETURN( - std::optional ragged_dot_adapter, - RaggedDotDimensionAdapter::Create(fusion, computation)); + ASSIGN_OR_RETURN(std::optional gemm_adapter, + GemmDimensionAdapter::Create(computation)); + ASSIGN_OR_RETURN(std::optional conv_adapter, + ConvDimensionAdapter::Create(fusion, computation)); + ASSIGN_OR_RETURN(std::optional ragged_dot_adapter, + RaggedDotDimensionAdapter::Create(fusion, computation)); if (!gemm_adapter.has_value() && !conv_adapter.has_value() && !ragged_dot_adapter.has_value()) { VLOG(3) << "No dot or conv or ragged_dot found inside cudnn fusion."; @@ -1028,12 +1027,12 @@ absl::StatusOr> HloFusionToCuDnnGraph( // Creates a cuDNN graph, queries cuDNN whether it is supported. absl::StatusOr PrepareGraph( se::dnn::DnnSupport& dnn_support, const HloFusionInstruction& hlo) { - TF_ASSIGN_OR_RETURN(std::optional graph, - HloFusionToCuDnnGraph(hlo)); + ASSIGN_OR_RETURN(std::optional graph, + HloFusionToCuDnnGraph(hlo)); if (!graph.has_value()) { return absl::InternalError("Construction of cuDNN graph failed."); } - TF_RETURN_IF_ERROR(graph->Prepare( + RETURN_IF_ERROR(graph->Prepare( dnn_support, se::EngineOptions{ RequireDeterminism(hlo.GetModule()->config()), /*allow_tf32=*/true, /*require_command_buffer=*/false})); @@ -1058,7 +1057,7 @@ absl::StatusOr AddWorkspace(HloInstruction& fusion, operands.push_back(custom_call); output_tuple = computation->AddInstruction(HloInstruction::CreateTuple(operands)); - TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( + RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( computation->root_instruction(), output_tuple)); } else { output_tuple = computation->AddInstruction(HloInstruction::CreateTuple( @@ -1067,16 +1066,15 @@ absl::StatusOr AddWorkspace(HloInstruction& fusion, computation->set_root_instruction(output_tuple, true); HloInstruction* new_fusion = fusion.parent()->AddInstruction( fusion.CloneWithNewShape(output_tuple->shape())); - TF_RETURN_IF_ERROR(new_fusion->CopyAllControlDepsFrom(&fusion)); - TF_RETURN_IF_ERROR(fusion.DropAllControlDeps()); + RETURN_IF_ERROR(new_fusion->CopyAllControlDepsFrom(&fusion)); + RETURN_IF_ERROR(fusion.DropAllControlDeps()); if (is_tuple_output) { - TF_RETURN_IF_ERROR(fusion.parent()->ReplaceInstructionWithDifferentShape( + RETURN_IF_ERROR(fusion.parent()->ReplaceInstructionWithDifferentShape( &fusion, new_fusion)); } else { - TF_RETURN_IF_ERROR( - fusion.ReplaceAllUsesWith(fusion.parent()->AddInstruction( - HloInstruction::CreateGetTupleElement(new_fusion, 0)))); - TF_RETURN_IF_ERROR(fusion.parent()->RemoveInstruction(&fusion)); + RETURN_IF_ERROR(fusion.ReplaceAllUsesWith(fusion.parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(new_fusion, 0)))); + RETURN_IF_ERROR(fusion.parent()->RemoveInstruction(&fusion)); } return new_fusion; } @@ -1088,8 +1086,7 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { : dnn_support_(dnn_support), compilation_results_(compilation_results) {} absl::Status HandleFusion(HloInstruction* hlo) override { - TF_ASSIGN_OR_RETURN(auto gpu_config, - hlo->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, hlo->backend_config()); const FusionBackendConfig& fusion_backend_config = gpu_config.fusion_backend_config(); if (fusion_backend_config.kind() != kCuDnnFusionKind) { @@ -1114,7 +1111,7 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { gpu_config.fusion_backend_config(); auto compile_graph = [&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::gpu::CudnnGraph graph, PrepareGraph(dnn_support_, *DynCast(hlo))); @@ -1127,7 +1124,7 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { if (plan_id >= graph.Graph().get_execution_plan_count()) { return absl::InternalError("cuDNN graph plan does not exist."); } - TF_RETURN_IF_ERROR(graph.Build(dnn_support_, plan_id)); + RETURN_IF_ERROR(graph.Build(dnn_support_, plan_id)); } else { // Build plans one by one till first successful when no plan_id was // provided. @@ -1146,7 +1143,7 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { gpu_config.mutable_fusion_backend_config() ->mutable_cudnn_fusion_config(); cudnn_config->set_plan_id(plan_id); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); + RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); } return graph; }; @@ -1165,9 +1162,8 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { hlo->fused_instructions_computation(), {}); if (auto it = compilation_results_.find(fingerprint); it == compilation_results_.cend()) { - TF_ASSIGN_OR_RETURN(const se::gpu::CudnnGraph graph, compile_graph()); - TF_ASSIGN_OR_RETURN(const std::string serialized, - serialize_graph(graph)); + ASSIGN_OR_RETURN(const se::gpu::CudnnGraph graph, compile_graph()); + ASSIGN_OR_RETURN(const std::string serialized, serialize_graph(graph)); compilation_results_.insert(it, {fingerprint, serialized}); } return absl::OkStatus(); @@ -1175,7 +1171,7 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { auto add_workspace = [&](const int64_t workspace_size) { if (workspace_size > 0) { - TF_ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size)); + ASSIGN_OR_RETURN(hlo, AddWorkspace(*hlo, workspace_size)); SetVisited(*hlo); } return absl::OkStatus(); @@ -1188,17 +1184,17 @@ class CuDnnFusionVisitor : public DfsHloRewriteVisitor { auto workspace_size_it = workspace_sizes_.find(fingerprint_without_workspace); if (workspace_size_it == workspace_sizes_.cend()) { - TF_ASSIGN_OR_RETURN(const se::gpu::CudnnGraph graph, compile_graph()); + ASSIGN_OR_RETURN(const se::gpu::CudnnGraph graph, compile_graph()); const int64_t workspace_size = graph.Graph().get_workspace_size(); workspace_sizes_.insert(workspace_size_it, {fingerprint_without_workspace, workspace_size}); - TF_RETURN_IF_ERROR(add_workspace(workspace_size)); - TF_ASSIGN_OR_RETURN(const std::string serialized, serialize_graph(graph)); + RETURN_IF_ERROR(add_workspace(workspace_size)); + ASSIGN_OR_RETURN(const std::string serialized, serialize_graph(graph)); compilation_results_[emitters::GetComputationFingerprint( hlo->fused_instructions_computation(), {})] = serialized; } else { VLOG(4) << "Cache hit."; - TF_RETURN_IF_ERROR(add_workspace(workspace_size_it->second)); + RETURN_IF_ERROR(add_workspace(workspace_size_it->second)); } MarkAsChanged(); diff --git a/third_party/xla/xla/backends/gpu/transforms/cudnn_norm_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/cudnn_norm_rewriter.cc index 4d357b1f9ad193..bc4da1d7ef0fd4 100644 --- a/third_party/xla/xla/backends/gpu/transforms/cudnn_norm_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/cudnn_norm_rewriter.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -827,8 +828,8 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { : cuda_compute_capability_(cuda_compute_capability) {} absl::Status HandleAdd(HloInstruction* instr) override { - TF_RETURN_IF_ERROR(MatchLayerNorm(instr)); - TF_RETURN_IF_ERROR(MatchLayerNormGradient(instr)); + RETURN_IF_ERROR(MatchLayerNorm(instr)); + RETURN_IF_ERROR(MatchLayerNormGradient(instr)); return absl::OkStatus(); } @@ -988,8 +989,8 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { for (int k = 0; k < x_transpose_order.size(); ++k) { y_transpose_order[x_transpose_order[k]] = k; } - TF_ASSIGN_OR_RETURN(x_transpose, - MakeTransposeHlo(x.instr(), x_transpose_order)); + ASSIGN_OR_RETURN(x_transpose, + MakeTransposeHlo(x.instr(), x_transpose_order)); } // Combine the dimensions not normalized into the first dimension of the @@ -1008,7 +1009,7 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { Shape reshaped_shape = ShapeUtil::MakeShape( x.instr()->shape().element_type(), reshaped_dims); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * x_reshape, MakeReshapeHlo(reshaped_shape, x_transpose.value_or(x.instr()))); @@ -1019,10 +1020,10 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { Shape scale_bias_shape = ShapeUtil::MakeShape( scale->shape().element_type(), reshaped_scale_dims); - TF_ASSIGN_OR_RETURN(HloInstruction * scale_reshape, - MakeReshapeHlo(scale_bias_shape, scale)); - TF_ASSIGN_OR_RETURN(HloInstruction * bias_reshape, - MakeReshapeHlo(scale_bias_shape, bias)); + ASSIGN_OR_RETURN(HloInstruction * scale_reshape, + MakeReshapeHlo(scale_bias_shape, scale)); + ASSIGN_OR_RETURN(HloInstruction * bias_reshape, + MakeReshapeHlo(scale_bias_shape, bias)); GpuBackendConfig gpu_backend_config; CudnnNormBackendConfig& backend_config = *gpu_backend_config.mutable_cudnn_norm_backend_config(); @@ -1035,8 +1036,8 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // Set the workspace size to its upper bound. // TODO(philipphack): Consider autotuning the norm kernels. - TF_ASSIGN_OR_RETURN(const int64_t c_constant, - CConstant(cuda_compute_capability_)); + ASSIGN_OR_RETURN(const int64_t c_constant, + CConstant(cuda_compute_capability_)); const int64_t workspace_size = (2 * c_constant * (4 + 256)) + (2 * reshaped_dims[0] * 4) + 64; algorithm->mutable_workspace_size()->set_value(workspace_size); @@ -1050,20 +1051,20 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { instr->AddInstruction(HloInstruction::CreateCustomCall( custom_call_shape, {x_reshape, scale_reshape, bias_reshape}, kCudnnNormCallTarget)); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); + RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); - TF_ASSIGN_OR_RETURN(HloInstruction * gte, - MakeGetTupleElementHlo(custom_call, 0)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(HloInstruction * gte, + MakeGetTupleElementHlo(custom_call, 0)); + ASSIGN_OR_RETURN( HloInstruction * y_reshape, MakeReshapeHlo(x_transpose.value_or(instr)->shape(), gte)); std::optional y_transpose; if (apply_transpose) { - TF_ASSIGN_OR_RETURN(y_transpose, - MakeTransposeHlo(y_reshape, y_transpose_order)); + ASSIGN_OR_RETURN(y_transpose, + MakeTransposeHlo(y_reshape, y_transpose_order)); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ReplaceInstruction(instr, y_transpose.value_or(y_reshape))); // Store metadata for potential use in the backward graph. @@ -1080,9 +1081,9 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { for (HloInstruction* user : norm_factor->users()) { if (HloPredicateIsOp(user) && user->operand_index(norm_factor) == 0) { - TF_ASSIGN_OR_RETURN(bool changed, - MatchNormFactor(user, custom_call, variance, - expectation, epsilon)); + ASSIGN_OR_RETURN(bool changed, + MatchNormFactor(user, custom_call, variance, + expectation, epsilon)); if (changed) { break; } @@ -1153,7 +1154,7 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction* new_custom_call = instr->AddInstruction( custom_call->CloneWithNewShape(custom_call_shape)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GpuBackendConfig gpu_backend_config, custom_call->backend_config()); CudnnNormBackendConfig& backend_config = @@ -1161,51 +1162,49 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { backend_config.set_kind(CudnnNormBackendConfig::LAYER_FWD_TRAIN); // Update the workspace size. - TF_ASSIGN_OR_RETURN(const int64_t c_constant, - CConstant(cuda_compute_capability_)); + ASSIGN_OR_RETURN(const int64_t c_constant, + CConstant(cuda_compute_capability_)); const int64_t workspace_size = (2 * c_constant * (4 + 256)) + 32; backend_config.mutable_algorithm()->mutable_workspace_size()->set_value( workspace_size); - TF_RETURN_IF_ERROR( - new_custom_call->set_backend_config(gpu_backend_config)); + RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config)); auto replace_with_new_cc = [new_custom_call, this]( HloInstruction* old_instr, int tuple_index) -> absl::Status { - TF_ASSIGN_OR_RETURN( - HloInstruction * new_gte, - MakeGetTupleElementHlo(new_custom_call, tuple_index)); + ASSIGN_OR_RETURN(HloInstruction * new_gte, + MakeGetTupleElementHlo(new_custom_call, tuple_index)); HloInstruction* new_instr = new_gte; if (!ShapeUtil::Equal(new_gte->shape(), old_instr->shape())) { - TF_ASSIGN_OR_RETURN(new_instr, - MakeReshapeHlo(old_instr->shape(), new_gte)); + ASSIGN_OR_RETURN(new_instr, + MakeReshapeHlo(old_instr->shape(), new_gte)); } if (HloPredicateIsNotOp(old_instr)) { // Replace the result of the layer norm or the expectation. - TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr)); + RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr)); } else { // Replace the norm factor, (variance + epsilon)^-1/2. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ReplaceInstruction(old_instr->mutable_operand(0), new_instr)); // Also replace the norm factor to the power of 3, (variance + // epsilon)^-1/2 / (variance + epsilon) = ((variance + // epsilon)^-1/2)^3. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_multiply0, MakeBinaryHlo(HloOpcode::kMultiply, new_instr, new_instr)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_multiply1, MakeBinaryHlo(HloOpcode::kMultiply, new_multiply0, new_instr)); - TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_multiply1)); + RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_multiply1)); } return absl::OkStatus(); }; // Replace the result of the original Custom Call as well as the // expectation and the norm factor with the augmented Custom Call. - TF_RETURN_IF_ERROR(replace_with_new_cc(gte, 0)); - TF_RETURN_IF_ERROR(replace_with_new_cc(expectation.instr(), 1)); - TF_RETURN_IF_ERROR(replace_with_new_cc(instr, 2)); + RETURN_IF_ERROR(replace_with_new_cc(gte, 0)); + RETURN_IF_ERROR(replace_with_new_cc(expectation.instr(), 1)); + RETURN_IF_ERROR(replace_with_new_cc(instr, 2)); // Update the Custom Call associated with the metadata of the forward // norm. @@ -1405,13 +1404,13 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // graph. HloInstruction* transposed_dy = dy.instr(); if (norm_metadata->second.x_transpose) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( transposed_dy, MakeTransposeHlo(dy.instr(), norm_metadata->second.x_transpose->dimensions())); } - TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_dy, - MakeReshapeHlo(x.instr()->shape(), transposed_dy)); + ASSIGN_OR_RETURN(HloInstruction * reshaped_dy, + MakeReshapeHlo(x.instr()->shape(), transposed_dy)); Shape dx_shape = ShapeUtil::MakeShape(instr->shape().element_type(), x.instr()->shape().dimensions()); @@ -1429,8 +1428,8 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { // Set the workspace size to its upper bound. // TODO(philipphack): Consider autotuning the norm kernels. - TF_ASSIGN_OR_RETURN(const int64_t c_constant, - CConstant(cuda_compute_capability_)); + ASSIGN_OR_RETURN(const int64_t c_constant, + CConstant(cuda_compute_capability_)); const int64_t workspace_size = (2 * c_constant * (4 + 256)) + (2 * x.instr()->shape().dimensions(0) * 4) + 64; @@ -1448,35 +1447,34 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { {x.instr(), scale.instr(), reshaped_dy, fused_expectation.instr(), fused_norm_factor.instr()}, kCudnnNormCallTarget)); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); + RETURN_IF_ERROR(custom_call->set_backend_config(gpu_backend_config)); auto replace_with_cc = [custom_call, norm_metadata, transposed_dy, this]( HloInstruction* old_instr, int tuple_index) -> absl::Status { - TF_ASSIGN_OR_RETURN(HloInstruction * gte, - MakeGetTupleElementHlo(custom_call, tuple_index)); + ASSIGN_OR_RETURN(HloInstruction * gte, + MakeGetTupleElementHlo(custom_call, tuple_index)); HloInstruction* new_instr; // Transpose DX applying the stored transpose order of Y from the // forward graph. if (tuple_index == 0 && norm_metadata->second.y_transpose) { - TF_ASSIGN_OR_RETURN(new_instr, - MakeReshapeHlo(transposed_dy->shape(), gte)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(new_instr, + MakeReshapeHlo(transposed_dy->shape(), gte)); + ASSIGN_OR_RETURN( new_instr, MakeTransposeHlo( new_instr, norm_metadata->second.y_transpose->dimensions())); } else { - TF_ASSIGN_OR_RETURN(new_instr, - MakeReshapeHlo(old_instr->shape(), gte)); + ASSIGN_OR_RETURN(new_instr, MakeReshapeHlo(old_instr->shape(), gte)); } - TF_RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr)); + RETURN_IF_ERROR(ReplaceInstruction(old_instr, new_instr)); return absl::OkStatus(); }; - TF_RETURN_IF_ERROR(replace_with_cc(instr, 0)); - TF_RETURN_IF_ERROR(replace_with_cc(dscale, 1)); + RETURN_IF_ERROR(replace_with_cc(instr, 0)); + RETURN_IF_ERROR(replace_with_cc(dscale, 1)); if (dbias) { - TF_RETURN_IF_ERROR(replace_with_cc(dbias, 2)); + RETURN_IF_ERROR(replace_with_cc(dbias, 2)); } VLOG(1) << "Gradients w.r.t. x" << (dbias ? ", scale and bias" : " and scale") @@ -1495,7 +1493,7 @@ absl::StatusOr RunOnComputation( HloComputation* computation, se::CudaComputeCapability cuda_compute_capability) { CudnnNormRewriterVisitor visitor(cuda_compute_capability); - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); return visitor.changed(); } @@ -1511,8 +1509,8 @@ absl::StatusOr CudnnNormRewriter::RunImpl( bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN( - bool result, RunOnComputation(computation, cuda_compute_capability_)); + ASSIGN_OR_RETURN(bool result, + RunOnComputation(computation, cuda_compute_capability_)); changed |= result; } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/cudnn_pad_for_convolutions.cc b/third_party/xla/xla/backends/gpu/transforms/cudnn_pad_for_convolutions.cc index 7ea1bdbff42125..506fdc76edc337 100644 --- a/third_party/xla/xla/backends/gpu/transforms/cudnn_pad_for_convolutions.cc +++ b/third_party/xla/xla/backends/gpu/transforms/cudnn_pad_for_convolutions.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -173,10 +174,10 @@ static absl::StatusOr ResolveAndPad( resolve_pad_shapes) { std::vector new_input_shapes; Shape new_result_shape; - TF_ASSIGN_OR_RETURN(bool result, resolve_pad_shapes(conv, &new_input_shapes, - &new_result_shape)); + ASSIGN_OR_RETURN(bool result, resolve_pad_shapes(conv, &new_input_shapes, + &new_result_shape)); if (result) { - TF_RETURN_IF_ERROR(PadConv(conv, new_input_shapes, new_result_shape)); + RETURN_IF_ERROR(PadConv(conv, new_input_shapes, new_result_shape)); return true; } return false; @@ -198,7 +199,7 @@ static absl::StatusOr ResolveAndPad( static absl::StatusOr TryResolvePaddedShapesForTensorCore( HloCustomCallInstruction* conv, std::vector* new_input_shapes_ptr, Shape* new_result_shape_ptr) { - TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); + ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); const auto& dnums = conv->convolution_dimension_numbers(); auto* lhs = conv->mutable_operand(0); auto* rhs = conv->mutable_operand(1); @@ -321,7 +322,7 @@ absl::StatusOr TryResolvePaddedShapesForIntegerConvolution( int pad_to, const se::CudaComputeCapability& compute_capability, HloCustomCallInstruction* conv, std::vector* new_input_shapes_ptr, Shape* new_result_shape_ptr) { - TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); + ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); const Shape& input_shape = conv->operand(0)->shape(); const Shape& kernel_shape = conv->operand(1)->shape(); const Shape& result_shape = conv->shape().tuple_shapes(0); @@ -372,9 +373,9 @@ absl::StatusOr TryResolvePaddedShapesForIntegerConvolution( } // Check that cudnn support our desired integer padding/vectorization. - TF_ASSIGN_OR_RETURN(bool cudnn_supports, - CudnnSupportsOptimizedIntegerConvolution( - compute_capability, *conv, pad_to)); + ASSIGN_OR_RETURN(bool cudnn_supports, + CudnnSupportsOptimizedIntegerConvolution(compute_capability, + *conv, pad_to)); if (!cudnn_supports) { return false; } @@ -500,14 +501,14 @@ absl::StatusOr CudnnPadForConvolutions::RunImpl( // because that lets us use the fast int8x32 data type. bool local_changed = false; if (compute_capability_.IsAtLeast(7, 5)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( local_changed, ResolveAndPad(conv, absl::bind_front( TryResolvePaddedShapesForIntegerConvolution, 32, compute_capability_))); } if (!local_changed) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( local_changed, ResolveAndPad(conv, absl::bind_front( TryResolvePaddedShapesForIntegerConvolution, @@ -517,7 +518,7 @@ absl::StatusOr CudnnPadForConvolutions::RunImpl( } if (compute_capability_.IsAtLeast(se::CudaComputeCapability::kVolta)) { for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool local_changed, ResolveAndPad(conv, TryResolvePaddedShapesForTensorCore)); changed |= local_changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/cudnn_simplify_padding.cc b/third_party/xla/xla/backends/gpu/transforms/cudnn_simplify_padding.cc index ec3142c1b4ef87..7114465d43ad24 100644 --- a/third_party/xla/xla/backends/gpu/transforms/cudnn_simplify_padding.cc +++ b/third_party/xla/xla/backends/gpu/transforms/cudnn_simplify_padding.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -226,10 +227,10 @@ absl::StatusOr TrySimplifyPadding(HloInstruction* instr) { // padding is allowed. new_pad_feature_dim->set_edge_padding_high( new_pad_feature_dim->edge_padding_high() - num_sliced_from_feature_dim); - TF_ASSIGN_OR_RETURN(HloInstruction * new_pad, - MakePadHlo(slice->mutable_operand(0), - pad->mutable_operand(1), new_padding_config)); - TF_RETURN_IF_ERROR(pad->parent()->ReplaceInstruction(pad, new_pad)); + ASSIGN_OR_RETURN(HloInstruction * new_pad, + MakePadHlo(slice->mutable_operand(0), + pad->mutable_operand(1), new_padding_config)); + RETURN_IF_ERROR(pad->parent()->ReplaceInstruction(pad, new_pad)); return true; } @@ -242,7 +243,7 @@ absl::StatusOr CudnnSimplifyPadding::RunImpl( for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool c, TrySimplifyPadding(instr)); + ASSIGN_OR_RETURN(bool c, TrySimplifyPadding(instr)); changed |= c; } } diff --git a/third_party/xla/xla/backends/gpu/transforms/cudnn_simplify_padding_test.cc b/third_party/xla/xla/backends/gpu/transforms/cudnn_simplify_padding_test.cc index f0f184022ed7b5..fff5f994986ad6 100644 --- a/third_party/xla/xla/backends/gpu/transforms/cudnn_simplify_padding_test.cc +++ b/third_party/xla/xla/backends/gpu/transforms/cudnn_simplify_padding_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" @@ -48,16 +49,15 @@ namespace m = ::xla::match; class CudnnSimplifyPaddingTest : public HloHardwareIndependentTestBase { protected: absl::StatusOr RunJustThisPass(HloModule* module) { - TF_ASSIGN_OR_RETURN(bool changed, - RunHloPass(CudnnSimplifyPadding(), module)); + ASSIGN_OR_RETURN(bool changed, RunHloPass(CudnnSimplifyPadding(), module)); VLOG(1) << "after simplify_padding:\n" << module->ToString(); // I know the name says "just this pass", but you really want algsimp too, // otherwise the resulting patterns are ugly/hard to match. - TF_RETURN_IF_ERROR(RunHloPass(HloPassFix( - AlgebraicSimplifierOptions()), - module) - .status()); + RETURN_IF_ERROR(RunHloPass(HloPassFix( + AlgebraicSimplifierOptions()), + module) + .status()); return changed; } }; diff --git a/third_party/xla/xla/backends/gpu/transforms/deviceless_estimate_cub_sort_scratch_size_test.cc b/third_party/xla/xla/backends/gpu/transforms/deviceless_estimate_cub_sort_scratch_size_test.cc index acf452b2c0ed56..2e38de180449ae 100644 --- a/third_party/xla/xla/backends/gpu/transforms/deviceless_estimate_cub_sort_scratch_size_test.cc +++ b/third_party/xla/xla/backends/gpu/transforms/deviceless_estimate_cub_sort_scratch_size_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/libraries/cub/cub_scratch_size_deviceless_lookup.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -43,9 +44,9 @@ class DevicelessEstimateCubSortScratchSizeTest protected: absl::StatusOr RunPassAndExtractScratchSize( absl::string_view hlo_text, DevicelessEstimateCubSortScratchSize& pass) { - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_text)); - TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(&pass, module.get())); + ASSIGN_OR_RETURN(bool changed, RunHloPass(&pass, module.get())); if (!changed) { return absl::InternalError("Pass did not change the module"); } diff --git a/third_party/xla/xla/backends/gpu/transforms/dot_algorithm_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/dot_algorithm_rewriter.cc index e0d41c7b16471e..9b242f2c048c76 100644 --- a/third_party/xla/xla/backends/gpu/transforms/dot_algorithm_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/dot_algorithm_rewriter.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -396,14 +397,14 @@ DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32(HloInstruction* lhs, << "Algorithm field set to BF16_BF16_F32, but the rhs isn't F32 or BF16."; if (lhs->shape().element_type() == PrimitiveType::BF16 && rhs->shape().element_type() == PrimitiveType::BF16) { - TF_ASSIGN_OR_RETURN(auto result_bf16, - MakeBinaryHlo(HloOpcode::kMultiply, lhs, rhs)); + ASSIGN_OR_RETURN(auto result_bf16, + MakeBinaryHlo(HloOpcode::kMultiply, lhs, rhs)); return UpcastToF32(result_bf16); } auto lhs_bf16 = RoundToBF16(lhs); auto rhs_bf16 = RoundToBF16(rhs); - TF_ASSIGN_OR_RETURN(auto result_bf16, - MakeBinaryHlo(HloOpcode::kMultiply, lhs_bf16, rhs_bf16)); + ASSIGN_OR_RETURN(auto result_bf16, + MakeBinaryHlo(HloOpcode::kMultiply, lhs_bf16, rhs_bf16)); return UpcastToF32(result_bf16); } @@ -417,9 +418,9 @@ DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32X3(HloInstruction* lhs, auto [lhs_high_bf16, lhs_low_bf16] = Split2xToBF16(lhs); auto [rhs_high_bf16, rhs_low_bf16] = Split2xToBF16(rhs); - TF_ASSIGN_OR_RETURN(auto* low_high, Mult(lhs_low_bf16, rhs_high_bf16)); - TF_ASSIGN_OR_RETURN(auto* high_low, Mult(lhs_high_bf16, rhs_low_bf16)); - TF_ASSIGN_OR_RETURN(auto* high_high, Mult(lhs_high_bf16, rhs_high_bf16)); + ASSIGN_OR_RETURN(auto* low_high, Mult(lhs_low_bf16, rhs_high_bf16)); + ASSIGN_OR_RETURN(auto* high_low, Mult(lhs_high_bf16, rhs_low_bf16)); + ASSIGN_OR_RETURN(auto* high_high, Mult(lhs_high_bf16, rhs_high_bf16)); auto* low_sum = SumToF32(low_high, high_low); auto* low = ReplaceNaNWithZeros(low_sum); auto* result = SumToF32(low, high_high); @@ -436,12 +437,12 @@ DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32X6(HloInstruction* lhs, auto [lhs_high_bf16, lhs_mid_bf16, lhs_low_bf16] = Split3xToBF16(lhs); auto [rhs_high_bf16, rhs_mid_bf16, rhs_low_bf16] = Split3xToBF16(rhs); - TF_ASSIGN_OR_RETURN(auto* mid_mid, Mult(lhs_mid_bf16, rhs_mid_bf16)); - TF_ASSIGN_OR_RETURN(auto* high_low, Mult(lhs_high_bf16, rhs_low_bf16)); - TF_ASSIGN_OR_RETURN(auto* low_high, Mult(lhs_low_bf16, rhs_high_bf16)); - TF_ASSIGN_OR_RETURN(auto* high_mid, Mult(lhs_high_bf16, rhs_mid_bf16)); - TF_ASSIGN_OR_RETURN(auto* mid_high, Mult(lhs_mid_bf16, rhs_high_bf16)); - TF_ASSIGN_OR_RETURN(auto* high_high, Mult(lhs_high_bf16, rhs_high_bf16)); + ASSIGN_OR_RETURN(auto* mid_mid, Mult(lhs_mid_bf16, rhs_mid_bf16)); + ASSIGN_OR_RETURN(auto* high_low, Mult(lhs_high_bf16, rhs_low_bf16)); + ASSIGN_OR_RETURN(auto* low_high, Mult(lhs_low_bf16, rhs_high_bf16)); + ASSIGN_OR_RETURN(auto* high_mid, Mult(lhs_high_bf16, rhs_mid_bf16)); + ASSIGN_OR_RETURN(auto* mid_high, Mult(lhs_mid_bf16, rhs_high_bf16)); + ASSIGN_OR_RETURN(auto* high_high, Mult(lhs_high_bf16, rhs_high_bf16)); HloInstruction* result = nullptr; result = SumToF32(mid_mid, high_low); @@ -463,15 +464,15 @@ DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32X9(HloInstruction* lhs, auto [lhs_high_bf16, lhs_mid_bf16, lhs_low_bf16] = Split3xToBF16(lhs); auto [rhs_high_bf16, rhs_mid_bf16, rhs_low_bf16] = Split3xToBF16(rhs); - TF_ASSIGN_OR_RETURN(auto* low_low, Mult(lhs_low_bf16, rhs_low_bf16)); - TF_ASSIGN_OR_RETURN(auto* low_mid, Mult(lhs_low_bf16, rhs_mid_bf16)); - TF_ASSIGN_OR_RETURN(auto* mid_low, Mult(lhs_mid_bf16, rhs_low_bf16)); - TF_ASSIGN_OR_RETURN(auto* mid_mid, Mult(lhs_mid_bf16, rhs_mid_bf16)); - TF_ASSIGN_OR_RETURN(auto* high_low, Mult(lhs_high_bf16, rhs_low_bf16)); - TF_ASSIGN_OR_RETURN(auto* low_high, Mult(lhs_low_bf16, rhs_high_bf16)); - TF_ASSIGN_OR_RETURN(auto* high_mid, Mult(lhs_high_bf16, rhs_mid_bf16)); - TF_ASSIGN_OR_RETURN(auto* mid_high, Mult(lhs_mid_bf16, rhs_high_bf16)); - TF_ASSIGN_OR_RETURN(auto* high_high, Mult(lhs_high_bf16, rhs_high_bf16)); + ASSIGN_OR_RETURN(auto* low_low, Mult(lhs_low_bf16, rhs_low_bf16)); + ASSIGN_OR_RETURN(auto* low_mid, Mult(lhs_low_bf16, rhs_mid_bf16)); + ASSIGN_OR_RETURN(auto* mid_low, Mult(lhs_mid_bf16, rhs_low_bf16)); + ASSIGN_OR_RETURN(auto* mid_mid, Mult(lhs_mid_bf16, rhs_mid_bf16)); + ASSIGN_OR_RETURN(auto* high_low, Mult(lhs_high_bf16, rhs_low_bf16)); + ASSIGN_OR_RETURN(auto* low_high, Mult(lhs_low_bf16, rhs_high_bf16)); + ASSIGN_OR_RETURN(auto* high_mid, Mult(lhs_high_bf16, rhs_mid_bf16)); + ASSIGN_OR_RETURN(auto* mid_high, Mult(lhs_mid_bf16, rhs_high_bf16)); + ASSIGN_OR_RETURN(auto* high_high, Mult(lhs_high_bf16, rhs_high_bf16)); HloInstruction* result = nullptr; result = SumToF32(low_low, low_mid); @@ -495,7 +496,7 @@ DotAlgorithmRewriter::MakeMultiplyForTF32TF32F32(HloInstruction* lhs, << "Algorithm field set to TF32_TF32_F32_X3, but the rhs is not F32."; auto lhs_tf32 = Truncate(lhs, kMaskTF32); auto rhs_tf32 = Truncate(rhs, kMaskTF32); - TF_ASSIGN_OR_RETURN(auto* result, Mult(lhs_tf32, rhs_tf32)); + ASSIGN_OR_RETURN(auto* result, Mult(lhs_tf32, rhs_tf32)); return result; } @@ -508,9 +509,9 @@ DotAlgorithmRewriter::MakeMultiplyForTF32TF32F32X3(HloInstruction* lhs, << "Algorithm field set to TF32_TF32_F32_X3, but the rhs is not F32."; auto [lhs_high_tf32, lhs_low_tf32] = Split2xToTF32(lhs); auto [rhs_high_tf32, rhs_low_tf32] = Split2xToTF32(rhs); - TF_ASSIGN_OR_RETURN(auto* low_high, Mult(lhs_low_tf32, rhs_high_tf32)); - TF_ASSIGN_OR_RETURN(auto* high_low, Mult(lhs_high_tf32, rhs_low_tf32)); - TF_ASSIGN_OR_RETURN(auto* high_high, Mult(lhs_high_tf32, rhs_high_tf32)); + ASSIGN_OR_RETURN(auto* low_high, Mult(lhs_low_tf32, rhs_high_tf32)); + ASSIGN_OR_RETURN(auto* high_low, Mult(lhs_high_tf32, rhs_low_tf32)); + ASSIGN_OR_RETURN(auto* high_high, Mult(lhs_high_tf32, rhs_high_tf32)); auto* low_sum = SumToF32(low_high, high_low); auto* low = ReplaceNaNWithZeros(low_sum); auto* result = SumToF32(low, high_high); diff --git a/third_party/xla/xla/backends/gpu/transforms/dot_dimension_sorter.cc b/third_party/xla/xla/backends/gpu/transforms/dot_dimension_sorter.cc index a97253b8fddcf4..321f322e50e88d 100644 --- a/third_party/xla/xla/backends/gpu/transforms/dot_dimension_sorter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/dot_dimension_sorter.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -127,7 +128,7 @@ absl::StatusOr DotDimensionSorter::RunImpl( return false; } for (HloInstruction* dot : dots_to_process) { - TF_RETURN_IF_ERROR(SortDotDimensions(Cast(dot))); + RETURN_IF_ERROR(SortDotDimensions(Cast(dot))); } return true; } diff --git a/third_party/xla/xla/backends/gpu/transforms/dot_normalizer.cc b/third_party/xla/xla/backends/gpu/transforms/dot_normalizer.cc index 67ed0650533377..6c2d12621543e8 100644 --- a/third_party/xla/xla/backends/gpu/transforms/dot_normalizer.cc +++ b/third_party/xla/xla/backends/gpu/transforms/dot_normalizer.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/backends/gpu/transforms/dot_normalizer.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -44,13 +45,13 @@ absl::StatusOr DotNormalizer::ExpandInstruction( ShapeUtil::AppendMinorDimension(1, &new_lhs_shape); HloInstruction* normalized_lhs = dot->AddInstruction(HloInstruction::CreateBitcast(new_lhs_shape, lhs)); - TF_RETURN_IF_ERROR(dot->ReplaceOperandWithDifferentShape(0, normalized_lhs)); + RETURN_IF_ERROR(dot->ReplaceOperandWithDifferentShape(0, normalized_lhs)); HloInstruction* rhs = dot->mutable_operand(1); Shape new_rhs_shape = rhs->shape(); ShapeUtil::AppendMinorDimension(1, &new_rhs_shape); HloInstruction* normalized_rhs = dot->AddInstruction(HloInstruction::CreateBitcast(new_rhs_shape, rhs)); - TF_RETURN_IF_ERROR(dot->ReplaceOperandWithDifferentShape(1, normalized_rhs)); + RETURN_IF_ERROR(dot->ReplaceOperandWithDifferentShape(1, normalized_rhs)); DotDimensionNumbers* dnums = dot->mutable_dot_dimension_numbers(); dnums->add_lhs_contracting_dimensions(new_lhs_shape.dimensions().size() - 1); dnums->add_rhs_contracting_dimensions(new_rhs_shape.dimensions().size() - 1); diff --git a/third_party/xla/xla/backends/gpu/transforms/dot_operand_converter.cc b/third_party/xla/xla/backends/gpu/transforms/dot_operand_converter.cc index 809ccd45126eb0..b0951b2f644145 100644 --- a/third_party/xla/xla/backends/gpu/transforms/dot_operand_converter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/dot_operand_converter.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" @@ -69,8 +70,8 @@ absl::StatusOr DotOperandConverter::ExpandInstruction( upcast_shape.set_element_type(desired_type); auto* convert_inst = instruction->AddInstruction( HloInstruction::CreateConvert(upcast_shape, inst_to_replace)); - TF_RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape( - operand_index, convert_inst)); + RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape(operand_index, + convert_inst)); return nullptr; } diff --git a/third_party/xla/xla/backends/gpu/transforms/dot_strength_reduction.cc b/third_party/xla/xla/backends/gpu/transforms/dot_strength_reduction.cc index 693d2de831e85e..f3925c3e84266d 100644 --- a/third_party/xla/xla/backends/gpu/transforms/dot_strength_reduction.cc +++ b/third_party/xla/xla/backends/gpu/transforms/dot_strength_reduction.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/codegen/triton/support_legacy.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -153,7 +154,7 @@ absl::StatusOr DotStrengthReduction::ExpandInstruction( HloInstruction* instruction) { HloDotInstruction* dot = Cast(instruction); const OpMetadata* metadata = &dot->metadata(); - TF_ASSIGN_OR_RETURN(auto dot_dims, DotOperandDims::FromDot(dot)); + ASSIGN_OR_RETURN(auto dot_dims, DotOperandDims::FromDot(dot)); std::array operands = {dot->mutable_operand(0), dot->mutable_operand(1)}; @@ -181,7 +182,7 @@ absl::StatusOr DotStrengthReduction::ExpandInstruction( // At this point, both operands have the same shape. Elementwise multiply. CHECK(operands[0]->shape().dimensions() == operands[1]->shape().dimensions()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * flow, MakeMultiplyForDotPrecisionAlgorithm( operands[0], operands[1], dot->precision_config().algorithm())); diff --git a/third_party/xla/xla/backends/gpu/transforms/double_buffer_loop_unrolling.cc b/third_party/xla/xla/backends/gpu/transforms/double_buffer_loop_unrolling.cc index 33fa4c03a8bae7..2f80e69af35dfb 100644 --- a/third_party/xla/xla/backends/gpu/transforms/double_buffer_loop_unrolling.cc +++ b/third_party/xla/xla/backends/gpu/transforms/double_buffer_loop_unrolling.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -148,9 +149,9 @@ absl::Status HandleControlDependencies( new_control_pred.push_back(old_to_new_map.at(pred)); } - TF_RETURN_IF_ERROR(new_instr->DropAllControlDeps()); + RETURN_IF_ERROR(new_instr->DropAllControlDeps()); for (HloInstruction* new_pred : new_control_pred) { - TF_RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr)); + RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr)); VLOG(2) << "Adding " << new_pred->ToString() << " to control dependency of " << new_instr->ToString(); } @@ -164,7 +165,7 @@ absl::Status HandleControlDependencies( skip_control_dep_injection.end() && !IsCollective(old_input)) { for (HloInstruction* old_root : *old_loop_roots) { - TF_RETURN_IF_ERROR(old_root->AddControlDependencyTo(new_input)); + RETURN_IF_ERROR(old_root->AddControlDependencyTo(new_input)); } } } @@ -188,8 +189,8 @@ absl::StatusOr FullyUnroll(HloInstruction* while_instr, absl::flat_hash_set skip_control_dep_injection; std::string clone_suffix = "full_unroll_clone"; - TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config, - while_instr->backend_config()); + ASSIGN_OR_RETURN(WhileLoopBackendConfig config, + while_instr->backend_config()); std::vector ops_to_clone; ops_to_clone.reserve(while_body->MakeInstructionPostOrder().size()); @@ -239,9 +240,9 @@ absl::StatusOr FullyUnroll(HloInstruction* while_instr, VLOG(2) << "Replaced with new root " << while_body->root_instruction()->ToString(); - TF_RETURN_IF_ERROR(HandleControlDependencies( - while_body, old_to_new_map, &loop_roots, old_input_parameter, - skip_control_dep_injection)); + RETURN_IF_ERROR(HandleControlDependencies(while_body, old_to_new_map, + &loop_roots, old_input_parameter, + skip_control_dep_injection)); // Inductive step update, clean/update necessary buffers to prepare them for // the next unrolling iteration. @@ -255,13 +256,13 @@ absl::StatusOr FullyUnroll(HloInstruction* while_instr, } WhileLoopBackendConfig old_config; - TF_ASSIGN_OR_RETURN(old_config, - while_instr->backend_config()); + ASSIGN_OR_RETURN(old_config, + while_instr->backend_config()); WhileLoopBackendConfig new_config = old_config; new_config.mutable_known_trip_count()->set_n(1); - TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config)); + RETURN_IF_ERROR(while_instr->set_backend_config(new_config)); return changed; } @@ -301,7 +302,7 @@ absl::Status PeelInstructionsForOddTripCount(HloModule* module, for (HloInstruction* instr : old_loop_roots) { new_roots.push_back(old_to_new_map[instr]); } - TF_RETURN_IF_ERROR(while_instr->ReplaceOperandWith( + RETURN_IF_ERROR(while_instr->ReplaceOperandWith( 0, old_to_new_map[while_body->root_instruction()])); VLOG(2) << "Replaced with new input tuple " << while_instr->operand(0)->ToString(); @@ -318,9 +319,9 @@ absl::Status PeelInstructionsForOddTripCount(HloModule* module, new_control_pred.push_back(old_to_new_map[pred]); } - TF_RETURN_IF_ERROR(new_instr->DropAllControlDeps()); + RETURN_IF_ERROR(new_instr->DropAllControlDeps()); for (HloInstruction* new_pred : new_control_pred) { - TF_RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr)); + RETURN_IF_ERROR(new_pred->AddControlDependencyTo(new_instr)); VLOG(2) << "Adding " << new_pred->ToString() << " to control dependency of peeled instruction: " << new_instr->ToString(); @@ -334,8 +335,8 @@ absl::Status PeelInstructionsForOddTripCount(HloModule* module, // a separate function. absl::StatusOr DoubleBufferingUnroll(HloInstruction* while_instr, HloModule* module) { - TF_ASSIGN_OR_RETURN(auto config, - while_instr->backend_config()); + ASSIGN_OR_RETURN(auto config, + while_instr->backend_config()); CHECK(config.has_known_trip_count()) << "Only loops with known trip count are supported."; @@ -357,7 +358,7 @@ absl::StatusOr DoubleBufferingUnroll(HloInstruction* while_instr, if (peel_one_iteration) { VLOG(2) << "Found loops with odd trip count, 1 iteration will be peeled " "outside of the main body."; - TF_RETURN_IF_ERROR(PeelInstructionsForOddTripCount(module, while_instr)); + RETURN_IF_ERROR(PeelInstructionsForOddTripCount(module, while_instr)); exact_trip_count -= 1; } @@ -393,9 +394,9 @@ absl::StatusOr DoubleBufferingUnroll(HloInstruction* while_instr, << while_body->root_instruction()->ToString(); // Handle existing control dependencies. - TF_RETURN_IF_ERROR(HandleControlDependencies(while_body, old_to_new_map, - &old_loop_roots, input_parameter, - skip_control_dep_injection)); + RETURN_IF_ERROR(HandleControlDependencies(while_body, old_to_new_map, + &old_loop_roots, input_parameter, + skip_control_dep_injection)); WhileLoopBackendConfig new_config = config; new_config.mutable_known_trip_count()->set_n(exact_trip_count / 2); @@ -408,7 +409,7 @@ absl::StatusOr DoubleBufferingUnroll(HloInstruction* while_instr, config.known_init_step().init() + (peel_one_iteration ? step : 0)); } - TF_RETURN_IF_ERROR(while_instr->set_backend_config(new_config)); + RETURN_IF_ERROR(while_instr->set_backend_config(new_config)); return true; // changed } @@ -443,8 +444,8 @@ absl::StatusOr DoubleBufferLoopUnrolling::RunImpl( VLOG(2) << "Processing " << while_instrs.size() << " while loops."; for (HloInstruction* while_instr : while_instrs) { - TF_ASSIGN_OR_RETURN(WhileLoopBackendConfig config, - while_instr->backend_config()); + ASSIGN_OR_RETURN(WhileLoopBackendConfig config, + while_instr->backend_config()); if (!config.has_known_trip_count()) { VLOG(2) << while_instr->ToString() << " doesn't have exact trip count, skipping loop unrolling."; @@ -458,11 +459,11 @@ absl::StatusOr DoubleBufferLoopUnrolling::RunImpl( } if (unroll_strategy_ == UnrollStrategy::kFullUnroll) { - TF_ASSIGN_OR_RETURN(changed, FullyUnroll(while_instr, module)); + ASSIGN_OR_RETURN(changed, FullyUnroll(while_instr, module)); } else if (unroll_strategy_ == UnrollStrategy::kDoubleBuffer) { - TF_ASSIGN_OR_RETURN(changed, DoubleBufferingUnroll(while_instr, module)); + ASSIGN_OR_RETURN(changed, DoubleBufferingUnroll(while_instr, module)); } else if (unroll_strategy_ == UnrollStrategy::kAuto) { - TF_ASSIGN_OR_RETURN(changed, AutoUnroll(while_instr, module)); + ASSIGN_OR_RETURN(changed, AutoUnroll(while_instr, module)); } else { LOG(FATAL) << absl::StrCat("Unhandled unrolling strategy: ", unroll_strategy_); @@ -477,8 +478,7 @@ absl::StatusOr DoubleBufferLoopUnrolling::RunImpl( // The call graph will not be flat if one of the loops that was unrolled // contains any kind of call to another computation---since the call will // be duplicated, thereby adding a second callsite for that computation. - TF_RETURN_IF_ERROR( - FlattenCallGraph().Run(module, execution_threads).status()); + RETURN_IF_ERROR(FlattenCallGraph().Run(module, execution_threads).status()); } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion.cc b/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion.cc index 1691b8d6d43382..edb7ee1677694f 100644 --- a/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion.cc +++ b/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -236,8 +237,8 @@ DynamicSliceFusion::ResolveResults(const HloInstruction* hero) { } for (const HloInstruction* user : leaf_gte->users()) { - TF_ASSIGN_OR_RETURN(auto rs, - ResolveOneResultChain(user, leaves[i].shape, i)); + ASSIGN_OR_RETURN(auto rs, + ResolveOneResultChain(user, leaves[i].shape, i)); if (rs.has_value()) { results[i] = *std::move(rs); } @@ -248,7 +249,7 @@ DynamicSliceFusion::ResolveResults(const HloInstruction* hero) { // Non-tuple hero: single result. for (const HloInstruction* user : hero->users()) { - TF_ASSIGN_OR_RETURN(auto rs, ResolveOneResultChain(user, hero->shape(), 0)); + ASSIGN_OR_RETURN(auto rs, ResolveOneResultChain(user, hero->shape(), 0)); if (rs.has_value()) { return std::vector{*std::move(rs)}; } diff --git a/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter.cc index 560ea3acc6d917..bbb7eae3244e81 100644 --- a/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -153,7 +154,7 @@ absl::Status CreateRootTuple( HloInstruction::CreateGetTupleElement(instr_mapping[hero], i)); if (hero->shape().tuple_shapes(i).IsTuple()) { instr_mapping[gte] = gte; - TF_RETURN_IF_ERROR(CreateRootTuple(gte, builder, {}, instr_mapping)); + RETURN_IF_ERROR(CreateRootTuple(gte, builder, {}, instr_mapping)); elements.push_back(builder.last_added_instruction()); } else { elements.push_back(gte); @@ -210,7 +211,7 @@ absl::StatusOr CreateFusionBody( // Create a tuple if the hero is a tuple to make sure there's a buffer // assigned for each of the elements. Make sure the tuple is not nil first. if (hero->shape().IsTuple() && hero->shape().tuple_shapes().size() > 0) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CreateRootTuple(hero, builder, sliced_user_paths, instr_mapping)); } @@ -241,7 +242,7 @@ absl::StatusOr CreateFusionInstruction( dynamic ? kDynamicSliceFusionWithDynamicAddressComputationConfigName : kDynamicSliceFusionWithStaticAddressComputationConfigName); *backend_config.mutable_custom_fusion_config() = config; - TF_RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config))); + RETURN_IF_ERROR(fusion->set_backend_config(std::move(gpu_config))); return fusion; } @@ -312,7 +313,7 @@ absl::StatusOr DynamicSliceFusionRewriter::RunImpl( auto captures = GetPatternCaptures(matched_instrs); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloComputation * fusion_body, CreateFusionBody(module, sliced_operand_paths, DataflowPathsView(sliced_user_paths_view), captures)); @@ -320,14 +321,13 @@ absl::StatusOr DynamicSliceFusionRewriter::RunImpl( bool has_dynamic_slices = absl::c_any_of(matched_instrs, [&](auto* instr) { return DynCast(instr) != nullptr; }); - TF_ASSIGN_OR_RETURN( - HloInstruction * fusion, - CreateFusionInstruction(module, hero, captures, fusion_body, - has_dynamic_slices)); + ASSIGN_OR_RETURN(HloInstruction * fusion, + CreateFusionInstruction(module, hero, captures, + fusion_body, has_dynamic_slices)); HloComputation* parent = hero->parent(); if (fusion->shape().IsTuple()) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( parent->ReplaceInstructionWithDifferentShape(hero, fusion)); for (auto& sliced_user_path : sliced_user_paths) { auto old_gte = @@ -335,7 +335,7 @@ absl::StatusOr DynamicSliceFusionRewriter::RunImpl( HloInstruction* gte = parent->AddInstruction(HloInstruction::CreateGetTupleElement( fusion, old_gte->tuple_index())); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( parent->ReplaceInstruction(sliced_user_path.back(), gte)); } } else { @@ -359,11 +359,10 @@ absl::StatusOr DynamicSliceFusionRewriter::RunImpl( } else { instr_to_be_replaced = sliced_user_paths.front().back(); } - TF_RETURN_IF_ERROR( - parent->ReplaceInstruction(instr_to_be_replaced, fusion)); + RETURN_IF_ERROR(parent->ReplaceInstruction(instr_to_be_replaced, fusion)); // This is required for collective operations which will not be removed. if (hero->parent()) { - TF_RETURN_IF_ERROR(hero->parent()->RemoveInstruction(hero)); + RETURN_IF_ERROR(hero->parent()->RemoveInstruction(hero)); } } } diff --git a/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2.cc b/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2.cc new file mode 100644 index 00000000000000..378f0852aed740 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2.cc @@ -0,0 +1,642 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_constants.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace xla::gpu { +namespace { + +//===----------------------------------------------------------------------===// +// Data structures +//===----------------------------------------------------------------------===// + +// A hero operand path: slice → noops → hero. Describes a Slice or +// DynamicSlice instruction feeding the hero through zero or more no-op +// instructions (bitcasts, GTEs). +struct SlicedParameter { + // The Slice or DynamicSlice instruction at the start of the chain. + HloInstruction* slice; + + // Bitcasts/GTEs between the slice and the hero (topological order). + std::vector noops; +}; + +// A hero result path: hero → [GTE] → [bitcasts] → DUS. Describes how one hero +// output flows through an optional GTE and bitcasts into a DynamicUpdateSlice +// that writes it into a target buffer. The `noops` vector holds the full chain +// (GTE first, then bitcasts). For passthrough results (tuple outputs without +// DUS), update_slice is nullptr and noops contains only the GTE. +// +// For flat tuple-producing heroes, there is one SlicedResult per tuple element. +struct SlicedResult { + // Flat tuple element index within the hero's output shape. Matches + // DynamicSliceFusion::Result::result_number. 0 for non-tuple heroes. + int64_t result_number = 0; + + // Instructions between the hero output and the DUS: GTE (for tuple heroes) + // followed by bitcasts. Empty when the leaf has no users in the original HLO + // (dead output) or for non-tuple heroes without bitcasts. + std::vector noops; + + // The DynamicUpdateSlice at the end of the chain. nullptr for passthrough + // results (tuple outputs that don't flow through a DUS). + HloInstruction* update_slice = nullptr; +}; + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +bool IsNoOp(const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kBitcast || + instr->opcode() == HloOpcode::kGetTupleElement; +} + +bool HasDynamicSliceConfig(const HloInstruction* instr) { + auto config = instr->backend_config(); + if (!config.ok()) { + return false; + } + return config->has_dynamic_slice_config(); +} + +bool IsAlignedSlice(const HloInstruction* instr) { + if (!IsContiguousSlice(*instr)) { + return false; + } + + const Shape& slice_shape = instr->opcode() == HloOpcode::kDynamicUpdateSlice + ? instr->operand(1)->shape() + : instr->shape(); + + auto byte_strides = ShapeUtil::ByteStrides(slice_shape); + if (!byte_strides.has_value()) { + return false; + } + + int64_t slice_bytes = ShapeUtil::ByteSizeOfElements(slice_shape); + return slice_bytes % kXlaAllocatedBufferAlignBytes == 0; +} + +bool HasSupportedShapes(const HloInstruction* hero) { + for (const HloInstruction* operand : hero->operands()) { + if (operand->shape().IsTuple()) { + LOG(WARNING) << "DynamicSliceFusionRewriterV2: skipping " << hero->name() + << " because operand " << operand->name() << " is a tuple"; + return false; + } + } + if (hero->shape().IsTuple()) { + for (const Shape& tuple_shape : hero->shape().tuple_shapes()) { + if (tuple_shape.IsTuple()) { + LOG(WARNING) << "DynamicSliceFusionRewriterV2: skipping " + << hero->name() + << " because nested tuple results are not supported"; + return false; + } + } + } + return true; +} + +//===----------------------------------------------------------------------===// +// Resolve sliced parameters +//===----------------------------------------------------------------------===// + +using OptLevel = DynamicSliceFusionRewriterV2::OptLevel; +using CaptureSlice = DynamicSliceFusionRewriterV2::CaptureSlice; +using CaptureUpdateSlice = DynamicSliceFusionRewriterV2::CaptureUpdateSlice; + +std::optional ResolveSlicedParameter(HloInstruction* operand, + OptLevel opt_level) { + HloInstruction* current = operand; + + // Walk backward through no-ops to find the slice. In O2 mode, also look + // through tuple→GTE barriers: if a chain of GTEs resolves through + // corresponding tuple instructions, we skip the entire GTE/tuple tower and + // continue from the innermost tuple operand. These are NOT added to noops. + std::vector noops; + while (IsNoOp(current)) { + if (opt_level == OptLevel::kO2 && + current->opcode() == HloOpcode::kGetTupleElement) { + // Collect a chain of GTEs, then try to resolve through + // opt-barriers and tuples back to the original operand. + std::vector indices; + HloInstruction* probe = current; + while (probe->opcode() == HloOpcode::kGetTupleElement) { + indices.push_back( + Cast(probe)->tuple_index()); + probe = probe->mutable_operand(0); + } + // Skip through optimization barriers (they pass tuples through + // unchanged). + while (probe->opcode() == HloOpcode::kOptimizationBarrier) { + probe = probe->mutable_operand(0); + } + // Walk the tuple chain: each index peels one tuple layer. + bool resolved = true; + for (int64_t idx : indices) { + if (probe->opcode() != HloOpcode::kTuple) { + resolved = false; + break; + } + probe = probe->mutable_operand(idx); + } + if (resolved) { + current = probe; + continue; + } + } + noops.push_back(current); + current = current->mutable_operand(0); + } + + if (auto* slice = DynCast(current)) { + if (!IsAlignedSlice(slice)) { + return std::nullopt; + } + absl::c_reverse(noops); + return SlicedParameter{slice, std::move(noops)}; + } + + if (auto* ds = DynCast(current)) { + if (!HasDynamicSliceConfig(ds)) { + return std::nullopt; + } + if (!IsAlignedSlice(ds)) { + return std::nullopt; + } + absl::c_reverse(noops); + return SlicedParameter{ds, std::move(noops)}; + } + + return std::nullopt; +} + +std::vector ResolveSlicedParameters( + HloInstruction* hero, OptLevel opt_level, + const CaptureSlice& capture_slice) { + std::vector result; + for (int64_t i = 0; i < hero->operand_count(); ++i) { + HloInstruction* operand = hero->mutable_operand(i); + auto param = ResolveSlicedParameter(operand, opt_level); + if (!param.has_value()) { + continue; + } + if (!capture_slice(hero, i, param->slice)) { + continue; + } + // In O2 mode, the chain output may differ from the hero's operand when + // we looked through a tuple/GTE barrier. Replace the hero's operand so + // BuildFusionPlan sees a connected graph. Safe because the hero will be + // replaced by the fusion. + HloInstruction* chain_output = + param->noops.empty() ? param->slice : param->noops.back(); + if (chain_output != operand) { + CHECK_OK(hero->ReplaceOperandWith(i, chain_output)); + } + result.push_back(std::move(*param)); + } + return result; +} + +//===----------------------------------------------------------------------===// +// Resolve sliced results +//===----------------------------------------------------------------------===// + +std::optional ResolveSlicedResult(HloInstruction* user) { + HloInstruction* current = user; + + // Walk forward through no-ops to find the DUS. + std::vector noops; + while (IsNoOp(current)) { + if (current->user_count() != 1) { + return std::nullopt; + } + noops.push_back(current); + current = current->users().front(); + } + + auto* dus = DynCast(current); + if (dus == nullptr) { + return std::nullopt; + } + + if (!HasDynamicSliceConfig(dus)) { + return std::nullopt; + } + + if (!IsAlignedSlice(dus)) { + return std::nullopt; + } + + // DUS must flow into the root: either IS the root, or feeds root tuple. + HloComputation* parent = dus->parent(); + HloInstruction* root = parent->root_instruction(); + if (dus != root) { + if (root->opcode() != HloOpcode::kTuple) { + return std::nullopt; + } + bool feeds_root = + absl::c_any_of(root->operands(), + [dus](const HloInstruction* op) { return op == dus; }); + if (!feeds_root) { + return std::nullopt; + } + } + + return SlicedResult{0, std::move(noops), dus}; +} + +// Finds the GTE user of `hero` that extracts the given tuple index. +HloInstruction* FindGte(HloInstruction* hero, int64_t tuple_index) { + for (HloInstruction* user : hero->users()) { + auto* gte = DynCast(user); + if (gte != nullptr && gte->tuple_index() == tuple_index) { + return gte; + } + } + return nullptr; +} + +// Walks forward from `gte` to find a DUS chain. If found, returns a +// SlicedResult with the GTE prepended into noops. +std::optional ResolveLeafDus(HloInstruction* gte) { + for (HloInstruction* user : gte->users()) { + if (auto sliced = ResolveSlicedResult(user)) { + sliced->noops.insert(sliced->noops.begin(), gte); + return sliced; + } + } + return std::nullopt; +} + +// Resolves the sliced result for a non-tuple hero. +std::vector ResolveNonTupleSlicedResult( + HloInstruction* hero, const CaptureUpdateSlice& capture_update_slice) { + for (HloInstruction* user : hero->users()) { + if (auto sliced = ResolveSlicedResult(user)) { + if (!capture_update_slice(hero, std::nullopt, sliced->update_slice)) { + continue; + } + return std::vector{std::move(*sliced)}; + } + } + return {}; +} + +// Resolves sliced results for a flat tuple-producing hero. Returns one +// SlicedResult per tuple element. +std::vector ResolveTupleSlicedResults( + HloInstruction* hero, const CaptureUpdateSlice& capture_update_slice) { + const int64_t tuple_size = hero->shape().tuple_shapes().size(); + + struct LeafInfo { + int64_t result_number; + HloInstruction* gte; + std::optional sliced_update; + }; + + std::vector leaf_infos; + leaf_infos.reserve(tuple_size); + for (int64_t i = 0; i < tuple_size; ++i) { + HloInstruction* gte = FindGte(hero, i); + std::optional sliced_update; + if (gte != nullptr) { + sliced_update = ResolveLeafDus(gte); + } + leaf_infos.push_back({i, gte, std::move(sliced_update)}); + } + + std::vector result; + result.reserve(tuple_size); + for (auto& info : leaf_infos) { + if (info.sliced_update.has_value() && + capture_update_slice(hero, info.result_number, + info.sliced_update->update_slice)) { + info.sliced_update->result_number = info.result_number; + result.push_back(std::move(*info.sliced_update)); + continue; + } + + SlicedResult passthrough; + passthrough.result_number = info.result_number; + if (info.gte != nullptr) { + passthrough.noops.push_back(info.gte); + } + result.push_back(std::move(passthrough)); + } + return result; +} + +std::vector ResolveSlicedResults( + HloInstruction* hero, const CaptureUpdateSlice& capture_update_slice) { + if (hero->shape().IsTuple()) { + return ResolveTupleSlicedResults(hero, capture_update_slice); + } + return ResolveNonTupleSlicedResult(hero, capture_update_slice); +} + +//===----------------------------------------------------------------------===// +// Build fusion plan +//===----------------------------------------------------------------------===// + +// Complete plan for creating a dynamic-slice fusion from a hero instruction. +// Contains all instructions to include in the fusion body and the external +// operands (captures). +struct DynamicSliceFusionPlan { + // Operands of `instructions` that are not in the set — these become fusion + // parameters (e.g., the original buffers, loop induction variable, + // constants). + std::vector captures; + + // All instructions to clone into the fusion body (topological order): + // slices, noops, hero, result noops, DUS instructions. The last instruction + // becomes the fusion root (may be a single DUS, or the hero itself when + // there are no DUS results). When there are multiple results, + // CreateFusionBody appends a tuple instruction as the root. + std::vector instructions; +}; + +std::optional BuildFusionPlan( + HloInstruction* hero, absl::Span sliced_params, + absl::Span sliced_results) { + bool has_any_slice = !sliced_params.empty() || + absl::c_any_of(sliced_results, [](const auto& r) { + return r.update_slice != nullptr; + }); + if (!has_any_slice) { + return std::nullopt; + } + + std::vector instructions; + absl::flat_hash_set instruction_set; + + auto add_instr = [&](HloInstruction* instr) { + if (instruction_set.insert(instr).second) { + instructions.push_back(instr); + } + }; + + // Add sliced parameter paths (topological: slice → noops → hero). + for (const auto& param : sliced_params) { + add_instr(param.slice); + for (HloInstruction* noop : param.noops) { + add_instr(noop); + } + } + + add_instr(hero); + + // Add sliced result paths (topological: hero → noops → DUS). + // For passthrough results (no DUS), add the GTE chain so the hero output + // is accessible inside the fusion. + for (const auto& result : sliced_results) { + for (HloInstruction* noop : result.noops) { + add_instr(noop); + } + if (result.update_slice != nullptr) { + add_instr(result.update_slice); + } + } + + // Sink constants into the fusion body instead of capturing them as + // parameters. Constants have no operands so they go at the front. + std::vector constants; + for (HloInstruction* instr : instructions) { + for (HloInstruction* operand : instr->operands()) { + if (operand->opcode() == HloOpcode::kConstant && + instruction_set.insert(operand).second) { + constants.push_back(operand); + } + } + } + instructions.insert(instructions.begin(), constants.begin(), constants.end()); + + // Collect captures: operands of instructions that are not in the set. + std::vector captures; + absl::flat_hash_set capture_set; + for (HloInstruction* instr : instructions) { + for (HloInstruction* operand : instr->operands()) { + if (!instruction_set.contains(operand) && + capture_set.insert(operand).second) { + captures.push_back(operand); + } + } + } + + return DynamicSliceFusionPlan{std::move(captures), std::move(instructions)}; +} + +//===----------------------------------------------------------------------===// +// Create fusion +//===----------------------------------------------------------------------===// + +absl::StatusOr CreateFusionBody( + HloModule* module, const DynamicSliceFusionPlan& plan, + absl::Span sliced_results, HloInstruction* hero) { + HloComputation::Builder builder("dynamic-slice-fusion"); + + absl::flat_hash_map instr_mapping; + instr_mapping.reserve(plan.captures.size() + plan.instructions.size()); + + // Create parameters for captures. + for (const HloInstruction* capture : plan.captures) { + int64_t index = instr_mapping.size(); + instr_mapping[capture] = + builder.AddInstruction(HloInstruction::CreateParameter( + index, capture->shape(), absl::StrCat("p", index))); + } + + auto mapped_operands = [&](HloInstruction* instr) { + absl::InlinedVector operands; + operands.reserve(instr->operand_count()); + for (HloInstruction* operand : instr->operands()) { + operands.push_back(instr_mapping.at(operand)); + } + return operands; + }; + + // Clone instructions in topological order. + for (HloInstruction* instr : plan.instructions) { + instr_mapping[instr] = builder.AddInstruction( + instr->CloneWithNewOperands(instr->shape(), mapped_operands(instr))); + } + + // Build root tuple when there are multiple results (DUS and/or passthrough). + if (sliced_results.size() > 1) { + HloInstruction* cloned_hero = instr_mapping.at(hero); + std::vector tuple_operands; + tuple_operands.reserve(sliced_results.size()); + for (const auto& result : sliced_results) { + if (result.update_slice != nullptr) { + tuple_operands.push_back(instr_mapping.at(result.update_slice)); + } else if (!result.noops.empty()) { + tuple_operands.push_back(instr_mapping.at(result.noops.back())); + } else { + // Dead output: create a GTE from the cloned hero to extract this tuple + // element so the fusion output has a buffer slot for it. + tuple_operands.push_back( + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + cloned_hero, result.result_number))); + } + } + builder.AddInstruction(HloInstruction::CreateTuple(tuple_operands)); + } + + return module->AddComputationAndUnifyNamesAndIds(builder.Build(), false); +} + +absl::Status SetDynamicSliceFusionBackendConfig(HloInstruction* fusion) { + GpuBackendConfig gpu_config; + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.set_kind("__custom_fusion"); + CustomFusionConfig config; + config.set_name(std::string(kDynamicSliceFusionConfigName)); + *backend_config.mutable_custom_fusion_config() = config; + return fusion->set_backend_config(std::move(gpu_config)); +} + +//===----------------------------------------------------------------------===// +// Rewrite sync hero +//===----------------------------------------------------------------------===// + +absl::StatusOr RewriteHero( + HloModule* module, HloInstruction* hero, + absl::Span sliced_params, + absl::Span sliced_results) { + auto plan = BuildFusionPlan(hero, sliced_params, sliced_results); + if (!plan.has_value()) { + return false; + } + + ASSIGN_OR_RETURN(HloComputation * fusion_body, + CreateFusionBody(module, *plan, sliced_results, hero)); + + HloComputation* parent = hero->parent(); + HloInstruction* fusion = parent->AddInstruction(HloInstruction::CreateFusion( + fusion_body->root_instruction()->shape(), + HloInstruction::FusionKind::kCustom, plan->captures, fusion_body)); + module->SetAndUniquifyInstrName(fusion, "dynamic_slice_fusion"); + RETURN_IF_ERROR(SetDynamicSliceFusionBackendConfig(fusion)); + + if (sliced_results.size() > 1) { + bool any_result_replaced = false; + for (int64_t i = 0; i < sliced_results.size(); ++i) { + auto* gte = parent->AddInstruction( + HloInstruction::CreateGetTupleElement(fusion, i)); + if (sliced_results[i].update_slice != nullptr) { + RETURN_IF_ERROR( + parent->ReplaceInstruction(sliced_results[i].update_slice, gte)); + any_result_replaced = true; + } else if (!sliced_results[i].noops.empty()) { + HloInstruction* original_leaf = sliced_results[i].noops.back(); + RETURN_IF_ERROR(parent->ReplaceInstruction(original_leaf, gte)); + any_result_replaced = true; + } + } + if (!any_result_replaced) { + RETURN_IF_ERROR(parent->ReplaceInstruction(hero, fusion)); + } + } else if (sliced_results.size() == 1) { + if (sliced_results[0].update_slice != nullptr) { + RETURN_IF_ERROR( + parent->ReplaceInstruction(sliced_results[0].update_slice, fusion)); + } else { + RETURN_IF_ERROR(parent->ReplaceInstruction(hero, fusion)); + } + } else { + RETURN_IF_ERROR(parent->ReplaceInstruction(hero, fusion)); + } + + return true; +} + +} // namespace + +//===----------------------------------------------------------------------===// +// RunImpl +//===----------------------------------------------------------------------===// + +absl::StatusOr DynamicSliceFusionRewriterV2::RunImpl( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + + std::vector computations = + module->MakeNonfusionComputations(execution_threads); + for (HloComputation* computation : computations) { + std::vector heroes; + for (HloInstruction* candidate : computation->instructions()) { + if (!options_.predicate(candidate)) { + continue; + } + if (candidate->opcode() == HloOpcode::kAsyncStart) { + return absl::InvalidArgumentError(absl::StrCat( + "DynamicSliceFusionRewriterV2 predicate must not match " + "async-start instructions, but matched: ", + candidate->name())); + } + if (HasSupportedShapes(candidate)) { + heroes.push_back(candidate); + } + } + + for (HloInstruction* hero : heroes) { + auto sliced_params = ResolveSlicedParameters(hero, options_.opt_level, + options_.capture_slice); + auto sliced_results = + ResolveSlicedResults(hero, options_.capture_update_slice); + ASSIGN_OR_RETURN( + bool hero_changed, + RewriteHero(module, hero, sliced_params, sliced_results)); + changed |= hero_changed; + } + } + + return changed; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2.h b/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2.h new file mode 100644 index 00000000000000..141aaa1a7b5a49 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2.h @@ -0,0 +1,104 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_V2_H_ +#define XLA_BACKENDS_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_V2_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/platform_id.h" + +namespace xla::gpu { + +// Dynamic slice fusion rewriter V2 wraps a hero instruction that reads from +// sliced buffers (via Slice or DynamicSlice) and/or writes into sliced buffers +// (via DynamicUpdateSlice) into a custom fusion. +// +// Requires that the DynamicSliceAnnotator pass has run first to annotate +// DynamicSlice/DynamicUpdateSlice instructions with DynamicSliceConfig. +// +// The pass is configured with options that select hero instructions and decide +// which sliced input/output edges to pull into each fusion body. +// +class DynamicSliceFusionRewriterV2 : public HloModulePass { + public: + // Selects instructions that are legal hero candidates for dynamic-slice + // fusion wrapping. It is evaluated before input/output slicing analysis. + using Predicate = absl::AnyInvocable; + + // Called after the pass has resolved an aligned Slice/DynamicSlice chain + // feeding one hero operand. Return true to move that slice chain into the + // fusion body; return false to leave it outside. For O2 tuple/GTE + // look-through, rejecting an input also prevents the temporary hero operand + // rewrite for that input. + using CaptureSlice = + absl::AnyInvocable; + + // Called after the pass has resolved an aligned DynamicUpdateSlice chain + // consuming one hero result. `result_index` is absent for non-tuple hero + // results and contains the flat tuple result number for tuple hero results. + // Return true to move that update chain into the fusion body; return false + // to leave it outside. + using CaptureUpdateSlice = absl::AnyInvocable result_index, + const HloInstruction* dynamic_update_slice) const>; + + enum class OptLevel { + // Follow a sequence of no-op (bitcast, tuple, gte) operations to find the + // the sliced source + kO1, + // Aggressive optimization that passes through optimization barriers to find + // the sliced source. + kO2, + }; + + struct Options { + Predicate predicate = [](auto...) { return false; }; + CaptureSlice capture_slice = [](auto...) { return true; }; + CaptureUpdateSlice capture_update_slice = [](auto...) { return true; }; + OptLevel opt_level = OptLevel::kO1; + }; + + DynamicSliceFusionRewriterV2(stream_executor::PlatformId platform_id, + Options options) + : platform_id_(std::move(platform_id)), options_(std::move(options)) {} + + absl::string_view name() const override { + return "dynamic-slice-fusion-rewriter-v2"; + } + + protected: + absl::StatusOr RunImpl( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + stream_executor::PlatformId platform_id_; + Options options_; +}; + +} // namespace xla::gpu + +#endif // XLA_BACKENDS_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_V2_H_ diff --git a/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2_test.cc b/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2_test.cc new file mode 100644 index 00000000000000..2f2cd24ee5b59b --- /dev/null +++ b/third_party/xla/xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2_test.cc @@ -0,0 +1,1940 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/transforms/dynamic_slice_fusion_rewriter_v2.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/log/check.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "xla/backends/gpu/transforms/dynamic_slice_annotator.h" +#include "xla/backends/gpu/transforms/dynamic_slice_fusion.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/platform_util.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/platform_id.h" // IWYU pragma: keep +#include "xla/xla_data.pb.h" + +namespace xla::gpu { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +using CstOff = DynamicSliceFusion::ConstantOffset; +using RtOff = DynamicSliceFusion::RuntimeOffset; +using Param = DynamicSliceFusion::Parameter; +using Result = DynamicSliceFusion::Result; + +static constexpr absl::string_view kFakeTarget = "fake_target"; + +DynamicSliceConfig MakeConfig(int64_t loop_index, int64_t offset, + int64_t stride) { + DynamicSliceConfig config; + config.set_loop_index(loop_index); + config.set_byte_offset(offset); + config.set_byte_stride(stride); + return config; +} + +DynamicSliceConfig MakeStaticConfig(int64_t offset) { + DynamicSliceConfig config; + config.set_byte_offset(offset); + config.set_byte_stride(0); + return config; +} + +const HloComputation* FindDsfBody(HloModule* module) { + for (HloComputation* comp : module->computations()) { + if (absl::StrContains(comp->name(), "dynamic-slice-fusion")) { + return comp; + } + } + return nullptr; +} + +class DynamicSliceFusionRewriterV2Test : public HloHardwareIndependentTestBase { + void SetUp() override { + auto maybe_name = PlatformUtil::CanonicalPlatformName("gpu"); + CHECK_OK(maybe_name); + auto maybe_platform_id = + PlatformUtil::GetPlatformIdFromCanonicalName(maybe_name.value()); + CHECK_OK(maybe_platform_id); + platform_id_ = maybe_platform_id.value(); + } + + protected: + stream_executor::PlatformId platform_id() const { return platform_id_; } + + using Options = DynamicSliceFusionRewriterV2::Options; + using OptLevel = DynamicSliceFusionRewriterV2::OptLevel; + + static Options DefaultOptions(OptLevel opt_level = OptLevel::kO1) { + Options options; + options.predicate = [](const HloInstruction* instr) { + auto* custom_call = DynCast(instr); + return custom_call != nullptr && + custom_call->custom_call_target() == kFakeTarget; + }; + options.opt_level = opt_level; + return options; + } + + HloPassPipeline MakePipeline(OptLevel opt_level = OptLevel::kO1) { + return MakePipeline(DefaultOptions(opt_level)); + } + + HloPassPipeline MakePipeline(Options options) { + HloPassPipeline pipeline("test-pipeline"); + pipeline.AddPass(); + pipeline.AddPass(platform_id(), + std::move(options)); + return pipeline; + } + + private: + stream_executor::PlatformId platform_id_; +}; + +//===----------------------------------------------------------------------===// +// Sliced operand tests +//===----------------------------------------------------------------------===// + +TEST_F(DynamicSliceFusionRewriterV2Test, SlicedOperands) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %p1 = f32[2,8,8]{2,1,0} parameter(1) + %slice0 = f32[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast0 = f32[8,8]{1,0} bitcast(%slice0) + %slice1 = f32[1,8,8]{2,1,0} slice(%p1), slice={[1:2], [0:8], [0:8]} + %bitcast1 = f32[8,8]{1,0} bitcast(%slice1) + ROOT %hero = f32[8,8]{1,0} custom-call(%bitcast0, %bitcast1), + custom_call_target="fake_target" + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK-DAG: {{.*}} f32[2,8,8]{2,1,0} parameter(0) + ; CHECK-DAG: {{.*}} f32[2,8,8]{2,1,0} parameter(1) + ; CHECK-DAG: {{.*}} f32[1,8,8]{2,1,0} slice( + ; CHECK-DAG: {{.*}} f32[8,8]{1,0} bitcast( + ; CHECK-DAG: {{.*}} f32[1,8,8]{2,1,0} slice( + ; CHECK-DAG: {{.*}} f32[8,8]{1,0} bitcast( + ; CHECK: ROOT {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT {{.*}} fusion(%p0, %p1), + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_288 = ShapeUtil::MakeShape(F32, {2, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT( + params, + ElementsAre( + Param{0, f32_288, f32_188, MakeStaticConfig(256), std::nullopt}, + Param{1, f32_288, f32_188, MakeStaticConfig(256), std::nullopt})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{std::nullopt, 0, f32_88, f32_88})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, SlicedOperandsDuplicateSlice) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %slice0 = f32[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast0 = f32[8,8]{1,0} bitcast(%slice0) + %slice1 = f32[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %bitcast1 = f32[8,8]{1,0} bitcast(%slice1) + ROOT %hero = f32[8,8]{1,0} custom-call(%bitcast0, %bitcast1), + custom_call_target="fake_target" + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[2,8,8]{2,1,0} parameter(0) + ; CHECK: {{.*}} = f32[1,8,8]{2,1,0} slice([[P0]]) + ; CHECK-SAME: slice={[1:2], [0:8], [0:8]} + ; CHECK: {{.*}} = f32[8,8]{1,0} bitcast( + ; CHECK: {{.*}} = f32[1,8,8]{2,1,0} slice([[P0]]) + ; CHECK-SAME: slice={[0:1], [0:8], [0:8]} + ; CHECK: {{.*}} = f32[8,8]{1,0} bitcast( + ; CHECK: ROOT {{.*}} = f32[8,8]{1,0} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT {{.*}} = f32[8,8]{1,0} fusion(%p0), + ; CHECK: kind=kCustom + ; CHECK: } + )"; + + auto f32_288 = ShapeUtil::MakeShape(F32, {2, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT( + params, + ElementsAre( + Param{0, f32_288, f32_188, MakeStaticConfig(256), std::nullopt}, + Param{0, f32_288, f32_188, MakeStaticConfig(0), std::nullopt})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{std::nullopt, 0, f32_88, f32_88})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, NotContiguousSliceNotFused) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[8,8]{1,0} parameter(0) + %slice0 = f32[4,4]{1,0} slice(%p0), slice={[0:4], [0:4]} + ROOT %hero = f32[4,4]{1,0} custom-call(%slice0), + custom_call_target="fake_target" + } + )"; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), std::nullopt); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, NonNoOpInSliceChainNotFused) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %slice0 = f32[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %negate = f32[1,8,8]{2,1,0} negate(%slice0) + %bitcast0 = f32[8,8]{1,0} bitcast(%negate) + ROOT %hero = f32[8,8]{1,0} custom-call(%bitcast0), + custom_call_target="fake_target" + } + )"; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), std::nullopt); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, DSWithConstantOffset) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[4,8,8]{2,1,0} parameter(0) + %c0 = s32[] constant(0) + %c1 = s32[] constant(1) + %ds = f32[1,8,8]{2,1,0} dynamic-slice(%p0, %c1, %c0, %c0), + dynamic_slice_sizes={1,8,8} + %bitcast = f32[8,8]{1,0} bitcast(%ds) + ROOT %hero = f32[8,8]{1,0} custom-call(%bitcast), + custom_call_target="fake_target" + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} dynamic-slice( + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT {{.*}} fusion(%p0), + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + std::vector offsets = { + CstOff{0, 0}, CstOff{0, 1}, CstOff{0, 2}}; + EXPECT_THAT(params, ElementsAre(Param{0, f32_488, f32_188, + MakeStaticConfig(256), offsets})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{std::nullopt, 0, f32_88, f32_88})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, SlicedOperandWithTupleResult) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %p1 = f32[8,8]{1,0} parameter(1) + %slice0 = f32[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast0 = f32[8,8]{1,0} bitcast(%slice0) + ROOT %hero = (f32[8,8]{1,0}, f32[8,8]{1,0}) custom-call(%bitcast0, %p1), + custom_call_target="fake_target" + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} = f32[2,8,8]{2,1,0} parameter({{.*}}) + ; CHECK: {{.*}} = f32[1,8,8]{2,1,0} slice({{.*}}) + ; CHECK-SAME: slice={[1:2], [0:8], [0:8]} + ; CHECK: {{.*}} = f32[8,8]{1,0} bitcast( + ; CHECK: {{.*}} = f32[8,8]{1,0} parameter({{.*}}) + ; CHECK: {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: {{.*}} get-tuple-element( + ; CHECK: {{.*}} get-tuple-element( + ; CHECK: ROOT {{.*}} tuple( + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT {{.*}} fusion(%p0, %p1), + ; CHECK: kind=kCustom + ; CHECK: } + )"; + + auto f32_288 = ShapeUtil::MakeShape(F32, {2, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT( + params, + ElementsAre( + Param{0, f32_288, f32_188, MakeStaticConfig(256), std::nullopt}, + Param{1, f32_88, f32_88, std::nullopt, std::nullopt})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{std::nullopt, 0, f32_88, f32_88, + std::nullopt, std::nullopt}, + Result{std::nullopt, 1, f32_88, f32_88, + std::nullopt, std::nullopt})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +//===----------------------------------------------------------------------===// +// Dynamic slice operand tests +//===----------------------------------------------------------------------===// + +TEST_F(DynamicSliceFusionRewriterV2Test, DynamicSlicedOperands) { + const char* hlo = R"( + HloModule test + + body { + p0 = (s32[], f32[4,8,8], f32[8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + input = f32[4,8,8] get-tuple-element(p0), index=1 + accum = f32[8,8] get-tuple-element(p0), index=2 + c0 = s32[] constant(0) + ds = f32[1,8,8] dynamic-slice(input, ivar, c0, c0), + dynamic_slice_sizes={1,8,8} + bitcast = f32[8,8] bitcast(ds) + hero = f32[8,8] custom-call(bitcast, accum), + custom_call_target="fake_target" + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[4,8,8], f32[8,8]) tuple(next_ivar, input, hero) + } + + condition { + p0 = (s32[], f32[4,8,8], f32[8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + input = f32[4,8,8] parameter(0) + accum = f32[8,8] parameter(1) + c0 = s32[] constant(0) + tuple = (s32[], f32[4,8,8], f32[8,8]) tuple(c0, input, accum) + ROOT while = (s32[], f32[4,8,8], f32[8,8]) while(tuple), + condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} dynamic-slice( + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + // Captures: input(p0), ivar(p1), accum(p2). Constants sunk into fusion. + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT( + params, + ElementsAre(Param{0, f32_488, f32_188, MakeConfig(0, 0, 256), offsets}, + Param{2, f32_88, f32_88, std::nullopt, std::nullopt})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +//===----------------------------------------------------------------------===// +// Dynamic update slice result tests +//===----------------------------------------------------------------------===// + +TEST_F(DynamicSliceFusionRewriterV2Test, DynamicUpdateSliceResult) { + const char* hlo = R"( + HloModule test + + body { + p0 = (s32[], f32[4,8,8], f32[4,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + input = f32[4,8,8] get-tuple-element(p0), index=1 + output = f32[4,8,8] get-tuple-element(p0), index=2 + c0 = s32[] constant(0) + ds = f32[1,8,8] dynamic-slice(input, ivar, c0, c0), + dynamic_slice_sizes={1,8,8} + bitcast_in = f32[8,8] bitcast(ds) + hero = f32[8,8] custom-call(bitcast_in), + custom_call_target="fake_target" + bitcast_out = f32[1,8,8] bitcast(hero) + dus = f32[4,8,8] dynamic-update-slice(output, bitcast_out, ivar, c0, c0) + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[4,8,8], f32[4,8,8]) tuple(next_ivar, input, dus) + } + + condition { + p0 = (s32[], f32[4,8,8], f32[4,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + input = f32[4,8,8] parameter(0) + output = f32[4,8,8] parameter(1) + c0 = s32[] constant(0) + tuple = (s32[], f32[4,8,8], f32[4,8,8]) tuple(c0, input, output) + ROOT while = (s32[], f32[4,8,8], f32[4,8,8]) while(tuple), + condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} dynamic-slice( + ; CHECK: {{.*}} bitcast( + ; CHECK: {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} dynamic-update-slice( + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + // Captures: input(p0), ivar(p1), output(p2). Constants sunk into fusion. + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, ElementsAre(Param{0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{2, 0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, DUSOnlyNoSlicedInput) { + // Hero takes an unsliced input and writes output via bitcast → DUS into a + // stacked buffer. This mirrors the real-world pattern where a transpose + // fusion produces a gradient slice that is DUS-ed into a stacked buffer. + const char* hlo = R"( + HloModule test + + body { + p0 = (s32[], f32[8,8], f32[4,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + input = f32[8,8] get-tuple-element(p0), index=1 + output = f32[4,8,8] get-tuple-element(p0), index=2 + c0 = s32[] constant(0) + hero = f32[8,8] custom-call(input), + custom_call_target="fake_target" + bitcast_out = f32[1,8,8] bitcast(hero) + dus = f32[4,8,8] dynamic-update-slice(output, bitcast_out, ivar, c0, c0) + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[8,8], f32[4,8,8]) tuple(next_ivar, input, dus) + } + + condition { + p0 = (s32[], f32[8,8], f32[4,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + input = f32[8,8] parameter(0) + output = f32[4,8,8] parameter(1) + c0 = s32[] constant(0) + tuple = (s32[], f32[8,8], f32[4,8,8]) tuple(c0, input, output) + ROOT while = (s32[], f32[8,8], f32[4,8,8]) while(tuple), + condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} dynamic-update-slice( + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + // Captures: input(p0), output(p1), ivar(p2). Constants sunk into fusion. + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{2, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, ElementsAre(Param{0, f32_88, f32_88, std::nullopt, + std::nullopt})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{1, 0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, DUSWithConstantOffset) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[8,8]{1,0} parameter(0) + %p1 = f32[4,8,8]{2,1,0} parameter(1) + %hero = f32[8,8]{1,0} custom-call(%p0), + custom_call_target="fake_target" + %bitcast = f32[1,8,8]{2,1,0} bitcast(%hero) + %c0 = s32[] constant(0) + ROOT %dus = f32[4,8,8]{2,1,0} dynamic-update-slice( + %p1, %bitcast, %c0, %c0, %c0) + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} dynamic-update-slice( + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT {{.*}} fusion(%p0, %p1), + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, ElementsAre(Param{0, f32_88, f32_88, std::nullopt, + std::nullopt})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + std::vector offsets = { + CstOff{0, 0}, CstOff{0, 1}, CstOff{0, 2}}; + EXPECT_THAT(results, ElementsAre(Result{1, 0, f32_488, f32_188, + MakeStaticConfig(0), offsets})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, DUSNotRootNotFused) { + const char* hlo = R"( + HloModule test + + body { + p0 = (s32[], f32[4,8,8], f32[4,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + input = f32[4,8,8] get-tuple-element(p0), index=1 + output = f32[4,8,8] get-tuple-element(p0), index=2 + c0 = s32[] constant(0) + ds = f32[1,8,8] dynamic-slice(input, ivar, c0, c0), + dynamic_slice_sizes={1,8,8} + bitcast_in = f32[8,8] bitcast(ds) + hero = f32[8,8] custom-call(bitcast_in), + custom_call_target="fake_target" + bitcast_out = f32[1,8,8] bitcast(hero) + dus = f32[4,8,8] dynamic-update-slice(output, bitcast_out, ivar, c0, c0) + negate = f32[4,8,8] negate(dus) + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[4,8,8], f32[4,8,8]) tuple( + next_ivar, input, negate) + } + + condition { + p0 = (s32[], f32[4,8,8], f32[4,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + input = f32[4,8,8] parameter(0) + output = f32[4,8,8] parameter(1) + c0 = s32[] constant(0) + tuple = (s32[], f32[4,8,8], f32[4,8,8]) tuple(c0, input, output) + ROOT while = (s32[], f32[4,8,8], f32[4,8,8]) while(tuple), + condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} dynamic-slice( + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + // Only DS fused (DUS not root). Captures: input(p0), ivar(p1). + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, ElementsAre(Param{0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{std::nullopt, 0, f32_88, f32_88})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, MixedSlicedAndUnslicedOperands) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %p1 = f32[8,8]{1,0} parameter(1) + %slice0 = f32[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast0 = f32[8,8]{1,0} bitcast(%slice0) + ROOT %hero = f32[8,8]{1,0} custom-call(%bitcast0, %p1), + custom_call_target="fake_target" + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} = f32[2,8,8]{2,1,0} parameter({{.*}}) + ; CHECK: {{.*}} = f32[1,8,8]{2,1,0} slice({{.*}}) + ; CHECK-SAME: slice={[1:2], [0:8], [0:8]} + ; CHECK: {{.*}} = f32[8,8]{1,0} bitcast( + ; CHECK: {{.*}} = f32[8,8]{1,0} parameter({{.*}}) + ; CHECK: ROOT {{.*}} = f32[8,8]{1,0} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT {{.*}} fusion(%p0, %p1), + ; CHECK: kind=kCustom + ; CHECK: } + )"; + + auto f32_288 = ShapeUtil::MakeShape(F32, {2, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT( + params, + ElementsAre( + Param{0, f32_288, f32_188, MakeStaticConfig(256), std::nullopt}, + Param{1, f32_88, f32_88, std::nullopt, std::nullopt})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +//===----------------------------------------------------------------------===// +// Predicate tests +//===----------------------------------------------------------------------===// + +// The next four tests use this HLO as input. +constexpr char kSlicedInputAndDusOutputHlo[] = R"( + HloModule test + + body { + p0 = (s32[], f32[4,8,8], f32[4,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + input = f32[4,8,8] get-tuple-element(p0), index=1 + output = f32[4,8,8] get-tuple-element(p0), index=2 + c0 = s32[] constant(0) + ds = f32[1,8,8] dynamic-slice(input, ivar, c0, c0), + dynamic_slice_sizes={1,8,8} + bitcast_in = f32[8,8] bitcast(ds) + hero = f32[8,8] custom-call(bitcast_in), + custom_call_target="fake_target" + bitcast_out = f32[1,8,8] bitcast(hero) + dus = f32[4,8,8] dynamic-update-slice(output, bitcast_out, ivar, c0, c0) + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[4,8,8], f32[4,8,8]) tuple(next_ivar, input, dus) + } + + condition { + p0 = (s32[], f32[4,8,8], f32[4,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + input = f32[4,8,8] parameter(0) + output = f32[4,8,8] parameter(1) + c0 = s32[] constant(0) + tuple = (s32[], f32[4,8,8], f32[4,8,8]) tuple(c0, input, output) + ROOT while = (s32[], f32[4,8,8], f32[4,8,8]) while(tuple), + condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } +)"; + +TEST_F(DynamicSliceFusionRewriterV2Test, PredicateMismatchNotFused) { + Options options = DefaultOptions(); + options.predicate = [](const HloInstruction*) { return false; }; + + const char* expected = R"( + ; CHECK: HloModule + ; CHECK-NOT: dynamic_slice_fusion + )"; + + RunAndFilecheckHloRewrite(kSlicedInputAndDusOutputHlo, + MakePipeline(std::move(options)), expected); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, + InputAndOutputPredicatesRejectAllSkipsHero) { + Options options = DefaultOptions(); + options.capture_slice = [](const HloInstruction*, int64_t, + const HloInstruction*) { return false; }; + options.capture_update_slice = [](const HloInstruction*, + std::optional, + const HloInstruction*) { return false; }; + + const char* expected = R"( + ; CHECK: HloModule + ; CHECK-NOT: dynamic_slice_fusion + )"; + + RunAndFilecheckHloRewrite(kSlicedInputAndDusOutputHlo, + MakePipeline(std::move(options)), expected); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, CanCaptureSlicedInputsOnly) { + Options options = DefaultOptions(); + options.capture_update_slice = [](const HloInstruction*, + std::optional, + const HloInstruction*) { return false; }; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} dynamic-slice( + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} custom-call( + ; CHECK-SAME: custom_call_target="fake_target" + ; CHECK-NOT: dynamic-update-slice( + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK-SAME: kind=kCustom + ; CHECK-SAME: "name":"dynamic_slice_fusion" + ; CHECK: {{.*}} bitcast( + ; CHECK: {{.*}} dynamic-update-slice( + ; CHECK: } + )"; + + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, ElementsAre(Param{0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{std::nullopt, 0, f32_88, f32_88})); + }; + + RunAndFilecheckHloRewrite(kSlicedInputAndDusOutputHlo, + MakePipeline(std::move(options)), expected, + fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, CanCaptureSlicedOutputsOnly) { + Options options = DefaultOptions(); + options.capture_slice = [](const HloInstruction*, int64_t, + const HloInstruction*) { return false; }; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK-NOT: dynamic-slice( + ; CHECK: {{.*}} custom-call( + ; CHECK-SAME: custom_call_target="fake_target" + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} dynamic-update-slice( + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} dynamic-slice( + ; CHECK: {{.*}} bitcast( + ; CHECK: {{.*}} fusion( + ; CHECK-SAME: kind=kCustom + ; CHECK-SAME: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{2, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, ElementsAre(Param{0, f32_88, f32_88})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{1, 0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets})); + }; + + RunAndFilecheckHloRewrite(kSlicedInputAndDusOutputHlo, + MakePipeline(std::move(options)), expected, + fusion_checks); +} + +//===----------------------------------------------------------------------===// +// Induction variable offset tests +//===----------------------------------------------------------------------===// + +TEST_F(DynamicSliceFusionRewriterV2Test, OffsetAsLinearFunctionOfInductionVar) { + const char* hlo = R"( + HloModule test + + body { + p0 = (s32[], f32[8,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + input = f32[8,8,8] get-tuple-element(p0), index=1 + c0 = s32[] constant(0) + c2 = s32[] constant(2) + offset = s32[] multiply(ivar, c2) + ds = f32[1,8,8] dynamic-slice(input, offset, c0, c0), + dynamic_slice_sizes={1,8,8} + bitcast = f32[8,8] bitcast(ds) + hero = f32[8,8] custom-call(bitcast), + custom_call_target="fake_target" + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[8,8,8]) tuple(next_ivar, input) + } + + condition { + p0 = (s32[], f32[8,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + input = f32[8,8,8] parameter(0) + c0 = s32[] constant(0) + tuple = (s32[], f32[8,8,8]) tuple(c0, input) + ROOT while = (s32[], f32[8,8,8]) while(tuple), + condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} dynamic-slice( + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + // offset = ivar*2, byte stride dim0 = 256. stride = 2*256 = 512. + // Captures: input(p0), offset(p1). Constants sunk into fusion. + auto f32_888 = ShapeUtil::MakeShape(F32, {8, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, ElementsAre(Param{0, f32_888, f32_188, + MakeConfig(0, 0, 512), offsets})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +//===----------------------------------------------------------------------===// +// Non-parameter source tests +//===----------------------------------------------------------------------===// + +TEST_F(DynamicSliceFusionRewriterV2Test, SliceFromNonParameterSource) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %negate = f32[2,8,8]{2,1,0} negate(%p0) + %slice0 = f32[1,8,8]{2,1,0} slice(%negate), slice={[1:2], [0:8], [0:8]} + %bitcast0 = f32[8,8]{1,0} bitcast(%slice0) + ROOT %hero = f32[8,8]{1,0} custom-call(%bitcast0), + custom_call_target="fake_target" + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[2,8,8]{2,1,0} parameter(0) + ; CHECK: [[S0:%[^ ]+]] = f32[1,8,8]{2,1,0} slice([[P0]]) + ; CHECK-SAME: slice={[1:2], [0:8], [0:8]} + ; CHECK: [[B0:%[^ ]+]] = f32[8,8]{1,0} bitcast([[S0]]) + ; CHECK: ROOT {{.*}} custom-call([[B0]]), + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: %negate = f32[2,8,8]{2,1,0} negate( + ; CHECK: ROOT {{.*}} = f32[8,8]{1,0} fusion(%negate), + ; CHECK: kind=kCustom + ; CHECK: } + )"; + + auto f32_288 = ShapeUtil::MakeShape(F32, {2, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, + ElementsAre(Param{0, f32_288, f32_188, MakeStaticConfig(256), + std::nullopt})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{std::nullopt, 0, f32_88, f32_88})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +//===----------------------------------------------------------------------===// +// Multiple heroes tests +//===----------------------------------------------------------------------===// + +TEST_F(DynamicSliceFusionRewriterV2Test, TwoHeroesSameComputation) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %p1 = f32[8,8]{1,0} parameter(1) + %slice0 = f32[1,8,8]{2,1,0} slice(%p0), slice={[0:1], [0:8], [0:8]} + %bitcast0 = f32[8,8]{1,0} bitcast(%slice0) + %hero0 = f32[8,8]{1,0} custom-call(%bitcast0), + custom_call_target="fake_target" + %slice1 = f32[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast1 = f32[8,8]{1,0} bitcast(%slice1) + %hero1 = f32[8,8]{1,0} custom-call(%bitcast1, %p1), + custom_call_target="fake_target" + ROOT %tuple = (f32[8,8]{1,0}, f32[8,8]{1,0}) tuple(%hero0, %hero1) + } + )"; + + const char* expected = R"( + ; CHECK: ENTRY %main{{.*}} { + ; CHECK-DAG: {{.*}} fusion(%p0), kind=kCustom + ; CHECK-DAG: {{.*}} fusion(%p0, %p1), kind=kCustom + ; CHECK: } + )"; + + // Two separate fusions; verify both can be analyzed. + auto fusion_checks = [](HloModule* module) { + int dsf_count = 0; + for (HloComputation* comp : module->computations()) { + if (!absl::StrContains(comp->name(), "dynamic-slice-fusion")) { + continue; + } + dsf_count++; + auto* hero = DynamicSliceFusion::FindHero(comp); + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_FALSE(params.empty()); + } + EXPECT_EQ(dsf_count, 2); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, NestedTupleResultNotFused) { + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %slice0 = f32[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast0 = f32[8,8]{1,0} bitcast(%slice0) + ROOT %hero = ((f32[8,8]{1,0}, f32[8,8]{1,0}), f32[8,8]{1,0}) + custom-call(%bitcast0), custom_call_target="fake_target" + } + )"; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), std::nullopt); +} + +//===----------------------------------------------------------------------===// +// Non-standard layout tests +//===----------------------------------------------------------------------===// + +TEST_F(DynamicSliceFusionRewriterV2Test, SlicedOperandColumnMajorLayout) { + // Layout {0,1,2}: column-major, dim0 most minor, dim2 most major. + // Byte strides for f32[8,8,2]{0,1,2}: dim0=4, dim1=32, dim2=256. + // Slice along most-major dim2: [0:8,0:8,1:2] → byte offset = 1*256 = 256. + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[8,8,2]{0,1,2} parameter(0) + %slice0 = f32[8,8,1]{0,1,2} slice(%p0), slice={[0:8], [0:8], [1:2]} + %bitcast0 = f32[8,8]{0,1} bitcast(%slice0) + ROOT %hero = f32[8,8]{0,1} custom-call(%bitcast0), + custom_call_target="fake_target" + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[8,8,2]{0,1,2} parameter(0) + ; CHECK: [[S0:%[^ ]+]] = f32[8,8,1]{0,1,2} slice([[P0]]) + ; CHECK-SAME: slice={[0:8], [0:8], [1:2]} + ; CHECK: [[B0:%[^ ]+]] = f32[8,8]{0,1} bitcast([[S0]]) + ; CHECK: ROOT {{.*}} = f32[8,8]{0,1} custom-call([[B0]]), + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT {{.*}} = f32[8,8]{0,1} fusion(%p0), + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_882 = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 8, 2}, {0, 1, 2}); + auto f32_881 = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 8, 1}, {0, 1, 2}); + auto f32_88 = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 8}, {0, 1}); + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, + ElementsAre(Param{0, f32_882, f32_881, MakeStaticConfig(256), + std::nullopt})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{std::nullopt, 0, f32_88, f32_88})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, DynamicSliceNonStandardLayoutWithDUS) { + // Layout {1,2,0}: minor_to_major=[1,2,0], so dim1 most minor, dim2 next, + // dim0 most major. Byte strides for f32[4,8,8]{1,2,0}: dim1=4, dim2=32, + // dim0=256. Slicing dim0 is contiguous because dim0 is most major. + // DS/DUS along dim0 → config stride = 256. + const char* hlo = R"( + HloModule test + + body { + p0 = (s32[], f32[4,8,8]{1,2,0}, f32[4,8,8]{1,2,0}) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + input = f32[4,8,8]{1,2,0} get-tuple-element(p0), index=1 + output = f32[4,8,8]{1,2,0} get-tuple-element(p0), index=2 + c0 = s32[] constant(0) + ds = f32[1,8,8]{1,2,0} dynamic-slice(input, ivar, c0, c0), + dynamic_slice_sizes={1,8,8} + bitcast_in = f32[8,8]{0,1} bitcast(ds) + hero = f32[8,8]{0,1} custom-call(bitcast_in), + custom_call_target="fake_target" + bitcast_out = f32[1,8,8]{1,2,0} bitcast(hero) + dus = f32[4,8,8]{1,2,0} dynamic-update-slice( + output, bitcast_out, ivar, c0, c0) + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[4,8,8]{1,2,0}, f32[4,8,8]{1,2,0}) + tuple(next_ivar, input, dus) + } + + condition { + p0 = (s32[], f32[4,8,8]{1,2,0}, f32[4,8,8]{1,2,0}) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + input = f32[4,8,8]{1,2,0} parameter(0) + output = f32[4,8,8]{1,2,0} parameter(1) + c0 = s32[] constant(0) + tuple = (s32[], f32[4,8,8]{1,2,0}, f32[4,8,8]{1,2,0}) + tuple(c0, input, output) + ROOT while = (s32[], f32[4,8,8]{1,2,0}, f32[4,8,8]{1,2,0}) + while(tuple), condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} dynamic-slice( + ; CHECK: {{.*}} bitcast( + ; CHECK: {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} dynamic-update-slice( + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_488 = ShapeUtil::MakeShapeWithDenseLayout(F32, {4, 8, 8}, {1, 2, 0}); + auto f32_188 = ShapeUtil::MakeShapeWithDenseLayout(F32, {1, 8, 8}, {1, 2, 0}); + auto f32_88 = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 8}, {0, 1}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, ElementsAre(Param{0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{2, 0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, + NonContiguousSliceNonStandardLayoutNotFused) { + // Layout {1,0,2}: dim1 most minor, dim0 next, dim2 most major. + // Slicing dim0 [1:2,0:8,0:8] is non-contiguous because dim2 (more major + // than sliced dim0) has extent 8 != 1. + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{1,0,2} parameter(0) + %slice0 = f32[1,8,8]{1,0,2} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast0 = f32[8,8]{0,1} bitcast(%slice0) + ROOT %hero = f32[8,8]{0,1} custom-call(%bitcast0), + custom_call_target="fake_target" + } + )"; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), std::nullopt); +} + +//===----------------------------------------------------------------------===// +// Mixed DUS/passthrough tuple output tests +//===----------------------------------------------------------------------===// + +TEST_F(DynamicSliceFusionRewriterV2Test, TupleOutputOneDUS) { + // Hero produces (f32[8,8], f32[8,8]). Only the first output flows through + // DUS; the second is returned directly (passthrough). The fusion output + // must be a tuple containing both the DUS result and the passthrough GTE. + const char* hlo = R"( + HloModule test + + body { + p0 = (s32[], f32[4,8,8], f32[8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + buf = f32[4,8,8] get-tuple-element(p0), index=1 + prev = f32[8,8] get-tuple-element(p0), index=2 + c0 = s32[] constant(0) + hero = (f32[8,8], f32[8,8]) custom-call(), + custom_call_target="fake_target" + gte0 = f32[8,8] get-tuple-element(hero), index=0 + gte1 = f32[8,8] get-tuple-element(hero), index=1 + bc0 = f32[1,8,8] bitcast(gte0) + dus = f32[4,8,8] dynamic-update-slice(buf, bc0, ivar, c0, c0) + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[4,8,8], f32[8,8]) tuple(next_ivar, dus, gte1) + } + + condition { + p0 = (s32[], f32[4,8,8], f32[8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + buf = f32[4,8,8] parameter(0) + prev = f32[8,8] parameter(1) + c0 = s32[] constant(0) + tuple = (s32[], f32[4,8,8], f32[8,8]) tuple(c0, buf, prev) + ROOT while = (s32[], f32[4,8,8], f32[8,8]) while(tuple), + condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} custom-call( + ; CHECK-SAME: custom_call_target="fake_target" + ; CHECK: {{.*}} get-tuple-element( + ; CHECK: {{.*}} bitcast( + ; CHECK: {{.*}} dynamic-update-slice( + ; CHECK: {{.*}} get-tuple-element( + ; CHECK: ROOT {{.*}} tuple( + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK-SAME: kind=kCustom + ; CHECK-SAME: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* body = FindDsfBody(module); + ASSERT_NE(body, nullptr); + auto* hero = DynamicSliceFusion::FindHero(body); + ASSERT_NE(hero, nullptr); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{0, 0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets}, + Result{std::nullopt, 1, f32_88, f32_88, + std::nullopt, std::nullopt})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, TupleOutputNoDUS) { + // Hero produces (f32[8,8], f32[8,8]). Neither output has DUS — both are + // passthrough. With a sliced input, the rewriter should still fuse. + const char* hlo = R"( + HloModule test + + body { + p0 = (s32[], f32[4,8,8], f32[8,8], f32[8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + input = f32[4,8,8] get-tuple-element(p0), index=1 + prev0 = f32[8,8] get-tuple-element(p0), index=2 + prev1 = f32[8,8] get-tuple-element(p0), index=3 + c0 = s32[] constant(0) + ds = f32[1,8,8] dynamic-slice(input, ivar, c0, c0), + dynamic_slice_sizes={1,8,8} + bc = f32[8,8] bitcast(ds) + hero = (f32[8,8], f32[8,8]) custom-call(bc), + custom_call_target="fake_target" + gte0 = f32[8,8] get-tuple-element(hero), index=0 + gte1 = f32[8,8] get-tuple-element(hero), index=1 + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[4,8,8], f32[8,8], f32[8,8]) + tuple(next_ivar, input, gte0, gte1) + } + + condition { + p0 = (s32[], f32[4,8,8], f32[8,8], f32[8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + input = f32[4,8,8] parameter(0) + prev0 = f32[8,8] parameter(1) + prev1 = f32[8,8] parameter(2) + c0 = s32[] constant(0) + tuple = (s32[], f32[4,8,8], f32[8,8], f32[8,8]) + tuple(c0, input, prev0, prev1) + ROOT while = (s32[], f32[4,8,8], f32[8,8], f32[8,8]) + while(tuple), condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} dynamic-slice( + ; CHECK: {{.*}} bitcast( + ; CHECK: {{.*}} custom-call( + ; CHECK-SAME: custom_call_target="fake_target" + ; CHECK: {{.*}} get-tuple-element( + ; CHECK: {{.*}} get-tuple-element( + ; CHECK: ROOT {{.*}} tuple( + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK-SAME: kind=kCustom + ; CHECK-SAME: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* body = FindDsfBody(module); + ASSERT_NE(body, nullptr); + auto* hero = DynamicSliceFusion::FindHero(body); + ASSERT_NE(hero, nullptr); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, ElementsAre(Param{0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{std::nullopt, 0, f32_88, f32_88, + std::nullopt, std::nullopt}, + Result{std::nullopt, 1, f32_88, f32_88, + std::nullopt, std::nullopt})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, TupleOutputOneDUSOneDeadOutput) { + // Hero produces (f32[8,8], f32[8,8]). Only the first output is used (via + // DUS). The second output has no GTE users — it is dead. The fusion must + // still include both in its output tuple so that buffer allocation has a + // slot for the second output. + const char* hlo = R"( + HloModule test + + body { + p0 = (s32[], f32[4,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + buf = f32[4,8,8] get-tuple-element(p0), index=1 + c0 = s32[] constant(0) + hero = (f32[8,8], f32[8,8]) custom-call(), + custom_call_target="fake_target" + gte0 = f32[8,8] get-tuple-element(hero), index=0 + bc0 = f32[1,8,8] bitcast(gte0) + dus = f32[4,8,8] dynamic-update-slice(buf, bc0, ivar, c0, c0) + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[4,8,8]) tuple(next_ivar, dus) + } + + condition { + p0 = (s32[], f32[4,8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + buf = f32[4,8,8] parameter(0) + c0 = s32[] constant(0) + tuple = (s32[], f32[4,8,8]) tuple(c0, buf) + ROOT while = (s32[], f32[4,8,8]) while(tuple), + condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} custom-call( + ; CHECK-SAME: custom_call_target="fake_target" + ; CHECK: {{.*}} get-tuple-element( + ; CHECK: {{.*}} bitcast( + ; CHECK: {{.*}} dynamic-update-slice( + ; CHECK: {{.*}} get-tuple-element( + ; CHECK: ROOT {{.*}} tuple( + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK-SAME: kind=kCustom + ; CHECK-SAME: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* body = FindDsfBody(module); + ASSERT_NE(body, nullptr); + auto* hero = DynamicSliceFusion::FindHero(body); + ASSERT_NE(hero, nullptr); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{0, 0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets}, + Result{std::nullopt, 1, f32_88, f32_88, + std::nullopt, std::nullopt})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, DynamicSliceOnlyNoResult) { + const char* hlo = R"( + HloModule test + + %body { + %p = (s32[], f32[4,8,8]{2,1,0}) parameter(0) + %ivar = s32[] get-tuple-element(%p), index=0 + %input = f32[4,8,8]{2,1,0} get-tuple-element(%p), index=1 + %zero = s32[] constant(0) + %ds = f32[1,8,8]{2,1,0} dynamic-slice(%input, %ivar, %zero, %zero), + dynamic_slice_sizes={1,8,8} + %bc = f32[8,8]{1,0} bitcast(%ds) + %hero = f32[8,8]{1,0} custom-call(%bc), + custom_call_target="fake_target" + %one = s32[] constant(1) + %next = s32[] add(%ivar, %one) + ROOT %tuple = (s32[], f32[4,8,8]{2,1,0}) tuple(%next, %input) + } + + %cond { + %p = (s32[], f32[4,8,8]{2,1,0}) parameter(0) + %i = s32[] get-tuple-element(%p), index=0 + %limit = s32[] constant(4) + ROOT %cmp = pred[] compare(%i, %limit), direction=LT + } + + ENTRY main { + %zero = s32[] constant(0) + %buf = f32[4,8,8]{2,1,0} parameter(0) + %init = (s32[], f32[4,8,8]{2,1,0}) tuple(%zero, %buf) + ROOT %while = (s32[], f32[4,8,8]{2,1,0}) while(%init), + body=%body, condition=%cond, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[4,8,8]{2,1,0} parameter(0) + ; CHECK: [[DS:%[^ ]+]] = f32[1,8,8]{2,1,0} dynamic-slice([[P0]] + ; CHECK: [[BC:%[^ ]+]] = f32[8,8]{1,0} bitcast([[DS]]) + ; CHECK: ROOT {{.*}} = f32[8,8]{1,0} custom-call([[BC]]), + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: %body{{.*}} { + ; CHECK: {{.*}} = f32[8,8]{1,0} fusion( + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT(params, ElementsAre(Param{0, f32_488, f32_188, + MakeConfig(0, 0, 256), offsets})); + + ASSERT_OK_AND_ASSIGN(auto results, + DynamicSliceFusion::ResolveResults(hero)); + EXPECT_THAT(results, ElementsAre(Result{std::nullopt, 0, f32_88, f32_88})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(), expected, fusion_checks); +} + +//===----------------------------------------------------------------------===// +// O2 mode tests — looking through tuples +//===----------------------------------------------------------------------===// + +TEST_F(DynamicSliceFusionRewriterV2Test, O2LooksThroughTupleGte) { + // Pattern: Slice → bitcast → tuple → GTE → hero. In O2 mode the rewriter + // should look through the tuple/GTE barrier and fuse the Slice. The tuple + // and GTE are NOT captured into the fusion. + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %p1 = f32[8,8]{1,0} parameter(1) + %slice0 = f32[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast0 = f32[8,8]{1,0} bitcast(%slice0) + %tuple = (f32[8,8]{1,0}, f32[8,8]{1,0}) tuple(%bitcast0, %p1) + %gte0 = f32[8,8]{1,0} get-tuple-element(%tuple), index=0 + %gte1 = f32[8,8]{1,0} get-tuple-element(%tuple), index=1 + ROOT %hero = f32[8,8]{1,0} custom-call(%gte0, %gte1), + custom_call_target="fake_target" + } + )"; + + // In O1 mode: GTE→tuple blocks the search, no fusion (only Slice has no + // DynamicSliceConfig, so the annotator doesn't change the module either). + RunAndFilecheckHloRewrite(hlo, MakePipeline(OptLevel::kO1), std::nullopt); + + // In O2 mode: look through tuple→GTE, find slice, fuse it. + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,8,8]{2,1,0} parameter({{.*}}) + ; CHECK: [[S0:%[^ ]+]] = f32[1,8,8]{2,1,0} slice([[P0]]) + ; CHECK-SAME: slice={[1:2], [0:8], [0:8]} + ; CHECK: [[B0:%[^ ]+]] = f32[8,8]{1,0} bitcast([[S0]]) + ; CHECK: ROOT {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT {{.*}} fusion( + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_288 = ShapeUtil::MakeShape(F32, {2, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT( + params, + ElementsAre( + Param{0, f32_288, f32_188, MakeStaticConfig(256), std::nullopt}, + Param{1, f32_88, f32_88, std::nullopt, std::nullopt})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(OptLevel::kO2), expected, + fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, O2DynamicSliceThroughTupleGte) { + // Pattern in while body: DS → bitcast → tuple → GTE → hero, with dynamic + // offsets. O2 mode looks through the tuple barrier and fuses the DS. + const char* hlo = R"( + HloModule test + + body { + p0 = (s32[], f32[4,8,8], f32[8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + input = f32[4,8,8] get-tuple-element(p0), index=1 + accum = f32[8,8] get-tuple-element(p0), index=2 + c0 = s32[] constant(0) + ds = f32[1,8,8] dynamic-slice(input, ivar, c0, c0), + dynamic_slice_sizes={1,8,8} + bitcast = f32[8,8] bitcast(ds) + barrier = (f32[8,8], f32[8,8]) tuple(bitcast, accum) + gte_sliced = f32[8,8] get-tuple-element(barrier), index=0 + gte_accum = f32[8,8] get-tuple-element(barrier), index=1 + hero = f32[8,8] custom-call(gte_sliced, gte_accum), + custom_call_target="fake_target" + c1 = s32[] constant(1) + next_ivar = s32[] add(ivar, c1) + ROOT result = (s32[], f32[4,8,8], f32[8,8]) tuple(next_ivar, input, hero) + } + + condition { + p0 = (s32[], f32[4,8,8], f32[8,8]) parameter(0) + ivar = s32[] get-tuple-element(p0), index=0 + c4 = s32[] constant(4) + ROOT cmp = pred[] compare(ivar, c4), direction=LT + } + + ENTRY main { + input = f32[4,8,8] parameter(0) + accum = f32[8,8] parameter(1) + c0 = s32[] constant(0) + tuple = (s32[], f32[4,8,8], f32[8,8]) tuple(c0, input, accum) + ROOT while = (s32[], f32[4,8,8], f32[8,8]) while(tuple), + condition=condition, body=body, + backend_config={"known_trip_count":{"n":"4"}, + "known_init_step":{"init":"0","step":"1"}, + "known_induction_variable":{"tuple_index":"0"}} + } + )"; + + // O1 mode: blocked by tuple barrier — annotator still runs and annotates + // the DS, but no fusion is created. + const char* o1_expected = R"( + ; CHECK: body + ; CHECK-NOT: fusion( + ; CHECK: ENTRY + )"; + RunAndFilecheckHloRewrite(hlo, MakePipeline(OptLevel::kO1), o1_expected); + + // O2 mode: looks through tuple, fuses DS. + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: {{.*}} dynamic-slice( + ; CHECK: {{.*}} bitcast( + ; CHECK: ROOT {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: body + ; CHECK: {{.*}} fusion( + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_488 = ShapeUtil::MakeShape(F32, {4, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + std::vector offsets = {RtOff{1, 0}, CstOff{0, 1}, + CstOff{0, 2}}; + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT( + params, + ElementsAre(Param{0, f32_488, f32_188, MakeConfig(0, 0, 256), offsets}, + Param{2, f32_88, f32_88, std::nullopt, std::nullopt})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(OptLevel::kO2), expected, + fusion_checks); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, O2NestedTupleGteLookthrough) { + // Nested tuple barrier: DS → bitcast → tuple → tuple → GTE → GTE → hero. + // O2 should look through both levels. + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %slice0 = f32[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast0 = f32[8,8]{1,0} bitcast(%slice0) + %inner = (f32[8,8]{1,0}) tuple(%bitcast0) + %outer = ((f32[8,8]{1,0})) tuple(%inner) + %gte_outer = (f32[8,8]{1,0}) get-tuple-element(%outer), index=0 + %gte_inner = f32[8,8]{1,0} get-tuple-element(%gte_outer), index=0 + ROOT %hero = f32[8,8]{1,0} custom-call(%gte_inner), + custom_call_target="fake_target" + } + )"; + + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK: [[P0:%[^ ]+]] = f32[2,8,8]{2,1,0} parameter(0) + ; CHECK: [[S0:%[^ ]+]] = f32[1,8,8]{2,1,0} slice([[P0]]) + ; CHECK-SAME: slice={[1:2], [0:8], [0:8]} + ; CHECK: [[B0:%[^ ]+]] = f32[8,8]{1,0} bitcast([[S0]]) + ; CHECK: ROOT {{.*}} custom-call([[B0]]), + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT {{.*}} fusion(%p0), + ; CHECK: kind=kCustom + ; CHECK: } + )"; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(OptLevel::kO2), expected); +} + +TEST_F(DynamicSliceFusionRewriterV2Test, O2LooksThroughOptBarrier) { + // Pattern: Slice → bitcast → tuple → opt-barrier → GTE → hero. + // O2 should look through the optimization barrier and tuple together. + const char* hlo = R"( + HloModule test + + ENTRY main { + %p0 = f32[2,8,8]{2,1,0} parameter(0) + %p1 = f32[8,8]{1,0} parameter(1) + %slice0 = f32[1,8,8]{2,1,0} slice(%p0), slice={[1:2], [0:8], [0:8]} + %bitcast0 = f32[8,8]{1,0} bitcast(%slice0) + %tuple = (f32[8,8]{1,0}, f32[8,8]{1,0}) tuple(%bitcast0, %p1) + %barrier = (f32[8,8]{1,0}, f32[8,8]{1,0}) opt-barrier(%tuple) + %gte0 = f32[8,8]{1,0} get-tuple-element(%barrier), index=0 + %gte1 = f32[8,8]{1,0} get-tuple-element(%barrier), index=1 + ROOT %hero = f32[8,8]{1,0} custom-call(%gte0, %gte1), + custom_call_target="fake_target" + } + )"; + + // O1 mode: blocked by opt-barrier + tuple. + RunAndFilecheckHloRewrite(hlo, MakePipeline(OptLevel::kO1), std::nullopt); + + // O2 mode: looks through opt-barrier and tuple, finds the slice. + const char* expected = R"( + ; CHECK: %dynamic-slice-fusion{{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,8,8]{2,1,0} parameter({{.*}}) + ; CHECK: [[S0:%[^ ]+]] = f32[1,8,8]{2,1,0} slice([[P0]]) + ; CHECK-SAME: slice={[1:2], [0:8], [0:8]} + ; CHECK: [[B0:%[^ ]+]] = f32[8,8]{1,0} bitcast([[S0]]) + ; CHECK: ROOT {{.*}} custom-call( + ; CHECK: custom_call_target="fake_target" + ; CHECK: } + ; CHECK: ENTRY %main{{.*}} { + ; CHECK: ROOT {{.*}} fusion( + ; CHECK: kind=kCustom + ; CHECK: "name":"dynamic_slice_fusion" + ; CHECK: } + )"; + + auto f32_288 = ShapeUtil::MakeShape(F32, {2, 8, 8}); + auto f32_188 = ShapeUtil::MakeShape(F32, {1, 8, 8}); + auto f32_88 = ShapeUtil::MakeShape(F32, {8, 8}); + + auto fusion_checks = [&](HloModule* module) { + auto* hero = DynamicSliceFusion::FindHero(FindDsfBody(module)); + + ASSERT_OK_AND_ASSIGN(auto params, + DynamicSliceFusion::ResolveParameters(hero)); + EXPECT_THAT( + params, + ElementsAre( + Param{0, f32_288, f32_188, MakeStaticConfig(256), std::nullopt}, + Param{1, f32_88, f32_88, std::nullopt, std::nullopt})); + }; + + RunAndFilecheckHloRewrite(hlo, MakePipeline(OptLevel::kO2), expected, + fusion_checks); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/transforms/explicit_collectives_group_async_wrapper.cc b/third_party/xla/xla/backends/gpu/transforms/explicit_collectives_group_async_wrapper.cc index 07a80bd6d88832..36d3417f824c98 100644 --- a/third_party/xla/xla/backends/gpu/transforms/explicit_collectives_group_async_wrapper.cc +++ b/third_party/xla/xla/backends/gpu/transforms/explicit_collectives_group_async_wrapper.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -65,7 +66,7 @@ absl::StatusOr CreateCollectivesGroupAsyncPair(HloInstruction* instr) { // Forward frontend attributes to both async instructions. async_start->set_frontend_attributes(instr->frontend_attributes()); async_done->set_frontend_attributes(instr->frontend_attributes()); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(instr, async_done)); + RETURN_IF_ERROR(computation->ReplaceInstruction(instr, async_done)); return true; } } // namespace @@ -77,7 +78,7 @@ absl::StatusOr ExplicitCollectivesGroupAsyncWrapper::RunImpl( for (const HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instr : comp->instructions()) { - TF_ASSIGN_OR_RETURN(bool result, CreateCollectivesGroupAsyncPair(instr)); + ASSIGN_OR_RETURN(bool result, CreateCollectivesGroupAsyncPair(instr)); changed |= result; } } diff --git a/third_party/xla/xla/backends/gpu/transforms/explicit_stream_annotation_async_wrapper.cc b/third_party/xla/xla/backends/gpu/transforms/explicit_stream_annotation_async_wrapper.cc index a871b825955756..affe11e8ad7c56 100644 --- a/third_party/xla/xla/backends/gpu/transforms/explicit_stream_annotation_async_wrapper.cc +++ b/third_party/xla/xla/backends/gpu/transforms/explicit_stream_annotation_async_wrapper.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -56,7 +57,7 @@ static absl::StatusOr AsynchronizeInstruction(HloInstruction* instr) { } ClearSchedulingAnnotations(instr); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * done, computation->CreateAsyncInstructions( instr, {}, ExplicitStreamAnnotationAsyncWrapper::kMainExecutionThread, @@ -64,12 +65,12 @@ static absl::StatusOr AsynchronizeInstruction(HloInstruction* instr) { // Replace the original attributes after creating the async pair. done->set_frontend_attributes(original_attributes); done->mutable_operand(0)->set_frontend_attributes(original_attributes); - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - done->backend_config()); + ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + done->backend_config()); // Set earliest schedule of done op to be false so it can be scheduled // far apart from start. gpu_config.set_force_earliest_schedule(false); - TF_RETURN_IF_ERROR(done->set_backend_config(gpu_config)); + RETURN_IF_ERROR(done->set_backend_config(gpu_config)); VLOG(5) << "Created async instruction: " << done->ToString(); return true; } @@ -82,7 +83,7 @@ absl::StatusOr ExplicitStreamAnnotationAsyncWrapper::RunImpl( for (const HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instr : comp->instructions()) { - TF_ASSIGN_OR_RETURN(bool result, AsynchronizeInstruction(instr)); + ASSIGN_OR_RETURN(bool result, AsynchronizeInstruction(instr)); changed |= result; } } diff --git a/third_party/xla/xla/backends/gpu/transforms/fusion_block_level_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/fusion_block_level_rewriter.cc index 6d89ef1b496736..e434991e5333e0 100644 --- a/third_party/xla/xla/backends/gpu/transforms/fusion_block_level_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/fusion_block_level_rewriter.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/Support/MathExtras.h" #include "xla/backends/gpu/codegen/triton/support.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -181,8 +182,8 @@ absl::StatusOr ProcessFusionInstruction( const se::DeviceDescription& device_info, HloCostAnalysis::ShapeSizeFunction shape_size, mlir::MLIRContext* mlir_context) { - TF_ASSIGN_OR_RETURN(bool should_try_rewrite, - ShouldTryRewriteFusion(fusion_instruction, device_info)); + ASSIGN_OR_RETURN(bool should_try_rewrite, + ShouldTryRewriteFusion(fusion_instruction, device_info)); if (!should_try_rewrite) { VLOG(2) << "Not rewriting fusion " << fusion_instruction->ToString() << " because it is not supported."; @@ -200,8 +201,8 @@ absl::StatusOr ProcessFusionInstruction( return false; } - TF_ASSIGN_OR_RETURN(auto backend_config, - fusion_instruction->backend_config()); + ASSIGN_OR_RETURN(auto backend_config, + fusion_instruction->backend_config()); if (backend_config.has_fusion_backend_config() && backend_config.fusion_backend_config().has_block_level_fusion_config()) { @@ -216,7 +217,7 @@ absl::StatusOr ProcessFusionInstruction( auto fusion_adaptor = HloFusionAdaptor::ForInstruction( Cast(fusion_instruction)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( TiledRunTimeDataOrError tiled_runtime_data_or_error, indexing_performance_model.TryFindBestTilingForFusion(*fusion_adaptor)); @@ -246,7 +247,7 @@ absl::StatusOr ProcessFusionInstruction( ->mutable_block_level_fusion_config() = tiled_runtime_data.block_level_parameters.ToBlockLevelFusionConfig(); backend_config.mutable_fusion_backend_config()->set_kind(kTritonFusionKind); - TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(backend_config)); + RETURN_IF_ERROR(fusion_instruction->set_backend_config(backend_config)); fusion_instruction->set_fusion_kind(HloInstruction::FusionKind::kCustom); return true; } @@ -256,7 +257,7 @@ absl::StatusOr ProcessFusionInstruction( absl::StatusOr FusionBlockLevelRewriter::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability( + RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability( device_info_.gpu_compute_capability())); bool has_changed = false; @@ -268,9 +269,9 @@ absl::StatusOr FusionBlockLevelRewriter::RunImpl( } HloFusionInstruction* fusion_instruction = ::xla::Cast(computation->FusionInstruction()); - TF_ASSIGN_OR_RETURN( - bool changed, ProcessFusionInstruction(fusion_instruction, device_info_, - shape_size_, mlir_context_)); + ASSIGN_OR_RETURN(bool changed, + ProcessFusionInstruction(fusion_instruction, device_info_, + shape_size_, mlir_context_)); has_changed |= changed; } diff --git a/third_party/xla/xla/backends/gpu/transforms/fusion_wrapper_test.cc b/third_party/xla/xla/backends/gpu/transforms/fusion_wrapper_test.cc index 8d59b55938264a..5adf0c63997c06 100644 --- a/third_party/xla/xla/backends/gpu/transforms/fusion_wrapper_test.cc +++ b/third_party/xla/xla/backends/gpu/transforms/fusion_wrapper_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" @@ -28,9 +29,9 @@ namespace gpu { namespace { absl::StatusOr MakeDeviceDescription() { - TF_ASSIGN_OR_RETURN(stream_executor::DeviceDescription device_description, - stream_executor::DeviceDescription::FromProto( - stream_executor::GpuDeviceInfoProto{})); + ASSIGN_OR_RETURN(stream_executor::DeviceDescription device_description, + stream_executor::DeviceDescription::FromProto( + stream_executor::GpuDeviceInfoProto{})); device_description.set_threads_per_warp(32); return device_description; } diff --git a/third_party/xla/xla/backends/gpu/transforms/gemm_broadcast_folding_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/gemm_broadcast_folding_rewriter.cc index 85300466a9a4b4..99ecc2b236146b 100644 --- a/third_party/xla/xla/backends/gpu/transforms/gemm_broadcast_folding_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/gemm_broadcast_folding_rewriter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -48,8 +49,8 @@ class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor { (Match(instr, m::CustomCall(&existing_gemm, {kGemmCallTarget, kCublasLtMatmulCallTarget}) .WithOperand(1, m::Broadcast(&bcast, m::Op()))))) { - TF_ASSIGN_OR_RETURN(auto gpu_config, - existing_gemm->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, + existing_gemm->backend_config()); GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config(); DotDimensionNumbers *dim_nums = config.mutable_dot_dimension_numbers(); int bcast_operand_index = instr->operand_index(bcast); @@ -93,9 +94,9 @@ class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor { 0, dim_nums->lhs_contracting_dimensions(0) - num_batch_dims); dim_nums->clear_lhs_batch_dimensions(); } - TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWithDifferentShape( + RETURN_IF_ERROR(existing_gemm->ReplaceOperandWithDifferentShape( bcast_operand_index, bcast->mutable_operand(0))); - TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config)); + RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config)); MarkAsChanged(); } return absl::OkStatus(); @@ -104,7 +105,7 @@ class GemmBroadcastFoldingVisitor : public DfsHloRewriteVisitor { static absl::StatusOr RunOnComputation(HloComputation *computation) { GemmBroadcastFoldingVisitor visitor; - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); return visitor.changed(); } @@ -114,7 +115,7 @@ absl::StatusOr GemmBroadcastFoldingRewriter::RunImpl( bool changed = false; for (HloComputation *computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); changed |= result; } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/gemm_fusion_swap_operands.cc b/third_party/xla/xla/backends/gpu/transforms/gemm_fusion_swap_operands.cc index cf1f4d2320b290..701805888992db 100644 --- a/third_party/xla/xla/backends/gpu/transforms/gemm_fusion_swap_operands.cc +++ b/third_party/xla/xla/backends/gpu/transforms/gemm_fusion_swap_operands.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/codegen/triton/support.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" @@ -103,8 +104,8 @@ absl::Status SwapDotOperandsInFusion(HloComputation* computation) { HloDotInstruction* new_dot = MakeDotWithSwappedOperands(dot); HloInstruction* new_bitcast = computation->AddInstruction( HloInstruction::CreateBitcast(dot->shape(), new_dot), &dot->metadata()); - TF_RETURN_IF_ERROR(dot->ReplaceAllUsesWith(new_bitcast)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(dot)); + RETURN_IF_ERROR(dot->ReplaceAllUsesWith(new_bitcast)); + RETURN_IF_ERROR(computation->RemoveInstruction(dot)); return absl::OkStatus(); } @@ -162,10 +163,10 @@ absl::StatusOr ShouldSwapOperands(const HloInstruction* instr) { } const bool lhs_has_code = HasCodeGeneratingInstructions(dot->operand(0)); const bool rhs_has_code = HasCodeGeneratingInstructions(dot->operand(1)); - TF_ASSIGN_OR_RETURN(const int64_t lhs_size, GetNonContractingDimsNumElements( - dot, /*operand_index=*/0)); - TF_ASSIGN_OR_RETURN(const int64_t rhs_size, GetNonContractingDimsNumElements( - dot, /*operand_index=*/1)); + ASSIGN_OR_RETURN(const int64_t lhs_size, + GetNonContractingDimsNumElements(dot, /*operand_index=*/0)); + ASSIGN_OR_RETURN(const int64_t rhs_size, + GetNonContractingDimsNumElements(dot, /*operand_index=*/1)); if (lhs_size < 64 && rhs_size >= 64) { return true; } @@ -184,7 +185,7 @@ absl::StatusOr EmitterCanHandleSwappedOperands( HloCloneContext clone_context(&tmp_module); HloComputation* cloned_computation = tmp_module.AddEntryComputation( dot->parent()->CloneInContext(clone_context)); - TF_RETURN_IF_ERROR(SwapDotOperandsInFusion(cloned_computation)); + RETURN_IF_ERROR(SwapDotOperandsInFusion(cloned_computation)); // If we fail to create a TritonFusionAnalysis, then the emitter can't handle // the fusion and choose not to make any changes. return TritonFusionAnalysis::Execute(*cloned_computation).ok(); @@ -196,16 +197,16 @@ absl::StatusOr MaybeSwapOperands(HloComputation* computation) { if (dot == nullptr) { return false; } - TF_ASSIGN_OR_RETURN(const bool should_swap_operands, ShouldSwapOperands(dot)); + ASSIGN_OR_RETURN(const bool should_swap_operands, ShouldSwapOperands(dot)); if (!should_swap_operands) { return false; } - TF_ASSIGN_OR_RETURN(const bool can_handle_swapped_operands, - EmitterCanHandleSwappedOperands(dot)); + ASSIGN_OR_RETURN(const bool can_handle_swapped_operands, + EmitterCanHandleSwappedOperands(dot)); if (!can_handle_swapped_operands) { return false; } - TF_RETURN_IF_ERROR(SwapDotOperandsInFusion(computation)); + RETURN_IF_ERROR(SwapDotOperandsInFusion(computation)); return true; } @@ -220,7 +221,7 @@ absl::StatusOr GemmFusionSwapOperands::RunImpl( if (!IsTritonFusedComputation(*computation)) { continue; } - TF_ASSIGN_OR_RETURN(const bool changed, MaybeSwapOperands(computation)); + ASSIGN_OR_RETURN(const bool changed, MaybeSwapOperands(computation)); any_changed |= changed; } return any_changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/gemm_fusion_test.cc b/third_party/xla/xla/backends/gpu/transforms/gemm_fusion_test.cc index 2eb188a157b2f8..46da1d5ee41702 100644 --- a/third_party/xla/xla/backends/gpu/transforms/gemm_fusion_test.cc +++ b/third_party/xla/xla/backends/gpu/transforms/gemm_fusion_test.cc @@ -1715,6 +1715,9 @@ ENTRY main { module->mutable_config().mutable_debug_options().set_xla_gpu_enable_cublaslt( false); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_experimental_use_ragged_dot_fusion(false); EXPECT_THAT(GemmFusion(gpu_version_).Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT( diff --git a/third_party/xla/xla/backends/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/gemm_rewriter.cc index 25b4cec647eba7..25d8ed6f19adcc 100644 --- a/third_party/xla/xla/backends/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/gemm_rewriter.cc @@ -40,6 +40,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -91,8 +92,8 @@ absl::Status SetName(HloModule* module, HloInstruction* gemm) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - gemm->backend_config()); + ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + gemm->backend_config()); const GemmBackendConfig& config = gpu_config.gemm_backend_config(); const DotDimensionNumbers& dot_dims = config.dot_dimension_numbers(); bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() || @@ -182,8 +183,8 @@ absl::StatusOr InvertAndConvertScalar(HloInstruction* scalar, Literal one_literal = LiteralUtil::One(scalar->shape().element_type()); HloInstruction* one = scalar->parent()->AddInstruction( HloInstruction::CreateConstant(one_literal.Clone())); - TF_ASSIGN_OR_RETURN(scalar, MakeBinaryHlo(HloOpcode::kDivide, one, scalar, - &scalar->metadata())); + ASSIGN_OR_RETURN(scalar, MakeBinaryHlo(HloOpcode::kDivide, one, scalar, + &scalar->metadata())); } if (scalar->shape().element_type() != F32) { scalar = MakeConvertToHlo(scalar, F32, &scalar->metadata()); @@ -592,7 +593,7 @@ absl::StatusOr NormalizeBatchDimensions(HloInstruction* dot) { return dot; } - TF_ASSIGN_OR_RETURN(auto dims, DotOperandDims::FromDot(dot)); + ASSIGN_OR_RETURN(auto dims, DotOperandDims::FromDot(dot)); std::array operands = {dot->mutable_operand(0), dot->mutable_operand(1)}; @@ -607,28 +608,26 @@ absl::StatusOr NormalizeBatchDimensions(HloInstruction* dot) { operands[i]->shape().dimensions().size()); absl::c_iota(permutation, 0); MoveSingleElement(absl::MakeSpan(permutation), b1, b1 < b0 ? b0 : b0 + 1); - TF_ASSIGN_OR_RETURN(operands[i], - MakeTransposeHlo(operands[i], permutation)); + ASSIGN_OR_RETURN(operands[i], MakeTransposeHlo(operands[i], permutation)); LayoutUtil::SetToDefaultLayout(operands[i]->mutable_shape()); dims[i].ApplyPermutation(permutation); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( dims[i].Collapse(DotOperandDims::kBatch, /*remove_if_empty=*/false)); - TF_ASSIGN_OR_RETURN(operands[i], - MakeReshapeHlo(dims[i].shape(), operands[i])); + ASSIGN_OR_RETURN(operands[i], MakeReshapeHlo(dims[i].shape(), operands[i])); } - TF_ASSIGN_OR_RETURN(DotDimensionNumbers new_dnums, - DotOperandDims::CreateDotDimensionNumbers(dims)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(DotDimensionNumbers new_dnums, + DotOperandDims::CreateDotDimensionNumbers(dims)); + ASSIGN_OR_RETURN( HloInstruction * new_dot, MakeDotHlo(operands[0], operands[1], new_dnums, dot_instr->precision_config(), dot_instr->shape().element_type(), &dot_instr->metadata())); - TF_ASSIGN_OR_RETURN(HloInstruction * reshape, - MakeReshapeHlo(dot->shape(), new_dot)); + ASSIGN_OR_RETURN(HloInstruction * reshape, + MakeReshapeHlo(dot->shape(), new_dot)); return reshape; } @@ -675,7 +674,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { options_(options) {} absl::Status HandleDot(HloInstruction* instr) override { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool is_supported_matmul, IsCublasSupportedMatMul(*instr, /*allow_matrix_vector_multiplication=*/true)); @@ -683,10 +682,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(HloInstruction * normalized_instr, - NormalizeBatchDimensions(instr)); + ASSIGN_OR_RETURN(HloInstruction * normalized_instr, + NormalizeBatchDimensions(instr)); if (normalized_instr != instr) { - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, normalized_instr)); + RETURN_IF_ERROR(ReplaceInstruction(instr, normalized_instr)); // After normalization, the dot instruction is followed by a reshape, // taking the operand to get the actual dot instruction. instr = normalized_instr->mutable_operand(0); @@ -697,16 +696,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { ->config() .debug_options() .xla_gpu_gemm_rewrite_size_threshold(); - TF_ASSIGN_OR_RETURN(bool is_matmul_tiny, - IsMatrixMultiplicationTooSmallForRewriting( - *instr, gemm_rewrite_size_threshold)); + ASSIGN_OR_RETURN(bool is_matmul_tiny, + IsMatrixMultiplicationTooSmallForRewriting( + *instr, gemm_rewrite_size_threshold)); if (is_matmul_tiny && IsDotSupportedByClassicalEmitters(*instr)) { return absl::OkStatus(); } // Create a GemmBackendConfig based on the instruction. - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config, - instr->backend_config()); + ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config, + instr->backend_config()); GemmBackendConfig& gemm_backend_config = *gpu_backend_config.mutable_gemm_backend_config(); gemm_backend_config.set_alpha_real(1.0); @@ -743,7 +742,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { switch (options_.dtype) { case GemmRewriterOptions::DType::kFp8Only: { // Rewrite FP8 GEMMs into a type-specific cublasLT Custom Call. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool supported_by_cublaslt, GemmIsSupportedByCublasLt(*instr, gemm_backend_config)); std::optional a, b; @@ -756,12 +755,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { toolkit_version_ < stream_executor::SemanticVersion{6, 2, 0} && instr->shape().element_type() != F16 && instr->shape().element_type() != F32) { - TF_ASSIGN_OR_RETURN( - instr, TurnF8DotWithUnsupportedOutputTypeIntoF32(instr)); + ASSIGN_OR_RETURN(instr, + TurnF8DotWithUnsupportedOutputTypeIntoF32(instr)); } - TF_ASSIGN_OR_RETURN(bool created_call, - CreateF8CustomCall(instr, gpu_backend_config, - a.value(), b.value())); + ASSIGN_OR_RETURN(bool created_call, + CreateF8CustomCall(instr, gpu_backend_config, + a.value(), b.value())); if (created_call) { return absl::OkStatus(); } @@ -770,18 +769,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // FP8 rewriter couldn't rewrite dot with FP8 inputs into cublasLt // custom call, so turn into an FP16 dot which may be rewritten as an // FP16 Triton, cublas or cublasLt call. - TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr)); + ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr)); } break; } case GemmRewriterOptions::DType::kNonFp8Only: { if (gemm_backend_config.precision_config().algorithm() == PrecisionConfig::ALG_DOT_BF16_BF16_F32) { - TF_RETURN_IF_ERROR(TurnDotIntoConvertAndDotForBF16BF16F32( + RETURN_IF_ERROR(TurnDotIntoConvertAndDotForBF16BF16F32( instr, gemm_backend_config, gpu_backend_config)); } else { // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( absl::string_view gemm_custom_call_target, GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); const Shape& output_shape = instr->shape(); @@ -790,8 +789,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { output_shape, {instr->mutable_operand(0), instr->mutable_operand(1)}, gemm_custom_call_target)); - TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); + RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); + RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); } } break; }; @@ -847,7 +846,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { gpu::kCublasLtGroupedMatmulCallTarget)); // Create a GroupedGemmBackendConfig based on the instruction. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( gpu::GpuBackendConfig gpu_backend_config, grouped_gemm_call->backend_config()); GroupedGemmBackendConfig& grouped_gemm_backend_config = @@ -868,10 +867,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { gemm_backend_config.set_grad_x(attributes["grad_x"] == "true"); gemm_backend_config.set_grad_y(attributes["grad_y"] == "true"); - TF_RETURN_IF_ERROR( - grouped_gemm_call->set_backend_config(gpu_backend_config)); + RETURN_IF_ERROR(grouped_gemm_call->set_backend_config(gpu_backend_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, grouped_gemm_call)); + RETURN_IF_ERROR(ReplaceInstruction(instr, grouped_gemm_call)); return absl::OkStatus(); } @@ -887,15 +885,15 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { auto rhs_convert = instr->mutable_operand(1)->AddInstruction( HloInstruction::CreateConvert(rhs_shape, instr->mutable_operand(1))); gemm_backend_config.mutable_precision_config()->clear_algorithm(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( absl::string_view gemm_custom_call_target, GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); const Shape& output_shape = instr->shape(); HloInstruction* gemm_call = instr->AddInstruction(HloInstruction::CreateCustomCall( output_shape, {lhs_convert, rhs_convert}, gemm_custom_call_target)); - TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); + RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); + RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); return absl::OkStatus(); } @@ -905,8 +903,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { m::MultiplyAnyOrder( CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(), m::Broadcast(m::ConstantScalar(&alpha)).WithOneUser()))) { - TF_ASSIGN_OR_RETURN(auto gpu_config, - existing_gemm->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, + existing_gemm->backend_config()); GemmBackendConfig& config = *gpu_config.mutable_gemm_backend_config(); // Do not fuse alpha into S32 GEMM, as they only support fixed values for // alpha/beta. @@ -920,7 +918,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { *alpha->literal().GetAsComplex128({}) * prev_alpha; config.set_alpha_real(new_alpha.real()); config.set_alpha_imag(new_alpha.imag()); - TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config)); + RETURN_IF_ERROR(existing_gemm->set_backend_config(gpu_config)); return ReplaceInstruction(instr, existing_gemm); } } @@ -1052,7 +1050,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { .WithOneUser(), m::Broadcast(&bias, OptionalConvert(&optional_convert, m::Op()))))) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool was_fused, FuseVectorBiasAdd(instr, bias, existing_gemm, optional_slice, optional_convert, optional_bitcast)); @@ -1073,11 +1071,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { .WithOneUser()) .WithOneUser(), m::Broadcast(&bias, m::Op()).WithOneUser()))) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_add, MakeBinaryHlo(HloOpcode::kAdd, existing_gemm, MakeBitcastHlo(bias, existing_gemm->shape()))); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape()))); // Continue below. @@ -1107,10 +1105,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { m::Op(&bias).WithPredicate(is_not_broadcast)))) { HloInstruction* new_bitcast = MakeBitcastHlo(bias, existing_gemm->shape(), &bias->metadata()); - TF_ASSIGN_OR_RETURN(HloInstruction * new_add, - MakeBinaryHlo(HloOpcode::kAdd, existing_gemm, - new_bitcast, &bias->metadata())); - TF_RETURN_IF_ERROR( + ASSIGN_OR_RETURN(HloInstruction * new_add, + MakeBinaryHlo(HloOpcode::kAdd, existing_gemm, + new_bitcast, &bias->metadata())); + RETURN_IF_ERROR( ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape()))); // Continue below transforming new_add. @@ -1127,14 +1125,14 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { m::Convert(CublasLtMatmul(&existing_gemm).WithOneUser()) .WithOneUser()), m::Op(&bias).WithPredicate(is_not_broadcast)))) { - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config, - existing_gemm->backend_config()); + ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config, + existing_gemm->backend_config()); const GemmBackendConfig& gemm_backend_config = gpu_backend_config.gemm_backend_config(); // check if type combination is supported here - TF_ASSIGN_OR_RETURN(bool types_are_supported, - TypesAreSupportedByCublasLt( - *existing_gemm, gemm_backend_config, instr)); + ASSIGN_OR_RETURN(bool types_are_supported, + TypesAreSupportedByCublasLt(*existing_gemm, + gemm_backend_config, instr)); // for mix type gemm, only fuse add if there is no consumers // ROOT add @@ -1193,8 +1191,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { CublasLtMatmulMaybeF8OrGrouped(&existing_gemm)) .WithOneUser(), m::Broadcast(&zeros, m::ConstantScalar(0))))) { - TF_RETURN_IF_ERROR(FuseReluActivation(instr, zeros, existing_gemm, - optional_slice_or_bitcast)); + RETURN_IF_ERROR(FuseReluActivation(instr, zeros, existing_gemm, + optional_slice_or_bitcast)); } return absl::OkStatus(); } @@ -1252,8 +1250,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { *gpu_backend_config.mutable_gemm_backend_config(); se::CudaComputeCapability cuda_compute_capability; if (gpu_version_.IsCuda()) { - TF_ASSIGN_OR_RETURN(cuda_compute_capability, - GetCudaComputeCapability(gpu_version_)); + ASSIGN_OR_RETURN(cuda_compute_capability, + GetCudaComputeCapability(gpu_version_)); // FP8 GEMM kernels are only available on Ada, Hopper, and later // architectures. if (!cuda_compute_capability.IsAtLeast(8, 9)) { @@ -1271,8 +1269,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } if (gpu_version_.IsRocm()) { - TF_ASSIGN_OR_RETURN(auto rocm_compute_capability, - GetRocmComputeCapability(gpu_version_)); + ASSIGN_OR_RETURN(auto rocm_compute_capability, + GetRocmComputeCapability(gpu_version_)); if (!rocm_compute_capability.has_fp8_support()) { VLOG(1) << "FP8 Custom Calls require MI300, or later architectures."; return false; @@ -1309,8 +1307,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } if (gpu_version_.IsRocm()) { - TF_ASSIGN_OR_RETURN(auto rocm_compute_capability, - GetRocmComputeCapability(gpu_version_)); + ASSIGN_OR_RETURN(auto rocm_compute_capability, + GetRocmComputeCapability(gpu_version_)); if (rocm_compute_capability.has_ocp_fp8_support()) { if (a_type == F8E5M2 && b_type == F8E5M2) { VLOG(1) << "Failed to rewrite " << instr->ToShortString() @@ -1425,8 +1423,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } } - TF_ASSIGN_OR_RETURN(auto rocm_compute_capability, - GetRocmComputeCapability(gpu_version_)); + ASSIGN_OR_RETURN(auto rocm_compute_capability, + GetRocmComputeCapability(gpu_version_)); if (rocm_compute_capability.has_ocp_fp8_support()) { supported_d_types.insert(F8E4M3FN); supported_d_types.insert(F8E5M2); @@ -1524,14 +1522,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { shift_ops(a.fp8_input, a.commutative_ops); shift_ops(b.fp8_input, b.commutative_ops); - TF_ASSIGN_OR_RETURN( - std::vector a_non_contracting_dims, - GetNonContractingDims(a.fp8_input->shape(), a_batch_dims, - a_contracting_dims)); - TF_ASSIGN_OR_RETURN( - std::vector b_non_contracting_dims, - GetNonContractingDims(b.fp8_input->shape(), b_batch_dims, - b_contracting_dims)); + ASSIGN_OR_RETURN(std::vector a_non_contracting_dims, + GetNonContractingDims(a.fp8_input->shape(), a_batch_dims, + a_contracting_dims)); + ASSIGN_OR_RETURN(std::vector b_non_contracting_dims, + GetNonContractingDims(b.fp8_input->shape(), b_batch_dims, + b_contracting_dims)); if (a_non_contracting_dims.size() != 1 || b_non_contracting_dims.size() != 1) { VLOG(1) << "Failed to rewrite " << instr->ToShortString() @@ -1540,9 +1536,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } - TF_ASSIGN_OR_RETURN( - GemmConfig gemm_config, - GemmConfig::For(instr, gemm_backend_config, gpu_version_)); + ASSIGN_OR_RETURN(GemmConfig gemm_config, + GemmConfig::For(instr, gemm_backend_config, gpu_version_)); DotDimensionNumbers* dim_nums = gemm_backend_config.mutable_dot_dimension_numbers(); @@ -1587,8 +1582,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr->shape().element_type(), new_output_shape.dimensions(), instr->shape().layout().minor_to_major()), operands_list, kCublasLtMatmulF8CallTarget)); - TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config)); - TF_RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call)); + RETURN_IF_ERROR(new_custom_call->set_backend_config(gpu_backend_config)); + RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call)); // Slice the result of the GEMM if the operands were padded. HloInstruction* slice = nullptr; @@ -1600,8 +1595,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr->shape().dimensions(), strides)); } - TF_RETURN_IF_ERROR( - ReplaceInstruction(instr, slice ? slice : new_custom_call)); + RETURN_IF_ERROR(ReplaceInstruction(instr, slice ? slice : new_custom_call)); VLOG(1) << instr->ToString() << " rewritten into FP8 Custom Call."; return true; } @@ -1627,8 +1621,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // The application of the scaling of the output to the input (see previous // comment) is not valid for epilogues other than ReLU or when a matrix bias // has been fused. - TF_ASSIGN_OR_RETURN(auto gpu_backend_config, - existing_gemm->backend_config()); + ASSIGN_OR_RETURN(auto gpu_backend_config, + existing_gemm->backend_config()); const GemmBackendConfig& config = gpu_backend_config.gemm_backend_config(); if ((config.epilogue() != GemmBackendConfig::DEFAULT && config.epilogue() != GemmBackendConfig::RELU) || @@ -1637,12 +1631,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } // If necessary, invert the scaling factor of D and convert to F32. - TF_ASSIGN_OR_RETURN( - d_scale, InvertAndConvertScalar( - d_scale, HloPredicateIsOp(instr))); + ASSIGN_OR_RETURN(d_scale, + InvertAndConvertScalar( + d_scale, HloPredicateIsOp(instr))); - TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(2, d_scale)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm)); + RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith(2, d_scale)); + RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm)); VLOG(1) << "Scaling of FP8 GEMM fused into Custom Call."; return absl::OkStatus(); @@ -1688,8 +1682,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (gemm_users.size() == 2) { // In the presence of a ReLU activation, the abs instruction is elided // since abs(ReLU(x)) = ReLU(x). - TF_ASSIGN_OR_RETURN(auto gpu_config, - existing_gemm->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, + existing_gemm->backend_config()); const GemmBackendConfig& config = gpu_config.gemm_backend_config(); for (int i = 0; i < gemm_users.size(); ++i) { HloInstruction* maybe_reduce = nullptr; @@ -1730,8 +1724,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto gpu_backend_config, - existing_gemm->backend_config()); + ASSIGN_OR_RETURN(auto gpu_backend_config, + existing_gemm->backend_config()); const GemmBackendConfig& gemm_backend_config = gpu_backend_config.gemm_backend_config(); @@ -1754,8 +1748,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // If necessary, invert the scaling factor of D and convert to F32. When no // scaling factor was captured, set the factor to one. if (d_scale) { - TF_ASSIGN_OR_RETURN(d_scale, - InvertAndConvertScalar(d_scale, !mult_scale)); + ASSIGN_OR_RETURN(d_scale, InvertAndConvertScalar(d_scale, !mult_scale)); } else { d_scale = instr->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::One(F32))); @@ -1771,7 +1764,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { std::unique_ptr new_gemm = existing_gemm->CloneWithNewShape(instr->shape()); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(new_gemm))); + RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(new_gemm))); VLOG(1) << "Conversion" << (reduce_damax ? " and amax calculation" : "") << " fused into FP8 GEMM."; @@ -1788,11 +1781,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction* gemm_and_damax = instr->AddInstruction(existing_gemm->CloneWithNewShape(tuple_shape)); - TF_ASSIGN_OR_RETURN(auto gpu_config, - gemm_and_damax->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, + gemm_and_damax->backend_config()); GemmBackendConfig& config = *gpu_config.mutable_gemm_backend_config(); config.set_damax_output(true); - TF_RETURN_IF_ERROR(gemm_and_damax->set_backend_config(gpu_config)); + RETURN_IF_ERROR(gemm_and_damax->set_backend_config(gpu_config)); // Obtain D and DAmax separately from the output tuple. HloInstruction* d = @@ -1804,8 +1797,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Convert DAmax from FP32 to the requested type and elide reduce. HloInstruction* damax_converted = instr->AddInstruction( HloInstruction::CreateConvert(reduce_damax->shape(), damax)); - TF_RETURN_IF_ERROR(ReplaceInstruction(reduce_damax, damax_converted)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, d)); + RETURN_IF_ERROR(ReplaceInstruction(reduce_damax, damax_converted)); + RETURN_IF_ERROR(ReplaceInstruction(instr, d)); return absl::OkStatus(); } @@ -1929,7 +1922,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { gemm->CloneWithNewOperands(gemm->shape(), operands); // set output shape to bias shape if mix type fused_op->mutable_shape()->set_element_type(bias->shape().element_type()); - TF_RETURN_IF_ERROR(fused_op->set_backend_config(gpu_config)); + RETURN_IF_ERROR(fused_op->set_backend_config(gpu_config)); // Choose whether the bias must alias the output. Legacy cublas GEMMs must // operate in place and alias the bias with the output, whereas with @@ -1951,7 +1944,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { xla::Cast(fused_op.get()) ->set_output_to_operand_aliasing({{{}, {bias_operand_index, {}}}}); } - TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get())); + RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get())); if (slice) { fused_op = slice->CloneWithNewOperands( slice->shape(), @@ -1991,8 +1984,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction* bias = broadcast->mutable_operand(0); - TF_ASSIGN_OR_RETURN(auto gpu_config, - gemm->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, gemm->backend_config()); GemmBackendConfig& config = GetMutableGemmBackendConfig(gpu_config); // # output column dims == # non-contracting rhs operand dims. const DotDimensionNumbers& dot_dims = config.dot_dimension_numbers(); @@ -2097,8 +2089,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction* result = computation->AddInstruction( gemm->CloneWithNewOperands(gemm->shape(), operands)); - TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(SetName(gemm->GetModule(), result)); + RETURN_IF_ERROR(result->set_backend_config(gpu_config)); + RETURN_IF_ERROR(SetName(gemm->GetModule(), result)); if (slice) { result = computation->AddInstruction( slice->CloneWithNewOperands(slice->shape(), {result})); @@ -2108,7 +2100,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { result = computation->AddInstruction( bitcast->CloneWithNewOperands(bitcast->shape(), {result})); } - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, result)); + RETURN_IF_ERROR(ReplaceInstruction(instr, result)); return true; } @@ -2128,8 +2120,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto gpu_config, - gemm->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, gemm->backend_config()); GemmBackendConfig& config = GetMutableGemmBackendConfig(gpu_config); if (config.epilogue() == GemmBackendConfig::DEFAULT) { config.set_epilogue(GemmBackendConfig::RELU); @@ -2141,8 +2132,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloComputation* computation = gemm->parent(); HloInstruction* result = computation->AddInstruction(gemm->Clone()); - TF_RETURN_IF_ERROR(result->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(SetName(gemm->GetModule(), result)); + RETURN_IF_ERROR(result->set_backend_config(gpu_config)); + RETURN_IF_ERROR(SetName(gemm->GetModule(), result)); if (slice_or_bitcast) { result = @@ -2187,8 +2178,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto gpu_config, - gemm->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, gemm->backend_config()); GemmBackendConfig& config = GetMutableGemmBackendConfig(gpu_config); if (config.epilogue() == GemmBackendConfig::DEFAULT) { @@ -2204,8 +2194,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { std::unique_ptr output = gemm->CloneWithNewShape( has_aux ? ShapeUtil::MakeTupleShape({gemm->shape(), gemm->shape()}) : gemm->shape()); - TF_RETURN_IF_ERROR(output->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(SetName(multiply->GetModule(), output.get())); + RETURN_IF_ERROR(output->set_backend_config(gpu_config)); + RETURN_IF_ERROR(SetName(multiply->GetModule(), output.get())); if (has_aux) { HloInstruction* tuple_output = gemm->AddInstruction(std::move(output)); @@ -2221,14 +2211,14 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { std::unique_ptr new_dot_slice = slice_or_bitcast->CloneWithNewOperands(slice_or_bitcast->shape(), {new_dot_output}); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( - slice_or_bitcast, std::move(new_dot_slice))); + RETURN_IF_ERROR(ReplaceWithNewInstruction(slice_or_bitcast, + std::move(new_dot_slice))); } } if (!gemm->IsDead()) { // `gemm` may already be dead if we replaced `slice_or_bitcast` with the // new dot slice (as we would also delete unused operands). - TF_RETURN_IF_ERROR(ReplaceInstruction(gemm, new_dot_output)); + RETURN_IF_ERROR(ReplaceInstruction(gemm, new_dot_output)); } output = std::move(gelu_output); } else if (slice_or_bitcast) { @@ -2269,8 +2259,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto gpu_config, - gemm->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, gemm->backend_config()); GemmBackendConfig& config = GetMutableGemmBackendConfig(gpu_config); if (config.epilogue() == GemmBackendConfig::DEFAULT) { @@ -2283,8 +2272,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { std::unique_ptr output = gemm->CloneWithNewShape(gemm->shape()); - TF_RETURN_IF_ERROR(output->set_backend_config(gpu_config)); - TF_RETURN_IF_ERROR(SetName(multiply->GetModule(), output.get())); + RETURN_IF_ERROR(output->set_backend_config(gpu_config)); + RETURN_IF_ERROR(SetName(multiply->GetModule(), output.get())); if (slice_or_bitcast) { output = slice_or_bitcast->CloneWithNewOperands( @@ -2307,15 +2296,104 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const GemmBackendConfig& gemm_backend_config) const { // All internal conditions are met, check if we meet the requirements of // cublasLt. - TF_ASSIGN_OR_RETURN(bool gemm_is_supported_by_cublas_lt, - GemmIsSupportedByCublasLt(instr, gemm_backend_config)); + ASSIGN_OR_RETURN(bool gemm_is_supported_by_cublas_lt, + GemmIsSupportedByCublasLt(instr, gemm_backend_config)); if (gemm_is_supported_by_cublas_lt) { return absl::string_view(kCublasLtMatmulCallTarget); } - // This case is not supported by cublasLt. - return absl::InternalError( - "GEMM not supported by cublasLt and legacy cublas is removed."); + // This case is not supported by cublasLt, fallback to legacy cublas. + return absl::string_view(kGemmCallTarget); + } + + absl::StatusOr TypesAreSupportedByLegacyCublas( + const HloInstruction& instr, const GemmBackendConfig& gemm_backend_config, + const HloInstruction* bias = nullptr) const { + // Figure out the Atype/Btype. + const PrimitiveType a_dtype = instr.operand(0)->shape().element_type(); + const PrimitiveType b_dtype = instr.operand(1)->shape().element_type(); + const PrimitiveType output_type = + bias ? bias->shape().element_type() : instr.shape().element_type(); + const std::array supported_type = { + PrimitiveType::S8, PrimitiveType::F16, PrimitiveType::BF16, + PrimitiveType::F32, PrimitiveType::S32, PrimitiveType::F64, + PrimitiveType::C64, PrimitiveType::C128}; + // legacy cublas has a defined set of combinations of types that it + // supports. Figure out the computeType and scaleType. + if (!absl::c_linear_search(supported_type, output_type)) { + return false; + } + TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype, + se::gpu::AsBlasDataType(output_type)); + TF_ASSIGN_OR_RETURN( + const se::blas::ComputationType compute_type, + se::gpu::GetBlasComputationType( + instr.precision_config().algorithm(), a_dtype, output_type, + stream_executor::blas::kDefaultComputePrecision, gpu_version_)); + se::blas::DataType scale_type = + se::gpu::GetScaleType(output_dtype, compute_type); + + using se::blas::ComputationType; + using se::blas::DataType; + // This matrix of supported types is taken directly from cublas + // documentation. + // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex + const std::array< + std::tuple, + 32> + supported_type_combinations = {{ + {ComputationType::kF16, DataType::kHalf, PrimitiveType::F16, + PrimitiveType::F16, DataType::kHalf}, + + {ComputationType::kI32, DataType::kInt32, PrimitiveType::S8, + PrimitiveType::S8, DataType::kInt32}, + + {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16, + PrimitiveType::BF16, DataType::kBF16}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16, + PrimitiveType::F16, DataType::kHalf}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::S8, + PrimitiveType::S8, DataType::kFloat}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::BF16, + PrimitiveType::BF16, DataType::kFloat}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F16, + PrimitiveType::F16, DataType::kFloat}, + {ComputationType::kF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + + // There would be an entry here for A/BType complex int8, but we do + // not support that type. + {ComputationType::kF32, DataType::kComplexFloat, PrimitiveType::C64, + PrimitiveType::C64, DataType::kComplexFloat}, + + {ComputationType::kF16AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kF16AsF32, DataType::kComplexFloat, + PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat}, + + {ComputationType::kBF16AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kBF16AsF32, DataType::kComplexFloat, + PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat}, + + {ComputationType::kTF32AsF32, DataType::kFloat, PrimitiveType::F32, + PrimitiveType::F32, DataType::kFloat}, + {ComputationType::kTF32AsF32, DataType::kComplexFloat, + PrimitiveType::C64, PrimitiveType::C64, DataType::kComplexFloat}, + + {ComputationType::kF64, DataType::kDouble, PrimitiveType::F64, + PrimitiveType::F64, DataType::kDouble}, + {ComputationType::kF64, DataType::kComplexDouble, + PrimitiveType::C128, PrimitiveType::C128, + DataType::kComplexDouble}, + }}; + + return absl::c_linear_search( + supported_type_combinations, + std::make_tuple(compute_type, scale_type, a_dtype, b_dtype, + output_dtype)); } absl::StatusOr TypesAreSupportedByCublasLt( @@ -2348,8 +2426,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } // cublasLt has a defined set of combinations of types that it supports. // Figure out the computeType and scaleType. - TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype, - se::gpu::AsBlasDataType(output_type)); + ASSIGN_OR_RETURN(const se::blas::DataType output_dtype, + se::gpu::AsBlasDataType(output_type)); const int max_precision = *absl::c_max_element( backend_config.precision_config().operand_precision()); const PrecisionConfig::Algorithm algorithm = @@ -2359,10 +2437,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } - TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type, - se::gpu::GetBlasComputationType( - algorithm, a_dtype, instr.shape().element_type(), - max_precision, gpu_version_)); + ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type, + se::gpu::GetBlasComputationType( + algorithm, a_dtype, instr.shape().element_type(), + max_precision, gpu_version_)); se::blas::DataType scale_type = se::gpu::GetScaleType(output_dtype, compute_type); @@ -2567,9 +2645,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const GemmBackendConfig& gemm_backend_config) const { const Shape& output_shape = instr.shape(); - TF_ASSIGN_OR_RETURN( - bool types_are_supported_by_cublas_lt, - TypesAreSupportedByCublasLt(instr, gemm_backend_config)); + ASSIGN_OR_RETURN(bool types_are_supported_by_cublas_lt, + TypesAreSupportedByCublasLt(instr, gemm_backend_config)); if (!types_are_supported_by_cublas_lt) { return false; } @@ -2598,7 +2675,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GemmConfig gemm_config, GemmConfig::For(&instr, gemm_backend_config, gpu_version_)); @@ -2616,7 +2693,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr->AddInstruction(instr->CloneWithNewShape(output_f32_shape)); HloInstruction* convert = instr->AddInstruction( HloInstruction::CreateConvert(instr->shape(), f32_dot)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert)); + RETURN_IF_ERROR(ReplaceInstruction(instr, convert)); return f32_dot; } @@ -2637,7 +2714,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction* convert = instr->AddInstruction(HloInstruction::CreateConvert( operand_f16_shape, instr->mutable_operand(i))); - TF_RETURN_IF_ERROR(instr->ReplaceOperandWith(i, convert)); + RETURN_IF_ERROR(instr->ReplaceOperandWith(i, convert)); } // If output is F8, change output to F16 and then convert it back to F8 @@ -2648,7 +2725,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr->AddInstruction(instr->CloneWithNewShape(output_f16_shape)); HloInstruction* convert_to_f8 = instr->AddInstruction( HloInstruction::CreateConvert(instr->shape(), f16_dot)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, convert_to_f8)); + RETURN_IF_ERROR(ReplaceInstruction(instr, convert_to_f8)); return f16_dot; } return instr; @@ -2668,12 +2745,12 @@ class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor { bool has_aux_output = false; if (instr->custom_call_target() == kCublasLtMatmulCallTarget || instr->custom_call_target() == kCublasLtMatmulF8CallTarget) { - TF_ASSIGN_OR_RETURN(const auto gpu_config, - instr->backend_config()); + ASSIGN_OR_RETURN(const auto gpu_config, + instr->backend_config()); const xla::gpu::GemmBackendConfig& config = gpu_config.gemm_backend_config(); xla::gpu::GemmBackendConfig_Epilogue epilogue = config.epilogue(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( has_aux_output, xla::gpu::gpublas_lt::EpilogueHasAuxiliaryOutput(epilogue)); @@ -2765,7 +2842,7 @@ class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor { HloInstruction* get_output = instr->AddInstruction(HloInstruction::CreateGetTupleElement( new_call, user_get_tuple->tuple_index())); - TF_RETURN_IF_ERROR(ReplaceInstruction(user_get_tuple, get_output)); + RETURN_IF_ERROR(ReplaceInstruction(user_get_tuple, get_output)); } return absl::OkStatus(); } @@ -2783,9 +2860,9 @@ absl::StatusOr RunOnComputation(HloComputation* computation, se::SemanticVersion toolkit_version, GemmRewriterOptions options) { GemmRewriterVisitor visitor(gpu_version, toolkit_version, options); - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); GemmWorkspaceRewriteVisitor workspace_visitor(gpu_version); - TF_RETURN_IF_ERROR(computation->Accept(&workspace_visitor)); + RETURN_IF_ERROR(computation->Accept(&workspace_visitor)); return visitor.changed() || workspace_visitor.changed(); } @@ -2804,9 +2881,8 @@ absl::StatusOr GemmRewriter::RunImpl( bool changed = false; for (HloComputation* computation : GetFusibleComputations(*module, execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, - RunOnComputation(computation, gpu_version_, - toolkit_version_, options_)); + ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, gpu_version_, + toolkit_version_, options_)); changed |= result; } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/hoist_fused_bitcasts.cc b/third_party/xla/xla/backends/gpu/transforms/hoist_fused_bitcasts.cc index 8ee0ba79ab5e17..068bb847451e16 100644 --- a/third_party/xla/xla/backends/gpu/transforms/hoist_fused_bitcasts.cc +++ b/third_party/xla/xla/backends/gpu/transforms/hoist_fused_bitcasts.cc @@ -148,7 +148,7 @@ PlanHoistBitcastUpwardsToCallers(const HloInstruction* bitcast) { // It is possible to support more cases by sinking the bitcast from such // producers downward. HloInstructionSetVector producers = GetProducerSet(bitcast); - TF_RETURN_IF_ERROR(VerifyIsClosedProducerSet(producers, bitcast)); + RETURN_IF_ERROR(VerifyIsClosedProducerSet(producers, bitcast)); if (bitcast->shape().element_type() != bitcast->operand(0)->shape().element_type()) { return absl::UnimplementedError( @@ -185,7 +185,7 @@ PlanHoistBitcastUpwardsToCallers(const HloInstruction* bitcast) { } return absl::OkStatus(); }; - TF_RETURN_IF_ERROR(set_result_shape(bitcast->operands(), bitcast->shape())); + RETURN_IF_ERROR(set_result_shape(bitcast->operands(), bitcast->shape())); std::vector> result; // We want to visit instructions in order from consumers to producers: we @@ -219,20 +219,20 @@ PlanHoistBitcastUpwardsToCallers(const HloInstruction* bitcast) { // update its operand. break; case HloOpcode::kBroadcast: { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( BitcastParams params, CalculateBitcastOfBroadcast( Cast(instruction), result_shape)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( set_result_shape(instruction->operands(), params.new_shape)); break; } case HloOpcode::kTranspose: { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( BitcastParams params, CalculateBitcastOfTranspose( Cast(instruction), result_shape)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( set_result_shape(instruction->operands(), params.new_shape)); break; } @@ -241,7 +241,7 @@ PlanHoistBitcastUpwardsToCallers(const HloInstruction* bitcast) { return absl::FailedPreconditionError(absl::StrCat( "Cannot hoist bitcast past ", instruction->ToString())); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( set_result_shape(instruction->operands(), result_shape)); break; } @@ -302,7 +302,7 @@ absl::StatusOr ComputeRootShapeAfterHoistingBitcasts( } return absl::OkStatus(); }; - TF_RETURN_IF_ERROR(set_operand_shape(dot->users(), dot->shape())); + RETURN_IF_ERROR(set_operand_shape(dot->users(), dot->shape())); for (HloInstruction* instruction : GetConsumerSet(dot)) { auto it = operand_shapes.find(instruction); @@ -310,7 +310,7 @@ absl::StatusOr ComputeRootShapeAfterHoistingBitcasts( continue; // Not affected. } Shape& operand_shape = it->second; - TF_ASSIGN_OR_RETURN(Shape result_shape, [&]() -> absl::StatusOr { + ASSIGN_OR_RETURN(Shape result_shape, [&]() -> absl::StatusOr { switch (instruction->opcode()) { case HloOpcode::kBroadcast: { auto paramsOr = CalculateBroadcastOfBitcast( @@ -348,7 +348,7 @@ absl::StatusOr ComputeRootShapeAfterHoistingBitcasts( CopyElementType(instruction->shape(), &result_shape); return result_shape; } - TF_RETURN_IF_ERROR(set_operand_shape(instruction->users(), result_shape)); + RETURN_IF_ERROR(set_operand_shape(instruction->users(), result_shape)); } return absl::InternalError("No root found"); } @@ -357,8 +357,8 @@ absl::StatusOr ComputeRootShapeAfterHoistingBitcasts( // each caller. absl::Status HoistBitcastUpwardsToCallers(HloInstruction* bitcast, absl::Span callers) { - TF_ASSIGN_OR_RETURN(auto rewrite_plan, - PlanHoistBitcastUpwardsToCallers(bitcast)); + ASSIGN_OR_RETURN(auto rewrite_plan, + PlanHoistBitcastUpwardsToCallers(bitcast)); for (auto [instruction, result_shape] : rewrite_plan) { VLOG(2) << absl::StrCat("rewriting result shape of ", instruction->ToString(), " to ", @@ -373,7 +373,7 @@ absl::Status HoistBitcastUpwardsToCallers(HloInstruction* bitcast, HloInstruction* new_bitcast = caller->AddInstruction(HloInstruction::CreateBitcast( result_shape, caller->mutable_operand(number))); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( caller->ReplaceOperandWithDifferentShape(number, new_bitcast)); } break; @@ -404,8 +404,8 @@ absl::Status HoistBitcastUpwardsToCallers(HloInstruction* bitcast, // HloVerifier error. instruction->clear_sharding(); } - TF_RETURN_IF_ERROR(bitcast->ReplaceAllUsesWith(bitcast->mutable_operand(0))); - TF_RETURN_IF_ERROR(bitcast->parent()->RemoveInstruction(bitcast)); + RETURN_IF_ERROR(bitcast->ReplaceAllUsesWith(bitcast->mutable_operand(0))); + RETURN_IF_ERROR(bitcast->parent()->RemoveInstruction(bitcast)); return absl::OkStatus(); } @@ -415,8 +415,8 @@ absl::Status HoistBitcastUpwardsToCallers(HloInstruction* bitcast, // root shape. absl::StatusOr MaybeInsertRootBitcast( HloInstruction* dot, absl::Span callers) { - TF_ASSIGN_OR_RETURN(Shape root_shape, - ComputeRootShapeAfterHoistingBitcasts(dot)); + ASSIGN_OR_RETURN(Shape root_shape, + ComputeRootShapeAfterHoistingBitcasts(dot)); HloComputation* computation = dot->parent(); HloInstruction* root = computation->root_instruction(); @@ -432,7 +432,7 @@ absl::StatusOr MaybeInsertRootBitcast( for (HloInstruction* caller : callers) { HloInstruction* new_bitcast = caller->AddInstruction( HloInstruction::CreateBitcast(caller->shape(), caller)); - TF_RETURN_IF_ERROR(caller->ReplaceAllUsesWith(new_bitcast)); + RETURN_IF_ERROR(caller->ReplaceAllUsesWith(new_bitcast)); *caller->mutable_shape() = root_shape; } @@ -547,7 +547,7 @@ absl::StatusOr HoistFusedBitcasts::RunOnModule( for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { HoistFusedBitcastsVisitor visitor(call_graph.get()); - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); changed |= visitor.changed(); } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/backends/gpu/transforms/layout_assignment.cc index 11856e3d6aac40..3209392949c3e7 100644 --- a/third_party/xla/xla/backends/gpu/transforms/layout_assignment.cc +++ b/third_party/xla/xla/backends/gpu/transforms/layout_assignment.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -214,7 +215,7 @@ absl::Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( Shape* filter_shape; Shape* output_shape; - TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr)); + ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr)); switch (kind) { case CudnnConvKind::kForward: case CudnnConvKind::kForwardActivation: @@ -242,7 +243,7 @@ absl::Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( std::tie(input, filter, output) = HeuristicLayoutAssignment(instr, gpu_version_); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::tie(*input_shape->mutable_layout(), *filter_shape->mutable_layout(), *output_shape->mutable_layout()), @@ -253,21 +254,21 @@ absl::Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( // The custom call returns a tuple of (actual_result, scratch_buffer); // call_result_buf is the logical buffer for actual_result, the thing that // contains the result of the conv call. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const LogicalBuffer* call_result_buf, points_to_analysis_->GetBufferDefinedAt(instr, /*index=*/{0})); // Set layouts of the instructions' shapes. - TF_RETURN_IF_ERROR(SetOperandLayout(lhs_shape, instr, 0)); - TF_RETURN_IF_ERROR(SetOperandLayout(rhs_shape, instr, 1)); - TF_RETURN_IF_ERROR(SetBufferLayout(result_shape.layout(), *call_result_buf)); + RETURN_IF_ERROR(SetOperandLayout(lhs_shape, instr, 0)); + RETURN_IF_ERROR(SetOperandLayout(rhs_shape, instr, 1)); + RETURN_IF_ERROR(SetBufferLayout(result_shape.layout(), *call_result_buf)); // For fused convolutions, instr->operand(2), if exists, is the bias buffer. // There is no need to assign layout to it, as it has only one dimension. // instr->operand(3), if exists, is the side input buffer. if (kind == CudnnConvKind::kForwardActivation && instr->operand_count() == 4) { // The side input layout must match the output layout. - TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, 3)); + RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, 3)); } // For graph convolutions, align the layouts of the non-scalar inputs to any @@ -275,7 +276,7 @@ absl::Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( if (kind == CudnnConvKind::kForwardGraph) { for (int k = 2; k < instr->operand_count(); ++k) { if (!ShapeUtil::IsScalar(instr->operand(k)->shape())) { - TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, k)); + RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, k)); } } } @@ -304,14 +305,14 @@ absl::Status GpuLayoutAssignment::AddBackendConstraintsToConvolution( std::tie(input, filter, output) = HeuristicLayoutAssignment(conv, gpu_version_); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), *output_shape.mutable_layout()), StreamExecutorConvLayoutsToXlaLayouts(dnums, input, filter, output)); - TF_RETURN_IF_ERROR(SetOperandLayout(input_shape, conv, 0)); - TF_RETURN_IF_ERROR(SetOperandLayout(filter_shape, conv, 1)); - TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, conv)); + RETURN_IF_ERROR(SetOperandLayout(input_shape, conv, 0)); + RETURN_IF_ERROR(SetOperandLayout(filter_shape, conv, 1)); + RETURN_IF_ERROR(SetInstructionLayout(output_shape, conv)); return absl::OkStatus(); } @@ -439,19 +440,19 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints( Side side = {operand_no, instruction->operand(operand_no), batch_dims, contracting_dims}; side.type = side.operand->shape().element_type(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( side.non_contracting_dims, GetNonContractingDims(side.operand->shape(), side.batch_dims, side.contracting_dims)); return side; }; const DotDimensionNumbers& dot_dims = instruction->dot_dimension_numbers(); - TF_ASSIGN_OR_RETURN(const Side lhs, - make_side(0, dot_dims.lhs_batch_dimensions(), - dot_dims.lhs_contracting_dimensions())); - TF_ASSIGN_OR_RETURN(const Side rhs, - make_side(1, dot_dims.rhs_batch_dimensions(), - dot_dims.rhs_contracting_dimensions())); + ASSIGN_OR_RETURN(const Side lhs, + make_side(0, dot_dims.lhs_batch_dimensions(), + dot_dims.lhs_contracting_dimensions())); + ASSIGN_OR_RETURN(const Side rhs, + make_side(1, dot_dims.rhs_batch_dimensions(), + dot_dims.rhs_contracting_dimensions())); const PrimitiveType& output_type = instruction->shape().element_type(); @@ -481,17 +482,17 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints( for (const Side& side : {lhs, rhs}) { if ((IsPackedInstruction(side.operand) && pack_along_contracting_dims) || both_operands_require_minor_contraction_dims) { - TF_RETURN_IF_ERROR(SetDotOperandLayoutToMinorContracting( + RETURN_IF_ERROR(SetDotOperandLayoutToMinorContracting( instruction, side.operand_no, side.batch_dims, side.contracting_dims, side.non_contracting_dims)); } else if (!side.batch_dims.empty() || side.contracting_dims.size() > 1 || side.non_contracting_dims.size() > 1) { - TF_RETURN_IF_ERROR(SetDotOperandLayout( + RETURN_IF_ERROR(SetDotOperandLayout( instruction, side.operand_no, side.batch_dims, side.contracting_dims, side.non_contracting_dims)); } else if (ChainEndsWithAutoLayout(side.operand, saved_entry_computation_layout())) { - TF_RETURN_IF_ERROR(SetDotOperandLayout( + RETURN_IF_ERROR(SetDotOperandLayout( instruction, side.operand_no, side.batch_dims, side.contracting_dims, side.non_contracting_dims, /*mandatory=*/false)); } @@ -502,7 +503,7 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints( // the dot output. if (!lhs.batch_dims.empty() || lhs.non_contracting_dims.size() > 1 || rhs.non_contracting_dims.size() > 1) { - TF_RETURN_IF_ERROR(SetDotLayout(instruction, constraints)); + RETURN_IF_ERROR(SetDotLayout(instruction, constraints)); } return absl::OkStatus(); @@ -518,7 +519,7 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( ++iterator) { HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { - TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall( + RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall( Cast(instruction), constraints)); } @@ -533,10 +534,10 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( DynCast(instruction)->convolution_kind() != CONVOLUTION_KIND_UNSET) << "conv-kind-assignment pass should run before this pass."; - TF_RETURN_IF_ERROR(AddBackendConstraintsToConvolution( + RETURN_IF_ERROR(AddBackendConstraintsToConvolution( Cast(instruction), constraints)); } else if (HloPredicateIsOp(instruction)) { - TF_RETURN_IF_ERROR(AddDotBackendConstraints( + RETURN_IF_ERROR(AddDotBackendConstraints( constraints, Cast(instruction))); } else if (HloPredicateIsOp(instruction)) { Shape op0_shape = instruction->operand(0)->shape(); @@ -545,9 +546,9 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( LayoutUtil::SetToDefaultLayout(&op1_shape); Shape output_shape = instruction->shape(); LayoutUtil::SetToDefaultLayout(&output_shape); - TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0)); - TF_RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1)); - TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction)); + RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0)); + RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1)); + RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction)); } else if (HloPredicateIsOp(instruction)) { const HloInstruction* operand = instruction->operand(0); if ((HloPredicateIsNotOp(operand)) || @@ -562,8 +563,7 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( LayoutUtil::MakeLayoutFromMajorToMinor(instruction->dimensions()); if (DotCanSupportShapeWithLayout(operand, shape)) { - TF_RETURN_IF_ERROR( - SetOperandLayout(shape, instruction, /*operand_no=*/0)); + RETURN_IF_ERROR(SetOperandLayout(shape, instruction, /*operand_no=*/0)); } } else if (HloPredicateIsOp(instruction)) { // cuFFT requires a dim0 major layout. @@ -571,8 +571,8 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( LayoutUtil::SetToDefaultLayout(&op0_shape); Shape output_shape = instruction->shape(); LayoutUtil::SetToDefaultLayout(&output_shape); - TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0)); - TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction)); + RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0)); + RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction)); } else if ((HloPredicateIsOp(instruction) || IsCubDeviceRadixSortNoScratchSize(*instruction)) && instruction->operand(0)->shape().dimensions().size() > 1) { @@ -583,31 +583,31 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( for (int64_t i = 0; i < instruction->operand_count(); ++i) { Shape shape = instruction->operand(i)->shape(); *shape.mutable_layout() = keys_layout; - TF_RETURN_IF_ERROR(SetOperandLayout(shape, instruction, i)); + RETURN_IF_ERROR(SetOperandLayout(shape, instruction, i)); const LogicalBuffer* output_buffer; if (instruction->shape().IsArray()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( output_buffer, points_to_analysis_->GetBufferDefinedAt(instruction, {})); } else { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( output_buffer, points_to_analysis_->GetBufferDefinedAt(instruction, {i})); } - TF_RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer)); + RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer)); } } else if (IsCustomCallToTopK(*instruction)) { // The output of the TopK custom call needs to have default layout. Layout default_layout = LayoutUtil::GetDefaultLayoutForRank( instruction->operand(0)->shape().dimensions().size()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto values_buffer, points_to_analysis_->GetBufferDefinedAt(instruction, {0})); - TF_RETURN_IF_ERROR(SetBufferLayout(default_layout, *values_buffer)); - TF_ASSIGN_OR_RETURN( + RETURN_IF_ERROR(SetBufferLayout(default_layout, *values_buffer)); + ASSIGN_OR_RETURN( auto indices_buffer, points_to_analysis_->GetBufferDefinedAt(instruction, {1})); - TF_RETURN_IF_ERROR(SetBufferLayout(default_layout, *indices_buffer)); + RETURN_IF_ERROR(SetBufferLayout(default_layout, *indices_buffer)); } else if (HloPredicateIsOp(instruction)) { Shape operand_shape = instruction->operand(0)->shape(); Shape output_shape = instruction->shape(); @@ -636,11 +636,11 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( } if (ranks_differ) { - TF_RETURN_IF_ERROR(SetOperandLayout(operand_shape, instruction, - /*operand_no=*/0, - /*mandatory=*/true)); - TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction, - /*mandatory=*/true)); + RETURN_IF_ERROR(SetOperandLayout(operand_shape, instruction, + /*operand_no=*/0, + /*mandatory=*/true)); + RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction, + /*mandatory=*/true)); } } else if (HloPredicateIsOp(instruction)) { // TODO(phawkins): Ideally we would relax this constraint. What we @@ -655,21 +655,21 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( SetFortranLayout(&op0_shape); SetFortranLayout(&op1_shape); SetFortranLayout(&output_shape); - TF_RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0)); - TF_RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1)); - TF_RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction)); + RETURN_IF_ERROR(SetOperandLayout(op0_shape, instruction, 0)); + RETURN_IF_ERROR(SetOperandLayout(op1_shape, instruction, 1)); + RETURN_IF_ERROR(SetInstructionLayout(output_shape, instruction)); } else if (HloPredicateIsOp(instruction)) { // XLA:GPU can only support reduce-scatter where the scatter dimension // is the most major dimension in the layout. auto ars = Cast(instruction); - TF_RETURN_IF_ERROR(SetInstructionLayout( + RETURN_IF_ERROR(SetInstructionLayout( ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()), ars)); } else if (HloPredicateIsOp(instruction)) { // XLA:GPU can only support all-gathers where the gather dimension is the // most major dimension in the layout. auto ag = Cast(instruction); - TF_RETURN_IF_ERROR(SetInstructionLayout( + RETURN_IF_ERROR(SetInstructionLayout( ShapeUtil::MoveDimToMajor(ag->shape(), ag->all_gather_dimension()), ag)); } else if (HloPredicateIsOp(instruction) && @@ -677,7 +677,7 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( // XLA:GPU can only support all-to-all with split dimensions where the // split dimension is the most major dimension in the layout. auto* all_to_all = Cast(instruction); - TF_RETURN_IF_ERROR(SetInstructionLayout( + RETURN_IF_ERROR(SetInstructionLayout( ShapeUtil::MoveDimToMajor(all_to_all->shape(), *all_to_all->split_dimension()), all_to_all)); @@ -685,14 +685,14 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( auto* ragged_all_to_all = Cast(instruction); // XLA:GPU can only support ragged-all-to-all with the most major ragged // dimension in the layout. - TF_RETURN_IF_ERROR(SetInstructionLayout( + RETURN_IF_ERROR(SetInstructionLayout( ShapeUtil::MoveDimToMajor(ragged_all_to_all->shape(), 0), ragged_all_to_all)); } else if (HloPredicateIsOp(instruction)) { Shape s = instruction->operand(0)->shape(); LayoutUtil::SetToDefaultLayout(&s); - TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction->operand(0))); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(SetInstructionLayout(s, instruction->operand(0))); + RETURN_IF_ERROR( SetArrayOperandLayout(s.layout(), instruction->operand(0), 0)); } else if (HloPredicateIsOp(instruction)) { Shape s = instruction->shape(); @@ -700,14 +700,14 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( &s, [&](Shape* subshape, const ShapeIndex& index) { LayoutUtil::SetToDefaultLayout(subshape); }); - TF_RETURN_IF_ERROR(SetInstructionLayout(s, instruction)); + RETURN_IF_ERROR(SetInstructionLayout(s, instruction)); } else if (IsCustomCallToMemoryPlacement(instruction)) { // Make sure that host memory buffers use the default layout so that // the compiler does not insert transposes on host memory buffers. Shape operand_shape = instruction->operand(0)->shape(); LayoutUtil::SetToDefaultLayout(&operand_shape); - TF_RETURN_IF_ERROR(SetOperandLayout(operand_shape, instruction, 0)); - TF_RETURN_IF_ERROR(SetInstructionLayout(operand_shape, instruction)); + RETURN_IF_ERROR(SetOperandLayout(operand_shape, instruction, 0)); + RETURN_IF_ERROR(SetInstructionLayout(operand_shape, instruction)); } else if (instruction->opcode() == HloOpcode::kAsyncStart) { HloComputation* called_computation = instruction->async_wrapped_computation(); @@ -722,9 +722,9 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( called_computation->ComputeProgramShape().parameters()); *new_shape.mutable_tuple_shapes(1) = called_computation->ComputeProgramShape().result(); - TF_RETURN_IF_ERROR(SetInstructionLayout(new_shape, instruction, - /*mandatory=*/true, /*dfs=*/true, - /*allow_alias=*/true)); + RETURN_IF_ERROR(SetInstructionLayout(new_shape, instruction, + /*mandatory=*/true, /*dfs=*/true, + /*allow_alias=*/true)); } else if (instruction->opcode() == HloOpcode::kAsyncDone) { HloComputation* called_computation = instruction->async_wrapped_computation(); @@ -736,9 +736,9 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( Shape new_shape = called_computation->root_instruction()->shape(); - TF_RETURN_IF_ERROR(SetInstructionLayout(new_shape, instruction, - /*mandatory=*/true, /*dfs=*/true, - /*allow_alias=*/true)); + RETURN_IF_ERROR(SetInstructionLayout(new_shape, instruction, + /*mandatory=*/true, /*dfs=*/true, + /*allow_alias=*/true)); } } return absl::OkStatus(); diff --git a/third_party/xla/xla/backends/gpu/transforms/move_copy_to_users.cc b/third_party/xla/xla/backends/gpu/transforms/move_copy_to_users.cc index f7b98fc94a89b1..f50baa023e6618 100644 --- a/third_party/xla/xla/backends/gpu/transforms/move_copy_to_users.cc +++ b/third_party/xla/xla/backends/gpu/transforms/move_copy_to_users.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -42,14 +43,14 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* c = hlo->mutable_operand(1); if (HloPredicateIsOp(operand)) { HloInstruction* copied = operand->mutable_operand(0); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * earlier_pad, MakePadHlo(copied, c, hlo->padding_config(), &hlo->metadata())); // MakePadHlo fails to propagate layout. *earlier_pad->mutable_shape()->mutable_layout() = copied->shape().layout(); HloInstruction* later_copy = MakeCopyHlo(earlier_pad, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } return absl::OkStatus(); } @@ -59,14 +60,14 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* operand = hlo->mutable_operand(0); if (HloPredicateIsOp(operand)) { HloInstruction* copied = operand->mutable_operand(0); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * earlier_slice, MakeSliceHlo(copied, hlo->slice_starts(), hlo->slice_limits(), hlo->slice_strides(), &hlo->metadata())); *earlier_slice->mutable_shape()->mutable_layout() = copied->shape().layout(); HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } return absl::OkStatus(); } @@ -77,7 +78,7 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* operand = hlo->mutable_operand(0); if (HloPredicateIsOp(operand)) { HloInstruction* copied = operand->mutable_operand(0); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * earlier_slice, MakeDynamicSliceHlo( copied, @@ -86,7 +87,7 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { *earlier_slice->mutable_shape()->mutable_layout() = copied->shape().layout(); HloInstruction* later_copy = MakeCopyHlo(earlier_slice, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } return absl::OkStatus(); } @@ -102,7 +103,7 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* operand = hlo->mutable_operand(0); if (HloPredicateIsOp(operand)) { HloInstruction* copied = operand->mutable_operand(0); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * earlier_reduce_window, MakeReduceWindowHlo(copied, hlo->mutable_operand(1), hlo->window(), hlo->called_computations()[0], &hlo->metadata())); @@ -110,7 +111,7 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { copied->shape().layout(); HloInstruction* later_copy = MakeCopyHlo(earlier_reduce_window, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } return absl::OkStatus(); } @@ -123,7 +124,7 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* new_reduce = hlo->AddInstruction( hlo->CloneWithNewOperands(hlo->shape(), {operand->mutable_operand(0), hlo->mutable_operand(1)})); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_reduce)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, new_reduce)); } return absl::OkStatus(); } @@ -141,12 +142,11 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { } if (HloPredicateIsOp(operand)) { HloInstruction* copied = operand->mutable_operand(0); - TF_ASSIGN_OR_RETURN( - HloInstruction * earlier_elementwise, - MakeUnaryHlo(hlo->opcode(), copied, &hlo->metadata())); + ASSIGN_OR_RETURN(HloInstruction * earlier_elementwise, + MakeUnaryHlo(hlo->opcode(), copied, &hlo->metadata())); HloInstruction* later_copy = MakeCopyHlo(earlier_elementwise, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } return absl::OkStatus(); } @@ -156,11 +156,11 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* operand = hlo->mutable_operand(0); if (HloPredicateIsOp(operand)) { HloInstruction* copied = operand->mutable_operand(0); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * earlier_reverse, MakeReverseHlo(copied, hlo->dimensions(), &hlo->metadata())); HloInstruction* later_copy = MakeCopyHlo(earlier_reverse, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } return absl::OkStatus(); } @@ -179,7 +179,7 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { HloInstruction* earlier_convert = MakeConvertToHlo( copied, hlo->shape().element_type(), &hlo->metadata()); HloInstruction* later_copy = MakeCopyHlo(earlier_convert, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } return absl::OkStatus(); } @@ -195,18 +195,17 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { if (copied_a->shape() == copied_b->shape()) { HloInstruction* earlier_elementwise; if (HloPredicateIsOp(hlo)) { - TF_ASSIGN_OR_RETURN( - earlier_elementwise, - MakeCompareHlo(hlo->comparison_direction(), copied_a, copied_b, - &hlo->metadata())); + ASSIGN_OR_RETURN(earlier_elementwise, + MakeCompareHlo(hlo->comparison_direction(), copied_a, + copied_b, &hlo->metadata())); } else { - TF_ASSIGN_OR_RETURN(earlier_elementwise, - MakeBinaryHlo(hlo->opcode(), copied_a, copied_b, - &hlo->metadata())); + ASSIGN_OR_RETURN(earlier_elementwise, + MakeBinaryHlo(hlo->opcode(), copied_a, copied_b, + &hlo->metadata())); } HloInstruction* later_copy = MakeCopyHlo(earlier_elementwise, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, later_copy)); } } return absl::OkStatus(); @@ -233,13 +232,12 @@ class MoveCopyToUsersVisitor : public DfsHloRewriteVisitor { new_operands.push_back(op->mutable_operand(0)); } - TF_ASSIGN_OR_RETURN( - HloInstruction * new_concat, - MakeConcatHlo(new_operands, hlo->concatenate_dimension())); + ASSIGN_OR_RETURN(HloInstruction * new_concat, + MakeConcatHlo(new_operands, hlo->concatenate_dimension())); *new_concat->mutable_shape()->mutable_layout() = inner_op_layout; HloInstruction* new_copy = MakeCopyHlo(new_concat, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, new_copy)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, new_copy)); return absl::OkStatus(); } }; diff --git a/third_party/xla/xla/backends/gpu/transforms/multi_output_fusion.cc b/third_party/xla/xla/backends/gpu/transforms/multi_output_fusion.cc index ecb7cc68f5c96f..cca30f594cdd65 100644 --- a/third_party/xla/xla/backends/gpu/transforms/multi_output_fusion.cc +++ b/third_party/xla/xla/backends/gpu/transforms/multi_output_fusion.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/hlo/analysis/hlo_dfs_reachability.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -430,7 +431,7 @@ absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { /*min_latencies_seconds=*/{}, /*count_multiple_input_accesses=*/true}, device_info_); - TF_RETURN_IF_ERROR(computation_->Accept(&cost_analysis)); + RETURN_IF_ERROR(computation_->Accept(&cost_analysis)); std::vector defs_before_uses = computation_->MakeInstructionPostOrder(); @@ -488,8 +489,8 @@ absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { gpu_performance_model_cache.Invalidate(*consumer_for_fusion); fusion_analysis_cache.Invalidate(producer->unique_id()); fusion_analysis_cache.Invalidate(consumer_for_fusion->unique_id()); - TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(producer)); - TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(consumer_for_fusion)); + RETURN_IF_ERROR(cost_analysis.RemoveInstruction(producer)); + RETURN_IF_ERROR(cost_analysis.RemoveInstruction(consumer_for_fusion)); HloInstruction* input_fusion; if (HloPredicateIsOp(consumer_for_fusion)) { @@ -521,7 +522,7 @@ absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { CHECK_EQ(0, producer->user_count()); CHECK_OK(computation_->RemoveInstruction(producer)); } - TF_RETURN_IF_ERROR(cost_analysis.RevisitInstruction(input_fusion)); + RETURN_IF_ERROR(cost_analysis.RevisitInstruction(input_fusion)); DumpFusionState(*input_fusion, absl::StrCat("Fused into |", input_fusion->name(), @@ -548,7 +549,7 @@ absl::StatusOr MultiOutputFusion::RunImpl( bool changed = false; for (auto* computation : GetFusibleComputations(*module, execution_threads)) { computation_ = computation; - TF_ASSIGN_OR_RETURN(bool computation_changed, DoMultiOutputFusion()); + ASSIGN_OR_RETURN(bool computation_changed, DoMultiOutputFusion()); changed |= computation_changed; } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/pgle_accuracy_checker.cc b/third_party/xla/xla/backends/gpu/transforms/pgle_accuracy_checker.cc index 652b2fcd07cb18..9aae21b7d368c3 100644 --- a/third_party/xla/xla/backends/gpu/transforms/pgle_accuracy_checker.cc +++ b/third_party/xla/xla/backends/gpu/transforms/pgle_accuracy_checker.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/tsl/platform/errors.h" @@ -26,7 +27,7 @@ namespace xla::gpu { absl::StatusOr PGLEAccuracyChecker::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_RETURN_IF_ERROR(pgle_estimator_.CheckAccuracy(*module)); + RETURN_IF_ERROR(pgle_estimator_.CheckAccuracy(*module)); return false; } diff --git a/third_party/xla/xla/backends/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/backends/gpu/transforms/priority_fusion.cc index 26db4db6182011..4b0ee9cd1b7e32 100644 --- a/third_party/xla/xla/backends/gpu/transforms/priority_fusion.cc +++ b/third_party/xla/xla/backends/gpu/transforms/priority_fusion.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/STLExtras.h" #include "mlir/IR/MLIRContext.h" #include "xla/backends/gpu/codegen/triton/support.h" @@ -309,7 +310,7 @@ class PriorityFusionQueue { EstimateRunTimeData runtime_data; if (IsGenericTritonFusion(*producer)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( runtime_data, gpu_indexing_performance_model_.EstimateRunTimeForTriton(producer)); } else { @@ -327,10 +328,10 @@ class PriorityFusionQueue { // Revisit costs of all updated ops. It's important to update cost analysis // before recalculating priorities. for (auto instruction : to_update_priority_) { - TF_RETURN_IF_ERROR(cost_analysis_.RevisitInstruction(instruction)); + RETURN_IF_ERROR(cost_analysis_.RevisitInstruction(instruction)); } for (auto producer : to_update_priority_) { - TF_RETURN_IF_ERROR(UpdatePerformanceModelCache(producer)); + RETURN_IF_ERROR(UpdatePerformanceModelCache(producer)); } ComputeAndSetPriorities(std::vector{ @@ -1203,7 +1204,7 @@ absl::StatusOr PriorityFusion::RunImpl( Fuse(producer, consumer, use_multi_output_fusion); auto backend_config_it = block_level_parameters_map.find(consumer); if (backend_config_it != block_level_parameters_map.end()) { - TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config( + RETURN_IF_ERROR(fusion_instruction->set_backend_config( GetTritonGpuBackendConfig(backend_config_it->second))); fusion_instruction->set_fusion_kind( HloInstruction::FusionKind::kCustom); @@ -1224,7 +1225,7 @@ absl::StatusOr PriorityFusion::RunImpl( // have been removed already. if (!use_multi_output_fusion) { producer->DetachFromOperandsAndUsers(); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); + RETURN_IF_ERROR(computation->RemoveInstruction(producer)); } } @@ -1234,7 +1235,7 @@ absl::StatusOr PriorityFusion::RunImpl( for (auto consumer_id : pre_fusion_consumer_ids) { fusion_analysis_cache_.Invalidate(consumer_id); } - TF_RETURN_IF_ERROR(fusion_queue->UpdatePriorities()); + RETURN_IF_ERROR(fusion_queue->UpdatePriorities()); } // Fuse all constants. diff --git a/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_canonicalizer.cc b/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_canonicalizer.cc index 3946ad8898ac14..40800daf2cf099 100644 --- a/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_canonicalizer.cc +++ b/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_canonicalizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -68,9 +69,8 @@ absl::StatusOr CanonicalizeRaggedAllToAll( /*channel_id=*/ragged_all_to_all->channel_id())); new_ragged_all_to_all->set_frontend_attributes( ragged_all_to_all->frontend_attributes()); - TF_RETURN_IF_ERROR( - ragged_all_to_all->ReplaceAllUsesWith(new_ragged_all_to_all)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(ragged_all_to_all->ReplaceAllUsesWith(new_ragged_all_to_all)); + RETURN_IF_ERROR( computation->RemoveInstructionAndUnusedOperands(ragged_all_to_all)); return true; } @@ -82,8 +82,8 @@ absl::StatusOr RaggedAllToAllCanonicalizer::RunImpl( for (auto computation : module->computations(execution_threads)) { for (auto hlo : computation->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool canonicalized, - CanonicalizeRaggedAllToAll(hlo, computation, module)); + ASSIGN_OR_RETURN(bool canonicalized, + CanonicalizeRaggedAllToAll(hlo, computation, module)); changed |= canonicalized; } } diff --git a/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_decomposer.cc b/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_decomposer.cc index 942e401196d8d0..f7d1c2f23a6fb8 100644 --- a/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_decomposer.cc +++ b/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_decomposer.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -345,8 +346,8 @@ absl::StatusOr DecomposeRaggedAllToAll(HloInstruction* hlo, HloRaggedAllToAllInstruction* all_to_all = Cast(hlo); - TF_ASSIGN_OR_RETURN(auto replica_group_count_and_size, - GetReplicaGroupCountAndSize(all_to_all)); + ASSIGN_OR_RETURN(auto replica_group_count_and_size, + GetReplicaGroupCountAndSize(all_to_all)); if (!replica_group_count_and_size.has_value()) { return false; } @@ -394,9 +395,8 @@ absl::StatusOr DecomposeRaggedAllToAll(HloInstruction* hlo, DenseToRagged(computation, dense_output, output_operand, output_offsets, recv_sizes, num_updates_per_replica, max_update_size); - TF_RETURN_IF_ERROR(all_to_all->ReplaceAllUsesWith(ragged_output)); - TF_RETURN_IF_ERROR( - computation->RemoveInstructionAndUnusedOperands(all_to_all)); + RETURN_IF_ERROR(all_to_all->ReplaceAllUsesWith(ragged_output)); + RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(all_to_all)); return true; } @@ -418,8 +418,8 @@ absl::StatusOr RaggedAllToAllDecomposer::RunImpl( "`ragged-all-to-all-canonicalizer` pass executed?"); } - TF_ASSIGN_OR_RETURN(bool result, - DecomposeRaggedAllToAll(hlo, computation, module)); + ASSIGN_OR_RETURN(bool result, + DecomposeRaggedAllToAll(hlo, computation, module)); changed |= result; } } diff --git a/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_multi_host_decomposer.cc b/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_multi_host_decomposer.cc index 674f509d57e37f..e097e703b81bab 100644 --- a/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_multi_host_decomposer.cc +++ b/third_party/xla/xla/backends/gpu/transforms/ragged_all_to_all_multi_host_decomposer.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/array.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -317,8 +318,8 @@ absl::StatusOr DecomposeDispatchRaggedAllToAll( std::make_shared(intra_host_replica_groups), /*channel_id=*/ragged_all_to_all->channel_id())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(ragged_all_to_all, - new_ragged_all_to_all)); + RETURN_IF_ERROR(computation->ReplaceInstruction(ragged_all_to_all, + new_ragged_all_to_all)); return true; } @@ -494,8 +495,8 @@ absl::StatusOr DecomposeCombineRaggedAllToAll( std::make_shared(degenerated_replica_groups), /*channel_id=*/ragged_all_to_all->channel_id())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(ragged_all_to_all, - local_ragged_all_to_all)); + RETURN_IF_ERROR(computation->ReplaceInstruction(ragged_all_to_all, + local_ragged_all_to_all)); return true; } @@ -635,8 +636,8 @@ absl::StatusOr RaggedAllToAllMultiHostDecomposer::RunImpl( "`ragged-all-to-all-canonicalizer` pass executed?"); } - TF_ASSIGN_OR_RETURN( - bool result, DecomposeRaggedAllToAll(hlo, computation, module, + ASSIGN_OR_RETURN(bool result, + DecomposeRaggedAllToAll(hlo, computation, module, fast_interconnect_slice_size_)); changed |= result; } diff --git a/third_party/xla/xla/backends/gpu/transforms/ragged_dot_fusion_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/ragged_dot_fusion_rewriter.cc index 5cf087e448a577..218e3179035ef3 100644 --- a/third_party/xla/xla/backends/gpu/transforms/ragged_dot_fusion_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/ragged_dot_fusion_rewriter.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/SmallVector.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -157,16 +158,14 @@ absl::StatusOr RaggedDotFusionRewriter::RunImpl( } for (auto* ragged_dot : ragged_dots) { - TF_ASSIGN_OR_RETURN(auto ragged_dot_fusion, - RaggedToCuDNNFusion(ragged_dot)); + ASSIGN_OR_RETURN(auto ragged_dot_fusion, RaggedToCuDNNFusion(ragged_dot)); gpu::GpuBackendConfig gpu_backend_config; gpu::FusionBackendConfig* fusion_config = gpu_backend_config.mutable_fusion_backend_config(); fusion_config->set_kind(gpu::kCuDnnFusionKind); - TF_RETURN_IF_ERROR( - ragged_dot_fusion->set_backend_config(gpu_backend_config)); + RETURN_IF_ERROR(ragged_dot_fusion->set_backend_config(gpu_backend_config)); ragged_dot_fusion->set_metadata(ragged_dot->metadata()); - TF_RETURN_IF_ERROR(ragged_dot->parent()->ReplaceWithNewInstruction( + RETURN_IF_ERROR(ragged_dot->parent()->ReplaceWithNewInstruction( ragged_dot, std::move(ragged_dot_fusion))); } diff --git a/third_party/xla/xla/backends/gpu/transforms/reduce_scatter_creator.cc b/third_party/xla/xla/backends/gpu/transforms/reduce_scatter_creator.cc index 1b156d0f2434d9..84143826fbc1fe 100644 --- a/third_party/xla/xla/backends/gpu/transforms/reduce_scatter_creator.cc +++ b/third_party/xla/xla/backends/gpu/transforms/reduce_scatter_creator.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -115,12 +116,12 @@ absl::StatusOr ReduceScatterCreator::RunImpl( // Note that RemoveInstructionAndUnusedOperands may not always remove the // all-reduce operand of the dynamic-slice, so remove all the dead // instructions manually. - TF_RETURN_IF_ERROR(ds->ReplaceAllUsesWith(result)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(ds)); + RETURN_IF_ERROR(ds->ReplaceAllUsesWith(result)); + RETURN_IF_ERROR(computation->RemoveInstruction(ds)); if (reshape) { - TF_RETURN_IF_ERROR(computation->RemoveInstruction(reshape)); + RETURN_IF_ERROR(computation->RemoveInstruction(reshape)); } - TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ar)); + RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ar)); changed = true; } } diff --git a/third_party/xla/xla/backends/gpu/transforms/reduce_scatter_creator_test.cc b/third_party/xla/xla/backends/gpu/transforms/reduce_scatter_creator_test.cc index 2c1314828d9705..c319462a2d9117 100644 --- a/third_party/xla/xla/backends/gpu/transforms/reduce_scatter_creator_test.cc +++ b/third_party/xla/xla/backends/gpu/transforms/reduce_scatter_creator_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/transforms/algebraic_simplifier.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -55,8 +56,8 @@ class GpuReduceScatterCreatorTest : public HloHardwareIndependentTestBase { HloModuleConfig config = GetModuleConfigForTest( /*replica_count=*/num_replicas, /*num_partitions=*/num_partitions); config.set_use_spmd_partitioning(use_spmd_partitioning); - TF_ASSIGN_OR_RETURN(auto module, - ParseAndReturnVerifiedModule(hlo_module, config)); + ASSIGN_OR_RETURN(auto module, + ParseAndReturnVerifiedModule(hlo_module, config)); auto changed = ReduceScatterCreator().Run(module.get()); if (!changed.ok()) { return changed.status(); diff --git a/third_party/xla/xla/backends/gpu/transforms/reduction_degenerate_dim_remover.cc b/third_party/xla/xla/backends/gpu/transforms/reduction_degenerate_dim_remover.cc index f9984638ea34fb..d9d68b549c16ac 100644 --- a/third_party/xla/xla/backends/gpu/transforms/reduction_degenerate_dim_remover.cc +++ b/third_party/xla/xla/backends/gpu/transforms/reduction_degenerate_dim_remover.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -87,7 +88,7 @@ class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor { canonical_reduce_shapes.push_back(canonical_reduce_shape); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto canonical_reduce_shape, ShapeUtil::MakeValidatedMaybeTupleShape(canonical_reduce_shapes)); const Shape &orig_reduce_shape = instr->shape(); @@ -122,9 +123,9 @@ class ReductionDegenerateDimRemoverVisitor : public DfsHloRewriteVisitor { absl::StatusOr ReductionDegenerateDimRemover::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN(bool changed, - ReductionDegenerateDimRemoverVisitor().RunOnModule( - module, execution_threads)); + ASSIGN_OR_RETURN(bool changed, + ReductionDegenerateDimRemoverVisitor().RunOnModule( + module, execution_threads)); return changed; } diff --git a/third_party/xla/xla/backends/gpu/transforms/reduction_dimension_grouper.cc b/third_party/xla/xla/backends/gpu/transforms/reduction_dimension_grouper.cc index 56d284e09b8bf4..4c3f14e1324c90 100644 --- a/third_party/xla/xla/backends/gpu/transforms/reduction_dimension_grouper.cc +++ b/third_party/xla/xla/backends/gpu/transforms/reduction_dimension_grouper.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -115,8 +116,8 @@ class ReduceDimensionGroupVisitor : public DfsHloRewriteVisitor { absl::StatusOr ReductionDimensionGrouper::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN(bool changed, ReduceDimensionGroupVisitor().RunOnModule( - module, execution_threads)); + ASSIGN_OR_RETURN(bool changed, ReduceDimensionGroupVisitor().RunOnModule( + module, execution_threads)); return changed; } diff --git a/third_party/xla/xla/backends/gpu/transforms/reduction_layout_normalizer.cc b/third_party/xla/xla/backends/gpu/transforms/reduction_layout_normalizer.cc index 64e3ec2ac75015..4c6e9d09843d9a 100644 --- a/third_party/xla/xla/backends/gpu/transforms/reduction_layout_normalizer.cc +++ b/third_party/xla/xla/backends/gpu/transforms/reduction_layout_normalizer.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -147,7 +148,7 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { } } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto new_reduce_shape, ShapeUtil::MakeValidatedMaybeTupleShape(new_reduce_shapes)); @@ -186,9 +187,9 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { absl::StatusOr ReductionLayoutNormalizer::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN(bool changed, - EnforceMinorToMajorReduceOpVisitor().RunOnModule( - module, execution_threads)); + ASSIGN_OR_RETURN(bool changed, + EnforceMinorToMajorReduceOpVisitor().RunOnModule( + module, execution_threads)); return changed; } diff --git a/third_party/xla/xla/backends/gpu/transforms/reduction_splitter.cc b/third_party/xla/xla/backends/gpu/transforms/reduction_splitter.cc index c02c142fd88723..96a6ef66f4a3ae 100644 --- a/third_party/xla/xla/backends/gpu/transforms/reduction_splitter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/reduction_splitter.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -134,10 +135,9 @@ class ReductionSplitterVisitor : public DfsHloRewriteVisitor { absl::StatusOr ReductionSplitter::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN( - bool changed, - ReductionSplitterVisitor(device_description_, ignore_small_dims_) - .RunOnModule(module, execution_threads)); + ASSIGN_OR_RETURN(bool changed, ReductionSplitterVisitor(device_description_, + ignore_small_dims_) + .RunOnModule(module, execution_threads)); return changed; } diff --git a/third_party/xla/xla/backends/gpu/transforms/reduction_splitter_test.cc b/third_party/xla/xla/backends/gpu/transforms/reduction_splitter_test.cc index 35cbae803a2fac..8bba7261231b84 100644 --- a/third_party/xla/xla/backends/gpu/transforms/reduction_splitter_test.cc +++ b/third_party/xla/xla/backends/gpu/transforms/reduction_splitter_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" @@ -39,9 +40,9 @@ namespace { namespace m = ::xla::match; absl::StatusOr MakeDeviceDescription() { - TF_ASSIGN_OR_RETURN(stream_executor::DeviceDescription device_description, - stream_executor::DeviceDescription::FromProto( - stream_executor::GpuDeviceInfoProto{})); + ASSIGN_OR_RETURN(stream_executor::DeviceDescription device_description, + stream_executor::DeviceDescription::FromProto( + stream_executor::GpuDeviceInfoProto{})); device_description.set_threads_per_warp(32); return device_description; } diff --git a/third_party/xla/xla/backends/gpu/transforms/scaled_dot_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/scaled_dot_rewriter.cc index fac6c4ef6f988e..dbe0d636b5c936 100644 --- a/third_party/xla/xla/backends/gpu/transforms/scaled_dot_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/scaled_dot_rewriter.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -152,7 +153,7 @@ absl::StatusOr Dequantize(HloInstruction* dot, return operand; } std::tie(operand, scale) = UpscaleBoth(operand, scale); - TF_RETURN_IF_ERROR(CheckOperandAndScaleShapes(side, operand, scale)); + RETURN_IF_ERROR(CheckOperandAndScaleShapes(side, operand, scale)); HloInstruction* broadcasted_scale = BroadcastAndReshape(scale, operand->shape(), computation); HloInstruction* dequantized = @@ -171,8 +172,8 @@ absl::StatusOr ScaledDotRewriter::RewriteComputation( } changed = true; HloScaledDotInstruction* dot = Cast(instruction); - TF_ASSIGN_OR_RETURN(HloInstruction * lhs, Dequantize(dot, 0, 2, "LHS")); - TF_ASSIGN_OR_RETURN(HloInstruction * rhs, Dequantize(dot, 1, 3, "RHS")); + ASSIGN_OR_RETURN(HloInstruction * lhs, Dequantize(dot, 0, 2, "LHS")); + ASSIGN_OR_RETURN(HloInstruction * rhs, Dequantize(dot, 1, 3, "RHS")); std::tie(lhs, rhs) = UpscaleBoth(lhs, rhs); @@ -180,12 +181,12 @@ absl::StatusOr ScaledDotRewriter::RewriteComputation( dot_shape.set_element_type(GetTargetType(lhs->shape().element_type(), dot->shape().element_type())); - TF_RETURN_IF_ERROR(dot->ReplaceAllUsesWith( + RETURN_IF_ERROR(dot->ReplaceAllUsesWith( Convert(computation->AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot->dot_dimension_numbers(), dot->precision_config())), dot->shape().element_type()))); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(dot)); + RETURN_IF_ERROR(computation->RemoveInstruction(dot)); } return changed; } @@ -194,7 +195,7 @@ absl::StatusOr ScaledDotRewriter::RunImpl( HloModule* module, const absl::flat_hash_set&) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN(bool result, RewriteComputation(computation)); + ASSIGN_OR_RETURN(bool result, RewriteComputation(computation)); changed |= result; } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/scatter_determinism_expander.cc b/third_party/xla/xla/backends/gpu/transforms/scatter_determinism_expander.cc index f2364c7ce622dd..2875bb240423a6 100644 --- a/third_party/xla/xla/backends/gpu/transforms/scatter_determinism_expander.cc +++ b/third_party/xla/xla/backends/gpu/transforms/scatter_determinism_expander.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/array.h" #include "xla/array2d.h" #include "xla/comparison_util.h" @@ -53,10 +54,10 @@ static absl::StatusOr> CanonicalizeScatterUpdates( std::vector adjusted_updates; adjusted_updates.reserve(scatter_updates.size()); for (HloInstruction* update : scatter_updates) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * canonical_update, PermuteScatterAndWindowDims(update, dim_numbers.update_window_dims())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * adjusted_update, AdjustScatterDims(scatter_indices->shape(), canonical_update, dim_numbers.index_vector_dim())); @@ -393,8 +394,8 @@ static absl::StatusOr CreateScanWithIndices( concatenated_indices, ComparisonDirection::kEq)); std::vector map_operands = {current_updates, concatenated_updates}; - TF_ASSIGN_OR_RETURN(HloInstruction * reduced_updates, - MakeMapHlo(map_operands, to_apply)); + ASSIGN_OR_RETURN(HloInstruction * reduced_updates, + MakeMapHlo(map_operands, to_apply)); current_updates = parent->AddInstruction(HloInstruction::CreateTernary( updates_shape, HloOpcode::kSelect, indices_mask, reduced_updates, current_updates)); @@ -409,10 +410,10 @@ absl::StatusOr> ComputePrefixScan( std::vector prefix_scans(sorted_updates.size()); HloInstruction* prefix_scan_update = nullptr; for (int i = 0; i < sorted_updates.size(); i++) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloComputation * to_apply, CallComputationAndGetIthOutputWithBinaryParams(scatter->to_apply(), i)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( prefix_scan_update, CreateScanWithIndices(parent, sorted_updates[i], sorted_scalar_indices, to_apply, operand_dims)); @@ -595,7 +596,7 @@ absl::StatusOr CheckValidIndices( init_reduce_value, {1}, reduce_computation)); // 2. Check last indices <= [bounds...] // Check if the index is OOB w.r.t. the operand dimensions and window sizes. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * max_valid_index_constant, CreateBoundTensor(parent, indices, operand_dims, full_index_to_operand_dims, false, window_sizes)); @@ -701,9 +702,9 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( // Canonicalize the scatter_indices, after which the size of its most-major // dimension must be same as the while loop trip count. HloInstruction* original_scatter_indices = scatter_indices; - TF_ASSIGN_OR_RETURN(scatter_indices, - CanonicalizeScatterIndices( - scatter_indices, dim_numbers.index_vector_dim())); + ASSIGN_OR_RETURN(scatter_indices, + CanonicalizeScatterIndices(scatter_indices, + dim_numbers.index_vector_dim())); CHECK_EQ(scatter_indices_count, scatter_indices->shape().dimensions(0)); // We compromise for maintainability and make the scatter_indices always 2D, // so that the implementation could be easier, as we do not need to maintain @@ -720,7 +721,7 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( // Canonicalize the updates, after which the size of their most-major // dimensions must be same as the while loop trip count. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( scatter_updates, CanonicalizeScatterUpdates(scatter_updates, original_scatter_indices, dim_numbers, scatter_indices_count)); @@ -742,7 +743,7 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( std::vector full_index_to_operand_dims = ComputeFullIndexToOperandDims(scatter_operands[0]->shape(), dim_numbers); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * out_of_bound_tensor, CreateBoundTensor(parent, scatter_indices, scatter->shape().dimensions(), dim_numbers.scatter_dims_to_operand_dims())); @@ -775,7 +776,7 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( full_index_to_operand_dims, actual_update_window_dims); // Map scatter_indices into operand space - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( scatter_indices, AddImplicitDimensionsToIndices( scatter_operands[0]->shape().dimensions().size(), @@ -783,7 +784,7 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( CHECK(scatter_indices->shape().dimensions(0) == scatter_indices_count); // Add implicit dimensions to OOB constant, if needed. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( out_of_bound_tensor, AddImplicitDimensionsToIndices( scatter_operands[0]->shape().dimensions().size(), @@ -791,7 +792,7 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( // If any updates are out of bound, we change the corresponding indices to // be oob_tensor values - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * oob_check_mask, CheckValidIndices(scatter->parent(), scatter_indices, scatter_operands[0]->shape().dimensions(), @@ -861,14 +862,14 @@ absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( sorted_indices = sorted_tensors[sorted_tensors.size() - 1]; } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector prefix_scan_updates, ComputePrefixScan(sorted_updates, sorted_scalar_indices, scatter, parent, scatter_operands[0]->shape().dimensions())); if (non_scalar_update) { // As the indices are expanded, we need to recompute out-of-bound tensor // with the same shape - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( out_of_bound_tensor, CreateBoundTensor(parent, sorted_indices, scatter_operands[0]->shape().dimensions(), diff --git a/third_party/xla/xla/backends/gpu/transforms/scatter_slice_simplifier.cc b/third_party/xla/xla/backends/gpu/transforms/scatter_slice_simplifier.cc index 765a2ee44f6855..20a168edbedb0f 100644 --- a/third_party/xla/xla/backends/gpu/transforms/scatter_slice_simplifier.cc +++ b/third_party/xla/xla/backends/gpu/transforms/scatter_slice_simplifier.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -219,7 +220,7 @@ class ScatterSliceSimplifierVisitor : public DfsHloRewriteVisitor { VLOG(3) << "Skipping user " << user->name() << " (already replaced)"; continue; } - TF_RETURN_IF_ERROR(ReplaceUserRecursive(user, new_instruction)); + RETURN_IF_ERROR(ReplaceUserRecursive(user, new_instruction)); } return absl::OkStatus(); } @@ -253,9 +254,9 @@ class ScatterSliceSimplifierVisitor : public DfsHloRewriteVisitor { user->CloneWithNewOperands(new_shape(user), new_operands)); } else { auto* gte = Cast(user); - TF_ASSIGN_OR_RETURN(new_user, - MakeGetTupleElementHlo(operand, gte->tuple_index(), - &user->metadata())); + ASSIGN_OR_RETURN(new_user, + MakeGetTupleElementHlo(operand, gte->tuple_index(), + &user->metadata())); } // Replace slice user instructions recursively. diff --git a/third_party/xla/xla/backends/gpu/transforms/scheduling_instruction_annotator.cc b/third_party/xla/xla/backends/gpu/transforms/scheduling_instruction_annotator.cc index 8714dd73962ac5..1cd81d4288f7b8 100644 --- a/third_party/xla/xla/backends/gpu/transforms/scheduling_instruction_annotator.cc +++ b/third_party/xla/xla/backends/gpu/transforms/scheduling_instruction_annotator.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -63,8 +64,8 @@ absl::StatusOr SchedulingInstructionAnnotator::RunImpl( // propagated from calles to callers. for (HloComputation* computation : module->MakeComputationPostOrder(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, - AnnotateSchedulingInstructionNames(*computation)); + ASSIGN_OR_RETURN(bool result, + AnnotateSchedulingInstructionNames(*computation)); changed |= result; } diff --git a/third_party/xla/xla/backends/gpu/transforms/softmax_rewriter_triton.cc b/third_party/xla/xla/backends/gpu/transforms/softmax_rewriter_triton.cc index 9cce920c5a66f3..8b795c180bd074 100644 --- a/third_party/xla/xla/backends/gpu/transforms/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/backends/gpu/transforms/softmax_rewriter_triton.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/IR/MLIRContext.h" #include "xla/backends/gpu/codegen/triton/support.h" #include "xla/backends/gpu/transforms/reduction_dimension_grouper.h" @@ -253,12 +254,12 @@ absl::StatusOr MakeFusionForDiamond( normalization_fusion->GetModule()->SetAndUniquifyInstrName( normalization_fusion, "triton_softmax"); - TF_ASSIGN_OR_RETURN(auto gpu_config, - normalization_fusion->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, + normalization_fusion->backend_config()); FusionBackendConfig& backend_config = *gpu_config.mutable_fusion_backend_config(); backend_config.set_kind(kTritonFusionKind); - TF_RETURN_IF_ERROR(normalization_fusion->set_backend_config(gpu_config)); + RETURN_IF_ERROR(normalization_fusion->set_backend_config(gpu_config)); return xla::Cast(normalization_fusion); } @@ -283,7 +284,7 @@ absl::Status RunFusionPipeline( /*ignore_small_reduce_dims=*/false); reduction_pipeline.AddPass>(device_info); - TF_RETURN_IF_ERROR(reduction_pipeline.Run(module).status()); + RETURN_IF_ERROR(reduction_pipeline.Run(module).status()); return FusionPipeline(module->config().debug_options(), shape_size, alias_info, /*thread_pool=*/nullptr, device_info, @@ -312,8 +313,8 @@ EstimateOptimizedHloRunTimeWithoutSoftMaxRewriterTriton( // After this call, the `new_module` will have instruction fused without // SoftmaxRewriterTriton. - TF_RETURN_IF_ERROR(RunFusionPipeline(new_module.get(), device_info, - shape_size, alias_info, mlir_context)); + RETURN_IF_ERROR(RunFusionPipeline(new_module.get(), device_info, shape_size, + alias_info, mlir_context)); VLOG(3) << "priority fusion module: " << new_module->ToString(); @@ -324,7 +325,7 @@ EstimateOptimizedHloRunTimeWithoutSoftMaxRewriterTriton( /*min_latencies_seconds=*/{}, /*count_multiple_input_accesses=*/true}; GpuHloCostAnalysis cost_analysis(cost_analysis_options, device_info); - TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis)); + RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis)); absl::Duration total_run_time = absl::ZeroDuration(); @@ -357,7 +358,7 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( bool use_cost_model_to_evaluate_fusions) { auto fusion_adaptor = HloFusionAdaptor::ForInstruction(normalization_fusion); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( TiledRunTimeDataOrError tiled_runtime_data_or, indexing_performance_model.TryFindBestTilingForFusion(*fusion_adaptor)); @@ -371,10 +372,10 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( std::get(std::move(tiled_runtime_data_or)); if (use_cost_model_to_evaluate_fusions) { - TF_ASSIGN_OR_RETURN(absl::Duration run_time_without_softmax_rewriter, - EstimateOptimizedHloRunTimeWithoutSoftMaxRewriterTriton( - normalization_fusion, device_info, shape_size, - alias_info, mlir_context)); + ASSIGN_OR_RETURN(absl::Duration run_time_without_softmax_rewriter, + EstimateOptimizedHloRunTimeWithoutSoftMaxRewriterTriton( + normalization_fusion, device_info, shape_size, + alias_info, mlir_context)); VLOG(2) << "run time estimate if normalization diamond fused together: " << tiled_runtime_data.runtime_data.exec_time; @@ -390,12 +391,12 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( } } - TF_ASSIGN_OR_RETURN(auto backend_config, - normalization_fusion->backend_config()); + ASSIGN_OR_RETURN(auto backend_config, + normalization_fusion->backend_config()); *backend_config.mutable_fusion_backend_config() ->mutable_block_level_fusion_config() = tiled_runtime_data.block_level_parameters.ToBlockLevelFusionConfig(); - TF_RETURN_IF_ERROR(normalization_fusion->set_backend_config(backend_config)); + RETURN_IF_ERROR(normalization_fusion->set_backend_config(backend_config)); VLOG(2) << "Fusing with backend config: " << backend_config.DebugString(); return FusionDecision::Allow(); @@ -408,32 +409,31 @@ absl::StatusOr MaybeFuseDiamondImpl( const HloCostAnalysis::ShapeSizeFunction& shape_size, const GpuAliasInfo* alias_info, MLIRContext* mlir_context, bool use_cost_model_to_evaluate_fusions) { - TF_ASSIGN_OR_RETURN(HloFusionInstruction * normalization_fusion, - MakeFusionForDiamond(diamond)); + ASSIGN_OR_RETURN(HloFusionInstruction * normalization_fusion, + MakeFusionForDiamond(diamond)); HloInstruction* root = diamond.root; VLOG(2) << "MaybeFuseDiamondImpl: " << normalization_fusion->ToString(); - TF_ASSIGN_OR_RETURN(FusionDecision fusion_decision, - DecideIfShouldFuseAndMaybeSetBlockLevelParameters( - normalization_fusion, indexing_performance_model, - device_info, shape_size, alias_info, mlir_context, - use_cost_model_to_evaluate_fusions)); + ASSIGN_OR_RETURN(FusionDecision fusion_decision, + DecideIfShouldFuseAndMaybeSetBlockLevelParameters( + normalization_fusion, indexing_performance_model, + device_info, shape_size, alias_info, mlir_context, + use_cost_model_to_evaluate_fusions)); if (!fusion_decision.CanFuse()) { VLOG(2) << "Not fusing: " << fusion_decision.Explain(); normalization_fusion->DetachFromOperandsAndUsers(); - TF_RETURN_IF_ERROR(normalization_fusion->parent()->RemoveInstruction( + RETURN_IF_ERROR(normalization_fusion->parent()->RemoveInstruction( normalization_fusion)); return false; } if (root->IsRoot()) { root->parent()->set_root_instruction(normalization_fusion); - TF_RETURN_IF_ERROR( - root->parent()->RemoveInstructionAndUnusedOperands(root)); + RETURN_IF_ERROR(root->parent()->RemoveInstructionAndUnusedOperands(root)); } else { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( root->parent()->ReplaceInstruction(root, normalization_fusion)); } return true; @@ -444,8 +444,8 @@ absl::StatusOr MaybeFuseDiamondImpl( absl::StatusOr CanSymbolicTileAnalysisTileDiamond( const DiamondDescriptor& diamond, const se::DeviceDescription& device_info) { - TF_ASSIGN_OR_RETURN(HloFusionInstruction * normalization_fusion, - MakeFusionForDiamond(diamond)); + ASSIGN_OR_RETURN(HloFusionInstruction * normalization_fusion, + MakeFusionForDiamond(diamond)); mlir::MLIRContext mlir_context; RegisterSymbolicExprStorage(&mlir_context); @@ -475,9 +475,9 @@ absl::StatusOr CanSymbolicTileAnalysisTileDiamond( symbolic_tile_analysis_or_error); } - TF_RETURN_IF_ERROR(diamond.root->GetModule()->RemoveEmbeddedComputation( + RETURN_IF_ERROR(diamond.root->GetModule()->RemoveEmbeddedComputation( normalization_fusion->called_computation())); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( diamond.root->parent()->RemoveInstruction(normalization_fusion)); return can_tile; @@ -636,7 +636,7 @@ SoftmaxRewriterTriton::FindAllFusibleNormalizationDiamonds( /*root=*/instr, /*producer=*/std::get(producer)}; // We filter out the diamonds that cannot be tiled correctly using // `SymbolicTileAnalysis`. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool can_tile_diamond, CanSymbolicTileAnalysisTileDiamond(diamond, device_info_)); if (can_tile_diamond) { @@ -673,10 +673,10 @@ absl::StatusOr SoftmaxRewriterTriton::MaybeFuseNormalizationDiamond( absl::StatusOr SoftmaxRewriterTriton::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability( + RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability( device_info_.gpu_compute_capability())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector diamonds, FindAllFusibleNormalizationDiamonds(*module, execution_threads)); @@ -686,7 +686,7 @@ absl::StatusOr SoftmaxRewriterTriton::RunImpl( // the producer of diamond n+1. for (auto diamond = diamonds.rbegin(); diamond != diamonds.rend(); ++diamond) { - TF_ASSIGN_OR_RETURN(bool fused, MaybeFuseNormalizationDiamond(*diamond)); + ASSIGN_OR_RETURN(bool fused, MaybeFuseNormalizationDiamond(*diamond)); changed |= fused; } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/sort_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/sort_rewriter.cc index df0ffff9ce4edb..b3fc75e00bbaf3 100644 --- a/third_party/xla/xla/backends/gpu/transforms/sort_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/sort_rewriter.cc @@ -857,7 +857,7 @@ absl::StatusOr SortRewriter::RunOnInstruction( // MLIR dictionary attributes when rewriting to the final FFI target. SortOptions sort_options; sort_options.set_descending(sort_analysis.descending); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(sort_options)); + RETURN_IF_ERROR(custom_call->set_backend_config(sort_options)); // Build the replacement instruction. HloInstruction* replacement; @@ -880,8 +880,7 @@ absl::StatusOr SortRewriter::RunOnInstruction( } // Replace sort operation with custom call followed by GTE. - TF_RETURN_IF_ERROR( - sort_op->parent()->ReplaceInstruction(sort_op, replacement)); + RETURN_IF_ERROR(sort_op->parent()->ReplaceInstruction(sort_op, replacement)); return true; } @@ -906,7 +905,7 @@ absl::StatusOr SortRewriter::RunOnComputation( bool changed = false; for (auto* sort : sort_ops) { - TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(sort)); + ASSIGN_OR_RETURN(bool result, RunOnInstruction(sort)); changed |= result; } return changed; @@ -928,8 +927,8 @@ absl::StatusOr SortRewriter::RunImpl( bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, - RunOnComputation(computation, deviceless_cub_mode)); + ASSIGN_OR_RETURN(bool result, + RunOnComputation(computation, deviceless_cub_mode)); changed |= result; } XLA_VLOG_LINES(3, "SortRewriter::RunImpl(), after:\n" + module->ToString()); diff --git a/third_party/xla/xla/backends/gpu/transforms/splitk_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/splitk_rewriter.cc index 5e4b2da1c1d1cd..4c322bbea37e2e 100644 --- a/third_party/xla/xla/backends/gpu/transforms/splitk_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/splitk_rewriter.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -332,34 +333,34 @@ absl::StatusOr SplitKDimensionOfDot(HloDotInstruction* src_dot, // Update the dot's dimension numbers accordingly (shifting right all the // dimensions starting from the K dimension and inserting new batch dims). - TF_ASSIGN_OR_RETURN(auto dims, DotOperandDims::FromDot(src_dot)); + ASSIGN_OR_RETURN(auto dims, DotOperandDims::FromDot(src_dot)); // We need to insert the dimension at the same index in both operands. // InsertDimension inserts at "natural" location by default which may be // different for lhs and rhs. Therefore, we take the index from the lhs and // insert at the same index in the rhs. std::optional insertion_idx = std::nullopt; for (size_t i : {0, 1}) { - TF_ASSIGN_OR_RETURN(insertion_idx, dims[i].InsertDimension( - DotOperandDims::kBatch, k_incices[i], - split_k, insertion_idx)); - TF_RETURN_IF_ERROR(dims[i].UpdateShape(operands[i]->shape())); + ASSIGN_OR_RETURN(insertion_idx, dims[i].InsertDimension( + DotOperandDims::kBatch, k_incices[i], + split_k, insertion_idx)); + RETURN_IF_ERROR(dims[i].UpdateShape(operands[i]->shape())); } - TF_ASSIGN_OR_RETURN(DotDimensionNumbers new_dnums, - DotOperandDims::CreateDotDimensionNumbers(dims)); - TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, - MakeDotHlo(operands[0], operands[1], new_dnums, - src_dot->precision_config(), accumulator_type, - &src_dot->metadata())); + ASSIGN_OR_RETURN(DotDimensionNumbers new_dnums, + DotOperandDims::CreateDotDimensionNumbers(dims)); + ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeDotHlo(operands[0], operands[1], new_dnums, + src_dot->precision_config(), accumulator_type, + &src_dot->metadata())); // Reduce along the new batch dimension. Batch dimensions are first in the dot // result, so we use index within the batch category to get it. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( int64_t splitk_dim_idx, dims[0].IndexWithinCategory(DotOperandDims::kBatch, k_incices[0])); - TF_ASSIGN_OR_RETURN(HloInstruction * splitk_root, - ReduceDimension(new_dot, splitk_dim_idx)); + ASSIGN_OR_RETURN(HloInstruction * splitk_root, + ReduceDimension(new_dot, splitk_dim_idx)); *splitk_root->mutable_shape()->mutable_layout() = src_dot->shape().layout(); if (output_type != accumulator_type) { splitk_root = MakeConvertToHlo(splitk_root, output_type); @@ -398,9 +399,9 @@ class SplitkRewriterVisitor : public DfsHloRewriteVisitor { if (split_k == 1) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, - SplitKDimensionOfDot(dot, split_k)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_dot)); + ASSIGN_OR_RETURN(HloInstruction * new_dot, + SplitKDimensionOfDot(dot, split_k)); + RETURN_IF_ERROR(ReplaceInstruction(instr, new_dot)); return absl::OkStatus(); } @@ -416,7 +417,7 @@ absl::StatusOr SplitkRewriter::RunImpl( for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { SplitkRewriterVisitor visitor(device_description_); - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); changed |= visitor.changed(); } return changed; diff --git a/third_party/xla/xla/backends/gpu/transforms/stream_attribute_annotator.cc b/third_party/xla/xla/backends/gpu/transforms/stream_attribute_annotator.cc index 5af214b940812e..bf10c89d5401f4 100644 --- a/third_party/xla/xla/backends/gpu/transforms/stream_attribute_annotator.cc +++ b/third_party/xla/xla/backends/gpu/transforms/stream_attribute_annotator.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -85,7 +86,7 @@ absl::StatusOr AnnotateStreamAttributesForInstruction( instr_gpu_config.set_operation_queue_id( comp_root_gpu_config->operation_queue_id()); - TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config)); + RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config)); return true; } @@ -97,7 +98,7 @@ absl::StatusOr AnnotateStreamAttributesForCopyStart( return false; } instr_gpu_config.set_operation_queue_id(channel_id); - TF_RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config)); + RETURN_IF_ERROR(instr->set_backend_config(instr_gpu_config)); VLOG(3) << "Add copy-start's backend config: " << channel_id; return true; } @@ -130,13 +131,13 @@ absl::StatusOr WrapIntoFusionAndAnnotateStreamAttributes( module->schedule().replace_instruction(computation, instruction, fusion_instruction); } - TF_RETURN_IF_ERROR(fusion_instruction->CopyAllControlDepsFrom(instruction)); - TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + RETURN_IF_ERROR(fusion_instruction->CopyAllControlDepsFrom(instruction)); + RETURN_IF_ERROR(instruction->DropAllControlDeps()); + RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction)); + RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); instr_gpu_config.set_operation_queue_id(channel_id); - TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(instr_gpu_config)); + RETURN_IF_ERROR(fusion_instruction->set_backend_config(instr_gpu_config)); VLOG(3) << "Add async stream " << channel_id << " and wrapped instruction " << instruction->ToString(); VLOG(3) << " Fusion wrapper: " << fusion_instruction->ToString(); @@ -163,25 +164,25 @@ absl::StatusOr StreamAttributeAnnotator::RunImpl( // when the root of fusion is a single instruction // running on non-default stream. if (HloPredicateIsOp(instr)) { - TF_ASSIGN_OR_RETURN(bool comp_result, - AnnotateStreamAttributesForInstruction( - instr, instr_gpu_config.value())); + ASSIGN_OR_RETURN(bool comp_result, + AnnotateStreamAttributesForInstruction( + instr, instr_gpu_config.value())); changed |= comp_result; } else if (instr->opcode() == HloOpcode::kCopyStart && module->has_schedule()) { - TF_ASSIGN_OR_RETURN(bool comp_result, - AnnotateStreamAttributesForCopyStart( - instr, channel_id, instr_gpu_config.value())); + ASSIGN_OR_RETURN(bool comp_result, + AnnotateStreamAttributesForCopyStart( + instr, channel_id, instr_gpu_config.value())); changed |= comp_result; continue; } else if (comp->IsAsyncComputation() && (instr->opcode() == HloOpcode::kDynamicSlice || instr->opcode() == HloOpcode::kDynamicUpdateSlice) && module->has_schedule()) { - TF_ASSIGN_OR_RETURN(bool comp_result, - WrapIntoFusionAndAnnotateStreamAttributes( - instr, channel_id, instr_gpu_config.value(), - device_description_)); + ASSIGN_OR_RETURN(bool comp_result, + WrapIntoFusionAndAnnotateStreamAttributes( + instr, channel_id, instr_gpu_config.value(), + device_description_)); changed |= comp_result; continue; } diff --git a/third_party/xla/xla/backends/gpu/transforms/stream_attribute_annotator_test.cc b/third_party/xla/xla/backends/gpu/transforms/stream_attribute_annotator_test.cc index 3ed14211906b48..04119f6322fbfb 100644 --- a/third_party/xla/xla/backends/gpu/transforms/stream_attribute_annotator_test.cc +++ b/third_party/xla/xla/backends/gpu/transforms/stream_attribute_annotator_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -36,9 +37,9 @@ namespace xla::gpu { namespace { absl::StatusOr MakeDeviceDescription() { - TF_ASSIGN_OR_RETURN(stream_executor::DeviceDescription device_description, - stream_executor::DeviceDescription::FromProto( - stream_executor::GpuDeviceInfoProto{})); + ASSIGN_OR_RETURN(stream_executor::DeviceDescription device_description, + stream_executor::DeviceDescription::FromProto( + stream_executor::GpuDeviceInfoProto{})); device_description.set_threads_per_warp(32); return device_description; } diff --git a/third_party/xla/xla/backends/gpu/transforms/stream_attribute_async_wrapper.cc b/third_party/xla/xla/backends/gpu/transforms/stream_attribute_async_wrapper.cc index a442eca4585940..0c259bfd83bbcb 100644 --- a/third_party/xla/xla/backends/gpu/transforms/stream_attribute_async_wrapper.cc +++ b/third_party/xla/xla/backends/gpu/transforms/stream_attribute_async_wrapper.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -43,17 +44,17 @@ static absl::StatusOr AsynchronizeInstruction(HloInstruction* instr) { return false; } HloComputation* computation = instr->parent(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * done, computation->CreateAsyncInstructions( instr, {}, StreamAttributeAsyncWrapper::kParallelExecutionThread, /*replace=*/true)); - TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, - done->backend_config()); + ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, + done->backend_config()); // Set the false delay of done op to be false so it can be scheduled // far apart from start. gpu_config.set_force_earliest_schedule(false); - TF_RETURN_IF_ERROR(done->set_backend_config(gpu_config)); + RETURN_IF_ERROR(done->set_backend_config(gpu_config)); VLOG(5) << "Created async instruction: " << done->ToString(); return true; } @@ -68,7 +69,7 @@ absl::StatusOr StreamAttributeAsyncWrapper::RunImpl( for (const HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instr : comp->instructions()) { - TF_ASSIGN_OR_RETURN(bool result, AsynchronizeInstruction(instr)); + ASSIGN_OR_RETURN(bool result, AsynchronizeInstruction(instr)); changed |= result; } } diff --git a/third_party/xla/xla/backends/gpu/transforms/topk_splitter.cc b/third_party/xla/xla/backends/gpu/transforms/topk_splitter.cc index 3cf04f7386d938..a0388d2e6d3f04 100644 --- a/third_party/xla/xla/backends/gpu/transforms/topk_splitter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/topk_splitter.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -84,7 +85,7 @@ class TopkSplitterVisitor : public DfsHloRewriteVisitor { // Split the input into B batches and compute TopK over the batched arrays. Shape split_input_shape = ShapeUtil::MakeShape(data_shape.element_type(), {new_batch, new_n}); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * reshaped, MakeReshapeHlo(split_input_shape, topk->mutable_operand(0))); Shape batch_topk_shape = ShapeUtil::MakeTupleShape( @@ -95,18 +96,18 @@ class TopkSplitterVisitor : public DfsHloRewriteVisitor { batch_topk_shape, {reshaped}, topk->to_apply(), "TopK", /*opaque=*/"")); // Fix indices, adding j*split_N to the j-th batch of indices. - TF_ASSIGN_OR_RETURN(HloInstruction * indices, - MakeGetTupleElementHlo(batch_topk, 1)); - TF_ASSIGN_OR_RETURN(HloInstruction * values, - MakeGetTupleElementHlo(batch_topk, 0)); + ASSIGN_OR_RETURN(HloInstruction * indices, + MakeGetTupleElementHlo(batch_topk, 1)); + ASSIGN_OR_RETURN(HloInstruction * values, + MakeGetTupleElementHlo(batch_topk, 0)); Shape iota_shape = ShapeUtil::MakeShape(S32, {new_batch}); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * fix, MakeBinaryHlo( HloOpcode::kMultiply, MakeIotaHlo(comp, iota_shape, 0), MakeBroadcastHlo(MakeR0ConstantHlo(comp, new_n), /*broadcast_dimensions=*/{}, iota_shape))); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( indices, MakeBinaryHlo(HloOpcode::kAdd, indices, MakeBroadcastHlo(fix, {0}, indices->shape()))); // With the indices restored, compute a final top-k. Since this topk uses diff --git a/third_party/xla/xla/backends/gpu/transforms/tree_reduction_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/tree_reduction_rewriter.cc index 7532755368a5f0..81103d321658b6 100644 --- a/third_party/xla/xla/backends/gpu/transforms/tree_reduction_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/tree_reduction_rewriter.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -317,9 +318,8 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { } // Inner reduce that reduces [k1, k2] to [k1]. - TF_ASSIGN_OR_RETURN( - auto tuple_shape, - ShapeUtil::MakeValidatedMaybeTupleShape(inner_reduce_shapes)); + ASSIGN_OR_RETURN(auto tuple_shape, ShapeUtil::MakeValidatedMaybeTupleShape( + inner_reduce_shapes)); HloInstruction *inner_reduce = reduce->parent()->AddInstruction( HloInstruction::CreateReduce(tuple_shape, reshaped_padded_inputs, reduce->init_values(), inner_reduce_dims, @@ -353,8 +353,8 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { ShapeUtil::DeleteDimension(minor_reduction_dim, input->shape())); } - TF_ASSIGN_OR_RETURN(auto tuple_shape, - ShapeUtil::MakeValidatedMaybeTupleShape(tuple_shapes)); + ASSIGN_OR_RETURN(auto tuple_shape, + ShapeUtil::MakeValidatedMaybeTupleShape(tuple_shapes)); HloInstruction *inner_reduce = hlo->parent()->AddInstruction(HloInstruction::CreateReduce( tuple_shape, hlo->inputs(), hlo->init_values(), @@ -374,9 +374,8 @@ absl::StatusOr TreeReductionRewriter::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(5) << "Rewriter input: " << module->ToString(); - TF_ASSIGN_OR_RETURN(bool changed, - ReductionRewriterVisitor(device_description_) - .RunOnModule(module, execution_threads)); + ASSIGN_OR_RETURN(bool changed, ReductionRewriterVisitor(device_description_) + .RunOnModule(module, execution_threads)); VLOG(5) << "Rewriter output: " << module->ToString(); return changed; } diff --git a/third_party/xla/xla/backends/gpu/transforms/triangular_solve_rewriter.cc b/third_party/xla/xla/backends/gpu/transforms/triangular_solve_rewriter.cc index ecb51be28f52e9..6ab1eda8b667fc 100644 --- a/third_party/xla/xla/backends/gpu/transforms/triangular_solve_rewriter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/triangular_solve_rewriter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -68,7 +69,7 @@ absl::StatusOr TriangularSolveRewriter::RunImpl( comp->AddInstruction(HloInstruction::CreateCustomCall( new_shape, instr->operands(), kTriangularSolveCallTarget)); module->SetAndUniquifyInstrName(custom_call, "triangular-solve"); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( custom_call->set_backend_config(instr->triangular_solve_options())); // Preserve metadata from `instr`. @@ -76,9 +77,9 @@ absl::StatusOr TriangularSolveRewriter::RunImpl( custom_call->set_frontend_attributes(instr->frontend_attributes()); // Get the actual result out of the custom call's tuple. - TF_ASSIGN_OR_RETURN(HloInstruction * gte, - MakeGetTupleElementHlo(custom_call, 0)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); + ASSIGN_OR_RETURN(HloInstruction * gte, + MakeGetTupleElementHlo(custom_call, 0)); + RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte)); changed = true; } } diff --git a/third_party/xla/xla/backends/gpu/transforms/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/backends/gpu/transforms/triton_fusion_numerics_verifier.cc index 626cada50f700c..18c4d95cdd1353 100644 --- a/third_party/xla/xla/backends/gpu/transforms/triton_fusion_numerics_verifier.cc +++ b/third_party/xla/xla/backends/gpu/transforms/triton_fusion_numerics_verifier.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/IR/MLIRContext.h" #include "xla/backends/autotuner/profiler.h" #include "xla/backends/gpu/autotuner/gpu_codegen_backend.h" @@ -77,8 +78,7 @@ absl::StatusOr AsTritonFusion( return nullptr; } const HloFusionInstruction* fusion = Cast(hlo); - TF_ASSIGN_OR_RETURN(auto gpu_config, - fusion->backend_config()); + ASSIGN_OR_RETURN(auto gpu_config, fusion->backend_config()); const FusionBackendConfig& backend_config = gpu_config.fusion_backend_config(); if (backend_config.kind() == kTritonFusionKind || @@ -105,7 +105,7 @@ absl::Status InlineModuleFusions(HloModule* hlo_module) { // Other emitters might not support them, thus we need to inline all fusions. while (true) { FusionToCallVisitor visitor; - TF_RETURN_IF_ERROR(hlo_module->entry_computation()->Accept(&visitor)); + RETURN_IF_ERROR(hlo_module->entry_computation()->Accept(&visitor)); if (!visitor.changed()) { return absl::OkStatus(); } @@ -113,7 +113,7 @@ absl::Status InlineModuleFusions(HloModule* hlo_module) { pipeline.AddPass(); pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); VLOG(2) << "After inline call: " << hlo_module->ToString(); } return absl::OkStatus(); @@ -130,18 +130,18 @@ absl::StatusOr> NewHloModuleFromFusionComputation( std::unique_ptr new_module = ExtractComputationIntoNewModule(*fusion.fused_instructions_computation()); new_module->mutable_config().set_debug_options(debug_opts); - TF_RETURN_IF_ERROR(InlineModuleFusions(new_module.get())); + RETURN_IF_ERROR(InlineModuleFusions(new_module.get())); TreeReductionRewriter tree_reduction_rewriter(gpu_device_info); - TF_RETURN_IF_ERROR(tree_reduction_rewriter.Run(new_module.get()).status()); + RETURN_IF_ERROR(tree_reduction_rewriter.Run(new_module.get()).status()); PriorityFusion fusion_pass( /*thread_pool=*/nullptr, gpu_device_info, alias_info, HloCostAnalysis::Options{}, mlir_context); - TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); + RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); // If the priority fusion pass above skipped some instructions, turn them // into fusions. FusionWrapper fusion_wrapper(gpu_device_info); - TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status()); + RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status()); return new_module; } @@ -165,10 +165,10 @@ absl::Status ForAllTritonFusions( for (HloComputation* computation : module.MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instruction : computation->instructions()) { - TF_ASSIGN_OR_RETURN(auto triton_fusion, AsTritonFusion(instruction)); + ASSIGN_OR_RETURN(auto triton_fusion, AsTritonFusion(instruction)); if (triton_fusion != nullptr) { VLOG(2) << "processing fusion " << triton_fusion->name(); - TF_RETURN_IF_ERROR(fn(*triton_fusion)); + RETURN_IF_ERROR(fn(*triton_fusion)); } } } @@ -190,10 +190,10 @@ absl::StatusOr CompileAndRunFusion( }; DebugOptions adjusted_debug_opts = debug_opts; GpuCodegenBackend::AdjustDebugOptionsForAutotuning(adjusted_debug_opts); - TF_ASSIGN_OR_RETURN(std::unique_ptr new_module, - extractor(adjusted_debug_opts)); + ASSIGN_OR_RETURN(std::unique_ptr new_module, + extractor(adjusted_debug_opts)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr compiler, Compiler::GetForPlatform(stream_executor.GetPlatform()->id())); @@ -201,9 +201,9 @@ absl::StatusOr CompileAndRunFusion( compile_options.device_allocator = allocator; compile_options.embed_hlo_module = false; - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - compiler->RunBackend(std::move(new_module), - &stream_executor, compile_options)); + ASSIGN_OR_RETURN(std::unique_ptr executable, + compiler->RunBackend(std::move(new_module), &stream_executor, + compile_options)); if (executable == nullptr) { return absl::InternalError("Failed to compile Triton fusion."); @@ -220,11 +220,11 @@ absl::StatusOr CompileAndRunFusion( } } - TF_ASSIGN_OR_RETURN(std::unique_ptr input_buffers, - profiler.CreateInputBuffers(executable.get())); + ASSIGN_OR_RETURN(std::unique_ptr input_buffers, + profiler.CreateInputBuffers(executable.get())); - TF_ASSIGN_OR_RETURN(ProfileResult profile_result, - profiler.Profile(executable.get(), *input_buffers)); + ASSIGN_OR_RETURN(ProfileResult profile_result, + profiler.Profile(executable.get(), *input_buffers)); if (!profile_result.output_buffer.has_value()) { return Internal("Profiling did not return output buffer."); @@ -252,16 +252,16 @@ TritonFusionNumericsVerifier::FusionCacheKey CacheKeyForFusion( absl::Status TritonFusionNumericsVerifier::VerifyTritonFusion( GpuProfiler& profiler, const HloFusionInstruction& fusion, const DebugOptions& debug_opts) { - TF_ASSIGN_OR_RETURN(auto triton_result, - triton_fusion_numerics_pass_internal::CompileAndRunFusion( - profiler, fusion, debug_opts, - /*disable_triton=*/false, stream_executor_, - allocator_, alias_info_, mlir_context_)); - TF_ASSIGN_OR_RETURN(auto emitters_result, - triton_fusion_numerics_pass_internal::CompileAndRunFusion( - profiler, fusion, debug_opts, - /*disable_triton=*/true, stream_executor_, allocator_, - alias_info_, mlir_context_)); + ASSIGN_OR_RETURN(auto triton_result, + triton_fusion_numerics_pass_internal::CompileAndRunFusion( + profiler, fusion, debug_opts, + /*disable_triton=*/false, stream_executor_, allocator_, + alias_info_, mlir_context_)); + ASSIGN_OR_RETURN(auto emitters_result, + triton_fusion_numerics_pass_internal::CompileAndRunFusion( + profiler, fusion, debug_opts, + /*disable_triton=*/true, stream_executor_, allocator_, + alias_info_, mlir_context_)); auto status = profiler.CheckOutputBuffer( triton_result, emitters_result, debug_opts.xla_gpu_autotune_gemm_rtol()); @@ -310,7 +310,7 @@ absl::StatusOr TritonFusionNumericsVerifier::RunImpl( return Internal("Failed to create GpuProfiler."); } - TF_RETURN_IF_ERROR(triton_fusion_numerics_pass_internal::ForAllTritonFusions( + RETURN_IF_ERROR(triton_fusion_numerics_pass_internal::ForAllTritonFusions( *module, execution_threads, [&](const HloFusionInstruction& fusion) -> absl::Status { auto key = CacheKeyForFusion(fusion); diff --git a/third_party/xla/xla/backends/gpu/transforms/variadic_op_splitter.cc b/third_party/xla/xla/backends/gpu/transforms/variadic_op_splitter.cc index 52c4b56faa6e81..fdd12580ac7009 100644 --- a/third_party/xla/xla/backends/gpu/transforms/variadic_op_splitter.cc +++ b/third_party/xla/xla/backends/gpu/transforms/variadic_op_splitter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -79,7 +80,7 @@ absl::StatusOr SplitConcatenate(HloInstruction* concat, } operands_to_split = new_operands; } - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(concat, operands_to_split[0])); + RETURN_IF_ERROR(comp->ReplaceInstruction(concat, operands_to_split[0])); return true; } @@ -104,7 +105,7 @@ absl::StatusOr VariadicOpSplitter::RunImpl( module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* op : GetRelevantVariadicOps(comp)) { // TODO(b/112613927): Handle also other ops than concatenate. - TF_ASSIGN_OR_RETURN(bool result, SplitConcatenate(op, comp)); + ASSIGN_OR_RETURN(bool result, SplitConcatenate(op, comp)); changed |= result; } } diff --git a/third_party/xla/xla/backends/gpu/transforms/windowed_einsum_handler.cc b/third_party/xla/xla/backends/gpu/transforms/windowed_einsum_handler.cc index 982aa35ea1d291..70736168effe37 100644 --- a/third_party/xla/xla/backends/gpu/transforms/windowed_einsum_handler.cc +++ b/third_party/xla/xla/backends/gpu/transforms/windowed_einsum_handler.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -203,7 +204,7 @@ absl::StatusOr ShiftDequantizationF8( // Replace the dequantized dot operands in the parameter tuple used by while // with FP8 operands. for (int k = 0; k < 2; ++k) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( param_tuple->ReplaceOperandWithDifferentShape(k, operands[k])); ShapeUtil::UpdateTupleShape(operands[k]->shape(), k, param_tuple->mutable_shape()); @@ -225,17 +226,17 @@ absl::StatusOr ShiftDequantizationF8( // instructions retrieving FP8 dot operands from the input tuple. HloInstruction* body_param = while_body->parameter_instruction(0); for (int k = 0; k < 2; ++k) { - TF_ASSIGN_OR_RETURN(HloInstruction * operand_f8, - MakeGetTupleElementHlo(body_param, k)); + ASSIGN_OR_RETURN(HloInstruction * operand_f8, + MakeGetTupleElementHlo(body_param, k)); if (while_root->operand(k) == gtes[k]) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( while_root->ReplaceOperandWithDifferentShape(k, operand_f8)); ShapeUtil::UpdateTupleShape(operand_f8->shape(), k, while_root->mutable_shape()); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * operand_scale, MakeGetTupleElementHlo( body_param, body_param->shape().tuple_shapes().size() - 2 + k)); @@ -250,7 +251,7 @@ absl::StatusOr ShiftDequantizationF8( MakeConvertToHlo(operand_f8, gtes[k]->shape().element_type()); HloInstruction* broadcast_scale = MakeBroadcastHlo(operand_scale, {}, operand_f32->shape()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * operand_scaled, MakeBinaryHlo(binaries[k]->opcode(), operand_f32, broadcast_scale)); @@ -260,11 +261,10 @@ absl::StatusOr ShiftDequantizationF8( // exchanged in gemm_rewriter.cc. for (int l = 0; l < 2; ++l) { if (dots[l]->operand(k) == gtes[k]) { - TF_RETURN_IF_ERROR(dots[l]->ReplaceOperandWith(k, operand_scaled)); + RETURN_IF_ERROR(dots[l]->ReplaceOperandWith(k, operand_scaled)); } if (dyn_slices[l] && dyn_slices[l]->operand(0) == gtes[k]) { - TF_RETURN_IF_ERROR( - dyn_slices[l]->ReplaceOperandWith(0, operand_scaled)); + RETURN_IF_ERROR(dyn_slices[l]->ReplaceOperandWith(0, operand_scaled)); } } @@ -285,13 +285,13 @@ absl::StatusOr ShiftDequantizationF8( // Insert the dequantization between coll_perms[0] and dots[1]. HloInstruction* coll_perm0_f32 = MakeConvertToHlo(coll_perms_f8[0], gtes[k]->shape().element_type()); - TF_ASSIGN_OR_RETURN(HloInstruction * x_scaled, - MakeBinaryHlo(binaries[k]->opcode(), coll_perm0_f32, - broadcast_scale)); - TF_RETURN_IF_ERROR(dots[1]->ReplaceOperandWith(0, x_scaled)); + ASSIGN_OR_RETURN(HloInstruction * x_scaled, + MakeBinaryHlo(binaries[k]->opcode(), coll_perm0_f32, + broadcast_scale)); + RETURN_IF_ERROR(dots[1]->ReplaceOperandWith(0, x_scaled)); // Update the output tuple. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( while_root->ReplaceOperandWithDifferentShape(0, coll_perms_f8[1])); ShapeUtil::UpdateTupleShape(coll_perms_f8[1]->shape(), 0, while_root->mutable_shape()); @@ -301,16 +301,16 @@ absl::StatusOr ShiftDequantizationF8( // Update the shape of the while call in the parent computation. HloInstruction* new_while_instr = while_instr->AddInstruction( while_instr->CloneWithNewShape(while_root->shape())); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( while_instr->ReplaceAllUsesWithDifferentShape(new_while_instr)); - TF_RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr)); + RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr)); if (coll_perms[0]) { - TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[1])); - TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[0])); + RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[1])); + RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[0])); } - TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gtes[0])); - TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gtes[1])); + RETURN_IF_ERROR(while_body->RemoveInstruction(gtes[0])); + RETURN_IF_ERROR(while_body->RemoveInstruction(gtes[1])); VLOG(5) << "FP8 dequantization moved into while loop."; return new_while_instr; @@ -331,7 +331,7 @@ absl::Status UpdateDotAndConsumerConfig(HloInstruction* dot, auto dot_gpu_config = dot->backend_config(); dot_gpu_config->set_operation_queue_id(stream_id); - TF_RETURN_IF_ERROR(dot->set_backend_config(dot_gpu_config.value())); + RETURN_IF_ERROR(dot->set_backend_config(dot_gpu_config.value())); return absl::OkStatus(); } @@ -341,7 +341,7 @@ absl::Status SetForceDelayForInstruction(HloInstruction* instr, gpu_config->set_force_earliest_schedule(force_delay); - TF_RETURN_IF_ERROR(instr->set_backend_config(gpu_config.value())); + RETURN_IF_ERROR(instr->set_backend_config(gpu_config.value())); return absl::OkStatus(); } @@ -518,7 +518,7 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( original_operands.push_back(new_full_buffer_output); HloInstruction* new_output_tuple = while_body->AddInstruction( HloInstruction::CreateTuple(original_operands)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); return absl::OkStatus(); @@ -640,7 +640,7 @@ absl::Status MoveAccumulationOutsideLoop( original_operands.push_back(concat); HloInstruction* new_output_tuple = while_body->AddInstruction( HloInstruction::CreateTuple(original_operands)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); // Update the shape of the while loop instruction. @@ -664,7 +664,7 @@ absl::Status MoveAccumulationOutsideLoop( }); if (it != loop->users().end()) { original_output_gte = *it; - TF_RETURN_IF_ERROR(original_output_gte->ReplaceAllUsesWith(reduced_result)); + RETURN_IF_ERROR(original_output_gte->ReplaceAllUsesWith(reduced_result)); } return absl::OkStatus(); } @@ -685,13 +685,13 @@ absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) { m::CollectivePermute( &matched_cp, m::GetTupleElement(m::Parameter(), force_delay_cp_gte_index)))) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetForceDelayForInstruction(matched_cp, /*force_delay=*/true)); } if (HloPredicateIsOp(inst)) { // Dispatch the dot to additional compute stream. - TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(inst, stream_id)); + RETURN_IF_ERROR(UpdateDotAndConsumerConfig(inst, stream_id)); ++stream_id; } // If dot's result is accumulated, this means we found a loop with @@ -705,7 +705,7 @@ absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) { if (partial_accumulations.size() > 0 && absl::StrContains(while_body->name(), WindowedEinsumHandler::kWindowedEinsumAgLoopName)) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( MoveAccumulationOutsideLoop(partial_accumulations, while_body, loop)); } return absl::OkStatus(); @@ -794,7 +794,7 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { } } } - TF_RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith( + RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith( 0, matched_a2a->mutable_operand(0))); HloInstruction* new_a2a = matched_a2a->parent()->AddInstruction(HloInstruction::CreateAllToAll( @@ -803,8 +803,8 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { false, hlo_query::NextChannelId(*matched_a2a->GetModule()), split_dimension)); - TF_RETURN_IF_ERROR(dot->ReplaceOperandWith(0, new_a2a)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(dot->ReplaceOperandWith(0, new_a2a)); + RETURN_IF_ERROR( matched_a2a->parent()->RemoveInstructionAndUnusedOperands(matched_a2a)); MarkAsChanged(); *lhs = new_a2a; @@ -926,8 +926,7 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { "windowed einsum loop : " << loop->ToString(); - TF_RETURN_IF_ERROR( - ProcessWindowedEinsumLoopForActivationCaching(ag_loop)); + RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching(ag_loop)); ag_loop.consumed = true; } @@ -945,15 +944,15 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { MakeConvertToHlo(new_gte, binary->shape().element_type()); HloInstruction* bcast_scale = MakeBroadcastHlo(scale, {}, new_convert->shape()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( new_gte_scaled, MakeBinaryHlo(binary->opcode(), new_convert, bcast_scale)); } - TF_RETURN_IF_ERROR(dot->ReplaceOperandWith( + RETURN_IF_ERROR(dot->ReplaceOperandWith( cache_output_index, scale ? new_gte_scaled : new_gte)); if (all_gather->user_count() == 0) { - TF_RETURN_IF_ERROR(comp->RemoveInstruction(all_gather)); + RETURN_IF_ERROR(comp->RemoveInstruction(all_gather)); } } // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms @@ -968,8 +967,8 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { HloInstruction* lhs; HloInstruction* rhs; std::vector replica_groups; - TF_ASSIGN_OR_RETURN(bool matched, - MatchA2aGemmWithIntermediateReshapes(dot, &lhs, &rhs)); + ASSIGN_OR_RETURN(bool matched, + MatchA2aGemmWithIntermediateReshapes(dot, &lhs, &rhs)); if (matched) { replica_groups = lhs->replica_groups(); // We split the a2a+gemm along the contracting dimension into multiple @@ -1054,7 +1053,7 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { Shape partial_all_to_all_shape = lhs_slice_shape; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape partial_dot_shape, ShapeInference::InferDotOpShape( partial_all_to_all_shape, rhs_slice_shape, original_dot_dnums, @@ -1092,10 +1091,9 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { HloInstruction::CreateBinary(partial_dot->shape(), HloOpcode::kAdd, partial_dot, partial_result)); a2a->SetupDerivedInstruction(partial_result); - TF_RETURN_IF_ERROR( - UpdateDotAndConsumerConfig(partial_dot, stream_id++)); + RETURN_IF_ERROR(UpdateDotAndConsumerConfig(partial_dot, stream_id++)); } - TF_RETURN_IF_ERROR(ReplaceInstruction(dot, partial_result)); + RETURN_IF_ERROR(ReplaceInstruction(dot, partial_result)); } return absl::OkStatus(); } @@ -1172,12 +1170,11 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { matched_dot->shape(), {matched_dot}, a2a->replica_groups(), false, hlo_query::NextChannelId(*matched_dot->GetModule()), split_dimension)); - TF_RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith( + RETURN_IF_ERROR(allowed_intermediate_ops.back()->ReplaceOperandWith( 0, result.a2a_replacement)); inst->SetupDerivedInstruction(result.a2a_replacement); - TF_RETURN_IF_ERROR( - ReplaceInstruction(inst, allowed_intermediate_ops.front())); + RETURN_IF_ERROR(ReplaceInstruction(inst, allowed_intermediate_ops.front())); result.lhs = matched_dot->mutable_operand(0); result.rhs = matched_dot->mutable_operand(1); result.producer_gemm = matched_dot; @@ -1200,8 +1197,8 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { // Rewrites a gemm+alltoall into multiple independent partial gemm+a2as // to minimize communication overhead. std::vector replica_groups; - TF_ASSIGN_OR_RETURN(MatchedGemmA2aResult matched_result, - MatchGemmA2aWithIntermediateReshapes(inst)); + ASSIGN_OR_RETURN(MatchedGemmA2aResult matched_result, + MatchGemmA2aWithIntermediateReshapes(inst)); if (matched_result.matched) { HloInstruction* a2a = inst; if (matched_result.a2a_replacement) { @@ -1294,11 +1291,10 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { HloInstruction* partial_result = output_buffer; Shape partial_all_to_all_shape = all_to_all->shape(); - TF_ASSIGN_OR_RETURN( - Shape partial_dot_shape, - ShapeInference::InferDotOpShape( - lhs_slice_shape, rhs_slice_shape, original_dot_dnums, - /*preferred_element_type=*/std::nullopt)); + ASSIGN_OR_RETURN(Shape partial_dot_shape, + ShapeInference::InferDotOpShape( + lhs_slice_shape, rhs_slice_shape, original_dot_dnums, + /*preferred_element_type=*/std::nullopt)); int64_t stream_id = hlo_query::NextChannelId(*all_to_all->GetModule()); for (int64_t i = 0; i < group_size; ++i) { lhs_slice = comp->AddInstruction(HloInstruction::CreateSlice( @@ -1333,10 +1329,9 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { partial_all_to_all_shape, HloOpcode::kAdd, partial_all_to_all, partial_result)); all_to_all->SetupDerivedInstruction(partial_result); - TF_RETURN_IF_ERROR( - UpdateDotAndConsumerConfig(partial_dot, stream_id++)); + RETURN_IF_ERROR(UpdateDotAndConsumerConfig(partial_dot, stream_id++)); } - TF_RETURN_IF_ERROR(ReplaceInstruction(all_to_all, partial_result)); + RETURN_IF_ERROR(ReplaceInstruction(all_to_all, partial_result)); } return absl::OkStatus(); @@ -1376,7 +1371,7 @@ absl::StatusOr WindowedEinsumHandler::RunImpl( } auto* while_op = *maybe_while_op; - TF_ASSIGN_OR_RETURN(auto maybe_new_op, ShiftDequantizationF8(comp)); + ASSIGN_OR_RETURN(auto maybe_new_op, ShiftDequantizationF8(comp)); if (maybe_new_op) { changed = true; while_op = maybe_new_op; @@ -1392,7 +1387,7 @@ absl::StatusOr WindowedEinsumHandler::RunImpl( for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { WindowedEinsumVisitor visitor(all_ag_loops_); - TF_RETURN_IF_ERROR(comp->Accept(&visitor)); + RETURN_IF_ERROR(comp->Accept(&visitor)); changed |= visitor.changed(); } @@ -1405,11 +1400,11 @@ absl::StatusOr WindowedEinsumHandler::RunImpl( // expects until the passes are applied. AlgebraicSimplifierOptions options; options.set_run_to_fixed_point(false); - TF_ASSIGN_OR_RETURN(bool applied_algsimp, AlgebraicSimplifier(options).Run( - module, execution_threads)); + ASSIGN_OR_RETURN(bool applied_algsimp, AlgebraicSimplifier(options).Run( + module, execution_threads)); changed |= applied_algsimp; - TF_ASSIGN_OR_RETURN(bool applied_cf, - HloConstantFolding().Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool applied_cf, + HloConstantFolding().Run(module, execution_threads)); changed |= applied_cf; } for (HloInstruction* loop : all_windowed_einsum_loops) { @@ -1427,7 +1422,7 @@ absl::StatusOr WindowedEinsumHandler::RunImpl( // We also need to keep the unrolled instructions in an isolated computation // unit such as a trivial loop so instructions here won't be fused with // other instructions later to disrupt the gemm-gemm overlap. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( UnrollResult result, WhileLoopUnroller::UnrollAndReturnReplacement( loop, /*unroll_factor=*/-1, /*wrap_in_trivial_loop=*/true, @@ -1438,7 +1433,7 @@ absl::StatusOr WindowedEinsumHandler::RunImpl( // unrolled which leaves the call graph non-flat. This is likely not the // optimal way to do things, but it preserves the previous behavior of // UnrollAndReturnReplacement which used to do it internally. - TF_RETURN_IF_ERROR(FlattenCallGraph().Run(module).status()); + RETURN_IF_ERROR(FlattenCallGraph().Run(module).status()); result.new_while_op->while_body()->SetAndSanitizeName( absl::StrCat("unrolled_", original_body_name)); @@ -1449,8 +1444,7 @@ absl::StatusOr WindowedEinsumHandler::RunImpl( // we add this attribute to it. result.new_while_op->set_frontend_attribute( "skip-simplify-while-loops_trip-count-one", "true"); - TF_RETURN_IF_ERROR( - PostProcessUnrolledLoop(result.new_while_op, stream_id)); + RETURN_IF_ERROR(PostProcessUnrolledLoop(result.new_while_op, stream_id)); } changed |= result.unrolled; } diff --git a/third_party/xla/xla/backends/interpreter/BUILD b/third_party/xla/xla/backends/interpreter/BUILD index 21cd43c8a00a73..b94fdf5a0cce4b 100644 --- a/third_party/xla/xla/backends/interpreter/BUILD +++ b/third_party/xla/xla/backends/interpreter/BUILD @@ -52,6 +52,7 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -94,6 +95,7 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:vlog_is_on", diff --git a/third_party/xla/xla/backends/interpreter/compiler.cc b/third_party/xla/xla/backends/interpreter/compiler.cc index 4eb90e2dc10b06..fa06942a4007fe 100644 --- a/third_party/xla/xla/backends/interpreter/compiler.cc +++ b/third_party/xla/xla/backends/interpreter/compiler.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/interpreter/executable.h" #include "xla/backends/interpreter/platform_id.h" #include "xla/hlo/evaluator/hlo_evaluator.h" @@ -112,7 +113,7 @@ absl::StatusOr> InterpreterCompiler::RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* /*stream_exec*/, const CompileOptions& /*options*/) { VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); - TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); + RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); return std::move(hlo_module); } @@ -123,7 +124,7 @@ absl::StatusOr> InterpreterCompiler::RunBackend( VLOG(1) << "Run backend " << hlo_module->name(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( DynamicDimensionInference dynamic_dimension_inference, DynamicDimensionInference::Run( hlo_module.get(), @@ -149,10 +150,10 @@ absl::StatusOr>> InterpreterCompiler::Compile(std::unique_ptr hlo_module, std::vector stream_exec, const CompileOptions& options) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( hlo_module, RunHloPasses(std::move(hlo_module), stream_exec[0], options)); - TF_ASSIGN_OR_RETURN(auto executable, RunBackend(std::move(hlo_module), - stream_exec[0], options)); + ASSIGN_OR_RETURN(auto executable, + RunBackend(std::move(hlo_module), stream_exec[0], options)); std::vector> ret; ret.push_back(std::move(executable)); return std::move(ret); diff --git a/third_party/xla/xla/backends/interpreter/executable_base.cc b/third_party/xla/xla/backends/interpreter/executable_base.cc index b2def3c2dd0755..dea9b658939804 100644 --- a/third_party/xla/xla/backends/interpreter/executable_base.cc +++ b/third_party/xla/xla/backends/interpreter/executable_base.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/layout_util.h" @@ -121,8 +122,8 @@ absl::StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( } } - TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, - TransferManager::GetForPlatform(platform)); + ASSIGN_OR_RETURN(TransferManager * transfer_manager, + TransferManager::GetForPlatform(platform)); // Transform the ShapedBuffer arguments into literals which the evaluator // consumes. @@ -130,9 +131,9 @@ absl::StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( const int64_t num_parameters = computation->num_parameters(); arg_literals.reserve(num_parameters); for (int64_t p = 0; p < num_parameters; ++p) { - TF_ASSIGN_OR_RETURN(Literal arg_literal, - transfer_manager->TransferLiteralFromDevice( - run_options->stream(), argument_buffers[p])); + ASSIGN_OR_RETURN(Literal arg_literal, + transfer_manager->TransferLiteralFromDevice( + run_options->stream(), argument_buffers[p])); const auto& expected_shape = computation->parameter_instruction(p)->shape(); if (expected_shape.is_dynamic()) { // Expand the input literal to expected shape. @@ -141,8 +142,8 @@ absl::StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( arg_literals.push_back(std::move(arg_literal)); } - TF_ASSIGN_OR_RETURN(Literal result_literal, - Evaluate(run_options, *computation, arg_literals)); + ASSIGN_OR_RETURN(Literal result_literal, + Evaluate(run_options, *computation, arg_literals)); // Shrink the generated dynamic shape into static shape. result_literal = result_literal.ToStatic(); @@ -150,12 +151,12 @@ absl::StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( const HloInputOutputAliasConfig& alias_config = has_module() ? module().input_output_alias_config() : HloInputOutputAliasConfig(); - TF_ASSIGN_OR_RETURN(ExecutionOutput result, - AllocateOutputMemoryWithInputReuse( - result_literal.shape(), alias_config, - run_options->allocator(), &arguments, stream)); + ASSIGN_OR_RETURN(ExecutionOutput result, + AllocateOutputMemoryWithInputReuse( + result_literal.shape(), alias_config, + run_options->allocator(), &arguments, stream)); - TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( + RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( run_options->stream(), result_literal, result.Result())); uint64_t end_micros = tsl::Env::Default()->NowMicros(); @@ -174,7 +175,7 @@ InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse( const Shape& shape, const HloInputOutputAliasConfig& alias_config, se::DeviceAddressAllocator* allocator, std::vector* arguments, se::Stream* stream) { - TF_RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus( + RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus( [&](const ShapeIndex& output_index, std::optional alias) -> absl::Status { @@ -195,8 +196,8 @@ InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse( se::StreamExecutor* executor = stream->parent(); const se::Platform* platform = executor->GetPlatform(); - TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, - TransferManager::GetForPlatform(platform)); + ASSIGN_OR_RETURN(TransferManager * transfer_manager, + TransferManager::GetForPlatform(platform)); ExecutionOutput result(shape, allocator, executor->device_ordinal()); for (auto& pair : result.MutableResult()->buffers()) { @@ -232,7 +233,7 @@ InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse( const Shape& on_device_shape = result.Result().on_device_shape(); const Shape& on_device_subshape = ShapeUtil::GetSubshape(on_device_shape, result_index); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto allocated_buffer, allocator->Allocate(executor->device_ordinal(), allocation_bytes, /*retry_on_failure=*/true, @@ -242,7 +243,7 @@ InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse( TF_RET_CHECK(allocation_bytes == 0 || result_buffer != nullptr); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( transfer_manager->WriteTupleIndexTables(stream, result.Result())); return std::move(result); } diff --git a/third_party/xla/xla/backends/profiler/subprocess/BUILD b/third_party/xla/xla/backends/profiler/subprocess/BUILD index 58556f36834a66..ded64e4c3f4611 100644 --- a/third_party/xla/xla/backends/profiler/subprocess/BUILD +++ b/third_party/xla/xla/backends/profiler/subprocess/BUILD @@ -64,6 +64,7 @@ cc_library( deps = [ ":subprocess_registry", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/profiler/utils:timestamp_utils", "//xla/tsl/profiler/utils:xplane_schema", diff --git a/third_party/xla/xla/backends/profiler/subprocess/subprocess_profiling_session.cc b/third_party/xla/xla/backends/profiler/subprocess/subprocess_profiling_session.cc index deff5952a11f3a..54f89b7bd8b81e 100644 --- a/third_party/xla/xla/backends/profiler/subprocess/subprocess_profiling_session.cc +++ b/third_party/xla/xla/backends/profiler/subprocess/subprocess_profiling_session.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "grpcpp/client_context.h" #include "grpcpp/support/status.h" #include "xla/backends/profiler/subprocess/subprocess_registry.h" @@ -106,7 +107,7 @@ absl::Status SubprocessProfilingSession::Stop() { terminate_request.set_session_id(request_.session_id()); tensorflow::TerminateResponse terminate_response; grpc::ClientContext context; - TF_RETURN_IF_ERROR(FromGrpcStatus(subprocess_info_.profiler_stub->Terminate( + RETURN_IF_ERROR(FromGrpcStatus(subprocess_info_.profiler_stub->Terminate( &context, terminate_request, &terminate_response))); // Wait for the response from the AsyncProfile+Finish calls. @@ -119,7 +120,7 @@ absl::Status SubprocessProfilingSession::Stop() { if (!success || !ok || got_tag != (void*)1) { return absl::InternalError("Failed to get response from profiler service"); } - TF_RETURN_IF_ERROR(FromGrpcStatus(grpc_status_)); + RETURN_IF_ERROR(FromGrpcStatus(grpc_status_)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/codegen/BUILD b/third_party/xla/xla/codegen/BUILD index 2126454edd4f01..f654260aa3f35b 100644 --- a/third_party/xla/xla/codegen/BUILD +++ b/third_party/xla/xla/codegen/BUILD @@ -29,6 +29,7 @@ cc_library( deps = [ ":kernel_definition", ":kernel_source", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -125,6 +126,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", "//xla/service:buffer_assignment", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/codegen/intrinsic/BUILD b/third_party/xla/xla/codegen/intrinsic/BUILD index 0156b3ad87c33a..061f75c03f7206 100644 --- a/third_party/xla/xla/codegen/intrinsic/BUILD +++ b/third_party/xla/xla/codegen/intrinsic/BUILD @@ -62,6 +62,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/service/llvm_ir:llvm_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/codegen/intrinsic/fptrunc.cc b/third_party/xla/xla/codegen/intrinsic/fptrunc.cc index d7430e665857e9..044aaca186bb50 100644 --- a/third_party/xla/xla/codegen/intrinsic/fptrunc.cc +++ b/third_party/xla/xla/codegen/intrinsic/fptrunc.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/FloatingPointMode.h" @@ -278,7 +279,7 @@ absl::StatusOr EmitFxxToF8E(llvm::Module* module, if (!from.is_scalar() && !from.is_vector()) { return absl::InvalidArgumentError("from_type must be a scalar or vector."); } - TF_RETURN_IF_ERROR(Type::VerifySameWidth(from, to)); + RETURN_IF_ERROR(Type::VerifySameWidth(from, to)); llvm::Function* func = CreateFunction(module, from, to); llvm::BasicBlock* entry_bb = llvm::BasicBlock::Create(context, "entry", func); @@ -348,7 +349,7 @@ absl::StatusOr EmitFxxToF8E(llvm::Module* module, // we can delegate all logic to EmitReducePrecisionIR and do a simple shift. if (fx_bias == f8_bias && fx_exp_bits == f8_exp_bits) { LOG(INFO) << "Using fast path for " << from.name() << " -> " << to.name(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( llvm::Value * reduced_precision, llvm_ir::EmitReducePrecisionIR( /*src_ty=*/fx_type, from_value, @@ -511,7 +512,7 @@ absl::StatusOr EmitFxxToF8E(llvm::Module* module, absl::StatusOr FpTrunc::CreateDefinition(llvm::Module* module, Type from, Type to) { - TF_RETURN_IF_ERROR(Type::VerifySameWidth(from, to)); + RETURN_IF_ERROR(Type::VerifySameWidth(from, to)); if (primitive_util::IsF8Type(to.element_type()) && (from.element_type() == F16 || from.element_type() == F32 || diff --git a/third_party/xla/xla/codegen/ir_emission_utils.cc b/third_party/xla/xla/codegen/ir_emission_utils.cc index 58ed76d4ee1a1d..00b733d5fa3f5e 100644 --- a/third_party/xla/xla/codegen/ir_emission_utils.cc +++ b/third_party/xla/xla/codegen/ir_emission_utils.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/codegen/hlo_fusion_spec.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -274,11 +275,11 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlace( root_index = {i}; } // Get output buffer for the fusion root. - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_buffer, - get_allocation_slice(fusion, root_index)); + ASSIGN_OR_RETURN(BufferAllocation::Slice output_buffer, + get_allocation_slice(fusion, root_index)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_buffer, - get_allocation_slice(&operand.instruction(), {})); + ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_buffer, + get_allocation_slice(&operand.instruction(), {})); if (lhs_buffer != output_buffer) { return false; } diff --git a/third_party/xla/xla/codegen/kernel_emitter.h b/third_party/xla/xla/codegen/kernel_emitter.h index 03c380c9cb5c55..251ecef89b3144 100644 --- a/third_party/xla/xla/codegen/kernel_emitter.h +++ b/third_party/xla/xla/codegen/kernel_emitter.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/codegen/kernel_definition.h" #include "xla/codegen/kernel_source.h" #include "xla/tsl/platform/statusor.h" @@ -64,7 +65,7 @@ class KernelEmitter : public KernelEmitterBase { private: absl::StatusOr> EmitKernelDefinitionBase() final { - TF_ASSIGN_OR_RETURN(auto kernel_definition, EmitKernelDefinition()); + ASSIGN_OR_RETURN(auto kernel_definition, EmitKernelDefinition()); return std::make_unique(std::move(kernel_definition)); } }; diff --git a/third_party/xla/xla/codegen/tiling/experimental/tile.cc b/third_party/xla/xla/codegen/tiling/experimental/tile.cc index c9355c02da2e95..9dd182a2c98806 100644 --- a/third_party/xla/xla/codegen/tiling/experimental/tile.cc +++ b/third_party/xla/xla/codegen/tiling/experimental/tile.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/codegen/tiling/experimental/tile.h" #include -#include #include #include diff --git a/third_party/xla/xla/codegen/xtile/codegen/BUILD b/third_party/xla/xla/codegen/xtile/codegen/BUILD index 4f3e33e90ae471..5294dbd399b832 100644 --- a/third_party/xla/xla/codegen/xtile/codegen/BUILD +++ b/third_party/xla/xla/codegen/xtile/codegen/BUILD @@ -51,10 +51,12 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", + "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:Support", @@ -157,6 +159,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/translate/hlo_to_mhlo:attribute_importer", "//xla/service:algorithm_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/codegen/xtile/codegen/dot_algorithms.cc b/third_party/xla/xla/codegen/xtile/codegen/dot_algorithms.cc index 66aa882460a072..3cf9df39cf8737 100644 --- a/third_party/xla/xla/codegen/xtile/codegen/dot_algorithms.cc +++ b/third_party/xla/xla/codegen/xtile/codegen/dot_algorithms.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -143,14 +144,14 @@ Value EmitStableHloDotAndAdd(mlir::ImplicitLocOpBuilder& b, Value lhs, absl::StatusOr GetAlgUnsetAccumulatorType(mlir::ImplicitLocOpBuilder& b, const HloDotInstruction& dot) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Type lhs_type, PrimitiveTypeToMlirType(b, dot.operand(0)->shape().element_type())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Type rhs_type, PrimitiveTypeToMlirType(b, dot.operand(1)->shape().element_type())); - TF_ASSIGN_OR_RETURN(Type accumulator_type, - PrimitiveTypeToMlirType(b, dot.shape().element_type())); + ASSIGN_OR_RETURN(Type accumulator_type, + PrimitiveTypeToMlirType(b, dot.shape().element_type())); // The code below assumes that lhs and rhs have the same type. However // this may not always be the case with f8 matmuls, e.g. e4m3×e5m2 is @@ -172,10 +173,10 @@ absl::StatusOr GetAlgUnsetAccumulatorType(mlir::ImplicitLocOpBuilder& b, absl::StatusOr> DotDefaultOperandsType( mlir::ImplicitLocOpBuilder& b, const HloDotInstruction& dot) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Type lhs_type, PrimitiveTypeToMlirType(b, dot.operand(0)->shape().element_type())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Type rhs_type, PrimitiveTypeToMlirType(b, dot.operand(1)->shape().element_type())); @@ -204,7 +205,7 @@ absl::StatusOr> GetForceOperandsType( return DotDefaultOperandsType(b, dot); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector allowed_operands_primitive_types, algorithm_util::GetAllowedOperandsTypeForAlgorithm(algorithm)); CHECK(!allowed_operands_primitive_types.empty()); @@ -212,7 +213,7 @@ absl::StatusOr> GetForceOperandsType( std::vector allowed_operands_types; allowed_operands_types.reserve(allowed_operands_primitive_types.size()); for (PrimitiveType primitive_type : allowed_operands_primitive_types) { - TF_ASSIGN_OR_RETURN(Type type, PrimitiveTypeToMlirType(b, primitive_type)); + ASSIGN_OR_RETURN(Type type, PrimitiveTypeToMlirType(b, primitive_type)); allowed_operands_types.push_back(type); } @@ -252,8 +253,8 @@ absl::StatusOr GetDotAccumulatorType(mlir::ImplicitLocOpBuilder& b, return GetAlgUnsetAccumulatorType(b, dot); } - TF_ASSIGN_OR_RETURN(PrimitiveType accumulator_type, - algorithm_util::GetDotAccumulatorType(algorithm)); + ASSIGN_OR_RETURN(PrimitiveType accumulator_type, + algorithm_util::GetDotAccumulatorType(algorithm)); return PrimitiveTypeToMlirType(b, accumulator_type); } @@ -268,11 +269,10 @@ absl::StatusOr EmitSingleTileDot(mlir::ImplicitLocOpBuilder& b, XlaPrecisionToStableHloPrecision( dot.precision_config().operand_precision(1))}; - TF_ASSIGN_OR_RETURN(std::optional force_operands_type, - GetForceOperandsType(b, dot, dot_operands)); + ASSIGN_OR_RETURN(std::optional force_operands_type, + GetForceOperandsType(b, dot, dot_operands)); - TF_ASSIGN_OR_RETURN(Type force_accumulator_type, - GetDotAccumulatorType(b, dot)); + ASSIGN_OR_RETURN(Type force_accumulator_type, GetDotAccumulatorType(b, dot)); if (force_operands_type.has_value()) { if (ElementType(dot_operands.lhs) != *force_operands_type) { diff --git a/third_party/xla/xla/codegen/xtile/codegen/emitter_helpers.cc b/third_party/xla/xla/codegen/xtile/codegen/emitter_helpers.cc index ff4e7b894abdd2..831c06f8f9ec43 100644 --- a/third_party/xla/xla/codegen/xtile/codegen/emitter_helpers.cc +++ b/third_party/xla/xla/codegen/xtile/codegen/emitter_helpers.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -40,6 +41,7 @@ limitations under the License. #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -608,8 +610,10 @@ Value Bitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type type) { auto tile_strides = tiled_hlo.tile_strides(); auto minor_to_major_layout = llvm::to_vector(LayoutUtil::MinorToMajor(shape)); + // Replica id is only supported for ge::TiledHloInstruction. return TileInfo(offsets, tile_strides, original_shape, padded_tile_sizes, - minor_to_major_layout, storage_type); + minor_to_major_layout, storage_type, + /*replica_id_offsets=*/{}, /*replica_id_bounds=*/{}); } /*static */ absl::StatusOr TileInfo::Construct( @@ -635,19 +639,76 @@ Value Bitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type type) { auto storage_type = StorageType(expected_element_type); auto minor_to_major_layout = llvm::to_vector(LayoutUtil::MinorToMajor(shape)); + SmallVector replica_id_offsets; + SmallVector replica_id_bounds; + if (!tiled_hlo.tile().replica_ids().empty()) { + llvm::SmallVector offset_exprs; + llvm::SmallVector bound_exprs; + offset_exprs.reserve(tiled_hlo.tile().replica_ids().size()); + bound_exprs.reserve(tiled_hlo.tile().replica_ids().size()); + for (const auto& replica_id : tiled_hlo.tile().replica_ids()) { + offset_exprs.push_back(replica_id.offset); + bound_exprs.push_back(replica_id.upper_bound); + } + ASSIGN_OR_RETURN(mlir::SmallVector evaluated_offsets, + emitter_ctx.EvaluateTilingParameters(offset_exprs)); + ASSIGN_OR_RETURN(mlir::SmallVector evaluated_bounds, + emitter_ctx.EvaluateTilingParameters(bound_exprs)); + replica_id_offsets = std::move(evaluated_offsets); + replica_id_bounds = std::move(evaluated_bounds); + } + + return TileInfo(std::move(offsets), std::move(tile_strides), + std::move(original_shape), std::move(tile_sizes), + std::move(minor_to_major_layout), storage_type, + std::move(replica_id_offsets), std::move(replica_id_bounds)); +} - return TileInfo(offsets, tile_strides, original_shape, tile_sizes, - minor_to_major_layout, storage_type); +absl::StatusOr GetConstantIntValue(mlir::Value value) { + if (std::optional int_value = mlir::getConstantIntValue(value); + int_value.has_value()) { + return int_value.value(); + } + return absl::InternalError(absl::StrFormat( + "Expected constant integer value for replica ID bound, but got: %v", + value)); } -TensorValue EmitParameterExtract(mlir::ImplicitLocOpBuilder& b, - const TileInfo& tile_info, Value arg) { +absl::StatusOr EmitParameterExtract(mlir::ImplicitLocOpBuilder& b, + const TileInfo& tile_info, + Value arg) { auto tensor_type = mlir::RankedTensorType::get(tile_info.padded_tile_sizes(), tile_info.storage_type()); - + mlir::Value source_buffer = arg; + if (!tile_info.replica_id_offsets().empty()) { + const auto& replica_id_offsets = tile_info.replica_id_offsets(); + const auto& replica_id_bounds = tile_info.replica_id_bounds(); + CHECK_EQ(replica_id_offsets.size(), replica_id_bounds.size()); + const int num_replica_dims = replica_id_offsets.size(); + for (int i = 0; i < num_replica_dims - 1; ++i) { + mlir::Value replica_id = replica_id_offsets[i]; + ASSIGN_OR_RETURN(int64_t next_bound, + GetConstantIntValue(replica_id_bounds[i + 1])); + mlir::Type next_buffer_type = + mlir::MemRefType::get({next_bound}, b.getI64Type()); + source_buffer = b.create( + next_buffer_type, source_buffer, replica_id); + } + // Final selection to obtain the spatial buffer + mlir::Value replica_id = replica_id_offsets.back(); + ASSIGN_OR_RETURN(PrimitiveType element_type, + GetPrimitiveType(tile_info.storage_type())); + xla::Shape spatial_shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + element_type, tile_info.original_shape(), + tile_info.minor_to_major_layout()); + mlir::Type spatial_memref_type = + GetMemRefType(spatial_shape, tile_info.storage_type()); + source_buffer = b.create(spatial_memref_type, + source_buffer, replica_id); + } return xla::xtile::ExtractTileOp::create( - b, tensor_type, arg, tile_info.offsets(), tile_info.padded_tile_sizes(), - tile_info.tile_strides()); + b, tensor_type, source_buffer, tile_info.offsets(), + tile_info.padded_tile_sizes(), tile_info.tile_strides()); } absl::StatusOr EmitScope( @@ -825,7 +886,8 @@ absl::StatusOr GetMlirType( absl::StatusOr> GetFnArgTypes( mlir::ImplicitLocOpBuilder& b, const HloFusionInstruction& fusion, absl::Span opaque_args_types, - const std::optional& gpu_cc) { + const std::optional& gpu_cc, + const DefaultTileRequirementsVisitor& tile_requirements_visitor) { SmallVector fn_arg_types; auto hlo_computation = fusion.fused_instructions_computation(); @@ -833,7 +895,18 @@ absl::StatusOr> GetFnArgTypes( for (HloInstruction* p : hlo_computation->parameter_instructions()) { ASSIGN_OR_RETURN(Type ir_type, GetMlirType(b, p->shape().element_type(), gpu_cc)); - fn_arg_types.push_back(GetMemRefType(p->shape(), ir_type)); + ASSIGN_OR_RETURN(SmallVector replica_id_bounds, + tile_requirements_visitor.RequiredReplicaIdBounds(*p)); + if (!replica_id_bounds.empty()) { + // Nested pointer schema for replica dimensions. + // R x S x where R is the number of replica dimensions and S is + // the shape on the local device. In total we have R pointers to + // S-dimensional tensors. + fn_arg_types.push_back( + mlir::MemRefType::get({replica_id_bounds.front()}, b.getI64Type())); + } else { + fn_arg_types.push_back(GetMemRefType(p->shape(), ir_type)); + } } // Add result types. diff --git a/third_party/xla/xla/codegen/xtile/codegen/emitter_helpers.h b/third_party/xla/xla/codegen/xtile/codegen/emitter_helpers.h index b8c4a06c17354a..03707d0f839054 100644 --- a/third_party/xla/xla/codegen/xtile/codegen/emitter_helpers.h +++ b/third_party/xla/xla/codegen/xtile/codegen/emitter_helpers.h @@ -50,6 +50,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" @@ -142,26 +143,34 @@ class TileInfo { const gpu::experimental::TiledHloInstruction& tiled_hlo); // Tile offsets. Its size is equal to the rank of the output shape. - inline mlir::ValueRange offsets() const { return offsets_; } + mlir::ValueRange offsets() const { return offsets_; } // Tile strides. Its size is equal to the rank of the output shape. - inline mlir::ArrayRef tile_strides() const { return tile_strides_; } + mlir::ArrayRef tile_strides() const { return tile_strides_; } // The original shape of the tensor. - inline mlir::ArrayRef original_shape() const { - return original_shape_; - } + mlir::ArrayRef original_shape() const { return original_shape_; } // Tile sizes after padding to a power of 2 (Triton requirement). - inline mlir::ArrayRef padded_tile_sizes() const { + mlir::ArrayRef padded_tile_sizes() const { return padded_tile_sizes_; } // The layout of the tensor in minor-to-major order. - inline const llvm::SmallVector& minor_to_major_layout() const { + const llvm::SmallVector& minor_to_major_layout() const { return minor_to_major_layout_; } + // The replica id offsets if the tensor has a replica dimension. + llvm::ArrayRef replica_id_offsets() const { + return replica_id_offsets_; + } + + // The replica id bounds if the tensor has a replica dimension. + llvm::ArrayRef replica_id_bounds() const { + return replica_id_bounds_; + } + // The storage type of the tensor. This could be different from the element // type. e.g. predicates are stored as i8 instead of i1. mlir::Type storage_type() const { return storage_type_; } @@ -173,19 +182,26 @@ class TileInfo { llvm::SmallVector padded_tile_sizes_; llvm::SmallVector minor_to_major_layout_; mlir::Type storage_type_; - - inline TileInfo(llvm::SmallVector offsets, - llvm::SmallVector tile_strides, - llvm::SmallVector original_shape, - llvm::SmallVector padded_tile_sizes, - llvm::SmallVector minor_to_major_layout, - mlir::Type storage_type) + llvm::SmallVector replica_id_offsets_; + llvm::SmallVector replica_id_bounds_; + + TileInfo(llvm::SmallVector offsets, // + llvm::SmallVector tile_strides, // + llvm::SmallVector original_shape, // + llvm::SmallVector padded_tile_sizes, // + llvm::SmallVector minor_to_major_layout, // + mlir::Type storage_type, // + llvm::SmallVector replica_id_offsets, // + llvm::SmallVector replica_id_bounds // + ) : offsets_(std::move(offsets)), tile_strides_(std::move(tile_strides)), original_shape_(std::move(original_shape)), padded_tile_sizes_(std::move(padded_tile_sizes)), minor_to_major_layout_(std::move(minor_to_major_layout)), - storage_type_(std::move(storage_type)) {} + storage_type_(std::move(storage_type)), + replica_id_offsets_(std::move(replica_id_offsets)), + replica_id_bounds_(std::move(replica_id_bounds)) {} }; // Triton requires that all block dimensions are a power of 2. @@ -290,8 +306,9 @@ mlir::Value Bitcast(mlir::ImplicitLocOpBuilder& b, mlir::Value value, mlir::Type type); // Emits an xtile::ExtractTileOp for the given tile info and argument. -TensorValue EmitParameterExtract(mlir::ImplicitLocOpBuilder& b, - const TileInfo& tile_info, mlir::Value arg); +absl::StatusOr EmitParameterExtract(mlir::ImplicitLocOpBuilder& b, + const TileInfo& tile_info, + mlir::Value arg); // Emits a sequence of HLO instructions within a specific scope. // @@ -343,11 +360,24 @@ absl::StatusOr GetMlirType( mlir::ImplicitLocOpBuilder& b, PrimitiveType type, const std::optional& gpu_cc); +// Visitor to determine tile based requirements while iterating over the fusion +// instructions. +struct DefaultTileRequirementsVisitor { + DefaultTileRequirementsVisitor() = default; + virtual ~DefaultTileRequirementsVisitor() = default; + virtual absl::StatusOr> RequiredReplicaIdBounds( + const HloInstruction& instr) const { + return llvm::SmallVector(); + } +}; + // Function to get the MLIR types from a HloFusionInstruction. absl::StatusOr> GetFnArgTypes( mlir::ImplicitLocOpBuilder& b, const HloFusionInstruction& fusion, absl::Span opaque_args_types, - const std::optional& gpu_cc); + const std::optional& gpu_cc, + const DefaultTileRequirementsVisitor& tile_requirements_visitor = + DefaultTileRequirementsVisitor()); // Function to check if the operands of a concatenation are valid for tiling. absl::Status CheckConcatenateOperands( diff --git a/third_party/xla/xla/codegen/xtile/codegen/experimental_fusion_emitter.cc b/third_party/xla/xla/codegen/xtile/codegen/experimental_fusion_emitter.cc index f290baaf42bc99..d7ea4520688b94 100644 --- a/third_party/xla/xla/codegen/xtile/codegen/experimental_fusion_emitter.cc +++ b/third_party/xla/xla/codegen/xtile/codegen/experimental_fusion_emitter.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -60,6 +61,7 @@ limitations under the License. #include "xla/codegen/xtile/ir/xtile_ops.h" #include "xla/hlo/analysis/indexing_map_serialization.h" // IWYU pragma: keep #include "xla/hlo/analysis/interval.h" +#include "xla/hlo/analysis/symbolic_expr.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -198,8 +200,7 @@ absl::StatusOr EmitConcatenate( ASSIGN_OR_RETURN(TileInfo tile_info, TileInfo::Construct(emitter_ctx, tiled_concat)); - TF_RETURN_IF_ERROR( - CheckConcatenateOperands(*hlo_concat, concat_dim_tile_size)); + RETURN_IF_ERROR(CheckConcatenateOperands(*hlo_concat, concat_dim_tile_size)); Type result_type = mlir::RankedTensorType::get(tile_sizes, tile_info.storage_type()); @@ -825,8 +826,10 @@ absl::StatusOr EmitTiledHloInstruction( } ASSIGN_OR_RETURN(TileInfo tile_info, TileInfo::Construct(emitter_ctx, tiled_hlo)); - TensorValue parameter = EmitParameterExtract( - b, tile_info, emitter_ctx.entry_func().getArgument(arg_index)); + ASSIGN_OR_RETURN( + TensorValue parameter, + EmitParameterExtract(b, tile_info, + emitter_ctx.entry_func().getArgument(arg_index))); // Workaround(i1_to_i8_workaround) // Some types are stored using different types, e.g. i1 is stored in memory @@ -867,6 +870,13 @@ absl::StatusOr EmitTiledHloInstruction( } // Please keep the cases in alphabetical order. switch (hlo->opcode()) { + case HloOpcode::kAllGather: + case HloOpcode::kAllGatherStart: + case HloOpcode::kAllGatherDone: { + // AllGatherStart and AllGatherDone are no-ops. + // Tile extraction handles the data movement. + return emitter_ctx.TiledHloToTensorValue(*tiled_hlo.operand(0)); + } case (HloOpcode::kAllReduceStart): { const HloComputation* computation = fusion.fused_instructions_computation(); @@ -1042,6 +1052,68 @@ absl::Status EmitGeneric(ImplicitLocOpBuilder& b, return absl::OkStatus(); } +// Implementation for the experimental tiling space. +class TileRequirementsVisitor : public DefaultTileRequirementsVisitor { + public: + explicit TileRequirementsVisitor(const ge::TiledHloComputation& computation) { + for (const auto& tiled_hlo : computation.tiled_hlo_instructions()) { + PopulateMap(tiled_hlo.get()); + } + } + + absl::StatusOr> RequiredReplicaIdBounds( + const HloInstruction& instr) const override { + ASSIGN_OR_RETURN(auto tiled_hlo, LookupTiledHlo(&instr)); + llvm::SmallVector bounds; + bounds.reserve(tiled_hlo->tile().replica_ids().size()); + for (const auto& replica_id : tiled_hlo->tile().replica_ids()) { + SymbolicExpr upper_bound = replica_id.upper_bound.Canonicalize(); + if (upper_bound.GetType() != SymbolicExprType::kConstant) { + return absl::InternalError( + absl::StrCat("Replica ID bound expression is not a constant: ", + upper_bound.ToString())); + } + bounds.push_back(upper_bound.GetValue()); + } + return bounds; + } + + private: + // Look up the instruction in the tiled HLO map. + // For parameters to nested fusions, we walk up the parameter chain to find + // the outermost operand index. + absl::StatusOr LookupTiledHlo( + const HloInstruction* original_instr) const { + auto it = hlo_to_tiled_.find(original_instr); + if (it != hlo_to_tiled_.end()) { + return it->second; + } + if (original_instr->opcode() == HloOpcode::kParameter) { + if (auto* fusion = original_instr->parent()->FusionInstruction()) { + const HloInstruction* resolved_instr = + fusion->operand(original_instr->parameter_number()); + return LookupTiledHlo(resolved_instr); + } + } + return absl::InternalError(absl::StrCat( + "InternalError: HLO instruction not found in tiled HLO map: ", + original_instr->ToString())); + } + + void PopulateMap(const ge::TiledHloInstruction* tiled_hlo) { + hlo_to_tiled_[tiled_hlo->hlo()] = tiled_hlo; + for (const auto& region : tiled_hlo->hlo_regions()) { + for (const auto& region_instruction : region) { + PopulateMap(region_instruction.get()); + } + } + } + + absl::flat_hash_map + hlo_to_tiled_; +}; + } // namespace // TODO(b/447133106): Contrary to the name, this function still does a lot of @@ -1065,7 +1137,8 @@ absl::StatusOr> EmitXTileModule( // Compute function argument types. ASSIGN_OR_RETURN(SmallVector fn_arg_types, - GetFnArgTypes(b, fusion, opaque_args_types, gpu_cc)); + GetFnArgTypes(b, fusion, opaque_args_types, gpu_cc, + TileRequirementsVisitor(tiled_computation))); // Metadata arguments are opaque to the tiling infra. llvm::SmallVector named_attributes{b.getNamedAttr( "num_opaque_args", b.getI32IntegerAttr(opaque_args_types.size()))}; @@ -1076,7 +1149,7 @@ absl::StatusOr> EmitXTileModule( b.setInsertionPointToStart(&fn.front()); ASSIGN_OR_RETURN(auto schedule, GetSchedule(tiled_computation)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( EmitGeneric(b, fusion, tiled_computation, schedule, fn, &mlir_context)); b.create(); @@ -1093,7 +1166,7 @@ absl::StatusOr> EmitXTileModule( mlir::PassManager pm(&mlir_context); pm.addPass(xtile::createVerifyLegalXTileOpsPass()); tsl::StatusScopedDiagnosticHandler diagnostic_handler(&mlir_context); - TF_RETURN_IF_ERROR(diagnostic_handler.consumeStatus(pm.run(*xtile_module))); + RETURN_IF_ERROR(diagnostic_handler.consumeStatus(pm.run(*xtile_module))); } return xtile_module; } diff --git a/third_party/xla/xla/codegen/xtile/codegen/experimental_fusion_emitter.h b/third_party/xla/xla/codegen/xtile/codegen/experimental_fusion_emitter.h index 0031a66c5fa337..d2ac6bdb878e05 100644 --- a/third_party/xla/xla/codegen/xtile/codegen/experimental_fusion_emitter.h +++ b/third_party/xla/xla/codegen/xtile/codegen/experimental_fusion_emitter.h @@ -28,8 +28,6 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "xla/autotuning.pb.h" #include "xla/codegen/tiling/experimental/tiled_hlo.h" -#include "xla/codegen/tiling/symbolic_tile_analysis.h" -#include "xla/codegen/tiling/tiling_specification.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/stream_executor/device_description.h" diff --git a/third_party/xla/xla/codegen/xtile/codegen/fusion_emitter.cc b/third_party/xla/xla/codegen/xtile/codegen/fusion_emitter.cc index f3aaea02e4c1ad..21b996cb35578e 100644 --- a/third_party/xla/xla/codegen/xtile/codegen/fusion_emitter.cc +++ b/third_party/xla/xla/codegen/xtile/codegen/fusion_emitter.cc @@ -169,7 +169,7 @@ absl::StatusOr EmitReduce( stablehlo::ReduceOp reduction = stablehlo::ReduceOp::create( b, input, init_value, hlo_reduce.dimensions()); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( EmitReduceComputation(b, &hlo_reduce, hlo_reduce.to_apply(), reduction)); return mlir::cast(reduction.getResult(0)); @@ -213,8 +213,8 @@ absl::StatusOr EmitTiledIota( // We can treat iota more or less as a parameter load, except that we need to // generate the right values in the right place as opposed to loading them. - TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, - tiled_iota.tile_offsets_indexing()); + ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, + tiled_iota.tile_offsets_indexing()); auto iota_dim_offset = Cast(b, @@ -235,7 +235,7 @@ absl::StatusOr EmitTiledIota( b, range, xtile::Splat(b, iota_dim_offset, padded_tile_sizes[iota_dim])); // Cast the result to the targeted type. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Type iota_element_type, PrimitiveTypeToMlirType(b, hlo_iota->shape().element_type())); @@ -281,9 +281,8 @@ absl::StatusOr EmitTiledBitcast( "Bitcast with different bitwidth for operand and output shape " "element type is not yet supported."); } - TF_ASSIGN_OR_RETURN( - Type output_element_type, - PrimitiveTypeToMlirType(b, output_shape.element_type())); + ASSIGN_OR_RETURN(Type output_element_type, + PrimitiveTypeToMlirType(b, output_shape.element_type())); auto output_type = mlir::RankedTensorType::get( GetPaddedTileSizes(tiled_bitcast.operand(0)->tile_sizes()), output_element_type); @@ -328,9 +327,8 @@ absl::StatusOr EmitTiledBitcast( if (ShapeUtil::Equal(trt->transpose1_shape, trt->reshape_shape)) { normalized_reshape = normalized_input; } else { - TF_ASSIGN_OR_RETURN( - normalized_reshape, - EmitTiledReshape(b, reshape_tile_sizes, normalized_input)); + ASSIGN_OR_RETURN(normalized_reshape, + EmitTiledReshape(b, reshape_tile_sizes, normalized_input)); } // The final transpose simply uses the tile sizes computed for the original @@ -424,9 +422,9 @@ absl::StatusOr MaskDotOperand( mask = xtile::BroadcastInDims(b, mlir::cast(mask), tile_shape, {contraction_dimension_index}); - TF_ASSIGN_OR_RETURN(auto element_type, - PrimitiveTypeToMlirType( - b, dot_operand.hlo()->shape().element_type())); + ASSIGN_OR_RETURN(auto element_type, + PrimitiveTypeToMlirType( + b, dot_operand.hlo()->shape().element_type())); TensorValue zero = CreateConst(b, element_type, 0.0f, tile_shape); @@ -460,7 +458,7 @@ absl::Status EmitTiledInstructionList( for (const std::unique_ptr& tiled_hlo : tiled_instructions) { const HloInstruction* hlo = tiled_hlo->hlo(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( TensorValue result, EmitTiledHloInstruction(b, fusion, *tiled_hlo, fn, pid, values)); TF_RET_CHECK(values.insert({tiled_hlo.get(), result}).second) @@ -498,13 +496,12 @@ absl::StatusOr EmitDot( // The specific accumulator type to use may not correspond to the output type // of the dot. In particular, that is the case when an algorithm is specified // and the dot's output type does not match its expectations. - TF_ASSIGN_OR_RETURN(Type accumulator_type, - xtile::GetDotAccumulatorType(b, dot)); + ASSIGN_OR_RETURN(Type accumulator_type, xtile::GetDotAccumulatorType(b, dot)); TensorValue accumulator = CreateConst(b, accumulator_type, 0.0f, padded_tile_sizes); - TF_ASSIGN_OR_RETURN(int64_t loop_iteration_count, - GetDotLoopIterationCount(tiled_hlo_dot)); + ASSIGN_OR_RETURN(int64_t loop_iteration_count, + GetDotLoopIterationCount(tiled_hlo_dot)); auto ctx = b.getContext(); auto pid_dim = CreateDimExpr(0, ctx); auto ki_symbol = CreateSymbolExpr(0, /*num_dims=*/1, ctx); @@ -533,7 +530,7 @@ absl::StatusOr EmitDot( Value computation_index = xla::ApplyIndexingOp::create( b, ValueRange{pid, ki}, computation_index_map) .getResult(0); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( EmitTiledInstructionList(b, fusion, tiled_hlo_dot.hlo_regions().front(), fn, computation_index, values)); SmallVector dot_args; @@ -552,17 +549,15 @@ absl::StatusOr EmitDot( dot.dot_dimension_numbers().rhs_contracting_dimensions(0); Value ki_i32 = Cast(b, ki, b.getI32Type()); - TF_ASSIGN_OR_RETURN( - TensorValue lhs, - MaskDotOperand(b, *tiled_hlo_dot.operand(0), dot_args[0], ki_i32, - lhs_contracting_dim_idx)); + ASSIGN_OR_RETURN(TensorValue lhs, + MaskDotOperand(b, *tiled_hlo_dot.operand(0), dot_args[0], + ki_i32, lhs_contracting_dim_idx)); - TF_ASSIGN_OR_RETURN( - TensorValue rhs, - MaskDotOperand(b, *tiled_hlo_dot.operand(1), dot_args[1], ki_i32, - rhs_contracting_dim_idx)); + ASSIGN_OR_RETURN(TensorValue rhs, + MaskDotOperand(b, *tiled_hlo_dot.operand(1), dot_args[1], + ki_i32, rhs_contracting_dim_idx)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Value acc_next, xtile::EmitSingleTileDot(b, dot, xtile::DotOperands{lhs, rhs, acc})); mlir::scf::YieldOp::create(b, acc_next); @@ -570,8 +565,8 @@ absl::StatusOr EmitDot( // The output of the loop may not match the expected output type of the dot. // We make sure to issue a conversion if necessary. - TF_ASSIGN_OR_RETURN(Type dot_output_type, - PrimitiveTypeToMlirType(b, dot.shape().element_type())); + ASSIGN_OR_RETURN(Type dot_output_type, + PrimitiveTypeToMlirType(b, dot.shape().element_type())); Value result = for_op.getResult(0); if (dot_output_type != accumulator_type) { @@ -613,8 +608,8 @@ absl::StatusOr EmitScaledDot( TensorValue accumulator = CreateConst(b, accumulator_type, 0.0f, padded_tile_sizes); - TF_ASSIGN_OR_RETURN(int64_t loop_iteration_count, - GetDotLoopIterationCount(tiled_hlo_dot)); + ASSIGN_OR_RETURN(int64_t loop_iteration_count, + GetDotLoopIterationCount(tiled_hlo_dot)); auto ctx = b.getContext(); auto pid_dim = CreateDimExpr(0, ctx); auto ki_symbol = CreateSymbolExpr(0, /*num_dims=*/1, ctx); @@ -641,7 +636,7 @@ absl::StatusOr EmitScaledDot( Value computation_index = xla::ApplyIndexingOp::create( b, ValueRange{pid, ki}, computation_index_map) .getResult(0); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( EmitTiledInstructionList(b, fusion, tiled_hlo_dot.hlo_regions().front(), fn, computation_index, values)); SmallVector dot_args; @@ -654,7 +649,7 @@ absl::StatusOr EmitScaledDot( Value acc = for_op.getRegionIterArgs().front(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Value acc_next, xtile::EmitSingleTileScaledDot( b, dot, @@ -667,8 +662,8 @@ absl::StatusOr EmitScaledDot( // The output of the loop may not match the expected output type of the dot. // We make sure to issue a conversion if necessary. - TF_ASSIGN_OR_RETURN(Type dot_output_type, - PrimitiveTypeToMlirType(b, dot.shape().element_type())); + ASSIGN_OR_RETURN(Type dot_output_type, + PrimitiveTypeToMlirType(b, dot.shape().element_type())); Value result = for_op.getResult(0); if (dot_output_type != accumulator_type) { @@ -699,18 +694,17 @@ absl::StatusOr EmitConcatenate( GetPaddedTileSizes(tiled_concatenate.tile_sizes()); int64_t concat_dim_tile_size = padded_tile_sizes[concatenate_dimension]; - TF_RETURN_IF_ERROR( - CheckConcatenateOperands(*hlo_concat, concat_dim_tile_size)); - TF_ASSIGN_OR_RETURN(Type element_type, - PrimitiveTypeToMlirType( - b, tiled_concatenate.hlo()->shape().element_type())); + RETURN_IF_ERROR(CheckConcatenateOperands(*hlo_concat, concat_dim_tile_size)); + ASSIGN_OR_RETURN(Type element_type, + PrimitiveTypeToMlirType( + b, tiled_concatenate.hlo()->shape().element_type())); Type result_type = mlir::RankedTensorType::get(padded_tile_sizes, element_type); // We will load and compute from a single operand, so we need to figure out // which one by looking at the offset within the concatenation dimension. - TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, - tiled_concatenate.tile_offsets_indexing()); + ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, + tiled_concatenate.tile_offsets_indexing()); Value concatenate_dimension_offset = emitters::ApplyIndexing(tile_offsets_indexing, /*dims=*/pid, @@ -776,8 +770,8 @@ absl::StatusOr EmitPad( const auto& pad_input_shape = tiled_operand->hlo()->shape().dimensions(); // Compute tile offsets. - TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, - tiled_pad.tile_offsets_indexing()); + ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, + tiled_pad.tile_offsets_indexing()); SmallVector tile_offsets = emitters::ApplyIndexing(tile_offsets_indexing, /*dims=*/pid, /*symbols=*/{}, b); @@ -862,7 +856,7 @@ absl::StatusOr EmitAllReduce( std::optional channel_handle = all_reduce.channel_id(); bool use_global_device_ids = all_reduce.use_global_device_ids(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto output_element_type, xtile::PrimitiveTypeToMlirType(b, all_reduce.shape().element_type())); auto output_type = mlir::RankedTensorType::get(tiled_hlo_reduce.tile_sizes(), @@ -884,8 +878,8 @@ absl::StatusOr EmitAllReduce( b, b.getLoc(), output_type, mlir::ValueRange(operands), replica_groups_attr, channel_handle_attr, use_global_device_ids); - TF_RETURN_IF_ERROR(EmitReduceComputation( - b, &all_reduce, all_reduce.to_apply(), all_reduce_op)); + RETURN_IF_ERROR(EmitReduceComputation(b, &all_reduce, all_reduce.to_apply(), + all_reduce_op)); return mlir::cast(all_reduce_op.getResult(0)); } @@ -909,12 +903,13 @@ absl::StatusOr EmitTiledHloInstruction( arg_index = hlo->parameter_number(); // Nested operands are parameters. hlo = instr->operand(arg_index); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( TileInfo tile_info, TileInfo::Construct(b, pid, GetRuntimeValues(tiled_hlo, values), tiled_hlo)); - TensorValue parameter = - EmitParameterExtract(b, tile_info, fn.getArgument(arg_index)); + TF_ASSIGN_OR_RETURN( + TensorValue parameter, + EmitParameterExtract(b, tile_info, fn.getArgument(arg_index))); // Workaround(i1_to_i8_workaround) // Some types are stored using different types, e.g. i1 is stored in memory @@ -922,9 +917,8 @@ absl::StatusOr EmitTiledHloInstruction( // loading if the type of the loaded parameter does not match what is // expected. Type loaded_element_type = getElementTypeOrSelf(parameter.getType()); - TF_ASSIGN_OR_RETURN( - Type expected_element_type, - PrimitiveTypeToMlirType(b, hlo->shape().element_type())); + ASSIGN_OR_RETURN(Type expected_element_type, + PrimitiveTypeToMlirType(b, hlo->shape().element_type())); if (expected_element_type != loaded_element_type) { // Ensure that we didn't mess up somewhere else by checking that we @@ -1000,7 +994,7 @@ absl::StatusOr EmitTiledHloInstruction( for (const TiledHloInstruction* operand : tiled_hlo.operands()) { operands.push_back(values[operand]); } - TF_ASSIGN_OR_RETURN(Value result, EmitElementwise(b, *hlo, operands)); + ASSIGN_OR_RETURN(Value result, EmitElementwise(b, *hlo, operands)); return mlir::cast(result); } @@ -1060,7 +1054,7 @@ absl::StatusOr> EmitTiledComputation( VLOG(1) << "Skipping nested fusion: " << hlo->ToString(); continue; } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( TensorValue result, EmitTiledHloInstruction(b, fusion, *tiled_hlo, fn, pid, values)); TF_RET_CHECK(values.insert({tiled_hlo, result}).second) << hlo->ToString(); @@ -1135,19 +1129,19 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, } } TF_RET_CHECK(root_index < symbolic_tile_analysis.GetRoots().size()); - TF_ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, - symbolic_tile_analysis.ComputeTiledComputation( - tiling, schedule_builder, - /*constraints_are_known_satisfied=*/false, - /*compute_all_tile_offset_indexing_maps=*/true)); + ASSIGN_OR_RETURN(TiledHloComputation tiled_hlo_computation, + symbolic_tile_analysis.ComputeTiledComputation( + tiling, schedule_builder, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/true)); VLOG(3) << "EmitGeneric: tiled HLO computation:\n" << tiled_hlo_computation.ToString(); Value tile_id = fn.getTileId(); absl::flat_hash_map values; - TF_ASSIGN_OR_RETURN(auto results, - EmitTiledComputation(b, fusion, tiled_hlo_computation, fn, - tile_id, values)); + ASSIGN_OR_RETURN(auto results, + EmitTiledComputation(b, fusion, tiled_hlo_computation, fn, + tile_id, values)); const HloComputation* computation = fusion.fused_instructions_computation(); for (auto [root, result, arg] : @@ -1164,7 +1158,7 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, result = mlir::cast(Cast(b, result, result_storage_type)); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tile_info, TileInfo::Construct(b, tile_id, /*runtime_values=*/{}, *root)); @@ -1205,8 +1199,8 @@ absl::StatusOr> EmitXTileModule( b.setInsertionPointToEnd(xtile_module->getBody()); // Build Triton kernel. - TF_ASSIGN_OR_RETURN(auto fn_arg_types, - GetFnArgTypes(b, fusion, opaque_args_types, gpu_cc)); + ASSIGN_OR_RETURN(auto fn_arg_types, + GetFnArgTypes(b, fusion, opaque_args_types, gpu_cc)); // Metadata arguments are opaque to the tiling infra. llvm::SmallVector named_attributes{b.getNamedAttr( @@ -1218,8 +1212,8 @@ absl::StatusOr> EmitXTileModule( fn.addEntryBlock(); b.setInsertionPointToStart(&fn.front()); - TF_RETURN_IF_ERROR(EmitGeneric(b, fusion, symbolic_tile_analysis, tiling, fn, - &mlir_context)); + RETURN_IF_ERROR(EmitGeneric(b, fusion, symbolic_tile_analysis, tiling, fn, + &mlir_context)); b.create(); @@ -1229,7 +1223,7 @@ absl::StatusOr> EmitXTileModule( mlir::PassManager pm(&mlir_context); pm.addPass(xtile::createVerifyLegalXTileOpsPass()); tsl::StatusScopedDiagnosticHandler diagnostic_handler(&mlir_context); - TF_RETURN_IF_ERROR(diagnostic_handler.consumeStatus(pm.run(*xtile_module))); + RETURN_IF_ERROR(diagnostic_handler.consumeStatus(pm.run(*xtile_module))); } return xtile_module; } diff --git a/third_party/xla/xla/core/collectives/BUILD b/third_party/xla/xla/core/collectives/BUILD index cb81e262b91f3d..9adecfb8e79871 100644 --- a/third_party/xla/xla/core/collectives/BUILD +++ b/third_party/xla/xla/core/collectives/BUILD @@ -52,6 +52,7 @@ cc_library( ":collectives", "//xla:util", "//xla/service:platform_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", diff --git a/third_party/xla/xla/core/collectives/collectives_registry.cc b/third_party/xla/xla/core/collectives/collectives_registry.cc index 618cc3f3d36060..0d7a52ca57ff80 100644 --- a/third_party/xla/xla/core/collectives/collectives_registry.cc +++ b/third_party/xla/xla/core/collectives/collectives_registry.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/core/collectives/collectives.h" #include "xla/service/platform_util.h" #include "xla/tsl/platform/statusor.h" @@ -67,8 +68,8 @@ static Registry& GetCollectivesRegistry() { absl::Status CollectivesRegistry::Register( absl::string_view platform_name, absl::string_view name, int32_t priority, std::unique_ptr collectives) { - TF_ASSIGN_OR_RETURN(std::string canonical_platform_name, - PlatformUtil::CanonicalPlatformName(platform_name)); + ASSIGN_OR_RETURN(std::string canonical_platform_name, + PlatformUtil::CanonicalPlatformName(platform_name)); auto& registry = GetCollectivesRegistry(); absl::MutexLock lock(registry.mu); @@ -84,8 +85,8 @@ absl::Status CollectivesRegistry::Register( absl::StatusOr CollectivesRegistry::Default( absl::string_view platform_name) { - TF_ASSIGN_OR_RETURN(std::string canonical_platform_name, - PlatformUtil::CanonicalPlatformName(platform_name)); + ASSIGN_OR_RETURN(std::string canonical_platform_name, + PlatformUtil::CanonicalPlatformName(platform_name)); auto& registry = GetCollectivesRegistry(); absl::MutexLock lock(registry.mu); @@ -101,8 +102,8 @@ absl::StatusOr CollectivesRegistry::Default( absl::StatusOr CollectivesRegistry::Get( absl::string_view platform_name, absl::string_view implementation_name) { - TF_ASSIGN_OR_RETURN(std::string canonical_platform_name, - PlatformUtil::CanonicalPlatformName(platform_name)); + ASSIGN_OR_RETURN(std::string canonical_platform_name, + PlatformUtil::CanonicalPlatformName(platform_name)); auto& registry = GetCollectivesRegistry(); absl::MutexLock lock(registry.mu); diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 6fd3b6a3195f0e..297cea6679fa07 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -441,8 +441,10 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { DebugOptions::PGLE_STRICTNESS_LEVEL_WARN); opts.set_xla_gpu_executable_embed_debug_info(true); - opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10); + opts.set_xla_gpu_executable_num_compute_streams(0); + opts.set_xla_gpu_executable_num_communication_streams(0); opts.set_xla_gpu_executable_terminate_timeout_seconds(30); + opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10); opts.set_xla_gpu_execution_terminate_timeout("inf"); opts.set_xla_gpu_execution_progress_tracking(0); @@ -493,6 +495,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { DebugOptions::COLLECTIVES_PRIVATE_MEMORY); opts.set_xla_gpu_experimental_use_ragged_dot_grouped_gemm(true); opts.set_xla_gpu_native_emitter_tune_unroll_factor_for_loops(false); + opts.set_xla_gpu_experimental_use_ragged_dot_fusion(true); opts.set_xla_cpu_collective_call_warn_stuck_seconds(20); opts.set_xla_cpu_collective_call_terminate_timeout_seconds(40); @@ -2664,18 +2667,32 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_executable_embed_debug_info(), "Add debug information to the executable such as HLO module, asm_text " "etc.")); + flag_list->push_back( + tsl::Flag("xla_gpu_executable_num_compute_streams", + int32_setter_for( + &DebugOptions::set_xla_gpu_executable_num_compute_streams), + debug_options->xla_gpu_executable_num_compute_streams(), + "Number of additional compute streams to allocate for a GPU " + "executable.")); flag_list->push_back(tsl::Flag( - "xla_gpu_executable_warn_stuck_timeout", + "xla_gpu_executable_num_communication_streams", int32_setter_for( - &DebugOptions::set_xla_gpu_executable_warn_stuck_timeout_seconds), - debug_options->xla_gpu_executable_warn_stuck_timeout_seconds(), - "Set timeout for Rendezvous stuck warning")); + &DebugOptions::set_xla_gpu_executable_num_communication_streams), + debug_options->xla_gpu_executable_num_communication_streams(), + "Number of additional communication streams to allocate for a GPU " + "executable.")); flag_list->push_back(tsl::Flag( "xla_gpu_executable_terminate_timeout", int32_setter_for( &DebugOptions::set_xla_gpu_executable_terminate_timeout_seconds), debug_options->xla_gpu_executable_terminate_timeout_seconds(), "Set timeout for Rendezvous termination")); + flag_list->push_back(tsl::Flag( + "xla_gpu_executable_warn_stuck_timeout", + int32_setter_for( + &DebugOptions::set_xla_gpu_executable_warn_stuck_timeout_seconds), + debug_options->xla_gpu_executable_warn_stuck_timeout_seconds(), + "Set timeout for Rendezvous stuck warning")); flag_list->push_back(tsl::Flag( "xla_gpu_execution_terminate_timeout", diff --git a/third_party/xla/xla/debug_options_parsers_test.cc b/third_party/xla/xla/debug_options_parsers_test.cc index 8a8c03c9e28f6c..4f21dfe93d43f8 100644 --- a/third_party/xla/xla/debug_options_parsers_test.cc +++ b/third_party/xla/xla/debug_options_parsers_test.cc @@ -524,11 +524,12 @@ TEST(ParseRepeatedEnumFlagsTest, AutotuneBackend) { autotuner::Backend::TRITON)); // Adding / removing options from the existing setting. - SetXlaFlagsEnvVar("--xla_gpu_experimental_autotune_backends=+cublas,-triton"); + SetXlaFlagsEnvVar( + "--xla_gpu_experimental_autotune_backends=+cublaslt,-triton"); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects); EXPECT_EQ(enabled_backends.size(), 2); EXPECT_THAT(enabled_backends, ElementsAre(autotuner::Backend::CUDNN, - autotuner::Backend::CUBLAS)); + autotuner::Backend::CUBLASLT)); } TEST(CollectivesModeParsingTest, CaseInsensitive) { diff --git a/third_party/xla/xla/examples/axpy/BUILD b/third_party/xla/xla/examples/axpy/BUILD index 1e65d2ce568221..fb62a4627940c0 100644 --- a/third_party/xla/xla/examples/axpy/BUILD +++ b/third_party/xla/xla/examples/axpy/BUILD @@ -24,6 +24,7 @@ xla_cc_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc index 6e93595bb0d7d0..2f2bf7819dfd45 100644 --- a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc +++ b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" @@ -85,7 +86,7 @@ class StableHloAxpyTest : public ::testing::Test { // Read StableHLO program to string. std::string program_string; - TF_RETURN_IF_ERROR(tsl::ReadFileToString( + RETURN_IF_ERROR(tsl::ReadFileToString( tsl::Env::Default(), std::string(program_path), &program_string)); std::cerr << "Loaded StableHLO program from " << program_path << ":\n" diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index 0820a2cfa752f2..b383c52a4a0b4d 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -63,6 +63,7 @@ cc_library( ], deps = [ ":backend_config", + ":hlo_payload_deduplicator", ":hlo_sharding", ":mesh_and_axis", ":named_sharding", @@ -355,6 +356,28 @@ xla_cc_test( ], ) +cc_library( + name = "hlo_payload_deduplicator", + srcs = ["hlo_payload_deduplicator.cc"], + hdrs = ["hlo_payload_deduplicator.h"], + deps = [ + ":backend_config", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "hlo_payload_deduplicator_test", + srcs = ["hlo_payload_deduplicator_test.cc"], + deps = [ + ":backend_config", + ":hlo_payload_deduplicator", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "hlo_instruction_utils", srcs = ["hlo_instruction_utils.cc"], diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index dde957b45ef367..49c7a36035c6eb 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -55,6 +55,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_payload_deduplicator.h" #include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/ir/ptrvec.h" #include "xla/literal.h" @@ -1274,14 +1275,15 @@ absl::Cord HloComputation::ToCord( return std::move(printer).ToCord(); } -void HloComputation::ToProto(HloComputationProto* proto) const { +void HloComputation::ToProto(HloComputationProto* proto, + HloPayloadDeduplicator* deduplicator) const { CHECK(unique_id_ != -1) << "This computation does not have a valid id. Please make sure the " "computation is inside a module before dumping it."; proto->set_id(unique_id_); proto->set_name(name_); for (const HloInstruction* instruction : MakeInstructionPostOrder()) { - instruction->ToProto(proto->add_instructions()); + instruction->ToProto(proto->add_instructions(), deduplicator); } proto->set_root_id(root_instruction()->unique_id()); ComputeProgramShape().ToProto(*proto->mutable_program_shape()); diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index c9d462c8ecc33d..21375542abe67c 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -66,6 +66,7 @@ limitations under the License. namespace xla { class HloModule; +class HloPayloadDeduplicator; // Describes a computation at the HLO level. // @@ -403,7 +404,8 @@ class HloComputation { absl::Span instruction_order) const; // Serializes this computation to a proto. - void ToProto(HloComputationProto* proto) const; + void ToProto(HloComputationProto* proto, + HloPayloadDeduplicator* deduplicator = nullptr) const; // Creates a computation from the given proto. Arguments: // diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 35162bb3f0ede9..257f592363f4ce 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -65,6 +65,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_original_value.h" #include "xla/hlo/ir/hlo_original_value_util.h" +#include "xla/hlo/ir/hlo_payload_deduplicator.h" #include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" @@ -4455,6 +4456,19 @@ void HloInstruction::PrintExtraAttributes( [](Printer* printer) { printer->Append("is_composite=true"); }); } break; + case HloOpcode::kCustomCall: + if (!called_computations().empty()) { + printer.Next([this, &new_options](Printer* printer) { + printer->Append("called_computations={\n"); + AppendJoin( + printer, called_computations(), ",\n", + [&](Printer* printer, const HloComputation* computation) { + computation->Print(printer, new_options); + }); + printer->Append("\n}"); + }); + } + break; default: if (!called_computations().empty()) { printer.Next([this, &new_options](Printer* printer) { @@ -4592,6 +4606,8 @@ void HloInstruction::ToProto(HloInstructionProto* proto) const { *proto->mutable_metadata() = metadata(); proto->set_backend_config(backend_config_->GetRawString()); + proto->clear_backend_config_payload(); + if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations()) { proto->add_called_computation_ids(computation->unique_id()); @@ -4616,6 +4632,16 @@ void HloInstruction::ToProto(HloInstructionProto* proto) const { } } +void HloInstruction::ToProto(HloInstructionProto* proto, + HloPayloadDeduplicator* deduplicator) const { + ToProto(proto); + if (deduplicator && !backend_config_->empty()) { + proto->mutable_backend_config_payload()->set_id( + deduplicator->Deduplicate(backend_config_.get())); + proto->clear_backend_config(); + } +} + std::string HloInstruction::ToCategory() const { if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy || opcode() == HloOpcode::kReshape || diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 0c81c7459a4721..299db9b642d9d4 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -82,6 +82,8 @@ namespace xla { class HloComputation; class HloModule; class HloInstruction; +class BackendConfigWrapper; +class HloPayloadDeduplicator; // A small holder that is used to keep some immutable info alongside an // instruction pointer in an HloComputation's list of instructions @@ -1680,6 +1682,10 @@ class HloInstruction { virtual void ToProto(HloInstructionProto* proto) const; + // Non-virtual overload that handles interning. + void ToProto(HloInstructionProto* proto, + HloPayloadDeduplicator* deduplicator) const; + // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". virtual std::string ToCategory() const; diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction_test.cc b/third_party/xla/xla/hlo/ir/hlo_instruction_test.cc index 4ac9fc759d6808..a79e3003e04d2a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction_test.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction_test.cc @@ -533,6 +533,49 @@ TEST_F(HloInstructionTest, CanonicalPrintingSupportsInt64) { "type=TOTALORDER"); } +TEST_F(HloInstructionTest, CanonicalPrintingSupportsCustomCall) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + R"( + HloModule custom_call_with_comp + + max_F32 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT maximum = f32[] maximum(lhs, rhs) + } + + ENTRY CustomCall { + constant = f32[1]{0} constant({12345}) + ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32} + } + )")); + + xla::HloPrintOptions hlo_print_options = + xla::HloPrintOptions(xla::HloPrintOptions::Canonical()); + hlo_print_options.set_is_in_nested_computation(true); + + xla::CanonicalNameMap new_map; + xla::StringPrinter printer; + module->entry_computation() + ->root_instruction() + ->operand(0) + ->PrintWithCanonicalNameMap(&printer, hlo_print_options, &new_map); + std::string param1_to_string = std::move(printer).ToString(); + + printer = StringPrinter(); + // CustomCall Root Instruction + module->entry_computation()->root_instruction()->PrintWithCanonicalNameMap( + &printer, hlo_print_options, &new_map); + std::string param2_to_string = std::move(printer).ToString(); + + EXPECT_EQ(param1_to_string, "tmp_0 = f32[1]{0} constant({12345})"); + EXPECT_EQ(param2_to_string, + "tmp_1 = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} tmp_0), " + "custom_call_target=\"foo\\\"bar\", called_computations={\n{\n " + "tmp_0 = f32[] parameter(0)\n tmp_1 = f32[] parameter(1)\n ROOT " + "tmp_2 = f32[] maximum(f32[] tmp_0, f32[] tmp_1)\n}\n}"); +} + TEST_F(HloInstructionTest, MapUnaryOutputDimToOperandDimConvert) { Shape shape = ShapeUtil::MakeShape(F32, {10, 20}); auto param = HloInstruction::CreateParameter(0, shape, "p"); diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index af7557fbce91ae..8f3589f3fa3237 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -57,6 +57,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module_metadata.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_original_value_util.h" +#include "xla/hlo/ir/hlo_payload_deduplicator.h" #include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -547,7 +548,8 @@ uint64_t HloModule::ToFingerprint( return printer.ToFingerprint(); } -void HloModule::ToProto(HloModuleProto* proto) const { +void HloModule::ToProto(HloModuleProto* proto, + bool intern_backend_config) const { proto->set_id(unique_id_); proto->set_name(name_); if (entry_computation_) { @@ -557,8 +559,25 @@ void HloModule::ToProto(HloModuleProto* proto) const { *proto->mutable_host_program_shape() = entry_computation_layout().ComputeProgramShape().ToProto(); } + + // Only create a deduplicator if needed. + int64_t base_offset = proto->payloads_size(); + std::optional deduplicator; + if (intern_backend_config) { + deduplicator.emplace(base_offset); + } + HloPayloadDeduplicator* deduplicator_ptr = + deduplicator ? &*deduplicator : nullptr; + for (const HloComputation* computation : MakeComputationPostOrder()) { - computation->ToProto(proto->add_computations()); + computation->ToProto(proto->add_computations(), deduplicator_ptr); + } + + if (deduplicator) { + DCHECK_EQ(proto->payloads_size(), base_offset); + for (std::string& payload : std::move(*deduplicator).TakePayloads()) { + proto->add_payloads(std::move(payload)); + } } if (has_schedule()) { @@ -626,9 +645,10 @@ void HloModule::ToProto(HloModuleProto* proto) const { } } -void HloModule::ToProtoWithConfig(HloModuleProtoWithConfig* proto) const { +void HloModule::ToProtoWithConfig(HloModuleProtoWithConfig* proto, + bool intern_backend_config) const { *proto->mutable_config() = config().ToProto(); - ToProto(proto->mutable_hlo_module()); + ToProto(proto->mutable_hlo_module(), intern_backend_config); } absl::Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() diff --git a/third_party/xla/xla/hlo/ir/hlo_module.h b/third_party/xla/xla/hlo/ir/hlo_module.h index def75e4fefc1a2..cb488436784978 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.h +++ b/third_party/xla/xla/hlo/ir/hlo_module.h @@ -542,11 +542,11 @@ class HloModule { computation_id_to_id_remap_map); // Convert an HloModule to a proto. - void ToProto(HloModuleProto* proto) const; + void ToProto(HloModuleProto* proto, bool intern_backend_config = false) const; - HloModuleProto ToProto() const { + HloModuleProto ToProto(bool intern_backend_config = false) const { HloModuleProto proto; - ToProto(&proto); + ToProto(&proto, intern_backend_config); return proto; } @@ -570,11 +570,13 @@ class HloModule { bool preserve_instruction_ids = true); // Convert an HloModule to or from a proto that includes module configuration - void ToProtoWithConfig(HloModuleProtoWithConfig* proto) const; + void ToProtoWithConfig(HloModuleProtoWithConfig* proto, + bool intern_backend_config = false) const; - HloModuleProtoWithConfig ToProtoWithConfig() const { + HloModuleProtoWithConfig ToProtoWithConfig( + bool intern_backend_config = false) const { HloModuleProtoWithConfig proto; - ToProtoWithConfig(&proto); + ToProtoWithConfig(&proto, intern_backend_config); return proto; } static absl::StatusOr> CreateFromProtoWithConfig( diff --git a/third_party/xla/xla/hlo/ir/hlo_module_test.cc b/third_party/xla/xla/hlo/ir/hlo_module_test.cc index 382da4c2f2a1be..7e2f0770609b4a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_test.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module_test.cc @@ -1625,5 +1625,122 @@ TEST(HloModuleTest, ModuleLevelCacheAPIs) { EXPECT_EQ(module.GetCacheEntry(key1)->value(), 100); } +TEST(HloModuleTest, BackendConfigDeduplicationAndRoundtrip) { + const char* hlo_text = R"( + HloModule test_module + ENTRY comp { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnUnverifiedModule(hlo_text)); + HloInstruction* p0 = m->entry_computation()->GetInstructionWithName("p0"); + HloInstruction* p1 = m->entry_computation()->GetInstructionWithName("p1"); + + p0->set_raw_backend_config_string("tokamax:{\"data\": 1}"); + p1->CopyBackendConfigFrom(p0); // Force in-memory sharing to test fast path. + + // Verify in-memory deduplication is active. + EXPECT_EQ(&p0->raw_backend_config_string(), &p1->raw_backend_config_string()); + + HloModuleProto proto = m->ToProto(/*intern_backend_config=*/true); + + // Verify the serialized proto structure partially using proto matchers. + using ::tsl::proto_testing::EqualsProto; + using ::tsl::proto_testing::Partially; + EXPECT_THAT(proto, Partially(EqualsProto(R"pb( + payloads: "tokamax:{\"data\": 1}" + computations { + instructions { + name: "p0" + backend_config_payload { id: 0 } + backend_config: "" + } + instructions { + name: "p1" + backend_config_payload { id: 0 } + backend_config: "" + } + instructions { name: "add" } + } + )pb"))); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr loaded, + HloModule::CreateFromProto(proto, m->config())); + const HloInstruction* loaded_p0 = + loaded->entry_computation()->GetInstructionWithName("p0"); + const HloInstruction* loaded_p1 = + loaded->entry_computation()->GetInstructionWithName("p1"); + + // Verify identical string object in memory is shared post-deserialization. + EXPECT_EQ(loaded_p0->raw_backend_config_string(), "tokamax:{\"data\": 1}"); + EXPECT_EQ(&loaded_p0->raw_backend_config_string(), + &loaded_p1->raw_backend_config_string()); +} + +TEST(HloModuleTest, BackendConfigNoInternByDefault) { + const char* hlo_text = R"( + HloModule test_module + ENTRY comp { + ROOT p0 = f32[] parameter(0) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnUnverifiedModule(hlo_text)); + HloInstruction* p0 = m->entry_computation()->root_instruction(); + p0->set_raw_backend_config_string("tokamax:{\"data\": 1}"); + + HloModuleProto proto = m->ToProto(); + // Config is NOT interned in payloads. + EXPECT_EQ(proto.payloads_size(), 0); + + using ::tsl::proto_testing::EqualsProto; + using ::tsl::proto_testing::Partially; + EXPECT_THAT( + proto, Partially(EqualsProto(R"pb( + computations { + instructions { name: "p0" backend_config: "tokamax:{\"data\": 1}" } + } + )pb"))); + EXPECT_FALSE( + proto.computations(0).instructions(0).has_backend_config_payload()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr loaded, + HloModule::CreateFromProto(proto, m->config())); + const HloInstruction* loaded_p0 = + loaded->entry_computation()->root_instruction(); + EXPECT_EQ(loaded_p0->raw_backend_config_string(), "tokamax:{\"data\": 1}"); +} + +TEST(HloModuleTest, BackendConfigDeduplicationWithBaseOffset) { + const char* hlo_text = R"( + HloModule test_module + ENTRY comp { + ROOT p0 = f32[] parameter(0) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnUnverifiedModule(hlo_text)); + HloInstruction* p0 = m->entry_computation()->root_instruction(); + p0->set_raw_backend_config_string("tokamax:{\"data\": 1}"); + + // 1. Create a proto and pre-fill its payloads with an existing string! + HloModuleProto proto; + proto.add_payloads("pre_existing_metadata"); + + // 2. Serialize to this pre-filled proto with interning! + m->ToProto(&proto, /*intern_backend_config=*/true); + + // Verify shifted ID and combined payloads. + using ::tsl::proto_testing::EqualsProto; + using ::tsl::proto_testing::Partially; + EXPECT_THAT(proto, Partially(EqualsProto(R"pb( + payloads: "pre_existing_metadata" + payloads: "tokamax:{\"data\": 1}" + computations { + instructions { + name: "p0" + backend_config_payload { id: 1 } + } + } + )pb"))); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_payload_deduplicator.cc b/third_party/xla/xla/hlo/ir/hlo_payload_deduplicator.cc new file mode 100644 index 00000000000000..c94d2974263d27 --- /dev/null +++ b/third_party/xla/xla/hlo/ir/hlo_payload_deduplicator.cc @@ -0,0 +1,63 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_payload_deduplicator.h" + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/backend_config.h" + +namespace xla { + +HloPayloadDeduplicator::HloPayloadDeduplicator(int64_t base_offset) + : offset_(base_offset) {} + +int64_t HloPayloadDeduplicator::Deduplicate( + const BackendConfigWrapper* wrapper) { + auto it = pointer_map_.find(wrapper); + if (it != pointer_map_.end()) { + return it->second; + } + + // Fall back to string deduplication. + int64_t id = Deduplicate(wrapper->GetRawString()); + pointer_map_.emplace(wrapper, id); + return id; +} + +int64_t HloPayloadDeduplicator::Deduplicate(absl::string_view value) { + auto it = string_map_.find(value); + if (it != string_map_.end()) { + return it->second; + } + int64_t id = offset_ + payloads_.size(); + payloads_.emplace_back(value); + string_map_.emplace(payloads_.back(), id); + return id; +} + +std::deque HloPayloadDeduplicator::TakePayloads() { + std::deque result = std::move(payloads_); + payloads_.clear(); + string_map_.clear(); + pointer_map_.clear(); + return result; +} + +} // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_payload_deduplicator.h b/third_party/xla/xla/hlo/ir/hlo_payload_deduplicator.h new file mode 100644 index 00000000000000..aa0ea3e3637918 --- /dev/null +++ b/third_party/xla/xla/hlo/ir/hlo_payload_deduplicator.h @@ -0,0 +1,62 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_IR_HLO_PAYLOAD_DEDUPLICATOR_H_ +#define XLA_HLO_IR_HLO_PAYLOAD_DEDUPLICATOR_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/backend_config.h" + +namespace xla { + +// Helper for deduplicating backend_config payloads into a growing list +// during HLO module serialization. +class HloPayloadDeduplicator { + public: + // Initializes the deduplicator. + // `base_offset` shifts the generated payload IDs on-the-fly. This is useful + // when serializing into a proto that already has pre-existing payloads + // (e.g. when appending to a pre-filled proto). + explicit HloPayloadDeduplicator(int64_t base_offset = 0); + + // Fast path: deduplicates the backend config using the in-memory pointer + // address of the BackendConfigWrapper. Returns the unique index (ID) of + // the payload and stores it if not already stored. + int64_t Deduplicate(const BackendConfigWrapper* wrapper); + + // Fallback path: deduplicates the backend config using raw string comparison. + // Stores the given value if not already stored, and returns a unique index + // (ID) referencing it. + int64_t Deduplicate(absl::string_view value); + + // Returns the collected string payloads and transfers ownership of them + // to the caller (moves the internal payloads list), avoiding copies. + std::deque TakePayloads(); + + private: + int64_t offset_; + std::deque payloads_; + absl::flat_hash_map string_map_; + absl::flat_hash_map pointer_map_; +}; + +} // namespace xla + +#endif // XLA_HLO_IR_HLO_PAYLOAD_DEDUPLICATOR_H_ diff --git a/third_party/xla/xla/hlo/ir/hlo_payload_deduplicator_test.cc b/third_party/xla/xla/hlo/ir/hlo_payload_deduplicator_test.cc new file mode 100644 index 00000000000000..133fc24c01aa4e --- /dev/null +++ b/third_party/xla/xla/hlo/ir/hlo_payload_deduplicator_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_payload_deduplicator.h" + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/backend_config.h" + +namespace xla { +namespace { + +TEST(HloPayloadDeduplicatorTest, DeduplicateString) { + HloPayloadDeduplicator deduplicator; + EXPECT_EQ(deduplicator.Deduplicate("config1"), 0); + EXPECT_EQ(deduplicator.Deduplicate("config2"), 1); + // Duplicate string should return the same ID. + EXPECT_EQ(deduplicator.Deduplicate("config1"), 0); + + auto payloads = deduplicator.TakePayloads(); + EXPECT_EQ(payloads.size(), 2); + EXPECT_EQ(payloads[0], "config1"); + EXPECT_EQ(payloads[1], "config2"); +} + +TEST(HloPayloadDeduplicatorTest, DeduplicatePointer) { + HloPayloadDeduplicator deduplicator; + auto wrapper0 = std::make_shared("config1"); + // Shares the same pointer (fast path). + auto wrapper1 = wrapper0; + // Different pointer but same string (fallback). + auto wrapper2 = std::make_shared("config1"); + + EXPECT_EQ(deduplicator.Deduplicate(wrapper0.get()), 0); + EXPECT_EQ(deduplicator.Deduplicate(wrapper1.get()), 0); + EXPECT_EQ(deduplicator.Deduplicate(wrapper2.get()), 0); + + auto payloads = deduplicator.TakePayloads(); + EXPECT_EQ(payloads.size(), 1); + EXPECT_EQ(payloads[0], "config1"); +} + +TEST(HloPayloadDeduplicatorTest, DeduplicateWithBaseOffset) { + HloPayloadDeduplicator deduplicator(5); + EXPECT_EQ(deduplicator.Deduplicate("config1"), 5); + EXPECT_EQ(deduplicator.Deduplicate("config2"), 6); + EXPECT_EQ(deduplicator.Deduplicate("config1"), 5); + + auto payloads = deduplicator.TakePayloads(); + EXPECT_EQ(payloads.size(), 2); + EXPECT_EQ(payloads[0], "config1"); + EXPECT_EQ(payloads[1], "config2"); +} + +TEST(HloPayloadDeduplicatorTest, TakePayloadsMovesMemory) { + HloPayloadDeduplicator deduplicator; + EXPECT_EQ(deduplicator.Deduplicate("config1"), 0); + + auto payloads1 = deduplicator.TakePayloads(); + EXPECT_EQ(payloads1.size(), 1); + EXPECT_EQ(payloads1[0], "config1"); + + // Subsequent take should be empty because it was moved. + auto payloads2 = deduplicator.TakePayloads(); + EXPECT_EQ(payloads2.size(), 0); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/hlo/parser/BUILD b/third_party/xla/xla/hlo/parser/BUILD index 5b9eba25155c72..63229bcf9a8236 100644 --- a/third_party/xla/xla/hlo/parser/BUILD +++ b/third_party/xla/xla/hlo/parser/BUILD @@ -50,6 +50,7 @@ cc_library( "//xla/tsl/lib/gtl:map_util", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -88,6 +89,7 @@ xla_cc_test( "//xla/service:pattern_matcher", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", diff --git a/third_party/xla/xla/hlo/parser/hlo_parser.cc b/third_party/xla/xla/hlo/parser/hlo_parser.cc index a04884d64e5795..73306b63263e95 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser.cc @@ -48,6 +48,7 @@ limitations under the License. #include "absl/strings/strip.h" #include "absl/types/span.h" #include "Eigen/Core" +#include "xla/tsl/platform/status_macros.h" #include "google/protobuf/descriptor.h" #include "xla/array.h" #include "xla/comparison_util.h" @@ -8335,7 +8336,7 @@ absl::StatusOr> ParseAndReturnUnverifiedModule( const HloParserOptions& options) { auto module = std::make_unique(/*name=*/"_", config); HloParserImpl parser(str, options); - TF_RETURN_IF_ERROR(parser.Run(module.get())); + RETURN_IF_ERROR(parser.Run(module.get())); return module; } diff --git a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc index e9936fa619d16a..a30c74456834e4 100644 --- a/third_party/xla/xla/hlo/parser/hlo_parser_test.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/array.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -2844,7 +2845,7 @@ absl::StatusOr> ParseAndReturnVerifiedModule( /*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/true, ShapeUtil::ByteSizeOfElements); - TF_RETURN_IF_ERROR(verified_module->ParseHloStringAndVerifyModule(hlo_text)); + RETURN_IF_ERROR(verified_module->ParseHloStringAndVerifyModule(hlo_text)); return verified_module; } diff --git a/third_party/xla/xla/hlo/pass/BUILD b/third_party/xla/xla/hlo/pass/BUILD index 1adab6d1a040db..33e3d949c8ee60 100644 --- a/third_party/xla/xla/hlo/pass/BUILD +++ b/third_party/xla/xla/hlo/pass/BUILD @@ -104,6 +104,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@tsl//tsl/profiler/lib:scoped_annotation", "@tsl//tsl/profiler/lib:traceme", + "@tsl//tsl/profiler/lib:traceme_encode", ], ) diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc index ab4cc9398968ac..e9a35dd1763be6 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc +++ b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/xla.pb.h" #include "tsl/profiler/lib/scoped_annotation.h" #include "tsl/profiler/lib/traceme.h" +#include "tsl/profiler/lib/traceme_encode.h" namespace xla { @@ -301,7 +302,9 @@ absl::StatusOr HloPassPipeline::RunImpl( VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": " << name(); - tsl::profiler::TraceMe traceme(name()); + tsl::profiler::TraceMe traceme([&] { + return tsl::profiler::TraceMeEncode(name(), {{"module", module->name()}}); + }); // Copy debug options by value as passes may modify module config. DebugOptions debug_options = module->config().debug_options(); return RunPassesInternal(module, debug_options, execution_threads); @@ -315,7 +318,9 @@ absl::StatusOr HloPassPipeline::RunImpl( VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": " << name(); - tsl::profiler::TraceMe traceme(name()); + tsl::profiler::TraceMe traceme([&] { + return tsl::profiler::TraceMeEncode(name(), {{"module", module->name()}}); + }); // Copy debug options by value as passes may modify module config. DebugOptions debug_options = module->config().debug_options(); return RunPassesInternal&>(module, debug_options, diff --git a/third_party/xla/xla/hlo/separate_compilation/BUILD b/third_party/xla/xla/hlo/separate_compilation/BUILD index fe6b46bfe9cd8f..17acd2d1e149c7 100644 --- a/third_party/xla/xla/hlo/separate_compilation/BUILD +++ b/third_party/xla/xla/hlo/separate_compilation/BUILD @@ -34,6 +34,7 @@ cc_library( "//xla/service:compilation_environments", "//xla/service:hlo_module_config", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -58,6 +59,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/service:compilation_environments", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/hlo/separate_compilation/hlo_module_linking.cc b/third_party/xla/xla/hlo/separate_compilation/hlo_module_linking.cc index f2e1aac4e1b567..94944e35240c73 100644 --- a/third_party/xla/xla/hlo/separate_compilation/hlo_module_linking.cc +++ b/third_party/xla/xla/hlo/separate_compilation/hlo_module_linking.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -77,7 +78,7 @@ class HloLinker { } else if (!current.entered) { VLOG(6) << "First visit to link: " << current.principal->name(); - TF_RETURN_IF_ERROR(HandleFirstVisit(current)); + RETURN_IF_ERROR(HandleFirstVisit(current)); } else { VLOG(6) << "Second visit to link: " << current.principal->name(); @@ -199,13 +200,13 @@ absl::StatusOr> LinkComputation( *linking_manifest.compilation_environment)); HloLinker linker(linked_module.get(), linking_manifest, root_computation); - TF_ASSIGN_OR_RETURN(HloComputation * linked_clone_ptr, linker.Link()); + ASSIGN_OR_RETURN(HloComputation * linked_clone_ptr, linker.Link()); linked_module->ReplaceEntryComputation(linked_clone_ptr); linked_module->mutable_config().SetComputationLayoutIfExists( linked_clone_ptr->ComputeProgramShape()); xla::HloDCE dce_pass; - TF_RETURN_IF_ERROR(dce_pass.Run(linked_module.get()).status()); + RETURN_IF_ERROR(dce_pass.Run(linked_module.get()).status()); if (VLOG_IS_ON(6)) { for (const HloComputation* comp : linked_module->computations()) { diff --git a/third_party/xla/xla/hlo/separate_compilation/hlo_module_splitting.cc b/third_party/xla/xla/hlo/separate_compilation/hlo_module_splitting.cc index 00a0bd2a0acefc..cfcc7faa980150 100644 --- a/third_party/xla/xla/hlo/separate_compilation/hlo_module_splitting.cc +++ b/third_party/xla/xla/hlo/separate_compilation/hlo_module_splitting.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -75,7 +76,7 @@ absl::StatusOr> CreateCalleeStub( std::vector operands; for (const HloInstruction* parameter : callee->parameter_instructions()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * cloned_parameter, comp_builder.AddParameter(parameter->Clone(/*suffix=*/""))); operands.push_back(cloned_parameter); @@ -235,8 +236,8 @@ absl::StatusOr> CreateHloModuleSplit( callee_replacements[caller] = callee; continue; } - TF_ASSIGN_OR_RETURN(std::unique_ptr stub, - CreateCalleeStub(callee, callee_index)); + ASSIGN_OR_RETURN(std::unique_ptr stub, + CreateCalleeStub(callee, callee_index)); VLOG(4) << "Stubbing " << stub->name() << " --> " << callee->name() << " " << stub->ToString(); HloComputation* stub_raw_ptr = @@ -287,21 +288,21 @@ absl::StatusOr> CreateHloModuleSplitGroup( absl::flat_hash_map global_computation_map; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> splits, GroupComputationsForSplitting(module)); for (const auto& split : splits) { - TF_ASSIGN_OR_RETURN(auto module_split, CreateHloModuleSplit(module, split)); + ASSIGN_OR_RETURN(auto module_split, CreateHloModuleSplit(module, split)); module_splits.push_back(std::move(module_split)); for (const auto* original_comp : split) { computation_address_book.insert( {original_comp, module_splits.back().get()}); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( MergeMapInto(global_stub_map, module_splits.back()->stub_map)); - TF_RETURN_IF_ERROR(MergeMapInto(global_computation_map, - module_splits.back()->computation_map)); + RETURN_IF_ERROR(MergeMapInto(global_computation_map, + module_splits.back()->computation_map)); } if (VLOG_IS_ON(5)) { @@ -320,8 +321,8 @@ absl::StatusOr> CreateHloModuleSplitGroup( } // Compose at the end once all planned cloning operations are finished and // we know where each original computation ended up. - TF_ASSIGN_OR_RETURN(auto stub_links, - ComposeMaps(global_stub_map, global_computation_map)); + ASSIGN_OR_RETURN(auto stub_links, + ComposeMaps(global_stub_map, global_computation_map)); HloLinkingManifest linking_manifest{ std::move(stub_links), module.shared_config(), diff --git a/third_party/xla/xla/hlo/testlib/BUILD b/third_party/xla/xla/hlo/testlib/BUILD index ab4f59dd0448ed..af1cdd063674c1 100644 --- a/third_party/xla/xla/hlo/testlib/BUILD +++ b/third_party/xla/xla/hlo/testlib/BUILD @@ -36,6 +36,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:test", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -66,6 +67,7 @@ cc_library( "//xla/service:hlo_verifier", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/algorithm:container", @@ -94,6 +96,7 @@ cc_library( "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:resource_loader", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:subprocess", "@com_google_absl//absl/log", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/hlo/testlib/filecheck.cc b/third_party/xla/xla/hlo/testlib/filecheck.cc index 7f351633f231df..a161e4786cf382 100644 --- a/third_party/xla/xla/hlo/testlib/filecheck.cc +++ b/third_party/xla/xla/hlo/testlib/filecheck.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/resource_loader.h" @@ -41,7 +42,7 @@ absl::StatusOr RunFileCheck( if (!env->LocalTempFilename(&pattern_path)) { return absl::InternalError("couldn't get a pattern file name"); } - TF_RETURN_IF_ERROR(tsl::WriteStringToFile(env, pattern_path, pattern)); + RETURN_IF_ERROR(tsl::WriteStringToFile(env, pattern_path, pattern)); VLOG(3) << "input: " << input; return RunFileCheckWithPatternFile(input, pattern_path, additional_check_prefixes); diff --git a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc index ae1db1a7493ec0..adcb693ef811d8 100644 --- a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc +++ b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -123,13 +124,13 @@ absl::Status HloHardwareIndependentTestBase:: for (int64_t i = 0; i < computation->num_parameters(); ++i) { const Shape& param_shape = computation->parameter_instruction(i)->shape(); - TF_RETURN_IF_ERROR(computation->parent() - ->mutable_entry_computation_layout() - ->mutable_parameter_layout(i) - ->CopyLayoutFromShape(param_shape)); + RETURN_IF_ERROR(computation->parent() + ->mutable_entry_computation_layout() + ->mutable_parameter_layout(i) + ->CopyLayoutFromShape(param_shape)); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation->parent() ->mutable_entry_computation_layout() ->mutable_result_layout() @@ -165,7 +166,7 @@ HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( TestName(), config_with_device_assignment, verifier_layout_sensitive_, allow_mixed_precision_in_hlo_verifier_, shape_size_fn, instruction_can_change_layout_func_); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( module->ParseHloStringAndVerifyModule(hlo_text, parser_options)); return module; } @@ -174,7 +175,7 @@ HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( absl::StatusOr HloHardwareIndependentTestBase::RunHloPass( HloPassInterface* hlo_pass, HloModule* module) { const std::string before_run = module->ToProto().ShortDebugString(); - TF_ASSIGN_OR_RETURN(bool changed, hlo_pass->Run(module)); + ASSIGN_OR_RETURN(bool changed, hlo_pass->Run(module)); const std::string after_run = module->ToProto().ShortDebugString(); if (changed) { EXPECT_NE(after_run, before_run) << absl::StrFormat( @@ -261,10 +262,10 @@ HloHardwareIndependentTestBase::RunAndCheckHloRewrite( std::string hlo_string = absl::StrReplaceAll(hlo_template, params); SCOPED_TRACE("Input HLO: " + hlo_string); VLOG(7) << "Input HLO: " << hlo_string; - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); VLOG(7) << "Input HLO parsed. Running the pass: + " << hlo_pass->name(); - TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(hlo_pass, module.get())); + ASSIGN_OR_RETURN(bool changed, RunHloPass(hlo_pass, module.get())); VLOG(7) << "Output HLO: " << module->ToString(HloPrintOptions::ShortParsable() .set_print_control_dependencies(true)); diff --git a/third_party/xla/xla/hlo/testlib/verified_hlo_module.cc b/third_party/xla/xla/hlo/testlib/verified_hlo_module.cc index f566f04fd6e320..049c6944396d0d 100644 --- a/third_party/xla/xla/hlo/testlib/verified_hlo_module.cc +++ b/third_party/xla/xla/hlo/testlib/verified_hlo_module.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/status_macros.h" #include "xla/tsl/platform/errors.h" @@ -31,7 +32,7 @@ absl::Status VerifiedHloModule::ParseHloStringAndVerifyModule( absl::string_view str, const HloParserOptions& options) { TF_RET_CHECK(computation_count() == 0); auto parser = HloParser::CreateHloParserForTests(str, options); - TF_RETURN_IF_ERROR(parser->Run(this)); + RETURN_IF_ERROR(parser->Run(this)); return Verify(); } diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/BUILD index 74a991fa97e83e..d971d274d4f630 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_diff/BUILD @@ -78,6 +78,7 @@ cc_library( "//xla/hlo/tools/hlo_diff/proto:diff_result_proto_cc", "//xla/hlo/tools/hlo_diff/utils:bidirectional_map", "//xla/hlo/tools/hlo_diff/utils:connected_components", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -131,6 +132,7 @@ cc_library( "//xla/hlo/tools/hlo_diff/matchers:top_down_matcher", "//xla/service:call_graph", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -199,6 +201,7 @@ xla_cc_binary( "//xla/service:hlo_proto_cc", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/graph/BUILD index d5d8563400ee66..0fca0dc12d40b9 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/graph/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/BUILD @@ -40,6 +40,7 @@ cc_library( "//xla/service:call_graph", "//xla/service:hlo_value", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/BUILD b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/BUILD index 233bc847801026..d1bec06eba4186 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/BUILD @@ -27,6 +27,7 @@ cc_library( "//xla/service:call_graph", "//xla/service:hlo_value", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc index e68eb907aa37c9..4a8e14900eac30 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/analysis/hlo_value_tracing.cc @@ -30,6 +30,7 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -1169,7 +1170,7 @@ absl::StatusOr> HloValueTracing::Run( auto hlo_value_tracing = absl::WrapUnique(new HloValueTracing(module, execution_threads)); - TF_RETURN_IF_ERROR(hlo_value_tracing->InitializeInstructionValueSets()); + RETURN_IF_ERROR(hlo_value_tracing->InitializeInstructionValueSets()); hlo_value_tracing->Propagate(); // Delete all values marked for deletion. diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.cc b/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.cc index 13275f37e629ec..b97a5acd981760 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.cc @@ -29,6 +29,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -153,7 +154,7 @@ absl::Status HloGumgraph::ConstructGraph(const HloModule& hlo_module) { for (auto* computation : instruction->called_computations()) { if (call_graph_->GetComputationCallers(computation).size() == 1) { inline_called_computations = true; - TF_RETURN_IF_ERROR(ConnectCalledComputation( + RETURN_IF_ERROR(ConnectCalledComputation( instruction->operands(), computation->parameter_instructions())); } @@ -178,7 +179,7 @@ absl::Status HloGumgraph::ConstructGraph(const HloModule& hlo_module) { ->GetComputationCallers(instruction->branch_computation(i)) .size() == 1) { inline_called_computations = true; - TF_RETURN_IF_ERROR(ConnectCalledComputation( + RETURN_IF_ERROR(ConnectCalledComputation( HloInstruction::InstructionVector( {instruction->operands()[i + 1]}), instruction->branch_computation(i) @@ -192,7 +193,7 @@ absl::Status HloGumgraph::ConstructGraph(const HloModule& hlo_module) { } if (!inline_called_computations) { - TF_RETURN_IF_ERROR(ConnectOperands(node)); + RETURN_IF_ERROR(ConnectOperands(node)); } // Connect the root instruction of the called computation with the @@ -299,8 +300,8 @@ void HloGumgraph::PrecomputeSizeAndHeight() { absl::Status HloGumgraph::PrecomputeComputationFingerprint() { LOG(INFO) << "Precomputing computation fingerprint"; - TF_RETURN_IF_ERROR(call_graph_->VisitNodes([&](const CallGraphNode& node) - -> absl::Status { + RETURN_IF_ERROR(call_graph_->VisitNodes([&](const CallGraphNode& node) + -> absl::Status { absl::flat_hash_map subgraph_fingerprint; const HloComputation* computation = node.computation(); for (auto* instruction : computation->MakeInstructionPostOrder()) { @@ -395,14 +396,14 @@ absl::StatusOr> HloGumgraph::Create( new HloGumgraph(*hlo_module, fingerprint_options, std::move(call_graph), std::move(hlo_value_tracing_ptr))); - TF_RETURN_IF_ERROR(graph->ConstructGraph(*hlo_module)); - TF_ASSIGN_OR_RETURN(std::vector zero_indegree_nodes, - graph->PrecomputeGenerations()); + RETURN_IF_ERROR(graph->ConstructGraph(*hlo_module)); + ASSIGN_OR_RETURN(std::vector zero_indegree_nodes, + graph->PrecomputeGenerations()); for (auto* zero_indegree_node : zero_indegree_nodes) { AddEdge(&graph->root_, zero_indegree_node); } graph->PrecomputeSizeAndHeight(); - TF_RETURN_IF_ERROR(graph->PrecomputeComputationFingerprint()); + RETURN_IF_ERROR(graph->PrecomputeComputationFingerprint()); if (precompute_instruction_dependencies) { graph->PrecomputeInstructionDependencies(); } diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc index d7aec6c9ef29c4..e8fbd9ad659fc0 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_main.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -109,9 +110,9 @@ absl::Status CheckGroupFlags(const Options::HloPath& hlo_path) { // Builds a HloModule from the HloModuleProto. absl::StatusOr> BuildHloModule( const HloModuleProto& hlo_module_proto) { - TF_ASSIGN_OR_RETURN(HloModuleConfig config, - HloModule::CreateModuleConfigFromProto( - hlo_module_proto, xla::GetDebugOptionsFromFlags())); + ASSIGN_OR_RETURN(HloModuleConfig config, + HloModule::CreateModuleConfigFromProto( + hlo_module_proto, xla::GetDebugOptionsFromFlags())); return HloModule::CreateFromProto(hlo_module_proto, config); } @@ -162,12 +163,11 @@ absl::StatusOr> LoadHLOModule( // Runs Gumgraph algorithm based diff and renders the diff results. absl::Status RunGumgraphDiff(HloModule& first_module, HloModule& second_module, const Options& opts) { - TF_RETURN_IF_ERROR(first_module.RemoveUnusedComputations()); - TF_RETURN_IF_ERROR(second_module.RemoveUnusedComputations()); + RETURN_IF_ERROR(first_module.RemoveUnusedComputations()); + RETURN_IF_ERROR(second_module.RemoveUnusedComputations()); - TF_ASSIGN_OR_RETURN( - auto hlo_gumgraph_diff, - ComputeDiff(first_module, second_module, opts.diff_options)); + ASSIGN_OR_RETURN(auto hlo_gumgraph_diff, + ComputeDiff(first_module, second_module, opts.diff_options)); std::cout << "Diffing finished" << '\n'; const DiffResult& diff = *hlo_gumgraph_diff.diff_result; @@ -181,7 +181,7 @@ absl::Status RunGumgraphDiff(HloModule& first_module, HloModule& second_module, if (!text_output.empty()) { std::ostringstream text; RenderText(diff, text); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::WriteStringToFile(tsl::Env::Default(), text_output, text.str())); } @@ -189,7 +189,7 @@ absl::Status RunGumgraphDiff(HloModule& first_module, HloModule& second_module, if (!html_output.empty()) { std::ostringstream html; RenderHtml(diff, diff_summary, html); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::WriteStringToFile(tsl::Env::Default(), html_output, html.str())); std::cout << "The diff summary is saved to: " << html_output << '\n'; diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc index f82768ff6f3b3c..c3a041a2a53b1a 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_diff_summary.cc @@ -33,6 +33,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -401,10 +402,10 @@ std::unique_ptr ConstructDiffSummary( absl::StatusOr> ConstructDiffSummary( const HloModule& left_module, const HloModule& right_module, const DiffResult& diff_result) { - TF_ASSIGN_OR_RETURN(std::unique_ptr graph_l, - HloGumgraph::Create(&left_module)); - TF_ASSIGN_OR_RETURN(std::unique_ptr graph_r, - HloGumgraph::Create(&right_module)); + ASSIGN_OR_RETURN(std::unique_ptr graph_l, + HloGumgraph::Create(&left_module)); + ASSIGN_OR_RETURN(std::unique_ptr graph_r, + HloGumgraph::Create(&right_module)); return ConstructDiffSummary(*graph_l, *graph_r, diff_result); } diff --git a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.cc b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.cc index df424541b360f1..1a443cdb831c4e 100644 --- a/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.cc +++ b/third_party/xla/xla/hlo/tools/hlo_diff/hlo_gumgraph_diff.cc @@ -23,6 +23,7 @@ #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/tools/hlo_diff/graph/hlo_gumgraph.h" @@ -58,7 +59,7 @@ absl::StatusOr> FindMappings( MatchCallGraphs(left, right, *mappings); - TF_RETURN_IF_ERROR(left.GetCallGraph().VisitNodes( + RETURN_IF_ERROR(left.GetCallGraph().VisitNodes( [&](const CallGraphNode& node) { if (auto right_node = mappings->left_to_right_computation_map.GetRight(&node); @@ -102,7 +103,7 @@ absl::StatusOr ComputeDiff(const HloModule& left, const HloModule& right, const DiffOptions& options) { LOG(INFO) << "Initializing left module graph"; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr left_graph, HloGumgraph::Create(&left, options.fingerprint_options, options.precompute_instruction_dependencies)); @@ -111,7 +112,7 @@ absl::StatusOr ComputeDiff(const HloModule& left, << " and height: " << left_graph->GetRoot().props.height; LOG(INFO) << "Initializing right module graph"; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr right_graph, HloGumgraph::Create(&right, options.fingerprint_options, options.precompute_instruction_dependencies)); @@ -119,7 +120,7 @@ absl::StatusOr ComputeDiff(const HloModule& left, << right_graph->GetNodeCount() << " and height: " << right_graph->GetRoot().props.height; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr mappings, FindMappings(*left_graph, *right_graph, options.manual_mappings, options.match_options)); diff --git a/third_party/xla/xla/hlo/tools/hlo_opt/BUILD b/third_party/xla/xla/hlo/tools/hlo_opt/BUILD index 085a9e835e5509..e9ee189b20a55f 100644 --- a/third_party/xla/xla/hlo/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/hlo/tools/hlo_opt/BUILD @@ -19,6 +19,7 @@ cc_library( "//xla/tools:hlo_module_loader", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", @@ -126,6 +127,7 @@ cc_library( "//xla/service:float_support", "//xla/service:platform_util", "//xla/stream_executor/platform:initialize", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc b/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc index c1d94dc9f0a730..f805732c569d36 100644 --- a/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc +++ b/third_party/xla/xla/hlo/tools/hlo_opt/opt_lib.cc @@ -36,6 +36,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" @@ -154,8 +155,8 @@ static ProviderMap& GetProviderMap() { std::string platform) { absl::MutexLock l(provider_mu); - TF_ASSIGN_OR_RETURN(std::string canonical_name, - xla::PlatformUtil::CanonicalPlatformName(platform)); + ASSIGN_OR_RETURN(std::string canonical_name, + xla::PlatformUtil::CanonicalPlatformName(platform)); auto it = GetProviderMap().find(canonical_name); if (it == GetProviderMap().end()) { return absl::UnimplementedError(absl::StrCat( diff --git a/third_party/xla/xla/hlo/tools/hlo_opt/opt_main.cc b/third_party/xla/xla/hlo/tools/hlo_opt/opt_main.cc index fd13e3bb2289e6..5311adee3141af 100644 --- a/third_party/xla/xla/hlo/tools/hlo_opt/opt_main.cc +++ b/third_party/xla/xla/hlo/tools/hlo_opt/opt_main.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/tools/hlo_opt/opt_lib.h" @@ -102,15 +103,13 @@ absl::StatusOr GetHloContents(const HloOptConfig& opts, int argc, } std::string data; - TF_RETURN_IF_ERROR( - tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &data)); + RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &data)); return data; } absl::StatusOr>> GetModules( const HloOptConfig& opts, int argc, char** argv) { - TF_ASSIGN_OR_RETURN(std::string module_data, - GetHloContents(opts, argc, argv)); + ASSIGN_OR_RETURN(std::string module_data, GetHloContents(opts, argc, argv)); std::vector hlos; if (opts.split_input_file) { @@ -139,8 +138,8 @@ absl::StatusOr>> GetModules( "specified"); } } - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - LoadModuleFromData(hlo, format)); + ASSIGN_OR_RETURN(std::unique_ptr module, + LoadModuleFromData(hlo, format)); out.push_back(std::move(module)); } return out; @@ -159,8 +158,8 @@ std::unique_ptr GetDummyModule() { absl::StatusOr TranslateToStage(int argc, char** argv, const HloOptConfig& opts) { - TF_ASSIGN_OR_RETURN(OptProvider * provider, - OptProvider::GetProviderForPlatform(opts.platform)); + ASSIGN_OR_RETURN(OptProvider * provider, + OptProvider::GetProviderForPlatform(opts.platform)); if (opts.list_stages) { return absl::StrJoin(provider->SupportedStages(), "\n"); @@ -173,8 +172,8 @@ absl::StatusOr TranslateToStage(int argc, char** argv, return provider->GetRegisteredPassNames(); } - TF_ASSIGN_OR_RETURN(std::vector> modules, - GetModules(opts, argc, argv)); + ASSIGN_OR_RETURN(std::vector> modules, + GetModules(opts, argc, argv)); if (opts.emit_proto) { std::string proto_str_combined; for (const auto& module : modules) { @@ -194,11 +193,10 @@ absl::StatusOr TranslateToStage(int argc, char** argv, for (std::unique_ptr& m : modules) { std::optional out; if (!opts.passes.empty()) { - TF_ASSIGN_OR_RETURN(out, provider->BuildAndRunTransformPipeline( - std::move(m), opts.passes)); + ASSIGN_OR_RETURN(out, provider->BuildAndRunTransformPipeline( + std::move(m), opts.passes)); } else { - TF_ASSIGN_OR_RETURN(out, - provider->GenerateStage(std::move(m), opts.stage)); + ASSIGN_OR_RETURN(out, provider->GenerateStage(std::move(m), opts.stage)); } if (!out.has_value()) { return absl::UnimplementedError("Stage not supported"); @@ -210,11 +208,11 @@ absl::StatusOr TranslateToStage(int argc, char** argv, } absl::Status RunOpt(int argc, char** argv, const HloOptConfig& opts) { - TF_ASSIGN_OR_RETURN(std::string output, TranslateToStage(argc, argv, opts)); + ASSIGN_OR_RETURN(std::string output, TranslateToStage(argc, argv, opts)); if (opts.output_file == "-") { std::cout << output << std::endl; } else { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::WriteStringToFile(tsl::Env::Default(), opts.output_file, output)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/hlo/tools/tests/BUILD b/third_party/xla/xla/hlo/tools/tests/BUILD index 86ec14a2ade184..14e0f1cfb4019c 100644 --- a/third_party/xla/xla/hlo/tools/tests/BUILD +++ b/third_party/xla/xla/hlo/tools/tests/BUILD @@ -98,6 +98,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:algebraic_simplifier", "//xla/service:hlo_module_config", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc b/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc index c9b5de93c5af4d..3e2f101fe31d31 100644 --- a/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc +++ b/third_party/xla/xla/hlo/tools/tests/hlo_opt_test_only_passes.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/builder/lib/math.h" #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/lib/prng.h" @@ -49,10 +50,10 @@ namespace { absl::StatusOr XlaComputationToHloComputation( XlaComputation& src_comp, HloModule* dest_module) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, src_comp.GetProgramShape()); + ASSIGN_OR_RETURN(ProgramShape program_shape, src_comp.GetProgramShape()); HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(src_comp.proto(), config)); + ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(src_comp.proto(), config)); HloCloneContext context(dest_module); return dest_module->DeepCloneComputation(new_module->entry_computation(), &context); @@ -103,13 +104,12 @@ absl::StatusOr BuildAndReplace(XlaBuilder& builder, HloComputation* computation = instruction->parent(); HloModule* module = computation->parent(); - TF_ASSIGN_OR_RETURN(XlaComputation called_computation, builder.Build()); - TF_ASSIGN_OR_RETURN( - HloComputation * new_computation, - XlaComputationToHloComputation(called_computation, module)); + ASSIGN_OR_RETURN(XlaComputation called_computation, builder.Build()); + ASSIGN_OR_RETURN(HloComputation * new_computation, + XlaComputationToHloComputation(called_computation, module)); HloInstruction* new_instruction = computation->AddInstruction( CreateCustomCallToBuilderMethod(instruction, new_computation)); - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_instruction)); + RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_instruction)); return true; } @@ -124,69 +124,69 @@ absl::StatusOr XlaBuilderTestPass::ReplaceWithExpandedClientHlo( // xla_builder.math if (custom_call_target == "xla_builder.math.Acos") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); xla::Acos(parameters[0]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.math.Acosh") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); xla::Acosh(parameters[0]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.math.Asin") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); xla::Asin(parameters[0]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.math.Asinh") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); xla::Asinh(parameters[0]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.math.Atan") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); xla::Atan(parameters[0]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.math.Cosh") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); xla::Cosh(parameters[0]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.math.IgammaGradA") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 2, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 2, custom_call_target)); xla::IgammaGradA(parameters[0], parameters[1]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.math.NextAfter") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 2, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 2, custom_call_target)); xla::NextAfter(parameters[0], parameters[1]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.math.Polygamma") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 2, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 2, custom_call_target)); xla::Polygamma(parameters[0], parameters[1]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.math.RandomGammaGrad") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 2, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 2, custom_call_target)); xla::RandomGammaGrad(parameters[0], parameters[1]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.math.RegularizedIncompleteBeta") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 3, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 3, custom_call_target)); xla::RegularizedIncompleteBeta(parameters[0], parameters[1], parameters[2]); return BuildAndReplace(builder, instruction); } // xla_builder.matrix if (custom_call_target == "xla_builder.matrix.GetMatrixDiagonalViaGather") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); xla::GetMatrixDiagonalViaGather(parameters[0]); return BuildAndReplace(builder, instruction); } if (custom_call_target == "xla_builder.matrix.Einsum") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 2, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 2, custom_call_target)); absl::string_view einsum_config = instruction->raw_backend_config_string(); xla::Einsum(parameters[0], parameters[1], einsum_config); return BuildAndReplace(builder, instruction); @@ -194,24 +194,23 @@ absl::StatusOr XlaBuilderTestPass::ReplaceWithExpandedClientHlo( // xla_builder.prng if (custom_call_target == "xla_builder.prng.ScramblePhiloxKey") { - TF_RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); + RETURN_IF_ERROR(VerifyOperandCount(instruction, 1, custom_call_target)); xla::ScramblePhiloxKey(parameters[0]); return BuildAndReplace(builder, instruction); } // xla_builder.tridiagonal if (custom_call_target == "xla_builder.tridiagonal.TridiagonalSolver") { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( VerifyOperandCounts(instruction, {2, 4}, custom_call_target)); if (parameters.size() == 2) { - TF_ASSIGN_OR_RETURN( - std::ignore, xla::tridiagonal::TridiagonalSolver( - tridiagonal::SolverAlgorithm::kThomas, parameters[0], - parameters[1])); + ASSIGN_OR_RETURN(std::ignore, xla::tridiagonal::TridiagonalSolver( + tridiagonal::SolverAlgorithm::kThomas, + parameters[0], parameters[1])); return BuildAndReplace(builder, instruction); } - TF_ASSIGN_OR_RETURN( - std::ignore, xla::tridiagonal::TridiagonalSolver( + ASSIGN_OR_RETURN(std::ignore, + xla::tridiagonal::TridiagonalSolver( tridiagonal::SolverAlgorithm::kThomas, parameters[0], parameters[1], parameters[2], parameters[3])); return BuildAndReplace(builder, instruction); @@ -230,10 +229,9 @@ absl::StatusOr XlaBuilderTestPass::RunImpl( // Find custom calls that start with "xla_builder." and expand the HLO if (instruction->opcode() == HloOpcode::kCustomCall && instruction->custom_call_target().rfind("xla_builder.", 0) == 0) { - TF_ASSIGN_OR_RETURN( - bool call_changed, - ReplaceWithExpandedClientHlo(instruction, - instruction->custom_call_target())); + ASSIGN_OR_RETURN(bool call_changed, + ReplaceWithExpandedClientHlo( + instruction, instruction->custom_call_target())); changed |= call_changed; } } diff --git a/third_party/xla/xla/hlo/transforms/expanders/BUILD b/third_party/xla/xla/hlo/transforms/expanders/BUILD index c7e841c9981118..102148158a8e8d 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/BUILD +++ b/third_party/xla/xla/hlo/transforms/expanders/BUILD @@ -406,6 +406,8 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:ir_emission_utils", "//xla/stream_executor:device_description", + "//xla/stream_executor:dnn", + "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:errors", "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", @@ -428,6 +430,7 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", "//xla/stream_executor:device_description", + "//xla/stream_executor:dnn", "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter.cc b/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter.cc index 52c7f6b6c3eab8..d567ae4b9ad914 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" @@ -406,10 +407,16 @@ absl::StatusOr RaggedDotRewriter::RunImpl( .debug_options() .xla_gpu_experimental_use_ragged_dot_grouped_gemm() && module->config().debug_options().xla_gpu_enable_cublaslt(); + const se::CudaComputeCapability* cuda_cc = + gpu_compute_capability_.has_value() + ? gpu_compute_capability_->cuda_compute_capability() + : nullptr; const bool ragged_dot_fusion_enabled = module->config() .debug_options() - .xla_gpu_experimental_use_ragged_dot_fusion(); + .xla_gpu_experimental_use_ragged_dot_fusion() && + cudnn_version_ >= kMinCudnnVersionForRaggedDotFusion && + cuda_cc != nullptr && cuda_cc->IsAtLeastAmpere(); // Gather all Ragged Dot operations. std::vector ragged_dots; diff --git a/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter.h b/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter.h index 9fc072a27fe351..136fbe8950cf6f 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter.h +++ b/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter.h @@ -22,14 +22,21 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" namespace xla { +inline const stream_executor::dnn::VersionInfo + kMinCudnnVersionForRaggedDotFusion(9, 22); + // RaggedDotRewriter converts ragged dots to general dots through expansion. class RaggedDotRewriter : public HloModulePass { public: - explicit RaggedDotRewriter(se::GpuComputeCapability gpu_compute_capability) - : gpu_compute_capability_(gpu_compute_capability) {} + explicit RaggedDotRewriter(se::GpuComputeCapability gpu_compute_capability, + stream_executor::dnn::VersionInfo cudnn_version = + stream_executor::dnn::VersionInfo()) + : gpu_compute_capability_(gpu_compute_capability), + cudnn_version_(cudnn_version) {} absl::string_view name() const override { return "ragged_dot_rewriter"; } @@ -40,6 +47,7 @@ class RaggedDotRewriter : public HloModulePass { private: std::optional gpu_compute_capability_; + stream_executor::dnn::VersionInfo cudnn_version_; }; } // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter_test.cc b/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter_test.cc index 1afb6c3dbf390a..826ac3c6aea7e9 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/ragged_dot_rewriter_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" #include "xla/tsl/platform/statusor.h" namespace xla { @@ -82,10 +83,65 @@ TEST_F(RaggedDotRewriterTest, DontRewriteIfUsingRaggedDotFusion) { module->mutable_config() .mutable_debug_options() .set_xla_gpu_experimental_use_ragged_dot_fusion(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RaggedDotRewriter(GetCudaComputeCapability(), + kMinCudnnVersionForRaggedDotFusion) + .Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(RaggedDotRewriterTest, RewriteIfPreAmpereForRaggedDotFusion) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = bf16[64,9]{1,0} parameter(0) + p1 = bf16[2,9,8]{2,1,0} parameter(1) + p2 = s64[2]{0} parameter(2) + ROOT ragged-dot = bf16[64,8]{1,0} ragged-dot(p0, p1, p2), + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + lhs_ragged_dims={0}, rhs_group_dims={0} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_experimental_use_ragged_dot_fusion(true); + // Volta is pre-Ampere, so even with a recent cuDNN the fusion gate must + // fail and the rewriter must fall back to expansion. TF_ASSERT_OK_AND_ASSIGN( bool changed, - RaggedDotRewriter(GetCudaComputeCapability()).Run(module.get())); - EXPECT_FALSE(changed); + RaggedDotRewriter( + se::CudaComputeCapability{se::CudaComputeCapability::kVolta, 0}, + kMinCudnnVersionForRaggedDotFusion) + .Run(module.get())); + EXPECT_TRUE(changed); +} + +TEST_F(RaggedDotRewriterTest, RewriteIfCudnnVersionTooOldForRaggedDotFusion) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = bf16[64,9]{1,0} parameter(0) + p1 = bf16[2,9,8]{2,1,0} parameter(1) + p2 = s64[2]{0} parameter(2) + ROOT ragged-dot = bf16[64,8]{1,0} ragged-dot(p0, p1, p2), + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + lhs_ragged_dims={0}, rhs_group_dims={0} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_experimental_use_ragged_dot_fusion(true); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + RaggedDotRewriter(GetCudaComputeCapability(), se::dnn::VersionInfo{9, 21}) + .Run(module.get())); + EXPECT_TRUE(changed); } TEST_F(RaggedDotRewriterTest, DontRewriteIfUsingRaggedDotFusionRocm) { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD index acf2dc095e30fc..9a64c9d93c6efe 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/BUILD +++ b/third_party/xla/xla/hlo/transforms/simplifiers/BUILD @@ -29,6 +29,7 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -68,6 +69,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", "//xla/service:all_reduce_key", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -121,6 +123,7 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", ], @@ -164,6 +167,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service:float_support", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -204,6 +208,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -239,6 +244,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service:call_graph", "//xla/service:float_support", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -343,6 +349,7 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:test", "//xla/service:call_graph", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:statusor", @@ -368,6 +375,7 @@ cc_library( "//xla/service/heap_simulator", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", @@ -511,6 +519,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -529,6 +538,7 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -545,6 +555,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -579,6 +590,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -654,6 +666,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -692,6 +705,7 @@ cc_library( "//xla/service:shape_inference", "//xla/service/graphcycles", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -796,6 +810,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -833,6 +848,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -865,6 +881,7 @@ cc_library( "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -907,6 +924,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_creation_utils", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -933,6 +951,7 @@ xla_cc_test( "//xla/service:hlo_verifier", "//xla/service:pattern_matcher", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", @@ -954,6 +973,7 @@ cc_library( "//xla/service:call_graph", "//xla/service:computation_layout", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -1033,6 +1053,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:logical_buffer", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1112,6 +1133,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service:slow_operation_alarm", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1180,6 +1202,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", @@ -1212,6 +1235,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", @@ -1253,6 +1277,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -1288,6 +1313,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1354,6 +1380,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1377,6 +1404,7 @@ xla_cc_test( "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:pattern_matcher", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -1397,6 +1425,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -1432,6 +1461,7 @@ cc_library( "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1475,6 +1505,7 @@ cc_library( "//xla/service:call_graph", "//xla/service:pattern_matcher", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1512,6 +1543,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1580,6 +1612,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:op_expander_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@tsl//tsl/platform:errors", @@ -1617,6 +1650,7 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service:hlo_proto_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1639,6 +1673,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -1653,6 +1688,7 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -1683,6 +1719,7 @@ cc_library( "//xla/hlo/transforms/expanders:op_expander_pass", "//xla/service:gather_scatter_utils", "//xla/service:hlo_creation_utils", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -1745,6 +1782,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -1765,6 +1803,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:collective_ops_utils", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1795,6 +1834,7 @@ cc_library( "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1814,6 +1854,7 @@ xla_cc_test( "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/utils:hlo_matchers", "//xla/service:hlo_module_config", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -1838,6 +1879,7 @@ cc_library( "//xla/service:collective_opt_utils", "//xla/service:hlo_module_config", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1862,6 +1904,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/service:collective_opt_utils", "//xla/service:hlo_module_config", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", @@ -1882,6 +1925,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:shape_inference", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1903,6 +1947,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1938,6 +1983,7 @@ cc_library( "//xla/hlo/utils/concurrency:concurrency_utils", "//xla/service:call_graph", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1978,6 +2024,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1998,6 +2045,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index 0832c4b2fa8241..ab611f87c41fce 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -712,7 +712,7 @@ absl::Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction( // When found a scalar multiply, save its scalar value. values.push_back(*GetConstantValue(multiplier)); // And remove the scalar multiply op. - TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand)); + RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand)); inst = operand; } @@ -739,7 +739,7 @@ absl::Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction( m::Broadcast(m::ConstantScalar(&multiplier))))) { values.push_back(*GetConstantValue(multiplier)); - TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand)); + RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand)); inst = operand; } @@ -907,8 +907,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { Match(add, m::Add(m::Add(m::NonConstant(&a), m::Broadcast(m::ConstantScalar(&c1))), m::Broadcast(m::ConstantScalar(&c2))))) { - TF_ASSIGN_OR_RETURN(auto* sum_of_constants, - MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); + ASSIGN_OR_RETURN(auto* sum_of_constants, + MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); if (ShapeUtil::IsScalar(sum_of_constants->shape()) && !ShapeUtil::IsScalar(add->shape())) { sum_of_constants = add->AddInstruction( @@ -925,8 +925,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { Match(add, m::Add(m::Subtract(m::Broadcast(m::ConstantScalar(&c1)), m::NonConstant(&a)), m::Broadcast(m::ConstantScalar(&c2))))) { - TF_ASSIGN_OR_RETURN(HloInstruction * sum_of_constants, - MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); + ASSIGN_OR_RETURN(HloInstruction * sum_of_constants, + MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); if (ShapeUtil::IsScalar(sum_of_constants->shape()) && !ShapeUtil::IsScalar(add->shape())) { sum_of_constants = add->AddInstruction( @@ -1164,38 +1164,35 @@ absl::Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { lhs_scatter_index->shape().element_type() == rhs_scatter_index->shape().element_type() && ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) { - TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, - MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, - rhs_scatter_operand)); - TF_ASSIGN_OR_RETURN(HloInstruction * new_index, - MakeConcatHlo({lhs_scatter_index, rhs_scatter_index}, - *index_concat_dimension)); - TF_ASSIGN_OR_RETURN( - HloInstruction * new_update, - MakeConcatHlo({lhs_scatter_update, rhs_scatter_update}, - *update_concat_dimension)); + ASSIGN_OR_RETURN(HloInstruction * new_operand, + MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, + rhs_scatter_operand)); + ASSIGN_OR_RETURN(HloInstruction * new_index, + MakeConcatHlo({lhs_scatter_index, rhs_scatter_index}, + *index_concat_dimension)); + ASSIGN_OR_RETURN(HloInstruction * new_update, + MakeConcatHlo({lhs_scatter_update, rhs_scatter_update}, + *update_concat_dimension)); return ReplaceWithNewInstruction( add, HloInstruction::CreateScatter( add->shape(), new_operand, new_index, new_update, lhs->to_apply(), lhs_dnums, false, false)); } - TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, - MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, - rhs_scatter_operand)); - TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand)); - TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, rhs)); + ASSIGN_OR_RETURN(HloInstruction * new_operand, + MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, + rhs_scatter_operand)); + RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand)); + RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, rhs)); return ReplaceInstruction(add, lhs); } else if (rhs_scatter) { - TF_ASSIGN_OR_RETURN( - HloInstruction * new_operand, - MakeBinaryHlo(HloOpcode::kAdd, lhs, rhs_scatter_operand)); - TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand)); + ASSIGN_OR_RETURN(HloInstruction * new_operand, + MakeBinaryHlo(HloOpcode::kAdd, lhs, rhs_scatter_operand)); + RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand)); return ReplaceInstruction(add, rhs); } else if (lhs_scatter) { - TF_ASSIGN_OR_RETURN( - HloInstruction * new_operand, - MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, rhs)); - TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, new_operand)); + ASSIGN_OR_RETURN(HloInstruction * new_operand, + MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, rhs)); + RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, new_operand)); return ReplaceInstruction(add, lhs); } return absl::OkStatus(); @@ -1235,7 +1232,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( std::optional rhs_info = get_compare_info(rhs); if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) { int64_t new_bound = std::min(lhs_info->constant, rhs_info->constant); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + RETURN_IF_ERROR(ReplaceWithNewInstruction( conjunction, HloInstruction::CreateCompare(lhs->shape(), lhs_info->var, MakeScalarLike(lhs_info->var, new_bound), @@ -1294,8 +1291,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleAllReduceOrReduceScatter( HloInstruction::CreateBroadcast(collective->shape(), constant, {})); } case ReductionKind::SUM: { - TF_ASSIGN_OR_RETURN(auto count_and_size, - GetReplicaGroupCountAndSize(collective)); + ASSIGN_OR_RETURN(auto count_and_size, + GetReplicaGroupCountAndSize(collective)); if (!count_and_size.has_value()) { return absl::OkStatus(); } @@ -1379,8 +1376,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleAnd( } // Simplify tautological conjunctions. - TF_ASSIGN_OR_RETURN(bool found_tautological_compare, - TrySimplifyTautologicalCompare(logical_and)); + ASSIGN_OR_RETURN(bool found_tautological_compare, + TrySimplifyTautologicalCompare(logical_and)); if (found_tautological_compare) { return absl::OkStatus(); } @@ -1432,14 +1429,13 @@ absl::Status AlgebraicSimplifierVisitor::HandleBitcast( // If a bitcast feeds a bitcast, make it a single bitcast. // Make sure the whole chain of bitcasts is optimized. if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) { - TF_RETURN_IF_ERROR(HandleBitcast(bitcast->mutable_operand(0))); + RETURN_IF_ERROR(HandleBitcast(bitcast->mutable_operand(0))); } HloInstruction* op; if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) { auto new_bitcast = HloInstruction::CreateBitcast(bitcast->shape(), op); HloInstruction* new_bitcast_ptr = new_bitcast.get(); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(bitcast, std::move(new_bitcast))); + RETURN_IF_ERROR(ReplaceWithNewInstruction(bitcast, std::move(new_bitcast))); bitcast = new_bitcast_ptr; } @@ -1750,8 +1746,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleBitcastConvert( bitcast->shape(), operand->mutable_operand(0))); } - TF_ASSIGN_OR_RETURN(bool replaced, - TrySimplifyTautologicalBitcastConvert(bitcast)); + ASSIGN_OR_RETURN(bool replaced, + TrySimplifyTautologicalBitcastConvert(bitcast)); if (replaced) { return absl::OkStatus(); } @@ -2142,9 +2138,9 @@ AlgebraicSimplifierVisitor::TrySimplifyTautologicalBitcastConvert( } const int64_t concat_dim = concat->concatenate_dimension(); - TF_ASSIGN_OR_RETURN(HloInstruction * new_concat, - MakeConcatHlo(outer_inputs, concat_dim)); - TF_RETURN_IF_ERROR(ReplaceInstruction(bitcast, new_concat)); + ASSIGN_OR_RETURN(HloInstruction * new_concat, + MakeConcatHlo(outer_inputs, concat_dim)); + RETURN_IF_ERROR(ReplaceInstruction(bitcast, new_concat)); return true; } @@ -2222,7 +2218,7 @@ AlgebraicSimplifierVisitor::TryRemoveUpcastAndDowncastSurroundingBinaryOp( computation->AddInstruction(bin_op_instr->CloneWithNewOperands( ShapeUtil::ChangeElementType(bin_op_instr->shape(), final_type), {arg_1, arg_2})); - TF_RETURN_IF_ERROR(ReplaceInstruction(final_convert_instr, new_bin_op)); + RETURN_IF_ERROR(ReplaceInstruction(final_convert_instr, new_bin_op)); return absl::OkStatus(); } @@ -2531,7 +2527,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { primitive_util::IsComplexType( primitive_type_constant)) { using NativeT = NativeTypeOf; - TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); + RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); auto inverse = c->AddInstruction(simplifier_->CreateConstantWithLayoutUpdated( @@ -2540,9 +2536,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { inverse = b->AddInstruction(HloInstruction::CreateBroadcast( b->shape(), inverse, b->dimensions())); } - TF_ASSIGN_OR_RETURN( - auto new_divide, - MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); + ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); return ReplaceInstruction(divide, new_divide); } return absl::OkStatus(); @@ -2575,31 +2570,27 @@ absl::Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Divide(m::Op(&c), m::Op(&d))))) { - TF_ASSIGN_OR_RETURN(auto a_times_d, - MakeBinaryHlo(HloOpcode::kMultiply, a, d)); - TF_ASSIGN_OR_RETURN(auto b_times_c, - MakeBinaryHlo(HloOpcode::kMultiply, b, c)); - TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide, - a_times_d, b_times_c)); + ASSIGN_OR_RETURN(auto a_times_d, MakeBinaryHlo(HloOpcode::kMultiply, a, d)); + ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply, b, c)); + ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a_times_d, b_times_c)); return ReplaceInstruction(divide, new_divide); } // (A / B) / C => A / (B * C) if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { - TF_ASSIGN_OR_RETURN(auto b_times_c, - MakeBinaryHlo(HloOpcode::kMultiply, b, c)); - TF_ASSIGN_OR_RETURN(auto new_divide, - MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c)); + ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply, b, c)); + ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c)); return ReplaceInstruction(divide, new_divide); } // A / (B / C) => (A*C) / B if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { - TF_ASSIGN_OR_RETURN(auto a_times_c, - MakeBinaryHlo(HloOpcode::kMultiply, a, c)); - TF_ASSIGN_OR_RETURN(auto new_divide, - MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b)); + ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, a, c)); + ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b)); return ReplaceInstruction(divide, new_divide); } @@ -2610,12 +2601,12 @@ absl::Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { m::Convert(&a, m::Op().WithShape(m::Shape().WithElementType(PRED))), m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b)); auto recip_bcast = divide->mutable_operand(1)->AddInstruction( HloInstruction::CreateBroadcast(divide->shape(), recip, {})); - TF_ASSIGN_OR_RETURN(auto mul, - MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a)); + ASSIGN_OR_RETURN(auto mul, + MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a)); return ReplaceInstruction(divide, mul); } @@ -2689,15 +2680,15 @@ AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot( ShapeUtil::DropDegenerateDimensions(rhs_shape), dot->mutable_operand(1))) : dot->mutable_operand(1); - TF_ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums, - dot->precision_config(), - dot->shape().element_type())); + ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums, + dot->precision_config(), + dot->shape().element_type())); dot->SetupDerivedInstruction(new_dot); if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) { - TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot)); + RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot)); } else { - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + RETURN_IF_ERROR(ReplaceWithNewInstruction( dot, HloInstruction::CreateReshape(dot->shape(), new_dot))); } return true; @@ -2838,7 +2829,7 @@ AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands( ? SwapOperandsInDotPrecisionConfig(dot->precision_config()) : dot->precision_config())); dot->SetupDerivedInstruction(new_dot); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + RETURN_IF_ERROR(ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, permutation))); return true; @@ -2877,7 +2868,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::MoveDotParamToRhs( std::swap(precision_config.mutable_operand_precision()->at(0), precision_config.mutable_operand_precision()->at(1)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_inst, MakeDotHlo(dot->mutable_operand(1), dot->mutable_operand(0), dot_dims, precision_config, dot->shape().element_type())); @@ -2901,13 +2892,13 @@ absl::StatusOr AlgebraicSimplifierVisitor::MoveDotParamToRhs( for (int i = 0; i != lhs_non_contracting_batch; ++i) { permutation.push_back(num_batch_dims + i); } - TF_ASSIGN_OR_RETURN(HloInstruction * new_transpose, - MakeTransposeHlo(new_dot, permutation)); + ASSIGN_OR_RETURN(HloInstruction * new_transpose, + MakeTransposeHlo(new_dot, permutation)); SetupDerivedInstruction(dot, new_dot, /*preserve_user_fusion_attr=*/true); SetupDerivedInstruction(dot, new_transpose, /*preserve_user_fusion_attr=*/false); - TF_RETURN_IF_ERROR(ReplaceInstruction( - dot, new_transpose, /*preserve_frontend_attributes=*/false)); + RETURN_IF_ERROR(ReplaceInstruction(dot, new_transpose, + /*preserve_frontend_attributes=*/false)); return true; } @@ -2958,7 +2949,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * optimized_lhs_concat, OptimizeDotOfConcatHelper(dot, lhs, lhs_contracting_dim, rhs, rhs_contracting_dim, /*swapped=*/false)); @@ -3513,14 +3504,13 @@ AlgebraicSimplifierVisitor::AssociativeReorderDotOperator( return nullptr; } if (!reordered_dims.empty()) { - TF_ASSIGN_OR_RETURN(reordered, - MakeReverseHlo(reorder_to, reordered_dims)); + ASSIGN_OR_RETURN(reordered, MakeReverseHlo(reorder_to, reordered_dims)); } if (!unreordered_dims.empty()) { // Want to use a greater threshold if reordering means increasing the // number of Hlos threshold_multiplier = 2.0; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( unreordered, MakeReverseHlo(reorder_from->mutable_operand(0), unreordered_dims)); } @@ -3572,8 +3562,8 @@ AlgebraicSimplifierVisitor::AssociativeReorderDotOperator( if (!make_hlo) { return nullptr; } - TF_ASSIGN_OR_RETURN(reordered, MakeSliceHlo(reorder_to, start_indices, - limit_indices, strides)); + ASSIGN_OR_RETURN(reordered, MakeSliceHlo(reorder_to, start_indices, + limit_indices, strides)); // Check if we still need a padding instruction, and create Hlo if so for (auto& dim : new_padding_config.dimensions()) { @@ -3581,7 +3571,7 @@ AlgebraicSimplifierVisitor::AssociativeReorderDotOperator( // Want to use a greater threshold if reordering means increasing // the number of Hlos threshold_multiplier = 2.0; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( unreordered, MakePadHlo(reorder_from->mutable_operand(0), reorder_from->mutable_operand(1), new_padding_config)); @@ -3649,9 +3639,8 @@ AlgebraicSimplifierVisitor::AssociativeReorderDotOperator( MakeBroadcastHlo(reorder_from->mutable_operand(0), reorder_from->dimensions(), new_broadcast_shape); } - TF_ASSIGN_OR_RETURN( - reordered, - MakeReduceHlo(reorder_to, zero, reduce_dims, HloOpcode::kAdd)); + ASSIGN_OR_RETURN(reordered, MakeReduceHlo(reorder_to, zero, reduce_dims, + HloOpcode::kAdd)); } if (!make_hlo) { @@ -3671,9 +3660,9 @@ AlgebraicSimplifierVisitor::AssociativeReorderDotOperator( // Create Hlo for new dot HloInstruction* new_dot; - TF_ASSIGN_OR_RETURN(new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums, - dot->precision_config(), - dot->shape().element_type())); + ASSIGN_OR_RETURN(new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums, + dot->precision_config(), + dot->shape().element_type())); // Do cost analysis to determine whether we should reorder. Reverse uses // the ratio of the two shapes a heuristic, while the others use the @@ -3697,17 +3686,17 @@ absl::Status AlgebraicSimplifierVisitor::RewriteAsMultiplyDotWithZeroLhsContractingDim( HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dnums) { - TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, - NormalizeDotOperandToBatchMajorAndContractingMinor( - lhs, dnums.lhs_batch_dimensions(), - dnums.lhs_contracting_dimensions())); + ASSIGN_OR_RETURN(HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, dnums.lhs_batch_dimensions(), + dnums.lhs_contracting_dimensions())); if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) { new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type()); } - TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, - NormalizeDotOperandToBatchMajorAndContractingMinor( - rhs, dnums.rhs_batch_dimensions(), - dnums.rhs_contracting_dimensions())); + ASSIGN_OR_RETURN(HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, dnums.rhs_batch_dimensions(), + dnums.rhs_contracting_dimensions())); if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) { new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type()); } @@ -3973,11 +3962,10 @@ AlgebraicSimplifierVisitor::AssociativeReorderNestedDot(HloDotInstruction* dot, } // Get Shape for new_inner - TF_ASSIGN_OR_RETURN( - Shape new_inner_shape, - ShapeInference::InferDotOpShape(new_inner_lhs->shape(), - new_inner_rhs->shape(), new_inner_dnums, - new_inner_lhs->shape().element_type())); + ASSIGN_OR_RETURN(Shape new_inner_shape, + ShapeInference::InferDotOpShape( + new_inner_lhs->shape(), new_inner_rhs->shape(), + new_inner_dnums, new_inner_lhs->shape().element_type())); Shape new_outer_lhs_shape = outer_lhs_dot ? inner->operand(0)->shape() : new_inner_shape; @@ -3998,7 +3986,7 @@ AlgebraicSimplifierVisitor::AssociativeReorderNestedDot(HloDotInstruction* dot, if (old_flops / static_cast(new_flops) > options_.associative_reordering_threshold()) { // We can now make the Hlo for new_inner and new_outer - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( new_inner, MakeDotHlo(new_inner_lhs, new_inner_rhs, new_inner_dnums, dot->precision_config(), dot->shape().element_type())); @@ -4010,7 +3998,7 @@ AlgebraicSimplifierVisitor::AssociativeReorderNestedDot(HloDotInstruction* dot, new_outer_lhs = new_inner; new_outer_rhs = inner->mutable_operand(1); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( new_outer, MakeDotHlo(new_outer_lhs, new_outer_rhs, new_outer_dnums, dot->precision_config(), dot->shape().element_type())); @@ -4079,13 +4067,13 @@ AlgebraicSimplifierVisitor::AssociativeReorderNestedDot(HloDotInstruction* dot, if (add_transpose) { HloInstruction* transposed_new_outer; - TF_ASSIGN_OR_RETURN(transposed_new_outer, - MakeTransposeHlo(new_outer, permutation)); + ASSIGN_OR_RETURN(transposed_new_outer, + MakeTransposeHlo(new_outer, permutation)); VLOG(10) << "Reordering with associativity and transpose"; - TF_RETURN_IF_ERROR(ReplaceInstruction(dot, transposed_new_outer)); + RETURN_IF_ERROR(ReplaceInstruction(dot, transposed_new_outer)); } else { VLOG(10) << "Reordering with associativity"; - TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_outer)); + RETURN_IF_ERROR(ReplaceInstruction(dot, new_outer)); } return RewriteResult::kRewritten; } @@ -4095,18 +4083,18 @@ AlgebraicSimplifierVisitor::AssociativeReorderNestedDot(HloDotInstruction* dot, absl::Status AlgebraicSimplifierVisitor::RewriteBatchPlusContractingAsReduce( HloDotInstruction* dot, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dnums) { - TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, - NormalizeDotOperandToBatchMajorAndContractingMinor( - lhs, dnums.lhs_batch_dimensions(), - dnums.lhs_contracting_dimensions())); + ASSIGN_OR_RETURN(HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, dnums.lhs_batch_dimensions(), + dnums.lhs_contracting_dimensions())); if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) { new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type()); } - TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, - NormalizeDotOperandToBatchMajorAndContractingMinor( - rhs, dnums.rhs_batch_dimensions(), - dnums.rhs_contracting_dimensions())); + ASSIGN_OR_RETURN(HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, dnums.rhs_batch_dimensions(), + dnums.rhs_contracting_dimensions())); if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) { new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type()); } @@ -4138,8 +4126,8 @@ absl::Status AlgebraicSimplifierVisitor::RewriteBatchPlusContractingAsReduce( new_lhs->shape(), new_rhs, rhs_broadcast_dims)); } - TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, - MakeMultiplyForPrecisionAlgorithm(dot, new_lhs, new_rhs)); + ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeMultiplyForPrecisionAlgorithm(dot, new_lhs, new_rhs)); std::vector reduce_dims(dnums.lhs_contracting_dimensions_size()); PrimitiveType dot_type = @@ -4212,8 +4200,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // Reorder nested dots with associativity using flops as a heuristic if (options_.use_associative_reordering()) { - TF_ASSIGN_OR_RETURN(RewriteResult result, - AssociativeReorderNestedDot(dot_cast, lhs, rhs)); + ASSIGN_OR_RETURN(RewriteResult result, + AssociativeReorderNestedDot(dot_cast, lhs, rhs)); if (result == RewriteResult::kRewritten || result == RewriteResult::kStopRewrites) { return absl::OkStatus(); @@ -4221,8 +4209,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } if (options_.use_associative_reordering()) { - TF_ASSIGN_OR_RETURN(HloInstruction * dot_operator_reordered, - AssociativeReorderDotOperator(dot_cast)); + ASSIGN_OR_RETURN(HloInstruction * dot_operator_reordered, + AssociativeReorderDotOperator(dot_cast)); if (dot_operator_reordered) { VLOG(10) << "Reordering dot operand to its mirror"; return ReplaceInstruction(dot, dot_operator_reordered); @@ -4244,8 +4232,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // Simplify dot(reshape(transpose(A)), Const) to: // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape // and transpose on the Const side can be constant folded. - TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized, - OptimizeDotOfReorderContractingDims(dot)); + ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized, + OptimizeDotOfReorderContractingDims(dot)); if (dot_of_reorder_optimized) { VLOG(10) << " Replaced dot " << dot->ToString() << " with new dot operation: " @@ -4253,8 +4241,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_reorder_optimized); } - TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, - OptimizeDotOfConcat(dot)); + ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, + OptimizeDotOfConcat(dot)); if (dot_of_concat_optimized) { VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., " "constant)...)"; @@ -4264,27 +4252,27 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // Simplify dot(ConstA, Gather(Index, ConstB)) to: // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately // batched version of dot. - TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized, - OptimizeDotOfGather(dot)); + ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized, + OptimizeDotOfGather(dot)); if (dot_of_gather_optimized) { VLOG(10) << "Replaced dot(constA, gather(i, constB)) with " "gather(i, dot*(constA, constB))"; return ReplaceInstruction(dot, dot_of_gather_optimized); } - TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions, - RemoveDegenerateDimensionFromDot(dot_cast)); + ASSIGN_OR_RETURN(bool removed_degenerate_dimensions, + RemoveDegenerateDimensionFromDot(dot_cast)); if (removed_degenerate_dimensions) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(bool removed_transposes, - RemoveTransposesFromDotOperands(dot_cast)); + ASSIGN_OR_RETURN(bool removed_transposes, + RemoveTransposesFromDotOperands(dot_cast)); if (removed_transposes) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(bool moved_param_to_rhs, MoveDotParamToRhs(dot_cast)); + ASSIGN_OR_RETURN(bool moved_param_to_rhs, MoveDotParamToRhs(dot_cast)); if (moved_param_to_rhs) { return absl::OkStatus(); } @@ -4424,10 +4412,10 @@ absl::Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) { if (ShapeUtil::IsEffectiveScalar(operand_shape)) { HloInstruction* new_operand = gather->mutable_operand(0); if (!operand_shape.dimensions().empty()) { - TF_ASSIGN_OR_RETURN(new_operand, - MakeReshapeHlo(ShapeUtil::MakeScalarShape( - operand_shape.element_type()), - new_operand)); + ASSIGN_OR_RETURN(new_operand, + MakeReshapeHlo(ShapeUtil::MakeScalarShape( + operand_shape.element_type()), + new_operand)); } HloInstruction* new_gather = MakeBroadcastHlo(new_operand, {}, gather->shape()); @@ -4659,10 +4647,10 @@ absl::StatusOr> MinMaxToClamp( const Literal& upper_bound = Cast(clamp_upper_bound)->literal(); - TF_ASSIGN_OR_RETURN(Literal lower_bound_literal_reshaped, - lower_bound.Reshape({})); - TF_ASSIGN_OR_RETURN(Literal upper_bound_literal_reshaped, - upper_bound.Reshape({})); + ASSIGN_OR_RETURN(Literal lower_bound_literal_reshaped, + lower_bound.Reshape({})); + ASSIGN_OR_RETURN(Literal upper_bound_literal_reshaped, + upper_bound.Reshape({})); std::unique_ptr lower_bound_instr = HloInstruction::CreateConstant(std::move(lower_bound_literal_reshaped)); std::unique_ptr upper_bound_instr = @@ -4677,8 +4665,7 @@ absl::StatusOr> MinMaxToClamp( ComparisonDirection::kLt); HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result, - evaluator.Evaluate(cloned_instruction.get())); + ASSIGN_OR_RETURN(auto result, evaluator.Evaluate(cloned_instruction.get())); if (result.IsAll(true)) { return HloInstruction::CreateTernary(to_clamp->shape(), HloOpcode::kClamp, clamp_lower_bound_bcast, to_clamp, @@ -4743,7 +4730,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleMaximum( // Note that we cannot simplify to max(x, y) here, as for the case that x and // y are NaN but with different sign, it will make a difference. if (Match(rhs, m::Maximum(m::Op(), m::Op().Is(lhs)))) { - TF_RETURN_IF_ERROR(maximum->ReplaceOperandWith(1, rhs->mutable_operand(0))); + RETURN_IF_ERROR(maximum->ReplaceOperandWith(1, rhs->mutable_operand(0))); MarkAsChanged(); return absl::OkStatus(); } @@ -4758,9 +4745,9 @@ absl::Status AlgebraicSimplifierVisitor::HandleMaximum( m::Op(&to_clamp), m::Broadcast(&clamp_upper_bound_bcast, m::ConstantEffectiveScalar()))))) { - TF_ASSIGN_OR_RETURN(auto clamp, - MinMaxToClamp(clamp_lower_bound_bcast, to_clamp, - clamp_upper_bound_bcast, simplifier_)); + ASSIGN_OR_RETURN(auto clamp, + MinMaxToClamp(clamp_lower_bound_bcast, to_clamp, + clamp_upper_bound_bcast, simplifier_)); if (clamp) { return ReplaceWithNewInstruction(maximum, std::move(clamp)); } @@ -4798,10 +4785,9 @@ absl::Status AlgebraicSimplifierVisitor::HandleMaximum( (check-sat) */ if (lhs->opcode() == rhs->opcode() && IsNondecreasingSublinear(lhs)) { - TF_ASSIGN_OR_RETURN( - auto new_maximum, - MakeBinaryHlo(HloOpcode::kMaximum, lhs->mutable_operand(0), - rhs->mutable_operand(0))); + ASSIGN_OR_RETURN(auto new_maximum, + MakeBinaryHlo(HloOpcode::kMaximum, lhs->mutable_operand(0), + rhs->mutable_operand(0))); VLOG(10) << "Sinking nondecreasing op through max"; return ReplaceWithNewInstruction( maximum, HloInstruction::CreateUnary(maximum->shape(), lhs->opcode(), @@ -4853,7 +4839,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleMinimum( // Note that we cannot simplify to min(x, y) here, as for the case that x and // y are NaN but with different sign, it will make a difference. if (Match(rhs, m::Minimum(m::Op(), m::Op().Is(lhs)))) { - TF_RETURN_IF_ERROR(minimum->ReplaceOperandWith(1, rhs->mutable_operand(0))); + RETURN_IF_ERROR(minimum->ReplaceOperandWith(1, rhs->mutable_operand(0))); MarkAsChanged(); return absl::OkStatus(); } @@ -4868,9 +4854,9 @@ absl::Status AlgebraicSimplifierVisitor::HandleMinimum( m::Op(&to_clamp), m::Broadcast(&clamp_lower_bound_bcast, m::ConstantEffectiveScalar()))))) { - TF_ASSIGN_OR_RETURN(auto clamp, - MinMaxToClamp(clamp_lower_bound_bcast, to_clamp, - clamp_upper_bound_bcast, simplifier_)); + ASSIGN_OR_RETURN(auto clamp, + MinMaxToClamp(clamp_lower_bound_bcast, to_clamp, + clamp_upper_bound_bcast, simplifier_)); if (clamp) { return ReplaceWithNewInstruction(minimum, std::move(clamp)); } @@ -4996,8 +4982,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleMultiply( HloInstruction *a, *b; if (Match(multiply, m::Multiply(m::Negate(m::Op(&a)), m::Negate(m::Op(&b))))) { - TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, a)); - TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, b)); + RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, a)); + RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, b)); MarkAsChanged(); return absl::OkStatus(); } @@ -5007,8 +4993,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleMultiply( HloInstruction* abs_operand; if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) && !ShapeUtil::ElementIsComplex(abs_operand->shape())) { - TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand)); - TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand)); + RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand)); + RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand)); MarkAsChanged(); return absl::OkStatus(); } @@ -5040,8 +5026,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleMultiply( m::MultiplyAnyOrder( m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)), m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) { - TF_ASSIGN_OR_RETURN(auto* product_of_constants, - MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); + ASSIGN_OR_RETURN(auto* product_of_constants, + MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); if (ShapeUtil::IsScalar(product_of_constants->shape()) && !ShapeUtil::IsScalar(multiply->shape())) { product_of_constants = @@ -5065,8 +5051,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleMultiply( m::MultiplyAnyOrder( m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)), m::Constant(&c2)))) { - TF_ASSIGN_OR_RETURN(auto* product_of_constants, - MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); + ASSIGN_OR_RETURN(auto* product_of_constants, + MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); if (ShapeUtil::IsScalar(product_of_constants->shape()) && !ShapeUtil::IsScalar(multiply->shape())) { product_of_constants = @@ -5126,8 +5112,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleMultiply( m::Multiply( m::Multiply(m::Op(&a), m::Broadcast(m::ConstantScalar(&c1))), m::Broadcast(m::ConstantScalar(&c2))))) { - TF_ASSIGN_OR_RETURN(auto* product_of_constants, - MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); + ASSIGN_OR_RETURN(auto* product_of_constants, + MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); if (ShapeUtil::IsScalar(product_of_constants->shape()) && !ShapeUtil::IsScalar(multiply->shape())) { product_of_constants = @@ -5258,9 +5244,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { auto non_zero_b = log->mutable_operand(0)->AddInstruction(HloInstruction::CreateBinary( log->shape(), HloOpcode::kMultiply, new_log, b)); - TF_ASSIGN_OR_RETURN( - auto b_is_zero, - MakeCompareHlo(Comparison::Direction::kEq, b, MakeScalarLike(b, 0.0))); + ASSIGN_OR_RETURN(auto b_is_zero, MakeCompareHlo(Comparison::Direction::kEq, + b, MakeScalarLike(b, 0.0))); simplifier_->UpdateLayout(b_is_zero->mutable_shape()); return ReplaceWithNewInstruction( log, HloInstruction::CreateTernary(log->shape(), HloOpcode::kSelect, @@ -5364,7 +5349,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleOptimizationBarrier( HloInstruction* new_operand = operand->AddInstruction(HloInstruction::CreateTuple(operands)); - TF_RETURN_IF_ERROR(barrier->ReplaceOperandWithDifferentShape(0, new_operand)); + RETURN_IF_ERROR(barrier->ReplaceOperandWithDifferentShape(0, new_operand)); *barrier->mutable_shape() = new_operand->shape(); for (auto use : barrier->users()) { CHECK_EQ(use->opcode(), HloOpcode::kGetTupleElement); @@ -5551,8 +5536,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleBroadcast( } if (options_.enable_sink_broadcast()) { - TF_ASSIGN_OR_RETURN(bool sink_succeeded, - TryToSinkBroadcastAfterElementwiseOps(broadcast)); + ASSIGN_OR_RETURN(bool sink_succeeded, + TryToSinkBroadcastAfterElementwiseOps(broadcast)); if (sink_succeeded) { MarkAsChanged(); return absl::OkStatus(); @@ -5751,21 +5736,21 @@ absl::Status AlgebraicSimplifierVisitor::HandleCompare( HloInstruction* b; if (Match(lhs, m::Maximum(m::Op(&a), m::Op(&b)))) { if (rhs == a) { // Gt(Max(a,b), a) -> Gt(b,a) - TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(0, b)); + RETURN_IF_ERROR(compare->ReplaceOperandWith(0, b)); MarkAsChanged(); return absl::OkStatus(); } else if (rhs == b) { // Gt(Max(a,b), b) -> Gt(a,b) - TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(0, a)); + RETURN_IF_ERROR(compare->ReplaceOperandWith(0, a)); MarkAsChanged(); return absl::OkStatus(); } } else if (Match(rhs, m::Minimum(m::Op(&a), m::Op(&b)))) { if (lhs == a) { // Gt(a, Min(a,b)) -> Gt(a,b) - TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(1, b)); + RETURN_IF_ERROR(compare->ReplaceOperandWith(1, b)); MarkAsChanged(); return absl::OkStatus(); } else if (lhs == b) { // Gt(b, Min(a,b)) -> Gt(b,a) - TF_RETURN_IF_ERROR(compare->ReplaceOperandWith(1, a)); + RETURN_IF_ERROR(compare->ReplaceOperandWith(1, a)); MarkAsChanged(); return absl::OkStatus(); } @@ -5811,8 +5796,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleConvert( primitive_util::BitWidth(dest_type) <= primitive_util::BitWidth(src_type) && constant->user_count() == 1 && primitive_util::BitWidth(dest_type) >= 8) { - TF_ASSIGN_OR_RETURN(Literal dest_literal, - constant->literal().Convert(dest_type)); + ASSIGN_OR_RETURN(Literal dest_literal, + constant->literal().Convert(dest_type)); VLOG(10) << "Replacing convert(constant) with constant"; return ReplaceWithNewInstruction( convert, HloInstruction::CreateConstant(std::move(dest_literal))); @@ -6064,16 +6049,16 @@ absl::Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { } } - TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad, - MakePadHlo(pad->mutable_operand(0), - pad->mutable_operand(1), nonzero_padding)); + ASSIGN_OR_RETURN(HloInstruction * nonzero_pad, + MakePadHlo(pad->mutable_operand(0), + pad->mutable_operand(1), nonzero_padding)); // MakePadHlo assumes that the return type matches the type of the operand, // but that's not required. Use the type from the original pad instruction. nonzero_pad->mutable_shape()->set_element_type(pad->shape().element_type()); // Copy the layout from the original pad instructions. The new pad and the // slice instruction should all have the same layout. - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( pad->shape(), nonzero_pad->mutable_shape())); simplifier_->UpdateLayout(nonzero_pad->mutable_shape()); @@ -6098,10 +6083,10 @@ absl::Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { strides.push_back(1); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * slice, MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides)); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( pad->shape(), slice->mutable_shape())); simplifier_->UpdateLayout(slice->mutable_shape()); @@ -6256,7 +6241,7 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterElementwiseOps( broadcast->AddInstruction(HloInstruction::CreateBroadcast( user->shape(), new_user, broadcast->dimensions())); VLOG(4) << " new broadcast: " << new_broadcast->ToString(); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); + RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); changed = true; } return changed; @@ -7063,7 +7048,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( VLOG(10) << "Folding scalar slice of concat into concat operand"; } else { VLOG(10) << "Folding scalar slice of concat into slice of concat operand"; - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + RETURN_IF_ERROR(ReplaceWithNewInstruction( slice, HloInstruction::CreateSlice( slice->shape(), concat->mutable_operand(operand_num), {slice->slice_starts(0) - operand_start}, @@ -7123,7 +7108,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( new_slice_operand, new_slice_starts, new_slice_limits, new_slice_stides)); simplifier_->UpdateLayout(new_slice->mutable_shape()); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + RETURN_IF_ERROR(ReplaceWithNewInstruction( slice, HloInstruction::CreateReshape(slice->shape(), new_slice))); return true; } @@ -7165,7 +7150,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse( HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts, new_limits, new_strides)); simplifier_->UpdateLayout(new_slice->mutable_shape()); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + RETURN_IF_ERROR(ReplaceWithNewInstruction( slice, HloInstruction::CreateReverse(new_slice->shape(), new_slice, reverse->dimensions()))); // We do not delete the old reverse, since there might be another @@ -7208,7 +7193,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::RemoveRedundantStride( } HloInstruction* slice_operand = slice->mutable_operand(0); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + RETURN_IF_ERROR(ReplaceWithNewInstruction( slice, HloInstruction::CreateSlice(slice->shape(), slice_operand, slice->slice_starts(), new_slice_limits, new_slice_strides))); @@ -7270,9 +7255,9 @@ absl::Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { return absl::OkStatus(); } if (slice_inside_pad) { - TF_ASSIGN_OR_RETURN(HloInstruction * new_slice, - MakeSliceHlo(pad_operand, new_starts, new_limits, - slice->slice_strides())); + ASSIGN_OR_RETURN(HloInstruction * new_slice, + MakeSliceHlo(pad_operand, new_starts, new_limits, + slice->slice_strides())); *(new_slice->mutable_shape()) = slice->shape(); return ReplaceInstruction(slice, new_slice); } @@ -7315,7 +7300,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { slice->mutable_operand(0)->dimensions())); } - TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice)); + ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice)); if (replaced) { return absl::OkStatus(); } @@ -7419,18 +7404,16 @@ absl::Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { HloInstruction* new_lhs = lhs; HloInstruction* new_rhs = rhs; if (slice_lhs) { - TF_ASSIGN_OR_RETURN( - new_lhs, - MakeSliceHlo(lhs, lhs_start_indices, lhs_limit_indices, lhs_strides)); + ASSIGN_OR_RETURN(new_lhs, MakeSliceHlo(lhs, lhs_start_indices, + lhs_limit_indices, lhs_strides)); } if (slice_rhs) { - TF_ASSIGN_OR_RETURN( - new_rhs, - MakeSliceHlo(rhs, rhs_start_indices, rhs_limit_indices, rhs_strides)); + ASSIGN_OR_RETURN(new_rhs, MakeSliceHlo(rhs, rhs_start_indices, + rhs_limit_indices, rhs_strides)); } // Finally, create Hlo for the new dot and reorder - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_dot, MakeDotHlo(new_lhs, new_rhs, dnums, dot->precision_config(), dot->shape().element_type())); @@ -7488,7 +7471,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { starts[concat_dim] + slice->shape().dimensions(concat_dim); HloInstruction* operand = concat->mutable_operand(*start_operand); if (*start_operand + 1 != *limit_operand) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_concat, MakeConcatHlo( absl::MakeSpan(concat->operands()) @@ -7591,7 +7574,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { // Do not try to reorder slices and reshapes after layout assignment as it may // be invalid. if (!options_.is_layout_sensitive()) { - TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); + ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); } if (replaced) { return absl::OkStatus(); @@ -7599,14 +7582,13 @@ absl::Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { bool reversed = false; if (Match(slice, m::Slice(m::Reverse(m::Op())))) { - TF_ASSIGN_OR_RETURN(reversed, TryToReorderSliceAndReverse(slice)); + ASSIGN_OR_RETURN(reversed, TryToReorderSliceAndReverse(slice)); } if (reversed) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(bool removed_redundant_stride, - RemoveRedundantStride(slice)); + ASSIGN_OR_RETURN(bool removed_redundant_stride, RemoveRedundantStride(slice)); if (removed_redundant_stride) { VLOG(10) << "Removed redundant stride for slice op."; } @@ -7875,7 +7857,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleDynamicSlice( // ds(ds(x,id),inner_id) -> ds(x, id + inner_id) if (operand->opcode() == HloOpcode::kDynamicSlice) { - TF_RETURN_IF_ERROR(dynamic_slice->ReplaceOperandWithDifferentShape( + RETURN_IF_ERROR(dynamic_slice->ReplaceOperandWithDifferentShape( 0, operand->mutable_operand(0))); for (int64_t i = 1; i < dynamic_slice->operand_count(); ++i) { HloInstruction* index = dynamic_slice->mutable_operand(i); @@ -7894,7 +7876,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleDynamicSlice( HloInstruction* combined_index = operand->AddInstruction(HloInstruction::CreateBinary( index->shape(), HloOpcode::kAdd, index, inner_index)); - TF_RETURN_IF_ERROR(dynamic_slice->ReplaceOperandWith(i, combined_index)); + RETURN_IF_ERROR(dynamic_slice->ReplaceOperandWith(i, combined_index)); } MarkAsChanged(); } @@ -8054,7 +8036,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( absl::MakeConstSpan(dynamic_update_slice->operands()).subspan(2), absl::MakeConstSpan(dus_update->operand(0)->operands()) .subspan(1)))) { - TF_RETURN_IF_ERROR(dynamic_update_slice->ReplaceOperandWithDifferentShape( + RETURN_IF_ERROR(dynamic_update_slice->ReplaceOperandWithDifferentShape( 1, dus_update->mutable_operand(1))); for (int64_t i = 2; i < dynamic_update_slice->operand_count(); ++i) { HloInstruction* index = dynamic_update_slice->mutable_operand(i); @@ -8074,7 +8056,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( HloInstruction* combined_index = dus_update->AddInstruction(HloInstruction::CreateBinary( index->shape(), HloOpcode::kAdd, index, inner_index)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( dynamic_update_slice->ReplaceOperandWith(i, combined_index)); } MarkAsChanged(); @@ -8284,12 +8266,10 @@ AlgebraicSimplifierVisitor::ReorderReduceDotToDotReduce( } // Create Hlo for reducing a and b - TF_ASSIGN_OR_RETURN( - HloInstruction * reduce_a, - MakeReduceHlo(a, init_value, reduce_a_dims, function)); - TF_ASSIGN_OR_RETURN( - HloInstruction * reduce_b, - MakeReduceHlo(b, init_value, reduce_b_dims, function)); + ASSIGN_OR_RETURN(HloInstruction * reduce_a, + MakeReduceHlo(a, init_value, reduce_a_dims, function)); + ASSIGN_OR_RETURN(HloInstruction * reduce_b, + MakeReduceHlo(b, init_value, reduce_b_dims, function)); // Construct maps from reduce_a and reduce_b to a and b std::vector map_reduce_a_a(reduce_a->shape().dimensions().size(), @@ -8338,7 +8318,7 @@ AlgebraicSimplifierVisitor::ReorderReduceDotToDotReduce( } // Create Hlo for new dot - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_dot, MakeDotHlo(reduce_a, reduce_b, new_dot_dnums, arg->precision_config(), reduce->shape().element_type())); @@ -8452,7 +8432,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { IsScalarConstantZero(init_value) && Match(reduce->to_apply()->root_instruction(), m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) { - TF_RETURN_IF_ERROR(reduce->ReplaceOperandWith(0, negate_arg)); + RETURN_IF_ERROR(reduce->ReplaceOperandWith(0, negate_arg)); auto users = reduce->users(); auto* negated_reduce = arg->AddInstruction(HloInstruction::CreateUnary( reduce->shape(), HloOpcode::kNegate, reduce)); @@ -8542,8 +8522,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } } } - TF_ASSIGN_OR_RETURN(HloInstruction * new_transpose, - MakeTransposeHlo(new_reduce, new_transpose_dimensions)); + ASSIGN_OR_RETURN(HloInstruction * new_transpose, + MakeTransposeHlo(new_reduce, new_transpose_dimensions)); return ReplaceInstruction(reduce, new_transpose); } @@ -8629,12 +8609,12 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { new_reduce_dims.push_back(matching_dim_it->first); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_reduce, MakeReduceHlo(arg->mutable_operand(0), init_value, new_reduce_dims, reduce->to_apply(), &reduce->metadata())); - TF_ASSIGN_OR_RETURN(HloInstruction * new_reshape, - MakeReshapeHlo(reduce->shape(), new_reduce)); + ASSIGN_OR_RETURN(HloInstruction * new_reshape, + MakeReshapeHlo(reduce->shape(), new_reduce)); return ReplaceInstruction(reduce, new_reshape); } } @@ -8713,7 +8693,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { reduce_dims.push_back(dim - removed_dims); } } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto new_dot, MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(), /*preferred_element_type=*/dot->shape().element_type())); @@ -8721,7 +8701,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { if (reduce_dims.empty()) { return ReplaceInstruction(hlo, new_dot); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto new_reduce, MakeReduceHlo(new_dot, init_value, reduce_dims, HloOpcode::kAdd)); reduce->SetupDerivedInstruction(new_reduce); @@ -8881,10 +8861,10 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { // instructions. HloInstruction* multiplier = MakeScalarLike(arg->mutable_operand(0), common_dims_prod); - TF_ASSIGN_OR_RETURN(HloInstruction * multiplied_scalar, - MakeBinaryHlo(HloOpcode::kMultiply, - arg->mutable_operand(0), multiplier)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(HloInstruction * multiplied_scalar, + MakeBinaryHlo(HloOpcode::kMultiply, + arg->mutable_operand(0), multiplier)); + ASSIGN_OR_RETURN( HloInstruction * add, MakeBinaryHlo( HloOpcode::kAdd, @@ -9065,9 +9045,9 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduceWindow( if (valid_pattern && reduction_dim != -1) { if (val_const->shape().element_type() != reduce_window->shape().element_type()) { - TF_ASSIGN_OR_RETURN(Literal dest_literal, - val_const->literal().Convert( - reduce_window->shape().element_type())); + ASSIGN_OR_RETURN(Literal dest_literal, + val_const->literal().Convert( + reduce_window->shape().element_type())); val_const = reduce_window->AddInstruction( HloInstruction::CreateConstant(std::move(dest_literal))); } @@ -9134,8 +9114,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduceWindow( const HloInstruction* nested_root = function->root_instruction(); DimensionVector broadcast_dims(nested_root->shape().dimensions().size()); absl::c_iota(broadcast_dims, 0); - TF_ASSIGN_OR_RETURN( - auto new_op, MakeBinaryHlo(nested_root->opcode(), operand, + ASSIGN_OR_RETURN(auto new_op, + MakeBinaryHlo(nested_root->opcode(), operand, MakeBroadcastHlo(init_value, broadcast_dims, operand->shape()))); @@ -9149,8 +9129,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduceWindow( padding_dim.set_edge_padding_high(window_dim.padding_high()); padding_dim.set_interior_padding(0); } - TF_ASSIGN_OR_RETURN(new_op, - MakePadHlo(new_op, init_value, padding_config)); + ASSIGN_OR_RETURN(new_op, MakePadHlo(new_op, init_value, padding_config)); } return ReplaceInstruction(reduce_window, new_op); @@ -9683,14 +9662,14 @@ absl::Status AlgebraicSimplifierVisitor::HandleTranspose( DotDimensionNumbers new_dnums = dnums; std::swap(*new_dnums.mutable_lhs_contracting_dimensions(), *new_dnums.mutable_rhs_contracting_dimensions()); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + RETURN_IF_ERROR(ReplaceWithNewInstruction( transpose, HloInstruction::CreateDot( transpose->shape(), /*lhs=*/rhs, /*rhs=*/lhs, new_dnums, SwapOperandsInDotPrecisionConfig(dot->precision_config())))); return true; }; - TF_ASSIGN_OR_RETURN(bool did_transpose_of_dot, do_transpose_of_dot()); + ASSIGN_OR_RETURN(bool did_transpose_of_dot, do_transpose_of_dot()); if (did_transpose_of_dot) { return absl::OkStatus(); } @@ -9700,7 +9679,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleTranspose( if (options_.supports_non_canonical_dots() && Match(operand, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs))) && dot->user_count() == 1) { - TF_ASSIGN_OR_RETURN(bool did_transform, [&]() -> absl::StatusOr { + ASSIGN_OR_RETURN(bool did_transform, [&]() -> absl::StatusOr { if (!consider_swapping_dot_operands(operand)) { return false; } @@ -9789,7 +9768,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleTranspose( VLOG(2) << "trying depth-to-space transform"; HloInstruction* reshape_operand = operand->mutable_operand(0); HloInstruction* outer_reshape = transpose->users()[0]; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool did_transform, ([&]() -> absl::StatusOr { if (operand->shape().dimensions().size() != reshape_operand->shape().dimensions().size() + 1) { @@ -9911,17 +9890,17 @@ absl::Status AlgebraicSimplifierVisitor::HandleTranspose( } strides.push_back(1); } - TF_ASSIGN_OR_RETURN(HloInstruction* const slice, - MakeSliceHlo(reshape_operand, start_indices, - end_indices, strides)); + ASSIGN_OR_RETURN(HloInstruction* const slice, + MakeSliceHlo(reshape_operand, start_indices, + end_indices, strides)); slices.push_back(slice); VLOG(2) << "slice " << i << " " << slice->ToString(); } - TF_ASSIGN_OR_RETURN(HloInstruction* const concat, - MakeConcatHlo(slices, transpose_dim)); + ASSIGN_OR_RETURN(HloInstruction* const concat, + MakeConcatHlo(slices, transpose_dim)); VLOG(2) << "concat " << concat->ToString(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( outer_reshape->ReplaceOperandWithDifferentShape(0, concat)); return true; @@ -9932,7 +9911,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleTranspose( } } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SimplifyTransposeOfBroadcast(transpose, transpose->dimensions())); return absl::OkStatus(); @@ -9992,7 +9971,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( auto new_conv = convolution->CloneWithNewOperands(convolution->shape(), {a, b}); new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ReplaceWithNewInstruction(convolution, std::move(new_conv))); return true; } @@ -10060,8 +10039,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( auto new_conv = convolution->CloneWithNewOperands( convolution->shape(), {lhs, rhs->mutable_operand(0)}); new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); + RETURN_IF_ERROR(ReplaceWithNewInstruction(convolution, std::move(new_conv))); return true; } @@ -10070,10 +10048,10 @@ absl::StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) { return false; } - TF_ASSIGN_OR_RETURN(bool changed, - SwapConvolutionOperandsIfBeneficial( - DynCast(convolution), - options_.conv_is_lowerable_callback())); + ASSIGN_OR_RETURN(bool changed, + SwapConvolutionOperandsIfBeneficial( + DynCast(convolution), + options_.conv_is_lowerable_callback())); if (changed) { MarkAsChanged(); } @@ -10151,7 +10129,7 @@ AlgebraicSimplifierVisitor::PromoteConvolutionToF32IfNotOnednnCompatible( to_conv->AddInstruction(HloInstruction::CreateConvert( ShapeUtil::ChangeElementType(to_conv->shape(), from_dtype), to_conv)); - TF_RETURN_IF_ERROR(ReplaceInstruction(*convolution, from_conv)); + RETURN_IF_ERROR(ReplaceInstruction(*convolution, from_conv)); *convolution = to_conv; return false; } @@ -10272,7 +10250,7 @@ absl::StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, convolution->precision_config())); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot))); return true; } @@ -10352,14 +10330,14 @@ absl::StatusOr AlgebraicSimplifierVisitor::SimplifyConvToMultiply( // Update shapes of the operands, if necessary. if (!absl::c_is_sorted(input_permutation)) { - TF_ASSIGN_OR_RETURN(input, MakeTransposeHlo(input, input_permutation)); + ASSIGN_OR_RETURN(input, MakeTransposeHlo(input, input_permutation)); } if (!ShapeUtil::SameElementType(input_shape, convolution_shape)) { input = MakeConvertToHlo(input, convolution_shape.element_type()); } if (!absl::c_is_sorted(kernel_permutation)) { - TF_ASSIGN_OR_RETURN(kernel, MakeTransposeHlo(kernel, kernel_permutation)); + ASSIGN_OR_RETURN(kernel, MakeTransposeHlo(kernel, kernel_permutation)); } if (!ShapeUtil::SameElementType(kernel_shape, convolution_shape)) { kernel = MakeConvertToHlo(kernel, convolution_shape.element_type()); @@ -10371,27 +10349,27 @@ absl::StatusOr AlgebraicSimplifierVisitor::SimplifyConvToMultiply( input->shape(), kernel, [&](std::unique_ptr added) { return convolution->parent()->AddInstruction(std::move(added)); })); - TF_ASSIGN_OR_RETURN(HloInstruction * result, - MakeBinaryHlo(HloOpcode::kMultiply, input, kernel)); + ASSIGN_OR_RETURN(HloInstruction * result, + MakeBinaryHlo(HloOpcode::kMultiply, input, kernel)); if (!reduction_dimensions.empty()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * sum, MakeReduceHlo( result, MakeConvertToHlo(MakeR0ConstantHlo(convolution->parent(), 0), convolution_shape.element_type()), reduction_dimensions, HloOpcode::kAdd)); - TF_ASSIGN_OR_RETURN(result, MakeReshapeHlo(convolution_shape, sum)); + ASSIGN_OR_RETURN(result, MakeReshapeHlo(convolution_shape, sum)); } - TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, result)); + RETURN_IF_ERROR(ReplaceInstruction(convolution, result)); return true; } absl::Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution) { if (options_.enable_scalar_multiply_reduction()) { - TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution)); + RETURN_IF_ERROR(ScalarMultiplyReduction(convolution)); } // Zero-sized input or filter. @@ -10401,19 +10379,19 @@ absl::Status AlgebraicSimplifierVisitor::HandleConvolution( } // Try to merge padding/dilation of the input with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); + ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); if (folded_input_pad) { return absl::OkStatus(); } // Try to merge dilation of the filter with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); + ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); if (folded_filter_pad) { return absl::OkStatus(); } // Try to swap convolution operands. - TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution)); + ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution)); if (swapped) { return absl::OkStatus(); } @@ -10421,7 +10399,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleConvolution( if (options_.enable_onednn_support()) { // Convert the data type back to F32 if we can't rewrite BF16 convolution to // oneDNN custom call. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool can_rewrite_bf16_conv_to_onednn, PromoteConvolutionToF32IfNotOnednnCompatible(&convolution)); if (can_rewrite_bf16_conv_to_onednn) { @@ -10430,12 +10408,12 @@ absl::Status AlgebraicSimplifierVisitor::HandleConvolution( } // Try to replace the convolution with a kDot or a kMultiply instruction. - TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); + ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); if (replaced_with_dot) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(bool replaced_with_multiply, - SimplifyConvToMultiply(convolution)); + ASSIGN_OR_RETURN(bool replaced_with_multiply, + SimplifyConvToMultiply(convolution)); if (replaced_with_multiply) { return absl::OkStatus(); } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_pad_ds_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_pad_ds_simplifier.cc index 7c0908f6e6250b..cd68fcc1dd577c 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_pad_ds_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_pad_ds_simplifier.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/collective_op_group_mode.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -967,7 +968,7 @@ absl::StatusOr AllGatherPadDsSimplifier::RunImpl( for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { AllGatherPadDsSimplifierVisitor visitor; - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); changed |= visitor.changed(); } return changed; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_pad_ds_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_pad_ds_simplifier_test.cc index ef01cba23ce99c..f584b46d4642d2 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_pad_ds_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_pad_ds_simplifier_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -56,8 +57,8 @@ class AllGatherPadDsSimplifierTest : public HloHardwareIndependentTestBase { /*replica_count=*/num_replicas, /*num_partitions=*/num_partitions); config.set_use_spmd_partitioning(num_partitions > 1); - TF_ASSIGN_OR_RETURN(auto module, - ParseAndReturnVerifiedModule(hlo_module, config)); + ASSIGN_OR_RETURN(auto module, + ParseAndReturnVerifiedModule(hlo_module, config)); auto changed = AllGatherPadDsSimplifier().Run(module.get(), {}); if (!changed.ok()) { return changed.status(); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_permuted_ds_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_permuted_ds_simplifier.cc index 5463c1cfaf950b..0f34f43cefa439 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_permuted_ds_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_permuted_ds_simplifier.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -95,7 +96,7 @@ absl::StatusOr AllGatherDynamicSlicePermutedOffsetSimplifier::RunImpl( for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { AllGatherDynamicSlicePermutedOffsetSimplifierVisitor visitor; - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); changed |= visitor.changed(); } return changed; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_permuted_ds_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_permuted_ds_simplifier_test.cc index a566378e5334da..f1f4010b0103d6 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_permuted_ds_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/all_gather_permuted_ds_simplifier_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -49,8 +50,8 @@ class AllGatherPermutedDsSimplifierTest HloModuleConfig config = GetModuleConfigForTest(num_replicas, num_partitions); config.set_use_spmd_partitioning(num_partitions > 1); - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_module, config)); + ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_module, config)); absl::StatusOr changed = AllGatherDynamicSlicePermutedOffsetSimplifier().Run(module.get(), {}); if (!changed.ok()) { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc b/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc index ea8f5733f37701..7dfd51e992a733 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -215,9 +216,9 @@ absl::StatusOr AllReduceFolder::RunImpl( std::make_shared(*new_replica_groups), /*constrain_layout=*/false, channel_id, ar0->use_global_device_ids())); - TF_RETURN_IF_ERROR(ar1->ReplaceAllUsesWith(new_ar)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(ar1)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(ar0)); + RETURN_IF_ERROR(ar1->ReplaceAllUsesWith(new_ar)); + RETURN_IF_ERROR(computation->RemoveInstruction(ar1)); + RETURN_IF_ERROR(computation->RemoveInstruction(ar0)); changed = true; } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner.cc b/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner.cc index 5a1945abeecda0..f0d77e103e19af 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/hlo_replication_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -58,7 +59,7 @@ namespace { // performance. absl::StatusOr ReplaceReplicatedAllReduce(HloModule* module, int64_t partition_count) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto replication_analysis, HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true)); @@ -92,7 +93,7 @@ absl::StatusOr ReplaceReplicatedAllReduce(HloModule* module, HloInstruction::CreateBroadcast(shape, divisor, {})); auto div = computation->AddInstruction(HloInstruction::CreateBinary( ar->shape(), HloOpcode::kDivide, ar, bcast)); - TF_RETURN_IF_ERROR(ar->ReplaceAllUsesWith(div)); + RETURN_IF_ERROR(ar->ReplaceAllUsesWith(div)); changed = true; } } @@ -537,7 +538,7 @@ absl::Status ArCrsCombiner::KeepProvablyEqualInstructionGroupsSPMD( HloModule* module) { // For SPMD mode, use HloReplicationAnalysis to figure out HLO value // equivalence across partitions. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto replication_analysis, HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/true)); @@ -645,16 +646,16 @@ absl::StatusOr ArCrsCombiner::RunImpl( GroupAllReducesById(module); if (spmd_partition_) { - TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module)); + RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsSPMD(module)); } else { - TF_RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD()); + RETURN_IF_ERROR(KeepProvablyEqualInstructionGroupsMPMD()); } - TF_ASSIGN_OR_RETURN(auto changed, RewriteGraph()); + ASSIGN_OR_RETURN(auto changed, RewriteGraph()); if (module->config().replica_count() > 1 && spmd_partition_) { - TF_ASSIGN_OR_RETURN(auto replaced, ReplaceReplicatedAllReduce( - module, num_spatial_partitions_)); + ASSIGN_OR_RETURN(auto replaced, ReplaceReplicatedAllReduce( + module, num_spatial_partitions_)); changed |= replaced; } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification.cc b/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification.cc index 6aed9f6e007067..f450d3cc5f7463 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -80,10 +81,10 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( return false; } - TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, - ElideDegenerateDims(lhs, degenerate_dims)); - TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, - ElideDegenerateDims(rhs, degenerate_dims)); + ASSIGN_OR_RETURN(HloInstruction * new_lhs, + ElideDegenerateDims(lhs, degenerate_dims)); + ASSIGN_OR_RETURN(HloInstruction * new_rhs, + ElideDegenerateDims(rhs, degenerate_dims)); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.clear_lhs_batch_dimensions(); @@ -103,19 +104,19 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( 0, new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_dot, MakeDotHlo(new_lhs, new_rhs, new_dim_numbers, batch_dot->precision_config(), /*preferred_element_type=*/batch_dot->shape().element_type())); - TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, - MakeReshapeHlo(batch_dot->shape(), new_dot)); + ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, + MakeReshapeHlo(batch_dot->shape(), new_dot)); VLOG(2) << "Replaced " << batch_dot->ToString() << " with " << new_dot->ToString(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( batch_dot->parent()->ReplaceInstruction(batch_dot, new_dot_reshaped)); return true; @@ -134,8 +135,8 @@ absl::StatusOr BatchDotSimplification::RunImpl( }); } for (HloInstruction* dot_instr : dot_instrs) { - TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, - ElideDegenerateBatchDimensionFromBatchDot(dot_instr)); + ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, + ElideDegenerateBatchDimensionFromBatchDot(dot_instr)); changed |= elided_batch_dim_from_one; } return changed; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.cc b/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.cc index 427618cd82c3cf..5f159006ba9cce 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/alias_info.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" @@ -92,7 +93,7 @@ absl::Status BFloat16ConversionFoldingVisitor::FoldOutputConversions( bfloat16_conversion_folding_->UpdateLayout(hlo->mutable_shape()); for (auto user : materialized_users) { CHECK_EQ(user->opcode(), HloOpcode::kConvert); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); + RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); changed_ = true; } return absl::OkStatus(); @@ -103,7 +104,7 @@ absl::Status BFloat16ConversionFoldingVisitor::FoldOperandConversion( // The operand is a convert from BF16 to F32. auto operand = hlo->mutable_operand(operand_index); CHECK_EQ(operand->opcode(), HloOpcode::kConvert); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0))); changed_ = true; return absl::OkStatus(); @@ -162,11 +163,11 @@ absl::Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( } if (fold_output_conversion) { - TF_RETURN_IF_ERROR(FoldOutputConversions(hlo)); + RETURN_IF_ERROR(FoldOutputConversions(hlo)); } for (int64_t i : bf16_to_f32_operands) { - TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i)); + RETURN_IF_ERROR(FoldOperandConversion(hlo, i)); } return absl::OkStatus(); } @@ -213,7 +214,7 @@ absl::Status BFloat16ConversionFoldingVisitor::HandleAllReduce( } // First use DefaultAction() to handle the operands. It can't handle // tuple-shaped output. - TF_RETURN_IF_ERROR(DefaultAction(crs)); + RETURN_IF_ERROR(DefaultAction(crs)); if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) { return absl::OkStatus(); @@ -265,7 +266,7 @@ absl::Status BFloat16ConversionFoldingVisitor::HandleAllReduce( bfloat16_conversion_folding_->UpdateLayout( ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i})); for (auto gte : per_tuple_element_gtes[i]) { - TF_RETURN_IF_ERROR(FoldOutputConversions(gte)); + RETURN_IF_ERROR(FoldOutputConversions(gte)); } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/call_parameter_cleanup.cc b/third_party/xla/xla/hlo/transforms/simplifiers/call_parameter_cleanup.cc index 8b6a935799ec60..6d6206de894f34 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/call_parameter_cleanup.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/call_parameter_cleanup.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -213,9 +214,9 @@ absl::StatusOr RemoveDeadParameters(HloComputation* computation) { // The new call computation is ready, now make all the call sites use it. for (HloInstruction* old_call : computation->caller_instructions()) { - TF_RETURN_IF_ERROR(ReplaceCallSite(old_call, new_computation, - old_to_new_parameter_number, - old_to_new_output_number, adjust_root)); + RETURN_IF_ERROR(ReplaceCallSite(old_call, new_computation, + old_to_new_parameter_number, + old_to_new_output_number, adjust_root)); } return true; @@ -252,7 +253,7 @@ absl::StatusOr CallParameterCleanup::RunImpl( bool changed = false; for (HloComputation* computation : computations_to_process) { - TF_ASSIGN_OR_RETURN(bool removed, RemoveDeadParameters(computation)); + ASSIGN_OR_RETURN(bool removed, RemoveDeadParameters(computation)); changed |= removed; } return changed; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/computation_canonicalizers.cc b/third_party/xla/xla/hlo/transforms/simplifiers/computation_canonicalizers.cc index e75d6c5300c2b6..0deb6563ce8d24 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/computation_canonicalizers.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/computation_canonicalizers.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -85,10 +86,10 @@ absl::StatusOr MoveParametersAndConstantsToFront( // we forward control predecessors to all users. for (HloInstruction* control_predecessor : inst->control_predecessors()) { for (HloInstruction* user : inst->users()) { - TF_RETURN_IF_ERROR(control_predecessor->AddControlDependencyTo(user)); + RETURN_IF_ERROR(control_predecessor->AddControlDependencyTo(user)); } } - TF_RETURN_IF_ERROR(inst->DropAllControlDeps()); + RETURN_IF_ERROR(inst->DropAllControlDeps()); } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer.cc b/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer.cc index 76fceddba9efd2..6d384ac3589fcd 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape.h" @@ -74,7 +75,7 @@ absl::StatusOr CanonicalizeNonTupleConditional( 0, HloInstruction::CreateParameter(0, shape, param->name())); HloInstruction* const gte = branch->AddInstruction( HloInstruction::CreateGetTupleElement(new_param, 0)); - TF_RETURN_IF_ERROR(new_param->ReplaceAllUsesWithDifferentShape(gte)); + RETURN_IF_ERROR(new_param->ReplaceAllUsesWithDifferentShape(gte)); changed = true; } @@ -96,8 +97,7 @@ absl::StatusOr CanonicalizeNonTupleConditional( if (!operand->shape().IsTuple()) { auto tuple = parent->AddInstruction(HloInstruction::CreateTuple({operand})); - TF_RETURN_IF_ERROR( - conditional->ReplaceOperandWithDifferentShape(i, tuple)); + RETURN_IF_ERROR(conditional->ReplaceOperandWithDifferentShape(i, tuple)); changed = true; } } @@ -110,7 +110,7 @@ absl::StatusOr CanonicalizeNonTupleConditional( parent->AddInstruction(conditional->CloneWithNewShape(new_shape)); auto gte = parent->AddInstruction( HloInstruction::CreateGetTupleElement(root_shape, new_conditional, 0)); - TF_RETURN_IF_ERROR(parent->ReplaceInstruction(conditional, gte)); + RETURN_IF_ERROR(parent->ReplaceInstruction(conditional, gte)); changed = true; } @@ -129,7 +129,7 @@ absl::StatusOr ConditionalCanonicalizer::RunImpl( for (auto* inst : comp->MakeInstructionPostOrder()) { if (inst->opcode() == HloOpcode::kConditional) { bool result; - TF_ASSIGN_OR_RETURN(result, CanonicalizeNonTupleConditional(inst)); + ASSIGN_OR_RETURN(result, CanonicalizeNonTupleConditional(inst)); changed |= result; } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/conv_operand_swapper.cc b/third_party/xla/xla/hlo/transforms/simplifiers/conv_operand_swapper.cc index 7458dc9bcab732..3f607cd7277953 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/conv_operand_swapper.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/conv_operand_swapper.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -139,10 +140,10 @@ absl::StatusOr SwapConvolutionOperandsIfBeneficial( convolution->precision_config().operand_precision(0)); if (!reverse_dimensions.empty()) { - TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions)); + ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions)); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_convolution, MakeConvolveHlo( kernel, input, /*feature_group_count=*/1, @@ -153,7 +154,7 @@ absl::StatusOr SwapConvolutionOperandsIfBeneficial( if (conv_is_lowerable_callback && !conv_is_lowerable_callback(new_convolution)) { - TF_RETURN_IF_ERROR(kernel->parent()->RemoveInstruction(new_convolution)); + RETURN_IF_ERROR(kernel->parent()->RemoveInstruction(new_convolution)); return false; } @@ -171,9 +172,9 @@ absl::StatusOr ConvOperandSwapper::RunImpl( for (auto comp : module->computations(execution_threads)) { for (HloInstruction* hlo : comp->MakeInstructionPostOrder()) { if (auto* convolution = DynCast(hlo)) { - TF_ASSIGN_OR_RETURN(bool convolution_changed, - SwapConvolutionOperandsIfBeneficial( - convolution, conv_is_lowerable_callback_)); + ASSIGN_OR_RETURN(bool convolution_changed, + SwapConvolutionOperandsIfBeneficial( + convolution, conv_is_lowerable_callback_)); changed |= convolution_changed; } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/conv_operand_swapper_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/conv_operand_swapper_test.cc index c832502ef83e45..915fa25f9fae68 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/conv_operand_swapper_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/conv_operand_swapper_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" @@ -41,9 +42,9 @@ class ConvOperandSwapperTest : public HloHardwareIndependentTestBase { public: absl::StatusOr> RunPass( absl::string_view hlo_module, int64_t distance_threshold = 100) { - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule( - hlo_module, GetModuleConfigForTest())); - TF_RETURN_IF_ERROR(ConvOperandSwapper().Run(module.get()).status()); + ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule( + hlo_module, GetModuleConfigForTest())); + RETURN_IF_ERROR(ConvOperandSwapper().Run(module.get()).status()); return absl::StatusOr>(std::move(module)); } }; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover.cc b/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover.cc index 3f60e885eeb501..e4fe6d61bfe7c5 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" @@ -159,7 +160,7 @@ absl::StatusOr MoveConvertPrecisionOps(HloComputation* comp) { new_shape.set_element_type(src_ty); HloInstruction* new_instr = comp->AddInstruction( instr->CloneWithNewOperands(new_shape, new_operands)); - TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( + RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( instr, HloInstruction::CreateConvert(instr->shape(), new_instr))); changed = true; } @@ -201,7 +202,7 @@ absl::StatusOr MoveConvertPrecisionOps(HloComputation* comp) { } Shape new_shape = to_convert->shape(); new_shape.set_element_type(dst_ty); - TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( + RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( instr, to_convert->CloneWithNewOperands(new_shape, new_operands))); changed = true; } @@ -217,8 +218,7 @@ absl::StatusOr ConvertMover::RunImpl( bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool changed_computation, - MoveConvertPrecisionOps(comp)); + ASSIGN_OR_RETURN(bool changed_computation, MoveConvertPrecisionOps(comp)); changed |= changed_computation; } return changed; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder.cc b/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder.cc index bee24f707f3b5c..057edf090ca45b 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/convert_operand_folder.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" @@ -104,7 +105,7 @@ absl::StatusOr ConvertOperandFolding::ExpandInstruction( for (int i = 0; i < instruction->operand_count(); ++i) { auto* operand = instruction->mutable_operand(i); if (IsUpcastConvert(operand)) { - TF_RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape( + RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape( i, EffectiveOperand(operand))); } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/dead_dynamic_update_slice_elimination.cc b/third_party/xla/xla/hlo/transforms/simplifiers/dead_dynamic_update_slice_elimination.cc index ef7058daa766f3..f270344f9f6392 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/dead_dynamic_update_slice_elimination.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dead_dynamic_update_slice_elimination.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -137,8 +138,8 @@ absl::StatusOr ProcessDynamicUpdateSlice(HloInstruction* dus, }); VLOG(2) << " is_dus_update_unused: " << is_dus_update_unused; if (is_dus_update_unused) { - TF_RETURN_IF_ERROR(dus->ReplaceAllUsesWith(dus->mutable_operand(0))); - TF_RETURN_IF_ERROR(comp->RemoveInstruction(dus)); + RETURN_IF_ERROR(dus->ReplaceAllUsesWith(dus->mutable_operand(0))); + RETURN_IF_ERROR(comp->RemoveInstruction(dus)); return true; // Changed } return false; // Not changed @@ -163,8 +164,8 @@ absl::StatusOr DeadDynamicUpdateSliceElimination::RunImpl( continue; } VLOG(2) << "Processing DUS: " << instruction->ToString(); - TF_ASSIGN_OR_RETURN(bool dus_changed, - ProcessDynamicUpdateSlice(instruction, computation)); + ASSIGN_OR_RETURN(bool dus_changed, + ProcessDynamicUpdateSlice(instruction, computation)); if (dus_changed) { changed = true; } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger.cc b/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger.cc index c2f6d31379d679..d0a3f312f22e94 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -129,11 +130,11 @@ class BatchDimensionMerger : public DfsHloRewriteVisitor { shifted_contracting_dimensions.end()); } - TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_lhs, - MakeReshapeHlo(new_lhs_shape, dot->mutable_operand(0))); + ASSIGN_OR_RETURN(HloInstruction * reshaped_lhs, + MakeReshapeHlo(new_lhs_shape, dot->mutable_operand(0))); - TF_ASSIGN_OR_RETURN(HloInstruction * reshaped_rhs, - MakeReshapeHlo(new_rhs_shape, dot->mutable_operand(1))); + ASSIGN_OR_RETURN(HloInstruction * reshaped_rhs, + MakeReshapeHlo(new_rhs_shape, dot->mutable_operand(1))); Shape new_dot_shape = merge_batch_dims(dot->shape(), /*batch_dim=*/0); HloInstruction* new_dot = dot->parent()->AddInstruction( diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger.cc b/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger.cc index 24924659f1e446..45a02e3e692313 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -181,10 +182,9 @@ absl::StatusOr TryMergeSameOperand(HloInstruction* a, ++outer_dim; } - TF_ASSIGN_OR_RETURN( - Shape concat_shape, - ShapeInference::InferConcatOpShape( - {&diff_op_a->shape(), &diff_op_b->shape()}, outer_dim)); + ASSIGN_OR_RETURN(Shape concat_shape, + ShapeInference::InferConcatOpShape( + {&diff_op_a->shape(), &diff_op_b->shape()}, outer_dim)); *concat_shape.mutable_layout() = diff_op_a->shape().layout(); HloInstruction* concat_op = diff_op_a->AddInstruction(HloInstruction::CreateConcatenate( @@ -192,11 +192,10 @@ absl::StatusOr TryMergeSameOperand(HloInstruction* a, HloInstruction* dot_lhs = lhs_same ? shared_op : concat_op; HloInstruction* dot_rhs = lhs_same ? concat_op : shared_op; - TF_ASSIGN_OR_RETURN( - Shape new_dot_shape, - ShapeInference::InferDotOpShape( - dot_lhs->shape(), dot_rhs->shape(), dnums, - /*preferred_element_type=*/a->shape().element_type())); + ASSIGN_OR_RETURN(Shape new_dot_shape, + ShapeInference::InferDotOpShape( + dot_lhs->shape(), dot_rhs->shape(), dnums, + /*preferred_element_type=*/a->shape().element_type())); *new_dot_shape.mutable_layout() = a->shape().layout(); HloInstruction* new_dot = a->AddInstruction(HloInstruction::CreateDot( new_dot_shape, dot_lhs, dot_rhs, dnums, a->precision_config())); @@ -221,13 +220,13 @@ absl::StatusOr TryMergeSameOperand(HloInstruction* a, // must live until the end of the pass. HloInstruction* new_a = a->AddInstruction(HloInstruction::CreateSlice( a->shape(), new_dot, start_indices, limit_indices, strides)); - TF_RETURN_IF_ERROR(a->ReplaceAllUsesWith(new_a)); + RETURN_IF_ERROR(a->ReplaceAllUsesWith(new_a)); start_indices[slice_dim] = limit_indices[slice_dim]; limit_indices[slice_dim] = new_dot_shape.dimensions(slice_dim); HloInstruction* new_b = b->AddInstruction(HloInstruction::CreateSlice( b->shape(), new_dot, start_indices, limit_indices, strides)); - TF_RETURN_IF_ERROR(b->ReplaceAllUsesWith(new_b)); + RETURN_IF_ERROR(b->ReplaceAllUsesWith(new_b)); return new_dot; } @@ -341,18 +340,17 @@ absl::StatusOr TryMergeLHSWithRHSOperand(HloInstruction* a, HloInstruction* b_lhs_transposed = b_lhs->AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::PermuteDimensions({1, 0}, b_lhs->shape()), b_lhs, {1, 0})); - TF_ASSIGN_OR_RETURN(Shape concat_shape, - ShapeInference::InferConcatOpShape( - {&a_rhs->shape(), &b_lhs_transposed->shape()}, 1)); + ASSIGN_OR_RETURN(Shape concat_shape, + ShapeInference::InferConcatOpShape( + {&a_rhs->shape(), &b_lhs_transposed->shape()}, 1)); HloInstruction* new_rhs = a_rhs->AddInstruction(HloInstruction::CreateConcatenate( concat_shape, {a_rhs, b_lhs_transposed}, 1)); - TF_ASSIGN_OR_RETURN( - Shape new_dot_shape, - ShapeInference::InferDotOpShape( - a_lhs->shape(), // The new LHS is the LHS of a. - new_rhs->shape(), dnums_a, - /*preferred_element_type=*/a->shape().element_type())); + ASSIGN_OR_RETURN(Shape new_dot_shape, + ShapeInference::InferDotOpShape( + a_lhs->shape(), // The new LHS is the LHS of a. + new_rhs->shape(), dnums_a, + /*preferred_element_type=*/a->shape().element_type())); *new_dot_shape.mutable_layout() = a->shape().layout(); HloInstruction* new_dot = a->AddInstruction(HloInstruction::CreateDot( new_dot_shape, a_lhs, new_rhs, dnums_a, a->precision_config())); @@ -380,8 +378,8 @@ absl::StatusOr TryMergeLHSWithRHSOperand(HloInstruction* a, HloInstruction::CreateTranspose(b->shape(), new_b_slice, {1, 0})); // Important: We do RAUW, not ReplaceInstruction, because the old // instruction must live until the end of the pass. - TF_RETURN_IF_ERROR(a->ReplaceAllUsesWith(new_a)); - TF_RETURN_IF_ERROR(b->ReplaceAllUsesWith(new_b)); + RETURN_IF_ERROR(a->ReplaceAllUsesWith(new_a)); + RETURN_IF_ERROR(b->ReplaceAllUsesWith(new_b)); return new_dot; } @@ -559,7 +557,7 @@ absl::StatusOr MergeDots( continue; } - TF_ASSIGN_OR_RETURN(HloInstruction * merged, TryMergeOperand(a, b)); + ASSIGN_OR_RETURN(HloInstruction * merged, TryMergeOperand(a, b)); if (merged != nullptr) { int32_t merged_id = graph_id(merged); graph.InsertEdge(a_id, merged_id); @@ -592,7 +590,7 @@ absl::StatusOr MergeDots( return a->unique_id() < b->unique_id(); }); for (HloInstruction* instr : sorted_dead_instrs) { - TF_RETURN_IF_ERROR(comp->RemoveInstruction(instr)); + RETURN_IF_ERROR(comp->RemoveInstruction(instr)); } return !dead_instrs.empty(); @@ -606,8 +604,8 @@ absl::StatusOr DotMerger::RunImpl( bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool changed_computation, - MergeDots(comp, max_size_to_merge_, queue_id_)); + ASSIGN_OR_RETURN(bool changed_computation, + MergeDots(comp, max_size_to_merge_, queue_id_)); changed |= changed_computation; } return changed; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.cc index 00bf452d035a1e..14cfbdeb11b06c 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/status_macros.h" @@ -51,7 +52,7 @@ absl::StatusOr ConcatForwarding(HloInstruction* concat) { if (changed) { auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate( concat->shape(), new_operands, concat->concatenate_dimension())); - TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat)); + RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat)); } return changed; } @@ -86,7 +87,7 @@ absl::StatusOr SliceConcatForwarding(HloInstruction* slice) { if (size_so_far == slice->slice_starts(0) && operand->shape().dimensions(0) == slice_size) { // Found an operand that can be forwarded. - TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand)); + RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand)); return true; } size_so_far += operand->shape().dimensions(concat_dim); @@ -117,8 +118,7 @@ absl::StatusOr ReshapeBroadcastForwarding(HloInstruction* reshape) { return false; } - TF_RETURN_IF_ERROR( - reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0))); + RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0))); return true; } @@ -136,8 +136,7 @@ absl::StatusOr ReshapeReshapeForwarding(HloInstruction* reshape) { if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) { return false; } - TF_RETURN_IF_ERROR( - reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0))); + RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0))); return true; } @@ -149,7 +148,7 @@ absl::StatusOr IdentityConvertRemoving(HloInstruction* convert) { } auto operand = convert->mutable_operand(0); if (Shape::Equal()(convert->shape(), operand->shape())) { - TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand)); + RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand)); return true; } return false; @@ -162,7 +161,7 @@ absl::StatusOr IdentityReshapeRemoving(HloInstruction* reshape) { } auto operand = reshape->mutable_operand(0); if (Shape::Equal()(reshape->shape(), operand->shape())) { - TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand)); + RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand)); return true; } return false; @@ -179,39 +178,39 @@ absl::StatusOr DynamicDimensionSimplifier::RunImpl( for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst)); + ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst)); changed |= local_changed; } } for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst)); + ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst)); changed |= local_changed; } } for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst)); + ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst)); changed |= local_changed; } } for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst)); + ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst)); changed |= local_changed; } } for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst)); + ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst)); changed |= local_changed; } } for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { for (auto* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst)); + ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst)); changed |= local_changed; } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph_test.cc index 914a63bdc50beb..af740cd11f855a 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -95,7 +96,7 @@ class FlattenCallGraphTest : public HloHardwareIndependentTestBase { absl::StatusOr RunFlattenCallGraph(HloModule* module) { FlattenCallGraph flatten; - TF_ASSIGN_OR_RETURN(bool result, flatten.Run(module)); + ASSIGN_OR_RETURN(bool result, flatten.Run(module)); return result; } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc index e0ec47713d19f0..a01b733bf8a853 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" @@ -148,7 +149,7 @@ absl::StatusOr FloatNormalizationVisitor::ConvertType( to == LowPrecisionType() && from == HighPrecisionType()) { return hlo->mutable_operand(0); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto new_hlo, computation->DeepCopyInstructionWithCustomCopier( hlo, [&](HloInstruction* leaf, const ShapeIndex& leaf_index, @@ -173,13 +174,13 @@ absl::Status FloatNormalizationVisitor::InsertConvertAfterOutput( bool is_root = computation->root_instruction() == hlo; std::vector materialized_users = hlo->users(); - TF_ASSIGN_OR_RETURN(auto new_hlo, ConvertType(hlo, from, to, computation)); + ASSIGN_OR_RETURN(auto new_hlo, ConvertType(hlo, from, to, computation)); if (new_hlo == hlo) { return absl::OkStatus(); } for (auto* user : materialized_users) { - TF_RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo)); + RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo)); } if (is_root) { computation->set_root_instruction(new_hlo, /*accept_different_shape=*/true); @@ -222,7 +223,7 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( }); float_normalization_->UpdateLayout(hlo->mutable_shape()); std::vector materialized_users = hlo->users(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto new_hlo, computation->DeepCopyInstructionWithCustomCopier( hlo, [&](HloInstruction* leaf, const ShapeIndex& leaf_index, @@ -262,11 +263,11 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( from == LowPrecisionType()) { conversions_to_simplify.push_back(user); } else { - TF_RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo)); + RETURN_IF_ERROR(hlo->ReplaceUseWithDifferentShape(user, new_hlo)); } } for (auto* convert : conversions_to_simplify) { - TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(hlo)); + RETURN_IF_ERROR(convert->ReplaceAllUsesWith(hlo)); } if (is_root) { computation->set_root_instruction(new_hlo, /*accept_different_shape=*/true); @@ -282,12 +283,12 @@ absl::Status FloatNormalizationVisitor::InsertConvertBeforeOperand( HloInstruction* hlo, int64_t operand_idx, PrimitiveType from, PrimitiveType to, HloComputation* computation) { auto operand = hlo->mutable_operand(operand_idx); - TF_ASSIGN_OR_RETURN(auto new_operand, - ConvertType(operand, from, to, computation)); + ASSIGN_OR_RETURN(auto new_operand, + ConvertType(operand, from, to, computation)); if (new_operand == operand) { return absl::OkStatus(); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( hlo->ReplaceOperandWithDifferentShape(operand_idx, new_operand)); changed_ = true; return absl::OkStatus(); @@ -311,13 +312,13 @@ absl::Status FloatNormalizationVisitor::ConvertCalledComputations( }); for (auto& comp_pair : cloned_computations) { auto comp = comp_pair.second; - TF_RETURN_IF_ERROR(InsertConvertAfterOutput(comp->root_instruction(), - LowPrecisionType(), - HighPrecisionType(), comp)); + RETURN_IF_ERROR(InsertConvertAfterOutput(comp->root_instruction(), + LowPrecisionType(), + HighPrecisionType(), comp)); for (auto* param : comp->parameter_instructions()) { // This changes the parameter to high-precision then inserts a convert // after it. - TF_RETURN_IF_ERROR(ChangeOutputTypeThenInsertConvertBack( + RETURN_IF_ERROR(ChangeOutputTypeThenInsertConvertBack( param, LowPrecisionType(), HighPrecisionType(), comp)); } } @@ -385,7 +386,7 @@ absl::Status FloatNormalizationVisitor::HandleMultipleOutputs( for (int64_t i = 0; i < hlo->operand_count(); ++i) { if (should_convert_operand(i)) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand( + RETURN_IF_ERROR(InsertConvertBeforeOperand( hlo, i, LowPrecisionType(), HighPrecisionType(), computation_)); high_prec_count += 1; low_prec_count -= 1; @@ -451,7 +452,7 @@ absl::Status FloatNormalizationVisitor::HandleMultipleOutputs( // ReplaceUseWith. *tuple->mutable_shape() = hlo->shape(); for (auto* user : materialized_users) { - TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple)); + RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple)); } bool is_root = computation_->root_instruction() == hlo; if (is_root) { @@ -516,7 +517,7 @@ absl::Status FloatNormalizationVisitor::HandleInstruction(HloInstruction* hlo) { hlo->operand(i)->shape(), LowPrecisionType()); if (low_prec_count_in_operand > 0 && !float_support_->SupportsLowPrecisionOperand(*hlo, i)) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand( + RETURN_IF_ERROR(InsertConvertBeforeOperand( hlo, i, LowPrecisionType(), HighPrecisionType(), computation_)); low_prec_count -= low_prec_count_in_operand; high_prec_count += low_prec_count_in_operand; @@ -528,7 +529,7 @@ absl::Status FloatNormalizationVisitor::HandleInstruction(HloInstruction* hlo) { int64_t low_prec_count_in_hlo = CountSubshapesWithMatchingType(hlo->shape(), LowPrecisionType()); if (low_prec_count_in_hlo > 0) { - TF_RETURN_IF_ERROR(ChangeOutputTypeThenInsertConvertBack( + RETURN_IF_ERROR(ChangeOutputTypeThenInsertConvertBack( hlo, LowPrecisionType(), HighPrecisionType(), computation_)); low_prec_count -= low_prec_count_in_hlo; high_prec_count += low_prec_count_in_hlo; @@ -564,16 +565,16 @@ absl::Status FloatNormalizationVisitor::HandleInstruction(HloInstruction* hlo) { } if (can_use_low_prec) { for (int i = 0; i < hlo->operand_count(); ++i) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand( + RETURN_IF_ERROR(InsertConvertBeforeOperand( hlo, i, HighPrecisionType(), LowPrecisionType(), computation_)); } return absl::OkStatus(); } } - TF_RETURN_IF_ERROR(ChangeOutputTypeThenInsertConvertBack( + RETURN_IF_ERROR(ChangeOutputTypeThenInsertConvertBack( hlo, LowPrecisionType(), HighPrecisionType(), computation_)); for (int i = 0; i < hlo->operand_count(); ++i) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand( + RETURN_IF_ERROR(InsertConvertBeforeOperand( hlo, i, LowPrecisionType(), HighPrecisionType(), computation_)); } return ConvertCalledComputations(hlo, low_precision_called_comps); @@ -684,7 +685,7 @@ absl::StatusOr FloatNormalization::RunImpl( FloatNormalizationVisitor visitor(float_support_, this); for (auto* comp : computations_to_visit) { if (computations_to_skip.contains(comp)) continue; - TF_RETURN_IF_ERROR(comp->Accept(&visitor)); + RETURN_IF_ERROR(comp->Accept(&visitor)); } XLA_VLOG_LINES(2, "FloatNormalization::RunImpl() for " + primitive_util::LowercasePrimitiveTypeName( @@ -692,9 +693,9 @@ absl::StatusOr FloatNormalization::RunImpl( ", after:\n" + module->ToString()); if (visitor.changed()) { TupleSimplifier tuple_simplifier; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module).status()); + RETURN_IF_ERROR(dce.Run(module).status()); } return visitor.changed(); } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking.cc b/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking.cc index d8bb9ff322c104..c3312d7ca34903 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/numbers.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -146,7 +147,7 @@ absl::StatusOr FusionConstantSinking::RunImpl( } if (changed) { - TF_ASSIGN_OR_RETURN(bool dce, HloDCE{}.Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool dce, HloDCE{}.Run(module, execution_threads)); changed |= dce; } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier.cc index 49e83965bee11a..be6fd0729aecf7 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal_util.h" @@ -49,9 +50,8 @@ absl::StatusOr GatherSimplifier::ExpandInstruction( auto* start_indices = gather->operands()[1]; // Make the start_indices a two-dimensional tensor. - TF_ASSIGN_OR_RETURN( - start_indices, - TransformStartIndices(start_indices, dims.index_vector_dim())); + ASSIGN_OR_RETURN(start_indices, TransformStartIndices( + start_indices, dims.index_vector_dim())); // Permute the slice sizes according to start_index_map and compute the new // output shape for the Gather op. @@ -78,8 +78,7 @@ absl::StatusOr GatherSimplifier::ExpandInstruction( dims.collapsed_slice_dims().size()); absl::c_transform(dims.collapsed_slice_dims(), collapsed_slice_dims.begin(), [](int64_t dim) { return dim + 1; }); - TF_ASSIGN_OR_RETURN(result, - ElideDegenerateDims(result, collapsed_slice_dims)); + ASSIGN_OR_RETURN(result, ElideDegenerateDims(result, collapsed_slice_dims)); } // Expand the start index dimensions. @@ -91,10 +90,10 @@ absl::StatusOr GatherSimplifier::ExpandInstruction( } } if (start_indices_dims.size() > 1) { - TF_ASSIGN_OR_RETURN(result, - ExpandFirstDimIntoNDims(result, start_indices_dims)); + ASSIGN_OR_RETURN(result, + ExpandFirstDimIntoNDims(result, start_indices_dims)); } else if (start_indices_dims.empty()) { - TF_ASSIGN_OR_RETURN(result, ElideDegenerateDims(result, {0})); + ASSIGN_OR_RETURN(result, ElideDegenerateDims(result, {0})); } // Move the offset dims to the final locations. diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc index 7642adb8d67776..a64323bc7862b9 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -130,7 +131,7 @@ absl::Status RecursivelyRemoveDeadInstructionAndDeadOperands( auto operands = dead_instruction->operands(); // First remove the instruction itself. - TF_RETURN_IF_ERROR(computation.RemoveInstruction(dead_instruction)); + RETURN_IF_ERROR(computation.RemoveInstruction(dead_instruction)); // Now check if some of its operands are dead as a result of the removal. for (auto operand : operands) { @@ -280,7 +281,7 @@ absl::StatusOr PropagateIdenticalConstantArguments( }); } const HloInstruction* constant = caller_instructions[0]->operand(i); - TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWith( + RETURN_IF_ERROR(parameter->ReplaceAllUsesWith( computation->AddInstruction(constant->Clone()))); changed = true; } @@ -323,8 +324,8 @@ absl::StatusOr HloConstantFolding::RunImpl( [](HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCall; })) { - TF_ASSIGN_OR_RETURN(bool did_change, - PropagateIdenticalConstantArguments(computation)); + ASSIGN_OR_RETURN(bool did_change, + PropagateIdenticalConstantArguments(computation)); changed |= did_change; } for (auto* instruction : computation->MakeInstructionPostOrder()) { @@ -432,8 +433,8 @@ absl::StatusOr HloConstantFolding::RunImpl( ->set_element_size_in_bits( instruction->shape().layout().element_size_in_bits()); } - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_constant)); - TF_RETURN_IF_ERROR(RecursivelyRemoveDeadInstructionAndDeadOperands( + RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_constant)); + RETURN_IF_ERROR(RecursivelyRemoveDeadInstructionAndDeadOperands( *computation, instruction)); } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter.cc index 4b52476ef9793d..1b53049c32c326 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "tsl/platform/errors.h" @@ -106,7 +107,7 @@ absl::StatusOr DuplicateConstantExpressionPerUser( cloned_instr->clear_sharding(); cloned_instructions_map[i] = cloned_instr; if (i == to_clone) { - TF_RETURN_IF_ERROR(to_clone->ReplaceUseWith(user, cloned_instr)); + RETURN_IF_ERROR(to_clone->ReplaceUseWith(user, cloned_instr)); changed = true; } } @@ -207,8 +208,8 @@ absl::StatusOr HloConstantSplitter::RunImpl( } } for (auto* u : users) { - TF_ASSIGN_OR_RETURN(bool duplicated, DuplicateConstantExpressionPerUser( - computation, instruction, u)); + ASSIGN_OR_RETURN(bool duplicated, DuplicateConstantExpressionPerUser( + computation, instruction, u)); changed |= duplicated; } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc index 2263b9f0ed791f..b33b8057b88e00 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -94,10 +95,10 @@ absl::Status UpdateFusionUsers(HloInstruction* fusion_instruction, for (HloInstruction* gte : users) { // Replace and change control successors to be dependent on the fusion // instruction itself. - TF_ASSIGN_OR_RETURN(std::ignore, gte->parent()->ReplaceInstruction( - gte, fusion_instruction, - /*preserve_sharding=*/true, - /*relay_control_dependency=*/true)); + ASSIGN_OR_RETURN(std::ignore, gte->parent()->ReplaceInstruction( + gte, fusion_instruction, + /*preserve_sharding=*/true, + /*relay_control_dependency=*/true)); } } return absl::OkStatus(); @@ -154,7 +155,7 @@ absl::StatusOr RemoveMultiOutputFusionsUnusedOutputs( *fusion_instruction->mutable_shape() = std::move(new_shape); // Update the users of the old fusion instruction. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( UpdateFusionUsers(fusion_instruction, used_tuple_elements, tuple_shapes)); // Update the root of the fusion computation. @@ -167,10 +168,10 @@ absl::StatusOr RemoveMultiOutputFusionsUnusedOutputs( } auto new_tuple = computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); - TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( + RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( computation->root_instruction(), new_tuple)); } else { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation->root_instruction()->ReplaceAllUsesWithDifferentShape( computation->root_instruction()->mutable_operand( *used_tuple_elements.begin()))); @@ -243,7 +244,7 @@ absl::StatusOr RemoveDeadRoots( for (HloInstruction* dead_root : dead_roots) { VLOG(1) << "Removing dead root " << dead_root->ToString() << " and its unused operands"; - TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( + RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( dead_root, /*cleanup=*/std::nullopt, /*ignore_control_dependencies=*/false, /*computation_callers=*/computation_callers)); @@ -294,7 +295,7 @@ absl::StatusOr RemoveDeadParameters( << " and its unused operands"; int64_t num_parameters = computation->num_parameters(); int64_t parameter_number = parameter->parameter_number(); - TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( + RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( parameter, /*cleanup=*/std::nullopt, /*ignore_control_dependencies=*/false, /*computation_callers=*/computation_callers, @@ -308,7 +309,7 @@ absl::StatusOr RemoveDeadParameters( } } if (update_entry_computation_layout) { - TF_RETURN_IF_ERROR(RemoveDeadParametersFromEntryComputationLayout( + RETURN_IF_ERROR(RemoveDeadParametersFromEntryComputationLayout( computation->parent(), dead_parameters)); } return changed; @@ -343,7 +344,7 @@ absl::StatusOr ProcessAgenda( if (execution_threads.empty() || execution_threads.contains(computation->execution_thread())) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool computation_changed, xla::HloDCE::RunOnComputation( computation, remove_cross_partition_collective_ops, call_graph, @@ -387,7 +388,7 @@ absl::StatusOr RemoveDanglingComputations( if (to_remove.contains(computation)) { if (execution_threads.empty() || execution_threads.contains(computation->execution_thread())) { - TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation( + RETURN_IF_ERROR(module->RemoveEmbeddedComputation( iterator.underlying_iterator().underlying_iterator())); changed = true; } @@ -411,17 +412,17 @@ absl::StatusOr RemoveDanglingComputations( }; bool changed = false; - TF_ASSIGN_OR_RETURN(bool fusion_changed, - RemoveMultiOutputFusionsUnusedOutputs(computation)); + ASSIGN_OR_RETURN(bool fusion_changed, + RemoveMultiOutputFusionsUnusedOutputs(computation)); changed |= fusion_changed; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool dead_roots_changed, RemoveDeadRoots(computation, remove_cross_partition_collective_ops, computation_callers)); changed |= dead_roots_changed; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool dead_parameters_changed, RemoveDeadParameters(computation, computation_callers, remove_dead_parameters_from_entry_computation)); @@ -447,14 +448,14 @@ absl::StatusOr HloDCE::RunImpl( absl::flat_hash_set to_remove; PopulateAgenda(module, agenda, to_remove); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool agenda_changed, ProcessAgenda(module, agenda, to_remove, execution_threads, remove_cross_partition_collective_ops_, call_graph.get(), remove_dead_parameters_from_entry_computation_)); changed |= agenda_changed; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool dangling_computations_removed, RemoveDanglingComputations(module, to_remove, execution_threads, use_call_analysis_, call_graph)); @@ -463,7 +464,7 @@ absl::StatusOr HloDCE::RunImpl( if (changed) { // Update the schedule to reflect the removed instructions. if (module->has_schedule()) { - TF_RETURN_IF_ERROR(module->schedule().Update(execution_threads)); + RETURN_IF_ERROR(module->schedule().Update(execution_threads)); } VLOG(2) << "After dce:"; XLA_VLOG_LINES(2, module->ToString()); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter.cc index 8552bd446ef1c9..0e162d9e2f850a 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" @@ -200,7 +201,7 @@ absl::StatusOr HloElementTypeConverter::RunImpl( new_hlo = computation->AddInstruction( hlo->CloneWithNewOperands(shape, new_operands, &context)); - TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); + RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); new_hlo = ToElementType(new_hlo, eliminate_type_); } else if (hlo->shape().IsTuple()) { @@ -210,7 +211,7 @@ absl::StatusOr HloElementTypeConverter::RunImpl( new_hlo = computation->AddInstruction( hlo->CloneWithNewOperands(new_shape, new_operands, &context)); - TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); + RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); // Convert the elements of the result of `new_hlo` to produce a new // tuple with shape `old_shape`. @@ -218,16 +219,16 @@ absl::StatusOr HloElementTypeConverter::RunImpl( } else { new_hlo = computation->AddInstruction( hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context)); - TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); + RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); } - TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo)); - TF_RETURN_IF_ERROR(hlo->DropAllControlDeps()); + RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo)); + RETURN_IF_ERROR(hlo->DropAllControlDeps()); // NB! We want to replace and remove side effecting instructions like Rng // as well so we can't rely HloComputation::ReplaceInstruction to reliably // remove the replaced instruction. - TF_RETURN_IF_ERROR(computation->RemoveInstruction(hlo)); + RETURN_IF_ERROR(computation->RemoveInstruction(hlo)); changed = true; } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc index b3e3889385c09f..2ff601a138801f 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc @@ -36,6 +36,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/alias_info.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/analysis/tuple_points_to_analysis.h" @@ -406,8 +407,8 @@ absl::StatusOr ComputationSchedulerAlgorithm::Run( for (HloComputation* computation : module->MakeComputationPostOrder(execution_threads)) { if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, - Run(computation, points_to_analysis, alias_analysis)); + ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, + Run(computation, points_to_analysis, alias_analysis)); if (postprocessor_) { computation_sequence = postprocessor_(computation_sequence); } @@ -415,9 +416,9 @@ absl::StatusOr ComputationSchedulerAlgorithm::Run( } } if (peak_memory) { - TF_ASSIGN_OR_RETURN(*peak_memory, HeapSimulator::MinimumMemoryForModule( - schedule, alias_analysis, alias_info_, - size_function_)); + ASSIGN_OR_RETURN(*peak_memory, HeapSimulator::MinimumMemoryForModule( + schedule, alias_analysis, alias_info_, + size_function_)); } return schedule; } @@ -485,7 +486,7 @@ absl::StatusOr DFSMemoryScheduler::Run( return absl::OkStatus(); }); visitor.ReserveVisitStates(computation->instruction_count()); - TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder( + RETURN_IF_ERROR(computation->AcceptWithOperandOrder( &visitor, [&stats_map](const HloInstruction* a, const HloInstruction* b) { auto& stats_a = stats_map.at(a); auto& stats_b = stats_map.at(b); @@ -583,21 +584,21 @@ absl::StatusOr DefaultMemoryScheduler::Run( // List wins for most of our benchmarks; postorder-based schedulers win for // some RNNs. int64_t list_memory; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloSchedule list_sequence, list_scheduler_.Run(module, points_to_analysis, alias_analysis, execution_threads, &list_memory)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); int64_t dfs_memory; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloSchedule dfs_sequence, dfs_scheduler_.Run(module, points_to_analysis, alias_analysis, execution_threads, &dfs_memory)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); int64_t post_order_memory; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloSchedule post_order_sequence, post_order_scheduler_.Run(module, points_to_analysis, alias_analysis, execution_threads, &post_order_memory)); @@ -632,17 +633,16 @@ absl::StatusOr ScheduleModule( return absl::StrFormat("XlaMemoryScheduler:#module=%s,program_id=%d#", module->name(), module->unique_id()); }); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(module)); - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module, algorithm.alias_info())); + ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, algorithm.alias_info())); - TF_ASSIGN_OR_RETURN( - HloSchedule schedule, - algorithm.Run(module, *points_to_analysis, *alias_analysis, - execution_threads, peak_memory)); + ASSIGN_OR_RETURN(HloSchedule schedule, + algorithm.Run(module, *points_to_analysis, *alias_analysis, + execution_threads, peak_memory)); - TF_RETURN_IF_ERROR(schedule.Verify()); + RETURN_IF_ERROR(schedule.Verify()); return schedule; } @@ -660,9 +660,9 @@ absl::StatusOr ScheduleModule( absl::StatusOr HloMemoryScheduler::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(module, *algorithm_, execution_threads)); - TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(module, *algorithm_, execution_threads)); + RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); return true; } @@ -681,10 +681,10 @@ absl::StatusOr HloTrivialScheduler::RunImpl( return absl::OkStatus(); }); visitor.ReserveVisitStates(computation->instruction_count()); - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); } } - TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); return true; } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc index 26c3acbf195160..de5c341cf081b1 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc @@ -43,6 +43,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/analysis/tuple_points_to_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -698,7 +699,7 @@ absl::Status MemoryUsageTracker::EndInstruction() { TF_RET_CHECK(in_progress_item_ != nullptr); VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name(); - TF_RETURN_IF_ERROR(CountFreedMemory(in_progress_item_)); + RETURN_IF_ERROR(CountFreedMemory(in_progress_item_)); in_progress_item_ = nullptr; @@ -1163,21 +1164,21 @@ absl::Status MemoryUsageTracker::AddHostOffloadCopyInstructions( // ones won't be either. if (copy_start_to_host_item->placed) { CountAllocatedMemory(copy_start_to_host_item); - TF_RETURN_IF_ERROR(CountFreedMemory(copy_start_to_host_item)); + RETURN_IF_ERROR(CountFreedMemory(copy_start_to_host_item)); // This will account for the freed memory that is defined by the original // item. if (copy_done_to_host_item->placed) { CountAllocatedMemory(copy_done_to_host_item); - TF_RETURN_IF_ERROR(CountFreedMemory(copy_done_to_host_item)); + RETURN_IF_ERROR(CountFreedMemory(copy_done_to_host_item)); if (copy_start_to_device_item->placed) { CountAllocatedMemory(copy_start_to_device_item); - TF_RETURN_IF_ERROR(CountFreedMemory(copy_start_to_device_item)); + RETURN_IF_ERROR(CountFreedMemory(copy_start_to_device_item)); if (copy_done_to_device_item->placed) { CountAllocatedMemory(copy_done_to_device_item); - TF_RETURN_IF_ERROR(CountFreedMemory(copy_done_to_device_item)); + RETURN_IF_ERROR(CountFreedMemory(copy_done_to_device_item)); } } } @@ -1225,8 +1226,8 @@ absl::StatusOr MemoryUsageTracker::GetCompactShape( return &it->second; } const Shape& original_shape = hlo->shape(); - TF_ASSIGN_OR_RETURN(Shape min_shape, - options_.compact_shape_function(original_shape)); + ASSIGN_OR_RETURN(Shape min_shape, + options_.compact_shape_function(original_shape)); return &compact_shape_.emplace(hlo, min_shape).first->second; } @@ -1740,7 +1741,7 @@ absl::StatusOr RematerializeInstructions( HloInstruction* remat = computation->AddInstruction(best->Clone(/*suffix=*/"remat", &context)); // Call the callback on the original and rematerialized instruction. - TF_RETURN_IF_ERROR(rematerialization->on_rematerialized(best, remat)); + RETURN_IF_ERROR(rematerialization->on_rematerialized(best, remat)); for (auto& cloned_computation_pair : context.cloned_computations()) { if (!schedule->is_computation_scheduled(cloned_computation_pair.first)) { continue; @@ -1760,7 +1761,7 @@ absl::StatusOr RematerializeInstructions( } // Add control dependencies to the new operation. - TF_RETURN_IF_ERROR(remat->CopyAllControlDepsFrom(best)); + RETURN_IF_ERROR(remat->CopyAllControlDepsFrom(best)); HloRematItem* remat_item = instruction_list->CreateItem(remat); // Peak priority specific optimization. Any recomputed instruction @@ -1806,7 +1807,7 @@ absl::StatusOr RematerializeInstructions( /*new_name=*/"bitcast.remat"); indirect_users.push_back(instruction_list->CreateItem(remat_use)); } - TF_RETURN_IF_ERROR(user.user->instruction->ReplaceOperandWith( + RETURN_IF_ERROR(user.user->instruction->ReplaceOperandWith( user.operand_number, remat_use)); // Peak priority specific optimization. Any recomputed instruction // should not be rematerialized again. @@ -1814,13 +1815,13 @@ absl::StatusOr RematerializeInstructions( RematAlgorithm::kPeakPriority) { (*rematerializable_map)[remat_use] = false; } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( rematerialization->on_rematerialized(user_operand, remat_use)); } } // Account for the rematerialization in the memory tracker. - TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction( + RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction( best_item, remat_item, absl::MakeSpan(indirect_users))); // Insert rematerialized instruction right before the earliest unplaced @@ -1924,7 +1925,7 @@ absl::StatusOr RematerializeInstructions( // TODO(b/486858124): Generalize this to all strategies. if (rematerialization->remat_algorithm() == RematAlgorithm::kPeakPriority) { - TF_RETURN_IF_ERROR(best->DropAllControlDeps()); + RETURN_IF_ERROR(best->DropAllControlDeps()); // Removes all leftover uses of best. These uses are inactive as the // instruction has been rendered effectively dead by rematerialization. while (!best->users().empty()) { @@ -1937,10 +1938,10 @@ absl::StatusOr RematerializeInstructions( << "Deleting user " << user->name() << " of instruction " << best->name() << " because the instruction was killed by rematerialization."; - TF_RETURN_IF_ERROR(user->DropAllControlDeps()); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(user)); + RETURN_IF_ERROR(user->DropAllControlDeps()); + RETURN_IF_ERROR(computation->RemoveInstruction(user)); } - TF_RETURN_IF_ERROR(computation->RemoveInstruction(best)); + RETURN_IF_ERROR(computation->RemoveInstruction(best)); } remat_move_instructions->insert(remat); net_instructions_added += indirect_users.size(); @@ -1958,8 +1959,8 @@ absl::StatusOr RematerializeInstructions( // We need to remove all control dependencies from best before removing it // from the computation. Its control dependencies were previously copied // to the remat instruction. - TF_RETURN_IF_ERROR(best->DropAllControlDeps()); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(best)); + RETURN_IF_ERROR(best->DropAllControlDeps()); + RETURN_IF_ERROR(computation->RemoveInstruction(best)); } } return net_instructions_added; @@ -1996,12 +1997,12 @@ absl::StatusOr CompressInstruction( if (!memory_tracker->IsPlaced(user)) { VLOG(5) << " Replacing use of " << best->name() << " in " << user->name() << " with " << uncompressed->name(); - TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed)); + RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed)); } } // Account for the rematerialization in the memory tracker. - TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions( + RETURN_IF_ERROR(memory_tracker->AddCompressInstructions( best_item, compressed_item, uncompressed_item)); // Insert rematerialized instruction right before the earliest unplaced @@ -2068,13 +2069,13 @@ absl::StatusOr OffloadInstruction( << copy_done_to_device->ToString(); // Update the HloCostAnalysis with the new instructions. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( copy_start_to_host->Visit(&memory_tracker->options().hlo_cost_analysis)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( copy_done_to_host->Visit(&memory_tracker->options().hlo_cost_analysis)); - TF_RETURN_IF_ERROR(copy_start_to_device->Visit( + RETURN_IF_ERROR(copy_start_to_device->Visit( &memory_tracker->options().hlo_cost_analysis)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( copy_done_to_device->Visit(&memory_tracker->options().hlo_cost_analysis)); // Create an HloRematItem for each instruction. These items will be inserted @@ -2273,7 +2274,7 @@ absl::StatusOr OffloadInstruction( if (!memory_tracker->IsPlaced(user)) { VLOG(3) << " Replacing use of " << best_instruction->name() << " in " << user->name() << " with " << copy_done_to_device->name(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( best_instruction->ReplaceUseWith(user, copy_done_to_device)); } else { VLOG(3) << user->name() << " is placed, not going to update"; @@ -2282,7 +2283,7 @@ absl::StatusOr OffloadInstruction( // Finally, update the MemoryUsageTracker. This will update the tracking of // buffer creations and uses. - TF_RETURN_IF_ERROR(memory_tracker->AddHostOffloadCopyInstructions( + RETURN_IF_ERROR(memory_tracker->AddHostOffloadCopyInstructions( best_item, copy_start_to_host_item, copy_done_to_host_item, copy_start_to_device_item, copy_done_to_device_item)); @@ -2341,7 +2342,7 @@ absl::StatusOr RematerializeBestBlock( best_items[0], best_strategy.compact_shape)) << ")"; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( num_instructions_added.net_instructions_added, CompressInstruction(memory_tracker, best_items[0], best_strategy.compact_shape, instruction_list)); @@ -2350,7 +2351,7 @@ absl::StatusOr RematerializeBestBlock( CHECK_EQ(best_items.size(), 1) << "More than one buffer offloaded simultaneously."; VLOG(1) << "Remat via offload: " << best_items[0]->instruction->name(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( num_instructions_added.net_instructions_added, OffloadInstruction(memory_tracker, best_items[0], instruction_list)); VLOG(4) << "Offload done, hlo computation:\n" @@ -2365,11 +2366,11 @@ absl::StatusOr RematerializeBestBlock( absl::StrAppend(out, item->instruction->name()); }) << '}'; - TF_ASSIGN_OR_RETURN(num_instructions_added.net_instructions_added, - RematerializeInstructions( - memory_tracker, &best_items, - remat_move_instructions, instruction_list, schedule, - rematerialization, rematerializable_map)); + ASSIGN_OR_RETURN(num_instructions_added.net_instructions_added, + RematerializeInstructions( + memory_tracker, &best_items, remat_move_instructions, + instruction_list, schedule, rematerialization, + rematerializable_map)); } return num_instructions_added; } @@ -2378,7 +2379,7 @@ absl::StatusOr RematerializeBestBlock( absl::StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, const HloInstructionSequence& order, const absl::flat_hash_set& execution_threads) const { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto peak_memory_result, ComputePeakMemoryAndInstruction(computation, order, execution_threads)); return peak_memory_result.memory_usage; @@ -2399,16 +2400,15 @@ HloRematerialization::ComputePeakMemoryAndInstruction( for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { const HloInstruction* instruction = item->instruction; - TF_RETURN_IF_ERROR(tracker.BeginInstruction(item)); - TF_ASSIGN_OR_RETURN( - int64_t callee_usage, - CalledComputationsMemoryUsage(instruction, execution_threads)); + RETURN_IF_ERROR(tracker.BeginInstruction(item)); + ASSIGN_OR_RETURN(int64_t callee_usage, CalledComputationsMemoryUsage( + instruction, execution_threads)); int64_t memory_at_instruction = tracker.memory_usage() + callee_usage; if (memory_at_instruction > peak_memory) { peak_memory = memory_at_instruction; peak_instruction = instruction; } - TF_RETURN_IF_ERROR(tracker.EndInstruction()); + RETURN_IF_ERROR(tracker.EndInstruction()); } VLOG(1) << "Peak memory for " << computation->name() << ": " << HumanReadableNumBytes(peak_memory); @@ -2446,9 +2446,9 @@ absl::Status HloRematerialization::UpdateScheduleFromSequence( // the points_to_analysis_ after rematerializing each computation. Recompute // points_to_analysis_ since the older analysis does not include // rematerialized instructions. - TF_ASSIGN_OR_RETURN(points_to_analysis_, - TuplePointsToAnalysis::Run(computation->parent())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(points_to_analysis_, + TuplePointsToAnalysis::Run(computation->parent())); + ASSIGN_OR_RETURN( computation_peak_memory_[computation], ComputePeakMemory(computation, schedule->sequence(computation), execution_threads)); @@ -2472,7 +2472,7 @@ HloRematerialization::RematerializeCalledComputationsPeakPriority( int64_t subcomputation_memory_limit_bytes = std::max( 0, memory_limit_bytes - memory_tracker_memory_usage); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool subcomputation_changed, RematerializeComputationPeakPriority( called_computation, schedule, subcomputation_memory_limit_bytes, @@ -2510,7 +2510,7 @@ RematPeakAggressively( << HumanReadableNumBytes(peak_memory_during_remat) << ") " << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( InstructionsAdded instructions_added, RematerializeBestBlock( /*min_block_size=*/1, max_block_size, &memory_tracker, @@ -2565,8 +2565,8 @@ HloRematerialization::PeakPrioritySubPass( VLOG(3) << "Creating memory tracker for rematerialization on " << computation->name() << " instruction list size " << state.instruction_list->size(); - TF_ASSIGN_OR_RETURN(points_to_analysis_, - TuplePointsToAnalysis::Run(computation->parent())); + ASSIGN_OR_RETURN(points_to_analysis_, + TuplePointsToAnalysis::Run(computation->parent())); MemoryUsageTracker memory_tracker(options_, computation, *points_to_analysis_, *state.instruction_list); state.instruction_list->PromoteNodesToSkip([&](HloRematItem* item) { @@ -2596,10 +2596,9 @@ HloRematerialization::PeakPrioritySubPass( VLOG(3) << "Instruction is dead: " << instruction->name(); continue; } - TF_ASSIGN_OR_RETURN( - int64_t callee_usage, - CalledComputationsMemoryUsage(instruction, execution_threads)); - TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item)); + ASSIGN_OR_RETURN(int64_t callee_usage, CalledComputationsMemoryUsage( + instruction, execution_threads)); + RETURN_IF_ERROR(memory_tracker.BeginInstruction(item)); VLOG(2) << "Program point at " << instruction->name() << ", memory usage = " << memory_tracker.memory_usage() @@ -2610,7 +2609,7 @@ HloRematerialization::PeakPrioritySubPass( // Only trigger rematerialization at peak memory instruction if (peak_memory_instruction == instruction) { // Rematerialize until the peak usage is brought down. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloRematerialization::RematerializationStepResult remat_step_result, RematPeakAggressively(instruction, state, this, memory_tracker, callee_usage, peak_memory_during_remat, @@ -2638,7 +2637,7 @@ HloRematerialization::PeakPrioritySubPass( << "). Rematerializing computations called by " << instruction->name(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool callee_usage_changed_sub_module, RematerializeCalledComputationsPeakPriority( callsite, memory_tracker.memory_usage(), state.schedule, @@ -2648,15 +2647,15 @@ HloRematerialization::PeakPrioritySubPass( // Recompute callee usage to account for any rematerialization performed // in the callee computations. - TF_ASSIGN_OR_RETURN(callee_usage, CalledComputationsMemoryUsage( - instruction, execution_threads)); + ASSIGN_OR_RETURN(callee_usage, CalledComputationsMemoryUsage( + instruction, execution_threads)); } peak_memory = std::max( peak_memory, memory_tracker.memory_usage() + callee_usage); VLOG(2) << "peak memory usage = " << HumanReadableNumBytes(peak_memory); - TF_RETURN_IF_ERROR(memory_tracker.EndInstruction()); + RETURN_IF_ERROR(memory_tracker.EndInstruction()); if (module_changed_in_this_subpass) { VLOG(2) << "Rematerialization successful, breaking."; break; @@ -2667,7 +2666,7 @@ HloRematerialization::PeakPrioritySubPass( VLOG(2) << "Module changed in this pass, updating peak memory stats."; HloRematerialization::MemoryUsageAndInstruction new_peak_memory_and_instruction; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( new_peak_memory_and_instruction, PeakPriorityUpdateVariables(*state.instruction_list, computation, state.schedule, execution_threads)); @@ -2730,7 +2729,7 @@ absl::StatusOr HloRematerialization::RematerializeComputationPeakPriority( int64_t cost_estimate_memory_limit_bytes = std::max(kMinimumCostEstimateMemoryLimitBytes, memory_limit_bytes); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloRematerialization::MemoryUsageAndInstruction peak_memory_result, ComputePeakMemoryAndInstruction( computation, schedule->sequence(computation), execution_threads)); @@ -2776,7 +2775,7 @@ absl::StatusOr HloRematerialization::RematerializeComputationPeakPriority( RematSubpassStatus remat_subpass_status; bool changed = false; do { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( RematSubpassResult remat_subpass_result, PeakPrioritySubPass(peak_memory_instruction, rematerialization_state, computation, call_graph_node, min_remat_size, @@ -2807,11 +2806,11 @@ HloRematerialization::PeakPriorityUpdateVariables( } sequence_from_list.push_back(item->instruction); } - TF_RETURN_IF_ERROR(HloRematerialization::UpdateScheduleFromSequence( + RETURN_IF_ERROR(HloRematerialization::UpdateScheduleFromSequence( computation, schedule, sequence_from_list, execution_threads)); VLOG(2) << "Schedule updated"; // Update instruction list to reflect the new instruction in computation. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( instruction_list.UpdateFromSequence(schedule->sequence(computation))); // Update peak memory. return ComputePeakMemoryAndInstruction( @@ -2877,10 +2876,9 @@ absl::StatusOr HloRematerialization::RematerializeComputation( for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { const HloInstruction* instruction = item->instruction; - TF_ASSIGN_OR_RETURN( - int64_t callee_usage, - CalledComputationsMemoryUsage(instruction, execution_threads)); - TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item)); + ASSIGN_OR_RETURN(int64_t callee_usage, CalledComputationsMemoryUsage( + instruction, execution_threads)); + RETURN_IF_ERROR(memory_tracker.BeginInstruction(item)); VLOG(2) << "Program point at " << instruction->name() << ", memory usage = " << memory_tracker.memory_usage() @@ -2910,7 +2908,7 @@ absl::StatusOr HloRematerialization::RematerializeComputation( callee_usage) << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( InstructionsAdded instructions_added, RematerializeBestBlock(min_block_size, max_block_size, &memory_tracker, &instruction_list, schedule, @@ -2971,7 +2969,7 @@ absl::StatusOr HloRematerialization::RematerializeComputation( // amount of memory used at this point in the computation. int64_t subcomputation_memory_limit_bytes = std::max( 0, memory_limit_bytes - memory_tracker.memory_usage()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool subcomputation_changed, RematerializeComputation(called_computation, schedule, subcomputation_memory_limit_bytes, @@ -2980,15 +2978,15 @@ absl::StatusOr HloRematerialization::RematerializeComputation( } } - TF_ASSIGN_OR_RETURN(callee_usage, CalledComputationsMemoryUsage( - instruction, execution_threads)); + ASSIGN_OR_RETURN(callee_usage, CalledComputationsMemoryUsage( + instruction, execution_threads)); } peak_memory = std::max( peak_memory, memory_tracker.memory_usage() + callee_usage); VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory); - TF_RETURN_IF_ERROR(memory_tracker.EndInstruction()); + RETURN_IF_ERROR(memory_tracker.EndInstruction()); } // Verify some invariants on the memory tracker. @@ -3074,7 +3072,7 @@ absl::StatusOr HloRematerialization::RunImpl( net_instructions_added_ = 0; TF_RET_CHECK(module->has_schedule()); - TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); + ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); next_channel_id_ = hlo_query::NextChannelId(*module); // Adjust memory limit to account for the output of the entry @@ -3111,17 +3109,17 @@ absl::StatusOr HloRematerialization::RunImpl( options_.async_computation_parallelism) { async_threads.insert(computation->execution_thread()); } - TF_RETURN_IF_ERROR(call_graph_->VisitNodes( + RETURN_IF_ERROR(call_graph_->VisitNodes( [this, module, &async_threads](const CallGraphNode& node) -> absl::Status { auto callee_thread = node.computation()->execution_thread(); if (node.context() == CallContext::kControlFlow && HloInstruction::IsThreadIncluded(callee_thread, async_threads)) { - TF_ASSIGN_OR_RETURN(computation_peak_memory_[node.computation()], - ComputePeakMemory(node.computation(), - module->schedule().sequence( - node.computation()), - {callee_thread})); + ASSIGN_OR_RETURN(computation_peak_memory_[node.computation()], + ComputePeakMemory(node.computation(), + module->schedule().sequence( + node.computation()), + {callee_thread})); } return absl::OkStatus(); }, @@ -3151,13 +3149,13 @@ absl::StatusOr HloRematerialization::RunImpl( } // Compute peak memory usage of all computations in the module called in a // sequential context. - TF_RETURN_IF_ERROR(call_graph_->VisitNodes( + RETURN_IF_ERROR(call_graph_->VisitNodes( [this, module, &execution_threads](const CallGraphNode& node) -> absl::Status { if (node.context() == CallContext::kControlFlow && HloInstruction::IsThreadIncluded( node.computation()->execution_thread(), execution_threads)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], ComputePeakMemory(node.computation(), module->schedule().sequence(node.computation()), @@ -3181,15 +3179,15 @@ absl::StatusOr HloRematerialization::RunImpl( // Initialize the HloCostAnalysis on this computation. for (auto* computation : module->MakeComputationPostOrder(execution_threads)) { - TF_RETURN_IF_ERROR(computation->Accept(&options_.hlo_cost_analysis)); + RETURN_IF_ERROR(computation->Accept(&options_.hlo_cost_analysis)); } - TF_ASSIGN_OR_RETURN(RematAlgorithmFunction remat_algorithm_func, - GetRematAlgorithmFunction(options_.remat_algorithm)); + ASSIGN_OR_RETURN(RematAlgorithmFunction remat_algorithm_func, + GetRematAlgorithmFunction(options_.remat_algorithm)); // Subcomputations called by the entry computation will also be // rematerialized. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool changed, remat_algorithm_func(module->entry_computation(), &module->schedule(), adjusted_memory_limit_bytes, options_.min_remat_size, @@ -3201,13 +3199,13 @@ absl::StatusOr HloRematerialization::RunImpl( // while the module is in flux. HloSchedule saved_schedule = module->schedule(); module->clear_schedule(); - TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloPassFix().Run(module)); + ASSIGN_OR_RETURN(bool dead_code_removed, HloPassFix().Run(module)); changed |= dead_code_removed; // After DCE, the module sequence may include instructions which no longer // exist. Update the schedule and restore it. - TF_RETURN_IF_ERROR(saved_schedule.Update(execution_threads)); - TF_RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule))); + RETURN_IF_ERROR(saved_schedule.Update(execution_threads)); + RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule))); VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.cc index 332d18e864476c..12bace9d63b814 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -83,9 +84,9 @@ class HostMemoryTransferAsyncifierVisitor : public DfsHloVisitorWithDefault { // Everything is as expected. Replace this dynamic-slice with the async // equivalent. const Shape context_shape = ShapeUtil::MakeScalarShape(U32); - TF_ASSIGN_OR_RETURN(HloInstruction * async_done, - dynamic_slice->parent()->CreateAsyncInstructions( - dynamic_slice, {context_shape})); + ASSIGN_OR_RETURN(HloInstruction * async_done, + dynamic_slice->parent()->CreateAsyncInstructions( + dynamic_slice, {context_shape})); VLOG(1) << "DynamicSlice \"" << dynamic_slice->ToString() << "\" is slicing from host memory. Converting to async " << async_done->ToString(); @@ -140,9 +141,9 @@ class HostMemoryTransferAsyncifierVisitor : public DfsHloVisitorWithDefault { // Everything is as expected. Replace this dynamic-update-slice with the // async equivalent. const Shape context_shape = ShapeUtil::MakeScalarShape(U32); - TF_ASSIGN_OR_RETURN(HloInstruction * async_done, - dynamic_update_slice->parent()->CreateAsyncInstructions( - dynamic_update_slice, {context_shape})); + ASSIGN_OR_RETURN(HloInstruction * async_done, + dynamic_update_slice->parent()->CreateAsyncInstructions( + dynamic_update_slice, {context_shape})); VLOG(1) << "DynamicUpdateSlice \"" << dynamic_update_slice->ToString() << "\" is slicing into host memory space. Converting to async " << async_done->ToString(); @@ -177,7 +178,7 @@ class HostMemoryTransferAsyncifierVisitor : public DfsHloVisitorWithDefault { // Everything is as expected. Replace this copy with the async equivalent. const Shape context_shape = ShapeUtil::MakeScalarShape(U32); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * async_done, copy->parent()->CreateAsyncInstructions(copy, {context_shape})); VLOG(1) @@ -202,7 +203,7 @@ absl::StatusOr HostMemoryTransferAsyncifier::RunImpl( const absl::flat_hash_set& execution_threads) { HostMemoryTransferAsyncifierVisitor visitor(kHostMemorySpaceColor); for (HloComputation* computation : module->MakeNonfusionComputations()) { - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); } return visitor.Changed(); } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier_test.cc index 882561b2672d95..b477840fd5587d 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -41,8 +42,8 @@ namespace m = ::xla::match; class HostMemoryTransferAsyncifierTest : public HloHardwareIndependentTestBase { protected: absl::StatusOr RunAsyncifier(absl::string_view hlo_string) { - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSIGN_OR_RETURN(bool changed, RunAsyncifier(module.get())); + ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); + ASSIGN_OR_RETURN(bool changed, RunAsyncifier(module.get())); return changed; } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.cc b/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.cc index 3da5222cf01825..5c8b8a865fcb4e 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout_util.h" @@ -128,9 +129,9 @@ absl::StatusOr OptimizeInputOutputBufferAlias::Build( donee_vector_index += 1; } else { // The current donor and donee match. - TF_RETURN_IF_ERROR(alias_config->SetUpAlias( + RETURN_IF_ERROR(alias_config->SetUpAlias( donee.index, donor.param_number, donor.index)); - TF_RETURN_IF_ERROR(buffer_donor_config->RemoveBufferDonor( + RETURN_IF_ERROR(buffer_donor_config->RemoveBufferDonor( donor.param_number, donor.index)); donor_vector_index += 1; donee_vector_index += 1; @@ -159,9 +160,9 @@ absl::StatusOr OptimizeInputOutputBufferAlias::RunImpl( &module->input_output_alias_config(); HloBufferDonorConfig* buffer_donor_config = &module->buffer_donor_config(); - TF_ASSIGN_OR_RETURN(bool changed, Build(input_shapes, output_shape, - alias_config, buffer_donor_config)); - TF_RETURN_IF_ERROR(alias_config->Verify(*module, shape_size_fn_)); + ASSIGN_OR_RETURN(bool changed, Build(input_shapes, output_shape, alias_config, + buffer_donor_config)); + RETURN_IF_ERROR(alias_config->Verify(*module, shape_size_fn_)); return changed; } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/recognize_reduce_window.cc b/third_party/xla/xla/hlo/transforms/simplifiers/recognize_reduce_window.cc index b734a4c00f8056..f8e4e61372eeec 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/recognize_reduce_window.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/recognize_reduce_window.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -381,12 +382,12 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { HloInstruction* static_slice = computation->AddInstruction( HloInstruction::CreateSlice(inst->shape(), inst->mutable_operand(0), starts, limits, strides)); - TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(static_slice)); + RETURN_IF_ERROR(inst->ReplaceAllUsesWith(static_slice)); changed = true; } } else if (inst->opcode() == HloOpcode::kGetTupleElement) { if (inst->operand(0)->opcode() == HloOpcode::kTuple) { - TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith( + RETURN_IF_ERROR(inst->ReplaceAllUsesWith( inst->mutable_operand(0)->mutable_operand(inst->tuple_index()))); changed = true; } @@ -558,12 +559,12 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { ? lhs->to_apply()->root_instruction()->opcode() : inst->opcode(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto rw, CreateReduceWindow(computation, inst->shape(), new_base_op, opcode, dim, current_window_size, window_dilation)); - TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(rw)); + RETURN_IF_ERROR(inst->ReplaceAllUsesWith(rw)); changed = true; continue; } @@ -662,7 +663,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { auto broadcasted_zero = computation->AddInstruction(HloInstruction::CreateBroadcast( inst->shape(), zero_replacement, {})); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation->ReplaceInstruction(inst, broadcasted_zero)); changed = true; continue; @@ -670,7 +671,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { if (folded.size() == 1) { if (folded[0].weight == one && !any_array_weights) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation->ReplaceInstruction(inst, folded[0].slice)); changed = true; continue; @@ -693,8 +694,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { inst->shape(), HloOpcode::kMultiply, replacement, broadcasted)); } - TF_RETURN_IF_ERROR( - computation->ReplaceInstruction(inst, replacement)); + RETURN_IF_ERROR(computation->ReplaceInstruction(inst, replacement)); changed = true; continue; } @@ -706,8 +706,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { HloInstruction* replacement = computation->AddInstruction( HloInstruction::CreateBinary(inst->shape(), HloOpcode::kMultiply, folded[0].slice, broadcasted)); - TF_RETURN_IF_ERROR( - computation->ReplaceInstruction(inst, replacement)); + RETURN_IF_ERROR(computation->ReplaceInstruction(inst, replacement)); changed = true; continue; } @@ -832,13 +831,12 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { folded[0].slice->mutable_operand(1), new_pad_config)); } int64_t current_window_size = folded.size(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * replacement, CreateReduceWindow(computation, inst->shape(), new_base_op, HloOpcode::kAdd, dim, current_window_size, window_dilation)); - TF_RETURN_IF_ERROR( - computation->ReplaceInstruction(inst, replacement)); + RETURN_IF_ERROR(computation->ReplaceInstruction(inst, replacement)); changed = true; continue; } @@ -864,8 +862,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { HloInstruction* replacement = computation->AddInstruction( HloInstruction::CreateBinary(inst->shape(), HloOpcode::kAdd, folded[0].slice, folded[1].slice)); - TF_RETURN_IF_ERROR( - computation->ReplaceInstruction(inst, replacement)); + RETURN_IF_ERROR(computation->ReplaceInstruction(inst, replacement)); changed = true; continue; } @@ -880,8 +877,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { computation->AddInstruction(HloInstruction::CreateBinary( inst->shape(), HloOpcode::kSubtract, folded[0].slice, folded[1].slice)); - TF_RETURN_IF_ERROR( - computation->ReplaceInstruction(inst, replacement)); + RETURN_IF_ERROR(computation->ReplaceInstruction(inst, replacement)); changed = true; continue; } @@ -897,8 +893,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { computation->AddInstruction(HloInstruction::CreateBinary( inst->shape(), HloOpcode::kSubtract, folded[1].slice, folded[0].slice)); - TF_RETURN_IF_ERROR( - computation->ReplaceInstruction(inst, replacement)); + RETURN_IF_ERROR(computation->ReplaceInstruction(inst, replacement)); changed = true; continue; } @@ -987,8 +982,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { replacement = computation->AddInstruction( HloInstruction::CreateDot(inst->shape(), concat, concat_weights, dnums, precision_config)); - TF_RETURN_IF_ERROR( - computation->ReplaceInstruction(inst, replacement)); + RETURN_IF_ERROR(computation->ReplaceInstruction(inst, replacement)); changed = true; continue; } else { @@ -1031,8 +1025,7 @@ absl::StatusOr RunOnComputation(HloComputation* computation) { replacement = computation->AddInstruction(HloInstruction::CreateDot( inst->shape(), concat, weights, dnums, precision_config)); - TF_RETURN_IF_ERROR( - computation->ReplaceInstruction(inst, replacement)); + RETURN_IF_ERROR(computation->ReplaceInstruction(inst, replacement)); changed = true; continue; } @@ -1053,8 +1046,7 @@ absl::StatusOr RecognizeReduceWindow::RunImpl( std::vector computations = module->MakeNonfusionComputations(execution_threads); for (HloComputation* computation : computations) { - TF_ASSIGN_OR_RETURN(bool computation_changed, - RunOnComputation(computation)); + ASSIGN_OR_RETURN(bool computation_changed, RunOnComputation(computation)); changed |= computation_changed; // Post-order traversal for DCE @@ -1070,7 +1062,7 @@ absl::StatusOr RecognizeReduceWindow::RunImpl( if (dce_inst->user_count() == 0 && !dce_inst->HasSideEffect() && computation->root_instruction() != dce_inst && computation->IsSafelyRemovable(dce_inst)) { - TF_RETURN_IF_ERROR(computation->RemoveInstruction(dce_inst)); + RETURN_IF_ERROR(computation->RemoveInstruction(dce_inst)); dce_changed = true; } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_resizer.cc b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_resizer.cc index e50dcf0904eb0a..a864e17b55918b 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_resizer.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_resizer.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -43,7 +44,7 @@ absl::StatusOr ReduceWindowResizer::RunImpl( if (reduce_window->inputs().front()->shape().dimensions().size() != 1) { continue; } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( reduce_window_util::Replace1DReduceWindowWithReshape(reduce_window)); changed = true; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter.cc b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter.cc index 560dbf546fd9c3..f613a66cf80a27 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -104,7 +105,7 @@ static absl::StatusOr ScalarizeComputation( HloInstruction* new_inst = nullptr; switch (inst->opcode()) { case HloOpcode::kParameter: { - TF_ASSIGN_OR_RETURN(Shape shape, get_scalar_shape(inst->shape())); + ASSIGN_OR_RETURN(Shape shape, get_scalar_shape(inst->shape())); new_inst = builder.AddInstruction(HloInstruction::CreateParameter( inst->parameter_number(), shape, inst->name())); break; @@ -114,7 +115,7 @@ static absl::StatusOr ScalarizeComputation( HloInstruction::CreateTuple(get_mapped_operands(inst))); break; case HloOpcode::kGetTupleElement: { - TF_ASSIGN_OR_RETURN(Shape shape, get_scalar_shape(inst->shape())); + ASSIGN_OR_RETURN(Shape shape, get_scalar_shape(inst->shape())); new_inst = builder.AddInstruction(HloInstruction::CreateGetTupleElement( shape, replacements[inst->operand(0)], inst->tuple_index())); break; @@ -156,7 +157,7 @@ static absl::StatusOr ScalarizeComputation( absl::StrCat("Instruction is not elementwise: ", HloOpcodeString(inst->opcode()))); } - TF_ASSIGN_OR_RETURN(Shape shape, get_scalar_shape(inst->shape())); + ASSIGN_OR_RETURN(Shape shape, get_scalar_shape(inst->shape())); new_inst = builder.AddInstruction( inst->CloneWithNewOperands(shape, get_mapped_operands(inst))); break; @@ -571,7 +572,7 @@ ReduceWindowRewriter::RewriteScanAsTreeReduction( scans.push_back(scan); return absl::OkStatus(); }); - TF_RETURN_IF_ERROR(status); + RETURN_IF_ERROR(status); HloInstruction* scan; if (result_shape.IsTuple()) { @@ -642,14 +643,13 @@ absl::StatusOr ReduceWindowRewriter::TryOptimizeCumSumOrProd( // We don't actually need to match the computation - this transformation will // work for a commutative/associative reducer, which is what we assume for // ReduceWindow anyway. - TF_ASSIGN_OR_RETURN( - HloInstruction * scan, - RewriteScanAsTreeReduction(parent, sources, reduce_window->init_values(), - reduce_window->to_apply(), - reduce_window->shape(), rank, scan_dim, - scan_length, forward_scan, is_exclusive)); - TF_RETURN_IF_ERROR(reduce_window->ReplaceAllUsesWith(scan)); - TF_RETURN_IF_ERROR(parent->RemoveInstruction(reduce_window)); + ASSIGN_OR_RETURN(HloInstruction * scan, + RewriteScanAsTreeReduction( + parent, sources, reduce_window->init_values(), + reduce_window->to_apply(), reduce_window->shape(), rank, + scan_dim, scan_length, forward_scan, is_exclusive)); + RETURN_IF_ERROR(reduce_window->ReplaceAllUsesWith(scan)); + RETURN_IF_ERROR(parent->RemoveInstruction(reduce_window)); return true; } @@ -667,10 +667,10 @@ absl::StatusOr ReduceWindowRewriter::TryOptimizeAssociativeScan( VLOG(2) << "Rewriting associative scan: " << scan->ToString(); HloComputation* parent = scan->parent(); - TF_ASSIGN_OR_RETURN(HloInstruction * init, - GetScalarInitValue(scan->inits()[0], parent)); - TF_ASSIGN_OR_RETURN(HloComputation * scan_to_apply, - ScalarizeComputation(scan->to_apply(), parent)); + ASSIGN_OR_RETURN(HloInstruction * init, + GetScalarInitValue(scan->inits()[0], parent)); + ASSIGN_OR_RETURN(HloComputation * scan_to_apply, + ScalarizeComputation(scan->to_apply(), parent)); HloComputation::Builder builder( absl::StrCat(scan_to_apply->name(), "_rw_wrapper")); @@ -698,17 +698,17 @@ absl::StatusOr ReduceWindowRewriter::TryOptimizeAssociativeScan( input->shape(), input, init, window, rw_to_apply)); } else { Shape outputs_shape = scan->shape().tuple_shapes(0); - TF_ASSIGN_OR_RETURN(result, RewriteScanAsTreeReduction( - parent, {input}, {init}, rw_to_apply, - outputs_shape, rank, scan_dim, scan_length, - /*forward_scan=*/true, - /*is_exclusive=*/false)); + ASSIGN_OR_RETURN(result, RewriteScanAsTreeReduction( + parent, {input}, {init}, rw_to_apply, + outputs_shape, rank, scan_dim, scan_length, + /*forward_scan=*/true, + /*is_exclusive=*/false)); } // Replace carry with init value, users are guaranteed to be dead. HloInstruction* tuple = parent->AddInstruction( HloInstruction::CreateTuple({result, scan->inits()[0]})); - TF_RETURN_IF_ERROR(parent->ReplaceInstruction(scan, tuple)); + RETURN_IF_ERROR(parent->ReplaceInstruction(scan, tuple)); return true; } @@ -723,7 +723,7 @@ absl::StatusOr ReduceWindowRewriter::RunImpl( computation->MakeInstructionPostOrder()) { if (auto* scan = DynCast(instruction)) { if (decompose_assoc_scan) { - TF_ASSIGN_OR_RETURN(bool result, TryOptimizeAssociativeScan(scan)); + ASSIGN_OR_RETURN(bool result, TryOptimizeAssociativeScan(scan)); changed |= result; } continue; @@ -732,15 +732,14 @@ absl::StatusOr ReduceWindowRewriter::RunImpl( if (auto* reduce_window = DynCast(instruction)) { auto result = TryOptimizeCumSumOrProd(reduce_window); - TF_RETURN_IF_ERROR(result.status()); + RETURN_IF_ERROR(result.status()); if (*result) { changed = true; continue; } if (reduce_window->inputs().front()->shape().dimensions().size() == 1) { - TF_RETURN_IF_ERROR( - reduce_window_util::Replace1DReduceWindowWithReshape( - reduce_window)); + RETURN_IF_ERROR(reduce_window_util::Replace1DReduceWindowWithReshape( + reduce_window)); changed = true; } continue; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_util.cc b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_util.cc index 9303e2007f9249..9665d964d831b1 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_util.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/shape.h" @@ -119,8 +120,8 @@ absl::Status Replace1DReduceWindowWithReshape( CHECK_EQ(final_reshapes.size(), 1); result = final_reshapes[0]; } - TF_RETURN_IF_ERROR(reduce_window->ReplaceAllUsesWith(result)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(reduce_window->ReplaceAllUsesWith(result)); + RETURN_IF_ERROR( new_reduce_window->parent()->RemoveInstruction(reduce_window)); return absl::OkStatus(); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover.cc b/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover.cc index f74bf40c913990..a771b77614a075 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/permutation_util.h" #include "xla/service/hlo_creation_utils.h" @@ -294,8 +295,8 @@ absl::StatusOr ReshapeMover::SinkRearrangeOperands( for (size_t i = 0; i < operands.size(); ++i) { VLOG(3) << "Updating operand #" << i << ": " << operands[i]->ToString(print_no_metadata); - TF_ASSIGN_OR_RETURN(operands[i], - ApplyInverseRearrange(rearrange, operands[i])); + ASSIGN_OR_RETURN(operands[i], + ApplyInverseRearrange(rearrange, operands[i])); VLOG(3) << "Updated operand #" << i << " to: " << operands[i]->ToString(print_no_metadata); } @@ -330,7 +331,7 @@ absl::StatusOr ReshapeMover::SinkRearrangeOperands( new_elementwise->clear_sharding(); } - TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( instruction, std::move(new_rearrange))); return true; } @@ -389,7 +390,7 @@ absl::StatusOr ReshapeMover::TryReshapeMoveOnCandidates( })) { break; } - TF_ASSIGN_OR_RETURN(bool did_change, SinkRearrangeOperands(instruction)); + ASSIGN_OR_RETURN(bool did_change, SinkRearrangeOperands(instruction)); CHECK(did_change); } return true; @@ -406,8 +407,7 @@ absl::StatusOr ReshapeMover::RunImpl( candidates.insert(instruction); } } - TF_ASSIGN_OR_RETURN(bool did_change, - TryReshapeMoveOnCandidates(&candidates)); + ASSIGN_OR_RETURN(bool did_change, TryReshapeMoveOnCandidates(&candidates)); changed |= did_change; } return changed; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover_test.cc index ed624d197eef97..655f7039325ebe 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -44,8 +45,7 @@ class ReshapeMoverTest : public HloHardwareIndependentTestBase { absl::Status RunPass(HloModule* module, bool change_expected, ReshapeMoverOptions options = ReshapeMoverOptions{}, bool run_algsimp = false) { - TF_ASSIGN_OR_RETURN(bool changed, - RunHloPass(ReshapeMover(options), module)); + ASSIGN_OR_RETURN(bool changed, RunHloPass(ReshapeMover(options), module)); SCOPED_TRACE(module->ToString()); EXPECT_EQ(changed, change_expected); TF_EXPECT_OK(RunHloPass(HloVerifier(HloVerifierOpts()), module).status()); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions.cc b/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions.cc index c9eea633cb00ea..ded2c27833b253 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -50,10 +51,10 @@ absl::StatusOr RunOnComputation(HloComputation& computation) { } if (instruction->shape().element_type() == input->shape().element_type()) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( instruction->parent()->ReplaceInstruction(instruction, input)); } else { - TF_RETURN_IF_ERROR(instruction->parent()->ReplaceWithNewInstruction( + RETURN_IF_ERROR(instruction->parent()->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConvert(instruction->shape(), input))); } @@ -73,7 +74,7 @@ absl::StatusOr SimplifyFPConversions::RunImpl( bool changed = false; for (HloComputation* computation : module->MakeComputationPostOrder(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool comp_changed, RunOnComputation(*computation)); + ASSIGN_OR_RETURN(bool comp_changed, RunOnComputation(*computation)); changed |= comp_changed; } XLA_VLOG_LINES( diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/slice_hoister.cc b/third_party/xla/xla/hlo/transforms/simplifiers/slice_hoister.cc index 23e26621917ce8..a21fa3a786e268 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/slice_hoister.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/slice_hoister.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -70,7 +71,7 @@ absl::StatusOr TryHoistSliceThroughTranspose( HloInstruction* new_transpose = computation->AddInstruction(HloInstruction::CreateTranspose( slice_instruction->shape(), new_slice, dimensions_permutation)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation->ReplaceInstruction(slice_instruction, new_transpose)); return true; } @@ -111,7 +112,7 @@ absl::StatusOr TryHoistSliceThroughElementwiseBinaryOperation( slice_instruction->shape(), rhs, slice_instruction->slice_starts(), slice_instruction->slice_limits(), slice_instruction->slice_strides())); - TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( slice_instruction, HloInstruction::CreateBinary( slice_instruction->shape(), slice_operand->opcode(), lhs_slice, rhs_slice))); @@ -130,8 +131,8 @@ absl::StatusOr TryHoistingSlice(HloInstruction* instruction, auto hoisting_functions = {TryHoistSliceThroughElementwiseBinaryOperation, TryHoistSliceThroughTranspose}; for (auto hoisting_function : hoisting_functions) { - TF_ASSIGN_OR_RETURN(bool changed, - hoisting_function(slice_instruction, computation)); + ASSIGN_OR_RETURN(bool changed, + hoisting_function(slice_instruction, computation)); if (changed) { return true; } @@ -161,8 +162,8 @@ absl::StatusOr HoistSliceOperations(HloComputation* computation) { std::vector instructions = computation->MakeInstructionPostOrder(); for (HloInstruction* instruction : instructions) { - TF_ASSIGN_OR_RETURN(bool instruction_changed, - TryHoistingSlice(instruction, computation)); + ASSIGN_OR_RETURN(bool instruction_changed, + TryHoistingSlice(instruction, computation)); if (instruction_changed) { changed_on_last_iteration = true; break; @@ -180,8 +181,8 @@ absl::StatusOr SliceHoister::RunImpl( bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool changed_computation, - HoistSliceOperations(computation)); + ASSIGN_OR_RETURN(bool changed_computation, + HoistSliceOperations(computation)); changed |= changed_computation; } return changed; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker.cc b/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker.cc index 5780e5d463017e..9058c01abaf743 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -199,7 +200,7 @@ absl::Status SinkSlices( user->shape(), {operation_on_slice_sources})); VLOG(10) << "Adding new slice: " << user_slice->ToString() << " to replace: " << user->ToString(); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(user_slice)); + RETURN_IF_ERROR(user->ReplaceAllUsesWith(user_slice)); } return absl::OkStatus(); } @@ -288,7 +289,7 @@ absl::StatusOr SliceSinker::RunImpl( instruction->operands(), std::back_inserter(slice_sources), [](HloInstruction* slice) { return slice->mutable_operand(0); }); - TF_RETURN_IF_ERROR(SinkSlices(slice_sources, similar_operations.value())); + RETURN_IF_ERROR(SinkSlices(slice_sources, similar_operations.value())); changed = true; } } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier.cc index 4b5bdbde4aabe8..e38ffe4d1e385a 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -140,9 +141,9 @@ absl::StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { std::vector users(sort->users().begin(), sort->users().end()); for (HloInstruction* user : users) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( user->ReplaceAllUsesWith(result_map.at(user->tuple_index()))); - TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(user)); + RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(user)); } return true; } @@ -162,7 +163,7 @@ absl::StatusOr SortSimplifier::RunImpl( } for (HloInstruction* sort_instr : sort_instrs) { - TF_ASSIGN_OR_RETURN(bool result, RemoveUnusedOperandFromSort(sort_instr)); + ASSIGN_OR_RETURN(bool result, RemoveUnusedOperandFromSort(sort_instr)); changed |= result; } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_collective_normalization.cc b/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_collective_normalization.cc index be89794d7c6f71..b32ff85ef871ca 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_collective_normalization.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_collective_normalization.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/collective_op_group_mode.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -137,8 +138,8 @@ absl::Status SubByteCollectiveNormalizationVisitor::HandleAllToAll( primitive_util::BitWidth(hlo->shape().element_type()); const auto* all_to_all = Cast(hlo); if (all_to_all->split_dimension()) { - TF_ASSIGN_OR_RETURN(const CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(all_to_all)); + ASSIGN_OR_RETURN(const CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(all_to_all)); const int64_t split_dimension_size = hlo->shape().dimensions(*all_to_all->split_dimension()); if (split_dimension_size % @@ -183,7 +184,7 @@ SubByteCollectiveNormalizationVisitor::ProcessCollectiveInstruction( hlo.parent()->AddInstruction(hlo.CloneWithNewOperands( new_collective_shape, {ReshapeAndCastToWiderType(hlo.mutable_operand(0), casted_type_)})); - TF_RETURN_IF_ERROR(hlo.parent()->ReplaceInstructionWithDifferentShape( + RETURN_IF_ERROR(hlo.parent()->ReplaceInstructionWithDifferentShape( &hlo, CastToNarrowerTypeAndReshape(new_collective, hlo.shape()))); MarkAsChanged(); @@ -198,7 +199,7 @@ absl::StatusOr SubByteCollectiveNormalization::RunImpl( SubByteCollectiveNormalizationVisitor visitor; for (HloComputation* computation : module->MakeComputationPostOrder(execution_threads)) { - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); } return visitor.changed(); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_normalization.cc b/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_normalization.cc index f00a844e69a988..b17f7335db7009 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_normalization.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_normalization.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/layout.h" @@ -104,7 +105,7 @@ absl::StatusOr SubByteNormalization::RunImpl( // element_size_in_bits within fusions being meaningless, because HloVerfier // checks for the correct use of element_size_in_bits even in fusion // computations. - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); } auto* computation_layout = module->mutable_entry_computation_layout(); for (int param_no = 0; param_no < computation_layout->parameter_count(); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/tree_reduction_rewriter.cc b/third_party/xla/xla/hlo/transforms/simplifiers/tree_reduction_rewriter.cc index a9056d12675350..f72ebdd40f0d53 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/tree_reduction_rewriter.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/tree_reduction_rewriter.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/builder/padding.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -109,13 +110,13 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { MakePadding(input_shape.dimensions(), window_dimensions, window_strides, Padding::kSame); - TF_ASSIGN_OR_RETURN( - Window window, ShapeInference::InferWindowFromDimensions( - window_dimensions, window_strides, padding, {}, {})); + ASSIGN_OR_RETURN(Window window, + ShapeInference::InferWindowFromDimensions( + window_dimensions, window_strides, padding, {}, {})); - TF_ASSIGN_OR_RETURN(Shape intermediate_shape, - ShapeInference::InferReduceWindowShape( - input_shape, initial_value->shape(), window)); + ASSIGN_OR_RETURN(Shape intermediate_shape, + ShapeInference::InferReduceWindowShape( + input_shape, initial_value->shape(), window)); HloInstruction *reduce_window = hlo->parent()->AddInstruction(HloInstruction::CreateReduceWindow( @@ -143,7 +144,7 @@ absl::StatusOr TreeReductionRewriter::RunImpl( bool changed = false; for (const auto &computation : module->MakeNonfusionComputations(execution_threads)) { - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); changed |= visitor.changed(); } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier.cc index a86f4c8aff75c9..0b81328beb914c 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -55,9 +56,9 @@ absl::StatusOr TupleSimplifier::RemoveWholeTuple( if (top_tuple == nullptr) { return nullptr; } - TF_ASSIGN_OR_RETURN(bool changed, - tuple->parent()->ReplaceInstruction( - tuple, top_tuple, /*preserve_sharding=*/true)); + ASSIGN_OR_RETURN(bool changed, + tuple->parent()->ReplaceInstruction( + tuple, top_tuple, /*preserve_sharding=*/true)); if (changed) { return top_tuple; } @@ -76,8 +77,7 @@ absl::StatusOr TupleSimplifier::RunImpl( } for (auto* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kTuple) { - TF_ASSIGN_OR_RETURN(HloInstruction * instr, - RemoveWholeTuple(instruction)); + ASSIGN_OR_RETURN(HloInstruction * instr, RemoveWholeTuple(instruction)); if (instr != nullptr) { changed = true; } @@ -115,11 +115,11 @@ absl::StatusOr TupleSimplifier::RunImpl( } if (replacement) { - TF_ASSIGN_OR_RETURN(bool replaced, - computation->ReplaceInstruction( - instruction, replacement, - /*preserve_sharding=*/true, - /*relay_control_dependency=*/true)); + ASSIGN_OR_RETURN(bool replaced, + computation->ReplaceInstruction( + instruction, replacement, + /*preserve_sharding=*/true, + /*relay_control_dependency=*/true)); changed |= replaced; } } @@ -127,7 +127,7 @@ absl::StatusOr TupleSimplifier::RunImpl( } if (changed && module->has_schedule()) { - TF_RETURN_IF_ERROR(module->schedule().Update()); + RETURN_IF_ERROR(module->schedule().Update()); } return changed; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/unflatten_call_graph.cc b/third_party/xla/xla/hlo/transforms/simplifiers/unflatten_call_graph.cc index 39530a71a2d790..2fda4054814d6e 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/unflatten_call_graph.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/unflatten_call_graph.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "highwayhash/arch_specific.h" #include "highwayhash/hh_types.h" #include "highwayhash/highwayhash.h" @@ -149,7 +150,7 @@ absl::Status UnflattenCallGraph::ValidateComputationHashes( }; // Validate all computations against their canonical versions in parallel. - TF_RETURN_IF_ERROR((xla::concurrency::ForEach( + RETURN_IF_ERROR((xla::concurrency::ForEach( hash_results.begin(), hash_results.end(), validate_against_canonical, concurrency::DefaultExecutor()))); @@ -169,8 +170,8 @@ absl::StatusOr UnflattenCallGraph::RunImpl( if (calls.targets.empty()) { return false; } - TF_ASSIGN_OR_RETURN(const std::vector hash_results, - HashComputations(calls.targets)); + ASSIGN_OR_RETURN(const std::vector hash_results, + HashComputations(calls.targets)); // Map computations to their hashes. // The HloComputation* keys are owned by the HloModule and are guaranteed to @@ -213,7 +214,7 @@ absl::StatusOr UnflattenCallGraph::RunImpl( } if (changed) { - TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); + RETURN_IF_ERROR(module->RemoveUnusedComputations()); module->CleanupComputations(); } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.cc b/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.cc index b87e57699aa3a4..5310bb61f0654d 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -71,7 +72,7 @@ absl::StatusOr ZeroSizedHloElimination::RunImpl( } if (comp->IsSafelyRemovable(instruction)) { - TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( + RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConstant(Literal::CreateFromShape(shape)))); changed = true; @@ -81,7 +82,7 @@ absl::StatusOr ZeroSizedHloElimination::RunImpl( HloInstruction* constant = comp->AddInstruction(HloInstruction::CreateConstant( Literal::CreateFromShape(instruction->shape()))); - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(constant)); + RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(constant)); changed = true; } } diff --git a/third_party/xla/xla/hlo/translate/BUILD b/third_party/xla/xla/hlo/translate/BUILD index 487619dd878366..d209786ca82f36 100644 --- a/third_party/xla/xla/hlo/translate/BUILD +++ b/third_party/xla/xla/hlo/translate/BUILD @@ -23,6 +23,7 @@ cc_library( ":stablehlo", "//xla/hlo/ir:hlo", "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", @@ -118,6 +119,7 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service/llvm_ir:llvm_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td index 3c2bf0361cffdc..75e742c217a8fd 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td @@ -38,6 +38,7 @@ defvar HloConversionAllowedOps = [ MHLO_ErfOp, MHLO_FusionOp, MHLO_MinimumBroadcastShapesOp, + MHLO_MulhiOp, MHLO_RaggedDotOp, MHLO_ScanOp, MHLO_SinhOp, @@ -128,6 +129,7 @@ defvar CustomHloConverterOps = [ MHLO_DomainOp, MHLO_FusionOp, MHLO_MinimumBroadcastShapesOp, + MHLO_MulhiOp, MHLO_RaggedDotOp, MHLO_ScanOp, MHLO_SinhOp, diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index cb35ef169ec66d..afdcdb60af245b 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -4363,6 +4363,9 @@ LogicalResult ExportXlaOp(AsinhOp op, OpLoweringContext ctx) { xla::Asinh(operand, /*result_accuracy=*/std::nullopt, /*expand=*/false); return success(); } +LogicalResult ExportXlaOp(MulhiOp op, OpLoweringContext ctx) { + return failure(); +} LogicalResult ExportXlaOp(AcosOp op, OpLoweringContext ctx) { return ExportElementwiseXlaOp(op, ctx); @@ -4955,6 +4958,11 @@ LogicalResult ConvertToHloModule::Lower( *return_value = xla::XlaOp(); + if (auto mulhi_op = dyn_cast(inst)) { + // Add lowering support once HLO op exists. + return failure(); + } + if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder, &stack_frame_indexes_builder_}))) { if (inst->getNumResults() == 1) { diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD index ee0bd17b68b44c..2f07332109eed0 100644 --- a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD @@ -45,6 +45,7 @@ lit_test_suite( "missing_main.mlir", "module_config.mlir", "module_attributes.mlir", + "mulhi.mlir", "multiple_return_tuple.mlir", "opaque_elements_attr.mlir", "ragged_dot.mlir", diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/mulhi.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/mulhi.mlir new file mode 100644 index 00000000000000..32a5df98fb5de0 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/mulhi.mlir @@ -0,0 +1,8 @@ +// RUN-DISABLED: xla-translate -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: echo 'Test filtered, unfilter once mulhi lowering lands.' + +func.func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: s32[4] mulhi + %0 = "mhlo.mulhi"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %0 : tensor<4xi32> +} diff --git a/third_party/xla/xla/hlo/translate/portable_api.cc b/third_party/xla/xla/hlo/translate/portable_api.cc index f2f9525abb7d26..4eb457d1d26e23 100644 --- a/third_party/xla/xla/hlo/translate/portable_api.cc +++ b/third_party/xla/xla/hlo/translate/portable_api.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/BuiltinOps.h" @@ -61,7 +62,7 @@ absl::StatusOr ConvertHloToStablehlo( xla::HloModule const& hlo_module, bool emit_bytecode) { mlir::MLIRContext context; LoadHloDialects(context); - TF_ASSIGN_OR_RETURN(auto module, ConvertHloToStablehlo(context, &hlo_module)); + ASSIGN_OR_RETURN(auto module, ConvertHloToStablehlo(context, &hlo_module)); if (emit_bytecode) return SerializeUsingBytecode(*module); return PrintModule(*module); } diff --git a/third_party/xla/xla/hlo/translate/stablehlo.cc b/third_party/xla/xla/hlo/translate/stablehlo.cc index 3479fe34329acb..4ef2017e1addf4 100644 --- a/third_party/xla/xla/hlo/translate/stablehlo.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" @@ -106,20 +107,20 @@ absl::Status ConvertStablehloToHloProtoInternal(mlir::ModuleOp module, bool run_canonicalizer) { if (!module) return absl::InvalidArgumentError("Module is null"); - TF_RETURN_IF_ERROR(StablehloToMhlo(module, run_canonicalizer)); + RETURN_IF_ERROR(StablehloToMhlo(module, run_canonicalizer)); mlir::MlirToHloConversionOptions options; options.return_tuple = return_tuple; options.use_tuple_args = use_tuple_args; options.direct_stablehlo_to_hlo = true; - TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module, hlo_proto, options)); + RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module, hlo_proto, options)); return absl::OkStatus(); } absl::StatusOr> ConvertStablehloToHloInternal( mlir::ModuleOp module, bool use_tuple_args, bool return_tuple) { xla::HloProto hlo_proto; - TF_RETURN_IF_ERROR(ConvertStablehloToHloProtoInternal( + RETURN_IF_ERROR(ConvertStablehloToHloProtoInternal( module, &hlo_proto, use_tuple_args, return_tuple, /*run_canonicalizer=*/true)); @@ -142,11 +143,11 @@ absl::StatusOr> ConvertHloToStablehlo( mlir::MLIRContext& ctx, const xla::HloModule* hlo_module) { mlir::OwningOpRef mlir_module = llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); - TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), - /*import_all_computation=*/true, - /*flatten_computation_args_result=*/true, - /*emit_stablehlo=*/true) - .Import(*hlo_module)); + RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), + /*import_all_computation=*/true, + /*flatten_computation_args_result=*/true, + /*emit_stablehlo=*/true) + .Import(*hlo_module)); return mlir_module; } @@ -154,11 +155,11 @@ absl::StatusOr> ConvertHloToStablehlo( mlir::MLIRContext& ctx, const xla::HloModuleProto* hlo_module_proto) { mlir::OwningOpRef mlir_module = llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); - TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), - /*import_all_computation=*/true, - /*flatten_computation_args_result=*/true, - /*emit_stablehlo=*/true) - .Import(*hlo_module_proto)); + RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), + /*import_all_computation=*/true, + /*flatten_computation_args_result=*/true, + /*emit_stablehlo=*/true) + .Import(*hlo_module_proto)); return mlir_module; } @@ -168,11 +169,10 @@ ConvertHloToStablehloWithOptions(mlir::MLIRContext& ctx, bool import_all_computations) { mlir::OwningOpRef mlir_module = llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); - TF_RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), - import_all_computations, - /*flatten_computation_args_result=*/true, - /*emit_stablehlo=*/true) - .Import(*hlo_module_proto)); + RETURN_IF_ERROR(HloModuleImporter(mlir_module.get(), import_all_computations, + /*flatten_computation_args_result=*/true, + /*emit_stablehlo=*/true) + .Import(*hlo_module_proto)); return mlir_module; } diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index efdccd0f0c4212..27b7de241c5075 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -40,6 +40,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -158,6 +159,7 @@ cc_library( "//xla/service:dot_as_convolution_util", "//xla/service:gather_scatter_utils", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", @@ -303,6 +305,7 @@ cc_library( "//xla/hlo/ir:mesh_and_axis", "//xla/hlo/ir:named_sharding", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/hlo/utils/hlo_longest_prefix.h b/third_party/xla/xla/hlo/utils/hlo_longest_prefix.h index 5ab989bfe22f25..763775971d720a 100644 --- a/third_party/xla/xla/hlo/utils/hlo_longest_prefix.h +++ b/third_party/xla/xla/hlo/utils/hlo_longest_prefix.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/tsl/platform/errors.h" @@ -29,12 +30,12 @@ template absl::Status VisitInstAndCalledButNotOperands(Visitor& visitor, const HloInstruction& inst) { // Visit the given instruction, and the things it calls, but not its operands. - TF_RETURN_IF_ERROR(visitor.DefaultAction(&inst)); + RETURN_IF_ERROR(visitor.DefaultAction(&inst)); for (const HloComputation* called : inst.called_computations()) { const HloInstruction* const root = called->root_instruction(); - TF_RETURN_IF_ERROR(root->Accept(&visitor, /*call_finish_visit=*/false, - /*ignore_control_predecessors=*/true, - /*cross_computation=*/true)); + RETURN_IF_ERROR(root->Accept(&visitor, /*call_finish_visit=*/false, + /*ignore_control_predecessors=*/true, + /*cross_computation=*/true)); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_reconstruction_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_reconstruction_util.cc index 0f94b87790b4de..2db75325fb212b 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_reconstruction_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_reconstruction_util.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/mesh_and_axis.h" #include "xla/hlo/ir/named_sharding.h" @@ -210,11 +211,10 @@ absl::StatusOr UnshardLiteral( } // We can do a dynamic slice and dynamic update slice or just CopySliceFrom - TF_RETURN_IF_ERROR( - unsharded_literal.CopySliceFrom(*shard.data, - /*src_base=*/zero_start, - /*dest_base=*/start_indices, - /*copy_size=*/copy_dims)); + RETURN_IF_ERROR(unsharded_literal.CopySliceFrom(*shard.data, + /*src_base=*/zero_start, + /*dest_base=*/start_indices, + /*copy_size=*/copy_dims)); } return unsharded_literal; } diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index 73a9a554e7f991..b9dd211bb47230 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -40,6 +40,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -3285,15 +3286,15 @@ absl::Status CanonicalizeLayoutAfterShardingPropagation( VLOG(4) << "There is no registered layout_canonicalization_callback."; return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(auto shapes_with_layout, - module->layout_canonicalization_callback()(*module)); + ASSIGN_OR_RETURN(auto shapes_with_layout, + module->layout_canonicalization_callback()(*module)); if (module->entry_computation_layout().result_layout().LayoutIsSet() && absl::c_any_of(update_output_layout, [](bool v) { return v; })) { if (absl::c_all_of(update_output_layout, [](bool v) { return v; })) { - TF_RETURN_IF_ERROR(module->mutable_entry_computation_layout() - ->mutable_result_layout() - ->CopyLayoutFromShape(shapes_with_layout.second)); + RETURN_IF_ERROR(module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->CopyLayoutFromShape(shapes_with_layout.second)); } else { Shape result_shape = module->mutable_entry_computation_layout() ->mutable_result_layout() @@ -3306,9 +3307,9 @@ absl::Status CanonicalizeLayoutAfterShardingPropagation( shapes_with_layout.second.tuple_shapes(i); } } - TF_RETURN_IF_ERROR(module->mutable_entry_computation_layout() - ->mutable_result_layout() - ->CopyLayoutFromShape(result_shape)); + RETURN_IF_ERROR(module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->CopyLayoutFromShape(result_shape)); } } @@ -3321,10 +3322,9 @@ absl::Status CanonicalizeLayoutAfterShardingPropagation( bool parameter_layout_is_set = module->entry_computation_layout().parameter_layout(i).LayoutIsSet(); if (update_parameter_layout && parameter_layout_is_set) { - TF_RETURN_IF_ERROR( - module->mutable_entry_computation_layout() - ->mutable_parameter_layout(i) - ->CopyLayoutFromShape(shapes_with_layout.first[i])); + RETURN_IF_ERROR(module->mutable_entry_computation_layout() + ->mutable_parameter_layout(i) + ->CopyLayoutFromShape(shapes_with_layout.first[i])); } } } diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/BUILD b/third_party/xla/xla/mlir/tools/mlir_replay/BUILD index 8eb72e4ef47444..e9e1741e95c616 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/BUILD +++ b/third_party/xla/xla/mlir/tools/mlir_replay/BUILD @@ -50,6 +50,7 @@ cc_library( "//xla/mlir/tools/mlir_replay/public:execution_trace_proto_cc", "//xla/mlir/tools/mlir_replay/public:execution_trace_utils", "//xla/service:hlo_proto_cc", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/random", "@com_google_absl//absl/random:bit_gen_ref", diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc b/third_party/xla/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc index 74e72759f621e0..85c73118b79e7c 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc +++ b/third_party/xla/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/random/random.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -67,7 +68,7 @@ absl::StatusOr> LoadArgs( const xla::HloSnapshot& snapshot, TypeRange types) { SmallVector result; for (const auto& [arg, type] : llvm::zip(snapshot.arguments(), types)) { - TF_ASSIGN_OR_RETURN(auto converted, LiteralToValue(arg, type)); + ASSIGN_OR_RETURN(auto converted, LiteralToValue(arg, type)); result.push_back(std::move(converted)); } return result; @@ -213,8 +214,8 @@ absl::StatusOr> Run( } auto args_to_buffers = ExtractXlaBufferAssignment(main); - TF_ASSIGN_OR_RETURN(auto args, - LoadArgs(snapshot, main.getBody().getArgumentTypes())); + ASSIGN_OR_RETURN(auto args, + LoadArgs(snapshot, main.getBody().getArgumentTypes())); auto out_args = main.getBody().getBlocks().front().getArguments().drop_front(args.size()); @@ -254,8 +255,7 @@ absl::StatusOr> Run( if (trace) { options.listener = &tracer; } - TF_ASSIGN_OR_RETURN(auto results, - RunInterpreter(symbols, main, args, options)); + ASSIGN_OR_RETURN(auto results, RunInterpreter(symbols, main, args, options)); if (results.empty()) { return out_buffers; diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD b/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD index 525701ba15c389..6c6f541f88e084 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD +++ b/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD @@ -35,6 +35,7 @@ cc_library( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/mlir/tools/mlir_interpreter/framework", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc b/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc index 8b7e0ec0c22d4d..a7240bfe4cf6bc 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc +++ b/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" @@ -242,7 +243,7 @@ absl::StatusOr LiteralToValue(const xla::Literal& literal) { auto elements = literal.Clone().DecomposeTuple(); Tuple result; for (auto& element : elements) { - TF_ASSIGN_OR_RETURN(auto converted, LiteralToValue(element)); + ASSIGN_OR_RETURN(auto converted, LiteralToValue(element)); result.values.push_back( std::make_shared(std::move(converted))); } @@ -302,14 +303,13 @@ absl::StatusOr LiteralToValue(const xla::Literal& literal) { absl::StatusOr LiteralToValue( const xla::LiteralProto& literal) { - TF_ASSIGN_OR_RETURN(auto deserialized, - xla::Literal::CreateFromProto(literal)); + ASSIGN_OR_RETURN(auto deserialized, xla::Literal::CreateFromProto(literal)); return LiteralToValue(deserialized); } absl::StatusOr LiteralToValue( const xla::LiteralProto& literal, mlir::Type type) { - TF_ASSIGN_OR_RETURN(auto result, LiteralToValue(literal)); + ASSIGN_OR_RETURN(auto result, LiteralToValue(literal)); return {DispatchScalarType(type, [&](auto dummy) -> InterpreterValue { TensorOrMemref cast; cast.view = result.View(); @@ -400,7 +400,7 @@ absl::StatusOr TracedValueToValue( case TracedValue::TUPLE: Tuple result; for (const auto& elem : traced_value.tuple_elements()) { - TF_ASSIGN_OR_RETURN(auto converted, TracedValueToValue(elem)); + ASSIGN_OR_RETURN(auto converted, TracedValueToValue(elem)); result.values.push_back( std::make_shared(std::move(converted))); } diff --git a/third_party/xla/xla/mlir/utils/BUILD b/third_party/xla/xla/mlir/utils/BUILD index 673c3d9df1fdfc..4902e31c6f6af9 100644 --- a/third_party/xla/xla/mlir/utils/BUILD +++ b/third_party/xla/xla/mlir/utils/BUILD @@ -53,6 +53,7 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/mlir/utils/type_util.cc b/third_party/xla/xla/mlir/utils/type_util.cc index cf194c242f8930..b5574cf2c8b9e9 100644 --- a/third_party/xla/xla/mlir/utils/type_util.cc +++ b/third_party/xla/xla/mlir/utils/type_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/mlir/utils/type_util.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" @@ -70,7 +71,7 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( : mlir::IntegerType::Signless); } if (xla::primitive_util::IsComplexType(type)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( mlir::Type component_type, xla::ConvertPrimitiveTypeToMlirType( xla::primitive_util::ComplexComponentType(type), b)); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 240701a891989a..c55937e5094a38 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -383,6 +383,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogisticOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MaxOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MinOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MulOp) +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MulhiOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NegOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NotOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(OrOp) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index 316427aa92d047..ab273c7a6f0e23 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -1029,6 +1029,16 @@ def MHLO_MulOp : MHLO_BinaryElementwiseOp<"multiply", let hasFolder = 1; } +def MHLO_MulhiOp : MHLO_BinaryElementwiseOp<"mulhi", + [Commutative, Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> { + let summary = "Mulhi operation"; + let description = [{ + Performs element-wise multiplication of two integer tensors `lhs` and `rhs`, + returning the most significant bits of the product. + }]; + let hasCustomHLOConverter = 1; +} + def MHLO_PowOp : MHLO_BinaryElementwiseOp<"power", [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexOrQuantizedIntTensor> { let summary = "Pow operation"; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc index 1cb2b99164d2ae..e86546934ac23f 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc @@ -153,7 +153,7 @@ struct ChloLegalizeToHighLevelMhloPass }); } conversionTarget.addIllegalOp(); + chlo::ScanOp, chlo::MulhiOp>(); if (failed(applyPartialConversion(getOperation(), conversionTarget, std::move(conversionPatterns)))) { @@ -334,6 +334,11 @@ LogicalResult convertAsinhChloToMhlo(chlo::AsinhOp op, rewriter.replaceOpWithNewOp(op, op->getOperands()); return success(); } +LogicalResult convertMulhiChloToMhlo(chlo::MulhiOp op, + PatternRewriter& rewriter) { + rewriter.replaceOpWithNewOp(op, op->getOperands()); + return success(); +} } // namespace @@ -387,6 +392,7 @@ void populateChloToHighLevelMhloOpPatterns( } patterns->add(mhlo::convertRaggedDotChloToMhlo, kBenefit); patterns->add(mhlo::convertScanChloToMhlo, kBenefit); + patterns->add(mhlo::convertMulhiChloToMhlo, kBenefit); populateWithGenerated(*patterns); } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index d4bb378d762263..f7addbd5c9ecb3 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -147,6 +147,11 @@ std::optional getPublicFeaturesNotInStablehlo(HloOpTy hloOp) { // Version 1: Initial version for AsinhOp. return 1; } + // StableHLO doesn't support Mulhi yet. + if constexpr (std::is_same::value) { + // Version 1: Initial version for MulhiOp. + return 1; + } return std::nullopt; } @@ -520,7 +525,8 @@ LogicalResult convertAttributes(ConversionPatternRewriter& rewriter, !std::is_same::value && !std::is_same::value && !std::is_same::value && - !std::is_same::value) { + !std::is_same::value && + !std::is_same::value) { if (!stablehloAttr) { stablehloAttr = convertDenseArray>( hloAttr.getName(), hloAttr.getValue()); @@ -814,7 +820,7 @@ void populateHloToStablehloPatterns(RewritePatternSet* patterns, populateHloToStablehloCustomCallPatterns< mhlo::AcosOp, mhlo::AcoshOp, mhlo::AsinOp, mhlo::AsinhOp, mhlo::AtanhOp, - mhlo::CoshOp, mhlo::SinhOp, mhlo::ErfOp, mhlo::TopKOp>( + mhlo::CoshOp, mhlo::SinhOp, mhlo::ErfOp, mhlo::TopKOp, mhlo::MulhiOp>( patterns, converter, context, allowExperimentalFeatures); } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc index 098b387a9c9415..c5a2fc0d1201da 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc @@ -111,8 +111,8 @@ struct HloLegalizeToStablehloPass // since we're specifically legalizing to StableHLO, not to CHLO.) target.addLegalOp< // mhlo::AcosOp, mhlo::AcoshOp, mhlo::AsinOp, mhlo::AsinhOp, - mhlo::AtanhOp, mhlo::CoshOp, mhlo::ErfOp, mhlo::RaggedDotOp, - mhlo::ScanOp, mhlo::SinhOp, mhlo::TopKOp>(); + mhlo::AtanhOp, mhlo::CoshOp, mhlo::ErfOp, mhlo::MulhiOp, + mhlo::RaggedDotOp, mhlo::ScanOp, mhlo::SinhOp, mhlo::TopKOp>(); // These ops do not exist in StableHLO. (They don't exist in CHLO, either; // MHLO is the appropriate dialect for expressing XLA-specific features diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_preserve_high_level_ops.cpp b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_preserve_high_level_ops.cpp index e50995fab79a38..2bb82cc8321a71 100644 --- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_preserve_high_level_ops.cpp +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_preserve_high_level_ops.cpp @@ -235,6 +235,7 @@ struct ChloPreserveHighLevelOpsPass ChloOpToCompositePattern, ChloOpToCompositePattern, ChloOpToCompositePattern, + ChloOpToCompositePattern, ChloOpToCompositePattern, ChloOpToCompositePattern, ChloOpToCompositePattern>( diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp index a2cb70cb289914..70520737c5f66e 100644 --- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp @@ -389,6 +389,15 @@ struct AsinhOpCustomCallRecomposePattern op, {"mhlo.asinh", "chlo.asinh"}, rewriter); } }; +struct MulhiOpCustomCallRecomposePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(stablehlo::CustomCallOp op, + PatternRewriter& rewriter) const override { + return recomposeChloOpFromCustomCall( + op, {"mhlo.mulhi", "chlo.mulhi"}, rewriter); + } +}; struct ScanOpCustomCallRecomposePattern : public OpRewritePattern { @@ -453,6 +462,7 @@ struct ChloRecomposeOpsPass CoshOpCustomCallRecomposePattern, SinhOpCustomCallRecomposePattern, ErfOpCustomCallRecomposePattern, + MulhiOpCustomCallRecomposePattern, RaggedDotOpCustomCallRecomposePattern, ScanOpCustomCallRecomposePattern, TanOpCustomCallRecomposePattern, @@ -468,6 +478,7 @@ struct ChloRecomposeOpsPass ChloOpRecomposePattern, ChloOpRecomposePattern, ChloOpRecomposePattern, + ChloOpRecomposePattern, ChloOpRecomposePattern, ChloOpRecomposePattern, ChloOpRecomposePattern>(ctx); diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index 9ffde8bc84f170..5fccf361b9b9db 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -3649,3 +3649,35 @@ func.func @scan_with_size(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tenso } : (tensor<2x3xf32>, tensor<3xf32>) -> (tensor<2x3xf32>, tensor<3xf32>) func.return %0 : tensor<2x3xf32> } + +// ----- + +// CHECK-LABEL: func.func @mulhi_s32( +// CHECK-SAME: %[[ARG0:.*]]: tensor<4xi32>, %[[ARG1:.*]]: tensor<4xi32>) -> tensor<4xi32> +// CHECK: %[[RESULT:.*]] = mhlo.mulhi %[[ARG0]], %[[ARG1]] : tensor<4xi32> +// CHECK: return %[[RESULT]] : tensor<4xi32> +func.func @mulhi_s32(%arg0 : tensor<4xi32>, %arg1 : tensor<4xi32>) -> tensor<4xi32> { + %result = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %result : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: func.func @mulhi_u32( +// CHECK-SAME: %[[ARG0:.*]]: tensor<4xui32>, %[[ARG1:.*]]: tensor<4xui32>) -> tensor<4xui32> +// CHECK: %[[RESULT:.*]] = mhlo.mulhi %[[ARG0]], %[[ARG1]] : tensor<4xui32> +// CHECK: return %[[RESULT]] : tensor<4xui32> +func.func @mulhi_u32(%arg0 : tensor<4xui32>, %arg1 : tensor<4xui32>) -> tensor<4xui32> { + %result = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xui32>, tensor<4xui32>) -> tensor<4xui32> + func.return %result : tensor<4xui32> +} + +// CHECK-LABEL: func.func @mulhi_i16( +// CHECK-SAME: %[[ARG0:.*]]: tensor<4xi16>, %[[ARG1:.*]]: tensor<4xi16>) -> tensor<4xi16> +// CHECK: %[[RESULT:.*]] = mhlo.mulhi %[[ARG0]], %[[ARG1]] : tensor<4xi16> +// CHECK: return %[[RESULT]] : tensor<4xi16> +func.func @mulhi_i16(%arg0 : tensor<4xi16>, %arg1 : tensor<4xi16>) -> tensor<4xi16> { + %result = "chlo.mulhi"(%arg0, %arg1) : (tensor<4xi16>, tensor<4xi16>) -> tensor<4xi16> + func.return %result : tensor<4xi16> +} + diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index e9031d779ebf35..98904e972fd134 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -7403,3 +7403,17 @@ func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tenso } : (tensor<11x5xf32>, tensor<5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> func.return %0 : tensor<11x7xf32> } + +// ----- + +// CHECK-LABEL: func @mulhi_i32 +func.func @mulhi_i32(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = "mhlo.mulhi"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @mulhi_i16 +func.func @mulhi_i16(%arg0: tensor<4xi16>, %arg1: tensor<4xi16>) -> tensor<4xi16> { + %0 = "mhlo.mulhi"(%arg0, %arg1) : (tensor<4xi16>, tensor<4xi16>) -> tensor<4xi16> + func.return %0: tensor<4xi16> +} diff --git a/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/chlo_preserve_high_level_ops.mlir b/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/chlo_preserve_high_level_ops.mlir index 81f31a49477581..b9437dc0a03e61 100644 --- a/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/chlo_preserve_high_level_ops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/chlo_preserve_high_level_ops.mlir @@ -129,6 +129,16 @@ func.func @asinh_preserve(%arg0: tensor<3x20x20xbf16>) -> tensor { // ----- +// CHECK-LABEL: func @mulhi_preserve +func.func @mulhi_preserve(%arg0: tensor<3x20x20xi32>, %arg1: tensor<3x20x20xi32>) -> tensor { + // CHECK: stablehlo.composite "chlo.mulhi" %arg0, %arg1 {decomposition = @chlo.mulhi.impl, version = 1 : i32} + %0 = chlo.mulhi %arg0, %arg1 : tensor<3x20x20xi32>, tensor<3x20x20xi32> -> tensor + return %0 : tensor +} + +// ----- + + // CHECK-LABEL: func @tan_no_preserve func.func @tan_no_preserve(%arg0: tensor<16xf32>) -> tensor { // CHECK: chlo.tan diff --git a/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/chlo_recompose_ops.mlir b/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/chlo_recompose_ops.mlir index 7eef8279ba793b..7c1b99c8da3218 100644 --- a/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/chlo_recompose_ops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/stablehlo_ext/chlo_recompose_ops.mlir @@ -123,6 +123,21 @@ func.func private @chlo.asinh.impl(%arg0: tensor<3x20x20xbf16>) -> tensor, %arg1: tensor<3x20x20xi32>) -> tensor { + // CHECK-NEXT: chlo.mulhi + // CHECK-NOT: stablehlo.composite + %0 = stablehlo.composite "chlo.mulhi" %arg0, %arg1 {decomposition = @chlo.mulhi.impl, version = 1 : i32} : (tensor<3x20x20xi32>, tensor<3x20x20xi32>) -> tensor + return %0 : tensor +} +// CHECK-NOT: @chlo.mulhi.impl +func.func private @chlo.mulhi.impl(%arg0: tensor<3x20x20xi32>, %arg1: tensor<3x20x20xi32>) -> tensor { + %0 = chlo.mulhi %arg0, %arg1 : tensor<3x20x20xi32>, tensor<3x20x20xi32> -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @ragged_dot_recompose_composite func.func @ragged_dot_recompose_composite(%arg0: tensor<2x11x5xf32>, %arg1: tensor<3x2x5x7xf32>, %arg2: tensor<3xi64>) -> tensor<2x11x7xf32> { // CHECK: "chlo.ragged_dot"(%arg0, %arg1, %arg2) <{precision_config = [#chlo, #chlo], ragged_dot_dimension_numbers = #chlo.ragged_dot}> : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> @@ -289,6 +304,20 @@ func.func @asinh_recompose_cc(%arg0: tensor<3x20x20xbf16>) -> tensor, %arg1: tensor<3x20x20xi32>) -> tensor { + // CHECK: %0 = chlo.mulhi %arg0, %arg1 : tensor<3x20x20xi32>, tensor<3x20x20xi32> -> tensor + %0 = "stablehlo.custom_call"(%arg0, %arg1) { + backend_config = "", + call_target_name = "mhlo.mulhi", + mhlo.attributes = {}, + mhlo.version = 1 : i64 + } : (tensor<3x20x20xi32>, tensor<3x20x20xi32>) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @ragged_dot_recompose_cc func.func @ragged_dot_recompose_cc(%arg0: tensor<2x11x5xf32>, %arg1: tensor<3x2x5x7xf32>, %arg2: tensor<3xi64>) -> tensor<2x11x7xf32> { // CHECK: "chlo.ragged_dot"(%arg0, %arg1, %arg2) <{precision_config = [#chlo, #chlo], ragged_dot_dimension_numbers = #chlo.ragged_dot}> : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index f605fbe999a12e..87f2175ddd77ec 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -1605,6 +1605,12 @@ PjRtRawLoadedExecutable::RawExecuteResult CpuPjRtRawLoadedExecutable::Execute( absl::MakeConstSpan(leaf_buffers), [buffers = leaf_buffers, tuple_index_table, allocator = client->allocator()]() mutable { + for (int i = 0; i < buffers.size(); ++i) { + if (buffers[i].IsError()) { + tuple_index_table.SetError(buffers[i].GetError()); + return; + } + } size_t index_table_byte_size = buffers.size() * sizeof(void*); // We assume tuple table allocations will not fail. CHECK_OK(CpuDeviceMemory::AllocateInto( diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc b/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc index 52fcdc024ef1bf..8b6d70d4b07442 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc @@ -1248,6 +1248,57 @@ TEST(PjRtCpuClientTest, SerializeYnnFusions) { LiteralUtil::CreateR1(literal_data_x2_squared), *result_literal)); } +TEST(PjRtCpuClientTest, TupleInputWithErrorBuffer) { + static constexpr char kProgram[] = R"( + HloModule TupleInput + ENTRY TupleInput { + t = (f32[2], f32[2]) parameter(0) + p0 = f32[2] get-tuple-element(t), index=0 + p1 = f32[2] get-tuple-element(t), index=1 + ROOT add = f32[2] add(p0, p1) + })"; + + ASSERT_OK_AND_ASSIGN(auto client, GetPjRtCpuClient(CpuClientOptions())); + ASSERT_OK_AND_ASSIGN(auto hlo_module, + ParseAndReturnUnverifiedModule(kProgram, {})); + + XlaComputation xla_computation(hlo_module->ToProto()); + CompileOptions compile_options; + compile_options.parameter_is_tupled_arguments = true; + ASSERT_OK_AND_ASSIGN( + auto pjrt_executable, + client->CompileAndLoad(xla_computation, compile_options)); + + std::vector data(2, 1.0f); + Shape shape = ShapeUtil::MakeShape(F32, {2}); + ASSERT_OK_AND_ASSIGN( + auto normal_buffer, + client->BufferFromHostBuffer( + data.data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, + client->memory_spaces()[0], /*device_layout=*/nullptr)); + + ASSERT_OK_AND_ASSIGN( + auto error_buffer, + client->CreateErrorBuffer(absl::InternalError("forced error"), shape, + client->memory_spaces()[0])); + + // The executable expects a single tuple parameter which we supply as + // independent leaf buffers. + // One of the leaf buffers is an error buffer. + auto result = pjrt_executable->Execute( + /*argument_handles=*/{{normal_buffer.get(), error_buffer.get()}}, + /*options=*/{}); + + ASSERT_THAT(result, absl_testing::StatusIs(tsl::error::OK)); + ASSERT_EQ(result->size(), 1); + ASSERT_EQ(result->at(0).size(), 1); + EXPECT_THAT( + result->at(0).at(0)->ToLiteral().Await(), + absl_testing::StatusIs(tsl::error::INTERNAL, HasSubstr("forced error"))); +} + } // namespace //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 2200357b0add21..04f568597d950e 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -135,6 +135,7 @@ cc_library( "//xla/python/ifrt", "//xla/python/pjrt_ifrt:pjrt_dtype", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", @@ -217,6 +218,7 @@ cc_library( "//xla/service:custom_call_sharding_helper", "//xla/service/spmd:spmd_partitioner", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -384,6 +386,7 @@ cc_library( "//xla/service/spmd/shardy/sdy_round_trip:pipelines", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", @@ -549,6 +552,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_proto_cc", "//xla/tsl/platform:env", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "//xla/tsl/profiler/convert:xla_op_utils", "//xla/tsl/profiler/utils:file_system_utils", diff --git a/third_party/xla/xla/python/compile_only_ifrt/BUILD b/third_party/xla/xla/python/compile_only_ifrt/BUILD index d65e9d7ede32f6..e4abd460b502e7 100644 --- a/third_party/xla/xla/python/compile_only_ifrt/BUILD +++ b/third_party/xla/xla/python/compile_only_ifrt/BUILD @@ -33,6 +33,7 @@ cc_library( "//xla/service:computation_placer_hdr", "//xla/tsl/concurrency:future", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/python/compile_only_ifrt/client.h b/third_party/xla/xla/python/compile_only_ifrt/client.h index 96f39e67bf06ef..c8197aa06dd634 100644 --- a/third_party/xla/xla/python/compile_only_ifrt/client.h +++ b/third_party/xla/xla/python/compile_only_ifrt/client.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/layout.h" #include "xla/layout_util.h" @@ -337,17 +338,17 @@ class CompileOnlyIfRtClient final return std::make_shared( LayoutUtil::MakeDescendingLayout(dims.size())); } - TF_ASSIGN_OR_RETURN(PrimitiveType element_type, ToPrimitiveType(dtype)); - TF_ASSIGN_OR_RETURN(xla::Layout layout, - topology_->GetDefaultLayout(element_type, dims)); + ASSIGN_OR_RETURN(PrimitiveType element_type, ToPrimitiveType(dtype)); + ASSIGN_OR_RETURN(xla::Layout layout, + topology_->GetDefaultLayout(element_type, dims)); return std::make_shared(std::move(layout)); } absl::StatusOr GetDefaultLayout( ifrt::DType dtype, const ifrt::Shape& shape, const ifrt::ShardingRef& sharding) const override { - TF_ASSIGN_OR_RETURN(const ifrt::Shape shard_shape, - sharding->GetShardShape(shape)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(const ifrt::Shape shard_shape, + sharding->GetShardShape(shape)); + ASSIGN_OR_RETURN( std::shared_ptr layout, GetDefaultPjRtLayout(dtype, shard_shape.dims(), sharding->devices()->devices().front(), diff --git a/third_party/xla/xla/python/custom_partition_callback.cc b/third_party/xla/xla/python/custom_partition_callback.cc index f848aa99082f83..bf2b08c3bf2217 100644 --- a/third_party/xla/xla/python/custom_partition_callback.cc +++ b/third_party/xla/xla/python/custom_partition_callback.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -79,7 +80,7 @@ absl::StatusOr InlineHloComputation( std::vector new_operands; new_operands.reserve(inst->operand_count()); for (HloInstruction* operand : inst->mutable_operands()) { - TF_ASSIGN_OR_RETURN(auto* new_operand, resolve(operand)); + ASSIGN_OR_RETURN(auto* new_operand, resolve(operand)); new_operands.push_back(new_operand); } auto* new_inst = builder->AddInstruction( @@ -114,18 +115,18 @@ class CApiCustomCallPartitioner : public xla::CustomCallPartitioner { std::vector arg_shardings; std::optional result_sharding; std::string mlir_module; - TF_ASSIGN_OR_RETURN(std::tie(mlir_module, arg_shardings, result_sharding), - jax::ConsumeResults(&args)); - TF_RETURN_IF_ERROR(ParseMlirModuleStringAndConvertToXlaComputation( + ASSIGN_OR_RETURN(std::tie(mlir_module, arg_shardings, result_sharding), + jax::ConsumeResults(&args)); + RETURN_IF_ERROR(ParseMlirModuleStringAndConvertToXlaComputation( mlir_module, computation, /*use_tuple_args=*/false, /*return_tuple=*/false)); auto hlo_module_config = xla::HloModule::CreateModuleConfigFromProto( computation.proto(), xla::DefaultDebugOptionsIgnoringFlags()) .value(); - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - xla::HloModule::CreateFromProto(computation.proto(), - hlo_module_config)); + ASSIGN_OR_RETURN(std::unique_ptr hlo_module, + xla::HloModule::CreateFromProto(computation.proto(), + hlo_module_config)); std::vector operands; operands.reserve(instruction->operand_count()); if (arg_shardings.size() != instruction->operand_count()) { @@ -145,15 +146,14 @@ class CApiCustomCallPartitioner : public xla::CustomCallPartitioner { // so inline all calls here explicitly, since some targets require it. HloPassPipeline pipeline("custom-call-inliner"); pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module.get(), {}).status()); - - TF_ASSIGN_OR_RETURN( - auto* partitioned_hlo, - InlineHloComputation( - instruction, hlo_module->entry_computation(), - partitioner->builder(), operands, - [partitioner]() { return partitioner->NewChannel(); }, - "_custom_call_lowering_rule")); + RETURN_IF_ERROR(pipeline.Run(hlo_module.get(), {}).status()); + + ASSIGN_OR_RETURN(auto* partitioned_hlo, + InlineHloComputation( + instruction, hlo_module->entry_computation(), + partitioner->builder(), operands, + [partitioner]() { return partitioner->NewChannel(); }, + "_custom_call_lowering_rule")); partitioned_hlo->set_sharding(result_sharding.value()); spmd::PartitionedHlo result_partitioned = @@ -240,8 +240,8 @@ absl::StatusOr ReadHloSharding( return absl::InternalError( "custom_call_sharding.cc: error parsing OpShardingProto"); } - TF_ASSIGN_OR_RETURN(xla::HloSharding sharding, - xla::HloSharding::FromProto(std::move(proto))); + ASSIGN_OR_RETURN(xla::HloSharding sharding, + xla::HloSharding::FromProto(std::move(proto))); if (sharding.UseNamedShardingLeaf()) { sharding = xla::HloSharding::V3ToV2Sharding(sharding.named_sharding()); } @@ -314,14 +314,14 @@ ConsumeResults(JAX_CustomCallPartitioner_Partition_Args* args) { absl::Cleanup cleanup = [args] { args->header.cleanup_fn(args->header.data); }; - TF_RETURN_IF_ERROR(ConsumeHeader(args->header)); - TF_ASSIGN_OR_RETURN(auto result_sharding, - ReadHloSharding(args->result_sharding)); + RETURN_IF_ERROR(ConsumeHeader(args->header)); + ASSIGN_OR_RETURN(auto result_sharding, + ReadHloSharding(args->result_sharding)); std::vector arg_shardings; arg_shardings.reserve(args->num_args); for (size_t i = 0; i < args->num_args; ++i) { - TF_ASSIGN_OR_RETURN(auto arg_sharding, - ReadHloSharding(args->args_sharding[i])); + ASSIGN_OR_RETURN(auto arg_sharding, + ReadHloSharding(args->args_sharding[i])); arg_shardings.push_back(std::move(arg_sharding)); } return std::tuple, @@ -360,22 +360,22 @@ ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) { shapes.reserve(args->num_args); shardings.reserve(args->num_args); for (size_t i = 0; i < args->num_args; ++i) { - TF_ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->op_args[i].shape)); + ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->op_args[i].shape)); shapes.push_back(shape); if (args->op_args[i].has_sharding) { - TF_ASSIGN_OR_RETURN(auto sharding, - ReadHloSharding(args->op_args[i].sharding)); + ASSIGN_OR_RETURN(auto sharding, + ReadHloSharding(args->op_args[i].sharding)); shardings.push_back(std::move(sharding)); } else { shardings.push_back(std::nullopt); } } - TF_ASSIGN_OR_RETURN(auto result_shape, ReadHloShape(args->op_result.shape)); + ASSIGN_OR_RETURN(auto result_shape, ReadHloShape(args->op_result.shape)); std::optional result_sharding; if (args->op_result.has_sharding) { - TF_ASSIGN_OR_RETURN(result_sharding, - ReadHloSharding(args->op_result.sharding)); + ASSIGN_OR_RETURN(result_sharding, + ReadHloSharding(args->op_result.sharding)); } return std::tuple, std::vector>, xla::Shape, @@ -393,18 +393,18 @@ ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { shapes.reserve(args->num_args); shardings.reserve(args->num_args); for (size_t i = 0; i < args->num_args; ++i) { - TF_ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->op_args[i].shape)); + ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->op_args[i].shape)); shapes.push_back(shape); if (args->op_args[i].has_sharding) { - TF_ASSIGN_OR_RETURN(auto sharding, - ReadHloSharding(args->op_args[i].sharding)); + ASSIGN_OR_RETURN(auto sharding, + ReadHloSharding(args->op_args[i].sharding)); shardings.push_back(std::move(sharding)); } else { shardings.push_back(std::nullopt); } } - TF_ASSIGN_OR_RETURN(auto result_shape, ReadHloShape(args->result_shape)); + ASSIGN_OR_RETURN(auto result_shape, ReadHloShape(args->result_shape)); return std::tuple, std::vector>, xla::Shape, absl::string_view>(std::move(shapes), std::move(shardings), @@ -459,7 +459,7 @@ absl::StatusOr> ConsumeResults( absl::Cleanup cleanup = [args] { args->header.cleanup_fn(args->header.data); }; - TF_RETURN_IF_ERROR(ConsumeHeader(args->header)); + RETURN_IF_ERROR(ConsumeHeader(args->header)); if (!args->has_result_sharding) { return std::nullopt; } @@ -468,8 +468,8 @@ absl::StatusOr> ConsumeResults( absl::StatusOr> ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { - TF_ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->result_shape)); - TF_ASSIGN_OR_RETURN(auto sharding, ReadHloSharding(args->result_sharding)); + ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->result_shape)); + ASSIGN_OR_RETURN(auto sharding, ReadHloSharding(args->result_sharding)); return std::tuple( std::move(sharding), std::move(shape), ToStringView(args->backend_config)); @@ -510,7 +510,7 @@ absl::StatusOr ConsumeResults( absl::Cleanup cleanup = [args] { args->header.cleanup_fn(args->header.data); }; - TF_RETURN_IF_ERROR(ConsumeHeader(args->header)); + RETURN_IF_ERROR(ConsumeHeader(args->header)); return ReadHloSharding(args->result_sharding); } diff --git a/third_party/xla/xla/python/ifrt/ir/BUILD b/third_party/xla/xla/python/ifrt/ir/BUILD index 1208a946510b04..5c3953e9fb63a9 100644 --- a/third_party/xla/xla/python/ifrt/ir/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/BUILD @@ -764,6 +764,7 @@ cc_library( "//xla/tsl/concurrency:ref_count", "//xla/tsl/platform:errors", "//xla/tsl/platform:status_macros", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", diff --git a/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.cc b/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.cc index 9a28a865317a31..6e68146f6a282e 100644 --- a/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.cc +++ b/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.cc @@ -311,9 +311,22 @@ CompiledIfrtIrProgram::Create( } } + // If `ifrt_ir_program` exclusively owns the MLIR context, create a new + // context and clone the compiled IFRT IR program into it. This reduces the + // host memory usage since the new context does not need to store the interned + // attributes from the deserialized StableHLO programs. + if (ifrt_ir_program->OwnsMlirContext()) { + auto context = std::make_unique( + mlir::MLIRContext::Threading::DISABLED); + ASSIGN_OR_RETURN(mlir::OwningOpRef cloned_module, + CloneModuleIntoContext(mlir_module, *context)); + ifrt_ir_program = std::make_unique( + std::move(context), std::move(cloned_module)); + } + // Extract input and output specs from the modified `mlir_module`, which has // all array shardings specified. - mlir::func::FuncOp main_func = GetMainFunction(mlir_module); + mlir::func::FuncOp main_func = GetMainFunction(ifrt_ir_program->mlir_module); std::vector in_specs; in_specs.reserve(main_func.getNumArguments()); for (const mlir::Type arg_type : main_func.getArgumentTypes()) { @@ -345,8 +358,7 @@ CompiledIfrtIrProgram::Create( auto create_program = [program_name = std::move(program_name), atom_executable_future_map = std::move(atom_executable_future_map), - mlir_module, client, in_specs = std::move(in_specs), - out_specs = std::move(out_specs), + client, in_specs = std::move(in_specs), out_specs = std::move(out_specs), donatable_input_indices = std::move(donatable_input_indices), device_list = std::move(device_list), ifrt_ir_program = std::move(ifrt_ir_program), @@ -359,9 +371,9 @@ CompiledIfrtIrProgram::Create( atom_executable_map->insert({key, std::move(executable)}); } - absl::Status layout_status = - PopulateLayouts(mlir_module, client, *atom_executable_map, - absl::MakeSpan(in_specs), absl::MakeSpan(out_specs)); + absl::Status layout_status = PopulateLayouts( + ifrt_ir_program->mlir_module, client, *atom_executable_map, + absl::MakeSpan(in_specs), absl::MakeSpan(out_specs)); if (!layout_status.ok()) { for (auto& spec : in_specs) { spec.layout = nullptr; @@ -371,9 +383,10 @@ CompiledIfrtIrProgram::Create( } } - ASSIGN_OR_RETURN(auto interpreter, ProgramInterpreter::Create( - client, program_name, mlir_module, - atom_executable_map, device_list)); + ASSIGN_OR_RETURN(auto interpreter, + ProgramInterpreter::Create( + client, program_name, ifrt_ir_program->mlir_module, + atom_executable_map, device_list)); ASSIGN_OR_RETURN(auto execute_fn, interpreter->BuildExecuteFn()); return std::make_shared(CompiledIfrtIrProgram{ diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc index db11c1f763b49a..3fe86ffc590c15 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc @@ -433,11 +433,6 @@ mlir::LogicalResult CopyArraysOp::verify() { return emitOpError() << "requires the same number of input and output arrays"; } - if (getDonated() && getReuse()) { - return emitOpError() - << "requires at most one of `donated` or `reuse` to be " - "set to true"; - } IfrtArrayType first_input = GetArrayType(getInputs().front()); auto src_devices = first_input.getDevicesAttr(); auto src_memory_kind = first_input.MemoryKind(); diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td index 495a36453de841..7d4a3189499cf7 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td @@ -73,7 +73,6 @@ def Ifrt_CopyArraysOp let arguments = (ins Variadic:$inputs, DefaultValuedOptionalAttr:$donated, - DefaultValuedOptionalAttr:$reuse, Variadic:$control_inputs); let results = (outs Variadic:$outputs, diff --git a/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc b/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc index b868a7a10e4e97..b1de65ae3ac2fc 100644 --- a/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc +++ b/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc @@ -66,6 +66,7 @@ limitations under the License. #include "xla/tsl/concurrency/future.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla { @@ -803,7 +804,7 @@ struct CopyArraysOpState { std::vector input_handles; absl::flat_hash_set dead_inputs; - ArrayCopySemantics copy_semantics; + bool copy_is_donated; std::vector output_handles; ShardingRef new_sharding; @@ -838,13 +839,11 @@ struct CopyArraysOpState { } inputs.push_back(array_it->second.array); - if (copy_semantics == ArrayCopySemantics::kDonateInput && - !array_it->second.can_be_donated) { + if (copy_is_donated && !array_it->second.can_be_donated) { array_idxs_to_copy.push_back(idx); arrays_to_copy.push_back(array_it->second.array); } - if ((copy_semantics == ArrayCopySemantics::kDonateInput && - array_it->second.can_be_donated) || + if ((copy_is_donated && array_it->second.can_be_donated) || dead_inputs.contains(handle)) { arrays_to_remove.push_back(handle); } @@ -878,10 +877,12 @@ struct CopyArraysOpState { // It is safe to get the devices and memory kind from the first output // because all outputs use the same devices and have the same memory kind. - ASSIGN_OR_RETURN( - auto copied_arrays, - env.client->CopyArrays(absl::MakeSpan(inputs), new_sharding->devices(), - new_sharding->memory_kind(), copy_semantics)); + ASSIGN_OR_RETURN(auto copied_arrays, + env.client->CopyArrays( + absl::MakeSpan(inputs), new_sharding->devices(), + new_sharding->memory_kind(), + copy_is_donated ? ArrayCopySemantics::kDonateInput + : ArrayCopySemantics::kAlwaysCopy)); for (const auto handle : arrays_to_remove) { if (env.deletable_program_arguments.erase(handle)) { @@ -921,14 +922,7 @@ absl::StatusOr ProgramInterpreter::HandleOp( state.dead_inputs.insert(ToArrayHandle(input)); } } - - if (copy_arrays_op.getDonated()) { - state.copy_semantics = ArrayCopySemantics::kDonateInput; - } else if (copy_arrays_op.getReuse()) { - state.copy_semantics = ArrayCopySemantics::kReuseInput; - } else { - state.copy_semantics = ArrayCopySemantics::kAlwaysCopy; - } + state.copy_is_donated = copy_arrays_op.getDonated(); ASSIGN_OR_RETURN(state.new_sharding, ShardingFromIfrtArrayType( diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir index 3077e5eaa17f6e..0c3d2a2597dc1e 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_copy_arrays.mlir @@ -28,17 +28,6 @@ func.func @copy_donated_array(%arg0: !array0) return } -// ----- - -!array = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> -func.func @copy_donated_array(%arg0: !array) - attributes {ifrt.function} { - %0, %ctrl = ifrt.CopyArrays(%arg0) {reuse=true} : (!array) -> (!array) - return -} - - // ----- !array0 = !ifrt.array, #ifrt.sharding_unspecified, [0,1]> @@ -206,17 +195,3 @@ func.func @no_auto_layout(%arg0: !array0) : (!array0) -> (!array1) return } - -// ----- - -!array0 = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> -!array1 = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [2,3]> -func.func @array_cannot_be_donated_and_reused(%arg0: !array0) - attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.CopyArrays' op requires at most one of `donated` or `reuse` to be set to true}} - %0, %ctrl = ifrt.CopyArrays(%arg0) {donated=true, reuse=true} - : (!array0) -> (!array1) - return -} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/BUILD b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/BUILD index c5ca5e6b94510c..d31b017cf6f9a3 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/BUILD @@ -15,7 +15,6 @@ lit_test_suite( "ifrt_legalize_to_vifrt.0_1_0.mlir.bytes", "ifrt_legalize_to_vifrt.0_2_0.mlir.bytes", "ifrt_legalize_to_vifrt.0_3_0.mlir.bytes", - "ifrt_legalize_to_vifrt.0_4_0.mlir.bytes", ], tools = [ "//xla/python/ifrt/ir/tests:ifrt-opt", diff --git a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_1_0.mlir b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_1_0.mlir index 0c9013969444eb..76ee66857e4f5f 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_1_0.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_1_0.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --vifrt-to-version='target_version=0.1.0' --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s +// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s // RUN: ifrt-translate --serialize --ifrt_version=0.1.0 --atom_program_version=1.13.1 --strip_debuginfo %s | ifrt-translate --deserialize --strip_debuginfo | ifrt-opt > %t.0 // RUN: ifrt-opt %s > %t.1 // RUN: diff %t.0 %t.1 @@ -23,7 +23,7 @@ func.func @type_array_and_control(%arg0: !array_t0) attributes {ifrt.function} { // CHECK-DAG: donated = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_t0) -> !array_t1 return } @@ -63,7 +63,7 @@ func.func @remap_attributes( // CHECK-DAG: mappings = [#vifrt.array_mapping_v1<0, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [0 : 1 : 1]>]>, #vifrt.array_mapping_v1<1, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [1 : 2 : 1]>]>] // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl_0 = ifrt.RemapArrays(%arg0, %arg1) mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] @@ -94,7 +94,7 @@ func.func @op_copy_arrays(%arg0: !array_cp0) -> !array_cp1 // CHECK-DAG: donated = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1 return %0: !array_cp1 } @@ -113,7 +113,7 @@ func.func @op_assemble(%arg0: !array_ad0, %arg1: !array_ad1) // CHECK-SAME: <{ // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl_0 = ifrt.Assemble(%arg0, %arg1) : (!array_ad0, !array_ad1) -> !array_ad2 return @@ -123,7 +123,7 @@ func.func @op_assemble(%arg0: !array_ad0, %arg1: !array_ad1) // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): func.func @op_disassemble(%arg0: !array_ad2) attributes {ifrt.function} { // CHECK: "vifrt.DisassembleV1"(%[[ARG0]]) - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %1, %ctrl_0 = ifrt.Disassemble(%arg0) : (!array_ad2) -> (!array_ad0, !array_ad1) return @@ -138,14 +138,14 @@ func.func @op_after(%arg0: !array_cp0, %arg1: !array_cp1) // CHECK-DAG: donated = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl_0 = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1 // CHECK: "vifrt.CopyArraysV1"(%[[OUT]]#0, %[[ARG1]], %[[OUT]]#1) // CHECK-SAME: <{ // CHECK-DAG: donated = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %1, %2, %ctrl_1 = ifrt.CopyArrays(%0, %arg1) after %ctrl_0 : (!array_cp1, !array_cp1) -> (!array_cp0, !array_cp0) return %1, %2: !array_cp0, !array_cp0 @@ -167,7 +167,7 @@ func.func @op_call_loaded_executable( // CHECK-DAG: io_aliases = [], // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl_0 = ifrt.CallLoadedExecutable @test_loaded_executable1(%arg0) : (!array_le_in) -> !array_le_out // CHECK: "vifrt.CallLoadedExecutableV1"(%[[ARG0]]) @@ -177,7 +177,7 @@ func.func @op_call_loaded_executable( // CHECK-DAG: io_aliases = [] // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %1, %ctrl_1 = ifrt.CallLoadedExecutable @test_loaded_executable1(%arg0) {donated_input_indices=array} : (!array_le_in) -> !array_le_out // CHECK: "vifrt.CallLoadedExecutableV1"(%[[ARG1]]) @@ -187,7 +187,7 @@ func.func @op_call_loaded_executable( // CHECK-DAG: io_aliases = [array] // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %2, %ctrl_2 = ifrt.CallLoadedExecutable @test_loaded_executable2(%arg1) {io_aliases=[array]} : (!array_le_in) -> !array_le_in return @@ -196,7 +196,7 @@ func.func @op_call_loaded_executable( // CHECK: "vifrt.LoadedExecutableV1"() // CHECK-SAME: <{ // CHECK-DAG: devices = #vifrt -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v1<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> +// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> // CHECK-DAG: sym_name = "test_loaded_executable1" // CHECK-SAME: }> ifrt.LoadedExecutable @test_loaded_executable1 on devices [0,1] @@ -205,7 +205,7 @@ ifrt.LoadedExecutable @test_loaded_executable1 on devices [0,1] // CHECK: "vifrt.LoadedExecutableV1"() // CHECK-SAME: <{ // CHECK-DAG: devices = #vifrt -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> +// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> // CHECK-DAG: sym_name = "test_loaded_executable2" // CHECK-SAME: }> ifrt.LoadedExecutable @test_loaded_executable2 on devices [0,1] @@ -228,7 +228,7 @@ func.func @op_remap_arrays( // CHECK-DAG: mappings = [#vifrt.array_mapping_v1<0, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [0 : 1 : 1]>]>, #vifrt.array_mapping_v1<1, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [1 : 2 : 1]>]>] // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl_0 = ifrt.RemapArrays(%arg0, %arg1) mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] @@ -250,7 +250,7 @@ func.func @op_bitcast_arrays(%arg0: !array_bc0 {ifrt.donated}) -> !array_bc1 // CHECK-DAG: donated = true // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl_0 = ifrt.BitcastArrays(%arg0) {donated=true} : (!array_bc0) -> !array_bc1 return %0: !array_bc1 } @@ -270,7 +270,7 @@ func.func @op_reshard(%arg0: !array_r0, %arg1: !array_r0) // CHECK-DAG: donated = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %1, %ctrl_1 = ifrt.Reshard(%arg0, %arg1) : (!array_r0, !array_r0) -> (!array_r1, !array_r2) return %0, %1 : !array_r1, !array_r2 @@ -282,7 +282,7 @@ func.func @op_reshard(%arg0: !array_r0, %arg1: !array_r0) // CHECK: "vifrt.FuncV1"() // CHECK-SAME: <{ // CHECK-DAG: arg_attrs = [{vifrt.donated}, {vifrt.donated}] -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v1<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v1<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">>> +// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">>> // CHECK-DAG: sym_name = "donated_arguments" // CHECK-DAG: res_attrs = [] // CHECK-SAME: }> @@ -295,10 +295,10 @@ func.func @donated_arguments( // CHECK-DAG: donated = true // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %1, %ctrl_1 = ifrt.Reshard(%arg0, %arg1) {donated=true} : (!array_r0, !array_r0) -> (!array_r1, !array_r2) - // CHECK: "vifrt.ReturnV1"(%[[OUT]]#0, %[[OUT]]#1) : (!vifrt.array_v1, #vifrt.sharding_param_v1<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v1<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">) + // CHECK: "vifrt.ReturnV1"(%[[OUT]]#0, %[[OUT]]#1) : (!vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">) return %0, %1 : !array_r1, !array_r2 } @@ -311,11 +311,11 @@ func.func @op_func_call(%arg0: !array_cp0) -> !array_cp1 // CHECK-DAG: donated = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1 // CHECK: %[[OUT1:.+]] = "vifrt.CallFuncV1"(%[[OUT0]]#0) // CHECK-SAME: <{callee = @copy_back}> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default"> + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default"> %1 = func.call @copy_back(%0) : (!array_cp1) -> !array_cp0 return %0: !array_cp1 } @@ -323,7 +323,7 @@ func.func @op_func_call(%arg0: !array_cp0) -> !array_cp1 // CHECK: "vifrt.FuncV1"() // CHECK-SAME: <{ // CHECK-DAG: arg_attrs = [] -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> +// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> // CHECK-DAG: res_attrs = [] // CHECK-DAG: sym_name = "copy_back" // CHECK-DAG: sym_visibility = "vifrt.default" @@ -336,7 +336,7 @@ func.func @copy_back(%arg1: !array_cp1) -> !array_cp0 // CHECK-DAG: donated = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl = ifrt.CopyArrays(%arg1) : (!array_cp1) -> !array_cp0 return %0: !array_cp0 } @@ -361,7 +361,7 @@ func.func @op_call( // CHECK-DAG: io_aliases = [] // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] : (!array_op_call) -> !array_op_call @@ -375,7 +375,7 @@ func.func @op_call( // CHECK-DAG: io_aliases = [] // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %1, %ctrl_1 = ifrt.Call @add_one::@main(%0) after %ctrl_0 on devices [0,1] : (!array_op_call) -> !array_op_call @@ -388,7 +388,7 @@ func.func @op_call( // CHECK-DAG: io_aliases = [] // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %2, %ctrl_2 = ifrt.Call @"escaped-module"::@main(%arg0) on devices [0,1] : (!array_op_call) -> !array_op_call @@ -402,7 +402,7 @@ func.func @op_call( // CHECK-DAG: io_aliases = [] // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %3, %ctrl_3 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] {donated_input_indices=array} : (!array_op_call) -> !array_op_call @@ -416,7 +416,7 @@ func.func @op_call( // CHECK-DAG: io_aliases = [array] // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v1<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) + // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %4, %ctrl_4 = ifrt.Call @add_two::@main(%arg1) on devices [0,1] {io_aliases=[array]} : (!array_op_call) -> !array_op_call diff --git a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_2_0.mlir b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_2_0.mlir index 91c0ae2b906f2b..0254974e3928f9 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_2_0.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_2_0.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --vifrt-to-version='target_version=0.2.0' --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s +// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s // RUN: ifrt-translate --serialize --ifrt_version=0.2.0 --atom_program_version=1.13.1 --strip_debuginfo %s | ifrt-translate --deserialize --strip_debuginfo | ifrt-opt > %t.0 // RUN: ifrt-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_3_0.mlir b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_3_0.mlir index dd28fd566aadb3..7c4644bbbe9253 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_3_0.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_3_0.mlir @@ -1,4 +1,4 @@ -// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --vifrt-to-version='target_version=0.3.0' --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s +// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s // RUN: ifrt-translate --serialize --ifrt_version=0.3.0 --atom_program_version=1.13.1 --strip_debuginfo %s | ifrt-translate --deserialize --strip_debuginfo | ifrt-opt > %t.0 // RUN: ifrt-opt %s > %t.1 // RUN: diff %t.0 %t.1 diff --git a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_4_0.mlir b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_4_0.mlir deleted file mode 100644 index 4e9e90b828877a..00000000000000 --- a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_4_0.mlir +++ /dev/null @@ -1,483 +0,0 @@ -// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --vifrt-to-version='target_version=0.4.0' --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s -// RUN: ifrt-translate --serialize --ifrt_version=0.4.0 --atom_program_version=1.13.1 --strip_debuginfo %s | ifrt-translate --deserialize --strip_debuginfo | ifrt-opt > %t.0 -// RUN: ifrt-opt %s > %t.1 -// RUN: diff %t.0 %t.1 - -// RUN: ifrt-translate --deserialize --strip_debuginfo %s.bytes | ifrt-opt > %t.2 -// RUN: ifrt-opt %s > %t.3 -// RUN: diff %t.2 %t.3 - -// ============ Types and attributes ============ - -// Verifies conversion of the array and control types, and devices and sharding -// param attributes. -!array_t0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> -!array_t1 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [2,3]> -// CHECK-LABEL: "type_array_and_control" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @type_array_and_control(%arg0: !array_t0) attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV2"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_t0) -> !array_t1 - return -} - -!array_us0 = !ifrt.array, #ifrt.sharding_unspecified, [0,1]> -!array_us1 = !ifrt.array, #ifrt.sharding_unspecified, [2,3]> -// CHECK-LABEL: "attr_unspecified_sharding" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @attr_unspecified_sharding(%arg0: !array_us0) attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV2"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_unspecified_v1, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_unspecified_v1, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_us0) -> !array_us1 - return -} - - -// Verify conversion of IntervalAttr, Mapping, ArrayMapping, -// MappingAttrArrayAttr and ArrayMappingAttrArrayAttr. -!array_rattr_in0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [0]> -!array_rattr_in1 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [1]> -!array_rattr_out = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> -// CHECK-LABEL: "remap_attributes" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @remap_attributes( - %arg0: !array_rattr_in0 {ifrt.donated}, %arg1: !array_rattr_in1 {ifrt.donated}) - attributes {ifrt.function} { - // CHECK: "vifrt.RemapArraysV1"(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = true - // CHECK-DAG: mappings = [#vifrt.array_mapping_v1<0, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [0 : 1 : 1]>]>, #vifrt.array_mapping_v1<1, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [1 : 2 : 1]>]>] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl_0 = ifrt.RemapArrays(%arg0, %arg1) - mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, - #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] - {donated=true} - : (!array_rattr_in0, !array_rattr_in1) -> (!array_rattr_out) - return -} - -// CHECK-LABEL: "ifrt_function_attribute" -// CHECK-NOT: {ifrt.function} -// CHECK: {vifrt.function} -func.func @ifrt_function_attribute() attributes {ifrt.function} { - return -} - -// ============ Ops ============ - -!array_cp0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> -!array_cp1 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [2,3]> -// CHECK-LABEL: "op_copy_arrays" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @op_copy_arrays(%arg0: !array_cp0) -> !array_cp1 - attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV2"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1 - return %0: !array_cp1 -} - -!array_ad0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [0]> -!array_ad1 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [1]> -!array_ad2 = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> -// CHECK-LABEL: "op_assemble" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_assemble(%arg0: !array_ad0, %arg1: !array_ad1) - attributes {ifrt.function} { - // CHECK: "vifrt.AssembleV1"(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl_0 = ifrt.Assemble(%arg0, %arg1) - : (!array_ad0, !array_ad1) -> !array_ad2 - return -} - -// CHECK-LABEL: "op_disassemble" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @op_disassemble(%arg0: !array_ad2) attributes {ifrt.function} { - // CHECK: "vifrt.DisassembleV1"(%[[ARG0]]) - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %1, %ctrl_0 = ifrt.Disassemble(%arg0) - : (!array_ad2) -> (!array_ad0, !array_ad1) - return -} - -// CHECK-LABEL: "op_after" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_after(%arg0: !array_cp0, %arg1: !array_cp1) - -> (!array_cp0, !array_cp0) attributes {ifrt.function} { - // CHECK: %[[OUT:.+]]:2 = "vifrt.CopyArraysV2"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl_0 = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1 - // CHECK: "vifrt.CopyArraysV2"(%[[OUT]]#0, %[[ARG1]], %[[OUT]]#1) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %1, %2, %ctrl_1 = ifrt.CopyArrays(%0, %arg1) after %ctrl_0 - : (!array_cp1, !array_cp1) -> (!array_cp0, !array_cp0) - return %1, %2: !array_cp0, !array_cp0 -} - -!array_le_in = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> -!array_le_out = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> -// CHECK-LABEL: "op_call_loaded_executable" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_call_loaded_executable( - %arg0: !array_le_in {ifrt.donated}, %arg1: !array_le_in {ifrt.donated}) - attributes {ifrt.function} { - // CHECK: %[[OUT:.+]]:2 = "vifrt.CallLoadedExecutableV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = @test_loaded_executable1 - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [], - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl_0 = ifrt.CallLoadedExecutable @test_loaded_executable1(%arg0) - : (!array_le_in) -> !array_le_out - // CHECK: "vifrt.CallLoadedExecutableV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = @test_loaded_executable1 - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %1, %ctrl_1 = ifrt.CallLoadedExecutable @test_loaded_executable1(%arg0) - {donated_input_indices=array} : (!array_le_in) -> !array_le_out - // CHECK: "vifrt.CallLoadedExecutableV1"(%[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = @test_loaded_executable2 - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [array] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %2, %ctrl_2 = ifrt.CallLoadedExecutable @test_loaded_executable2(%arg1) - {io_aliases=[array]} : (!array_le_in) -> !array_le_in - return -} - -// CHECK: "vifrt.LoadedExecutableV1"() -// CHECK-SAME: <{ -// CHECK-DAG: devices = #vifrt -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> -// CHECK-DAG: sym_name = "test_loaded_executable1" -// CHECK-SAME: }> -ifrt.LoadedExecutable @test_loaded_executable1 on devices [0,1] - : (!array_le_in) -> !array_le_out - -// CHECK: "vifrt.LoadedExecutableV1"() -// CHECK-SAME: <{ -// CHECK-DAG: devices = #vifrt -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> -// CHECK-DAG: sym_name = "test_loaded_executable2" -// CHECK-SAME: }> -ifrt.LoadedExecutable @test_loaded_executable2 on devices [0,1] - : (!array_le_in) -> !array_le_in - -!array_ra_in0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [0]> -!array_ra_in1 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [1]> -!array_ra_out = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> -// CHECK-LABEL: "op_remap_arrays" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_remap_arrays( - %arg0: !array_ra_in0 {ifrt.donated}, %arg1: !array_ra_in1 {ifrt.donated}) - attributes {ifrt.function} { - // CHECK: "vifrt.RemapArraysV1"(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = true - // CHECK-DAG: mappings = [#vifrt.array_mapping_v1<0, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [0 : 1 : 1]>]>, #vifrt.array_mapping_v1<1, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [1 : 2 : 1]>]>] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl_0 = ifrt.RemapArrays(%arg0, %arg1) - mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, - #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] - {donated=true} - : (!array_ra_in0, !array_ra_in1) -> (!array_ra_out) - return -} - -!array_bc0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> -!array_bc1 = !ifrt.array, - #ifrt.sharding_param<1x1x1 to [0] on 2>, [0,1]> -// CHECK-LABEL: "op_bitcast_arrays" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @op_bitcast_arrays(%arg0: !array_bc0 {ifrt.donated}) -> !array_bc1 - attributes {ifrt.function} { - // CHECK: "vifrt.BitcastArraysV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = true - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl_0 = ifrt.BitcastArrays(%arg0) {donated=true} : (!array_bc0) -> !array_bc1 - return %0: !array_bc1 -} - -!array_r0 = !ifrt.array, - #ifrt.sharding_param<2 to [0] on 2>, [0,1]> -!array_r1 = !ifrt.array, - #ifrt.sharding_param<1 to [0] on 1>, [2]> -!array_r2 = !ifrt.array, - #ifrt.sharding_param<1 to [0] on 1>, [3]> -// CHECK-LABEL: "op_reshard" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_reshard(%arg0: !array_r0, %arg1: !array_r0) - -> (!array_r1, !array_r2) attributes {ifrt.function} { - // CHECK: "vifrt.ReshardV1"(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %1, %ctrl_1 = ifrt.Reshard(%arg0, %arg1) - : (!array_r0, !array_r0) -> (!array_r1, !array_r2) - return %0, %1 : !array_r1, !array_r2 -} - -// Verifies that the FuncOp, ReturnOp and donated arguments are converted to -// VIFRT. - -// CHECK: "vifrt.FuncV1"() -// CHECK-SAME: <{ -// CHECK-DAG: arg_attrs = [{vifrt.donated}, {vifrt.donated}] -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">>> -// CHECK-DAG: sym_name = "donated_arguments" -// CHECK-DAG: res_attrs = [] -// CHECK-SAME: }> -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @donated_arguments( - %arg0: !array_r0 {ifrt.donated}, %arg1: !array_r0 {ifrt.donated}) - -> (!array_r1, !array_r2) attributes {ifrt.function} { - // CHECK: %[[OUT:.+]]:3 = "vifrt.ReshardV1"(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = true - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %1, %ctrl_1 = ifrt.Reshard(%arg0, %arg1) {donated=true} - : (!array_r0, !array_r0) -> (!array_r1, !array_r2) - // CHECK: "vifrt.ReturnV1"(%[[OUT]]#0, %[[OUT]]#1) : (!vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">) - return %0, %1 : !array_r1, !array_r2 -} - -// CHECK-LABEL: "op_func_call" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @op_func_call(%arg0: !array_cp0) -> !array_cp1 - attributes {ifrt.function} { - // CHECK: %[[OUT0:.+]]:2 = "vifrt.CopyArraysV2"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1 - // CHECK: %[[OUT1:.+]] = "vifrt.CallFuncV1"(%[[OUT0]]#0) - // CHECK-SAME: <{callee = @copy_back}> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default"> - %1 = func.call @copy_back(%0) : (!array_cp1) -> !array_cp0 - return %0: !array_cp1 -} - -// CHECK: "vifrt.FuncV1"() -// CHECK-SAME: <{ -// CHECK-DAG: arg_attrs = [] -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> -// CHECK-DAG: res_attrs = [] -// CHECK-DAG: sym_name = "copy_back" -// CHECK-DAG: sym_visibility = "vifrt.default" -// CHECK-SAME: }> -// CHECK-NEXT: (%[[ARG1:.*]]: {{.*}}): -func.func @copy_back(%arg1: !array_cp1) -> !array_cp0 - attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV2"(%[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg1) : (!array_cp1) -> !array_cp0 - return %0: !array_cp0 -} - -!token0 = !ifrt.array, - #ifrt.sharding_param< to [0] on 2>, [0, 1]> -!token1 = !ifrt.array, - #ifrt.sharding_param< to [0] on 2>, [2, 3]> -// CHECK: "vifrt.FuncV1"() -// CHECK-SAME: <{ -// CHECK-DAG: arg_attrs = [] -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2< to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2< to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">>> -// CHECK-DAG: res_attrs = [] -// CHECK-DAG: sym_name = "token_type" -// CHECK-DAG: sym_visibility = "vifrt.default" -// CHECK-SAME: }> -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @token_type(%arg0: !token0) -> !token1 attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV2"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2< to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2< to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg0) : (!token0) -> !token1 - return %0: !token1 -} - -// Important: The test verifying CallOps must be last. This is necessary because -// in order to test serialization roundtrip the tests in this file are not split -// into per file tests. However, during deserialization we do not know where to -// re-introduce the atom program modules within the module, and thus we append -// them at the end. -!array_op_call = !ifrt.array, - #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> -// CHECK-LABEL: "op_call" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_call( - %arg0: !array_op_call {ifrt.donated}, %arg1: !array_op_call {ifrt.donated}) - -> !array_op_call attributes {ifrt.function} { - // CHECK: %[[OUT0:.+]]:2 = "vifrt.CallV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = "@add_one::@main" - // CHECK-DAG: devices = #vifrt - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] - : (!array_op_call) -> !array_op_call - - // Verifies that the control value is passed to the next call. - - // CHECK: %[[OUT1:.+]]:2 = "vifrt.CallV1"(%[[OUT0]]#0, %[[OUT0]]#1) - // CHECK-SAME: <{ - // CHECK-DAG: callee = "@add_one::@main" - // CHECK-DAG: devices = #vifrt - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %1, %ctrl_1 = ifrt.Call @add_one::@main(%0) after %ctrl_0 on devices [0,1] - : (!array_op_call) -> !array_op_call - - // Verifies that escaped symbol attr is correctly handled. - // CHECK: %[[OUT2:.+]]:2 = "vifrt.CallV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = "@escaped-module::@main" - // CHECK-DAG: devices = #vifrt - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %2, %ctrl_2 = ifrt.Call @"escaped-module"::@main(%arg0) on devices [0,1] - : (!array_op_call) -> !array_op_call - - // Verifies that the donated input indices attribute is converted. - - // CHECK: "vifrt.CallV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = "@add_one::@main" - // CHECK-DAG: devices = #vifrt - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %3, %ctrl_3 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] - {donated_input_indices=array} : (!array_op_call) -> !array_op_call - - // Verifies that the io_aliases attribute is converted. - - // CHECK: "vifrt.CallV1"(%[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = "@add_two::@main" - // CHECK-DAG: devices = #vifrt - // CHECK-DAG: donated_input_indices = array, - // CHECK-DAG: io_aliases = [array] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %4, %ctrl_4 = ifrt.Call @add_two::@main(%arg1) on devices [0,1] - {io_aliases=[array]} : (!array_op_call) -> !array_op_call - - return %1 : !array_op_call -} - -// CHECK-NOT @add_one -module @add_one attributes {sym_visibility = "private"} { - func.func @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = stablehlo.constant dense<1> : tensor<2x2xi32> - %1 = stablehlo.add %arg0, %0 : tensor<2x2xi32> - return %1 : tensor<2x2xi32> - } -} - -// CHECK-NOT @"escaped-module" -module @"escaped-module" attributes {sym_visibility = "private"} { - func.func @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = stablehlo.constant dense<2> : tensor<2x2xi32> - %1 = stablehlo.add %arg0, %0 : tensor<2x2xi32> - return %1 : tensor<2x2xi32> - } -} - -// CHECK-NOT @add_two -module @add_two attributes {sym_visibility = "private"} { - func.func @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = stablehlo.constant dense<2> : tensor<2x2xi32> - %1 = stablehlo.add %arg0, %0 : tensor<2x2xi32> - return %1 : tensor<2x2xi32> - } -} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_4_0.mlir.bytes b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_4_0.mlir.bytes deleted file mode 100644 index 8a06e422ca8a6b..00000000000000 Binary files a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.0_4_0.mlir.bytes and /dev/null differ diff --git a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.mlir b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.mlir index c08ab30503cfb3..0e1299e464b653 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/ifrt_legalize_to_vifrt.mlir @@ -14,10 +14,9 @@ // CHECK-LABEL: "type_array_and_control" // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): func.func @type_array_and_control(%arg0: !array_t0) attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV2"(%[[ARG0]]) + // CHECK: "vifrt.CopyArraysV1"(%[[ARG0]]) // CHECK-SAME: <{ // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) @@ -30,10 +29,9 @@ func.func @type_array_and_control(%arg0: !array_t0) attributes {ifrt.function} { // CHECK-LABEL: "attr_unspecified_sharding" // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): func.func @attr_unspecified_sharding(%arg0: !array_us0) attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV2"(%[[ARG0]]) + // CHECK: "vifrt.CopyArraysV1"(%[[ARG0]]) // CHECK-SAME: <{ // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_unspecified_v1, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_unspecified_v1, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) @@ -87,10 +85,9 @@ func.func @ifrt_function_attribute() attributes {ifrt.function} { // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): func.func @op_copy_arrays(%arg0: !array_cp0) -> !array_cp1 attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV2"(%[[ARG0]]) + // CHECK: "vifrt.CopyArraysV1"(%[[ARG0]]) // CHECK-SAME: <{ // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) @@ -132,18 +129,16 @@ func.func @op_disassemble(%arg0: !array_ad2) attributes {ifrt.function} { // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): func.func @op_after(%arg0: !array_cp0, %arg1: !array_cp1) -> (!array_cp0, !array_cp0) attributes {ifrt.function} { - // CHECK: %[[OUT:.+]]:2 = "vifrt.CopyArraysV2"(%[[ARG0]]) + // CHECK: %[[OUT:.+]]:2 = "vifrt.CopyArraysV1"(%[[ARG0]]) // CHECK-SAME: <{ // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) %0, %ctrl_0 = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1 - // CHECK: "vifrt.CopyArraysV2"(%[[OUT]]#0, %[[ARG1]], %[[OUT]]#1) + // CHECK: "vifrt.CopyArraysV1"(%[[OUT]]#0, %[[ARG1]], %[[OUT]]#1) // CHECK-SAME: <{ // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) @@ -307,10 +302,9 @@ func.func @donated_arguments( // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): func.func @op_func_call(%arg0: !array_cp0) -> !array_cp1 attributes {ifrt.function} { - // CHECK: %[[OUT0:.+]]:2 = "vifrt.CopyArraysV2"(%[[ARG0]]) + // CHECK: %[[OUT0:.+]]:2 = "vifrt.CopyArraysV1"(%[[ARG0]]) // CHECK-SAME: <{ // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) @@ -333,10 +327,9 @@ func.func @op_func_call(%arg0: !array_cp0) -> !array_cp1 // CHECK-NEXT: (%[[ARG1:.*]]: {{.*}}): func.func @copy_back(%arg1: !array_cp1) -> !array_cp0 attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV2"(%[[ARG1]]) + // CHECK: "vifrt.CopyArraysV1"(%[[ARG1]]) // CHECK-SAME: <{ // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) @@ -358,10 +351,9 @@ func.func @copy_back(%arg1: !array_cp1) -> !array_cp0 // CHECK-SAME: }> // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): func.func @token_type(%arg0: !token0) -> !token1 attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV2"(%[[ARG0]]) + // CHECK: "vifrt.CopyArraysV1"(%[[ARG0]]) // CHECK-SAME: <{ // CHECK-DAG: donated = false - // CHECK-DAG: reuse = false // CHECK-DAG: operandSegmentSizes = array // CHECK-SAME: }> // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2< to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2< to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) diff --git a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/vifrt_to_version_downgrade.0_3_0.mlir b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/vifrt_to_version_downgrade.0_3_0.mlir deleted file mode 100644 index abc264eb9aa01d..00000000000000 --- a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/vifrt_to_version_downgrade.0_3_0.mlir +++ /dev/null @@ -1,457 +0,0 @@ -// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --vifrt-to-version='target_version=0.3.0' --symbol-dce --mlir-print-op-generic -split-input-file | FileCheck %s - -// ============ Types and attributes ============ - -// Verifies conversion of the array and control types, and devices and sharding -// param attributes. -!array_t0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> -!array_t1 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [2,3]> -// CHECK-LABEL: "type_array_and_control" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @type_array_and_control(%arg0: !array_t0) attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_t0) -> !array_t1 - return -} - -!array_us0 = !ifrt.array, #ifrt.sharding_unspecified, [0,1]> -!array_us1 = !ifrt.array, #ifrt.sharding_unspecified, [2,3]> -// CHECK-LABEL: "attr_unspecified_sharding" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @attr_unspecified_sharding(%arg0: !array_us0) attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_unspecified_v1, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_unspecified_v1, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_us0) -> !array_us1 - return -} - -!array_unreduced = !ifrt.array, - #ifrt.sharding_param<2 to [0,1] on 2x2 unreduced [1]>, - [0, 1, 2, 3]> -// CHECK-LABEL: "array_with_unreduced_axes" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @array_with_unreduced_axes(%arg0: !array_unreduced) attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0, 1] on 2x2 unreduced [1]>, [0, 1, 2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0, 1] on 2x2 unreduced [1]>, [0, 1, 2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_unreduced) -> !array_unreduced - return -} - -// Verify conversion of IntervalAttr, Mapping, ArrayMapping, -// MappingAttrArrayAttr and ArrayMappingAttrArrayAttr. -!array_rattr_in0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [0]> -!array_rattr_in1 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [1]> -!array_rattr_out = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> -// CHECK-LABEL: "remap_attributes" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @remap_attributes( - %arg0: !array_rattr_in0 {ifrt.donated}, %arg1: !array_rattr_in1 {ifrt.donated}) - attributes {ifrt.function} { - // CHECK: "vifrt.RemapArraysV1"(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = true - // CHECK-DAG: mappings = [#vifrt.array_mapping_v1<0, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [0 : 1 : 1]>]>, #vifrt.array_mapping_v1<1, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [1 : 2 : 1]>]>] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.RemapArrays(%arg0, %arg1) - mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, - #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] - {donated=true} - : (!array_rattr_in0, !array_rattr_in1) -> (!array_rattr_out) - return -} - -// CHECK-LABEL: "ifrt_function_attribute" -// CHECK-NOT: {ifrt.function} -// CHECK: {vifrt.function} -func.func @ifrt_function_attribute() attributes {ifrt.function} { - return -} - -// ============ Ops ============ - -!array_cp0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> -!array_cp1 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [2,3]> -// CHECK-LABEL: "op_copy_arrays" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @op_copy_arrays(%arg0: !array_cp0) -> !array_cp1 - attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1 - return %0: !array_cp1 -} - -!array_ad0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [0]> -!array_ad1 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [1]> -!array_ad2 = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> -// CHECK-LABEL: "op_assemble" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_assemble(%arg0: !array_ad0, %arg1: !array_ad1) - attributes {ifrt.function} { - // CHECK: "vifrt.AssembleV1"(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.Assemble(%arg0, %arg1) : (!array_ad0, !array_ad1) -> !array_ad2 - return -} - -// CHECK-LABEL: "op_disassemble" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @op_disassemble(%arg0: !array_ad2) attributes {ifrt.function} { - // CHECK: "vifrt.DisassembleV1"(%[[ARG0]]) - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %1, %ctrl = ifrt.Disassemble(%arg0) : (!array_ad2) -> (!array_ad0, !array_ad1) - return -} - -// CHECK-LABEL: "op_after" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_after(%arg0: !array_cp0, %arg1: !array_cp1) - -> (!array_cp0, !array_cp0) attributes {ifrt.function} { - // CHECK: %[[OUT:.+]]:2 = "vifrt.CopyArraysV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl_0 = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1 - // CHECK: "vifrt.CopyArraysV1"(%[[OUT]]#0, %[[ARG1]], %[[OUT]]#1) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %1, %2, %ctrl_1 = ifrt.CopyArrays(%0, %arg1) after %ctrl_0 - : (!array_cp1, !array_cp1) -> (!array_cp0, !array_cp0) - return %1, %2: !array_cp0, !array_cp0 -} - -!array_le_in = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> -!array_le_out = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> -// CHECK-LABEL: "op_call_loaded_executable" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_call_loaded_executable( - %arg0: !array_le_in {ifrt.donated}, %arg1: !array_le_in {ifrt.donated}) - attributes {ifrt.function} { - // CHECK: %[[OUT:.+]]:2 = "vifrt.CallLoadedExecutableV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = @test_loaded_executable1 - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [], - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl_0 = ifrt.CallLoadedExecutable @test_loaded_executable1(%arg0) - : (!array_le_in) -> !array_le_out - // CHECK: "vifrt.CallLoadedExecutableV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = @test_loaded_executable1 - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %1, %ctrl_1 = ifrt.CallLoadedExecutable @test_loaded_executable1(%arg0) - {donated_input_indices=array} : (!array_le_in) -> !array_le_out - // CHECK: "vifrt.CallLoadedExecutableV1"(%[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = @test_loaded_executable2 - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [array] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %2, %ctrl_2 = ifrt.CallLoadedExecutable @test_loaded_executable2(%arg1) - {io_aliases=[array]} : (!array_le_in) -> !array_le_in - return -} - -// CHECK: "vifrt.LoadedExecutableV1"() -// CHECK-SAME: <{ -// CHECK-DAG: devices = #vifrt -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> -// CHECK-DAG: sym_name = "test_loaded_executable1" -// CHECK-SAME: }> -ifrt.LoadedExecutable @test_loaded_executable1 on devices [0,1] - : (!array_le_in) -> !array_le_out - -// CHECK: "vifrt.LoadedExecutableV1"() -// CHECK-SAME: <{ -// CHECK-DAG: devices = #vifrt -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> -// CHECK-DAG: sym_name = "test_loaded_executable2" -// CHECK-SAME: }> -ifrt.LoadedExecutable @test_loaded_executable2 on devices [0,1] - : (!array_le_in) -> !array_le_in - -!array_ra_in0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [0]> -!array_ra_in1 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 1>, [1]> -!array_ra_out = !ifrt.array, - #ifrt.sharding_param<1x2 to [0] on 2>, [0,1]> -// CHECK-LABEL: "op_remap_arrays" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_remap_arrays( - %arg0: !array_ra_in0 {ifrt.donated}, %arg1: !array_ra_in1 {ifrt.donated}) - attributes {ifrt.function} { - // CHECK: "vifrt.RemapArraysV1"(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = true - // CHECK-DAG: mappings = [#vifrt.array_mapping_v1<0, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [0 : 1 : 1]>]>, #vifrt.array_mapping_v1<1, 0, [#vifrt.mapping_v1<[0 : 1 : 1] to [1 : 2 : 1]>]>] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [0], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 1>, [1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.RemapArrays(%arg0, %arg1) - mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, - #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] - {donated=true} - : (!array_ra_in0, !array_ra_in1) -> (!array_ra_out) - return -} - -!array_bc0 = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> -!array_bc1 = !ifrt.array, - #ifrt.sharding_param<1x1x1 to [0] on 2>, [0,1]> -// CHECK-LABEL: "op_bitcast_arrays" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @op_bitcast_arrays(%arg0: !array_bc0 {ifrt.donated}) -> !array_bc1 - attributes {ifrt.function} { - // CHECK: "vifrt.BitcastArraysV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = true - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.BitcastArrays(%arg0) {donated=true} : (!array_bc0) -> !array_bc1 - return %0: !array_bc1 -} - -!array_r0 = !ifrt.array, - #ifrt.sharding_param<2 to [0] on 2>, [0,1]> -!array_r1 = !ifrt.array, - #ifrt.sharding_param<1 to [0] on 1>, [2]> -!array_r2 = !ifrt.array, - #ifrt.sharding_param<1 to [0] on 1>, [3]> -// CHECK-LABEL: "op_reshard" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_reshard(%arg0: !array_r0, %arg1: !array_r0) - -> (!array_r1, !array_r2) attributes {ifrt.function} { - // CHECK: "vifrt.ReshardV1"(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %1, %ctrl_1 = ifrt.Reshard(%arg0, %arg1) - : (!array_r0, !array_r0) -> (!array_r1, !array_r2) - return %0, %1 : !array_r1, !array_r2 -} - -// Verifies that the FuncOp, ReturnOp and donated arguments are converted to -// VIFRT. - -// CHECK: "vifrt.FuncV1"() -// CHECK-SAME: <{ -// CHECK-DAG: arg_attrs = [{vifrt.donated}, {vifrt.donated}] -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">>> -// CHECK-DAG: sym_name = "donated_arguments" -// CHECK-DAG: res_attrs = [] -// CHECK-SAME: }> -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @donated_arguments( - %arg0: !array_r0 {ifrt.donated}, %arg1: !array_r0 {ifrt.donated}) - -> (!array_r1, !array_r2) attributes {ifrt.function} { - // CHECK: %[[OUT:.+]]:3 = "vifrt.ReshardV1"(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = true - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<2 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %1, %ctrl_1 = ifrt.Reshard(%arg0, %arg1) {donated=true} - : (!array_r0, !array_r0) -> (!array_r1, !array_r2) - // CHECK: "vifrt.ReturnV1"(%[[OUT]]#0, %[[OUT]]#1) : (!vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [2], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.array_v1, #vifrt.sharding_param_v2<1 to [0] on 1>, [3], memory_kind = "vifrt.default", layout = "vifrt.default">) - return %0, %1 : !array_r1, !array_r2 -} - -// CHECK-LABEL: "op_func_call" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}): -func.func @op_func_call(%arg0: !array_cp0) -> !array_cp1 - attributes {ifrt.function} { - // CHECK: %[[OUT0:.+]]:2 = "vifrt.CopyArraysV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg0) : (!array_cp0) -> !array_cp1 - // CHECK: %[[OUT1:.+]] = "vifrt.CallFuncV1"(%[[OUT0]]#0) - // CHECK-SAME: <{callee = @copy_back}> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default"> - %1 = func.call @copy_back(%0) : (!array_cp1) -> !array_cp0 - return %0: !array_cp1 -} - -// CHECK: "vifrt.FuncV1"() -// CHECK-SAME: <{ -// CHECK-DAG: arg_attrs = [] -// CHECK-DAG: function_type = #vifrt.type_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> !vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">>> -// CHECK-DAG: res_attrs = [] -// CHECK-DAG: sym_name = "copy_back" -// CHECK-DAG: sym_visibility = "vifrt.default" -// CHECK-SAME: }> -// CHECK-NEXT: (%[[ARG1:.*]]: {{.*}}): -func.func @copy_back(%arg1: !array_cp1) -> !array_cp0 - attributes {ifrt.function} { - // CHECK: "vifrt.CopyArraysV1"(%[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: donated = false - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [2, 3], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<1x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl = ifrt.CopyArrays(%arg1) : (!array_cp1) -> !array_cp0 - return %0: !array_cp0 -} - -// Important: The test verifying CallOps must be last. This is necessary because -// in order to test serialization roundtrip the tests in this file are not split -// into per file tests. However, during deserialization we do not know where to -// re-introduce the atom program modules within the module, and thus we append -// them at the end. -!array_op_call = !ifrt.array, - #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> -// CHECK-LABEL: "op_call" -// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}): -func.func @op_call( - %arg0: !array_op_call {ifrt.donated}, %arg1: !array_op_call {ifrt.donated}) - -> !array_op_call attributes {ifrt.function} { - // CHECK: %[[OUT0:.+]]:2 = "vifrt.CallV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = "@add_one::@main" - // CHECK-DAG: devices = #vifrt - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] - : (!array_op_call) -> !array_op_call - - // Verifies that the control value is passed to the next call. - - // CHECK: %[[OUT1:.+]]:2 = "vifrt.CallV1"(%[[OUT0]]#0, %[[OUT0]]#1) - // CHECK-SAME: <{ - // CHECK-DAG: callee = "@add_one::@main" - // CHECK-DAG: devices = #vifrt - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %1, %ctrl_1 = ifrt.Call @add_one::@main(%0) after %ctrl_0 on devices [0,1] - : (!array_op_call) -> !array_op_call - - // Verifies that escaped symbol attr is correctly handled. - // CHECK: %[[OUT2:.+]]:2 = "vifrt.CallV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = "@escaped-module::@main" - // CHECK-DAG: devices = #vifrt - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %2, %ctrl_2 = ifrt.Call @"escaped-module"::@main(%arg0) on devices [0,1] - : (!array_op_call) -> !array_op_call - - // Verifies that the donated input indices attribute is converted. - - // CHECK: "vifrt.CallV1"(%[[ARG0]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = "@add_one::@main" - // CHECK-DAG: devices = #vifrt - // CHECK-DAG: donated_input_indices = array - // CHECK-DAG: io_aliases = [] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %3, %ctrl_3 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] - {donated_input_indices=array} : (!array_op_call) -> !array_op_call - - // Verifies that the io_aliases attribute is converted. - - // CHECK: "vifrt.CallV1"(%[[ARG1]]) - // CHECK-SAME: <{ - // CHECK-DAG: callee = "@add_two::@main" - // CHECK-DAG: devices = #vifrt - // CHECK-DAG: donated_input_indices = array, - // CHECK-DAG: io_aliases = [array] - // CHECK-DAG: operandSegmentSizes = array - // CHECK-SAME: }> - // CHECK-SAME: (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">) -> (!vifrt.array_v1, #vifrt.sharding_param_v2<2x1 to [0] on 2>, [0, 1], memory_kind = "vifrt.default", layout = "vifrt.default">, !vifrt.control_v1) - %4, %ctrl_4 = ifrt.Call @add_two::@main(%arg1) on devices [0,1] - {io_aliases=[array]} : (!array_op_call) -> !array_op_call - - return %1 : !array_op_call -} - -// CHECK-NOT @add_one -module @add_one attributes {sym_visibility = "private"} { - func.func @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = stablehlo.constant dense<1> : tensor<2x2xi32> - %1 = stablehlo.add %arg0, %0 : tensor<2x2xi32> - return %1 : tensor<2x2xi32> - } -} - -// CHECK-NOT @"escaped-module" -module @"escaped-module" attributes {sym_visibility = "private"} { - func.func @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = stablehlo.constant dense<2> : tensor<2x2xi32> - %1 = stablehlo.add %arg0, %0 : tensor<2x2xi32> - return %1 : tensor<2x2xi32> - } -} - -// CHECK-NOT @add_two -module @add_two attributes {sym_visibility = "private"} { - func.func @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = stablehlo.constant dense<2> : tensor<2x2xi32> - %1 = stablehlo.add %arg0, %0 : tensor<2x2xi32> - return %1 : tensor<2x2xi32> - } -} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/vifrt_to_version_downgrade_invalid.0_3_0.mlir b/third_party/xla/xla/python/ifrt/ir/tests/vifrt/vifrt_to_version_downgrade_invalid.0_3_0.mlir deleted file mode 100644 index bfe76c40d1c44d..00000000000000 --- a/third_party/xla/xla/python/ifrt/ir/tests/vifrt/vifrt_to_version_downgrade_invalid.0_3_0.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: ifrt-opt %s --ifrt-legalize-to-vifrt --vifrt-to-version='target_version=0.3.0' --symbol-dce --verify-diagnostics --split-input-file - -!array = !ifrt.array, - #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> - -// expected-error@-6 {{failed to convert to VIFRT version 0.3.0}} -// expected-error@+2 {{failed to legalize operation 'vifrt.CopyArraysV2' that was explicitly marked illegal}} -func.func @copy_array_with_reuse(%arg0: !array) attributes {ifrt.function} { - %0, %ctrl = ifrt.CopyArrays(%arg0) {reuse = true} : (!array) -> !array - return -} diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_legalize_to_vifrt_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_legalize_to_vifrt_pass.cc index 49e9df1c74a5ce..9e2c31af14b47e 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_legalize_to_vifrt_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_legalize_to_vifrt_pass.cc @@ -291,16 +291,10 @@ mlir::LogicalResult addDefaultAttrs( convertGeneric(ifrt_attr, pattern.getTypeConverter())); }; - if constexpr (std::is_same::value) { - if (!ifrt_op.getDonatedAttr()) { - add_default_attr("donated", builder.getBoolAttr(false)); - } - if (!ifrt_op.getReuseAttr()) { - add_default_attr("reuse", builder.getBoolAttr(false)); - } - } else if constexpr (std::is_same::value || - std::is_same::value || - std::is_same::value) { + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { if (!ifrt_op.getDonatedAttr()) { add_default_attr("donated", builder.getBoolAttr(false)); } diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_merge_copies_and_reshards_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_merge_copies_and_reshards_pass.cc index ec5b20b881db0c..6fce976381a46b 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_merge_copies_and_reshards_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_merge_copies_and_reshards_pass.cc @@ -79,8 +79,6 @@ llvm::hash_code GetMergeKey(CopyArraysOp op) { // explicitly convert to BoolAttr. hash = llvm::hash_combine( hash, mlir::BoolAttr::get(op.getContext(), op.getDonated())); - hash = llvm::hash_combine( - hash, mlir::BoolAttr::get(op.getContext(), op.getReuse())); hash = llvm::hash_combine( hash, StringAttr::get(op.getContext(), absl::StrCat(input_type.MemoryKind()))); @@ -127,21 +125,14 @@ void RewriteCopyArraysGroup(mlir::IRRewriter& rewriter, locs.reserve(to_merge.size()); bool donated = false; - bool reuse = false; - for (auto [index, group_op] : llvm::enumerate(to_merge)) { + for (mlir::Operation* group_op : to_merge) { auto copy_arrays_op = mlir::cast(group_op); inputs.append(copy_arrays_op.getInputs().begin(), copy_arrays_op.getInputs().end()); for (mlir::Value output : copy_arrays_op.getOutputs()) { output_types.push_back(output.getType()); } - if (index == 0) { - donated = copy_arrays_op.getDonated(); - reuse = copy_arrays_op.getReuse(); - } else { - CHECK_EQ(donated, copy_arrays_op.getDonated()); - CHECK_EQ(reuse, copy_arrays_op.getReuse()); - } + donated = copy_arrays_op.getDonated(); locs.push_back(group_op->getLoc()); } @@ -156,7 +147,6 @@ void RewriteCopyArraysGroup(mlir::IRRewriter& rewriter, IfrtControlType::get(rewriter.getContext()), /*inputs=*/inputs, /*donated=*/donated, - /*reuse=*/reuse, /*control_inputs=*/mlir::ValueRange()); // Replace the original group with the new merged CopyArrays. diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc index 6c4337d0f74f48..cb28478947db74 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc @@ -118,7 +118,6 @@ class ReshardToCopyArraysOpPattern : public mlir::OpRewritePattern { /*control_output=*/op.getControlOutput().getType(), /*inputs=*/{op.getInputs()[array_idx]}, /*donated=*/op.getDonated(), - /*reuse=*/false, /*control_inputs=*/op.getControlInputs()); outputs[array_idx] = copy_arrays_op.getOutputs().front(); if (reshard_inputs_left.empty()) { diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/map_ifrt_to_vifrt.h b/third_party/xla/xla/python/ifrt/ir/transforms/map_ifrt_to_vifrt.h index 8fdb88d52e2834..015263b8e1e201 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/map_ifrt_to_vifrt.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/map_ifrt_to_vifrt.h @@ -54,7 +54,7 @@ using IfrtToVifrtOp = typename IfrtToVifrtOpImpl::Type; // Mappings between IFRT and current VIFRT ops. MAP_IFRT_TO_VIFRT(CallOp, V1) MAP_IFRT_TO_VIFRT(ReshardOp, V1) -MAP_IFRT_TO_VIFRT(CopyArraysOp, V2) +MAP_IFRT_TO_VIFRT(CopyArraysOp, V1) MAP_IFRT_TO_VIFRT(AssembleOp, V1) MAP_IFRT_TO_VIFRT(DisassembleOp, V1) MAP_IFRT_TO_VIFRT(RemapArraysOp, V1) diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/vifrt_legalize_to_ifrt_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/vifrt_legalize_to_ifrt_pass.cc index 428b5b64e61f27..0d45f5d51038cb 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/vifrt_legalize_to_ifrt_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/vifrt_legalize_to_ifrt_pass.cc @@ -365,7 +365,7 @@ class VifrtToIfrtOpConverter : public mlir::OpConversionPattern { return mlir::failure(); } - // Convert the VIFRT attributes to IFRT attributes. + // Convert the IFRT attributes to VIFRT attributes. llvm::SmallVector ifrt_attrs; llvm::DenseSet already_converted_attrs; // Special case operations. diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/vifrt_to_version_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/vifrt_to_version_pass.cc index 233efdd737e9f4..13531151ade31b 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/vifrt_to_version_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/vifrt_to_version_pass.cc @@ -352,44 +352,6 @@ struct VifrtTypeConversionPattern : public mlir::ConversionPattern { std::optional version; }; -void copyDiscardableAttrs(mlir::Operation* src, mlir::Operation* dst) { - dst->setDiscardableAttrs(src->getDiscardableAttrDictionary()); -} - -struct CopyArraysOpV1ToV2 : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - CopyArraysOpV1 op, mlir::PatternRewriter& rewriter) const override { - CopyArraysOpV2 new_op = rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), op.getInputs(), op.getDonated(), - /*reuse=*/mlir::BoolAttr::get(op.getContext(), false), - op.getControlInputs()); - copyDiscardableAttrs(op, new_op); - return mlir::success(); - } -}; - -struct CopyArraysOpV2ToV1 : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - CopyArraysOpV2 op, mlir::PatternRewriter& rewriter) const override { - if (auto reuse_attr = - llvm::dyn_cast_or_null(op.getReuse())) { - if (reuse_attr.getValue()) { - return rewriter.notifyMatchFailure( - op, "reuse is not supported in CopyArraysOpV1"); - } - } - CopyArraysOpV1 new_op = rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), op.getInputs(), op.getDonated(), - op.getControlInputs()); - copyDiscardableAttrs(op, new_op); - return mlir::success(); - } -}; - } // namespace void populateVifrtToVersionPatterns(mlir::RewritePatternSet* patterns, @@ -448,9 +410,6 @@ void populateVifrtToVersionPatterns(mlir::RewritePatternSet* patterns, // 3) and 4) Convert the typed attributes and attributes of operations. patterns->add(*converter, context, version); - - // Upgrade/Downgrade conversion patterns between ops. - patterns->add(context); } } // namespace ifrt diff --git a/third_party/xla/xla/python/ifrt/ir/version.h b/third_party/xla/xla/python/ifrt/ir/version.h index 7196d3666a4ee3..f5a1b76a3f3ae1 100644 --- a/third_party/xla/xla/python/ifrt/ir/version.h +++ b/third_party/xla/xla/python/ifrt/ir/version.h @@ -36,7 +36,7 @@ class Version { static mlir::FailureOr fromString(llvm::StringRef version_ref); // Returns a Version representing the current IFRT IR version. - static Version getCurrentVersion() { return Version(0, 4, 0); } + static Version getCurrentVersion() { return Version(0, 3, 0); } /// Returns a Version representing the minimum supported IFRT IR version. static Version getMinimumVersion() { return Version(0, 1, 0); } diff --git a/third_party/xla/xla/python/ifrt/ir/vifrt_dialect.td b/third_party/xla/xla/python/ifrt/ir/vifrt_dialect.td index 23d61ad0857ad4..d2f0a690239400 100644 --- a/third_party/xla/xla/python/ifrt/ir/vifrt_dialect.td +++ b/third_party/xla/xla/python/ifrt/ir/vifrt_dialect.td @@ -37,7 +37,6 @@ def Vifrt_Dialect : Dialect { 0.1.0: Initial IFRT IR stability guarantees. 0.2.0: Added unreduced to sharding parameter attribute. 0.3.0: Added token type. - 0.4.0: Added CopyArrays with reuse semantics. }]; let useDefaultAttributePrinterParser = 0; diff --git a/third_party/xla/xla/python/ifrt/ir/vifrt_ops.td b/third_party/xla/xla/python/ifrt/ir/vifrt_ops.td index 3e413182a6343a..fe7543fc8d3fb8 100644 --- a/third_party/xla/xla/python/ifrt/ir/vifrt_ops.td +++ b/third_party/xla/xla/python/ifrt/ir/vifrt_ops.td @@ -57,7 +57,7 @@ def Vifrt_ReshardOpV1 : Vifrt_Op<"ReshardV1", "0.1.0", "current", Vifrt_AnyType:$control_output); } -def Vifrt_CopyArraysOpV1 : Vifrt_Op<"CopyArraysV1", "0.1.0", "0.3.0", +def Vifrt_CopyArraysOpV1 : Vifrt_Op<"CopyArraysV1", "0.1.0", "current", [AttrSizedOperandSegments]> { let arguments = (ins Variadic:$inputs, @@ -68,18 +68,6 @@ def Vifrt_CopyArraysOpV1 : Vifrt_Op<"CopyArraysV1", "0.1.0", "0.3.0", Vifrt_AnyType:$control_output); } -def Vifrt_CopyArraysOpV2 : Vifrt_Op<"CopyArraysV2", "0.4.0", "current", - [AttrSizedOperandSegments]> { - let arguments = (ins - Variadic:$inputs, - Vifrt_AnyAttr:$donated, - Vifrt_AnyAttr:$reuse, - Variadic:$control_inputs); - let results = (outs - Variadic:$outputs, - Vifrt_AnyType:$control_output); -} - def Vifrt_AssembleOpV1 : Vifrt_Op<"AssembleV1", "0.1.0", "current", [AttrSizedOperandSegments]> { let arguments = (ins diff --git a/third_party/xla/xla/python/ifrt/sharding.cc b/third_party/xla/xla/python/ifrt/sharding.cc index fb2f3607946900..a79a3314fc0cbe 100644 --- a/third_party/xla/xla/python/ifrt/sharding.cc +++ b/third_party/xla/xla/python/ifrt/sharding.cc @@ -724,14 +724,14 @@ absl::StatusOr> ConcreteEvenSharding::IndexDomains( const Shape& shape, SingleDeviceShardSemantics single_device_shard_semantics) const { DCHECK(this); - if (devices_->devices().size() == 1 && is_fully_replicated_ && - shape_ == shard_shape_ && shape_ == shape) { + if (is_fully_replicated_ && shape_ == shard_shape_ && shape_ == shape) { std::vector result; if (single_device_shard_semantics == - SingleDeviceShardSemantics::kAllShards || - devices_->devices().front()->IsAddressable()) { - result.reserve(1); - result.push_back(IndexDomain(shape)); + SingleDeviceShardSemantics::kAllShards) { + result.resize(devices_->size(), IndexDomain(shape)); + } else { + result.resize(devices_->AddressableDeviceList()->size(), + IndexDomain(shape)); } return result; } diff --git a/third_party/xla/xla/python/ifrt/sharding_test.cc b/third_party/xla/xla/python/ifrt/sharding_test.cc index 67aa8f4ec87ec9..dc8b74f8a664f6 100644 --- a/third_party/xla/xla/python/ifrt/sharding_test.cc +++ b/third_party/xla/xla/python/ifrt/sharding_test.cc @@ -807,7 +807,27 @@ TEST_P(ConcreteEvenShardingTest, DisassembleFailsForUnexpectedShape) { HasSubstr("ConcreteEvenSharding can only disassemble"))); } -TEST_P(ConcreteEvenShardingTest, IndexDomainsFails) { +TEST_P(ConcreteEvenShardingTest, IndexDomainsForFullyReplicated) { + Shape shape({10, 20}); + + auto device_list = GetDevices({0, 4}); + ASSERT_TRUE(device_list->devices()[0]->IsAddressable()); + ASSERT_FALSE(device_list->devices()[1]->IsAddressable()); + + ShardingRef sharding = ConcreteEvenSharding::Create( + device_list, MemoryKind(), /*shape=*/shape, /*shard_shape=*/shape, + /*is_fully_replicated=*/true); + + EXPECT_THAT( + sharding->IndexDomains(shape, SingleDeviceShardSemantics::kAllShards), + IsOkAndHolds(ElementsAre(IndexDomain(shape), IndexDomain(shape)))); + + EXPECT_THAT(sharding->IndexDomains( + shape, SingleDeviceShardSemantics::kAddressableShards), + IsOkAndHolds(ElementsAre(IndexDomain(shape)))); +} + +TEST_P(ConcreteEvenShardingTest, IndexDomainsFailsForNonFullyReplicated) { auto device_list = GetDevices({0, 1}); std::vector shard_shapes; ShardingRef sharding = diff --git a/third_party/xla/xla/python/refine_polymorphic_shapes.cc b/third_party/xla/xla/python/refine_polymorphic_shapes.cc index 89a2933dd9be32..459e65b100eb2a 100644 --- a/third_party/xla/xla/python/refine_polymorphic_shapes.cc +++ b/third_party/xla/xla/python/refine_polymorphic_shapes.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" @@ -320,8 +321,8 @@ absl::Status RefinePolymorphicShapes(llvm::StringRef module_str, return absl::InvalidArgumentError("Cannot parse module."); } - TF_RETURN_IF_ERROR(RefinePolymorphicShapes(*module, enable_shape_assertions)); - if (validate_static_shapes) TF_RETURN_IF_ERROR(ValidateStaticShapes(*module)); + RETURN_IF_ERROR(RefinePolymorphicShapes(*module, enable_shape_assertions)); + if (validate_static_shapes) RETURN_IF_ERROR(ValidateStaticShapes(*module)); if (mlir::failed(mlir::writeBytecodeToFile(*module, os))) { return absl::InternalError("Cannot serialize module."); } diff --git a/third_party/xla/xla/python/transfer/BUILD b/third_party/xla/xla/python/transfer/BUILD index 600dce1397635e..907877a807b6fd 100644 --- a/third_party/xla/xla/python/transfer/BUILD +++ b/third_party/xla/xla/python/transfer/BUILD @@ -122,6 +122,7 @@ cc_library( "//xla/tsl/concurrency:future", "//xla/tsl/concurrency:ref_count", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", @@ -169,6 +170,7 @@ xla_cc_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base", "@com_google_absl//absl/functional:any_invocable", @@ -284,6 +286,7 @@ cc_library( "//xla/python/pjrt_ifrt:transfer_server_interface", "//xla/tsl/concurrency:ref_count", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", diff --git a/third_party/xla/xla/python/transfer/pjrt_transfer_server.cc b/third_party/xla/xla/python/transfer/pjrt_transfer_server.cc index 4da96b8fa12946..5fc103335be895 100644 --- a/third_party/xla/xla/python/transfer/pjrt_transfer_server.cc +++ b/third_party/xla/xla/python/transfer/pjrt_transfer_server.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/Support/Casting.h" #include "xla/layout.h" #include "xla/pjrt/distributed/key_value_store_interface.h" @@ -126,12 +127,12 @@ PjRtTransferServer::MakePjRtTransferServerFactory( std::shared_ptr kv_store, const std::string& socket_address, const std::vector& transport_addresses) { - TF_ASSIGN_OR_RETURN(aux::SocketAddress address, - aux::SocketAddress::Parse(socket_address)); + ASSIGN_OR_RETURN(aux::SocketAddress address, + aux::SocketAddress::Parse(socket_address)); std::vector transport_socket_addresses; if (transport_addresses.empty()) { - TF_ASSIGN_OR_RETURN(aux::SocketAddress transport_address, - aux::SocketAddress::Parse("0.0.0.0:0")); + ASSIGN_OR_RETURN(aux::SocketAddress transport_address, + aux::SocketAddress::Parse("0.0.0.0:0")); // TODO(emilyaf, parkers): Remove this once defaults are set per device // platform. transport_socket_addresses.reserve(4); @@ -141,8 +142,8 @@ PjRtTransferServer::MakePjRtTransferServerFactory( } else { transport_socket_addresses.reserve(transport_addresses.size()); for (const std::string& transport_address : transport_addresses) { - TF_ASSIGN_OR_RETURN(aux::SocketAddress socket_address, - aux::SocketAddress::Parse(transport_address)); + ASSIGN_OR_RETURN(aux::SocketAddress socket_address, + aux::SocketAddress::Parse(transport_address)); transport_socket_addresses.push_back(socket_address); } } @@ -154,7 +155,7 @@ PjRtTransferServer::MakePjRtTransferServerFactory( client, client->addressable_device_count() * 2, transfer_size, cross_host_transfer_timeout, kv_store, address, transport_socket_addresses); - TF_RETURN_IF_ERROR(transfer_server->StartTransferServer()); + RETURN_IF_ERROR(transfer_server->StartTransferServer()); return transfer_server; }; return factory; @@ -162,22 +163,22 @@ PjRtTransferServer::MakePjRtTransferServerFactory( absl::Status PjRtTransferServer::StartTransferServer() { // Populate the KV store with this process's socket address. - TF_RETURN_IF_ERROR(kv_store_->Set( + RETURN_IF_ERROR(kv_store_->Set( absl::StrCat(kKeyPrefixSocketAddress, pjrt_client_->process_index()), socket_address_.ToString())); size_t total_size = transfer_size_ * max_num_parallel_copies_; - TF_ASSIGN_OR_RETURN(auto tmp, aux::AllocateAlignedMemory(total_size)); - TF_ASSIGN_OR_RETURN(auto map, aux::MapPjrtMemory(pjrt_client_, tmp->data(), - tmp->size(), tmp)); + ASSIGN_OR_RETURN(auto tmp, aux::AllocateAlignedMemory(total_size)); + ASSIGN_OR_RETURN(auto map, aux::MapPjrtMemory(pjrt_client_, tmp->data(), + tmp->size(), tmp)); aux::SlabAllocator uallocator(map, transfer_size_); - TF_ASSIGN_OR_RETURN(auto factory, aux::CreateSocketBulkTransportFactory( - transport_addresses_, std::nullopt, - std::move(uallocator))); + ASSIGN_OR_RETURN(auto factory, aux::CreateSocketBulkTransportFactory( + transport_addresses_, std::nullopt, + std::move(uallocator))); socket_server_ = std::make_shared(); - TF_ASSIGN_OR_RETURN( - auto mem, aux::AllocateAndMapPjrtMemory(pjrt_client_, total_size * 2)); + ASSIGN_OR_RETURN(auto mem, + aux::AllocateAndMapPjrtMemory(pjrt_client_, total_size * 2)); premapped_copier_ = std::make_shared( mem, max_num_parallel_copies_, transfer_size_); return (*socket_server_)->Start(socket_address_, factory); @@ -204,8 +205,8 @@ absl::Status PjRtTransferServer::CrossHostAwaitPull( if (pjrt_arr->pjrt_buffers().empty()) { return absl::InvalidArgumentError("PjRtArray has no buffers."); } - TF_ASSIGN_OR_RETURN(size_t buf_size, - pjrt_arr->pjrt_buffers()[0]->GetOnDeviceSizeInBytes()); + ASSIGN_OR_RETURN(size_t buf_size, + pjrt_arr->pjrt_buffers()[0]->GetOnDeviceSizeInBytes()); for (int j : buffer_idxs) { auto& pjrt_buf = pjrt_arr->pjrt_buffers()[j]; refs.push_back({pjrt_buf, buf_size}); @@ -220,11 +221,11 @@ absl::Status PjRtTransferServer::CrossHostAwaitPull( absl::StatusOr> PjRtTransferServer::GetConnection(int remote_pid) { if (!connections_.contains(remote_pid)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::string address, kv_store_->Get(absl::StrCat(kKeyPrefixSocketAddress, remote_pid), cross_host_transfer_timeout_)); - TF_ASSIGN_OR_RETURN(auto addr, aux::SocketAddress::Parse(address)); + ASSIGN_OR_RETURN(auto addr, aux::SocketAddress::Parse(address)); connections_[remote_pid] = (*socket_server_)->Connect(addr); } return connections_[remote_pid]; @@ -241,7 +242,7 @@ absl::Status PjRtTransferServer::CrossHostPull( tsl::RCReference connection; { absl::MutexLock lock(connections_mu_); - TF_ASSIGN_OR_RETURN(connection, GetConnection(remote_pid)); + ASSIGN_OR_RETURN(connection, GetConnection(remote_pid)); } std::vector shape_specs; @@ -249,10 +250,10 @@ absl::Status PjRtTransferServer::CrossHostPull( shape_specs.reserve(arrays.size()); layouts.reserve(arrays.size()); for (int i = 0; i < arrays.size(); ++i) { - TF_ASSIGN_OR_RETURN(xla::PrimitiveType prim_type, - xla::ifrt::ToPrimitiveType(arrays[i]->dtype())); - TF_ASSIGN_OR_RETURN( - Shape shape, arrays[i]->sharding().GetShardShape(arrays[i]->shape())); + ASSIGN_OR_RETURN(xla::PrimitiveType prim_type, + xla::ifrt::ToPrimitiveType(arrays[i]->dtype())); + ASSIGN_OR_RETURN(Shape shape, + arrays[i]->sharding().GetShardShape(arrays[i]->shape())); xla::PjRtClient::ShapeSpec shape_spec = { prim_type, xla::DimensionVector(shape.dims().begin(), shape.dims().end())}; @@ -262,11 +263,10 @@ absl::Status PjRtTransferServer::CrossHostPull( arrays[i]->pjrt_layout(); std::optional layout; if (pjrt_layout.ok() && *pjrt_layout == nullptr) { - TF_ASSIGN_OR_RETURN( - xla::ifrt::Shape shard_shape, - arrays[i]->sharding().GetShardShape(arrays[i]->shape())); - TF_ASSIGN_OR_RETURN( - pjrt_layout, arrays[i]->client()->GetDefaultPjRtLayout( + ASSIGN_OR_RETURN(xla::ifrt::Shape shard_shape, + arrays[i]->sharding().GetShardShape(arrays[i]->shape())); + ASSIGN_OR_RETURN(pjrt_layout, + arrays[i]->client()->GetDefaultPjRtLayout( arrays[i]->dtype(), shard_shape.dims(), arrays[i]->sharding().devices()->devices().front(), arrays[i]->sharding().memory_kind())); @@ -279,12 +279,12 @@ absl::Status PjRtTransferServer::CrossHostPull( for (int j = 0; j < dst_device_idxs.size(); ++j) { int device_index = dst_device_idxs[j]; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( xla::PjRtMemorySpace * mem_space, GetMemorySpace(memory_kind, dst_devices->devices()[device_index])); // TODO(emilyaf, parkers): Pass `layouts` instead of `std::nullopt` once // ASAN failure is debugged. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::shared_ptr atm, pjrt_client_->CreateBuffersForAsyncHostToDevice( shape_specs, std::nullopt, mem_space)); @@ -349,10 +349,10 @@ PjRtTransferServer::CopyArraysForCrossHost( continue; } if (src_process_index == pjrt_client_->process_index()) { - TF_RETURN_IF_ERROR(CrossHostAwaitPull( + RETURN_IF_ERROR(CrossHostAwaitPull( uuid, arrays, await_pull_buffer_idxs[dst_process_index])); } else { - TF_RETURN_IF_ERROR(CrossHostPull( + RETURN_IF_ERROR(CrossHostPull( uuid, arrays, pull_to_device_idxs[src_process_index], dst_devices, memory_kind, src_process_index, buffers_by_device)); } @@ -362,26 +362,25 @@ PjRtTransferServer::CopyArraysForCrossHost( std::vector new_arrays; new_arrays.reserve(arrays.size()); for (size_t i = 0; i < arrays.size(); ++i) { - TF_ASSIGN_OR_RETURN(auto new_sharding, - arrays[i]->shared_ptr_sharding()->WithDeviceAssignment( - dst_devices, memory_kind)); - TF_ASSIGN_OR_RETURN(auto new_layout, arrays[i]->pjrt_layout()); + ASSIGN_OR_RETURN(auto new_sharding, + arrays[i]->shared_ptr_sharding()->WithDeviceAssignment( + dst_devices, memory_kind)); + ASSIGN_OR_RETURN(auto new_layout, arrays[i]->pjrt_layout()); if (new_layout == nullptr) { - TF_ASSIGN_OR_RETURN( - xla::ifrt::Shape shard_shape, - arrays[i]->sharding().GetShardShape(arrays[i]->shape())); - TF_ASSIGN_OR_RETURN( - new_layout, arrays[i]->client()->GetDefaultPjRtLayout( - arrays[i]->dtype(), shard_shape.dims(), - arrays[i]->sharding().devices()->devices().front(), - arrays[i]->sharding().memory_kind())); + ASSIGN_OR_RETURN(xla::ifrt::Shape shard_shape, + arrays[i]->sharding().GetShardShape(arrays[i]->shape())); + ASSIGN_OR_RETURN(new_layout, + arrays[i]->client()->GetDefaultPjRtLayout( + arrays[i]->dtype(), shard_shape.dims(), + arrays[i]->sharding().devices()->devices().front(), + arrays[i]->sharding().memory_kind())); } PjRtArray::PjRtBuffers array_buffers; array_buffers.reserve(buffers_by_device.size()); for (auto& [_, bufs] : buffers_by_device) { array_buffers.push_back(std::move(bufs[i])); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto arr, PjRtArray::Create(client, arrays[i]->dtype(), arrays[i]->shape(), std::move(new_sharding), std::move(array_buffers), diff --git a/third_party/xla/xla/python/transfer/streaming_ifrt.cc b/third_party/xla/xla/python/transfer/streaming_ifrt.cc index e00729c53fe3aa..c8d672c753f52b 100644 --- a/third_party/xla/xla/python/transfer/streaming_ifrt.cc +++ b/third_party/xla/xla/python/transfer/streaming_ifrt.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/raw_buffer.h" @@ -93,7 +94,7 @@ absl::StatusOr>> AllocateAndMapPjrtMemory( absl::StatusOr> DmaCopyChunk::DivideBufferCopiesEvenly(std::shared_ptr buffer, size_t xfer_size, size_t buffer_id) { - TF_ASSIGN_OR_RETURN(size_t copy_size, buffer->GetOnDeviceSizeInBytes()); + ASSIGN_OR_RETURN(size_t copy_size, buffer->GetOnDeviceSizeInBytes()); size_t total_num_copies = (copy_size + xfer_size - 1) / xfer_size; std::vector work_units; work_units.reserve(total_num_copies); @@ -288,7 +289,7 @@ class SlicedRawBufferChunkDestination : public ChunkDestination { } { absl::MutexLock l(mu_); - TF_RETURN_IF_ERROR(saved_status_); + RETURN_IF_ERROR(saved_status_); sent_bytes_ += size; } auto future = diff --git a/third_party/xla/xla/python/transfer/streaming_ifrt_test.cc b/third_party/xla/xla/python/transfer/streaming_ifrt_test.cc index f87549295e5f27..24ed0d40d5a3e6 100644 --- a/third_party/xla/xla/python/transfer/streaming_ifrt_test.cc +++ b/third_party/xla/xla/python/transfer/streaming_ifrt_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/future.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/raw_buffer.h" @@ -81,11 +82,11 @@ absl::StatusOr SetupTransferDestList( xla::ifrt::PjRtClient* ifrt_client, size_t xfer_size) { auto* pjrt_client = ifrt_client->pjrt_client(); // CHECK_EQ(pjrt_client->platform_id(), xla::TpuId()); - TF_ASSIGN_OR_RETURN(auto* pjrt_memory_space, - device->pjrt_device()->default_memory_space()); - TF_ASSIGN_OR_RETURN(auto atm_owned, - pjrt_client->CreateBuffersForAsyncHostToDevice( - {shape}, pjrt_memory_space)); + ASSIGN_OR_RETURN(auto* pjrt_memory_space, + device->pjrt_device()->default_memory_space()); + ASSIGN_OR_RETURN(auto atm_owned, + pjrt_client->CreateBuffersForAsyncHostToDevice( + {shape}, pjrt_memory_space)); auto atm = std::shared_ptr( std::move(atm_owned)); SingleBufferCopyPlan results; @@ -93,9 +94,9 @@ absl::StatusOr SetupTransferDestList( results.dests.push_back(MakeDmaDestination(atm, 0, copy_size)); // `CreateBuffersForAsyncHostToDevice` uses a default layout. - TF_ASSIGN_OR_RETURN( - auto arr, ifrt_client->CreatePjRtArray(atm->RetrieveBuffer(0), - /*has_custom_layout=*/false)); + ASSIGN_OR_RETURN(auto arr, + ifrt_client->CreatePjRtArray(atm->RetrieveBuffer(0), + /*has_custom_layout=*/false)); results.arrays.push_back(std::move(arr)); return results; } @@ -150,7 +151,7 @@ absl::StatusOr> FetchResult( tsl::RCReference arr, size_t result_size) { std::vector result; result.resize(result_size); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( arr->CopyToHostBuffer(result.data(), std::nullopt, xla::ifrt::ArrayCopySemantics::kReuseInput) .Await()); diff --git a/third_party/xla/xla/python/types.cc b/third_party/xla/xla/python/types.cc index 51b19f6c204b86..97b55189e25392 100644 --- a/third_party/xla/xla/python/types.cc +++ b/third_party/xla/xla/python/types.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "nanobind/nanobind.h" #include "nanobind/ndarray.h" // IWYU pragma: keep #include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep @@ -365,7 +366,7 @@ absl::StatusOr DtypeToIfRtDType(const nb_dtype& dtype) { if (dtype.kind() == 'T') { return ifrt::DType(ifrt::DType::kString); } - TF_ASSIGN_OR_RETURN(auto primitive_type, DtypeToPrimitiveType(dtype)); + ASSIGN_OR_RETURN(auto primitive_type, DtypeToPrimitiveType(dtype)); return ifrt::ToDType(primitive_type); } @@ -518,7 +519,7 @@ absl::StatusOr LiteralToPython( std::vector elems = m.DecomposeTuple(); std::vector arrays(elems.size()); for (int i = 0; i < elems.size(); ++i) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( arrays[i], LiteralToPython(std::make_unique(std::move(elems[i])))); } @@ -531,8 +532,8 @@ absl::StatusOr LiteralToPython( TF_RET_CHECK(m.shape().IsArray()); nb::object literal_object = nb::cast(literal); - TF_ASSIGN_OR_RETURN(nb_dtype dtype, - PrimitiveTypeToNbDtype(m.shape().element_type())); + ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(m.shape().element_type())); return nb_numpy_ndarray(dtype, m.shape().dimensions(), ByteStridesForShape(m.shape()), m.untyped_data(), literal_object); diff --git a/third_party/xla/xla/python/xplane_to_profile_instructions.cc b/third_party/xla/xla/python/xplane_to_profile_instructions.cc index d8a86c9412849c..1ee2236ebea524 100644 --- a/third_party/xla/xla/python/xplane_to_profile_instructions.cc +++ b/third_party/xla/xla/python/xplane_to_profile_instructions.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/platform/env.h" @@ -182,7 +183,7 @@ absl::Status ConvertXplaneUnderLogdirToProfiledInstructionsProto( profiled_instructions_proto) { // Find the xplane files for each host under logdir. std::vector children_path; - TF_RETURN_IF_ERROR(tsl::Env::Default()->GetChildren(logdir, &children_path)); + RETURN_IF_ERROR(tsl::Env::Default()->GetChildren(logdir, &children_path)); if (children_path.empty()) { return absl::NotFoundError( absl::StrCat("Could not find file under: ", logdir)); @@ -192,7 +193,7 @@ absl::Status ConvertXplaneUnderLogdirToProfiledInstructionsProto( if (absl::StrContains(child_path, kXPlanePb)) { std::string xspace_path = ProfilerJoinPath(logdir, child_path); tensorflow::profiler::XSpace xspace; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ReadBinaryProto(tsl::Env::Default(), xspace_path, &xspace)); xspaces.push_back(xspace); } diff --git a/third_party/xla/xla/runtime/BUILD b/third_party/xla/xla/runtime/BUILD index 566ad08739e4d3..33614deed1b74d 100644 --- a/third_party/xla/xla/runtime/BUILD +++ b/third_party/xla/xla/runtime/BUILD @@ -139,6 +139,7 @@ cc_library( name = "object_pool", hdrs = ["object_pool.h"], deps = [ + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:safe_reinterpret_cast", "@com_google_absl//absl/base:core_headers", diff --git a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD index 08da26f2e22842..2bb2f25087d0e2 100644 --- a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD +++ b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/BUILD @@ -43,6 +43,7 @@ cc_library( "//xla:shape_util", "//xla/service:hlo_proto_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -63,6 +64,7 @@ xla_cc_test( "//xla/service:hlo_proto_cc", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util/proto:proto_matchers", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization.cc b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization.cc index 2020e97443e00a..7132af959ac45e 100644 --- a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization.cc +++ b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/literal.h" #include "xla/runtime/large_hlo_snapshot_serialization/coded_stream_iterators.h" #include "xla/service/hlo.pb.h" @@ -46,9 +47,9 @@ absl::Status SerializeHloUnoptimizedSnapshot( for (const auto& partition : snapshot.partitions()) { HloInputs* partition_metadata = metadata_proto.add_partitions(); for (const auto& argument : partition.arguments()) { - TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(argument.shape())); - TF_ASSIGN_OR_RETURN(int64_t serialized_size, - ShapeUtil::SerializedSize(shape)); + ASSIGN_OR_RETURN(auto shape, Shape::FromProto(argument.shape())); + ASSIGN_OR_RETURN(int64_t serialized_size, + ShapeUtil::SerializedSize(shape)); partition_metadata->add_arguments_descriptors()->set_argument_size_bytes( serialized_size); } @@ -63,9 +64,9 @@ absl::Status SerializeHloUnoptimizedSnapshot( // Serialize literals for (const auto& hlo_input : snapshot.partitions()) { for (const auto& literal_proto : hlo_input.arguments()) { - TF_ASSIGN_OR_RETURN(Literal literal, - xla::Literal::CreateFromProto(literal_proto)); - TF_RETURN_IF_ERROR(literal.Serialize(output_it)); + ASSIGN_OR_RETURN(Literal literal, + xla::Literal::CreateFromProto(literal_proto)); + RETURN_IF_ERROR(literal.Serialize(output_it)); } } return absl::OkStatus(); diff --git a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization_test.cc b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization_test.cc index a407bc38f662bb..0cbfdd6ade4d2e 100644 --- a/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization_test.cc +++ b/third_party/xla/xla/runtime/large_hlo_snapshot_serialization/serialization_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/hlo.pb.h" @@ -59,7 +60,7 @@ absl::StatusOr SerializeAndDeserialize( const HloUnoptimizedSnapshot& snapshot) { std::string serialized_snapshot; tsl::protobuf::io::StringOutputStream output_stream(&serialized_snapshot); - TF_RETURN_IF_ERROR(SerializeHloUnoptimizedSnapshot(snapshot, &output_stream)); + RETURN_IF_ERROR(SerializeHloUnoptimizedSnapshot(snapshot, &output_stream)); tsl::protobuf::io::ArrayInputStream input_stream(serialized_snapshot.data(), serialized_snapshot.size()); diff --git a/third_party/xla/xla/runtime/object_pool.h b/third_party/xla/xla/runtime/object_pool.h index ed771198d7499c..ce212c70ef8328 100644 --- a/third_party/xla/xla/runtime/object_pool.h +++ b/third_party/xla/xla/runtime/object_pool.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/safe_reinterpret_cast.h" @@ -119,7 +120,7 @@ auto ObjectPool::CreateEntry(Args... args) -> absl::StatusOr> { DCHECK(builder_) << "ObjectPool builder is not initialized"; auto entry = std::make_unique(); - TF_ASSIGN_OR_RETURN(entry->object, builder_(std::forward(args)...)); + ASSIGN_OR_RETURN(entry->object, builder_(std::forward(args)...)); num_created_.fetch_add(1, std::memory_order_relaxed); return entry; } @@ -181,7 +182,7 @@ auto ObjectPool::GetOrCreate(Args... args) if (std::unique_ptr entry = PopEntry(); ABSL_PREDICT_TRUE(entry)) { return BorrowedObject(this, std::move(entry)); } - TF_ASSIGN_OR_RETURN(auto entry, CreateEntry(std::forward(args)...)); + ASSIGN_OR_RETURN(auto entry, CreateEntry(std::forward(args)...)); return BorrowedObject(this, std::move(entry)); } diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index b410e9a9b3e071..04751e0dca809c 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -191,6 +191,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -214,6 +215,7 @@ xla_cc_test( "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/hlo/utils:hlo_matchers", "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -231,6 +233,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -269,6 +272,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/parser:hlo_parser", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", @@ -327,6 +331,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:backend_configs_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -475,6 +480,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/utils:hlo_query", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -551,6 +557,7 @@ cc_library( "//xla/tsl/lib/strings:proto_serialization", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", @@ -612,6 +619,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -735,6 +743,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -801,6 +810,7 @@ cc_library( "//xla/hlo/utils:hlo_sharding_util", "//xla/service/spmd:shard_barrier_partitioner", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -857,6 +867,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/spmd:shard_barrier_partitioner", "//xla/service/spmd/shardy:constants", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", @@ -896,6 +907,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", ], ) @@ -1085,6 +1097,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1143,6 +1156,7 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service/spmd/shardy:constants", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1210,6 +1224,7 @@ cc_library( "//xla/stream_executor/sycl:sycl_platform_id", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -1260,6 +1275,7 @@ cc_library( "//xla/tsl/framework:bfc_allocator", "//xla/tsl/framework:device_id", "//xla/tsl/platform:env", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -1317,6 +1333,7 @@ cc_library( "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -1357,6 +1374,7 @@ cc_library( "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", @@ -1385,6 +1403,7 @@ cc_library( "//xla/client:executable_build_options", "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -1417,6 +1436,7 @@ cc_library( "//xla/service/heap_simulator", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -1457,6 +1477,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/algorithm:container", @@ -1483,6 +1504,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1541,6 +1563,7 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", @@ -1570,6 +1593,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@tsl//tsl/platform:logging", @@ -1650,6 +1674,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -1689,6 +1714,7 @@ cc_library( ":stream_pool", "//xla:executable_run_options", "//xla/stream_executor:platform", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1726,6 +1752,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor/abi:executable_abi_version", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", @@ -1918,6 +1945,7 @@ cc_library( ":stream_pool", "//xla/hlo/ir:hlo", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", @@ -1948,6 +1976,7 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -1978,6 +2007,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -2001,6 +2031,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:logging", ], @@ -2070,6 +2101,7 @@ cc_library( "//xla/service/memory_space_assignment", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", @@ -2157,6 +2189,7 @@ cc_library( "//xla/hlo/parser:hlo_parser", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -2236,6 +2269,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", "//xla/tsl/platform:macros", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2289,6 +2323,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", @@ -2322,6 +2357,7 @@ cc_library( "//xla/hlo/builder/lib:comparators", "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -2402,6 +2438,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", @@ -2425,6 +2462,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:op_expander_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", @@ -2442,6 +2480,7 @@ cc_library( "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", @@ -2467,6 +2506,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:op_expander_pass", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", @@ -2489,6 +2529,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:op_expander_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -2545,6 +2586,7 @@ cc_library( "//xla/hlo/builder/lib:slicing", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:op_expander_pass", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -2625,6 +2667,7 @@ cc_library( "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2679,6 +2722,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", @@ -2699,6 +2743,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", @@ -2721,6 +2766,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -2762,6 +2808,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -2800,6 +2847,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", ], @@ -2833,6 +2881,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:errors", @@ -2850,6 +2899,7 @@ xla_cc_test( "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:status_macros", ], ) @@ -2882,6 +2932,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2929,6 +2980,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -2986,6 +3038,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/lib/core:bitmap", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3030,6 +3083,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/simplifiers:tuple_simplifier", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -3088,6 +3142,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/hlo/utils:hlo_query", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3146,6 +3201,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/utils:hlo_query", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3212,6 +3268,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -3275,6 +3332,7 @@ cc_library( "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3317,6 +3375,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:flatten_call_graph", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/tsl/lib/monitoring:gauge", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", @@ -3365,6 +3424,7 @@ xla_test( "//xla/tests:xla_test_backend_predicates", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/log", @@ -3405,6 +3465,7 @@ xla_cc_test( "//xla/hlo/testlib:test", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -3431,6 +3492,7 @@ cc_library( "//xla/stream_executor/rocm:rocm_platform_id", "//xla/stream_executor/sycl:sycl_platform_id", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -3529,6 +3591,7 @@ cc_library( "//xla/stream_executor:memory_allocation", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", @@ -3580,6 +3643,7 @@ cc_library( "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3904,6 +3968,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3956,6 +4021,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3988,6 +4054,7 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service/graphcycles", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -4063,6 +4130,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:tuple_simplifier", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:errors", @@ -4078,6 +4146,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -4109,6 +4178,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -4163,6 +4233,7 @@ cc_library( ":hlo_verifier", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:errors", @@ -4228,6 +4299,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -4248,6 +4320,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -4322,6 +4395,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -4344,6 +4418,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -4362,6 +4437,7 @@ cc_library( ":hlo_domain_remover", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@tsl//tsl/platform:statusor", ], ) @@ -4376,6 +4452,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -4455,6 +4532,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", @@ -4485,6 +4563,7 @@ cc_library( "//xla/service/llvm_ir:llvm_loop", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -4546,6 +4625,7 @@ cc_library( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", @@ -4581,6 +4661,7 @@ cc_library( "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], @@ -4606,6 +4687,7 @@ cc_library( "//xla/tsl/lib/io:zlib_outputbuffer", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -4654,6 +4736,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -4765,6 +4848,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", @@ -4888,6 +4972,7 @@ cc_library( "//xla:shape_tree", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", @@ -4937,6 +5022,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -4978,6 +5064,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -5005,6 +5092,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", @@ -5037,6 +5125,7 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service/gpu:dynamic_slicing_utils", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -5088,6 +5177,7 @@ cc_library( "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:tuple_simplifier", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -5139,6 +5229,7 @@ cc_library( "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -5182,6 +5273,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -5220,6 +5312,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -5265,6 +5358,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -5310,6 +5404,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -5378,6 +5473,7 @@ cc_library( "//xla:types", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -5550,6 +5646,7 @@ cc_library( "//xla/hlo/ir:collective_op_group_mode", "//xla/hlo/ir:hlo", "//xla/runtime:device_id", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -5583,6 +5680,7 @@ xla_cc_test( "//xla/runtime:device_id", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -5606,6 +5704,7 @@ cc_library( "//xla/hlo/builder/lib:comparators", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -5724,6 +5823,7 @@ cc_library( hdrs = ["mapped_ptr_container_sorter.h"], deps = [ "//xla:util", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", @@ -5825,6 +5925,7 @@ cc_library( deps = [ "//xla:xla_proto_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -5898,6 +5999,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -6006,6 +6108,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", @@ -6027,6 +6130,7 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", @@ -6306,6 +6410,7 @@ cc_library( "//xla:parse_flags_from_env", "//xla:util", "//xla:xla_proto_cc", + "//xla/tsl/platform:status_macros", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -6463,6 +6568,7 @@ cc_library( "//xla/hlo/ir:ptrvec", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -6548,6 +6654,7 @@ cc_library( "//xla:side_effect_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -6637,6 +6744,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -6676,6 +6784,7 @@ cc_library( ":buffer_assignment", ":shaped_slice_proto_cc", "//xla:shape_util", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/service/all_gather_decomposer.cc b/third_party/xla/xla/service/all_gather_decomposer.cc index c21b03c7834f23..46393ae7a329dd 100644 --- a/third_party/xla/xla/service/all_gather_decomposer.cc +++ b/third_party/xla/xla/service/all_gather_decomposer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -87,9 +88,9 @@ HloInstruction* AllGatherDecomposer::TranslateAllGatherToAllReducePerOperand( absl::Status AllGatherDecomposer::DecomposeAllGather( HloAllGatherInstruction* ag, HloComputation* comp) { - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(ag->channel_id().has_value(), - ag->use_global_device_ids())); + ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(ag->channel_id().has_value(), + ag->use_global_device_ids())); if (ag->operand_count() > 1) { std::vector tuple_inputs; for (int i = 0; i < ag->operand_count(); ++i) { @@ -101,14 +102,14 @@ absl::Status AllGatherDecomposer::DecomposeAllGather( tuple_inputs.push_back(ar); } auto tup = comp->AddInstruction(HloInstruction::CreateTuple(tuple_inputs)); - TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(tup)); + RETURN_IF_ERROR(ag->ReplaceAllUsesWith(tup)); } else { auto* ar = TranslateAllGatherToAllReducePerOperand( group_mode, *ag, ag->shape(), ag->mutable_operand(0), comp, ag->all_gather_dimension()); - TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar)); + RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar)); } - TF_RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag)); + RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag)); return absl::OkStatus(); } @@ -123,7 +124,7 @@ absl::StatusOr AllGatherDecomposer::RunImpl( } auto ag = Cast(hlo); if (ShouldDecompose(*ag)) { - TF_RETURN_IF_ERROR(DecomposeAllGather(ag, comp)); + RETURN_IF_ERROR(DecomposeAllGather(ag, comp)); changed = true; } } diff --git a/third_party/xla/xla/service/all_gather_simplifier.cc b/third_party/xla/xla/service/all_gather_simplifier.cc index 858473e5854b31..2db438a3e0c9e4 100644 --- a/third_party/xla/xla/service/all_gather_simplifier.cc +++ b/third_party/xla/xla/service/all_gather_simplifier.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -46,7 +47,7 @@ absl::StatusOr AllGatherSimplifier::CancelSingleDynamicSliceFromAllGather( HloComputation* computation = inst->parent(); if (ShapeUtil::Compatible(inst->shape(), inst->operand(0)->shape())) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation->ReplaceInstruction(inst, inst->mutable_operand(0))); return true; } @@ -68,8 +69,8 @@ absl::StatusOr AllGatherSimplifier::CancelSingleDynamicSliceFromAllGather( if (!ShapeUtil::Compatible(ds->shape(), ag_operand->shape())) { return false; } - TF_RETURN_IF_ERROR(ds->ReplaceAllUsesWith(ag_operand)); - TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ds)); + RETURN_IF_ERROR(ds->ReplaceAllUsesWith(ag_operand)); + RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ds)); return true; } @@ -82,8 +83,8 @@ absl::StatusOr AllGatherSimplifier::RunImpl( bool changed = false; for (auto computation : module->computations(execution_threads)) { for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool local_changed, - CancelSingleDynamicSliceFromAllGather(module, inst)); + ASSIGN_OR_RETURN(bool local_changed, + CancelSingleDynamicSliceFromAllGather(module, inst)); changed |= local_changed; } } diff --git a/third_party/xla/xla/service/all_reduce_reassociate.cc b/third_party/xla/xla/service/all_reduce_reassociate.cc index 3b3c55257f6949..2e9af248e6a91b 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate.cc +++ b/third_party/xla/xla/service/all_reduce_reassociate.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -234,12 +235,12 @@ absl::StatusOr AllReduceReassociate::RunImpl( // Check Dynamic-slice pattern is identical if (lhs->opcode() == HloOpcode::kDynamicSlice) { HloInstruction* original_rhs_operand = rhs->mutable_operand(0); - TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, lhs->mutable_operand(0))); + RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, lhs->mutable_operand(0))); if (!lhs->Identical(*rhs)) { - TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, original_rhs_operand)); + RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, original_rhs_operand)); continue; } - TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, original_rhs_operand)); + RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, original_rhs_operand)); ar0 = Cast(lhs->mutable_operand(0)); ar1 = Cast(rhs->mutable_operand(0)); reduce_scatter_pattern_match = true; @@ -298,12 +299,12 @@ absl::StatusOr AllReduceReassociate::RunImpl( HloInstruction* new_op_operand1 = ar1->mutable_operand(0); if (convert0) { HloInstruction* ar0_operand = ar0->mutable_operand(0); - TF_RETURN_IF_ERROR(convert0->ReplaceOperandWith(0, ar0_operand)); + RETURN_IF_ERROR(convert0->ReplaceOperandWith(0, ar0_operand)); new_op_operand0 = convert0; } if (convert1) { HloInstruction* ar1_operand = ar1->mutable_operand(0); - TF_RETURN_IF_ERROR(convert1->ReplaceOperandWith(0, ar1_operand)); + RETURN_IF_ERROR(convert1->ReplaceOperandWith(0, ar1_operand)); new_op_operand1 = convert1; } @@ -324,8 +325,8 @@ absl::StatusOr AllReduceReassociate::RunImpl( } else if (reduce_scatter_pattern_match) { new_ar_out_shape = ar0->shape(); } else { - TF_RETURN_IF_ERROR(ar0->ReplaceAllUsesWith(ar0->mutable_operand(0))); - TF_RETURN_IF_ERROR(ar1->ReplaceAllUsesWith(ar1->mutable_operand(0))); + RETURN_IF_ERROR(ar0->ReplaceAllUsesWith(ar0->mutable_operand(0))); + RETURN_IF_ERROR(ar1->ReplaceAllUsesWith(ar1->mutable_operand(0))); } HloInstruction* new_ar = computation->AddInstruction( @@ -351,32 +352,32 @@ absl::StatusOr AllReduceReassociate::RunImpl( HloComputation* to_apply_promoted = inst->GetModule()->AddEmbeddedComputation(promoted.Build()); new_ar->set_to_apply(to_apply_promoted); - TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(new_ar)); + RETURN_IF_ERROR(inst->ReplaceAllUsesWith(new_ar)); } else if (reduce_scatter_pattern_match) { auto dyn_slice_operands = lhs->mutable_operands(); dyn_slice_operands[0] = new_ar; HloInstruction* new_dyn_slice = inst->parent()->AddInstruction( lhs->CloneWithNewOperands(inst->shape(), dyn_slice_operands)); - TF_RETURN_IF_ERROR(inst->ReplaceUsesWith(op_users, new_dyn_slice)); + RETURN_IF_ERROR(inst->ReplaceUsesWith(op_users, new_dyn_slice)); } else { - TF_RETURN_IF_ERROR(inst->ReplaceUsesWith(op_users, new_ar)); + RETURN_IF_ERROR(inst->ReplaceUsesWith(op_users, new_ar)); } // Note that RemoveInstructionAndUnusedOperands may not remove the 2 // all-reduce operands of `inst` if they are not safe to remove otherwise, // so manually these instructions. if (should_promote_ar || reduce_scatter_pattern_match) { - TF_RETURN_IF_ERROR(computation->RemoveInstruction(inst)); + RETURN_IF_ERROR(computation->RemoveInstruction(inst)); } if (reduce_scatter_pattern_match) { - TF_RETURN_IF_ERROR(computation->RemoveInstruction(lhs)); + RETURN_IF_ERROR(computation->RemoveInstruction(lhs)); if (lhs != rhs) { - TF_RETURN_IF_ERROR(computation->RemoveInstruction(rhs)); + RETURN_IF_ERROR(computation->RemoveInstruction(rhs)); } } - TF_RETURN_IF_ERROR(computation->RemoveInstruction(ar0)); + RETURN_IF_ERROR(computation->RemoveInstruction(ar0)); if (ar0 != ar1) { - TF_RETURN_IF_ERROR(computation->RemoveInstruction(ar1)); + RETURN_IF_ERROR(computation->RemoveInstruction(ar1)); } changed = true; } diff --git a/third_party/xla/xla/service/all_reduce_reassociate_test.cc b/third_party/xla/xla/service/all_reduce_reassociate_test.cc index 809a12eef21b24..cd8a98e13ab24a 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate_test.cc +++ b/third_party/xla/xla/service/all_reduce_reassociate_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -46,7 +47,7 @@ class AllReduceSimplifierTest : public HloHardwareIndependentTestBase { absl::StatusOr> RunPass( absl::string_view hlo_module, bool expect_change, bool reassociate_converted_ar = false) { - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); + ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); auto changed = AllReduceReassociate(reassociate_converted_ar).Run(module.get()); if (!changed.ok()) { diff --git a/third_party/xla/xla/service/all_reduce_reduce_scatter_reorder.cc b/third_party/xla/xla/service/all_reduce_reduce_scatter_reorder.cc index e738bc918a9256..ac66f34e42f7cc 100644 --- a/third_party/xla/xla/service/all_reduce_reduce_scatter_reorder.cc +++ b/third_party/xla/xla/service/all_reduce_reduce_scatter_reorder.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -64,10 +65,10 @@ absl::Status ReorderAllReduceReduceScatter(HloInstruction* old_ar) { old_rs->CloneWithNewOperands(old_rs->shape(), old_ar->operands())); HloInstruction* new_ar = computation->AddInstruction( old_ar->CloneWithNewOperands(old_rs->shape(), {new_rs})); - TF_RETURN_IF_ERROR(old_rs->ReplaceUsesWith(old_rs->users(), new_ar)); + RETURN_IF_ERROR(old_rs->ReplaceUsesWith(old_rs->users(), new_ar)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_rs)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_ar)); + RETURN_IF_ERROR(computation->RemoveInstruction(old_rs)); + RETURN_IF_ERROR(computation->RemoveInstruction(old_ar)); return absl::OkStatus(); } } // namespace @@ -79,7 +80,7 @@ absl::StatusOr AllReduceReduceScatterReorder::RunImpl( for (auto computation : module->computations(execution_threads)) { for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { if (IsAllReduceReduceScatter(inst)) { - TF_RETURN_IF_ERROR(ReorderAllReduceReduceScatter(inst)); + RETURN_IF_ERROR(ReorderAllReduceReduceScatter(inst)); changed = true; } } diff --git a/third_party/xla/xla/service/all_reduce_simplifier.cc b/third_party/xla/xla/service/all_reduce_simplifier.cc index ec8a870d14999a..c23b04b98cd1cd 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/hlo_replication_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -45,7 +46,7 @@ namespace xla { absl::StatusOr AllReduceSimplifier::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto replication, HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false)); std::vector> all_reduces_to_replace; @@ -56,7 +57,7 @@ absl::StatusOr AllReduceSimplifier::RunImpl( auto get_participant_counts_for_replica_group = [](const HloInstruction* all_reduce) -> absl::StatusOr { const HloModuleConfig& config = all_reduce->GetModule()->config(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( CollectiveOpGroupMode group_mode, GetCollectiveOpGroupMode(all_reduce->channel_id().has_value(), Cast(all_reduce) @@ -64,10 +65,10 @@ absl::StatusOr AllReduceSimplifier::RunImpl( int64_t num_devices = config.num_partitions(); int64_t num_replicas = config.replica_count(); - TF_ASSIGN_OR_RETURN(std::vector participant_counts, - GetPariticipantCountsForReplicaGroups( - num_replicas, num_devices, - all_reduce->replica_groups(), group_mode)); + ASSIGN_OR_RETURN(std::vector participant_counts, + GetPariticipantCountsForReplicaGroups( + num_replicas, num_devices, + all_reduce->replica_groups(), group_mode)); if (participant_counts.empty()) { return -1; } @@ -96,7 +97,7 @@ absl::StatusOr AllReduceSimplifier::RunImpl( inst->opcode() == HloOpcode::kReduceScatter) && ShapeUtil::Compatible(inst->shape(), inst->operand(0)->shape())) { changed = true; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation->ReplaceInstruction(inst, inst->mutable_operand(0))); } } @@ -119,9 +120,9 @@ absl::StatusOr AllReduceSimplifier::RunImpl( } TF_RET_CHECK(async_done != nullptr) << "Expected async-done for async-start " << async_start->name(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( async_done->parent()->ReplaceInstruction(async_done, input)); - TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); + RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); changed = true; } @@ -142,8 +143,8 @@ absl::StatusOr AllReduceSimplifier::RunImpl( // TODO: b/501070020 - Support asynchronous all-reduce. continue; } - TF_ASSIGN_OR_RETURN(int64_t group_size, - get_participant_counts_for_replica_group(inst)); + ASSIGN_OR_RETURN(int64_t group_size, + get_participant_counts_for_replica_group(inst)); // We will not simplify this all reduce if any of the following is true: // 1. All group do not have the same size. @@ -170,7 +171,7 @@ absl::StatusOr AllReduceSimplifier::RunImpl( auto all_reduce = all_reduce_and_group_size.first; const int64_t replica_group_size = all_reduce_and_group_size.second; if (replica_group_size == 1) { - TF_RETURN_IF_ERROR(all_reduce->parent()->ReplaceInstruction( + RETURN_IF_ERROR(all_reduce->parent()->ReplaceInstruction( all_reduce, all_reduce->mutable_operand(0))); changed = true; continue; @@ -216,7 +217,7 @@ absl::StatusOr AllReduceSimplifier::RunImpl( } VLOG(2) << "Replacing " << all_reduce->ToString() << " with " << replacement->ToString(); - TF_RETURN_IF_ERROR(all_reduce->ReplaceAllUsesWith(replacement)); + RETURN_IF_ERROR(all_reduce->ReplaceAllUsesWith(replacement)); changed = true; } return changed; diff --git a/third_party/xla/xla/service/allocation_tracker.cc b/third_party/xla/xla/service/allocation_tracker.cc index 4479c811531400..b8eefd0da83252 100644 --- a/third_party/xla/xla/service/allocation_tracker.cc +++ b/third_party/xla/xla/service/allocation_tracker.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -105,8 +106,8 @@ absl::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { absl::MutexLock lock(mutex_); VLOG(2) << "Unregister(" << "handle: " << data.handle() << ")"; - TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, - ResolveInternal(data)); + ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); for (const auto& shaped_buffer : replicated_buffers) { std::vector shape_indices; ShapeUtil::ForEachSubshape( @@ -115,8 +116,8 @@ absl::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { shape_indices.push_back(index); }); for (const ShapeIndex& index : shape_indices) { - TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), - shaped_buffer->device_ordinal())); + RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal())); } } // Keep a nullptr as a tombstone for unregistered handles. This enables @@ -137,8 +138,8 @@ absl::StatusOr> AllocationTracker::DeconstructTuple(const GlobalDataHandle& data) { absl::MutexLock lock(mutex_); - TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, - ResolveInternal(data)); + ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); // We only need to care about replica id 0 here, since the GlobalDataHandle is // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; @@ -162,7 +163,7 @@ AllocationTracker::DeconstructTuple(const GlobalDataHandle& data) { /*index=*/{}); std::vector replicated_buffers; replicated_buffers.push_back(std::move(element_buffer)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GlobalDataHandle element_handle, RegisterInternal(std::move(replicated_buffers), "deconstructed tuple")); @@ -180,8 +181,8 @@ absl::StatusOr> AllocationTracker::Resolve( absl::StatusOr AllocationTracker::ResolveForReplica( const GlobalDataHandle& data, int replica_id) const { absl::MutexLock lock(mutex_); - TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, - ResolveInternal(data)); + ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); if (replica_id >= replicated_buffers.size()) { return InvalidArgument( "Requesting buffer for replica %d, but found buffers only for %lu " @@ -233,7 +234,7 @@ absl::Status AllocationTracker::DecrementRefCount( Allocation& allocation = it->second; TF_RET_CHECK(allocation.ref_count >= 1); if (allocation.ref_count == 1) { - TF_RETURN_IF_ERROR(allocation.device_memory.Free()); + RETURN_IF_ERROR(allocation.device_memory.Free()); allocation_map.erase(it); } else { allocation.ref_count--; diff --git a/third_party/xla/xla/service/backend.cc b/third_party/xla/xla/service/backend.cc index 4f4528a268ac1e..9719dbe0b5693b 100644 --- a/third_party/xla/xla/service/backend.cc +++ b/third_party/xla/xla/service/backend.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/service/compiler.h" #include "xla/service/computation_placer.h" #include "xla/service/platform_util.h" @@ -183,14 +184,14 @@ CreateGpuAllocators(const se::Platform* platform, /* static */ absl::StatusOr> Backend::CreateBackend( const BackendOptions& options) { se::Platform* platform = options.platform(); - TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform->id())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform->id())); + ASSIGN_OR_RETURN( auto stream_executors, PlatformUtil::GetStreamExecutors(platform, options.allowed_devices())); - TF_ASSIGN_OR_RETURN(auto transfer_manager, - TransferManager::GetForPlatform(platform)); - TF_ASSIGN_OR_RETURN(auto computation_placer, - ComputationPlacer::GetForPlatform(platform->id())); + ASSIGN_OR_RETURN(auto transfer_manager, + TransferManager::GetForPlatform(platform)); + ASSIGN_OR_RETURN(auto computation_placer, + ComputationPlacer::GetForPlatform(platform->id())); std::unique_ptr backend(new Backend( platform, std::move(compiler), stream_executors, transfer_manager, computation_placer, options.intra_op_parallelism_threads())); @@ -199,8 +200,7 @@ CreateGpuAllocators(const se::Platform* platform, /* static */ absl::StatusOr> Backend::CreateDefaultBackend() { - TF_ASSIGN_OR_RETURN(se::Platform * platform, - PlatformUtil::GetDefaultPlatform()); + ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetDefaultPlatform()); BackendOptions backend_options; backend_options.set_platform(platform); return CreateBackend(backend_options); @@ -208,7 +208,7 @@ Backend::CreateDefaultBackend() { absl::StatusOr Backend::BorrowStream( int device_ordinal, se::StreamPriority priority) { - TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); + ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); return BorrowStream(executor, priority); } @@ -224,7 +224,7 @@ absl::StatusOr Backend::BorrowStream( absl::StatusOr> Backend::BorrowStreams( int device_ordinal, int num_streams, se::StreamPriority priority) { absl::MutexLock l(mu_); - TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); + ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); if (!stream_pools_.contains(executor)) { stream_pools_.emplace(executor, std::make_unique(executor)); } @@ -312,10 +312,10 @@ absl::StatusOr Backend::devices_equivalent(int device_ordinal_a, // bit crude but works for GPUs which is the important case where we compile // an executable for one GPU and want to know if it will run (well) on // another. - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor_a, - stream_executor(device_ordinal_a)); - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor_b, - stream_executor(device_ordinal_b)); + ASSIGN_OR_RETURN(se::StreamExecutor * executor_a, + stream_executor(device_ordinal_a)); + ASSIGN_OR_RETURN(se::StreamExecutor * executor_b, + stream_executor(device_ordinal_b)); return (executor_a->GetDeviceDescription().name() == executor_b->GetDeviceDescription().name()); } diff --git a/third_party/xla/xla/service/batchnorm_expander.cc b/third_party/xla/xla/service/batchnorm_expander.cc index 2bc3c3630fcb27..84274056346d46 100644 --- a/third_party/xla/xla/service/batchnorm_expander.cc +++ b/third_party/xla/xla/service/batchnorm_expander.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -179,11 +180,11 @@ absl::Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape feature_shape = scale->shape(); auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); + ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); + ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); Shape scalar_broadcast_shape = ShapeUtil::MakeStaticShape(operand_shape); auto epsilon = add(HloInstruction::CreateBroadcast( scalar_broadcast_shape, @@ -339,7 +340,7 @@ absl::Status BatchNormExpanderVisitor::HandleBatchNormInference( }; auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); + ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( scalar_broadcast_shape, add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); @@ -433,11 +434,11 @@ absl::Status BatchNormExpanderVisitor::HandleBatchNormGrad( add(DynamicElementCountPerFeature(activation, feature_index, add)); auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); + ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); + ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); auto epsilon_activation = add(HloInstruction::CreateBroadcast( diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index e2bb9adbae8d15..dbfb60114e49e3 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -45,6 +45,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/alias_info.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" @@ -735,9 +736,9 @@ BufferAllocation* BufferAssignment::NewEmptyAllocation( absl::StatusOr BufferAssignment::NewAllocation( const HloBuffer& buffer, int64_t size) { - TF_ASSIGN_OR_RETURN(auto color, buffer.color()); + ASSIGN_OR_RETURN(auto color, buffer.color()); BufferAllocation* allocation = NewEmptyAllocation(size, color); - TF_RETURN_IF_ERROR(AddAssignment(allocation, buffer, /*offset=*/0, size)); + RETURN_IF_ERROR(AddAssignment(allocation, buffer, /*offset=*/0, size)); allocation->peak_buffers_.push_back(buffer.values()[0]); return allocation; } @@ -752,7 +753,7 @@ absl::Status BufferAssignment::AddAssignment(BufferAllocation* allocation, for (const HloValue* buffer_value : buffer.values()) { CHECK(!allocation_index_for_value_.contains(buffer_value)) << "BufferValue " << buffer_value << " already has an allocation."; - TF_RETURN_IF_ERROR(allocation->AddAssignment(*buffer_value, offset, size)); + RETURN_IF_ERROR(allocation->AddAssignment(*buffer_value, offset, size)); allocation_index_for_value_[buffer_value] = allocation->index(); } @@ -767,7 +768,7 @@ absl::Status BufferAssignment::AddAssignment(BufferAllocation* allocation, absl::Status BufferAssignment::AddAssignment(BufferAllocation* allocation, const HloValue& value, int64_t offset, int64_t size) { - TF_RETURN_IF_ERROR(allocation->AddAssignment(value, offset, size)); + RETURN_IF_ERROR(allocation->AddAssignment(value, offset, size)); allocation_index_for_value_[&value] = allocation->index(); const HloValue& hlo_value = *CHECK_NOTNULL(dynamic_cast(&value)); @@ -853,7 +854,7 @@ absl::Status BufferAssignment::CombineTempAllocations( const HloValue* value = buffer_offset_size.first; const int64_t offset = buffer_offset_size.second.offset; const int64_t size = buffer_offset_size.second.size; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( combined_allocation->AddAssignment(*value, base + offset, size)); } if (!temp_allocation.HeapTraces().empty()) { @@ -946,8 +947,8 @@ absl::StatusOr BufferAssignment::ComputeTotalFragmentationBytes( } } if (schedule_complete) { - TF_RETURN_IF_ERROR(schedule.Verify()); - TF_ASSIGN_OR_RETURN( + RETURN_IF_ERROR(schedule.Verify()); + ASSIGN_OR_RETURN( const int64_t min_size, HeapSimulator::MinimumMemoryForModule(schedule, alias_analysis(), alias_info, &buffer_size_)); @@ -1408,8 +1409,8 @@ absl::StatusOr> BufferAssignment::FromProto( const BufferAssignmentProto& proto, const HloModule* module, BufferValue::SizeFunction buffer_size, const AliasInfo* alias_info) { // Create alias and dataflow analysis. - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module, alias_info)); + ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, alias_info)); // Build a map from a unique_id to corresponding HloInstruction in the module. auto id_to_hlo_instruction = BuildIdToHloInstructionMap(module); @@ -1417,7 +1418,7 @@ absl::StatusOr> BufferAssignment::FromProto( // Build a map from logical buffer id in the proto to hlo value in the // existing dataflow analysis. absl::flat_hash_map id_to_logical_buffer; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( id_to_logical_buffer, BuildIdToLogicalBufferMap(proto, id_to_hlo_instruction, alias_analysis->dataflow_analysis())); @@ -1463,7 +1464,7 @@ absl::StatusOr> BufferAssignment::FromProto( for (const auto& assignee : alloc_proto.assigned()) { HloValue::Id logical_buffer_id = assignee.logical_buffer_id(); const auto& buffer_val = id_to_logical_buffer[logical_buffer_id]; - TF_RETURN_IF_ERROR(buffer_assignment->AddAssignment( + RETURN_IF_ERROR(buffer_assignment->AddAssignment( allocation, *buffer_val, assignee.offset(), assignee.size())); } @@ -1584,7 +1585,7 @@ absl::StatusOr BufferAssigner::MaybeAssignBuffer( << assignment->HloBufferSize(hlo_buffer) << " to allocation: " << *allocation; - TF_ASSIGN_OR_RETURN(auto buffer_color, hlo_buffer.color()); + ASSIGN_OR_RETURN(auto buffer_color, hlo_buffer.color()); if (buffer_color != allocation->color()) { VLOG(4) << "Can't assign: buffer has color " << buffer_color << " and allocation has color " << allocation->color() << "."; @@ -1722,7 +1723,7 @@ absl::StatusOr BufferAssigner::MaybeAssignBuffer( return false; } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( assignment->AddAssignment(allocation, hlo_buffer, /*offset=*/0, assignment->HloBufferSize(hlo_buffer))); return true; @@ -1739,9 +1740,8 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( for (const HloValue* value : hlo_buffer->values()) { if (value->instruction()->opcode() == HloOpcode::kConstant) { if (opts_.allocate_buffers_for_constants) { - TF_ASSIGN_OR_RETURN( - BufferAllocation * allocation, - assignment->NewAllocation(*hlo_buffer, buffer_size)); + ASSIGN_OR_RETURN(BufferAllocation * allocation, + assignment->NewAllocation(*hlo_buffer, buffer_size)); allocation->set_constant(true); VLOG(3) << "New allocation #" << allocation->index() << " for constant " << *hlo_buffer << " value ptr: " << value; @@ -1763,8 +1763,8 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( // allocation and sets its parameter number. Parameters of non-entry // computations do not need special allocations because they live inside // callers. - TF_ASSIGN_OR_RETURN(BufferAllocation * allocation, - assignment->NewAllocation(*hlo_buffer, buffer_size)); + ASSIGN_OR_RETURN(BufferAllocation * allocation, + assignment->NewAllocation(*hlo_buffer, buffer_size)); allocation->set_entry_computation_parameter( instruction->parameter_number(), value->index(), parameter_has_alias); @@ -1778,8 +1778,8 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( } if (is_thread_local) { - TF_ASSIGN_OR_RETURN(BufferAllocation * allocation, - assignment->NewAllocation(*hlo_buffer, buffer_size)); + ASSIGN_OR_RETURN(BufferAllocation * allocation, + assignment->NewAllocation(*hlo_buffer, buffer_size)); allocation->set_is_thread_local(true); VLOG(3) << "New allocation #" << allocation->index() << " for thread-local: " << *hlo_buffer; @@ -1788,8 +1788,8 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( for (const HloValue* value : hlo_buffer->values()) { if (value->shape().IsTuple()) { - TF_ASSIGN_OR_RETURN(BufferAllocation * allocation, - assignment->NewAllocation(*hlo_buffer, buffer_size)); + ASSIGN_OR_RETURN(BufferAllocation * allocation, + assignment->NewAllocation(*hlo_buffer, buffer_size)); allocation->set_is_tuple(true); VLOG(3) << "New allocation #" << allocation->index() << " for tuple-shaped buffer: " << *hlo_buffer; @@ -1803,7 +1803,7 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( assignment->GetAllSlices(operand, /*index=*/{})) { BufferAllocation* allocation = assignment->GetMutableAllocation(operand_slice.index()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool buffer_assigned, MaybeAssignBuffer(allocation, *hlo_buffer, assignment)); if (buffer_assigned) { @@ -1822,8 +1822,8 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( allocation_index >= 0; allocation_index--) { BufferAllocation* allocation = assignment->GetMutableAllocation( allocation_indices->at(allocation_index)); - TF_ASSIGN_OR_RETURN(bool buffer_assigned, - MaybeAssignBuffer(allocation, *hlo_buffer, assignment)); + ASSIGN_OR_RETURN(bool buffer_assigned, + MaybeAssignBuffer(allocation, *hlo_buffer, assignment)); if (buffer_assigned) { VLOG(3) << "Reusing allocation #" << allocation->index() << " for: " << *hlo_buffer; @@ -1861,8 +1861,8 @@ absl::Status BufferAssigner::AssignSingleHloBuffer( } if (!assignment->HasAllocation(*hlo_buffer)) { - TF_ASSIGN_OR_RETURN(BufferAllocation * allocation, - assignment->NewAllocation(*hlo_buffer, buffer_size)); + ASSIGN_OR_RETURN(BufferAllocation * allocation, + assignment->NewAllocation(*hlo_buffer, buffer_size)); allocation_indices->push_back(allocation->index()); VLOG(3) << "New allocation #" << allocation->index() << " for: " << *hlo_buffer; @@ -1887,7 +1887,7 @@ absl::Status BufferAssigner::AssignBuffersForComputations( // First assign the preset allocations. absl::flat_hash_set preset_assigned_buffers; - TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment)); + RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment)); const HloAliasAnalysis& alias_analysis = assignment->alias_analysis(); @@ -1914,7 +1914,7 @@ absl::Status BufferAssigner::AssignBuffersForComputations( std::vector reverse_post_order_computations; std::unique_ptr call_graph = CallGraph::Build(computations[0]->parent()); - TF_RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) { + RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) { if (computations_set.contains(node.computation())) { reverse_post_order_computations.push_back(node.computation()); } @@ -1979,9 +1979,9 @@ absl::Status BufferAssigner::AssignBuffersForComputations( for (const HloBuffer* buffer : sorted_buffers) { VLOG(3) << "================================================="; VLOG(3) << "Assigning buffer for " << *buffer; - TF_RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local, - buffers_to_assign_sequentially, - &allocation_indices, assignment)); + RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local, + buffers_to_assign_sequentially, + &allocation_indices, assignment)); } return absl::OkStatus(); } @@ -2056,7 +2056,7 @@ absl::Status BufferAssigner::AssignPresetBuffers( CHECK(preset_allocations_iter != preset_allocations.end()) << "No preset value allocation for color " << value->color() << " for " << value->ToShortString() << " found."; - TF_RETURN_IF_ERROR(preset_allocations_iter->second->AddAssignment( + RETURN_IF_ERROR(preset_allocations_iter->second->AddAssignment( *value, chunk.offset, chunk.size)); } @@ -2257,25 +2257,24 @@ absl::Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*private_stack_computation); HeapSimulator::Result result; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( result, HeapSimulator::Run( get_heap_algorithm(alignment, color), *private_stack_computation, *instruction_sequence, assignment->alias_analysis(), alias_info_, &assignment->buffer_size_, &schedule, options)); - TF_RETURN_IF_ERROR(AssignBuffersFromHeapSimulator( + RETURN_IF_ERROR(AssignBuffersFromHeapSimulator( result, assignment, color, isolation_options)); } } else { options.buffers_to_assign = &color_map[color]; HeapSimulator::Result result; - TF_ASSIGN_OR_RETURN( - result, - HeapSimulator::Run(get_heap_algorithm(alignment, color), - assignment->module(), schedule, - assignment->alias_analysis(), alias_info_, - &assignment->buffer_size_, options)); - TF_RETURN_IF_ERROR(AssignBuffersFromHeapSimulator( + ASSIGN_OR_RETURN(result, HeapSimulator::Run( + get_heap_algorithm(alignment, color), + assignment->module(), schedule, + assignment->alias_analysis(), alias_info_, + &assignment->buffer_size_, options)); + RETURN_IF_ERROR(AssignBuffersFromHeapSimulator( result, assignment, color, isolation_options)); } } @@ -2304,12 +2303,12 @@ absl::Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Options options; options.buffers_to_assign = &color_map[color]; HeapSimulator::Result result; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( result, HeapSimulator::Run( get_heap_algorithm(alignment, color), *computation, *instruction_sequence, assignment->alias_analysis(), alias_info_, &assignment->buffer_size_, options)); - TF_RETURN_IF_ERROR(AssignBuffersFromHeapSimulator( + RETURN_IF_ERROR(AssignBuffersFromHeapSimulator( result, assignment, color, isolation_options)); } } @@ -2526,8 +2525,8 @@ absl::Status BufferAssigner::AssignBuffersFromHeapSimulator( assignment->NewEmptyAllocation(heap_result.heap_size, color); for (const auto& [value, chunk] : heap_result.chunk_map) { - TF_RETURN_IF_ERROR(assignment->AddAssignment(allocation, *value, - chunk.offset, chunk.size)); + RETURN_IF_ERROR(assignment->AddAssignment(allocation, *value, + chunk.offset, chunk.size)); } allocation->peak_buffers_ = ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace); @@ -2544,8 +2543,8 @@ BufferAssigner::CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment) { - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module, alias_info_)); + ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, alias_info_)); // Set up a schedule for each computation. HloSchedule schedule(module); @@ -2558,9 +2557,9 @@ BufferAssigner::CreateAssignment( } } - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, - HloLiveRange::Run(schedule, *alias_analysis, - module->entry_computation(), true)); + ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, + HloLiveRange::Run(schedule, *alias_analysis, + module->entry_computation(), true)); VLOG(1) << "Assigning buffers to module " << module->name(); XLA_VLOG_LINES(3, module->ToString()); @@ -2576,7 +2575,7 @@ BufferAssigner::CreateAssignment( std::move(color_alignment), std::move(alias_analysis), std::move(hlo_live_range))); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( opts_.colorer(&assignment->alias_analysis(), assignment->hlo_ordering())); VLOG(3) << "After coloring:"; XLA_VLOG_LINES(3, @@ -2584,7 +2583,7 @@ BufferAssigner::CreateAssignment( std::vector thread_local_computations; std::vector global_computations; - TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( + RETURN_IF_ERROR(GatherComputationsByAllocationType( module, &thread_local_computations, &global_computations)); // First assign buffers for global computations. Temporary buffers for @@ -2592,10 +2591,10 @@ BufferAssigner::CreateAssignment( // 'buffers_to_assign_sequentially'. flat_hash_map> buffers_to_assign_sequentially; - TF_RETURN_IF_ERROR(AssignBuffersForComputations( - global_computations, - /*is_thread_local=*/false, &buffers_to_assign_sequentially, - assignment.get())); + RETURN_IF_ERROR(AssignBuffersForComputations(global_computations, + /*is_thread_local=*/false, + &buffers_to_assign_sequentially, + assignment.get())); // Assign buffers with sequential ordering, if any. If all global // computations are sequential, we can run heap simulation on the whole // module, which reduces memory usage. @@ -2608,7 +2607,7 @@ BufferAssigner::CreateAssignment( VLOG(2) << "Multiheap per heap size limit: " << multiheap_size_constraint_per_heap; const PrivateStacks private_stacks; - TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( + RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( buffers_to_assign_sequentially, run_whole_module_heap_simulation, assignment.get(), opts_.buffer_assignment_algorithm, opts_.private_stacks ? *opts_.private_stacks : private_stacks, @@ -2626,7 +2625,7 @@ BufferAssigner::CreateAssignment( thread_local_computations_no_fusion.push_back(computation); } - TF_RETURN_IF_ERROR(AssignBuffersForComputations( + RETURN_IF_ERROR(AssignBuffersForComputations( thread_local_computations_no_fusion, /*is_thread_local=*/true, /*buffers_to_assign_sequentially=*/nullptr, assignment.get())); @@ -2656,8 +2655,8 @@ BufferAssigner::CreateAssignment( } } - TF_RETURN_IF_ERROR(assignment->CombineTempAllocations( - private_stack_colors, opts_.temp_buffer_color)); + RETURN_IF_ERROR(assignment->CombineTempAllocations(private_stack_colors, + opts_.temp_buffer_color)); XLA_VLOG_LINES(2, assignment->ToString()); assignment->ComputeSummaryStats(); @@ -2756,7 +2755,7 @@ ComputeLogicalBufferUnpaddedSizes( } } - TF_ASSIGN_OR_RETURN(Shape subshape, Shape::FromProto(*subshape_proto)); + ASSIGN_OR_RETURN(Shape subshape, Shape::FromProto(*subshape_proto)); // Same logic as tensorflow::profiler::ShapeUnpaddedSize. LayoutUtil::SetToDefaultLayout(&subshape); @@ -2851,8 +2850,8 @@ absl::StatusOr ComputePeakMemoryImpl( absl::StatusOr ComputePeakMemorySizes( const BufferAssignmentProto& proto, const HloModuleProto& hlo) { - TF_ASSIGN_OR_RETURN(auto logical_buffer_unpadded_sizes, - ComputeLogicalBufferUnpaddedSizes(hlo, proto)); + ASSIGN_OR_RETURN(auto logical_buffer_unpadded_sizes, + ComputeLogicalBufferUnpaddedSizes(hlo, proto)); return ComputePeakMemoryImpl(proto, logical_buffer_unpadded_sizes); } diff --git a/third_party/xla/xla/service/call_graph.cc b/third_party/xla/xla/service/call_graph.cc index 2ee5b15a2aa538..a7ed1b1e34cbe8 100644 --- a/third_party/xla/xla/service/call_graph.cc +++ b/third_party/xla/xla/service/call_graph.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -336,7 +337,7 @@ absl::Status CallGraph::VisitNodesInternal( } for (const HloComputation* computation : node.callees()) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( VisitNodesInternal(visitor_func, GetNode(computation), visited)); } @@ -354,13 +355,13 @@ absl::StatusOr CallGraph::VisitNodesInternal( bool changed = false; for (const HloComputation* computation : node.callees()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool node_changed, VisitNodesInternal(visitor_func, GetNode(computation), visited)); changed |= node_changed; } - TF_ASSIGN_OR_RETURN(bool node_changed, visitor_func(node)); + ASSIGN_OR_RETURN(bool node_changed, visitor_func(node)); changed |= node_changed; return changed; } @@ -372,12 +373,12 @@ absl::Status CallGraph::VisitNodes(VisitorFunction visitor_func, // Traverse from all roots in the call graph. for (const CallGraphNode& node : nodes()) { if (node.callers().empty()) { - TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited)); + RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited)); } } } else { // Traverse only from the entry computation. - TF_RETURN_IF_ERROR(VisitNodesInternal( + RETURN_IF_ERROR(VisitNodesInternal( visitor_func, GetNode(module_->entry_computation()), &visited)); } @@ -392,14 +393,14 @@ absl::StatusOr CallGraph::VisitNodesWithReturn( // Traverse from all roots in the call graph. for (const CallGraphNode& node : nodes()) { if (node.callers().empty()) { - TF_ASSIGN_OR_RETURN(bool node_changed, - VisitNodesInternal(visitor_func, node, &visited)); + ASSIGN_OR_RETURN(bool node_changed, + VisitNodesInternal(visitor_func, node, &visited)); changed |= node_changed; } } } else { // Traverse only from the entry computation. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( changed, VisitNodesInternal(visitor_func, GetNode(module_->entry_computation()), &visited)); diff --git a/third_party/xla/xla/service/call_inliner.cc b/third_party/xla/xla/service/call_inliner.cc index 6bdd4fb29f292e..6a2feed5cdaac7 100644 --- a/third_party/xla/xla/service/call_inliner.cc +++ b/third_party/xla/xla/service/call_inliner.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -82,7 +83,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { absl::Status DefaultAction(HloInstruction* hlo) override { std::vector new_operands; for (HloInstruction* operand : hlo->operands()) { - TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand)); + ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand)); new_operands.push_back(new_operand); } VLOG(1) << "Cloning HLO and adding to caller: " << hlo->ToString(); @@ -93,22 +94,22 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { PropagateCallMetadata::PropagateMetadataToInstruction( new_hlo_pointer, call_op_name_, call_stack_frame_id_); } - TF_RETURN_IF_ERROR(NoteMapping(hlo, new_hlo_pointer)); + RETURN_IF_ERROR(NoteMapping(hlo, new_hlo_pointer)); PropagateOriginalValue(new_hlo_pointer, hlo); // Account for control edges. for (HloInstruction* control_predecessor : hlo->control_predecessors()) { - TF_ASSIGN_OR_RETURN(HloInstruction * new_control_predecessor, - Resolve(control_predecessor)); - TF_RETURN_IF_ERROR( + ASSIGN_OR_RETURN(HloInstruction * new_control_predecessor, + Resolve(control_predecessor)); + RETURN_IF_ERROR( new_control_predecessor->AddControlDependencyTo(new_hlo_pointer)); } // The newly inlined instructions should honor the control predecessors of // the previous call instruction. for (HloInstruction* control_predecessor : call_->control_predecessors()) { - TF_RETURN_IF_ERROR(control_predecessor->AddControlDependencyTo( + RETURN_IF_ERROR(control_predecessor->AddControlDependencyTo( /*instruction=*/new_hlo_pointer)); } @@ -119,7 +120,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { // from the subcomputation parameter node to the call operands in the caller // computation. absl::Status HandleParameter(HloInstruction* parameter) override { - TF_RETURN_IF_ERROR(NoteMapping( + RETURN_IF_ERROR(NoteMapping( parameter, call_->mutable_operand(parameter->parameter_number()))); return absl::OkStatus(); } @@ -127,7 +128,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { // Wires the consumers of the call to instead point at the newly created // root, replacing the call operation in the caller computation. absl::Status FinishVisit(HloInstruction* root) override { - TF_ASSIGN_OR_RETURN(HloInstruction * new_root, Resolve(root)); + ASSIGN_OR_RETURN(HloInstruction * new_root, Resolve(root)); VLOG(1) << "Replacing all uses of " << call_->ToString() << " with new root " << new_root->ToString(); auto original_value = new_root->original_value(); @@ -302,7 +303,7 @@ CallInliner::Inline(HloInstruction* call, bool propagate_metadata) { SubcomputationInsertionVisitor visitor( call, call->metadata().op_name(), StackFrameId{call->metadata().stack_frame_id()}, propagate_metadata); - TF_RETURN_IF_ERROR(callee->Accept(&visitor)); + RETURN_IF_ERROR(callee->Accept(&visitor)); return visitor.ConsumeInstructionMap(); } @@ -391,8 +392,8 @@ absl::StatusOr CallInliner::InlineAndLegalize( // The caller instruction will get removed after inlining. Record the // callee computation beforehand, so we can find its schedule. HloComputation* callee = instruction->to_apply(); - TF_ASSIGN_OR_RETURN(InlinedInstructionMap inline_map_cur_call, - Inline(instruction, propagate_metadata_)); + ASSIGN_OR_RETURN(InlinedInstructionMap inline_map_cur_call, + Inline(instruction, propagate_metadata_)); if (module->has_schedule()) { for (HloInstruction* inlined_instruction : module->schedule().sequence(callee).instructions()) { @@ -407,7 +408,7 @@ absl::StatusOr CallInliner::InlineAndLegalize( if (update_domain_) { HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); for (const auto& [call_inst, inlined_inst] : inline_map_cur_call) { - TF_RETURN_IF_ERROR(isolator.UpdateDomains(inlined_inst).status()); + RETURN_IF_ERROR(isolator.UpdateDomains(inlined_inst).status()); } } if (inline_map.has_value()) { @@ -433,7 +434,7 @@ absl::StatusOr CallInliner::RunWithInlineMap( // Because call graph nodes are visited in post-order (callees before callers) // we'll always inline kCalls into their callers in the appropriate order. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool did_mutate, call_graph->VisitNodesWithReturn( [&](const CallGraphNode& node) -> absl::StatusOr { @@ -459,9 +460,9 @@ absl::StatusOr CallInliner::RunWithInlineMap( // were send/recv instructions, which the module group verifier will flag as // error finding the same channel ID used for multiple send/recv // instructions. - TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); + RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); if (module->has_schedule()) { - TF_RETURN_IF_ERROR(module->schedule().Update(execution_threads)); + RETURN_IF_ERROR(module->schedule().Update(execution_threads)); } } return did_mutate; @@ -504,7 +505,7 @@ absl::StatusOr GetInlinedModule(const HloModule* module) { module->CloneWithContext("inline", module->config()); CallInliner::InlinedInstructionMap clone_inlined_map; CallInliner inliner; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( inliner.RunWithInlineMap(cloned_module.get(), &clone_inlined_map, {}) .status()); return InlinedModule{std::move(cloned_module), std::move(clone_context), diff --git a/third_party/xla/xla/service/change_op_data_type.cc b/third_party/xla/xla/service/change_op_data_type.cc index 3bcd57d7a8038e..0f7875039fb5b7 100644 --- a/third_party/xla/xla/service/change_op_data_type.cc +++ b/third_party/xla/xla/service/change_op_data_type.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -90,7 +91,7 @@ absl::StatusOr ChangeOpDataType::RunImpl( if (new_instr->shape().element_type() != instr->shape().element_type()) { new_instr = MakeConvertToHlo(new_instr, instr->shape().element_type()); } - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr)); + RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_instr)); changed = true; } } diff --git a/third_party/xla/xla/service/collective_combiner_utils.h b/third_party/xla/xla/service/collective_combiner_utils.h index 2e4bde02ff7fc9..6b975449435b2a 100644 --- a/third_party/xla/xla/service/collective_combiner_utils.h +++ b/third_party/xla/xla/service/collective_combiner_utils.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -183,7 +184,7 @@ absl::StatusOr CombineInstructionsByKey( } if (to_combine.size() > 1) { - TF_RETURN_IF_ERROR(combine_fn(to_combine)); + RETURN_IF_ERROR(combine_fn(to_combine)); changed = true; } } diff --git a/third_party/xla/xla/service/collective_ops_utils.cc b/third_party/xla/xla/service/collective_ops_utils.cc index a241a5057e8bcc..6a9e681e71b704 100644 --- a/third_party/xla/xla/service/collective_ops_utils.cc +++ b/third_party/xla/xla/service/collective_ops_utils.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/core/collectives/reduction_kind.h" #include "xla/hlo/ir/collective_op_group_mode.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -361,8 +362,8 @@ GetParticipatingDevicesGroups(const HloInstruction* collective) { CHECK(collective->GetModule()->config().has_static_device_assignment()); const DeviceAssignment& device_assignment = collective->GetModule()->config().static_device_assignment(); - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode mode, - GetCollectiveOpGroupMode(collective)); + ASSIGN_OR_RETURN(CollectiveOpGroupMode mode, + GetCollectiveOpGroupMode(collective)); return GetParticipatingDevicesGroups(device_assignment, collective->replica_groups(), mode); } @@ -456,9 +457,8 @@ GetParticipatingFlattenedIdGroups( absl::StatusOr> GetParticipatingFlattenedIdGroups(const HloInstruction* hlo, const DeviceAssignment& device_assignment) { - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode mode, - GetCollectiveOpGroupMode(hlo)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(CollectiveOpGroupMode mode, GetCollectiveOpGroupMode(hlo)); + ASSIGN_OR_RETURN( std::unique_ptr collective_device_list, GetParticipatingFlattenedIdGroups(device_assignment, *hlo->device_list(), mode)); @@ -472,8 +472,8 @@ absl::StatusOr> GetParticipatingDevices( int replica_count = device_assignment.replica_count(); int partition_count = device_assignment.computation_count(); - TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID logical_id, - device_assignment.LogicalIdForDevice(device_id)); + ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID logical_id, + device_assignment.LogicalIdForDevice(device_id)); int current_replica_id = logical_id.replica_id; int current_partition_id = logical_id.computation_id; TF_RET_CHECK(0 <= current_replica_id && current_replica_id < replica_count) @@ -488,9 +488,9 @@ absl::StatusOr> GetParticipatingDevices( // This is a cross replica operation. replica group contains replica id. // use current replica id to find the set of participating replicas. If // replica groups are empty, assume a group with all replicas. - TF_ASSIGN_OR_RETURN(std::vector participating_replicas, - GetParticipatingIDs(group_mode, current_replica_id, - replica_count, replica_groups)); + ASSIGN_OR_RETURN(std::vector participating_replicas, + GetParticipatingIDs(group_mode, current_replica_id, + replica_count, replica_groups)); // The set of participating devices is the replicas from the current // partition. @@ -507,9 +507,9 @@ absl::StatusOr> GetParticipatingDevices( case CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_CROSS_PARTITION: { // replica_groups contain partition_id, group contains all partitions for // the current replica. - TF_ASSIGN_OR_RETURN(std::vector participating_partitions, - GetParticipatingIDs(group_mode, current_partition_id, - partition_count, replica_groups)); + ASSIGN_OR_RETURN(std::vector participating_partitions, + GetParticipatingIDs(group_mode, current_partition_id, + partition_count, replica_groups)); participants.reserve(participating_partitions.size()); for (int partition_id : participating_partitions) { TF_RET_CHECK(0 <= partition_id && partition_id < partition_count) @@ -524,9 +524,9 @@ absl::StatusOr> GetParticipatingDevices( COLLECTIVE_OP_GROUP_MODE_CROSS_REPLICA_AND_PARTITION: { // replica_groups contain replica_ids. Group contains replicas for all // partitions. - TF_ASSIGN_OR_RETURN(std::vector participating_replicas, - GetParticipatingIDs(group_mode, current_replica_id, - replica_count, replica_groups)); + ASSIGN_OR_RETURN(std::vector participating_replicas, + GetParticipatingIDs(group_mode, current_replica_id, + replica_count, replica_groups)); participants.reserve(participating_replicas.size() * partition_count); for (int replica_id : participating_replicas) { TF_RET_CHECK(0 <= replica_id && replica_id < replica_count) @@ -550,7 +550,7 @@ absl::StatusOr> GetParticipatingDevices( // Find participants based on flattened id. replica_groups cannot be empty // so no need to pass in total_participant_count. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector participating_flattened_ids, GetParticipatingIDs(group_mode, current_flattened_id, /*total_participant_count=*/std::nullopt, @@ -654,12 +654,12 @@ GetReplicaGroupCountAndSize(const HloInstruction* hlo) { return std::make_pair(device_list->num_replica_groups(), device_list->num_devices_per_group()); } - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(hlo)); - TF_ASSIGN_OR_RETURN(std::vector participant_counts, - GetPariticipantCountsForReplicaGroups( - config.replica_count(), config.num_partitions(), - device_list->replica_groups(), group_mode)); + ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(hlo)); + ASSIGN_OR_RETURN(std::vector participant_counts, + GetPariticipantCountsForReplicaGroups( + config.replica_count(), config.num_partitions(), + device_list->replica_groups(), group_mode)); int64_t replica_group_size = participant_counts[0]; for (int64_t participant_count : participant_counts) { if (participant_count != replica_group_size) { diff --git a/third_party/xla/xla/service/collective_ops_utils_test.cc b/third_party/xla/xla/service/collective_ops_utils_test.cc index c5eaeab1b3282e..6b3c53235b330f 100644 --- a/third_party/xla/xla/service/collective_ops_utils_test.cc +++ b/third_party/xla/xla/service/collective_ops_utils_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/array2d.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -656,12 +657,12 @@ class GetCollectOpGroupModeTestForInstruction absl::StatusOr> CreateMaxComputation() { Shape scalar = ShapeUtil::MakeScalarShape(F32); auto builder_max = HloComputation::Builder("max"); - TF_ASSIGN_OR_RETURN(HloInstruction * a, - builder_max.AddParameter( - HloInstruction::CreateParameter(0, scalar, "a"))); - TF_ASSIGN_OR_RETURN(HloInstruction * b, - builder_max.AddParameter( - HloInstruction::CreateParameter(1, scalar, "b"))); + ASSIGN_OR_RETURN(HloInstruction * a, + builder_max.AddParameter( + HloInstruction::CreateParameter(0, scalar, "a"))); + ASSIGN_OR_RETURN(HloInstruction * b, + builder_max.AddParameter( + HloInstruction::CreateParameter(1, scalar, "b"))); HloInstruction* max = builder_max.AddInstruction( HloInstruction::CreateBinary(scalar, HloOpcode::kMaximum, a, b), "max"); return builder_max.Build(max); diff --git a/third_party/xla/xla/service/collective_permute_decomposer.cc b/third_party/xla/xla/service/collective_permute_decomposer.cc index 7d1f735df02342..0258f9f4437135 100644 --- a/third_party/xla/xla/service/collective_permute_decomposer.cc +++ b/third_party/xla/xla/service/collective_permute_decomposer.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -174,8 +175,8 @@ static absl::StatusOr DecomposeCollectivePermute( HloInstruction::CreateGetTupleElement(recv_done, 0), absl::StrCat(cp_name, "-recv-data")); - TF_RETURN_IF_ERROR(cp->ReplaceAllUsesWith(recv_data)); - TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(cp)); + RETURN_IF_ERROR(cp->ReplaceAllUsesWith(recv_data)); + RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(cp)); // We choose to run recv before send as an invariant, which helps avoid // deadlocks. At the same time, running recv before send allows for pipelining @@ -183,11 +184,11 @@ static absl::StatusOr DecomposeCollectivePermute( // pipeline parallelism. switch (pipeline_parallelism_opt_level) { case DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE: - TF_RETURN_IF_ERROR(recv->AddControlDependencyTo(send)); - TF_RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done)); + RETURN_IF_ERROR(recv->AddControlDependencyTo(send)); + RETURN_IF_ERROR(send->AddControlDependencyTo(recv_done)); break; case DebugOptions::PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE: - TF_RETURN_IF_ERROR(recv_done->AddControlDependencyTo(send)); + RETURN_IF_ERROR(recv_done->AddControlDependencyTo(send)); break; default: return absl::InvalidArgumentError( @@ -270,8 +271,8 @@ static absl::Status EnforceOrderOfSendRecvChain( for (size_t i = 1; i < deco_post_order.size(); ++i) { DecomposedCp& cur = deco_post_order[i]; DecomposedCp& prev = deco_post_order[i - 1]; - TF_RETURN_IF_ERROR(prev.send->AddControlDependencyTo(cur.recv)); - TF_RETURN_IF_ERROR(prev.send_done->AddControlDependencyTo(cur.recv_done)); + RETURN_IF_ERROR(prev.send->AddControlDependencyTo(cur.recv)); + RETURN_IF_ERROR(prev.send_done->AddControlDependencyTo(cur.recv_done)); } return absl::OkStatus(); } @@ -285,7 +286,7 @@ static absl::Status EnforceOrderOfSendRecvChainRelativeToConflictingCollectives( // Add control dependencies from chain to all conflicting collectives. for (HloInstruction* instr : conflicting_collectives) { - TF_RETURN_IF_ERROR(last_in_chain->AddControlDependencyTo(instr)); + RETURN_IF_ERROR(last_in_chain->AddControlDependencyTo(instr)); } return absl::OkStatus(); @@ -416,7 +417,7 @@ absl::StatusOr CollectivePermuteDecomposer::RunImpl( } else if (cp1_to_pipeline == cp) { pipeline_decision = "1"; } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( DecomposedCp decomposed_ops, DecomposeCollectivePermute(cp, computation, pipeline_decision, pipeline_parallelism_opt_level_)); @@ -437,10 +438,9 @@ absl::StatusOr CollectivePermuteDecomposer::RunImpl( // enforce all other conflicting collectives to follow the send/recv chain // so that these cannot be scheduled in between the send/recv, which would // also lead to deadlocks. - TF_RETURN_IF_ERROR(EnforceOrderOfSendRecvChain(deco_post_order)); - TF_RETURN_IF_ERROR( - EnforceOrderOfSendRecvChainRelativeToConflictingCollectives( - deco_post_order, conflicing_collectives)); + RETURN_IF_ERROR(EnforceOrderOfSendRecvChain(deco_post_order)); + RETURN_IF_ERROR(EnforceOrderOfSendRecvChainRelativeToConflictingCollectives( + deco_post_order, conflicing_collectives)); if (!cps_to_decompose.empty()) { changed = true; diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 475d346a9076b8..c385fe18511fce 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -40,6 +40,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/analysis/tuple_points_to_analysis.h" @@ -106,7 +107,7 @@ absl::Status UpdateControlDependencies(HloInstruction* original, // in the same computation eventually and we do want to add the control // dependency here. if (it->second->parent() == new_instr->parent()) { - TF_RETURN_IF_ERROR(it->second->AddControlDependencyTo(new_instr)); + RETURN_IF_ERROR(it->second->AddControlDependencyTo(new_instr)); } } return absl::OkStatus(); @@ -774,13 +775,13 @@ void UpdateInstructionChannelId(HloInstruction* cloned_instr, absl::Status UpdateInstructionSchedulingAnnotation( HloInstruction* cloned_instr, int64_t& scheduling_id, absl::flat_hash_map& annotation_map) { - TF_ASSIGN_OR_RETURN(std::optional annotation_idx, - GetSchedulingAnnotationGroupId(cloned_instr)); + ASSIGN_OR_RETURN(std::optional annotation_idx, + GetSchedulingAnnotationGroupId(cloned_instr)); if (annotation_idx) { if (!annotation_map.contains(*annotation_idx)) { annotation_map[*annotation_idx] = scheduling_id++; } - TF_RETURN_IF_ERROR(SetSchedulingAnnotationGroupId( + RETURN_IF_ERROR(SetSchedulingAnnotationGroupId( cloned_instr, annotation_map[*annotation_idx])); } return absl::OkStatus(); @@ -812,7 +813,7 @@ absl::StatusOr CloneBackwardChain( auto new_operands = MapNewOperands(chain_op->operands(), clone_map); HloInstruction* cloned = target_computation.AddInstruction( chain_op->CloneWithNewOperands(chain_op->shape(), new_operands)); - TF_RETURN_IF_ERROR(UpdateControlDependencies(chain_op, cloned, clone_map)); + RETURN_IF_ERROR(UpdateControlDependencies(chain_op, cloned, clone_map)); UpdateInstructionChannelId(cloned, next_channel_id, update_collective_channel_id); if (next_scheduling_id != -1) { @@ -820,7 +821,7 @@ absl::StatusOr CloneBackwardChain( } clone_map[chain_op] = cloned; if (postprocess_pipelined_ops) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( postprocess_pipelined_ops(cloned, /*new_while_instr=*/nullptr)); } last_cloned = cloned; @@ -1947,8 +1948,8 @@ absl::StatusOr TransformLoopForward( // Duplicate the loop body into the loop parent computation, so that the first // iteration happens there. - TF_ASSIGN_OR_RETURN(int64_t next_scheduling_id, - NextSchedulingGroupId(*while_loop->GetModule())); + ASSIGN_OR_RETURN(int64_t next_scheduling_id, + NextSchedulingGroupId(*while_loop->GetModule())); absl::flat_hash_map annotation_map; for (auto* instr : while_body->MakeInstructionPostOrder()) { if (instr == loop_parameter) { @@ -1981,7 +1982,7 @@ absl::StatusOr TransformLoopForward( loop_computation->parent()->AddEmbeddedComputation( instr->while_body()->CloneWithReplacements(nullptr))); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( UpdateControlDependencies(instr, cloned_instr, while_body_to_peeled)); UpdateInstructionChannelId(cloned_instr, next_channel_id, update_collective_channel_id); @@ -1991,7 +1992,7 @@ absl::StatusOr TransformLoopForward( .iteration_id.has_value()) { RemoveSchedulingAnnotation(cloned_instr); } else { - TF_RETURN_IF_ERROR(UpdateInstructionSchedulingAnnotation( + RETURN_IF_ERROR(UpdateInstructionSchedulingAnnotation( cloned_instr, next_scheduling_id, annotation_map)); } while_body_to_peeled[instr] = cloned_instr; @@ -2062,24 +2063,22 @@ absl::StatusOr TransformLoopForward( HloInstruction* new_init = loop_computation->AddInstruction( HloInstruction::CreateTuple(new_init_operands)); while_body_to_peeled[while_body->root_instruction()] = new_init; - TF_RETURN_IF_ERROR(UpdateControlDependencies(while_body->root_instruction(), - new_init, while_body_to_peeled)); + RETURN_IF_ERROR(UpdateControlDependencies(while_body->root_instruction(), + new_init, while_body_to_peeled)); HloInstruction* new_while_loop = loop_computation->AddInstruction(HloInstruction::CreateWhile( loop_state_shape, new_while_condition, new_while_body, new_init)); - TF_RETURN_IF_ERROR( - while_loop->ReplaceAllUsesWithDifferentShape(new_while_loop)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(while_loop->ReplaceAllUsesWithDifferentShape(new_while_loop)); + RETURN_IF_ERROR( loop_computation->RemoveInstructionAndUnusedOperands(while_loop)); - TF_RETURN_IF_ERROR(new_while_loop->GetModule()->RemoveUnusedComputations()); + RETURN_IF_ERROR(new_while_loop->GetModule()->RemoveUnusedComputations()); // Run WhileLoopAnalysis again on the new loop to collect the position of the // all-reduces in the new cloned loop as they aren't the same of the old. // Loop analysis should result exactly the same, because the loop is the same // except some new scalar unused parameters added at the end. - TF_ASSIGN_OR_RETURN( - std::unique_ptr new_dataflow_analysis, - HloDataflowAnalysis::Run(*(new_while_loop->GetModule()), - /*ssa_form=*/true)); + ASSIGN_OR_RETURN(std::unique_ptr new_dataflow_analysis, + HloDataflowAnalysis::Run(*(new_while_loop->GetModule()), + /*ssa_form=*/true)); WhileLoopAnalysis new_loop_analysis( new_while_loop, loop_analysis.GetMaxPipeliningPerLoop(), pipeline_use_tree, process_different_sized_ops, @@ -2143,7 +2142,7 @@ absl::StatusOr TransformLoopForward( } if (post_processing_fn) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( post_processing_fn(processed, /*new_while_instr=*/nullptr)); } InstructionMap new_cloned_map; @@ -2163,7 +2162,7 @@ absl::StatusOr TransformLoopForward( } (*formatting_cloned_map)[formatting_op] = processed; if (post_processing_fn) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( post_processing_fn(processed, /*new_while_instr=*/nullptr)); } } @@ -2204,7 +2203,7 @@ absl::StatusOr TransformLoopForward( computation->AddInstruction(HloInstruction::CreateDynamicSlice( slice_target_shape, data_to_slice, indices, dynamic_slice_sizes)); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( sliced_data, process_slice(sliced_data, pipelined_values_map, move_info, /*formatting_cloned_map=*/nullptr, update_annotations)); @@ -2224,15 +2223,15 @@ absl::StatusOr TransformLoopForward( } auto replace_instructions_with = [](absl::Span to_replace_instrs, - HloInstruction* new_instr) { - for (auto* to_replace : to_replace_instrs) { - HloComputation* computation = to_replace->parent(); - TF_RETURN_IF_ERROR(to_replace->ReplaceAllUsesWith(new_instr)); - TF_RETURN_IF_ERROR( - computation->RemoveInstructionAndUnusedOperands(to_replace)); - } - return absl::OkStatus(); - }; + HloInstruction* new_instr) -> absl::Status { + for (auto* to_replace : to_replace_instrs) { + HloComputation* computation = to_replace->parent(); + RETURN_IF_ERROR(to_replace->ReplaceAllUsesWith(new_instr)); + RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(to_replace)); + } + return absl::OkStatus(); + }; if (move_info.dynamic_update_slices.empty()) { auto it = moves_requiring_special_output_to_idx.find(i); CHECK(it != moves_requiring_special_output_to_idx.end()); @@ -2248,7 +2247,7 @@ absl::StatusOr TransformLoopForward( new_while_loop, it->second)); // This is the case where the instruction is a sink. InstructionMap updated_pipelined_values_map; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto* in_sink_instr, process_slice(input_data_to_sink, pipelined_values_map_inloop, move_info, @@ -2263,7 +2262,7 @@ absl::StatusOr TransformLoopForward( move_info.sink_instruction->CloneWithNewOperands( move_info.sink_instruction->shape(), new_input_sink_operands)); updated_pipelined_values_map.clear(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto* out_sink_instr, process_slice(output_data_to_sink, pipelined_values_map_outloop, move_info, @@ -2280,11 +2279,11 @@ absl::StatusOr TransformLoopForward( move_info.sink_instruction->CloneWithNewOperands( move_info.sink_instruction->shape(), new_output_sink_operands)); RemoveSchedulingAnnotation(out_sink_instr); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( move_info.sink_instruction->ReplaceAllUsesWith(in_sink_instr)); - TF_RETURN_IF_ERROR(new_while_body->RemoveInstructionAndUnusedOperands( + RETURN_IF_ERROR(new_while_body->RemoveInstructionAndUnusedOperands( move_info.sink_instruction)); - TF_RETURN_IF_ERROR(replace_instructions_with( + RETURN_IF_ERROR(replace_instructions_with( absl::MakeSpan(loop_output_to_replace), out_sink_instr)); continue; } @@ -2328,12 +2327,12 @@ absl::StatusOr TransformLoopForward( move_info.collectives_to_move.front()->operand(0)->shape(), new_while_loop, it->second)); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( input_stacked_data, extract_and_process_slice(input_stacked_data, input_data_to_slice, move_info, pipelined_values_map_inloop, input_dus_idx, /*update_annotations=*/false)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( output_stacked_data, extract_and_process_slice(output_stacked_data, output_data_to_slice, move_info, pipelined_values_map_outloop, @@ -2349,13 +2348,13 @@ absl::StatusOr TransformLoopForward( move_info.sliced_idx), input_stacked_data); } - TF_RETURN_IF_ERROR(dyn_update->ReplaceAllUsesWith(new_peeled_dus)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(dyn_update->ReplaceAllUsesWith(new_peeled_dus)); + RETURN_IF_ERROR( new_while_body->RemoveInstructionAndUnusedOperands(dyn_update)); - TF_RETURN_IF_ERROR(replace_instructions_with( + RETURN_IF_ERROR(replace_instructions_with( absl::MakeSpan(loop_output_to_replace), output_stacked_data)); } - TF_RETURN_IF_ERROR(loop_computation->parent()->RemoveUnusedComputations()); + RETURN_IF_ERROR(loop_computation->parent()->RemoveUnusedComputations()); return new_while_loop; } @@ -2710,8 +2709,8 @@ absl::StatusOr TransformLoopForwardSink( inst->operand(1)->IsCustomCall( CollectivePipeliner::kSunkByPreviousStep)) { HloInstruction* cc = inst->mutable_operand(1); - TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(1, cc->mutable_operand(0))); - TF_RETURN_IF_ERROR(cc->parent()->RemoveInstruction(cc)); + RETURN_IF_ERROR(inst->ReplaceOperandWith(1, cc->mutable_operand(0))); + RETURN_IF_ERROR(cc->parent()->RemoveInstruction(cc)); } } CHECK_EQ(while_body->root_instruction()->opcode(), HloOpcode::kTuple); @@ -2931,8 +2930,8 @@ absl::StatusOr TransformLoopForwardSink( cloned_body->AddInstruction(HloInstruction::CreateGetTupleElement( output->shape(), cloned_body->parameter_instruction(0), *idx)); HloInstruction* old_operand_param = output->mutable_operand(0); - TF_RETURN_IF_ERROR(output->ReplaceOperandWith(0, new_param)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(output->ReplaceOperandWith(0, new_param)); + RETURN_IF_ERROR( old_operand_param->parent()->RemoveInstruction(old_operand_param)); if (insert_non_alias_custom_call && indices_to_insert.contains(i)) { auto* old_operand = output->mutable_operand(1); @@ -2940,7 +2939,7 @@ absl::StatusOr TransformLoopForwardSink( cloned_body->AddInstruction(HloInstruction::CreateCustomCall( old_operand->shape(), {old_operand}, /*custom_call_target=*/CollectivePipeliner::kSunkByPreviousStep)); - TF_RETURN_IF_ERROR(output->ReplaceOperandWith(1, custom_call)); + RETURN_IF_ERROR(output->ReplaceOperandWith(1, custom_call)); } } HloInstruction* new_while = @@ -3008,7 +3007,7 @@ absl::StatusOr TransformLoopForwardSink( // an effect on the instruction itself (like say broadcast, slices ... // etc). for (HloInstruction* formatting_op : to_move.formatting_ops) { - TF_RETURN_IF_ERROR(TransformFormattingOp( + RETURN_IF_ERROR(TransformFormattingOp( formatting_op, to_move, loop_computation, pipelined_map, to_add_batch_set, next_channel_id, update_collective_channel_id)); } @@ -3038,10 +3037,10 @@ absl::StatusOr TransformLoopForwardSink( } HloInstruction* new_tuple = loop_computation->AddInstruction( HloInstruction::CreateTuple(new_output_tuple)); - TF_RETURN_IF_ERROR(while_loop->ReplaceAllUsesWithDifferentShape(new_tuple)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(while_loop->ReplaceAllUsesWithDifferentShape(new_tuple)); + RETURN_IF_ERROR( loop_computation->RemoveInstructionAndUnusedOperands(while_loop)); - TF_RETURN_IF_ERROR(loop_computation->parent()->RemoveUnusedComputations()); + RETURN_IF_ERROR(loop_computation->parent()->RemoveUnusedComputations()); return new_while; } @@ -3163,8 +3162,8 @@ static absl::StatusOr TransformLoopBackward( // Add to the rewritten loop the new parameter/output data that is going to be // pipelined. Clone chains of pipelined data in the parent computation in the // process (they will endup being executed before the loop). - TF_ASSIGN_OR_RETURN(int64_t next_scheduling_id, - NextSchedulingGroupId(*while_loop->GetModule())); + ASSIGN_OR_RETURN(int64_t next_scheduling_id, + NextSchedulingGroupId(*while_loop->GetModule())); absl::flat_hash_map annotation_map; for (int i = 0; i < loop_analysis.GetMoveInfos().size(); ++i) { const int64_t idx = i + loop_parameter->shape().tuple_shapes().size(); @@ -3172,7 +3171,7 @@ static absl::StatusOr TransformLoopBackward( loop_analysis.GetMoveInfos()[i].collectives_to_move[0]->shape(); new_root_operands[idx] = loop_analysis.GetMoveInfos()[i].collectives_to_move[0]; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( new_init_operands[idx], CloneBackwardChain( *while_loop->parent(), loop_analysis.GetMoveInfos()[i], @@ -3182,12 +3181,12 @@ static absl::StatusOr TransformLoopBackward( /*loop_variant_parameter_info=*/nullptr, post_processing_fn)); if (post_processing_fn) { - TF_RETURN_IF_ERROR(post_processing_fn(new_init_operands[idx], - /*new_while_instr=*/nullptr)); + RETURN_IF_ERROR(post_processing_fn(new_init_operands[idx], + /*new_while_instr=*/nullptr)); } if (postprocess_peeled) { - TF_RETURN_IF_ERROR(postprocess_peeled(new_init_operands[idx], - /*new_while_instr=*/nullptr)); + RETURN_IF_ERROR(postprocess_peeled(new_init_operands[idx], + /*new_while_instr=*/nullptr)); } } ConstantValue next_loop_iteration = @@ -3233,7 +3232,7 @@ static absl::StatusOr TransformLoopBackward( // Passing -1 as the next scheduling id to indicate that we should not // update the scheduling id of the cloned instructions. int64_t next_scheduling_id = -1; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( cloned_instr, CloneBackwardChain( body_builder, loop_analysis.GetMoveInfos()[it->second], @@ -3243,11 +3242,11 @@ static absl::StatusOr TransformLoopBackward( &loop_variant_parameter_info, post_processing_fn)); if (post_processing_fn) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( post_processing_fn(cloned_instr, /*new_while_instr=*/nullptr)); } if (postprocess_rotated) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( postprocess_rotated(cloned_instr, /*new_while_instr=*/nullptr)); } } else { @@ -3271,8 +3270,8 @@ static absl::StatusOr TransformLoopBackward( while_loop->GetModule()->AddEmbeddedComputation( instr->while_body()->CloneWithReplacements(nullptr))); } - TF_RETURN_IF_ERROR(UpdateControlDependencies( - instr, cloned_instr, while_body_replacement_map)); + RETURN_IF_ERROR(UpdateControlDependencies(instr, cloned_instr, + while_body_replacement_map)); UpdateInstructionChannelId(cloned_instr, next_channel_id, update_collective_channel_id); } @@ -3297,7 +3296,7 @@ static absl::StatusOr TransformLoopBackward( auto it = while_body_replacement_map.find(new_root_operands[idx]); CHECK(it != while_body_replacement_map.end()) << new_root_operands[idx]->ToString() << " not present in map"; - TF_RETURN_IF_ERROR(value->ReplaceAllUsesWith(it->second)); + RETURN_IF_ERROR(value->ReplaceAllUsesWith(it->second)); } new_root_operands.back() = @@ -3318,9 +3317,9 @@ static absl::StatusOr TransformLoopBackward( HloComputation* new_while_body = while_loop->GetModule()->AddEmbeddedComputation( body_builder.Build(new_loop_root)); - TF_RETURN_IF_ERROR(UpdateControlDependencies(while_body->root_instruction(), - new_loop_root, - while_body_replacement_map)); + RETURN_IF_ERROR(UpdateControlDependencies(while_body->root_instruction(), + new_loop_root, + while_body_replacement_map)); auto cond_builder = HloComputation::Builder(while_loop->while_condition()->name()); HloInstruction* new_cond_param = @@ -3357,8 +3356,8 @@ static absl::StatusOr TransformLoopBackward( cond_builder.Build(comparison)); HloInstruction* new_loop_init = while_loop->parent()->AddInstruction( HloInstruction::CreateTuple(new_init_operands)); - TF_RETURN_IF_ERROR(UpdateControlDependencies(while_body->root_instruction(), - new_loop_init, chain_clone_map)); + RETURN_IF_ERROR(UpdateControlDependencies(while_body->root_instruction(), + new_loop_init, chain_clone_map)); // Create the new loop. HloInstruction* new_while_loop = while_loop->parent()->AddInstruction(HloInstruction::CreateWhile( @@ -3397,15 +3396,15 @@ static absl::StatusOr TransformLoopBackward( instr->CloneWithNewOperands(instr->shape(), new_operands)); if (postprocess_peeled_trailing_op) { CHECK_NE(new_while_loop, nullptr); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( postprocess_peeled_trailing_op(cloned_instr, new_while_loop)); } - TF_RETURN_IF_ERROR(UpdateControlDependencies(instr, cloned_instr, - while_body_replacement_map)); + RETURN_IF_ERROR(UpdateControlDependencies(instr, cloned_instr, + while_body_replacement_map)); UpdateInstructionChannelId(cloned_instr, next_channel_id, update_collective_channel_id); - TF_RETURN_IF_ERROR(UpdateInstructionSchedulingAnnotation( + RETURN_IF_ERROR(UpdateInstructionSchedulingAnnotation( cloned_instr, next_scheduling_id, annotation_map)); while_body_replacement_map[instr] = cloned_instr; if (instruction_is_output_it != is_output_instruction.end()) { @@ -3418,11 +3417,11 @@ static absl::StatusOr TransformLoopBackward( HloInstruction* final_loop_output = while_loop->parent()->AddInstruction( HloInstruction::CreateTuple(output_tuple_instructions)); HloComputation* loop_computation = while_loop->parent(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( while_loop->ReplaceAllUsesWithDifferentShape(final_loop_output)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( loop_computation->RemoveInstructionAndUnusedOperands(while_loop)); - TF_RETURN_IF_ERROR(loop_computation->parent()->RemoveUnusedComputations()); + RETURN_IF_ERROR(loop_computation->parent()->RemoveUnusedComputations()); return new_while_loop; } @@ -3434,11 +3433,11 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( // Precompute module-scoped analyses. Because we are running a while-loop // analysis over all while instructions in the module, computing them here and // passing them in avoids recomputing them once for each while instruction. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr tuple_points_to_analysis, TuplePointsToAnalysis::Run(module)); - TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow_analysis, - HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + ASSIGN_OR_RETURN(std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); std::vector>> loop_analyses; @@ -3499,7 +3498,7 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( if (config_.pipelining_direction == collective_pipeliner_utils::PipeliningDirection::kForward) { CHECK(config_.reuse_pipelined_op_buffer); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( transformed_while_loop, TransformLoopForward( *loop_analysis, !config_.last_run, config_.level_to_operate_on, @@ -3509,7 +3508,7 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( config_.unique_channel_id, config_.postprocess_pipelined_ops)); } else if (config_.pipelining_direction == collective_pipeliner_utils::PipeliningDirection::kForwardSink) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( transformed_while_loop, TransformLoopForwardSink( *loop_analysis, !config_.last_run, config_.level_to_operate_on, @@ -3519,7 +3518,7 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( } else { CHECK_EQ(config_.pipelining_direction, collective_pipeliner_utils::PipeliningDirection::kBackward); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( transformed_while_loop, TransformLoopBackward( *loop_analysis, !config_.last_run, config_.level_to_operate_on, @@ -3531,7 +3530,7 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( config_.unique_channel_id, config_.postprocess_pipelined_ops)); } if (config_.postprocess_transformed_while_loop) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( config_.postprocess_transformed_while_loop(transformed_while_loop)); } ++transformed_loops; @@ -3546,16 +3545,15 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( if (instruction->IsCustomCall( CollectivePipeliner::kInsertedByPreviousStep)) { to_remove.push_back(instruction); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); changed = true; } } } for (auto* instruction : to_remove) { - TF_RETURN_IF_ERROR( - instruction->parent()->RemoveInstructionAndUnusedOperands( - instruction)); + RETURN_IF_ERROR(instruction->parent()->RemoveInstructionAndUnusedOperands( + instruction)); } } VLOG(1) << "Transformed loops: " << transformed_loops @@ -3564,7 +3562,7 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( << GetPipelineDirectionString(config_.pipelining_direction); // Run necessary cleanup to make sure unused code doesn't trigger HloVerifier. if (changed) { - TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); + RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); } return changed; @@ -3589,13 +3587,13 @@ absl::StatusOr CollectivePipeliner::RunImpl( bool changed = true; int64_t iter = 0; while (changed) { - TF_ASSIGN_OR_RETURN(changed, RunPipeliner(module, execution_threads)); + ASSIGN_OR_RETURN(changed, RunPipeliner(module, execution_threads)); VLOG(1) << "Finished running pipeliner's iteration for small collectives: " << iter; iter++; } config_.delay_sinking_large_collectives = false; - TF_ASSIGN_OR_RETURN(changed, RunPipeliner(module, execution_threads)); + ASSIGN_OR_RETURN(changed, RunPipeliner(module, execution_threads)); VLOG(1) << "Finished running pipeliner's iteration for large collectives: " << iter; return iter > 1 || changed; diff --git a/third_party/xla/xla/service/compilation_environments.cc b/third_party/xla/xla/service/compilation_environments.cc index 8d38b7ebd9652f..2309a5c80ec5fa 100644 --- a/third_party/xla/xla/service/compilation_environments.cc +++ b/third_party/xla/xla/service/compilation_environments.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/unknown_field_set.h" @@ -162,7 +163,7 @@ CompilationEnvironments::CreateFromProto( "'")); } - TF_RETURN_IF_ERROR(envs->AddEnv(std::move(env))); + RETURN_IF_ERROR(envs->AddEnv(std::move(env))); } return envs; @@ -211,7 +212,7 @@ absl::Status CompilationEnvironments::InitializeAllKnownEnvs() { for (const auto& descriptor : descriptors) { auto it = environments_.find(descriptor); if (it == environments_.end()) { - TF_RETURN_IF_ERROR(AddEnvImpl(*descriptor, nullptr)); + RETURN_IF_ERROR(AddEnvImpl(*descriptor, nullptr)); DefaultEnvCreatedByCompilationEnvironments(descriptor->full_name()); } } @@ -288,8 +289,8 @@ absl::Status CompilationEnvironments::AddEnvImpl( return absl::InvalidArgumentError(absl::StrCat( "Unknown CompilationEnvironment type ", descriptor.full_name())); } - TF_ASSIGN_OR_RETURN(std::unique_ptr processed_env, - process_new_env(std::move(env))); + ASSIGN_OR_RETURN(std::unique_ptr processed_env, + process_new_env(std::move(env))); // Check for unknown fields const google::protobuf::UnknownFieldSet& unknown_fields = diff --git a/third_party/xla/xla/service/compile_only_service.cc b/third_party/xla/xla/service/compile_only_service.cc index a5a8bbcdc1b0b0..7133733ee18e94 100644 --- a/third_party/xla/xla/service/compile_only_service.cc +++ b/third_party/xla/xla/service/compile_only_service.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/service/backend.h" #include "xla/service/compiler.h" @@ -46,10 +47,10 @@ CompileOnlyService::NewService(se::Platform* platform) { CompileOnlyService::NewService(const ServiceOptions& options) { se::Platform* platform = options.platform(); if (platform == nullptr) { - TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform->id())); + ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform->id())); std::unique_ptr service( new CompileOnlyService(options, std::move(compiler))); @@ -119,15 +120,15 @@ CompileOnlyService::CompileAheadOfTime( update_shape_with_empty_tiles); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ProgramShape program_shape, ProgramShape::FromProto(computation.computation.host_program_shape())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(program_shape, computation.argument_layouts, &execution_options, &options)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr hlo_module, HloModule::CreateFromProto(computation.computation, *module_config)); DumpHloModuleIfEnabled(*hlo_module, "before_optimizations"); diff --git a/third_party/xla/xla/service/computation_layout.cc b/third_party/xla/xla/service/computation_layout.cc index 1145102a78cfec..f3b1be6f43e9a6 100644 --- a/third_party/xla/xla/service/computation_layout.cc +++ b/third_party/xla/xla/service/computation_layout.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/layout.h" #include "xla/printer.h" #include "xla/shape.h" @@ -63,7 +64,7 @@ absl::StatusOr> ComputationLayout::FlattenedParameterLayouts() const { std::vector result; for (int i = 0; i < parameter_count(); ++i) { - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( parameter_shape(i), [this, &result](const Shape& subshape, const ShapeIndex& index) -> absl::Status { @@ -92,7 +93,7 @@ ComputationLayout::FlattenedParameterLayouts() const { absl::StatusOr> ComputationLayout::FlattenedResultLayouts() const { std::vector result; - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( result_shape(), [this, &result](const Shape& subshape, const ShapeIndex& index) -> absl::Status { diff --git a/third_party/xla/xla/service/computation_placer.cc b/third_party/xla/xla/service/computation_placer.cc index fc650cc0bdaaad..177b51ce45c900 100644 --- a/third_party/xla/xla/service/computation_placer.cc +++ b/third_party/xla/xla/service/computation_placer.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/runtime/device_id.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/host/host_platform_id.h" @@ -67,15 +68,13 @@ DeviceAssignment::LogicalIdForDevice(GlobalDeviceId device_id) const { absl::StatusOr DeviceAssignment::ReplicaIdForDevice( GlobalDeviceId device_id) const { - TF_ASSIGN_OR_RETURN(const LogicalID logical_id, - LogicalIdForDevice(device_id)); + ASSIGN_OR_RETURN(const LogicalID logical_id, LogicalIdForDevice(device_id)); return logical_id.replica_id; } absl::StatusOr DeviceAssignment::PartitionIdForDevice( GlobalDeviceId device_id) const { - TF_ASSIGN_OR_RETURN(const LogicalID logical_id, - LogicalIdForDevice(device_id)); + ASSIGN_OR_RETURN(const LogicalID logical_id, LogicalIdForDevice(device_id)); return logical_id.computation_id; } diff --git a/third_party/xla/xla/service/conditional_code_motion.cc b/third_party/xla/xla/service/conditional_code_motion.cc index 4e628a83354aa9..972dbe27ec31d7 100644 --- a/third_party/xla/xla/service/conditional_code_motion.cc +++ b/third_party/xla/xla/service/conditional_code_motion.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" @@ -535,8 +536,7 @@ absl::Status RestructureConditionalInstruction(HloComputation* computation, } } for (auto new_tuple_user : new_tuple_users) { - TF_RETURN_IF_ERROR( - conditional->ReplaceUseWith(new_tuple_user, new_tuple)); + RETURN_IF_ERROR(conditional->ReplaceUseWith(new_tuple_user, new_tuple)); } } VLOG(2) << "computation after root restructure:\n" << computation->ToString(); @@ -585,7 +585,7 @@ absl::StatusOr ConvertSpecialMove(HloInstruction* conditional, return false; } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( RestructureConditionalInstruction(conditional->parent(), conditional)); for (int branch = 0; branch < branch_count; branch++) { @@ -648,7 +648,7 @@ absl::StatusOr ConvertSpecialMove(HloInstruction* conditional, absl::MakeSpan(conditional->branch_computations()), absl::MakeSpan(conditional->operands()).subspan(1))); // Ensure that all the users of conditional refer to the new one. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( conditional->ReplaceAllUsesWithDifferentShape(newconditional)); CHECK_OK(conditional_parent->RemoveInstruction(conditional)); conditional = newconditional; @@ -677,7 +677,7 @@ absl::StatusOr ConvertSpecialMove(HloInstruction* conditional, HloInstruction* hoisted = conditional_parent->AddInstruction( hoist->CloneWithNewOperands(hoist->shape(), new_operands)); VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString(); - TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted)); + RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted)); CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist)); } // No need to explicitly delete a hoisted instruction since if its dead @@ -749,8 +749,8 @@ absl::StatusOr ConditionalCodeMotion::MoveInstructionOut( CHECK(new_opd != nullptr); VLOG(2) << "Try replace all uses of :" << old_user_boundary.ToString() << "\n"; - TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd)); - TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr)); + RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd)); + RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr)); } VLOG(2) << "Done changing conditional users\n" << conditional_parent->ToString() << "\n"; @@ -780,7 +780,7 @@ absl::StatusOr ConditionalCodeMotion::MoveInstructionOut( instr_to_remove->IsDead()) { VLOG(2) << "Removing boundary:" << b2.ToString() << "\n"; VLOG(2) << "computation: " << computation->ToString() << "\n"; - TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove)); + RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove)); } } } @@ -876,7 +876,7 @@ absl::StatusOr ConditionalCodeMotion::MoveUserInstructionsIn( computation->set_root_instruction(new_root, /*accept_different_shape*/ true); if (old_root->opcode() == HloOpcode::kTuple) { - TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root)); + RETURN_IF_ERROR(computation->RemoveInstruction(old_root)); } VLOG(2) << "new branch computation: " << computation->ToString() << "\n"; } @@ -927,7 +927,7 @@ absl::StatusOr ConditionalCodeMotion::MoveUserInstructionsIn( HloInstruction* gtr = conditional->parent()->AddInstruction( HloInstruction::CreateGetTupleElement(op->shape(), conditional, op_index++)); - TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr)); + RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr)); if (conditional->parent()->root_instruction() == op) { conditional->parent()->set_root_instruction(gtr); } @@ -969,12 +969,12 @@ class MoveOperandIntoBranch { auto new_operands = inst->unique_operands(); // Mapping from operands to their new locations in branch entry. std::vector> matching_tuple_indices; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ReplaceInputInUser(inst, user, new_operands, matching_tuple_indices)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( MoveInputIntoBranch(inst, user, new_operands, matching_tuple_indices)); if (inst->user_count() == 0) { - TF_RETURN_IF_ERROR(inst->parent()->RemoveInstruction(inst)); + RETURN_IF_ERROR(inst->parent()->RemoveInstruction(inst)); } return absl::OkStatus(); } @@ -1113,7 +1113,7 @@ class MoveOperandIntoBranch { op_map_[new_input] = opd_index; VLOG(2) << "Mapping operand " << repl_count << " = " << new_input->ToString() << " to " << opd_index; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( user->ReplaceOperandWithDifferentShape(opd_index, new_input)); *user->mutable_shape()->mutable_tuple_shapes(opd_index) = new_input->shape(); @@ -1165,7 +1165,7 @@ class MoveOperandIntoBranch { op_map_[new_operands[i]] = i; } user = new_input; - TF_RETURN_IF_ERROR(input->ReplaceUseWithDifferentShape(cond, new_input)); + RETURN_IF_ERROR(input->ReplaceUseWithDifferentShape(cond, new_input)); } TF_RET_CHECK(cond->opcode() == HloOpcode::kConditional) << "User has non-conditional users"; @@ -1212,7 +1212,7 @@ class MoveOperandIntoBranch { param_user != branch_comp->root_instruction())) { continue; } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( branch_param->ReplaceUseWithDifferentShape(param_user, inserted)); // We can create invalid get-tuple-element() instructions when the // output is not a tuple. Clean them away here. @@ -1282,7 +1282,7 @@ absl::StatusOr ConditionalCodeMotion::MoveOperandInstructionsIn( users.push_back(std::make_pair( user_now->users()[0], user_now->users()[0]->operand_index(user_now))); } - TF_RETURN_IF_ERROR(move_into_branch(op, user)); + RETURN_IF_ERROR(move_into_branch(op, user)); // Update the user chain of the original op to find the new user. for (int64_t i = users.size() - 1; i > 0; --i) { CHECK_NE(users[i].first, nullptr); @@ -2173,9 +2173,9 @@ absl::StatusOr ConditionalCodeMotion::RunImpl( if (final_d == Decision::Direction::kMoveOutOfBranch) { CHECK(to_move_out.size() == new_boundaries_for_moveout.size()); for (int i = 0; i < to_move_out.size(); ++i) { - TF_ASSIGN_OR_RETURN(bool result, - MoveInstructionOut(conditional, to_move_out[i], - new_boundaries_for_moveout[i])); + ASSIGN_OR_RETURN(bool result, + MoveInstructionOut(conditional, to_move_out[i], + new_boundaries_for_moveout[i])); changed |= result; } VLOG(2) << "Done moving out of branches " << to_move_out.size() @@ -2198,15 +2198,15 @@ absl::StatusOr ConditionalCodeMotion::RunImpl( if (to_move_in[i][0].IsOutsideBranchOperand()) { VLOG(1) << "Modifying code---number of operand boundaries to move in:" << to_move_in[i].size() << "\n"; - TF_ASSIGN_OR_RETURN(bool result, MoveOperandInstructionsIn( - conditional, to_move_in[i])); + ASSIGN_OR_RETURN(bool result, MoveOperandInstructionsIn( + conditional, to_move_in[i])); changed |= result; } else { VLOG(1) << "Modifying code---number of user boundaries to move in:" << to_move_in[i].size() << "\n"; CHECK(to_move_in[i][0].IsOutsideBranchUser()); - TF_ASSIGN_OR_RETURN( - bool result, MoveUserInstructionsIn(conditional, to_move_in[i])); + ASSIGN_OR_RETURN(bool result, + MoveUserInstructionsIn(conditional, to_move_in[i])); changed |= result; } VLOG(2) << "Before removing instructions:" @@ -2218,7 +2218,7 @@ absl::StatusOr ConditionalCodeMotion::RunImpl( if (op->user_count() == 0 && op->parent() != nullptr) { VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n"; - TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op)); + RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op)); VLOG(2) << "Done removing boundary.\n"; } } @@ -2242,9 +2242,8 @@ absl::StatusOr ConditionalCodeMotion::RunImpl( // cloning has been done by the earlier analysis. // TOOD[b/165848866]: extend solution to handle cloning for special // move. - TF_ASSIGN_OR_RETURN( - bool convert_result, - ConvertSpecialMove(conditional, is_layout_sensitive_)); + ASSIGN_OR_RETURN(bool convert_result, + ConvertSpecialMove(conditional, is_layout_sensitive_)); if (convert_result) { VLOG(2) << "Done special moving of convert\n"; if (!ConsumeFuel("conditional_code_motion", [&] { @@ -2264,7 +2263,7 @@ absl::StatusOr ConditionalCodeMotion::RunImpl( subpipeline.AddPass(); subpipeline.AddPass(); subpipeline.AddPass(); - TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); + ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); if (cleanup_changed) { VLOG(2) << "subpipeline cleanup have modified code\n"; } diff --git a/third_party/xla/xla/service/conditional_simplifier.cc b/third_party/xla/xla/service/conditional_simplifier.cc index a5162a9ad0ce67..af9352bf4c650a 100644 --- a/third_party/xla/xla/service/conditional_simplifier.cc +++ b/third_party/xla/xla/service/conditional_simplifier.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -151,7 +152,7 @@ absl::StatusOr TryRemoveUnusedConditionalOperands( } HloInstruction* new_tuple = conditional->parent()->AddInstruction( HloInstruction::CreateTuple(new_tuple_operands)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( conditional->ReplaceOperandWithDifferentShape(branch + 1, new_tuple)); CHECK(ShapeUtil::Compatible(conditional->operand(branch + 1)->shape(), conditional->branch_computation(branch) @@ -482,9 +483,9 @@ absl::StatusOr ConditionalSimplifier::TryRemoveConditional( if (conditional->branch_count() == 1) { HloInstruction* call_op = create_call(0); call_op->set_original_value(conditional->original_value()); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); + RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); if (CallInliner::InlineInstructionAllowed(call_op)) { - TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); + RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); } return true; } @@ -501,9 +502,9 @@ absl::StatusOr ConditionalSimplifier::TryRemoveConditional( } HloInstruction* call_op = create_call(branch_index); call_op->set_original_value(conditional->original_value()); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); + RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); if (CallInliner::InlineInstructionAllowed(call_op)) { - TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); + RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); } return true; @@ -589,14 +590,14 @@ absl::StatusOr ConditionalSimplifier::TryRemoveConditional( HloInstruction::CreateTuple(selects)); }; - TF_RETURN_IF_ERROR(computation->ReplaceInstruction( + RETURN_IF_ERROR(computation->ReplaceInstruction( conditional, select(true_call_op, false_call_op))); if (CallInliner::InlineInstructionAllowed(false_call_op)) { - TF_RETURN_IF_ERROR(CallInliner::Inline(false_call_op).status()); + RETURN_IF_ERROR(CallInliner::Inline(false_call_op).status()); } if (CallInliner::InlineInstructionAllowed(true_call_op)) { - TF_RETURN_IF_ERROR(CallInliner::Inline(true_call_op).status()); + RETURN_IF_ERROR(CallInliner::Inline(true_call_op).status()); } return true; } @@ -663,7 +664,7 @@ absl::StatusOr ConditionalSimplifier::RunImpl( changed |= MergeDuplicateTupleElements(conditional_op); changed |= RemoveUnusedTupleElements(conditional_op); changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op); - TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op)); + ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op)); if (result) { removed_conditionals.insert(conditional_op); changed = true; @@ -693,8 +694,8 @@ absl::StatusOr ConditionalSimplifier::RunImpl( for (auto* comp : calling_computationals_vector) { auto entry = calling_conditionals.find(comp); CHECK(entry != calling_conditionals.end()); - TF_ASSIGN_OR_RETURN(bool result, TryRemoveUnusedConditionalOperands( - entry->first, entry->second)); + ASSIGN_OR_RETURN(bool result, TryRemoveUnusedConditionalOperands( + entry->first, entry->second)); changed |= result; } diff --git a/third_party/xla/xla/service/conditional_to_select.cc b/third_party/xla/xla/service/conditional_to_select.cc index e750fec2e2cf8f..c166d75509fd23 100644 --- a/third_party/xla/xla/service/conditional_to_select.cc +++ b/third_party/xla/xla/service/conditional_to_select.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/conditional_to_select.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -95,12 +96,12 @@ static absl::StatusOr DoConditionalToSelect(HloInstruction* conditional) { ShapeUtil::ChangeElementType(condition->shape(), PrimitiveType::PRED), condition)); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * select_op, MakeSelectHlo(condition, if_call_op, else_call_op, conditional)); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, select_op)); - TF_RETURN_IF_ERROR(CallInliner::Inline(if_call_op).status()); - TF_RETURN_IF_ERROR(CallInliner::Inline(else_call_op).status()); + RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, select_op)); + RETURN_IF_ERROR(CallInliner::Inline(if_call_op).status()); + RETURN_IF_ERROR(CallInliner::Inline(else_call_op).status()); return true; } @@ -110,7 +111,7 @@ absl::StatusOr ConditionalToSelect::RunImpl( std::unique_ptr call_graph = CallGraph::Build(module); bool did_mutate = false; VLOG(1) << "Running conditional-to-select pass"; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( call_graph->VisitNodes([&](const CallGraphNode& node) -> absl::Status { std::vector ToInline; if (node.context() != CallContext::kEmbedded) { @@ -120,8 +121,7 @@ absl::StatusOr ConditionalToSelect::RunImpl( if (callsite.instruction()->opcode() == HloOpcode::kConditional) { VLOG(1) << "Visiting conditional: " << callsite.ToString(); HloInstruction* conditional = callsite.instruction(); - TF_ASSIGN_OR_RETURN(bool result, - DoConditionalToSelect(conditional)); + ASSIGN_OR_RETURN(bool result, DoConditionalToSelect(conditional)); did_mutate |= result; } } diff --git a/third_party/xla/xla/service/copy_insertion.cc b/third_party/xla/xla/service/copy_insertion.cc index d1b6d280cd4723..8907f96ae02853 100644 --- a/third_party/xla/xla/service/copy_insertion.cc +++ b/third_party/xla/xla/service/copy_insertion.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/frontend_attributes.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" @@ -164,12 +165,12 @@ DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, // false) have nullptr at that index. ShapeTree from_copy_tree(from->shape(), /*init_value=*/nullptr); - TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy, - from->parent()->DeepCopyInstruction( - from, &indices_to_copy, &from_copy_tree)); + ASSIGN_OR_RETURN(HloInstruction * from_deep_copy, + from->parent()->DeepCopyInstruction(from, &indices_to_copy, + &from_copy_tree)); ShapeTree to_copy_tree(to->shape(), /*init_value=*/nullptr); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * to_deep_copy, to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree)); @@ -183,7 +184,7 @@ DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, continue; } TF_RET_CHECK(to_copy != nullptr); - TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy)); + RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy)); } return std::make_pair(from_deep_copy, to_deep_copy); @@ -410,10 +411,10 @@ absl::Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, // Deep copy init. HloInstruction* while_init = xla_while->mutable_operand(0); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * while_init_copy, xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy)); - TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy)); + RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy)); // Deep copy the parameter and the root. Extend a control edge from the copy // of the parameter value to the corresponding copy value of the root. @@ -430,14 +431,14 @@ absl::Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, // deep copy). std::vector param_users = param->users(); - TF_ASSIGN_OR_RETURN(auto pair, - DeepCopyAndAddControlEdges(param, root, indices_to_copy)); + ASSIGN_OR_RETURN(auto pair, + DeepCopyAndAddControlEdges(param, root, indices_to_copy)); HloInstruction* param_copy = pair.first; HloInstruction* root_copy = pair.second; for (HloInstruction* user : param_users) { - TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy)); + RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy)); } body->set_root_instruction(root_copy); @@ -451,9 +452,9 @@ absl::Status AddCopiesForInPlaceOperation( int64_t operand_number) { VLOG(2) << "Adding copies for in-place operation " << in_place_op->name(); HloInstruction* operand = in_place_op->mutable_operand(operand_number); - TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, - in_place_op->parent()->DeepCopyInstruction(operand)); - TF_RETURN_IF_ERROR( + ASSIGN_OR_RETURN(HloInstruction * deep_copy, + in_place_op->parent()->DeepCopyInstruction(operand)); + RETURN_IF_ERROR( operand->ReplaceUseWith(in_place_op, operand_number, deep_copy)); return absl::OkStatus(); } @@ -504,15 +505,15 @@ absl::Status AddCopiesForAliasedInputOutputs( std::vector users = param->users(); ShapeTree param_copy_tree(param->shape(), /*init_value=*/nullptr); - TF_ASSIGN_OR_RETURN(HloInstruction * copied, - entry->DeepCopyInstruction( - param, ¶m_indices_to_copy, ¶m_copy_tree)); + ASSIGN_OR_RETURN(HloInstruction * copied, + entry->DeepCopyInstruction(param, ¶m_indices_to_copy, + ¶m_copy_tree)); if (param == root) { entry->set_root_instruction(copied); root = copied; } for (HloInstruction* user : users) { - TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied)); + RETURN_IF_ERROR(param->ReplaceUseWith(user, copied)); } copied_parameters[param->parameter_number()] = param_copy_tree; @@ -526,12 +527,12 @@ absl::Status AddCopiesForAliasedInputOutputs( ShapeTree output_copy_tree(root->shape(), /*init_value=*/nullptr); - TF_ASSIGN_OR_RETURN(HloInstruction * root_copied, - root->parent()->DeepCopyInstruction( - root, &output_indices_to_copy, &output_copy_tree)); + ASSIGN_OR_RETURN(HloInstruction * root_copied, + root->parent()->DeepCopyInstruction( + root, &output_indices_to_copy, &output_copy_tree)); // Add control dependencies between the input/output copies. - TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( + RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( [&](const ShapeIndex& output_index, const HloInputOutputAliasConfig::Alias& alias) -> absl::Status { if (!copied_parameters[alias.parameter_number]) { @@ -544,7 +545,7 @@ absl::Status AddCopiesForAliasedInputOutputs( TF_RET_CHECK(from != nullptr); TF_RET_CHECK(to != nullptr); - TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to)); + RETURN_IF_ERROR(from->AddControlDependencyTo(to)); return absl::OkStatus(); })); @@ -556,12 +557,12 @@ absl::Status AddCopiesForAliasedInputOutputs( // Removes any control dependencies to or from the given instruction. absl::Status StripControlDependenciesFrom(HloInstruction* instruction) { while (!instruction->control_successors().empty()) { - TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo( + RETURN_IF_ERROR(instruction->RemoveControlDependencyTo( instruction->control_successors().front())); } while (!instruction->control_predecessors().empty()) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( instruction->control_predecessors().front()->RemoveControlDependencyTo( instruction)); } @@ -589,11 +590,10 @@ absl::Status CopyInsertion::AddCopiesForConditional( for (HloComputation* computation : conditional->branch_computations()) { HloInstruction* root = computation->root_instruction(); std::vector users = root->users(); - TF_ASSIGN_OR_RETURN( - HloInstruction * deep_copy, - computation->DeepCopyInstruction(root, &indices_to_copy)); + ASSIGN_OR_RETURN(HloInstruction * deep_copy, + computation->DeepCopyInstruction(root, &indices_to_copy)); for (HloInstruction* user : users) { - TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy)); + RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy)); } computation->set_root_instruction(deep_copy); } @@ -891,13 +891,13 @@ absl::Status AddCopiesForNonCopyableTransitionsRotatedCase( VLOG(2) << "Transition from copyable to non-copyable: copy " << operand->ToString() << " for " << start_op->ToString() << " output_index "; - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(start_op, copied_operand)); - TF_RETURN_IF_ERROR(end_op->AddControlDependencyTo(copied_operand)); + RETURN_IF_ERROR(operand->ReplaceUseWith(start_op, copied_operand)); + RETURN_IF_ERROR(end_op->AddControlDependencyTo(copied_operand)); } // Add a control dependency from the rotated end_op of the chain to the // start_op of the chain guarantee disjoint live times of the buffer. - TF_RETURN_IF_ERROR(end_op->AddControlDependencyTo(start_op)); + RETURN_IF_ERROR(end_op->AddControlDependencyTo(start_op)); // Insert copies for the result produced by the end_op of the chain where we // transition from non-copyable to copyable. @@ -910,7 +910,7 @@ absl::Status AddCopiesForNonCopyableTransitionsRotatedCase( if (!end_op->shape().IsTuple()) { HloInstruction* copy = while_body->AddInstruction( HloInstruction::CreateUnary(end_op->shape(), HloOpcode::kCopy, end_op)); - TF_RETURN_IF_ERROR(copy->AddControlDependencyTo(start_op)); + RETURN_IF_ERROR(copy->AddControlDependencyTo(start_op)); return end_op->ReplaceAllUsesWith(copy); } @@ -924,8 +924,8 @@ absl::Status AddCopiesForNonCopyableTransitionsRotatedCase( << user->ToString() << " for all users"; HloInstruction* copy = while_body->AddInstruction( HloInstruction::CreateUnary(user->shape(), HloOpcode::kCopy, user)); - TF_RETURN_IF_ERROR(copy->AddControlDependencyTo(start_op)); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(copy)); + RETURN_IF_ERROR(copy->AddControlDependencyTo(start_op)); + RETURN_IF_ERROR(user->ReplaceAllUsesWith(copy)); } } @@ -959,7 +959,7 @@ absl::Status CopyInsertion::AddCopiesForExplicitNonCopyableTransitions( HloInstruction* copied_operand = parent->AddInstruction(HloInstruction::CreateUnary( operand->shape(), HloOpcode::kCopy, operand)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(chain_start, copied_operand)); + RETURN_IF_ERROR(operand->ReplaceUseWith(chain_start, copied_operand)); } return absl::OkStatus(); @@ -1016,15 +1016,15 @@ absl::Status AddCopiesForNonCopyableTransitionsRotatedCase( HloInstruction* copied_operand = while_body->AddInstruction(HloInstruction::CreateUnary( operand->shape(), HloOpcode::kCopy, operand)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(chain_start, copied_operand)); - TF_RETURN_IF_ERROR(chain_end->AddControlDependencyTo(copied_operand)); + RETURN_IF_ERROR(operand->ReplaceUseWith(chain_start, copied_operand)); + RETURN_IF_ERROR(chain_end->AddControlDependencyTo(copied_operand)); } // The chain_end is rotated and semantically paired with the chain_start of // the previous iteration. We add a control dependency from the chain_end to // the chain_start to in the same lexical iteration guarantee disjoint live // times of the buffers involved. - TF_RETURN_IF_ERROR(chain_end->AddControlDependencyTo(chain_start)); + RETURN_IF_ERROR(chain_end->AddControlDependencyTo(chain_start)); // If chain_end has users, insert copies for the result produced by the // chain_end with aliasing input and output buffers, where we transition from @@ -1034,17 +1034,16 @@ absl::Status AddCopiesForNonCopyableTransitionsRotatedCase( return absl::OkStatus(); } ShapeTree copies_added(chain_end->shape()); - TF_ASSIGN_OR_RETURN( - HloInstruction * copy, - while_body->DeepCopyInstruction(chain_end, /*indices_to_copy=*/nullptr, - &copies_added)); + ASSIGN_OR_RETURN(HloInstruction * copy, + while_body->DeepCopyInstruction( + chain_end, /*indices_to_copy=*/nullptr, &copies_added)); for (auto [shape_index, instr] : copies_added) { if (instr != nullptr) { - TF_RETURN_IF_ERROR(instr->AddControlDependencyTo(chain_start)); + RETURN_IF_ERROR(instr->AddControlDependencyTo(chain_start)); } } for (HloInstruction* it : users) { - TF_RETURN_IF_ERROR(chain_end->ReplaceUseWith(it, copy)); + RETURN_IF_ERROR(chain_end->ReplaceUseWith(it, copy)); } return absl::OkStatus(); } @@ -1097,8 +1096,8 @@ absl::Status CopyInsertion::AddCopiesForNonCopyableTransitions( if (!IsImplicitNonCopyable(chain_start)) { if (chain_start->IsCustomCall(kPinCustomCallTarget) || chain_start->IsCustomCall(kCreateBufferCustomCallTarget)) { - TF_RETURN_IF_ERROR(AddCopiesForExplicitNonCopyableTransitions( - alias_analysis, chain_start)); + RETURN_IF_ERROR(AddCopiesForExplicitNonCopyableTransitions(alias_analysis, + chain_start)); } return absl::OkStatus(); } @@ -1132,7 +1131,7 @@ absl::Status CopyInsertion::AddCopiesForNonCopyableTransitions( HloInstruction* copied_operand = parent->AddInstruction(HloInstruction::CreateUnary( operand->shape(), HloOpcode::kCopy, operand)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(chain_start, copied_operand)); + RETURN_IF_ERROR(operand->ReplaceUseWith(chain_start, copied_operand)); return absl::OkStatus(); } @@ -1158,8 +1157,8 @@ absl::Status CopyInsertion::AddCopiesForNonCopyableTransitions( absl::Status CopyInsertion::AddCopiesToResolveInterference( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module, alias_info_)); + ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, alias_info_)); for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { if (computation->IsAsyncComputation()) { @@ -1168,14 +1167,13 @@ absl::Status CopyInsertion::AddCopiesToResolveInterference( for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); + RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); } else if (instruction->opcode() == HloOpcode::kConditional) { - TF_RETURN_IF_ERROR( - AddCopiesForConditional(*alias_analysis, instruction)); + RETURN_IF_ERROR(AddCopiesForConditional(*alias_analysis, instruction)); } else if (IsNonCopyable(instruction)) { // We currently assume that we don't have a custom-call with // output-to-operand aliases for both buffers and non-buffers. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( AddCopiesForNonCopyableTransitions(*alias_analysis, instruction)); } else { // When an operand is a tuple, we avoid copying the operand multiple @@ -1210,15 +1208,14 @@ absl::Status CopyInsertion::AddCopiesToResolveInterference( continue; } copied_operands.insert(operand_index.operand_number); - TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation( + RETURN_IF_ERROR(AddCopiesForInPlaceOperation( *alias_analysis, instruction, operand_index.operand_number)); } } } } - TF_RETURN_IF_ERROR( - AddCopiesForAliasedInputOutputs(module, execution_threads)); + RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module, execution_threads)); return absl::OkStatus(); } @@ -1241,8 +1238,8 @@ absl::Status CopyInsertion::AddSpecialCaseCopies( std::function should_add_target_specific_copies, CustomBufferAnalysisFn custom_buffer_analysis) { - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module, alias_info_)); + ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, alias_info_)); // Identify which shape indices of which instructions need to be copied. Store // these results in 'instructions_to_copy'. @@ -1400,11 +1397,11 @@ absl::Status CopyInsertion::AddSpecialCaseCopies( ShapeTree copies_added(indices_to_copy.shape()); std::vector users = instruction->users(); - TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, - instruction->parent()->DeepCopyInstruction( - instruction, &indices_to_copy, &copies_added)); + ASSIGN_OR_RETURN(HloInstruction * deep_copy, + instruction->parent()->DeepCopyInstruction( + instruction, &indices_to_copy, &copies_added)); for (HloInstruction* user : users) { - TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); + RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); } if (instruction == instruction->parent()->root_instruction()) { instruction->parent()->set_root_instruction(deep_copy); @@ -1444,8 +1441,8 @@ absl::Status CopyInsertion::RemoveUnnecessaryCopies( ordering = std::make_unique(module); } - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module, alias_info_)); + ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, alias_info_)); CopyRemover copy_remover(*module, *alias_analysis, alias_info_, ordering.get(), execution_threads); if (VLOG_IS_ON(3)) { @@ -1488,8 +1485,8 @@ absl::Status CopyInsertion::RemoveUnnecessaryCopies( insert_post_scheduling_control_dependencies, should_skip_removal_)) { changed = true; - TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(StripControlDependenciesFrom(instruction)); + RETURN_IF_ERROR( instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); VLOG(3) << "Copy removed successfully: " << instruction->ToString(); XLA_VLOG_LINES( @@ -1543,7 +1540,7 @@ absl::StatusOr CopyInsertion::RunImpl( int64_t num_copies_before = GetNumExistingCopies(module, execution_threads); - TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module, execution_threads)); + RETURN_IF_ERROR(AddCopiesToResolveInterference(module, execution_threads)); // Simplify the tuple structures introduced by the deep copies. This should be // done before removing copies (RemoveUnnecessaryCopies) because tuple @@ -1552,22 +1549,22 @@ absl::StatusOr CopyInsertion::RunImpl( // instructions introduced by tuple simplification. TupleSimplifier tuple_simplifier; HloDCE dce; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module, execution_threads).status()); - TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status()); + RETURN_IF_ERROR(tuple_simplifier.Run(module, execution_threads).status()); + RETURN_IF_ERROR(dce.Run(module, execution_threads).status()); DumpHloModuleDuringPassIfEnabled( name(), "after adding copies to resolve interference", *module); - TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(module, execution_threads)); + RETURN_IF_ERROR(RemoveUnnecessaryCopies(module, execution_threads)); DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies", *module); - TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, execution_threads, - module, nullptr, - /*custom_buffer_analysis=*/nullptr)); + RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, execution_threads, module, + nullptr, + /*custom_buffer_analysis=*/nullptr)); DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies", *module); - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module, execution_threads).status()); - TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status()); + RETURN_IF_ERROR(tuple_simplifier.Run(module, execution_threads).status()); + RETURN_IF_ERROR(dce.Run(module, execution_threads).status()); VLOG(1) << "Num copies before copy-insertion: " << num_copies_before; VLOG(1) << "Num copies after copy-insertion: " diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 7637077868d91a..6a7393ac2a8d3d 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -1266,6 +1266,8 @@ absl::Status CreateHloProfilingArtifacts( absl::StatusOr> CpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, const CompileOptions& options) { + XLA_SCOPED_LOGGING_TIMER( + absl::StrFormat("Running HLO passes on [%s] for CPU", module->name())); if (MultiModuleDriver::ShouldProcess(*module)) { VLOG(1) << "Triggering HLO module splitting for module: " << module->name(); { diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc index 8d5d0ee74d77f3..207d22b9d797c1 100644 --- a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc +++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/layout_util.h" #include "xla/primitive_util.h" @@ -29,7 +30,7 @@ limitations under the License. namespace xla { absl::Status CpuGpuShapeVerifier::Preprocess(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( hlo->shape(), [&](const Shape& shape, const ShapeIndex&) { if (shape.has_layout()) { if (!primitive_util::IsSubByteNonPredType(shape.element_type()) && diff --git a/third_party/xla/xla/service/dot_as_convolution_util.cc b/third_party/xla/xla/service/dot_as_convolution_util.cc index 619b000df532d0..0808b760b8b772 100644 --- a/third_party/xla/xla/service/dot_as_convolution_util.cc +++ b/third_party/xla/xla/service/dot_as_convolution_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/shape_inference.h" @@ -168,14 +169,13 @@ CreateShardedConvForDotGeneralConvolution( wd->set_padding_high(wd->size() - 1); wd->set_padding_low(wd->size() - 1); } - TF_ASSIGN_OR_RETURN( - Shape sharded_conv_shape, - ShapeInference::InferConvolveShape( - sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), - /*feature_group_count=*/conv.feature_group_count(), - /*batch_group_count=*/conv.batch_group_count(), window, conv_dnums, - conv.sparsity_config(), - /*preferred_element_type=*/conv.shape().element_type())); + ASSIGN_OR_RETURN(Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), + /*feature_group_count=*/conv.feature_group_count(), + /*batch_group_count=*/conv.batch_group_count(), window, + conv_dnums, conv.sparsity_config(), + /*preferred_element_type=*/conv.shape().element_type())); *sharded_conv_shape.mutable_layout() = conv.shape().layout(); CHECK(!conv.sparsity_config().has_lhs() && !conv.sparsity_config().has_rhs()); return HloInstruction::CreateConvolve( diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc index 9dfae97c08aed0..7bb5ab301c834d 100644 --- a/third_party/xla/xla/service/dump.cc +++ b/third_party/xla/xla/service/dump.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/OperationSupport.h" #include "mlir/Transforms/LocationSnapshot.h" @@ -184,19 +185,19 @@ static absl::Status WriteStringToFile(tsl::Env* env, const std::string& fname, DataProducer& data_producer, bool compressed) { std::unique_ptr file; - TF_RETURN_IF_ERROR(env->NewWritableFile(fname, &file)); + RETURN_IF_ERROR(env->NewWritableFile(fname, &file)); if (compressed) { auto gz_opts = tsl::io::ZlibCompressionOptions::GZIP(); tsl::io::ZlibOutputBuffer gz_file(file.get(), gz_opts.input_buffer_size, gz_opts.output_buffer_size, gz_opts); - TF_RETURN_IF_ERROR(gz_file.Init()); + RETURN_IF_ERROR(gz_file.Init()); while (auto next_producer = data_producer.Next()) { - TF_RETURN_IF_ERROR(gz_file.Append(next_producer())); + RETURN_IF_ERROR(gz_file.Append(next_producer())); } return gz_file.Close(); } while (auto next_producer = data_producer.Next()) { - TF_RETURN_IF_ERROR(file->Append(next_producer())); + RETURN_IF_ERROR(file->Append(next_producer())); } return file->Close(); } @@ -207,12 +208,12 @@ static absl::Status WriteStringToFile(tsl::Env* env, const std::string& fname, return tsl::WriteStringToFile(env, fname, data); } std::unique_ptr file; - TF_RETURN_IF_ERROR(env->NewWritableFile(fname, &file)); + RETURN_IF_ERROR(env->NewWritableFile(fname, &file)); auto gz_opts = tsl::io::ZlibCompressionOptions::GZIP(); tsl::io::ZlibOutputBuffer gz_file(file.get(), gz_opts.input_buffer_size, gz_opts.output_buffer_size, gz_opts); - TF_RETURN_IF_ERROR(gz_file.Init()); - TF_RETURN_IF_ERROR(gz_file.Append(data)); + RETURN_IF_ERROR(gz_file.Init()); + RETURN_IF_ERROR(gz_file.Append(data)); return gz_file.Close(); } @@ -1121,8 +1122,8 @@ absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message, absl::string_view file_name, std::string* full_path) { tsl::Env* env = tsl::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); - TF_RETURN_IF_ERROR(CreateDirIfNeeded(directory, env)); + RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + RETURN_IF_ERROR(CreateDirIfNeeded(directory, env)); std::string safe_file_name = SanitizeFileName(std::string(file_name)) + ".pb"; std::string full_path_impl; if (!full_path) { diff --git a/third_party/xla/xla/service/dynamic_dimension_inference.cc b/third_party/xla/xla/service/dynamic_dimension_inference.cc index 5d606be110b90e..4384cf9be5c908 100644 --- a/third_party/xla/xla/service/dynamic_dimension_inference.cc +++ b/third_party/xla/xla/service/dynamic_dimension_inference.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" @@ -99,7 +100,7 @@ WidenComputation(HloComputation* narrow_comp, const Shape& wide_shape) { std::make_shared(OriginalValue::SyntheticCall())); wide_comp->set_root_instruction(call_narrow_comp, /*accept_different_shape=*/true); - TF_ASSIGN_OR_RETURN(auto inline_map, CallInliner::Inline(call_narrow_comp)); + ASSIGN_OR_RETURN(auto inline_map, CallInliner::Inline(call_narrow_comp)); return std::make_pair(wide_comp, std::move(inline_map)); } } // namespace @@ -139,7 +140,7 @@ class DynamicDimensionInferenceVisitor : public DfsHloRewriteVisitor { param_bindings, dataflow_analysis, parent, std::move(custom_call_handler), shape_check_mode, assertion_generator); - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); if (visitor.shape_assertion_ != nullptr) { CHECK(assertion_generator); assertion_generator(visitor.shape_assertion_); @@ -371,7 +372,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleTuple( if (!CanInfer(hlo)) { return absl::OkStatus(); } - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction*, ShapeIndex index, int64_t dimension, int64_t operand_index, HloInstruction* dynamic_size) { index.push_front(operand_index); @@ -404,13 +405,13 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConstant( ShapeTree do_pad(constant->shape(), false); Shape padded_shape = constant->shape(); bool pad_any = false; - TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( &padded_shape, [&](Shape* subshape, const ShapeIndex& index) -> absl::Status { if (!subshape->IsArray()) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(bool requires_pad, RequiresPadToStatic(hlo, index)); + ASSIGN_OR_RETURN(bool requires_pad, RequiresPadToStatic(hlo, index)); if (requires_pad) { pad_any = *do_pad.mutable_element(index) = true; *subshape = ShapeUtil::MakeStaticShape(*subshape); @@ -421,14 +422,14 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConstant( return absl::OkStatus(); } Literal padded_literal(padded_shape); - do_pad.ForEachElement([&](const ShapeIndex& index, bool requires_pad) { + do_pad.ForEachElement([&](const ShapeIndex& index, + bool requires_pad) -> absl::Status { const Shape& subshape = ShapeUtil::GetSubshape(padded_shape, index); if (!subshape.IsArray()) { return absl::OkStatus(); } - TF_RETURN_IF_ERROR(padded_literal.CopyFrom(constant->literal(), index, - index, - /*only_dynamic_bound=*/true)); + RETURN_IF_ERROR(padded_literal.CopyFrom(constant->literal(), index, index, + /*only_dynamic_bound=*/true)); if (!requires_pad) { for (int64_t dimension = 0; dimension < subshape.dimensions().size(); ++dimension) { @@ -443,9 +444,9 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConstant( }); auto* padded_constant = hlo->AddInstruction( HloInstruction::CreateConstant(std::move(padded_literal))); - TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(padded_constant)); + RETURN_IF_ERROR(constant->ReplaceAllUsesWith(padded_constant)); SetVisited(*padded_constant); - TF_RETURN_IF_ERROR(do_pad.ForEachElementWithStatus( + RETURN_IF_ERROR(do_pad.ForEachElementWithStatus( [&](const ShapeIndex& index, bool requires_pad) -> absl::Status { if (!requires_pad) { return absl::OkStatus(); @@ -497,7 +498,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleCustomCall( handled = custom_call_handler_(hlo, parent_); } if (!handled) { - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, int64_t operand_index, @@ -640,7 +641,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleReduce( } auto* reduce = Cast(hlo); int64_t rank = -1; - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( reduce->shape(), [&](const Shape& subshape, const ShapeIndex& index) -> absl::Status { if (!subshape.IsArray()) { @@ -656,7 +657,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleReduce( TF_RET_CHECK(rank >= 0); absl::InlinedVector dynamic_sizes(rank, nullptr); - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, int64_t operand_index, HloInstruction* dynamic_size) { int64_t operand_count = reduce->operand_count(); @@ -706,7 +707,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { } absl::InlinedVector dynamic_sizes( hlo->shape().dimensions().size(), nullptr); - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex operand_shape_index, int64_t operand_dimension, int64_t operand_index, @@ -898,7 +899,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConcatenate( } // Simply pass through non-concat dynamic dimensions. - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, int64_t operand_index, HloInstruction* dynamic_size) -> absl::Status { @@ -944,7 +945,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize( } if (replacement != nullptr) { - TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(replacement)); + RETURN_IF_ERROR(gds->ReplaceAllUsesWith(replacement)); // The dependency between an instruction and its dynamic dimensions is not // modeled in the IR. As instr is being replaced by dynamic_size, also tell // dynamic dimension inference that the instruction is being replaced. @@ -982,7 +983,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize( } // Also Propagate dynamic dimension already set by operands. - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, int64_t operand_index, HloInstruction* dynamic_size) -> absl::Status { @@ -1098,7 +1099,7 @@ absl::Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension( // the dynamic size. ShapeTree> dynamic_sizes( hlo->shape()); - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, int64_t operand_index, HloInstruction* dynamic_size) { const Shape& subshape = ShapeUtil::GetSubshape(hlo->shape(), index); @@ -1165,7 +1166,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleElementwiseNary( absl::InlinedVector, 2> operand_sizes( hlo->shape().dimensions().size(), absl::InlinedVector(hlo->operand_count(), nullptr)); - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, int64_t operand_index, HloInstruction* dynamic_size) -> absl::Status { @@ -1188,9 +1189,8 @@ absl::Status DynamicDimensionInferenceVisitor::HandleElementwiseNary( if (existing_size == nullptr) { existing_sizes[dimension] = dynamic_size; } else if (existing_sizes[dimension] != dynamic_size) { - TF_RETURN_IF_ERROR( - InsertShapeCheck(existing_size, dynamic_size, - /*support_implicit_broadcast=*/true)); + RETURN_IF_ERROR(InsertShapeCheck(existing_size, dynamic_size, + /*support_implicit_broadcast=*/true)); auto one = comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::One(S32))); @@ -1393,7 +1393,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleReshape( bool need_flatten_unflatten = hlo->inferred_dimension() != -1 && hlo->shape().dimensions(hlo->inferred_dimension()) == 1; - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t input_dynamic_dimension, int64_t operand_index, @@ -1454,7 +1454,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleReshape( hlo->ToString()); } - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t input_dynamic_dimension, int64_t operand_index, @@ -1656,7 +1656,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleReduceWindow( } ShapeTree> dynamic_sizes( hlo->shape()); - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, int64_t operand_index, HloInstruction* dynamic_size) { auto* reduce_window = Cast(hlo); @@ -1805,7 +1805,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice( } absl::InlinedVector output_dynamic_sizes( hlo->shape().dimensions().size(), nullptr); - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension, int64_t operand_index, HloInstruction* dynamic_size) -> absl::Status { @@ -1852,7 +1852,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleGather( } absl::InlinedVector output_dynamic_sizes( hlo->shape().dimensions().size(), nullptr); - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64_t input_dynamic_dimension, int64_t operand_index, @@ -1955,7 +1955,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConditional( operand_shape.IsTuple() ? operand_shape.tuple_shapes().size() : 0; // Prepare to pass dynamic dimension into the new computation and add // dynamic dimension sizes as parameters to the new tuple. - TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand( + RETURN_IF_ERROR(ForEachDynamicDimensionInOperand( hlo, operand_index, [&](HloInstruction*, ShapeIndex, int64_t, int64_t, HloInstruction* dynamic_size) -> absl::Status { @@ -1995,24 +1995,24 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConditional( for (HloInstruction* operand : operands_to_add) { ShapeUtil::AppendShapeToTuple(operand->shape(), &new_param_shape); } - TF_ASSIGN_OR_RETURN( - std::tie(new_computation, inline_map), - WidenComputation(branch_computation, new_param_shape)); + ASSIGN_OR_RETURN(std::tie(new_computation, inline_map), + WidenComputation(branch_computation, new_param_shape)); } // Set the dynamic dimensions for the newly created branch computation's // parameters so that the hlos inside the computation can see dynamic // dimensions. DynamicParameterBinding dynamic_parameter_binding; - TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand( + RETURN_IF_ERROR(ForEachDynamicDimensionInOperand( hlo, operand_index, [&](HloInstruction*, ShapeIndex index, int64_t dimension, - int64_t operand_index, HloInstruction* dynamic_size) { + int64_t operand_index, + HloInstruction* dynamic_size) -> absl::Status { DynamicParameterBinding::DynamicSizeParameter dynamic_parameter{ 0, {dynamic_size_to_operand_id_index_map[dynamic_size]}}; DynamicParameterBinding::DynamicDimension dynamic_dimension{ 0, {index}, dimension}; - TF_RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter, - dynamic_dimension)); + RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter, + dynamic_dimension)); return absl::OkStatus(); })); @@ -2026,12 +2026,11 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConditional( /*dynamic_size_map=*/&inline_map); } - TF_ASSIGN_OR_RETURN( - bool changed, - DynamicDimensionInferenceVisitor::Run( - new_computation, dataflow_analysis_, dynamic_parameter_binding, - parent_, custom_call_handler_, shape_check_mode_, - assertion_generator_)); + ASSIGN_OR_RETURN(bool changed, DynamicDimensionInferenceVisitor::Run( + new_computation, dataflow_analysis_, + dynamic_parameter_binding, parent_, + custom_call_handler_, shape_check_mode_, + assertion_generator_)); if (changed) { MarkAsChanged(); } @@ -2138,9 +2137,9 @@ absl::Status DynamicDimensionInferenceVisitor::HandleConditional( } }); - TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted)); + RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted)); // Remove the original instruction even if has side-effects. - TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo)); + RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo)); SetVisited(*new_conditional); SetVisited(*new_conditional_extracted); MarkAsChanged(); @@ -2245,7 +2244,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile( int operand_count = original_tuple_count; // Clean up the result shape DynamicParameterBinding binding_for_while; - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dim, int64_t operand_num, HloInstruction* dynamic_size) -> absl::Status { @@ -2262,7 +2261,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile( /*parameter_num=*/0, /*parameter_index=*/{operand_count}, }; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( binding_for_while.Bind(dynamic_size_param, dynamic_dimension)); ++operand_count; return absl::OkStatus(); @@ -2277,8 +2276,8 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile( // hlo->while_body()->parameter_instruction(0); // HloInstruction* old_condition_parameter = // hlo->while_condition()->parameter_instruction(0); - TF_ASSIGN_OR_RETURN(WhileUtil::MakeInstructionsLiveInResult result, - WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add)); + ASSIGN_OR_RETURN(WhileUtil::MakeInstructionsLiveInResult result, + WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add)); TF_RET_CHECK(result.replacement_instr->opcode() == HloOpcode::kTuple); // WhileUtil creates a new while hlo and tuple. Update the dynamic size // mapping for the newly created tuple. @@ -2312,16 +2311,16 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile( // Rerun inference on the body and condition now that we have added dynamic // size parameters. - TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( - hlo->while_body(), dataflow_analysis_, - binding_for_while, parent_, custom_call_handler_, - shape_check_mode_, assertion_generator_) - .status()); - TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( - hlo->while_condition(), dataflow_analysis_, - binding_for_while, parent_, custom_call_handler_, - shape_check_mode_, assertion_generator_) - .status()); + RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( + hlo->while_body(), dataflow_analysis_, binding_for_while, + parent_, custom_call_handler_, shape_check_mode_, + assertion_generator_) + .status()); + RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( + hlo->while_condition(), dataflow_analysis_, + binding_for_while, parent_, custom_call_handler_, + shape_check_mode_, assertion_generator_) + .status()); // The dynamic dimension size could have been changed in the loop body (e.g, A // loop that inserts items in a stack, the stack size increases with each @@ -2337,7 +2336,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile( body_root->shape().tuple_shapes(i), body_root, i)); } // Add dynamic dimension size as new outputs of the while loop body. - TF_RETURN_IF_ERROR(dynamic_output_mapping.ForEachElementWithStatus( + RETURN_IF_ERROR(dynamic_output_mapping.ForEachElementWithStatus( [&](const ShapeIndex& index, const absl::flat_hash_map& dim_to_size) -> absl::Status { @@ -2356,7 +2355,7 @@ absl::Status DynamicDimensionInferenceVisitor::HandleWhile( HloInstruction* new_body_root = hlo->while_body()->AddInstruction( HloInstruction::CreateTuple(new_root_operands)); for (int i = 0; i < original_tuple_count; ++i) { - TF_RETURN_IF_ERROR(ForEachDynamicDimension( + RETURN_IF_ERROR(ForEachDynamicDimension( body_root, [&](ShapeIndex index, int64_t dimension, HloInstruction* dynamic_size) -> absl::Status { @@ -2435,7 +2434,7 @@ absl::Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension( HloInstruction* dynamic_size = parent_->GetDynamicSize( dynamic_dimension.inst, dynamic_dimension.index, dynamic_dimension.dim); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( fn(dynamic_dimension.index, dynamic_dimension.dim, dynamic_size)); } } @@ -2510,13 +2509,12 @@ absl::Status DynamicDimensionInferenceVisitor::InsertPadToStaticOnInstruction( // Decide while leaf arrays need to be padded. ShapeTree needs_pad(inst->shape(), false); bool any_needs_pad = false; - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( inst->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { if (subshape.IsTuple()) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(bool do_pad, - RequiresPadToStatic(inst, shape_index)); + ASSIGN_OR_RETURN(bool do_pad, RequiresPadToStatic(inst, shape_index)); if (do_pad) { *needs_pad.mutable_element(shape_index) = true; any_needs_pad = true; @@ -2535,7 +2533,7 @@ absl::Status DynamicDimensionInferenceVisitor::InsertPadToStaticOnInstruction( // Add PadToStatic to the leaf arrays and record the dynamic dimensions. ShapeTree padded(inst->shape(), nullptr); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapePostOrderWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapePostOrderWithStatus( inst->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) -> absl::Status { @@ -2610,7 +2608,7 @@ absl::Status DynamicDimensionInferenceVisitor::InsertPadToStaticOnInstruction( // Replace all uses of the original instruction with the padded outputs. for (auto user : users) { for (int64_t i : user->OperandIndices(inst)) { - TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, result)); + RETURN_IF_ERROR(user->ReplaceOperandWith(i, result)); } } if (inst->IsRoot()) { @@ -2634,13 +2632,12 @@ absl::Status DynamicDimensionInferenceVisitor::InsertShapeCheck( "%s vs %s", dim1->ToString(), dim2->ToString()); case DynamicDimensionInference::kRuntime: { - TF_ASSIGN_OR_RETURN( - HloInstruction * assertion, - MakeCompareHlo(Comparison::Direction::kEq, dim1, dim2)); + ASSIGN_OR_RETURN(HloInstruction * assertion, + MakeCompareHlo(Comparison::Direction::kEq, dim1, dim2)); if (shape_assertion_ == nullptr) { shape_assertion_ = assertion; } else { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( shape_assertion_, MakeBinaryHlo(HloOpcode::kAnd, shape_assertion_, assertion)); } @@ -2660,9 +2657,8 @@ absl::Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand( HloInstruction* dynamic_size = parent_->GetDynamicSize( dynamic_dimension.inst, dynamic_dimension.index, dynamic_dimension.dim); - TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index, - dynamic_dimension.dim, operand_index, - dynamic_size)); + RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim, operand_index, dynamic_size)); } } return absl::OkStatus(); @@ -2672,8 +2668,7 @@ absl::Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( HloInstruction* inst, OperandDynamicDimensionFn fn) { for (int64_t operand_index = 0; operand_index < inst->operand_count(); ++operand_index) { - TF_RETURN_IF_ERROR( - ForEachDynamicDimensionInOperand(inst, operand_index, fn)); + RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(inst, operand_index, fn)); } return absl::OkStatus(); } @@ -2732,7 +2727,7 @@ absl::StatusOr DynamicDimensionInference::Run( module, std::move(op_supports_dynamism_handler), std::move(custom_call_handler), shape_check_mode, assertion_generator, execution_threads); - TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions()); + RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions()); return std::move(inference); } @@ -2763,16 +2758,16 @@ DynamicDimensionInference::DynamicDimensionInference( execution_threads_(execution_threads) {} absl::Status DynamicDimensionInference::AnalyzeDynamicDimensions() { - TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow_analysis, - HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false, - /*bitcast_defines_value=*/true, - execution_threads_)); + ASSIGN_OR_RETURN(std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false, + /*bitcast_defines_value=*/true, + execution_threads_)); for (HloComputation* computation : module_->MakeComputationPostOrder()) { if (!HloInstruction::IsThreadIncluded(computation->execution_thread(), execution_threads_)) { continue; } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool changed, DynamicDimensionInferenceVisitor::Run( computation, *dataflow_analysis, {}, this, custom_call_handler_, diff --git a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc index ca9c4a19b4e8ed..b78a407b73e591 100644 --- a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc +++ b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -57,10 +58,10 @@ class DynamicDimensionInferenceTest : public HloHardwareIndependentTestBase { DynamicDimensionInference::ShapeCheckMode::kIgnore, const DynamicDimensionInference::AssertionGenerator& assertion_generator = nullptr) { - TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, - DynamicDimensionInference::Run( - module_.get(), op_supports_dynamism_handler, - handler, shape_check_mode, assertion_generator)); + ASSIGN_OR_RETURN(DynamicDimensionInference inference, + DynamicDimensionInference::Run( + module_.get(), op_supports_dynamism_handler, handler, + shape_check_mode, assertion_generator)); inference_ = std::make_unique(inference); return absl::OkStatus(); diff --git a/third_party/xla/xla/service/dynamic_padder.cc b/third_party/xla/xla/service/dynamic_padder.cc index b0cf5cd49a2862..45a935748a2764 100644 --- a/third_party/xla/xla/service/dynamic_padder.cc +++ b/third_party/xla/xla/service/dynamic_padder.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" @@ -183,9 +184,9 @@ absl::StatusOr ReplaceGetSize( } HloComputation* computation = instr->parent(); - TF_ASSIGN_OR_RETURN(auto legal_shape, - ShapeInference::InferGetDimensionSizeShape( - instr->operand(0)->shape(), instr->dimension())); + ASSIGN_OR_RETURN(auto legal_shape, + ShapeInference::InferGetDimensionSizeShape( + instr->operand(0)->shape(), instr->dimension())); TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)) << "instr->shape() " << instr->shape().ToString() << " , " << "legal_shape " << legal_shape.ToString(); @@ -195,7 +196,7 @@ absl::StatusOr ReplaceGetSize( HloInstruction* dynamic_size = dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); if (dynamic_size != nullptr) { - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); + RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); // The dependency between a instruction and its dynamic dimensions is not // modeled in the IR. As instr is being replaced by dynamic_size, also tell // dynamic dimension inference that the instruction is being replaced. @@ -205,7 +206,7 @@ absl::StatusOr ReplaceGetSize( int32_t size = instr->operand(0)->shape().dimensions(dim); HloInstruction* new_instr = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr, new_instr); } @@ -223,7 +224,7 @@ absl::StatusOr ReplaceSetSize(HloInstruction* instr) { << "instruction operand shape " << instr->operand(0)->shape(); HloInstruction* operand = instr->mutable_operand(0); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); + RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); return true; } @@ -239,7 +240,7 @@ absl::StatusOr ReplaceSetBound(HloInstruction* instr) { << "instruction operand shape " << instr->operand(0)->shape(); HloInstruction* operand = instr->mutable_operand(0); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); + RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); return true; } @@ -587,7 +588,7 @@ absl::StatusOr RewriteDynamicReshapeSplitInput( // Step 4: Feed gather input to original reshape. - TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, gather)); + RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, gather)); HloInstruction* reshape_dynamic = reshape; @@ -606,9 +607,9 @@ absl::StatusOr RewriteDynamicReshapeSplitInput( } for (auto* user : users) { - TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, reshape_dynamic)); + RETURN_IF_ERROR(reshape->ReplaceUseWith(user, reshape_dynamic)); } - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( reshape, reshape_dynamic, {})); return true; @@ -778,7 +779,7 @@ absl::StatusOr RewriteDynamicReshapeCombineInput( for (auto* user : users) { // Avoid cycles by not replacing the static reshape and get_dimension_size. if (user != reshape_static && user != output_dynamic_size) { - TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, gather)); + RETURN_IF_ERROR(reshape->ReplaceUseWith(user, gather)); } } @@ -786,7 +787,7 @@ absl::StatusOr RewriteDynamicReshapeCombineInput( reshape->parent()->set_root_instruction(gather); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( dynamic_dimension_inference->ForwardDynamicSize(reshape, gather, {})); return true; @@ -914,9 +915,9 @@ absl::StatusOr RewriteReverse( HloInstruction* dynamic_reverse = reverse->AddInstruction(HloInstruction::CreateDynamicSlice( reverse_shape, pad, start_indices, reverse_shape.dimensions())); - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( reverse, dynamic_reverse, {})); - TF_RETURN_IF_ERROR(reverse->ReplaceAllUsesWith(dynamic_reverse)); + RETURN_IF_ERROR(reverse->ReplaceAllUsesWith(dynamic_reverse)); return true; } @@ -1047,8 +1048,8 @@ absl::StatusOr RewriteDynamicConvolutionInputGrad( custom_call_conv->batch_group_count(), window, custom_call_conv->convolution_dimension_numbers(), custom_call_conv->precision_config())); - TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv)); - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv)); + RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( custom_call_conv, static_conv, {})); return true; } @@ -1106,8 +1107,8 @@ absl::StatusOr RewriteDynamicConvolutionForward( custom_call_conv->batch_group_count(), window, custom_call_conv->convolution_dimension_numbers(), custom_call_conv->precision_config())); - TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv)); - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv)); + RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( custom_call_conv, static_conv, {})); return true; } @@ -1190,8 +1191,8 @@ absl::StatusOr RewriteDynamicConvolutionKernelGrad( custom_call_conv->batch_group_count(), window, custom_call_conv->convolution_dimension_numbers(), custom_call_conv->precision_config())); - TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv)); - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv)); + RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( custom_call_conv, static_conv, {})); return true; } @@ -1235,8 +1236,8 @@ absl::StatusOr RewriteDynamicReduceWindowSamePadding( HloInstruction* rewritten = hlo->AddInstruction(HloInstruction::CreateReduceWindow( hlo->shape(), input, init, window, hlo->called_computations()[0])); - TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten)); + RETURN_IF_ERROR( dynamic_dimension_inference->ForwardDynamicSize(hlo, rewritten, {})); return true; } @@ -1247,8 +1248,8 @@ absl::StatusOr RewriteDynamicSelectAndScatterSamePadding( HloInstruction* input = hlo->mutable_operand(0); HloInstruction* source = hlo->mutable_operand(1); HloInstruction* init = hlo->mutable_operand(2); - TF_ASSIGN_OR_RETURN(HloInstruction * input_padding_value, - ChooseIdentityValue(hlo, /*operand_number=*/0)); + ASSIGN_OR_RETURN(HloInstruction * input_padding_value, + ChooseIdentityValue(hlo, /*operand_number=*/0)); int64_t rank = hlo->shape().dimensions().size(); Window window = hlo->window(); std::vector padding_before(hlo->shape().dimensions().size(), @@ -1313,8 +1314,8 @@ absl::StatusOr RewriteDynamicSelectAndScatterSamePadding( HloInstruction* padded = MakePadHlo(rewritten, init, padding_configs).value(); rewritten = hlo->AddInstruction(HloInstruction::CreateDynamicSlice( hlo->shape(), padded, start_indices, hlo->shape().dimensions())); - TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten)); + RETURN_IF_ERROR( dynamic_dimension_inference->ForwardDynamicSize(hlo, rewritten, {})); return true; } @@ -1361,8 +1362,8 @@ absl::StatusOr RewriteDynamicConcat( dynamic_size)); } } - TF_RETURN_IF_ERROR(concat->ReplaceUsesWith(prev_users, rewritten_concat)); - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + RETURN_IF_ERROR(concat->ReplaceUsesWith(prev_users, rewritten_concat)); + RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( concat, rewritten_concat, {})); return true; } @@ -1444,9 +1445,9 @@ absl::StatusOr RewriteDynamicSort( auto rewritten_sort = hlo->AddInstruction( HloInstruction::CreateGetTupleElement(sort->shape(), sort_clone, 0)); for (HloInstruction* user : sort_users) { - TF_RETURN_IF_ERROR(sort->ReplaceUseWith(user, rewritten_sort)); + RETURN_IF_ERROR(sort->ReplaceUseWith(user, rewritten_sort)); } - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( sort, rewritten_sort, {})); if (hlo->parent()->root_instruction() == sort) { hlo->parent()->set_root_instruction(rewritten_sort); @@ -1562,8 +1563,8 @@ absl::StatusOr RewriteDynamicBinaryOp( } } if (changed) { - TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0)); - TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1)); + RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0)); + RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1)); } return changed; } @@ -1665,7 +1666,7 @@ absl::StatusOr RewriteDynamicUpdateSlice( update->shape(), HloOpcode::kSelect, pred, update, base_slice)); } } - TF_RETURN_IF_ERROR(dus->ReplaceOperandWith(1, update)); + RETURN_IF_ERROR(dus->ReplaceOperandWith(1, update)); return true; } @@ -1773,19 +1774,19 @@ absl::StatusOr RewriteDynamicReshape( HloInstruction* unflatten = reshape->parent()->AddInstruction( HloInstruction::CreateReshape(unflattened_shape, flatten), absl::StrCat(reshape->name(), ".unflatten")); - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( reshape, unflatten, {})); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool changed_unused, RewriteDynamicReshape(flatten, dynamic_dimension_inference)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( changed_unused, RewriteDynamicReshape(unflatten, dynamic_dimension_inference)); - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( reshape, unflatten, {})); - TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(unflatten)); + RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(unflatten)); return true; } @@ -1816,11 +1817,11 @@ absl::StatusOr RewriteDynamicReshape( reshape->ToString()); } - TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicReshapeSingleGroup( - reshape, input_dims, output_dims, - absl::MakeSpan(input_dynamic_dims), - absl::MakeSpan(output_dynamic_dims), - dynamic_dimension_inference)); + ASSIGN_OR_RETURN(bool c, RewriteDynamicReshapeSingleGroup( + reshape, input_dims, output_dims, + absl::MakeSpan(input_dynamic_dims), + absl::MakeSpan(output_dynamic_dims), + dynamic_dimension_inference)); changed |= c; } @@ -1828,8 +1829,8 @@ absl::StatusOr RewriteDynamicReshape( auto* static_reshape = reshape->AddInstruction(HloInstruction::CreateReshape( reshape->shape(), reshape->mutable_operand(0))); - TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(static_reshape)); - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( + RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(static_reshape)); + RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( reshape, static_reshape, {})); changed = true; } @@ -1890,14 +1891,14 @@ class DynamicShapeRemovingVisitor : public DfsHloRewriteVisitor { DynamicShapeRemovingVisitor visitor(op_supports_dynamism_handler, dynamic_shape_inference, execution_threads); - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + RETURN_IF_ERROR(computation->Accept(&visitor)); // If the outputs is required to be dynamic form, insert static to dynamic // conversion as root. if (require_dynamic_output) { HloInstruction* root = computation->root_instruction(); if (dynamic_shape_inference->HasDynamicDimension(root)) { - TF_ASSIGN_OR_RETURN(HloInstruction * new_root, - visitor.ConvertToDynamic(root)); + ASSIGN_OR_RETURN(HloInstruction * new_root, + visitor.ConvertToDynamic(root)); computation->set_root_instruction(new_root); } } @@ -1963,9 +1964,9 @@ absl::Status DynamicShapeRemovingVisitor::ConvertOperandsToDynamic( for (int64_t i = 0; i < inst->operand_count(); ++i) { auto operand = inst->mutable_operand(i); if (dynamic_dimension_inference_->HasDynamicDimension(operand)) { - TF_ASSIGN_OR_RETURN(auto dynamic_operand, - ConvertToDynamic(inst->mutable_operand(i))); - TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(i, dynamic_operand)); + ASSIGN_OR_RETURN(auto dynamic_operand, + ConvertToDynamic(inst->mutable_operand(i))); + RETURN_IF_ERROR(inst->ReplaceOperandWith(i, dynamic_operand)); MarkAsChanged(); } } @@ -2108,27 +2109,26 @@ absl::StatusOr DynamicPadder::RunImpl( // TODO(b/419842730): Support dynamic padder for graphs with complex CFGs. FlattenCallGraph flatten_call_graph; - TF_ASSIGN_OR_RETURN(bool changed, - flatten_call_graph.Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool changed, + flatten_call_graph.Run(module, execution_threads)); CallInliner call_inliner( /*single_call_site=*/false, /*update_domain=*/false); - TF_ASSIGN_OR_RETURN(bool inliner_changed, - call_inliner.Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool inliner_changed, + call_inliner.Run(module, execution_threads)); changed |= inliner_changed; // Run DCE before inference, in case earlier passes left dead instructions // that could cause us to insert PadToStatic when it isn't desired. HloDCE dce; - TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module, execution_threads)); changed |= dce_changed; - TF_ASSIGN_OR_RETURN( - DynamicDimensionInference dynamic_dimension_inference, - DynamicDimensionInference::Run( - module, options_.op_supports_dynamism_handler, - options_.custom_call_handler, options_.shape_check_mode, - options_.assertion_generator, execution_threads)); + ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run( + module, options_.op_supports_dynamism_handler, + options_.custom_call_handler, options_.shape_check_mode, + options_.assertion_generator, execution_threads)); changed |= dynamic_dimension_inference.changed(); std::vector computations = @@ -2145,26 +2145,26 @@ absl::StatusOr DynamicPadder::RunImpl( continue; } if (inst->opcode() == HloOpcode::kConcatenate) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool c, RewriteDynamicConcat(inst, &dynamic_dimension_inference)); changed |= c; continue; } if (inst->opcode() == HloOpcode::kReverse) { - TF_ASSIGN_OR_RETURN(bool c, - RewriteReverse(inst, &dynamic_dimension_inference)); + ASSIGN_OR_RETURN(bool c, + RewriteReverse(inst, &dynamic_dimension_inference)); changed |= c; continue; } if (inst->opcode() == HloOpcode::kSort) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool c, RewriteDynamicSort(inst, &dynamic_dimension_inference)); changed |= c; continue; } if (inst->opcode() == HloOpcode::kReshape || inst->opcode() == HloOpcode::kDynamicReshape) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool c, RewriteDynamicReshape(inst, &dynamic_dimension_inference)); changed |= c; continue; @@ -2173,50 +2173,50 @@ absl::StatusOr DynamicPadder::RunImpl( // Elementwise binary with dynamic shapes have implicit broadcast // semantics. if (inst->IsElementwiseBinary()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool c, RewriteDynamicBinaryOp(inst, &dynamic_dimension_inference)); changed |= c; continue; } if (inst->opcode() == HloOpcode::kDynamicUpdateSlice) { - TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicUpdateSlice( - inst, &dynamic_dimension_inference)); + ASSIGN_OR_RETURN(bool c, RewriteDynamicUpdateSlice( + inst, &dynamic_dimension_inference)); changed |= c; continue; } if (inst->IsCustomCall("DynamicConvolutionInputGrad")) { - TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionInputGrad( - inst, &dynamic_dimension_inference)); + ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionInputGrad( + inst, &dynamic_dimension_inference)); changed |= c; continue; } if (inst->IsCustomCall("DynamicConvolutionForward")) { - TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionForward( - inst, &dynamic_dimension_inference)); + ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionForward( + inst, &dynamic_dimension_inference)); changed |= c; continue; } if (inst->IsCustomCall("DynamicConvolutionKernelGrad")) { - TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionKernelGrad( - inst, &dynamic_dimension_inference)); + ASSIGN_OR_RETURN(bool c, RewriteDynamicConvolutionKernelGrad( + inst, &dynamic_dimension_inference)); changed |= c; continue; } if (inst->IsCustomCall("DynamicReduceWindowSamePadding")) { - TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicReduceWindowSamePadding( - inst, &dynamic_dimension_inference)); + ASSIGN_OR_RETURN(bool c, RewriteDynamicReduceWindowSamePadding( + inst, &dynamic_dimension_inference)); changed |= c; continue; } if (inst->IsCustomCall("DynamicSelectAndScatterSamePadding")) { - TF_ASSIGN_OR_RETURN(bool c, RewriteDynamicSelectAndScatterSamePadding( - inst, &dynamic_dimension_inference)); + ASSIGN_OR_RETURN(bool c, RewriteDynamicSelectAndScatterSamePadding( + inst, &dynamic_dimension_inference)); changed |= c; continue; } @@ -2245,15 +2245,15 @@ absl::StatusOr DynamicPadder::RunImpl( continue; } - TF_ASSIGN_OR_RETURN(HloInstruction * identity_value, - ChooseIdentityValue(inst, operand_num)); + ASSIGN_OR_RETURN(HloInstruction * identity_value, + ChooseIdentityValue(inst, operand_num)); if (identity_value == nullptr) { continue; } HloInstruction* padded = PadWithScalar( operand, input_dim, operand_dynamic_size, identity_value); - TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded)); + RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded)); operand = inst->mutable_operand(operand_num); changed = true; } @@ -2281,11 +2281,11 @@ absl::StatusOr DynamicPadder::RunImpl( // the output tensor to be in dynamic form. bool require_dynamic_output = options_.slice_dynamic_output && computation == module->entry_computation(); - TF_ASSIGN_OR_RETURN(bool c, - DynamicShapeRemovingVisitor::Run( - computation, options_.op_supports_dynamism_handler, - &dynamic_dimension_inference, execution_threads, - /*require_dynamic_output=*/require_dynamic_output)); + ASSIGN_OR_RETURN(bool c, + DynamicShapeRemovingVisitor::Run( + computation, options_.op_supports_dynamism_handler, + &dynamic_dimension_inference, execution_threads, + /*require_dynamic_output=*/require_dynamic_output)); changed |= c; } @@ -2299,7 +2299,7 @@ absl::StatusOr DynamicPadder::RunImpl( continue; } for (auto instruction : computation->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool c, ReplaceGetSize(instruction, &dynamic_dimension_inference)); changed |= c; } @@ -2310,17 +2310,17 @@ absl::StatusOr DynamicPadder::RunImpl( continue; } for (auto instruction : computation->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool c, ReplaceSetSize(instruction)); + ASSIGN_OR_RETURN(bool c, ReplaceSetSize(instruction)); changed |= c; - TF_ASSIGN_OR_RETURN(c, ReplaceSetBound(instruction)); + ASSIGN_OR_RETURN(c, ReplaceSetBound(instruction)); changed |= c; } } if (changed) { HloDCE dce; - TF_ASSIGN_OR_RETURN(bool c, dce.Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool c, dce.Run(module, execution_threads)); changed |= c; } diff --git a/third_party/xla/xla/service/dynamic_padder_test.cc b/third_party/xla/xla/service/dynamic_padder_test.cc index f60720ff9683e7..0f3d29af675aa6 100644 --- a/third_party/xla/xla/service/dynamic_padder_test.cc +++ b/third_party/xla/xla/service/dynamic_padder_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/error_spec.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" @@ -118,14 +119,14 @@ class DynamicPadderTest : public HloPjRtTestBase { std::move(op_supports_dynamism_handler); options.custom_call_handler = std::move(custom_call_handler); DynamicPadder padder(std::move(options)); - TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(&padder, module_.get())); + ASSIGN_OR_RETURN(bool changed, RunHloPass(&padder, module_.get())); if (!changed) return false; // Dynamic padder can add redundant tuple/get-tuple-element and copy // instructions. TupleSimplifier tuple_simplifier; - TF_RETURN_IF_ERROR(RunHloPass(&tuple_simplifier, module_.get()).status()); + RETURN_IF_ERROR(RunHloPass(&tuple_simplifier, module_.get()).status()); AlgebraicSimplifier alg_simplifier(AlgebraicSimplifierOptions{}); - TF_RETURN_IF_ERROR(RunHloPass(&alg_simplifier, module_.get()).status()); + RETURN_IF_ERROR(RunHloPass(&alg_simplifier, module_.get()).status()); return true; } diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index 4b26288f06f37d..6653239243fddc 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -320,7 +321,7 @@ absl::StatusOr EmitF16ToF4e2m1fn(llvm::Value* f16_value, // Truncate the mantissa to 1 bit and the exponent to 3 bits (not 2 bits, as // the type doesn't have Inf/NaN and can represent unbiased exponent 2). // This case, as well as the denormal, is handled below. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( llvm::Value * reduced_precision, EmitReducePrecisionIR( /*src_ty=*/F16, f16_value, @@ -863,9 +864,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( if (to_type == F8E8M0FNU || to_type == F4E2M1FN) { cast_type = F32; } - TF_ASSIGN_OR_RETURN(operand_value, - EmitF8fnuzToFloating(from_type, operand_value, - cast_type, b_, module_)); + ASSIGN_OR_RETURN(operand_value, + EmitF8fnuzToFloating(from_type, operand_value, + cast_type, b_, module_)); from_type = cast_type; if (from_type == to_type) { return operand_value; @@ -1152,10 +1153,10 @@ absl::StatusOr ElementalIrEmitter::EmitComplexUnaryOp( Select(FCmpOGT(a1, abs_b), a, FSub(max_abs_of_a1_and_b, one)); auto min_max_ratio = FDiv(min_abs_of_a1_and_b, max_abs_of_a1_and_b); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto log_of_max_abs_of_a1_and_b, EmitLog1p(component_type, max_abs_of_a1_and_b_minus_one)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto log_of_sqrt_part, EmitLog1p(component_type, FMul(min_max_ratio, min_max_ratio))); @@ -1163,7 +1164,7 @@ absl::StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto real_part = Select(FCmpUNO(r, r), min_abs_of_a1_and_b, r); // handles nan and inf values correctly - TF_ASSIGN_OR_RETURN(auto imag_part, EmitAtan2(component_type, b, a1, "")); + ASSIGN_OR_RETURN(auto imag_part, EmitAtan2(component_type, b, a1, "")); return EmitComposeComplex(op, real_part, imag_part, module_, b_); } case HloOpcode::kConvert: { @@ -1195,11 +1196,11 @@ absl::StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto half = llvm::ConstantFP::get(type, 0.5); auto pos_inf = llvm::ConstantFP::getInfinity(type); - TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a, "")); + ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a, "")); auto a_half = FMul(a, half); - TF_ASSIGN_OR_RETURN(auto exp_a_half, EmitExp(component_type, a_half, "")); - TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); - TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); + ASSIGN_OR_RETURN(auto exp_a_half, EmitExp(component_type, a_half, "")); + ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); + ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); auto exp_a_is_inf = FCmpOEQ(exp_a, pos_inf); auto b_is_zero = FCmpOEQ(b, zero); @@ -1228,10 +1229,10 @@ absl::StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto zero = llvm::ConstantFP::get(b->getType(), 0.0); auto one = llvm::ConstantFP::get(b->getType(), 1.0); auto b_is_zero = FCmpOEQ(b, zero); - TF_ASSIGN_OR_RETURN(auto expm1_a, EmitExpm1(component_type, a)); + ASSIGN_OR_RETURN(auto expm1_a, EmitExpm1(component_type, a)); auto exp_a = FAdd(expm1_a, one); - TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); - TF_ASSIGN_OR_RETURN(auto cos_b_minus_one, EmitCosm1(component_type, b)); + ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); + ASSIGN_OR_RETURN(auto cos_b_minus_one, EmitCosm1(component_type, b)); auto cos_b = FAdd(cos_b_minus_one, one); auto real_result = FAdd(FMul(expm1_a, cos_b), cos_b_minus_one); auto imag_result = Select(b_is_zero, zero, FMul(exp_a, sin_b)); @@ -1248,11 +1249,11 @@ absl::StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto x = EmitExtractReal(operand_value); auto y = EmitExtractImag(operand_value); auto type = y->getType(); - TF_ASSIGN_OR_RETURN(auto exp_y, EmitExp(component_type, y, "")); + ASSIGN_OR_RETURN(auto exp_y, EmitExp(component_type, y, "")); auto half_exp_y = FMul(llvm::ConstantFP::get(type, 0.5), exp_y); auto half_exp_neg_y = FDiv(llvm::ConstantFP::get(type, 0.5), exp_y); - TF_ASSIGN_OR_RETURN(auto sin_x, EmitSin(component_type, x)); - TF_ASSIGN_OR_RETURN(auto cos_x, EmitCos(component_type, x)); + ASSIGN_OR_RETURN(auto sin_x, EmitSin(component_type, x)); + ASSIGN_OR_RETURN(auto cos_x, EmitCos(component_type, x)); auto sinh_y = FSub(half_exp_y, half_exp_neg_y); auto cosh_y = FAdd(half_exp_y, half_exp_neg_y); llvm::Value* real_result = nullptr; @@ -1319,23 +1320,23 @@ absl::StatusOr ElementalIrEmitter::EmitComplexUnaryOp( // ULP to be arbitrarily small. For larger values of `a`, calculating the // numerator as Exp(2a)-Exp(-2a) vs Expm1(2a)-Expm1(-2a) return virtually // identical results. - TF_ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1, - EmitExpm1(component_type, two_a)); - TF_ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1, - EmitExpm1(component_type, neg_2a)); + ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1, + EmitExpm1(component_type, two_a)); + ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1, + EmitExpm1(component_type, neg_2a)); llvm::Value* real_numerator = FSub(exp_2a_m1, exp_neg_2a_m1); // We can use the identity cos(2b)+1 = cos(b)^2-sin(b)^2+cos(b)^2+sin(b)^2 // = 2cos(b)^2. This gives us the ability to be more precise when the // denominator is close to zero. - TF_ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b)); + ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b)); llvm::Value* four = llvm::ConstantFP::get(type, 4.0); llvm::Value* cos_b_sq = FMul(cos_b, cos_b); llvm::Value* two_cos_2b_p2 = FMul(cos_b_sq, four); // Similarly we can compute sin(2b) with the formula sin(2b) = // 2*sin(b)*cos(b). - TF_ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b)); + ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b)); llvm::Value* imag_numerator = FMul(four, FMul(cos_b, sin_b)); // About "x^2 is a better approximation than Expm1(x) + Expm1(x) @@ -1421,8 +1422,8 @@ absl::StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComplexAbs(component_type, operand_value); } case HloOpcode::kSign: { // Sign(c) = c / |c| - TF_ASSIGN_OR_RETURN(auto cplx_abs, - EmitComplexAbs(component_type, operand_value)); + ASSIGN_OR_RETURN(auto cplx_abs, + EmitComplexAbs(component_type, operand_value)); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = FCmpOEQ(cplx_abs, zero); @@ -1508,10 +1509,10 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( lhs_value = EmitF8e8m0fnuToF32(lhs_value, b_); rhs_value = EmitF8e8m0fnuToF32(rhs_value, b_); } else if (operand_type == F8E5M2FNUZ || operand_type == F8E4M3FNUZ) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( lhs_value, EmitF8fnuzToFloating(operand_type, lhs_value, F16, b_, module_)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( rhs_value, EmitF8fnuzToFloating(operand_type, rhs_value, F16, b_, module_)); } else if (operand_type == F8E3M4) { @@ -1588,7 +1589,7 @@ ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type, llvm::Value* div_sq = FMul(div, div); llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1); llvm::Value* one_p_div_sq = FAdd(one, div_sq); - TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq)); + ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq)); return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq); } @@ -1599,7 +1600,7 @@ absl::StatusOr ElementalIrEmitter::EmitComplexAbs( llvm::Value* sqrt; llvm::Value* real = EmitExtractReal(operand_value); llvm::Value* imag = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::tie(min, max, sqrt), EmitComplexAbsHelper(prim_type, real, imag, /*return_sqrt=*/true)); llvm::Value* result = FMul(max, sqrt); @@ -1617,13 +1618,13 @@ absl::StatusOr ElementalIrEmitter::EmitSqrtComplexAbs( llvm::Value* one_p_div_sq; llvm::Value* real = EmitExtractReal(operand_value); llvm::Value* imag = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::tie(min, max, one_p_div_sq), EmitComplexAbsHelper(prim_type, real, imag, /*return_sqrt=*/false)); - TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max)); - TF_ASSIGN_OR_RETURN(llvm::Value * pow, - EmitPow(prim_type, one_p_div_sq, - llvm::ConstantFP::get(max->getType(), .25), "")); + ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max)); + ASSIGN_OR_RETURN(llvm::Value * pow, + EmitPow(prim_type, one_p_div_sq, + llvm::ConstantFP::get(max->getType(), .25), "")); llvm::Value* result = FMul(sqrt_max, pow); // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN. // In such cases, we return `min` instead of `result`. @@ -1639,13 +1640,13 @@ absl::StatusOr ElementalIrEmitter::EmitRsqrtComplexAbs( llvm::Value* sqrt; llvm::Value* real = EmitExtractReal(operand_value); llvm::Value* imag = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::tie(min, max, sqrt), EmitComplexAbsHelper(prim_type, real, imag, /*return_sqrt=*/true)); - TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max)); - TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt)); + ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max)); + ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt)); llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt); - TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min)); + ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min)); // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN. // In such cases, we return rsqrt(min) instead of `result`. return Select(FCmpUNO(result, result), rsqrt_min, result); @@ -1805,10 +1806,10 @@ absl::StatusOr ElementalIrEmitter::EmitComplexLog( primitive_util::ComplexComponentType(op->shape().element_type()); auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a, "")); - TF_ASSIGN_OR_RETURN(llvm::Value * abs, - EmitComplexAbs(component_type, operand_value)); - TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs)); + ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a, "")); + ASSIGN_OR_RETURN(llvm::Value * abs, + EmitComplexAbs(component_type, operand_value)); + ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs)); return EmitComposeComplex(op, log_abs, angle, module_, b_); } @@ -1825,17 +1826,17 @@ absl::StatusOr ElementalIrEmitter::EmitComplexSqrt( llvm::Type* type = static_cast(operand_value->getType()) ->getElementType(0); - TF_ASSIGN_OR_RETURN(llvm::Value * r, - EmitSqrtComplexAbs(prim_type, operand_value)); + ASSIGN_OR_RETURN(llvm::Value * r, + EmitSqrtComplexAbs(prim_type, operand_value)); llvm::Value* a = EmitExtractReal(operand_value); llvm::Value* b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, "")); + ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, "")); llvm::Value* c = llvm::ConstantFP::get(type, 0.5); llvm::Value* angle = FMul(t, c); - TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle)); - TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle)); + ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle)); + ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle)); llvm::Value* real_part; llvm::Value* imag_part; @@ -1882,17 +1883,17 @@ absl::StatusOr ElementalIrEmitter::EmitComplexRsqrt( llvm::Type* type = static_cast(operand_value->getType()) ->getElementType(0); - TF_ASSIGN_OR_RETURN(llvm::Value * r, - EmitRsqrtComplexAbs(prim_type, operand_value)); + ASSIGN_OR_RETURN(llvm::Value * r, + EmitRsqrtComplexAbs(prim_type, operand_value)); llvm::Value* a = EmitExtractReal(operand_value); llvm::Value* b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, "")); + ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, "")); llvm::Value* c = llvm::ConstantFP::get(type, -0.5); llvm::Value* angle = FMul(t, c); - TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle)); - TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle)); + ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle)); + ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle)); llvm::Value* real_part = FMul(r, cos); llvm::Value* imag_part = FMul(r, sin); @@ -1950,19 +1951,19 @@ absl::StatusOr ElementalIrEmitter::EmitComplexPower( auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); - TF_ASSIGN_OR_RETURN(auto abs, EmitComplexAbs(component_type, lhs_value)); - TF_ASSIGN_OR_RETURN(auto abs_to_c, EmitPow(component_type, abs, c, "")); + ASSIGN_OR_RETURN(auto abs, EmitComplexAbs(component_type, lhs_value)); + ASSIGN_OR_RETURN(auto abs_to_c, EmitPow(component_type, abs, c, "")); auto neg_d = FNeg(d); - TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, "")); + ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, "")); auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); - TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, - EmitExp(component_type, neg_d_arg_lhs, "")); + ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, + EmitExp(component_type, neg_d_arg_lhs, "")); auto coeff = FMul(abs_to_c, e_to_neg_d_arg_lhs); - TF_ASSIGN_OR_RETURN(auto ln_abs, EmitLog(component_type, abs)); + ASSIGN_OR_RETURN(auto ln_abs, EmitLog(component_type, abs)); auto q = FAdd(FMul(c, arg_lhs), FMul(d, ln_abs)); - TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); - TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); + ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); + ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); llvm::Value* inf = llvm::ConstantFP::getInfinity(a->getType()); auto zero = llvm::ConstantFP::get(a->getType(), 0); @@ -2058,13 +2059,13 @@ absl::StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) auto y = lhs_value; auto x = rhs_value; - TF_ASSIGN_OR_RETURN(auto x_squared, EmitComplexMultiply(op, x, x)); - TF_ASSIGN_OR_RETURN(auto y_squared, EmitComplexMultiply(op, y, y)); - TF_ASSIGN_OR_RETURN(auto x_squared_plus_y_squared, - EmitComplexAdd(op, x_squared, y_squared)); + ASSIGN_OR_RETURN(auto x_squared, EmitComplexMultiply(op, x, x)); + ASSIGN_OR_RETURN(auto y_squared, EmitComplexMultiply(op, y, y)); + ASSIGN_OR_RETURN(auto x_squared_plus_y_squared, + EmitComplexAdd(op, x_squared, y_squared)); auto component_type = primitive_util::ComplexComponentType(op->shape().element_type()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto sqrt_x_squared_plus_y_squared, EmitComplexSqrt(op, component_type, x_squared_plus_y_squared)); auto type = @@ -2072,12 +2073,12 @@ absl::StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto zero = llvm::ConstantFP::get(type, 0.0); auto one = llvm::ConstantFP::get(type, 1.0); auto i = EmitComposeComplex(op, zero, one, module_, b_); - TF_ASSIGN_OR_RETURN(auto i_times_y, EmitComplexMultiply(op, i, y)); - TF_ASSIGN_OR_RETURN(auto x_plus_iy, EmitComplexAdd(op, x, i_times_y)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(auto i_times_y, EmitComplexMultiply(op, i, y)); + ASSIGN_OR_RETURN(auto x_plus_iy, EmitComplexAdd(op, x, i_times_y)); + ASSIGN_OR_RETURN( auto div_result, EmitComplexDivide(op, x_plus_iy, sqrt_x_squared_plus_y_squared)); - TF_ASSIGN_OR_RETURN(auto log_result, EmitComplexLog(op, div_result)); + ASSIGN_OR_RETURN(auto log_result, EmitComplexLog(op, div_result)); auto negative_one = llvm::ConstantFP::get(type, -1.0); auto negative_i = EmitComposeComplex(op, zero, negative_one, module_, b_); return EmitComplexMultiply(op, negative_i, log_result); @@ -2121,7 +2122,7 @@ absl::StatusOr ElementalIrEmitter::EmitSqrt(PrimitiveType, absl::StatusOr ElementalIrEmitter::EmitRsqrt( PrimitiveType prim_type, llvm::Value* value) { - TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value)); + ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value)); return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt); } @@ -2163,12 +2164,12 @@ absl::StatusOr ElementalIrEmitter::EmitCosm1( 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, 4.1666666666666666609054E-2, }; - TF_ASSIGN_OR_RETURN(auto cos_x, EmitCos(prim_type, x)); + ASSIGN_OR_RETURN(auto cos_x, EmitCos(prim_type, x)); auto for_large_x = FAdd(cos_x, negative_one); auto xx = FMul(x, x); auto xxxx = FMul(xx, xx); - TF_ASSIGN_OR_RETURN(auto poly, EvaluatePolynomial(type, xx, kCoeffs)); + ASSIGN_OR_RETURN(auto poly, EvaluatePolynomial(type, xx, kCoeffs)); auto for_small_x = FAdd(FMul(xxxx, poly), FMul(negative_half, xx)); // (pi/4)^2 is approximately 0.61685 @@ -2198,8 +2199,8 @@ absl::StatusOr ElementalIrEmitter::EmitExpm1( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {x}, {type}, b_); // Use a naive exp(x)-1 calculation if |x| is > 0.5 auto x_magnitude_is_large = FCmpOGT(abs_x, half); - TF_ASSIGN_OR_RETURN(auto tanh_of_x_over_two, EmitTanh(prim_type, x_over_two)); - TF_ASSIGN_OR_RETURN(auto exp_of_x, EmitExp(prim_type, x, "")); + ASSIGN_OR_RETURN(auto tanh_of_x_over_two, EmitTanh(prim_type, x_over_two)); + ASSIGN_OR_RETURN(auto exp_of_x, EmitExp(prim_type, x, "")); auto exp_of_x_plus_one = FAdd(exp_of_x, one); auto exp_of_x_minus_one = FSub(exp_of_x, one); auto expm1_of_x = FMul(tanh_of_x_over_two, exp_of_x_plus_one); @@ -2221,8 +2222,8 @@ absl::StatusOr ElementalIrEmitter::EmitCbrt( auto third = llvm::ConstantFP::get(type, 1.0 / 3.0); auto abs_value = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - TF_ASSIGN_OR_RETURN(llvm::Value * abs_res, - EmitPow(prim_type, abs_value, third, "")); + ASSIGN_OR_RETURN(llvm::Value * abs_res, + EmitPow(prim_type, abs_value, third, "")); auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign, {abs_res, value}, {type}, b_); return signed_res; @@ -2579,12 +2580,12 @@ absl::StatusOr ElementalIrEmitter::EmitElementalSelect( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { - TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, - operand_to_generator.at(hlo->operand(0))(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, - operand_to_generator.at(hlo->operand(1))(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, - operand_to_generator.at(hlo->operand(2))(index)); + ASSIGN_OR_RETURN(llvm::Value * pred_value, + operand_to_generator.at(hlo->operand(0))(index)); + ASSIGN_OR_RETURN(llvm::Value * on_true_value, + operand_to_generator.at(hlo->operand(1))(index)); + ASSIGN_OR_RETURN(llvm::Value * on_false_value, + operand_to_generator.at(hlo->operand(2))(index)); return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value, on_false_value); } @@ -2593,12 +2594,12 @@ absl::StatusOr ElementalIrEmitter::EmitElementalClamp( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& index) { - TF_ASSIGN_OR_RETURN(llvm::Value * min_value, - operand_to_generator.at(hlo->operand(0))(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, - operand_to_generator.at(hlo->operand(1))(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * max_value, - operand_to_generator.at(hlo->operand(2))(index)); + ASSIGN_OR_RETURN(llvm::Value * min_value, + operand_to_generator.at(hlo->operand(0))(index)); + ASSIGN_OR_RETURN(llvm::Value * arg_value, + operand_to_generator.at(hlo->operand(1))(index)); + ASSIGN_OR_RETURN(llvm::Value * max_value, + operand_to_generator.at(hlo->operand(2))(index)); PrimitiveType prim_type = hlo->shape().element_type(); if (primitive_util::IsFloatingPointType(prim_type)) { return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value, ""), ""); @@ -2706,8 +2707,8 @@ absl::StatusOr ElementalIrEmitter::EmitElementalConcatenate( operand_multi_index, operand->shape(), source_index.GetType()); } - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(operand_index)); + ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(operand_index)); output->addIncoming(value, b_->GetInsertBlock()); b_->SetInsertPoint(init_block, saved_insert_point); } @@ -2787,9 +2788,8 @@ absl::StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( return llvm::ConstantInt::get(index_type, c); }; llvm_ir::IrArray::Index zero_index(index_type); - TF_ASSIGN_OR_RETURN( - llvm::Value * start_index_value, - operand_to_generator.at(hlo->operand(1 + i))(zero_index)); + ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(1 + i))(zero_index)); // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) @@ -2919,8 +2919,8 @@ absl::StatusOr ElementalIrEmitter::EmitElementalGather( if (indices_shape.dimensions().size() == dim_numbers.index_vector_dim()) { IrArray::Index gather_index_index(gather_index_index_components, indices_shape, index_type); - TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, - indices_generator(gather_index_index)); + ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); add_to_operand_index(gather_dim_component, 0); } else { int64_t index_vector_size = @@ -2930,8 +2930,8 @@ absl::StatusOr ElementalIrEmitter::EmitElementalGather( index.GetConstantWithIndexType(i); IrArray::Index gather_index_index(gather_index_index_components, indices_shape, index_type); - TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, - indices_generator(gather_index_index)); + ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); add_to_operand_index(gather_dim_component, i); } } @@ -2962,9 +2962,8 @@ ElementalIrEmitter::EmitElementalDynamicUpdateSlice( }; llvm_ir::IrArray::Index zero_index(index_type); - TF_ASSIGN_OR_RETURN( - llvm::Value * start_index_value, - operand_to_generator.at(hlo->operand(2 + i))(zero_index)); + ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(2 + i))(zero_index)); // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) @@ -3014,14 +3013,14 @@ ElementalIrEmitter::EmitElementalDynamicUpdateSlice( } llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(), index.GetType()); - TF_ASSIGN_OR_RETURN(llvm::Value * true_value, - operand_to_generator.at(update_hlo)(update_index)); + ASSIGN_OR_RETURN(llvm::Value * true_value, + operand_to_generator.at(update_hlo)(update_index)); Store(true_value, ret_value_addr); // Handle false BB (return data from 'input') SetToFirstInsertPoint(if_data.false_block, b_); - TF_ASSIGN_OR_RETURN(llvm::Value * false_value, - operand_to_generator.at(input_hlo)(index)); + ASSIGN_OR_RETURN(llvm::Value * false_value, + operand_to_generator.at(input_hlo)(index)); Store(false_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); @@ -3072,14 +3071,14 @@ absl::StatusOr ElementalIrEmitter::EmitElementalPad( SetToFirstInsertPoint(if_data.true_block, b_); llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(), padded_index.GetType()); - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))(index)); + ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(index)); Store(operand_value, ret_value_addr); SetToFirstInsertPoint(if_data.false_block, b_); - TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))( - IrArray::Index(index.GetType()))); + ASSIGN_OR_RETURN(llvm::Value * padding_value, + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index.GetType()))); Store(padding_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); @@ -3192,8 +3191,8 @@ absl::StatusOr ElementalIrEmitter::EmitElementalDot( llvm::Value* current_accumulator = Load(accumulator_alloca->getAllocatedType(), accumulator_alloca); - TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); - TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); + ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); + ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); if (primitive_type == BF16) { lhs_value = b_->CreateFPExt(lhs_value, b_->getFloatTy()); @@ -3259,8 +3258,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kTanh: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))(index)); + ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(index)); return EmitUnaryOp(hlo, operand_value); }; case HloOpcode::kAdd: @@ -3284,10 +3283,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const IrArray::Index& index) -> absl::StatusOr { const HloInstruction* lhs = hlo->operand(0); const HloInstruction* rhs = hlo->operand(1); - TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, - operand_to_generator.at(lhs)(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, - operand_to_generator.at(rhs)(index)); + ASSIGN_OR_RETURN(llvm::Value * lhs_value, + operand_to_generator.at(lhs)(index)); + ASSIGN_OR_RETURN(llvm::Value * rhs_value, + operand_to_generator.at(rhs)(index)); return EmitBinaryOp(hlo, lhs_value, rhs_value); }; case HloOpcode::kSelect: @@ -3303,8 +3302,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kReducePrecision: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))(index)); + ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(index)); return EmitReducePrecision(hlo, operand_value); }; case HloOpcode::kConcatenate: @@ -3393,7 +3392,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( IrArray::Index source_index(target_index.multidim(), hlo->operand(0)->shape(), target_index.GetType()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(source_index)); return operand_value; @@ -3422,8 +3421,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const IrArray::Index& index) -> absl::StatusOr { std::vector operands; for (int i = 0; i < hlo->operand_count(); i++) { - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(i))(index)); + ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(i))(index)); operands.push_back(operand_value); } return EmitElementalMap(Cast(hlo), operands); @@ -3508,7 +3507,7 @@ llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs, absl::StatusOr ElementalIrEmitter::EmitElementalMap( const HloMapInstruction* map_instr, absl::Span elemental_operands) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector values, EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands, llvm_ir::IrName(map_instr), /*is_reducer=*/false)); @@ -3549,7 +3548,7 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduceWindow( accum_ptrs.push_back(accum_ptr); { auto initial_value_generator = initial_value_generators[operand_index]; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( llvm::Value* const init_value, initial_value_generator(llvm_ir::IrArray::Index(index.GetType()))); Store(init_value, accum_ptr); @@ -3620,18 +3619,18 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduceWindow( IrArray::Index input_index(input_multi_index, reduce_window->inputs()[0]->shape(), index_type); for (int64_t operand_idx = 0; operand_idx < input_count; ++operand_idx) { - TF_ASSIGN_OR_RETURN(llvm::Value * input_value, - input_generators[operand_idx](input_index)); + ASSIGN_OR_RETURN(llvm::Value * input_value, + input_generators[operand_idx](input_index)); input_values[input_count + operand_idx] = input_value; input_values[operand_idx] = Load(llvm::cast(accum_ptrs[operand_idx]) ->getAllocatedType(), accum_ptrs[operand_idx]); } - TF_ASSIGN_OR_RETURN(std::vector accum_values, - EmitThreadLocalCall(*reduce_window->to_apply(), - input_values, "reducer_function", - /*is_reducer=*/true)); + ASSIGN_OR_RETURN(std::vector accum_values, + EmitThreadLocalCall(*reduce_window->to_apply(), input_values, + "reducer_function", + /*is_reducer=*/true)); for (int64_t operand_idx = 0; operand_idx < accum_values.size(); ++operand_idx) { @@ -3672,7 +3671,7 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduce( // Initialize an accumulator with init_value. llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( accumulator_llvm_type, "accumulator_" + std::to_string(i), b()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( llvm::Value* const init_value, initial_value_generators[i](llvm_ir::IrArray::Index(index_type))); Store(init_value, accumulator_addr); @@ -3716,15 +3715,14 @@ absl::StatusOr ElementalIrEmitter::EmitElementalReduce( } for (int i = 0; i < accumulators_count; i++) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, - input_generators[i](input_index)); + ASSIGN_OR_RETURN(llvm::Value* const input_element, + input_generators[i](input_index)); reduction_operands.push_back(input_element); } - TF_ASSIGN_OR_RETURN( - std::vector results, - EmitThreadLocalCall(*reduce->to_apply(), reduction_operands, - "reduce_function", /*is_reducer=*/true)); + ASSIGN_OR_RETURN(std::vector results, + EmitThreadLocalCall(*reduce->to_apply(), reduction_operands, + "reduce_function", /*is_reducer=*/true)); CHECK(results.size() == accumulators_count); for (int i = 0; i < accumulators_count; i++) { @@ -3896,12 +3894,12 @@ absl::StatusOr ElementalIrEmitter::EmitConvolution( llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(), b_->getInt64Ty()); - TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, - input_generator(input_index)); + ASSIGN_OR_RETURN(llvm::Value* const input_value, + input_generator(input_index)); llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(), b_->getInt64Ty()); - TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value, - kernel_generator(kernel_index)); + ASSIGN_OR_RETURN(llvm::Value* const kernel_value, + kernel_generator(kernel_index)); llvm::Value* sum = EmitMulAdd(input_value, kernel_value, Load(sum_address->getAllocatedType(), sum_address), diff --git a/third_party/xla/xla/service/executable.cc b/third_party/xla/xla/service/executable.cc index a9f8da25d12d1c..d68851da5e762b 100644 --- a/third_party/xla/xla/service/executable.cc +++ b/third_party/xla/xla/service/executable.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" @@ -81,8 +82,8 @@ absl::StatusOr Executable::ExecuteOnStream( absl::StatusOr result = ExecuteAsyncOnStream(run_options, arguments); absl::Status blocking_status = run_options->stream()->BlockHostUntilDone(); - TF_RETURN_IF_ERROR(result.status()); - TF_RETURN_IF_ERROR(blocking_status); + RETURN_IF_ERROR(result.status()); + RETURN_IF_ERROR(blocking_status); return result; } @@ -104,8 +105,8 @@ absl::StatusOr Executable::ExecuteAsyncOnStream( for (const ShapedBuffer* arg : arguments) { args.emplace_back(MakeMaybeOwningDeviceAddressTree(*arg)); } - TF_ASSIGN_OR_RETURN(ExecutionOutput out, - ExecuteAsyncOnStream(run_options, std::move(args))); + ASSIGN_OR_RETURN(ExecutionOutput out, + ExecuteAsyncOnStream(run_options, std::move(args))); return out.ConsumeResult(); } @@ -115,8 +116,8 @@ absl::StatusOr Executable::ExecuteOnStream( absl::StatusOr result = ExecuteAsyncOnStream(run_options, std::move(arguments)); absl::Status blocking_status = run_options->stream()->BlockHostUntilDone(); - TF_RETURN_IF_ERROR(result.status()); - TF_RETURN_IF_ERROR(blocking_status); + RETURN_IF_ERROR(result.status()); + RETURN_IF_ERROR(blocking_status); return result; } @@ -129,8 +130,7 @@ absl::StatusOr> Executable::ExecuteOnStreams( return_values.reserve(run_options.size()); if (run_options.size() == 1) { - TF_ASSIGN_OR_RETURN(auto rv, - ExecuteOnStream(&run_options[0], arguments[0])); + ASSIGN_OR_RETURN(auto rv, ExecuteOnStream(&run_options[0], arguments[0])); return_values.push_back(std::move(rv)); return std::move(return_values); } @@ -139,13 +139,13 @@ absl::StatusOr> Executable::ExecuteOnStreams( // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched // executions may never complete if not all executions are running. - TF_ASSIGN_OR_RETURN(auto rv, - ExecuteAsyncOnStream(&run_options[i], arguments[i])); + ASSIGN_OR_RETURN(auto rv, + ExecuteAsyncOnStream(&run_options[i], arguments[i])); return_values.push_back(std::move(rv)); } for (const auto& options : run_options) { TF_RET_CHECK(options.stream() != nullptr); - TF_RETURN_IF_ERROR(options.stream()->BlockHostUntilDone()); + RETURN_IF_ERROR(options.stream()->BlockHostUntilDone()); } return std::move(return_values); } @@ -156,8 +156,8 @@ absl::StatusOr Executable::ExecuteOnStreamWrapper( absl::StatusOr result = ExecuteAsyncOnStreamWrapper(run_options, arguments); absl::Status block_status = run_options->stream()->BlockHostUntilDone(); - TF_RETURN_IF_ERROR(result.status()); - TF_RETURN_IF_ERROR(block_status); + RETURN_IF_ERROR(result.status()); + RETURN_IF_ERROR(block_status); return result; } @@ -167,8 +167,8 @@ absl::StatusOr Executable::ExecuteOnStreamWrapper( absl::StatusOr result = ExecuteAsyncOnStreamWrapper(run_options, std::move(arguments)); absl::Status block_status = run_options->stream()->BlockHostUntilDone(); - TF_RETURN_IF_ERROR(result.status()); - TF_RETURN_IF_ERROR(block_status); + RETURN_IF_ERROR(result.status()); + RETURN_IF_ERROR(block_status); return result; } @@ -202,7 +202,7 @@ absl::Status ExecuteWrapperAfterExecution( if (state.profile != nullptr) { // We block instead of using an async callback because reading the timer // value may call back into the driver on GPU, which is not allowed. - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + RETURN_IF_ERROR(stream->BlockHostUntilDone()); const int64_t executable_size_in_bytes = executable->SizeOfGeneratedCodeInBytes(); @@ -231,7 +231,7 @@ absl::StatusOr Executable::ExecuteAsyncOnStreamWrapper( auto state = ExecuteWrapperBeforeExecution(*this, run_options); absl::StatusOr return_value = ExecuteAsyncOnStream(run_options, arguments); - TF_RETURN_IF_ERROR(ExecuteWrapperAfterExecution( + RETURN_IF_ERROR(ExecuteWrapperAfterExecution( this, state, return_value.status(), run_options->stream())); return return_value; } @@ -242,7 +242,7 @@ absl::StatusOr Executable::ExecuteAsyncOnStreamWrapper( auto state = ExecuteWrapperBeforeExecution(*this, run_options); absl::StatusOr return_value = ExecuteAsyncOnStream(run_options, std::move(arguments)); - TF_RETURN_IF_ERROR(ExecuteWrapperAfterExecution( + RETURN_IF_ERROR(ExecuteWrapperAfterExecution( this, state, return_value.status(), run_options->stream())); return return_value; } diff --git a/third_party/xla/xla/service/execution_tracker.cc b/third_party/xla/xla/service/execution_tracker.cc index 58534fd751547b..d8d990bdf52881 100644 --- a/third_party/xla/xla/service/execution_tracker.cc +++ b/third_party/xla/xla/service/execution_tracker.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" #include "tsl/platform/logging.h" @@ -39,7 +40,7 @@ AsyncExecution::AsyncExecution(Backend* backend, absl::Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + RETURN_IF_ERROR(stream->BlockHostUntilDone()); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/float8_fnuz_ir_emitter.cc b/third_party/xla/xla/service/float8_fnuz_ir_emitter.cc index e335b9ec67037d..9afa201db955c0 100644 --- a/third_party/xla/xla/service/float8_fnuz_ir_emitter.cc +++ b/third_party/xla/xla/service/float8_fnuz_ir_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/APFloat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Intrinsics.h" @@ -116,11 +117,11 @@ absl::StatusOr ComputeMaximumValue(PrimitiveType input_type, TF_RET_CHECK(primitive_util::IsFloatingPointType(output_type)); TF_RET_CHECK(BitWidth(input_type) > BitWidth(output_type)); - TF_ASSIGN_OR_RETURN(auto output_semantics, - PrimitiveTypeToAPFloatSemantics(output_type)); + ASSIGN_OR_RETURN(auto output_semantics, + PrimitiveTypeToAPFloatSemantics(output_type)); - TF_ASSIGN_OR_RETURN(auto input_semantics, - PrimitiveTypeToAPFloatSemantics(input_type)); + ASSIGN_OR_RETURN(auto input_semantics, + PrimitiveTypeToAPFloatSemantics(input_type)); // Compute the largest number of the output type and convert it to the input // type. @@ -160,8 +161,8 @@ absl::StatusOr IsInputOutsideOutputRange( // Ignore the sign bit. llvm::Value* non_sign_bits = b->CreateAnd(value, bit_mask); - TF_ASSIGN_OR_RETURN(uint64_t maximum_value, - ComputeMaximumValue(input_type, output_type, b)); + ASSIGN_OR_RETURN(uint64_t maximum_value, + ComputeMaximumValue(input_type, output_type, b)); // Compare against the maximum value. llvm::Type* uint_type = b->getIntNTy(BitWidth(input_type)); @@ -389,8 +390,8 @@ absl::StatusOr DynamicRoundingBias(PrimitiveType input_type, llvm::Type* int_type = b->getIntNTy(BitWidth(input_type)); // Find the bit position of the last mantissa bit. - TF_ASSIGN_OR_RETURN(llvm::Value * shift, - LastMantissaBit(input_type, value, output_type, b)); + ASSIGN_OR_RETURN(llvm::Value * shift, + LastMantissaBit(input_type, value, output_type, b)); // Compute the mask to select that bit. llvm::Value* last_mantissa_bit_mask = @@ -511,7 +512,7 @@ llvm::Value* BuildOutputSign(llvm::Value* sign, PrimitiveType output_type, } absl::StatusOr GetQNaN(PrimitiveType type) { - TF_ASSIGN_OR_RETURN(auto semantics, PrimitiveTypeToAPFloatSemantics(type)); + ASSIGN_OR_RETURN(auto semantics, PrimitiveTypeToAPFloatSemantics(type)); return llvm::APFloat::getQNaN(*semantics).bitcastToAPInt().getZExtValue(); } @@ -534,10 +535,10 @@ absl::StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, const std::string lut_name = PrimitiveType_Name(input_type) + "To" + PrimitiveType_Name(output_type) + "LUT"; - TF_ASSIGN_OR_RETURN(auto input_semantics, - PrimitiveTypeToAPFloatSemantics(input_type)); - TF_ASSIGN_OR_RETURN(auto output_semantics, - PrimitiveTypeToAPFloatSemantics(output_type)); + ASSIGN_OR_RETURN(auto input_semantics, + PrimitiveTypeToAPFloatSemantics(input_type)); + ASSIGN_OR_RETURN(auto output_semantics, + PrimitiveTypeToAPFloatSemantics(output_type)); llvm::Constant* global_result_lut_array = module->getOrInsertGlobal( lut_name, result_lut_array_type, [&]() -> llvm::GlobalVariable* { @@ -570,7 +571,7 @@ absl::StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, }); // Check for NaN, since it's a special case. - TF_ASSIGN_OR_RETURN(const uint64_t input_qnan, GetQNaN(input_type)); + ASSIGN_OR_RETURN(const uint64_t input_qnan, GetQNaN(input_type)); llvm::Value* nan_pred = b->CreateICmpEQ( f8_value, llvm::ConstantInt::get(b->getInt8Ty(), input_qnan)); @@ -603,8 +604,8 @@ absl::StatusOr EmitF8fnuzToFloating(PrimitiveType input_type, llvm::Value* result = b->CreateOr(sign, result_abs); // Bitcast to the output type. - TF_ASSIGN_OR_RETURN(auto type, PrimitiveTypeToLLVMType(b, output_type)); - TF_ASSIGN_OR_RETURN(const uint64_t output_qnan, GetQNaN(output_type)); + ASSIGN_OR_RETURN(auto type, PrimitiveTypeToLLVMType(b, output_type)); + ASSIGN_OR_RETURN(const uint64_t output_qnan, GetQNaN(output_type)); return b->CreateBitCast( b->CreateSelect(nan_pred, llvm::ConstantInt::get(output_int_type, output_qnan), diff --git a/third_party/xla/xla/service/gather_expander.cc b/third_party/xla/xla/service/gather_expander.cc index af17d739e2decc..2218c454ff6b12 100644 --- a/third_party/xla/xla/service/gather_expander.cc +++ b/third_party/xla/xla/service/gather_expander.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" @@ -75,7 +76,7 @@ absl::StatusOr TransposeIndexVectorDimToLast( absl::StatusOr CanonicalizeGatherIndices( HloInstruction* start_indices, int64_t index_vector_dim) { // Transpose the non-index-vector dimensions to the front. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * transposed_start_indices, TransposeIndexVectorDimToLast(start_indices, index_vector_dim)); bool indices_are_scalar = @@ -164,29 +165,28 @@ absl::StatusOr> GatherLoopBody( if (has_scalar_indices) { // In this case start_indices has rank 1 and induction_var_as_vector (of // shape {1}) is an index into this rank 1 tensor. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( index_vector, MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1})); } else { // In this case start_indices has rank 2 and induction_var_as_vector (of // shape {1}) is an index into just the first dimension of this rank 2 // tensor. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * index_into_start_indices, PadVectorWithZeros(induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); int64_t index_vector_size = start_indices->shape().dimensions(1); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * index_vector_2d, MakeDynamicSliceHlo(start_indices, index_into_start_indices, {1, index_vector_size})); - TF_ASSIGN_OR_RETURN(index_vector, - ElideDegenerateDims(index_vector_2d, {0})); + ASSIGN_OR_RETURN(index_vector, ElideDegenerateDims(index_vector_2d, {0})); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * gathered_slice_start, ExpandIndexVectorIntoOperandSpace( orig_start_indices_shape, operand->shape().dimensions().size(), @@ -194,27 +194,26 @@ absl::StatusOr> GatherLoopBody( dim_numbers.start_indices_batching_dims(), dim_numbers.operand_batching_dims(), index_vector, induction_var)); - TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, - MakeDynamicSliceHlo(operand, gathered_slice_start, - gather.gather_slice_sizes())); + ASSIGN_OR_RETURN(HloInstruction * gathered_slice, + MakeDynamicSliceHlo(operand, gathered_slice_start, + gather.gather_slice_sizes())); - TF_ASSIGN_OR_RETURN( - HloInstruction* const gathered_slice_with_dims_collapsed, - ElideDegenerateDims(gathered_slice, - GetDegeneratedSliceDims(dim_numbers))); + ASSIGN_OR_RETURN(HloInstruction* const gathered_slice_with_dims_collapsed, + ElideDegenerateDims(gathered_slice, + GetDegeneratedSliceDims(dim_numbers))); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction* const gathered_slice_for_update, PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction* const index_vector_into_accumulator, PadVectorWithZeros( induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/ gathered_slice_with_dims_collapsed->shape().dimensions().size())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction* const updated_accumulator, MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update, index_vector_into_accumulator)); @@ -335,9 +334,9 @@ absl::StatusOr GatherExpander::ExpandInstruction( Shape broadcast_operand_shape = ShapeUtil::DeleteDimensions( GetDegeneratedSliceDims(gather_instr->gather_dimension_numbers()), gather_instr->operand(0)->shape()); - TF_ASSIGN_OR_RETURN(HloInstruction * broadcast_operand, - MakeReshapeHlo(broadcast_operand_shape, - gather_instr->mutable_operand(0))); + ASSIGN_OR_RETURN(HloInstruction * broadcast_operand, + MakeReshapeHlo(broadcast_operand_shape, + gather_instr->mutable_operand(0))); gather_instr->SetupDerivedInstruction(broadcast_operand); HloInstruction* broadcast = MakeBroadcastHlo(broadcast_operand, @@ -364,7 +363,7 @@ absl::StatusOr GatherExpander::ExpandInstruction( gather_instr->ToString()); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * canonical_start_indices, CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim())); @@ -386,12 +385,12 @@ absl::StatusOr GatherExpander::ExpandInstruction( }, gather_instr->metadata()); - TF_ASSIGN_OR_RETURN(std::vector gather_loop_result, - gather_loop_result_or_error); + ASSIGN_OR_RETURN(std::vector gather_loop_result, + gather_loop_result_or_error); HloInstruction* accumulator_result = gather_loop_result.back(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction* const accumulator_with_batch_dims_decanonicalized, AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result, dim_numbers.index_vector_dim())); diff --git a/third_party/xla/xla/service/gather_scatter_utils.cc b/third_party/xla/xla/service/gather_scatter_utils.cc index 024483703b5b31..c4a684c919ac6f 100644 --- a/third_party/xla/xla/service/gather_scatter_utils.cc +++ b/third_party/xla/xla/service/gather_scatter_utils.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout_util.h" @@ -99,20 +100,20 @@ absl::StatusOr TransformStartIndices( if (index_vector_dim == rank) { // Add a size 1 dimension to the indices if the index_vector_dim is // implicit. - TF_ASSIGN_OR_RETURN(indices, - InsertDegenerateDims(indices, {index_vector_dim})); + ASSIGN_OR_RETURN(indices, + InsertDegenerateDims(indices, {index_vector_dim})); ++rank; } else if (index_vector_dim < rank - 1) { // Ensure index_vector_dim is the last dimension in scatter_indices. - TF_ASSIGN_OR_RETURN(indices, - MoveDimensionToEnd(indices, index_vector_dim, rank)); + ASSIGN_OR_RETURN(indices, + MoveDimensionToEnd(indices, index_vector_dim, rank)); } // Flatten indices, making it two-dimensional. if (rank > 2) { - TF_ASSIGN_OR_RETURN(indices, CollapseFirstNDims(indices, rank - 1)); + ASSIGN_OR_RETURN(indices, CollapseFirstNDims(indices, rank - 1)); } else if (rank == 1) { - TF_ASSIGN_OR_RETURN(indices, InsertDegenerateDims(indices, {0})); + ASSIGN_OR_RETURN(indices, InsertDegenerateDims(indices, {0})); } return indices; } @@ -136,7 +137,7 @@ absl::StatusOr MaybeTranspose( if (IsIdentityPermutation(permutation)) { return operand; } - TF_ASSIGN_OR_RETURN(auto* result, MakeTransposeHlo(operand, permutation)); + ASSIGN_OR_RETURN(auto* result, MakeTransposeHlo(operand, permutation)); // Assign the default layout to the transpose. This method is also used after // layout normalization, and before, we don't care about the layout. *result->mutable_shape()->mutable_layout() = @@ -158,8 +159,8 @@ absl::StatusOr> MaybeTranspose( std::vector result; result.reserve(operands.size()); for (auto* operand : operands) { - TF_ASSIGN_OR_RETURN(result.emplace_back(), - MaybeTranspose(operand, operand_permutation)); + ASSIGN_OR_RETURN(result.emplace_back(), + MaybeTranspose(operand, operand_permutation)); } return result; } @@ -213,7 +214,7 @@ absl::StatusOr ExpandIndexVectorIntoOperandSpace( for (int i = 0; i < operand_rank; i++) { int64_t index_vector_dim_index = FindIndex(start_index_map, i); if (index_vector_dim_index != start_index_map.size()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * component_to_concat, MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, /*limit_indices=*/{index_vector_dim_index + 1}, diff --git a/third_party/xla/xla/service/generic_transfer_manager.cc b/third_party/xla/xla/service/generic_transfer_manager.cc index 9cdca24bd1685c..c8cdf8155351ac 100644 --- a/third_party/xla/xla/service/generic_transfer_manager.cc +++ b/third_party/xla/xla/service/generic_transfer_manager.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/primitive_util.h" @@ -63,10 +64,10 @@ absl::Status GenericTransferManager::WriteSingleTupleIndexTable( for (const se::DeviceAddressBase& element : elements) { element_pointers->push_back(element.opaque()); } - TF_RETURN_IF_ERROR(TransferBufferToDevice( - stream, GetByteSizeRequirement(shape), element_pointers->data(), region)); + RETURN_IF_ERROR(TransferBufferToDevice(stream, GetByteSizeRequirement(shape), + element_pointers->data(), region)); // Ensure the buffer is transferred before we destroy element_pointers. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( stream->DoHostCallback([element_pointers{std::move(element_pointers)}]() { /* holds reference to element_pointers in closure */ })); @@ -85,7 +86,7 @@ void GenericTransferManager::TransferLiteralFromDevice( TF_RET_CHECK(stream->parent()->device_ordinal() == device_buffer.physical_device_ordinal()); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_device_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> absl::Status { if (subshape.IsArray()) { @@ -102,7 +103,7 @@ void GenericTransferManager::TransferLiteralFromDevice( /*num_elements=*/ShapeUtil::ElementsIn(subshape), /*destination=*/literal.untyped_data(index)); } else { - TF_RETURN_IF_ERROR(TransferBufferFromDevice( + RETURN_IF_ERROR(TransferBufferFromDevice( stream, /*source=*/device_buffer.buffer(index), // With bounded dynamic shapes, the shape of the device buffer @@ -157,7 +158,7 @@ absl::Status GenericTransferManager::TransferLiteralToDeviceAsync( TF_RET_CHECK(stream->parent()->device_ordinal() == device_buffer.physical_device_ordinal()); - TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); + RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_device_shape(), @@ -196,9 +197,9 @@ absl::Status GenericTransferManager::TransferLiteralToDeviceAsync( // Relayout data before transferring. auto relaid_out = std::make_shared( subliteral.Relayout(device_subshape.layout())); - TF_RETURN_IF_ERROR(TransferBuffer(relaid_out->untyped_data())); + RETURN_IF_ERROR(TransferBuffer(relaid_out->untyped_data())); // Ensure the buffer is transferred before we destroy it. - TF_RETURN_IF_ERROR(stream->DoHostCallback( + RETURN_IF_ERROR(stream->DoHostCallback( [keep_alive = std::move(relaid_out)] {})); } } @@ -254,9 +255,9 @@ absl::Status GenericTransferManager::TransferIntNArrayFromDevice( int64_t elements_per_byte = 8 / bit_width; int64_t packed_size = CeilOfRatio(num_elements, elements_per_byte); auto packed_dst_data = std::make_unique>(packed_size); - TF_RETURN_IF_ERROR(TransferBufferFromDevice(stream, source, packed_size, - packed_dst_data->data())); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(TransferBufferFromDevice(stream, source, packed_size, + packed_dst_data->data())); + RETURN_IF_ERROR( stream->DoHostCallback([destination, bit_width, num_elements, packed_dst_data = std::move(packed_dst_data)]() { UnpackIntN( @@ -277,8 +278,8 @@ absl::Status GenericTransferManager::TransferIntNArrayToDevice( absl::MakeSpan(static_cast(source), num_elements), absl::MakeSpan(*packed_src_data)); TF_RET_CHECK(packed_src_data->size() == destination->size()); - TF_RETURN_IF_ERROR(TransferBufferToDevice( - stream, packed_src_data->size(), packed_src_data->data(), destination)); + RETURN_IF_ERROR(TransferBufferToDevice(stream, packed_src_data->size(), + packed_src_data->data(), destination)); return stream->DoHostCallback([keep_alive = std::move(packed_src_data)] {}); } diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b8b5e57eba3539..d70c8986e85c95 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -736,6 +736,7 @@ cc_library( "//xla/backends/gpu/runtime:command", "//xla/backends/gpu/runtime:command_buffer_conversion_pass", "//xla/backends/gpu/runtime:command_buffer_thunk", + "//xla/backends/gpu/runtime:execution_stream_id", "//xla/backends/gpu/runtime:nvshmem_collective_thunk", "//xla/backends/gpu/runtime:scratch_memory", "//xla/backends/gpu/runtime:scratch_memory_requests", @@ -773,11 +774,14 @@ cc_library( "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:device_description", + "//xla/stream_executor:event", "//xla/stream_executor:event_based_timer", "//xla/stream_executor:kernel_stats", + "//xla/stream_executor:memory_allocation", "//xla/stream_executor:memory_reservation", "//xla/stream_executor:module_spec", "//xla/stream_executor:platform", + "//xla/stream_executor:platform_id", "//xla/stream_executor:scoped_module_handle", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", @@ -787,11 +791,9 @@ cc_library( "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", "//xla/stream_executor/sycl:sycl_platform_id", - "//xla/tsl/platform:env", "//xla/tsl/platform:env_time", "//xla/tsl/platform:logging", "//xla/tsl/platform:status_macros", - "//xla/tsl/platform:statusor", "//xla/tsl/util:sorted_range", "//xla/util/split_proto:split_executable_and_options_writer", "//xla/util/split_proto:split_gpu_executable_writer", @@ -802,7 +804,6 @@ cc_library( "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -816,7 +817,6 @@ cc_library( "@com_google_absl//absl/types:span", "@riegeli//riegeli/bytes:string_writer", "@riegeli//riegeli/bytes:writer", - "@tsl//tsl/platform:casts", "@tsl//tsl/platform:random", "@tsl//tsl/profiler/lib:scoped_annotation", "@tsl//tsl/profiler/lib:traceme", diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 17c1f30b4aa2aa..6fc889ed5a881e 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -35,7 +35,7 @@ xla_test( "//xla/backends/autotuner", "//xla/backends/autotuner:codegen_backend", "//xla/backends/autotuner:profiler", - "//xla/backends/gpu/autotuner:cublas", + "//xla/backends/gpu/autotuner:cublaslt", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/service:platform_util", @@ -123,6 +123,7 @@ cc_library( "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -158,6 +159,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor/gpu:redzone_allocator", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -229,6 +231,7 @@ xla_cc_test( "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/util/proto:proto_matchers", @@ -277,6 +280,7 @@ cc_library( "//xla/stream_executor/sycl:sycl_platform_id", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_cache.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_cache.cc index 92271fe274b1d2..52fe9d367e74f2 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_cache.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_cache.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/SHA256.h" #include "google/protobuf/text_format.h" @@ -78,7 +79,7 @@ absl::StatusOr GetBase64EncodedSha256Hash(absl::string_view s) { absl::string_view hash_view(reinterpret_cast(hash.data()), hash.size()); std::string base64_encoded_hash; - TF_RETURN_IF_ERROR(tsl::Base64Encode(hash_view, &base64_encoded_hash)); + RETURN_IF_ERROR(tsl::Base64Encode(hash_view, &base64_encoded_hash)); return base64_encoded_hash; } @@ -113,13 +114,13 @@ absl::Status AddResultToFileBasedCacheIfEnabled( } tsl::Env* default_env = tsl::Env::Default(); - TF_RETURN_IF_ERROR(CreateDirIfNeeded(std::string(cache_dir), default_env)); + RETURN_IF_ERROR(CreateDirIfNeeded(std::string(cache_dir), default_env)); - TF_ASSIGN_OR_RETURN(std::string key_hash, - GetBase64EncodedSha256Hash(key.ToString())); + ASSIGN_OR_RETURN(std::string key_hash, + GetBase64EncodedSha256Hash(key.ToString())); - TF_ASSIGN_OR_RETURN(const std::string file_path, - GetCacheFilePath(cache_dir, key_hash)); + ASSIGN_OR_RETURN(const std::string file_path, + GetCacheFilePath(cache_dir, key_hash)); VLOG(1) << "Writing autotune result to file: " << file_path; @@ -133,14 +134,14 @@ absl::Status AddResultToFileBasedCacheIfEnabled( // file. Also avoids reading incomplete files. (This may not work on all file // systems.) std::string tmp_dir = tsl::io::JoinPath(cache_dir, "tmp"); - TF_RETURN_IF_ERROR(CreateDirIfNeeded(tmp_dir, default_env)); + RETURN_IF_ERROR(CreateDirIfNeeded(tmp_dir, default_env)); int64_t time_stamp = absl::GetCurrentTimeNanos(); std::string temp_file_path = tsl::io::JoinPath( tmp_dir, absl::StrCat("tmp_per_fusion_cache_", key_hash, "_", std::to_string(time_stamp), ".textproto")); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::WriteStringToFile(default_env, temp_file_path, result_str)); return default_env->RenameFile(temp_file_path, file_path); } @@ -163,11 +164,11 @@ TryToFindInFileBasedCacheIfEnabled(const AutotuneCacheKey& key, return std::nullopt; } - TF_ASSIGN_OR_RETURN(std::string key_hash, - GetBase64EncodedSha256Hash(key.ToString())); + ASSIGN_OR_RETURN(std::string key_hash, + GetBase64EncodedSha256Hash(key.ToString())); - TF_ASSIGN_OR_RETURN(const std::string file_path, - GetCacheFilePath(cache_dir, key_hash)); + ASSIGN_OR_RETURN(const std::string file_path, + GetCacheFilePath(cache_dir, key_hash)); if (!tsl::Env::Default()->FileExists(file_path).ok()) { VLOG(1) << "Autotune result file not found: " << file_path; return std::nullopt; @@ -175,8 +176,8 @@ TryToFindInFileBasedCacheIfEnabled(const AutotuneCacheKey& key, VLOG(1) << "Autotune result file found: " << file_path; std::string autotune_result_str; - TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), file_path, - &autotune_result_str)); + RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), file_path, + &autotune_result_str)); AutotuneResult result; if (!tsl::protobuf::TextFormat::ParseFromString(autotune_result_str, &result)) { @@ -278,8 +279,8 @@ TryFindInAllCacheTypes(const AutotuneCacheKey& key, absl::string_view cache_dir) return std::make_pair(CacheType::kInMemory, opt_result); } - TF_ASSIGN_OR_RETURN(opt_result, - TryToFindInFileBasedCacheIfEnabled(key, cache_dir)); + ASSIGN_OR_RETURN(opt_result, + TryToFindInFileBasedCacheIfEnabled(key, cache_dir)); if (opt_result.has_value()) { AddResultToInMemoryCache(key, opt_result.value()); return std::make_pair(CacheType::kOnDisk, opt_result); @@ -292,7 +293,7 @@ TryFindInAllCacheTypes(const AutotuneCacheKey& key, absl::string_view cache_dir) absl::StatusOr> AutotunerCache::TryFindInCache( const AutotuneCacheKey& key, absl::string_view cache_dir) ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { - TF_ASSIGN_OR_RETURN(auto cached, TryFindInAllCacheTypes(key, cache_dir)); + ASSIGN_OR_RETURN(auto cached, TryFindInAllCacheTypes(key, cache_dir)); if (VLOG_IS_ON(1)) { std::string logged_key = @@ -320,7 +321,7 @@ absl::StatusOr AutotunerCache::AddResultToCaches( ABSL_LOCKS_EXCLUDED(autotune_cache_mu) { ResultAndInserted result_and_inserted = AddResultToInMemoryCache(key, result); if (result_and_inserted.inserted) { - TF_RETURN_IF_ERROR(AddResultToFileBasedCacheIfEnabled( + RETURN_IF_ERROR(AddResultToFileBasedCacheIfEnabled( key, result_and_inserted.result, cache_dir, autotune_cache_mode)); } return result_and_inserted; @@ -357,14 +358,14 @@ bool IsTextProtoPath(absl::string_view file_path) { } AddVersionToAutotuneResults(results); - TF_RETURN_IF_ERROR(LoadAutotuneResults(results, allow_override)); + RETURN_IF_ERROR(LoadAutotuneResults(results, allow_override)); return absl::OkStatus(); } /*static*/ absl::StatusOr AutotunerCache::SerializeAutotuneResults( bool as_textproto) { AutotuneResults results; - TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results)); + RETURN_IF_ERROR(SerializeAutotuneResults(&results)); return AutotuneResultsToString(results, as_textproto); } @@ -379,11 +380,11 @@ bool IsTextProtoPath(absl::string_view file_path) { return FailedPrecondition("File path can not be resolved: %s", file_path); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::string autotune_results_str, AutotuneResultsToString(results, IsTextProtoPath(resolved_path))); - TF_RETURN_IF_ERROR(tsl::WriteStringToFile(tsl::Env::Default(), resolved_path, - autotune_results_str)); + RETURN_IF_ERROR(tsl::WriteStringToFile(tsl::Env::Default(), resolved_path, + autotune_results_str)); LOG(INFO) << "Autotune results serialized to file: " << resolved_path; return absl::OkStatus(); @@ -392,7 +393,7 @@ bool IsTextProtoPath(absl::string_view file_path) { /*static*/ absl::Status AutotunerCache::SerializeAutotuneResultsToFile( absl::string_view file_path) { AutotuneResults results; - TF_RETURN_IF_ERROR(SerializeAutotuneResults(&results)); + RETURN_IF_ERROR(SerializeAutotuneResults(&results)); return SerializeAutotuneResultsToFile(results, file_path); } @@ -410,11 +411,11 @@ bool IsTextProtoPath(absl::string_view file_path) { resolved_path); } std::string autotune_results_str; - TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), resolved_path, - &autotune_results_str)); + RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), resolved_path, + &autotune_results_str)); - TF_RETURN_IF_ERROR(LoadAutotuneResults(autotune_results_str, - IsTextProtoPath(resolved_path))); + RETURN_IF_ERROR(LoadAutotuneResults(autotune_results_str, + IsTextProtoPath(resolved_path))); LOG(INFO) << "Autotune results loaded from file: " << resolved_path; diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_cache_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_cache_test.cc index f715b9b709679e..a0052992aaa0e2 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_cache_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_cache_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status_matchers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "google/protobuf/text_format.h" #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" @@ -151,7 +152,7 @@ ENTRY e { absl::Status PopulateResultCache() { EXPECT_TRUE(AutotunerCache::ResultCacheIsEmpty()); - TF_RETURN_IF_ERROR(AutotunerCache::LoadAutotuneResults(kResultText, true)); + RETURN_IF_ERROR(AutotunerCache::LoadAutotuneResults(kResultText, true)); EXPECT_FALSE(AutotunerCache::ResultCacheIsEmpty()); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.cc index b96bd5160f31c1..f9ab88ae8886e8 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_pass.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/IR/MLIRContext.h" #include "xla/backends/autotuner/autotuner.h" #include "xla/backends/autotuner/autotuner_cache_interface.h" @@ -290,13 +291,10 @@ AutotunerPass::GetGpuAutotunerBackends( std::vector disabled_autotune_backends; if (debug_options.xla_gpu_experimental_disable_binary_libraries()) { - disabled_autotune_backends.push_back(autotuner::Backend::CUBLAS); disabled_autotune_backends.push_back(autotuner::Backend::CUBLASLT); disabled_autotune_backends.push_back(autotuner::Backend::CUDNN); - disabled_autotune_backends.push_back(autotuner::Backend::ROCBLAS); disabled_autotune_backends.push_back(autotuner::Backend::HIPBLASLT); disabled_autotune_backends.push_back(autotuner::Backend::MIOPEN); - disabled_autotune_backends.push_back(autotuner::Backend::ROCBLAS_FISSION); disabled_autotune_backends.push_back(autotuner::Backend::HIPBLASLT_FISSION); } @@ -397,7 +395,7 @@ absl::StatusOr> AutotunerPass::Create( cache_dir, debug_options.xla_gpu_experimental_autotune_cache_mode(), target_config->device_description); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr autotuner, Autotuner::Create(std::move(backends), std::move(profiler), autotune_config, std::move(cache), thread_pool)); @@ -415,10 +413,10 @@ absl::StatusOr AutotunerPass::RunImpl( bool shard_autotuning = enable_sharding_ && key_value_store_.process_count > 1; if (shard_autotuning) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( autotuner_->Autotune(module, should_autotune_, key_value_store_)); } else { - TF_RETURN_IF_ERROR(autotuner_->Autotune(module, should_autotune_)); + RETURN_IF_ERROR(autotuner_->Autotune(module, should_autotune_)); } VLOG(1) << "Autotuner cache stats: hits=" << autotuner_->GetCacheStats().hits << ", misses=" << autotuner_->GetCacheStats().misses; diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_pass_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_pass_test.cc index 3d20e1269c8284..17d338b053aecc 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_pass_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_pass_test.cc @@ -30,7 +30,7 @@ limitations under the License. #include "xla/backends/autotuner/autotuner.h" #include "xla/backends/autotuner/codegen_backend.h" #include "xla/backends/autotuner/profiler.h" -#include "xla/backends/gpu/autotuner/cublas.h" +#include "xla/backends/gpu/autotuner/cublaslt.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" @@ -81,7 +81,7 @@ ENTRY %main (arg0: f32[100,100], arg1: f32[100,100]) -> f32[100,100] { %arg0 = f32[100,100]{1,0} parameter(0) %arg1 = f32[100,100]{1,0} parameter(1) %custom-call.1 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%arg0, %arg1), - custom_call_target="__cublas$gemm", + custom_call_target="__cublas$lt$matmul", backend_config={ "gemm_backend_config":{ "dot_dimension_numbers": @@ -104,7 +104,7 @@ TEST_F(AutotunerPassTest, CublasGemmIsAutotuned) { /*num_threads=*/4); std::vector> backends; GpuCompiler::GpuTargetConfig target_config(stream_executor_); - backends.push_back(std::make_unique( + backends.push_back(std::make_unique( stream_executor_, &module->config().debug_options(), &compiler_, &target_config)); @@ -149,7 +149,7 @@ TEST_F(AutotunerPassTest, CublasGemmIsAutotunedAndCached) { // Run the pass for the first time, this should populate the cache. { std::vector> backends; - backends.push_back(std::make_unique( + backends.push_back(std::make_unique( stream_executor_, &module->config().debug_options(), &compiler_, &target_config)); @@ -197,7 +197,7 @@ TEST_F(AutotunerPassTest, CublasGemmIsAutotunedAndCached) { .set_xla_gpu_require_complete_aot_autotune_results(true); { std::vector> backends2; - backends2.push_back(std::make_unique( + backends2.push_back(std::make_unique( stream_executor_, &module_2->config().debug_options(), &compiler_, &target_config)); @@ -250,7 +250,7 @@ TEST_F(AutotunerPassTest, CublasGemmIsAutotunedWithCacheOnly) { // Run the pass for the first time, this should populate the cache. { std::vector> backends; - backends.push_back(std::make_unique( + backends.push_back(std::make_unique( stream_executor_, &module->config().debug_options(), &compiler_, &target_config)); @@ -283,7 +283,7 @@ TEST_F(AutotunerPassTest, CublasGemmIsAutotunedWithCacheOnly) { { std::vector> backends2; - backends2.push_back(std::make_unique( + backends2.push_back(std::make_unique( stream_executor_, &module_2->config().debug_options(), &compiler_, &target_config)); @@ -330,7 +330,7 @@ TEST_F(AutotunerPassTest, DevicelessUsesDefaultConfigIfNoCache) { GpuCompiler::GpuTargetConfig target_config(stream_executor_); std::vector> backends; - backends.push_back(std::make_unique( + backends.push_back(std::make_unique( stream_executor_, &module->config().debug_options(), &compiler_, &target_config)); @@ -367,7 +367,7 @@ ENTRY %main (arg0: f32[100,100], arg1: f32[100,100]) -> f32[100,100] { %arg0 = f32[100,100]{1,0} parameter(0) %arg1 = f32[100,100]{1,0} parameter(1) %custom-call.1 = (f32[100,100]{1,0}, s8[80000]{0}) custom-call(%arg0, %arg1), - custom_call_target="__cublas$gemm", + custom_call_target="__cublas$lt$matmul", backend_config={ "operation_queue_id":"109", "gemm_backend_config":{ @@ -392,7 +392,7 @@ ENTRY %main (arg0: f32[100,100], arg1: f32[100,100]) -> f32[100,100] { std::vector> backends; GpuCompiler::GpuTargetConfig target_config(stream_executor_); - backends.push_back(std::make_unique( + backends.push_back(std::make_unique( stream_executor_, &module->config().debug_options(), &compiler_, &target_config)); diff --git a/third_party/xla/xla/service/gpu/autotuning/redzone_buffers.cc b/third_party/xla/xla/service/gpu/autotuning/redzone_buffers.cc index 8c1666480d0739..5fbb9801c33488 100644 --- a/third_party/xla/xla/service/gpu/autotuning/redzone_buffers.cc +++ b/third_party/xla/xla/service/gpu/autotuning/redzone_buffers.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" @@ -84,14 +85,14 @@ absl::StatusOr RedzoneBuffers::FromProgramShape( int64_t rng_state = 0; - TF_RETURN_IF_ERROR(buffers.CreateInputs(program_shape.parameters(), - should_init_buffers, rng_state)); + RETURN_IF_ERROR(buffers.CreateInputs(program_shape.parameters(), + should_init_buffers, rng_state)); if (buffers_to_create == BuffersToCreate::kAllInputsAllOutputs || buffers_to_create == BuffersToCreate::kAllInputsOutputsNoScratch) { - TF_RETURN_IF_ERROR(buffers.CreateOutputs(program_shape.result(), - buffers_to_create, - should_init_buffers, rng_state)); + RETURN_IF_ERROR(buffers.CreateOutputs(program_shape.result(), + buffers_to_create, + should_init_buffers, rng_state)); } return buffers; } @@ -101,9 +102,9 @@ absl::Status RedzoneBuffers::CreateInputs(absl::Span input_shapes, int64_t& rng_state) { tsl::profiler::TraceMe traceme("create inputs"); for (const auto& input_shape : input_shapes) { - TF_ASSIGN_OR_RETURN(se::DeviceAddressBase buf, - redzone_allocator_->CreateBuffer( - input_shape, should_init_buffers, rng_state)); + ASSIGN_OR_RETURN(se::DeviceAddressBase buf, + redzone_allocator_->CreateBuffer( + input_shape, should_init_buffers, rng_state)); input_buffers_.push_back(buf); input_shapes_.push_back(input_shape); } @@ -116,9 +117,9 @@ absl::Status RedzoneBuffers::CreateOutputs(const Shape& output_shape, int64_t& rng_state) { tsl::profiler::TraceMe traceme("create outputs"); if (!output_shape.IsTuple()) { - TF_ASSIGN_OR_RETURN(se::DeviceAddressBase buf, - redzone_allocator_->CreateBuffer( - output_shape, should_init_buffers, rng_state)); + ASSIGN_OR_RETURN(se::DeviceAddressBase buf, + redzone_allocator_->CreateBuffer( + output_shape, should_init_buffers, rng_state)); output_buffers_.push_back(buf); output_shape_ = output_shape; return absl::OkStatus(); @@ -139,9 +140,9 @@ absl::Status RedzoneBuffers::CreateOutputs(const Shape& output_shape, if (current_shape_it->IsTuple()) { return Unimplemented("Nested tuples are unsupported by RedzoneBuffers."); } - TF_ASSIGN_OR_RETURN(se::DeviceAddressBase buf, - redzone_allocator_->CreateBuffer( - *current_shape_it, should_init_buffers, rng_state)); + ASSIGN_OR_RETURN(se::DeviceAddressBase buf, + redzone_allocator_->CreateBuffer( + *current_shape_it, should_init_buffers, rng_state)); output_buffers_.push_back(buf); } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/determinism_test.cc b/third_party/xla/xla/service/gpu/determinism_test.cc index 2a57d3a33aaa4e..6b23827346f02c 100644 --- a/third_party/xla/xla/service/gpu/determinism_test.cc +++ b/third_party/xla/xla/service/gpu/determinism_test.cc @@ -173,14 +173,6 @@ class DeterminismTest : public HloPjRtGpuTestBase { }; TEST_F(DeterminismTest, CublasLtDot) { - if (IsRocm()) { - if (!HasHipblasLt()) { - GTEST_SKIP() << "No hipblas-lt support on this architecture!"; - } - } - - // This test expects to use CublasLt. Disable other backends, including - // Triton. debug_options_.clear_xla_gpu_experimental_autotune_backends(); debug_options_.add_xla_gpu_experimental_autotune_backends( autotuner::Backend::CUBLASLT); @@ -195,7 +187,6 @@ ENTRY e { if (!HasHipblasLt()) { GTEST_SKIP() << "No hipblas-lt support on this architecture!"; } - debug_options_.clear_xla_gpu_experimental_autotune_backends(); } debug_options_.set_xla_gpu_enable_triton_gemm(false); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 8695cfbcd8b31b..2a901c937f1218 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -698,7 +698,8 @@ absl::Status RunOptimizationPasses( DebugOptions::DETECTION_MODE_NONE) { pipeline.AddPass(); } - pipeline.AddPass(gpu_version); + pipeline.AddPass(gpu_version, + gpu_target_config.dnn_version_info); if (!debug_options.xla_gpu_experimental_scaled_dot_with_triton()) { pipeline.AddPass(); } @@ -2251,6 +2252,43 @@ bool UsesCollectiveMemorySpaceFrontendAttr(const HloUse& use) { return false; } +bool DefinesCollectiveMemorySpaceFrontendAttr(const HloValue* value) { + const HloInstruction* def = value->defining_instruction(); + if (def->opcode() != HloOpcode::kCustomCall) { + return false; + } + + auto attr = def->get_frontend_attribute(kResultsMemorySpacesAttr); + if (!attr.has_value()) { + return false; + } + + auto pairs = ParseIndexMemorySpacePairs(*attr); + if (!pairs.ok()) { + return false; + } + + // Determine the logical result index. If the custom call returns a tuple, + // we look at the top-level index (e.g., element 0 or 1 of the tuple). + int64_t result_index = 0; + if (def->shape().IsTuple()) { + if (value->defining_index().empty()) { + // The buffer for the tuple pointer array itself is not S1. + return false; + } + result_index = value->defining_index()[0]; + } + + for (auto [index, memory_space] : *pairs) { + if (index == result_index && + memory_space == MemorySpaceColor::kCollective) { + return true; + } + } + + return false; +} + bool ShouldAddCopyForCollectiveMemorySpace(const HloValue* value, const GpuTopology& gpu_topology) { const HloInstruction* inst = value->defining_instruction(); @@ -2265,8 +2303,7 @@ bool ShouldAddCopyForCollectiveMemorySpace(const HloValue* value, if (IsCollectiveMosaicGpuInstruction(*use.instruction) || (gpu_topology.num_partitions() > gpu_topology.num_devices_per_host() && - IsMosaicWithCollectiveMetadata(*use.instruction)) || - UsesCollectiveMemorySpaceFrontendAttr(use)) { + IsMosaicWithCollectiveMetadata(*use.instruction))) { return true; } } @@ -2300,11 +2337,16 @@ bool RequiresCollectiveInput(const HloUse& use, const DebugOptions& opts) { return true; } + // Check custom calls with operands_memory_spaces attribute + if (UsesCollectiveMemorySpaceFrontendAttr(use)) { + return true; + } + return false; } -bool RequiresCollectiveOutput(const HloInstruction* def, - const DebugOptions& opts) { +bool RequiresCollectiveOutput(const HloValue* value, const DebugOptions& opts) { + HloInstruction* def = value->defining_instruction(); const bool is_nccl_buffers_used = opts.xla_gpu_enable_nccl_user_buffers() || opts.xla_gpu_experimental_enable_nccl_symmetric_buffers(); @@ -2327,12 +2369,17 @@ bool RequiresCollectiveOutput(const HloInstruction* def, return true; } + // Check custom calls with results_memory_spaces attribute + if (DefinesCollectiveMemorySpaceFrontendAttr(value)) { + return true; + } + return false; } // TODO: b/482045400: Migrate remaining cases from // ShouldAddCopyForCollectiveMemorySpace -// (IsCollectiveMosaicGpuInstruction, UsesCollectiveMemorySpaceFrontendAttr) +// (IsCollectiveMosaicGpuInstruction) // to this function from ShouldAddCopyForCollectiveMemorySpace void GpuCollectiveBufferAnalysis( HloModule* module, const HloAliasAnalysis& alias_analysis, @@ -2374,7 +2421,7 @@ void GpuCollectiveBufferAnalysis( live_out_values.push_back(value); } - if (RequiresCollectiveOutput(def, opts)) { + if (RequiresCollectiveOutput(value, opts)) { defined_by_collective = true; } } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index 1505af3cf41748..1a0d732e92de82 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -614,11 +614,11 @@ ENTRY main { EXPECT_EQ(operand_1->operand(0)->opcode(), HloOpcode::kAllGatherDone); } -// This test ensures that the pathway for using the cuBLAS fallback (forming a -// Triton fusion and falling back to cuBLAS in the autotuner) is exactly the -// same as using cuBLAS directly (with Triton disabled). +// This test ensures that the pathway for using the cuBlasLt fallback (forming a +// Triton fusion and falling back to cuBlasLt in the autotuner) is exactly the +// same as using cuBLasLt directly (with Triton disabled). TEST_F(GpuCompilerTest, - GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas) { + GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublasLt) { if (!get_cuda_cc().IsAtLeastAmpere()) { GTEST_SKIP() << "Autotuning results have only been generated for Ampere " << "and later GPUs"; @@ -655,8 +655,6 @@ ENTRY main { // Triton enabled, but forced to fallback to cuBLAS (no Triton backend). DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest(); triton_enabled_debug_options.clear_xla_gpu_experimental_autotune_backends(); - triton_enabled_debug_options.add_xla_gpu_experimental_autotune_backends( - autotuner::Backend::CUBLAS_FISSION); triton_enabled_debug_options.add_xla_gpu_experimental_autotune_backends( autotuner::Backend::CUBLASLT_FISSION); triton_enabled_debug_options.add_xla_gpu_experimental_autotune_backends( @@ -671,9 +669,8 @@ ENTRY main { AutotuneResults results; ASSERT_OK(AutotunerCache::SerializeAutotuneResults(&results)); EXPECT_FALSE(results.results().empty()); - EXPECT_TRUE(absl::StrContains(results.DebugString(), "CUBLAS_FISSION") || - // CUBLASLT_FISSION is dumped as GemmKey in the AutotunerResult. - absl::StrContains(results.DebugString(), "gemm")); + // CUBLASLT_FISSION is dumped as GemmKey in the AutotunerResult. + EXPECT_TRUE(absl::StrContains(results.DebugString(), "gemm")); // Triton disabled - this will skip the GemmFusion pass and use cuBLAS. DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest(); @@ -688,10 +685,8 @@ ENTRY main { const HloInstruction* root = triton_enabled_module->entry_computation()->root_instruction(); const HloInstruction* custom_op = root->operand(0)->operand(0); - bool is_cublas_gemm = GetDebugOptionsForTest().xla_gpu_enable_cublaslt() - ? custom_op->IsCustomCall("__cublas$lt$matmul") - : custom_op->IsCustomCall("__cublas$gemm"); - EXPECT_TRUE(is_cublas_gemm) << custom_op->ToString(); + EXPECT_TRUE(custom_op->IsCustomCall("__cublas$lt$matmul")) + << custom_op->ToString(); // Make sure that the module has the same number of computations with/without // enabling triton gemm EXPECT_EQ(triton_enabled_module->computation_count(), @@ -720,8 +715,6 @@ ENTRY main { // Triton enabled, but forced to fallback to cuBLAS (no Triton backend). DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest(); triton_enabled_debug_options.clear_xla_gpu_experimental_autotune_backends(); - triton_enabled_debug_options.add_xla_gpu_experimental_autotune_backends( - autotuner::Backend::CUBLAS_FISSION); triton_enabled_debug_options.add_xla_gpu_experimental_autotune_backends( autotuner::Backend::CUBLASLT_FISSION); triton_enabled_debug_options.add_xla_gpu_experimental_autotune_backends( @@ -732,7 +725,7 @@ ENTRY main { auto triton_enabled_executable = triton_enabled_module_and_executable.second.get(); - // Triton disabled - this will skip the GemmFusion pass and use cuBLAS. + // Triton disabled - this will skip the GemmFusion pass and use cuBlasLt. DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest(); triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false); config.set_debug_options(triton_disabled_debug_options); @@ -815,7 +808,7 @@ ENTRY main { R"(CHECK: custom-call($0{{[^)]*}}, $1{{[^)]*}}){{.*}}custom_call_target="__cublas$$lt$$matmul$$f8")", lhs_name, rhs_name); const std::string cublas_convert_to_f16 = - R"(CHECK: custom-call(f16{{.*}}, f16{{.*}}){{.*}}custom_call_target="{{__cublas\$gemm|__cublas\$lt\$matmul}}")"; + R"(CHECK: custom-call(f16{{.*}}, f16{{.*}}){{.*}}custom_call_target="{{__cublas\$lt\$matmul}}")"; const std::string fallback_convert_to_f16 = R"(CHECK: dot(f16{{[^)]*}}, f16{{[^)]*}}))"; @@ -1029,7 +1022,7 @@ HloModule composite async_call { p0 = f32[32,32] parameter(0) p1 = f32[32,32] parameter(1) - gemm = (f32[32,32], s8[8192]) custom-call(p0, p1), custom_call_target="__cublas$gemm", + gemm = (f32[32,32], s8[8192]) custom-call(p0, p1), custom_call_target="__cublas$lt$matmul", backend_config={ "gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0, "dot_dimension_numbers": @@ -1071,7 +1064,7 @@ ENTRY main { gpu_exec->thunk_executor().thunks()[0].get()); EXPECT_EQ(async_start_thunk->thunks().size(), 1); EXPECT_THAT(async_start_thunk->thunks(), - ::testing::ElementsAre(ThunkKindIs(Thunk::kGemm))); + ::testing::ElementsAre(ThunkKindIs(Thunk::kCublasLtMatmul))); } TEST_F(GpuCompilerTest, StreamAnnotationThunkTestFDO) { @@ -1081,7 +1074,7 @@ HloModule composite async_call { p0 = f32[32,32] parameter(0) p1 = f32[32,32] parameter(1) - gemm = (f32[32,32], s8[8192]) custom-call(p0, p1), custom_call_target="__cublas$gemm", + gemm = (f32[32,32], s8[8192]) custom-call(p0, p1), custom_call_target="__cublas$lt$matmul", backend_config={ "gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0, "dot_dimension_numbers": @@ -1131,7 +1124,7 @@ ENTRY main { gpu_exec->thunk_executor().thunks()[0].get()); EXPECT_EQ(async_start_thunk->thunks().size(), 1); EXPECT_THAT(async_start_thunk->thunks(), - ::testing::ElementsAre(ThunkKindIs(Thunk::kGemm))); + ::testing::ElementsAre(ThunkKindIs(Thunk::kCublasLtMatmul))); } using GpuCompilerPassTest = GpuCompilerTest; @@ -2799,5 +2792,151 @@ ENTRY entry { EXPECT_GT(MultiModuleDriver::GetCompileCount(), 0); } + +static absl::Status MockCustomCallExecuteF32( + ffi::BufferR1 src, ffi::Result> dst) { + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER( + kMockCustomCallExecuteF32, MockCustomCallExecuteF32, + ffi::Ffi::Bind().Arg>().Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test_mock_custom_call_f32", + "gpu", + XLA_FFI_Handler_Bundle{ + /*instantiate=*/nullptr, + /*prepare=*/nullptr, + /*initialize=*/nullptr, + /*execute=*/kMockCustomCallExecuteF32, + }); + +class FrontendAttributesMemorySpaceTest + : public GpuCompilerTest, + public ::testing::WithParamInterface {}; + +TEST_P(FrontendAttributesMemorySpaceTest, DirectUsage) { + constexpr absl::string_view kHloTemplate = R"( + HloModule test$0 + + ENTRY test_computation { + p = f32[16] parameter(0) + ROOT cc = f32[16] custom-call(p), + custom_call_target="__xla_test_mock_custom_call_f32", + api_version=API_VERSION_TYPED_FFI, + frontend_attributes={ + operands_memory_spaces="{0:1}", + results_memory_spaces="{0:1}" + } + } + )"; + + bool use_input_output_alias = GetParam(); + // Inject the alias layout configuration into the HLO + std::string hlo_text = absl::StrReplaceAll( + kHloTemplate, + {{"$0", + use_input_output_alias ? ", input_output_alias={ {}: (0, {}) }" : ""}}); + + HloModuleConfig config = GetModuleConfigForTest(); + + std::pair> + optimized_module_and_executable; + ASSERT_OK_AND_ASSIGN(optimized_module_and_executable, + GetOptimizedModuleForExecutable(hlo_text, config)); + + const HloModule* optimized_module = optimized_module_and_executable.first; + + // Regardless of the input_output_alias it should be two copies + constexpr absl::string_view expected_check = R"( + // CHECK: %p = f32[16]{0} parameter(0) + // CHECK: [[COPY0:%copy[0-9.]*]] = f32[16]{0:S(1)} copy(%p) + // CHECK: %cc = f32[16]{0:S(1)} custom-call([[COPY0]]) + // CHECK: ROOT %copy{{.*}} = f32[16]{0} copy(%cc) + )"; + + EXPECT_THAT(RunFileCheck( + optimized_module->ToString(HloPrintOptions{} + .set_print_operand_shape(false) + .set_print_metadata(false)), + expected_check), + absl_testing::IsOkAndHolds(true)); +} + +TEST_P(FrontendAttributesMemorySpaceTest, LoopUsage) { + constexpr absl::string_view kHloTemplate = R"( + HloModule test$0 + + while_condition { + params = (s32[], f32[16]) parameter(0) + loop_counter = s32[] get-tuple-element(params), index=0 + limit = s32[] constant(3) + ROOT result = pred[] compare(loop_counter, limit), direction=LT + } + + while_body { + params = (s32[], f32[16]) parameter(0) + loop_counter = s32[] get-tuple-element(params), index=0 + cc_input = f32[16] get-tuple-element(params), index=1 + cc = f32[16] custom-call(cc_input), + custom_call_target="__xla_test_mock_custom_call_f32", + api_version=API_VERSION_TYPED_FFI, + frontend_attributes={ + operands_memory_spaces="{0:1}", + results_memory_spaces="{0:1}" + } + new_loop_counter = s32[] add(loop_counter, s32[] constant(1)) + ROOT result = (s32[], f32[16]) tuple(new_loop_counter, cc) + } + + ENTRY entry_computation { + init_loop_counter = s32[] constant(0) + input = f32[16] parameter(0) + while_init = tuple(init_loop_counter, input) + while = (s32[], f32[16]) while(while_init), condition=while_condition, body=while_body + ROOT result = get-tuple-element(while), index=1 + } + )"; + + bool use_input_output_alias = GetParam(); + // Inject the alias layout configuration into the HLO + std::string hlo_text = absl::StrReplaceAll( + kHloTemplate, + {{"$0", + use_input_output_alias ? ", input_output_alias={ {}: (0, {}) }" : ""}}); + + HloModuleConfig config = GetModuleConfigForTest(); + + std::pair> + optimized_module_and_executable; + ASSERT_OK_AND_ASSIGN(optimized_module_and_executable, + GetOptimizedModuleForExecutable(hlo_text, config)); + + const HloModule* optimized_module = optimized_module_and_executable.first; + + // Regardless of the input_output_alias it should be two copies + constexpr absl::string_view expected_check = R"( + // CHECK: %input = f32[16]{0} parameter(0) + // CHECK: [[COPY0:%copy[0-9.]*]] = f32[16]{0:S(1)} copy(%input) + // CHECK: %tuple = (s32[], f32[16]{0:S(1)}) tuple(%copy{{.*}}, [[COPY0]]) + // CHECK: %while = (s32[], f32[16]{0:S(1)}) while(%tuple) + // CHECK: [[RESULT:%result[0-9.]*]] = f32[16]{0:S(1)} get-tuple-element(%while), index=1 + // CHECK: ROOT %copy{{.*}} = f32[16]{0} copy([[RESULT]]) + )"; + + EXPECT_THAT(RunFileCheck( + optimized_module->ToString(HloPrintOptions{} + .set_print_operand_shape(false) + .set_print_metadata(false)), + expected_check), + absl_testing::IsOkAndHolds(true)); +} + +INSTANTIATE_TEST_SUITE_P(FrontendAttributesMemorySpace, + FrontendAttributesMemorySpaceTest, ::testing::Bool(), + [](const ::testing::TestParamInfo& info) { + return info.param ? "with_alias" : "no_alias"; + }); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 67b84c828a3cf3..007c4bd23ad4a5 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -32,7 +32,6 @@ limitations under the License. #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -60,6 +59,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_buffer_conversion_pass.h" #include "xla/backends/gpu/runtime/command_buffer_thunk.h" +#include "xla/backends/gpu/runtime/execution_stream_id.h" #include "xla/backends/gpu/runtime/nvshmem_collective_thunk.h" #include "xla/backends/gpu/runtime/scratch_memory.h" #include "xla/backends/gpu/runtime/scratch_memory_requests.h" @@ -100,6 +100,7 @@ limitations under the License. #include "xla/service/hlo_value.h" #include "xla/service/logical_buffer.h" #include "xla/service/maybe_owning_device_address.h" +#include "xla/service/rendezvous.h" #include "xla/service/riegeli_dump_writer.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" @@ -117,25 +118,25 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/kernel_stats.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/memory_reservation.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_id.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scoped_module_handle.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/sycl/sycl_platform_id.h" #include "xla/stream_executor/vmm_device_address_allocator.h" -#include "xla/tsl/platform/env.h" #include "xla/tsl/platform/env_time.h" #include "xla/tsl/platform/logging.h" -#include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/sorted_range.h" #include "xla/util.h" #include "xla/util/split_proto/split_executable_and_options_writer.h" #include "xla/util/split_proto/split_gpu_executable_writer.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/casts.h" #include "tsl/platform/random.h" #include "tsl/profiler/lib/scoped_annotation.h" #include "tsl/profiler/lib/traceme.h" @@ -235,21 +236,56 @@ absl::StatusOr ShouldCollectiveUseMinimalResource( using ::tsl::profiler::ScopedAnnotation; -// Returns the number of additional compute streams needed by all -// `AsyncStartThunks` in the `GpuExecutable`. -static int64_t GetNumAdditionalComputeStreams(ThunkExecutor& executor) { - int64_t num_streams = 0; - CHECK_OK(executor.thunks().WalkNested([&](Thunk* thunk) -> absl::Status { - if (auto* async_start = dynamic_cast(thunk)) { - auto stream_id = async_start->execution_stream_id(); - if (stream_id.is_computation()) { - num_streams = std::max(num_streams, - stream_id.computation_id().value() + 1); +constexpr int kAsyncStreamTotal = + static_cast(AsyncStreamKind::ASYNC_STREAM_KIND_MEMCPYP2P) + 1; + +// Returns the number of additional streams to allocate for a `GpuExecutable`. +static GpuExecutable::NumAdditionalStreams GetNumAdditionalStreams( + ThunkExecutor& executor, const DebugOptions& opts) { + // First initialize based on what was requested via the DebugOptions. + int compute = opts.xla_gpu_executable_num_compute_streams(); + int comm = opts.xla_gpu_executable_num_communication_streams(); + + // Clamp it to minimum number of required streams. + compute = std::max(0, compute); + comm = std::max(kAsyncStreamTotal, comm); + + // Then traverse all thunks to see if anyone requested more streams. + for (const auto& thunk : executor.thunks()) { + thunk->Walk([&](Thunk* nested) { + if (auto* async_start = dynamic_cast(nested)) { + ExecutionStreamId id = async_start->execution_stream_id(); + if (id.is_computation()) { + compute = std::max(compute, id.computation_id().value() + 1); + } else { + comm = std::max(comm, id.communication_id().value() + 1); + } } - } - return absl::OkStatus(); - })); - return num_streams; + }); + } + + return {compute, comm}; +} + +GpuExecutable::BorrowedStreams GpuExecutable::BorrowedStreams::Assign( + se::Stream* stream, int num_streams) { + return BorrowedStreams{std::vector(num_streams, stream), {}}; +} + +absl::StatusOr GpuExecutable::BorrowStreams( + const ServiceExecutableRunOptions& run_options, int device_ordinal, + int num_streams, se::StreamPriority priority) { + ASSIGN_OR_RETURN( + std::vector owners, + run_options.BorrowStreams(device_ordinal, num_streams, priority)); + + std::vector streams; + streams.reserve(num_streams); + for (auto& stream : owners) { + streams.push_back(stream.get()); + } + + return BorrowedStreams{std::move(streams), std::move(owners)}; } static absl::Status RunThunkPasses(const DebugOptions& debug_options, @@ -373,7 +409,7 @@ GpuExecutable::GpuExecutable( absl::flat_hash_map output_info, bool enable_debug_info_manager, ModuleStats module_stats, absl::StatusOr> thunk_sequence_proto, - stream_executor::ExecutableAbiVersion executable_abi_version, + se::ExecutableAbiVersion executable_abi_version, std::optional cpu_target_machine_options, std::optional buffer_assignment_proto) : Executable(std::move(debug_module)), @@ -382,8 +418,8 @@ GpuExecutable::GpuExecutable( dnn_compiled_graphs_(std::move(dnn_compiled_graphs)), gpu_version_(device_description.gpu_compute_capability()), thunk_executor_(std::move(executable)), - num_additional_compute_streams_( - GetNumAdditionalComputeStreams(*thunk_executor_)), + num_additional_streams_( + GetNumAdditionalStreams(*thunk_executor_, debug_options)), module_name_(std::move(module_name)), program_shape_(std::move(program_shape)), allocation_ptrs_(GatherAllocationPtrs( @@ -472,9 +508,8 @@ absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( const ServiceExecutableRunOptions* run_options) { se::Stream* main_stream = run_options->stream(); - stream_executor::Platform::Id platform_id = - main_stream->parent()->GetPlatform()->id(); - if (platform_id == stream_executor::rocm::kROCmPlatformId) { + se::PlatformId platform_id = main_stream->parent()->GetPlatform()->id(); + if (platform_id == se::rocm::kROCmPlatformId) { auto cc = main_stream->GetRocmComputeCapability(); std::string stream_arch = cc.gcn_arch_name(); std::string gpu_exec_arch = @@ -482,12 +517,12 @@ absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( TF_RET_CHECK(stream_arch == gpu_exec_arch) << "AMDGPU GCN ISA version mismatch; expected {" << gpu_exec_arch << ", but was " << stream_arch; - } else if (platform_id == stream_executor::cuda::kCudaPlatformId) { + } else if (platform_id == se::cuda::kCudaPlatformId) { se::CudaComputeCapability cc = main_stream->GetCudaComputeCapability(); TF_RET_CHECK(cc == *gpu_version_.cuda_compute_capability()) << "Compute capability mismatch; expected {" << gpu_version_.ToString() << "}, but was {" << cc.ToString() << "}"; - } else if (platform_id == stream_executor::sycl::kSyclPlatformId) { + } else if (platform_id == se::sycl::kSyclPlatformId) { // TODO: Add check. } else { return Internal("Unknown platform"); @@ -496,8 +531,6 @@ absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( return absl::OkStatus(); } -namespace { - absl::Status MaybeSyncAndProfile(const ServiceExecutableRunOptions* run_options, se::EventBasedTimer* execution_timer, se::Stream* stream_to_sync); @@ -512,18 +545,16 @@ absl::Status BarrierAfterExecutable( const DebugOptions* absl_nullable debug_options, se::Stream& stream_to_sync, size_t num_participants); -absl::Status ExecuteThunksImpl(const DebugOptions* debug_options, - const std::string& module_name, - ModuleIdentifier module_id, - ThunkExecutor& thunk_executor, - Thunk::ExecutableSource executable_source, - const ServiceExecutableRunOptions* run_options, - const BufferAllocations& buffer_allocations, - bool block_host_until_done, - int64_t num_additional_compute_streams, - CollectiveMemoryCache& collective_memory_cache, - bool collective_use_minimal_resource, - RendezvousFlag& post_init_rendezvous_flag) { +absl::Status GpuExecutable::ExecuteThunksImpl( + const DebugOptions* debug_options, const std::string& module_name, + ModuleIdentifier module_id, ThunkExecutor& thunk_executor, + Thunk::ExecutableSource executable_source, + const ServiceExecutableRunOptions* run_options, + const BufferAllocations& buffer_allocations, bool block_host_until_done, + GpuExecutable::NumAdditionalStreams num_additional_streams, + CollectiveMemoryCache& collective_memory_cache, + bool collective_use_minimal_resource, + RendezvousFlag& post_init_rendezvous_flag) { bool mock_collectives = run_options->run_options().gpu_executable_run_options() ? run_options->run_options() @@ -543,14 +574,15 @@ absl::Status ExecuteThunksImpl(const DebugOptions* debug_options, se::Stream* main_stream = run_options->stream(); se::StreamExecutor* executor = main_stream->parent(); - stream_executor::StreamPriority stream_priority = - stream_executor::StreamPriority::Default; + se::StreamPriority communication_stream_priority = + se::StreamPriority::Default; + // TODO(intel-tf): Enable stream priorities for sycl backend. - if (executor->GetPlatform()->id() == stream_executor::sycl::kSyclPlatformId) { + if (executor->GetPlatform()->id() == se::sycl::kSyclPlatformId) { use_highest_priority_for_async_stream = false; } if (use_highest_priority_for_async_stream) { - stream_priority = stream_executor::StreamPriority::Highest; + communication_stream_priority = se::StreamPriority::Highest; } // Maybe install progress tracker for this execution. @@ -632,53 +664,38 @@ absl::Status ExecuteThunksImpl(const DebugOptions* debug_options, std::move(pre_abort))); } - constexpr int64_t kAsyncStreamTotal = - static_cast(AsyncStreamKind::ASYNC_STREAM_KIND_MEMCPYP2P) + 1; - - // Borrow streams required for CollectiveThunk. - absl::InlinedVector async_comms_streams( - kAsyncStreamTotal, nullptr); + // Borrow stream for tracing command buffers. se::Stream* command_buffer_trace_stream = nullptr; - std::vector async_comms_streams_owner; StreamPool::Ptr borrowed_command_buffer_trace_stream; if (run_options->HasStreamBorrower()) { - ASSIGN_OR_RETURN( - async_comms_streams_owner, - run_options->BorrowStreams(executor->device_ordinal(), - kAsyncStreamTotal, stream_priority)); - for (int64_t i = 0; i < kAsyncStreamTotal; ++i) { - async_comms_streams[i] = async_comms_streams_owner[i].get(); - } - - // Borrow stream for tracing command buffers. ASSIGN_OR_RETURN(borrowed_command_buffer_trace_stream, run_options->BorrowStream(executor->device_ordinal())); command_buffer_trace_stream = borrowed_command_buffer_trace_stream.get(); } - // Borrow streams for additional compute streams. - std::vector additional_compute_streams; - std::vector borrowed_compute_streams; - if (num_additional_compute_streams > 0) { - if (run_options->HasStreamBorrower()) { - ASSIGN_OR_RETURN( - borrowed_compute_streams, - run_options->BorrowStreams(executor->device_ordinal(), - num_additional_compute_streams)); - additional_compute_streams.reserve(num_additional_compute_streams); - for (auto& stream : borrowed_compute_streams) { - additional_compute_streams.push_back(stream.get()); - } - XLA_VLOG_DEVICE(2, run_options->device_ordinal()) - << absl::StreamFormat("Using %d additional compute streams.", - num_additional_compute_streams); - } else { - XLA_VLOG_DEVICE(2, run_options->device_ordinal()) - << "No stream borrower created. " - << "Assigning the default stream to all parallel computes."; - additional_compute_streams.assign(num_additional_compute_streams, - main_stream); - } + // Borrow streams for communication. + BorrowedStreams communication_streams = BorrowedStreams::Assign( + main_stream, num_additional_streams.communication); + if (run_options->HasStreamBorrower()) { + ASSIGN_OR_RETURN(communication_streams, + BorrowStreams(*run_options, executor->device_ordinal(), + num_additional_streams.communication, + communication_stream_priority)); + XLA_VLOG_DEVICE(2, run_options->device_ordinal()) + << absl::StreamFormat("Using %d additional communication streams.", + num_additional_streams.communication); + } + + // Borrow streams for computations. + BorrowedStreams compute_streams = + BorrowedStreams::Assign(main_stream, num_additional_streams.compute); + if (run_options->HasStreamBorrower()) { + ASSIGN_OR_RETURN(compute_streams, + BorrowStreams(*run_options, executor->device_ordinal(), + num_additional_streams.compute, + se::StreamPriority::Default)); + XLA_VLOG_DEVICE(2, run_options->device_ordinal()) << absl::StreamFormat( + "Using %d additional compute streams.", num_additional_streams.compute); } tsl::profiler::TraceMe hlo_module_activity( @@ -706,7 +723,7 @@ absl::Status ExecuteThunksImpl(const DebugOptions* debug_options, ASSIGN_OR_RETURN( CollectiveParams collective_params, CollectiveParams::Create( - *run_options, async_comms_streams, + *run_options, communication_streams.streams, LocalDeviceId(main_stream->parent()->device_ordinal()), std::move(collectives_impl_name), collective_max_nchannels, p2p_max_nchannels, collective_use_minimal_resource)); @@ -793,7 +810,7 @@ absl::Status ExecuteThunksImpl(const DebugOptions* debug_options, Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create( *run_options, buffer_allocations, main_stream, command_buffer_trace_stream, &collective_params, &collective_cliques, - &collective_memory, std::move(additional_compute_streams), + &collective_memory, std::move(compute_streams.streams), &execution_scoped_state); XLA_VLOG_DEVICE(1, run_options->device_ordinal()) @@ -947,6 +964,7 @@ absl::Status BarrierAfterExecutable( const size_t num_participants) { if (debug_options != nullptr && debug_options->xla_gpu_experimental_enable_nvshmem()) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectivesFromRegistry()); ASSIGN_OR_RETURN(std::unique_ptr nvshmem_comm, collectives->CreateCommunicator()); @@ -989,8 +1007,6 @@ absl::Status BarrierAfterExecutable( return absl::OkStatus(); } -} // namespace - absl::StatusOr GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { se::StreamExecutor* executor = stream->parent(); @@ -1012,8 +1028,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { // The CUDA driver isn't able to load a PTX and a binary which are both empty. // It's okay if we skip loading in this case; if the module isn't loaded, all // symbol lookups will fail, just as they should for an empty module. - if (!(executor->GetPlatform()->id() == - stream_executor::cuda::kCudaPlatformId && + if (!(executor->GetPlatform()->id() == se::cuda::kCudaPlatformId && binary().empty() && text().empty())) { ASSIGN_OR_RETURN(module_handle, executor->LoadModule(module_spec)); } @@ -1023,7 +1038,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { int submitted_mem_copies = 0; for (const ConstantInfo& info : constants_) { - absl::StatusOr global_status; + absl::StatusOr global_status; if (static_cast(module_handle)) { global_status = executor->GetSymbol(info.symbol_name, module_handle); } @@ -1073,18 +1088,18 @@ absl::StatusOr GpuExecutable::BufferForAllocation( allocate_granularity) { if (allocation.is_thread_local()) { return se::DeviceAddressBase{}; - } else if (allocation.is_entry_computation_parameter()) { + } + if (allocation.is_entry_computation_parameter()) { int64_t param_no = allocation.parameter_number(); se::DeviceAddressBase registered_buffer = [&] { if (auto unowned_shapedbuffers = std::get_if>(&arguments)) { return (*unowned_shapedbuffers)[param_no]->buffers().element( allocation.param_shape_index()); - } else { - return std::get>(arguments)[param_no] - .Buffer(allocation.param_shape_index()) - .AsDeviceAddress(); } + return std::get>(arguments)[param_no] + .Buffer(allocation.param_shape_index()) + .AsDeviceAddress(); }(); if (registered_buffer.is_null() && registered_buffer.size() > 0) { return FailedPrecondition( @@ -1095,32 +1110,33 @@ absl::StatusOr GpuExecutable::BufferForAllocation( allocation.param_shape_index().ToString(), param_no); } return registered_buffer; - } else if (allocation.is_constant()) { + } + if (allocation.is_constant()) { auto it = globals->find(arg_idx); if (it == globals->end()) { return se::DeviceAddressBase(); } return it->second; - } else { - // Allocate each allocation that might escape, or is the temp buffer. - CHECK(allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()); - int64_t buffer_size = allocation.size(); - se::DeviceAddressBase buffer_address; - if (buffer_size > 0) { - // Maybe round up buffer allocation size to the requested granulariy. - if (auto it = allocate_granularity.find(allocation.color()); - it != allocate_granularity.end()) { - buffer_size = RoundUpTo(buffer_size, it->second); - } - ASSIGN_OR_RETURN( - se::ScopedDeviceAddress buffer, - memory_allocator->Allocate(device_ordinal, buffer_size, - /*retry_on_failure=*/true, - /*memory_space=*/allocation.color())); - buffer_address = buffer.Release(); + } + + // Allocate each allocation that might escape, or is the temp buffer. + CHECK(allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()); + int64_t buffer_size = allocation.size(); + se::DeviceAddressBase buffer_address; + if (buffer_size > 0) { + // Maybe round up buffer allocation size to the requested granularity. + if (auto it = allocate_granularity.find(allocation.color()); + it != allocate_granularity.end()) { + buffer_size = RoundUpTo(buffer_size, it->second); } - return buffer_address; + ASSIGN_OR_RETURN( + se::ScopedDeviceAddress buffer, + memory_allocator->Allocate(device_ordinal, buffer_size, + /*retry_on_failure=*/true, + /*memory_space=*/allocation.color())); + buffer_address = buffer.Release(); } + return buffer_address; } static absl::Status CheckAlignment(const BufferAllocation& allocation, @@ -1128,11 +1144,11 @@ static absl::Status CheckAlignment(const BufferAllocation& allocation, const int64_t expected_alignment = [&] { if (allocation.is_entry_computation_parameter()) { return kEntryParameterAlignBytes; - } else if (allocation.is_constant()) { + } + if (allocation.is_constant()) { return kConstantBufferAlignBytes; - } else { - return kXlaAllocatedBufferAlignBytes; } + return kXlaAllocatedBufferAlignBytes; }(); if (!buffer.is_null() && reinterpret_cast(buffer.opaque()) % expected_alignment != 0) { @@ -1244,12 +1260,12 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( // computations to use the same GPU simultaneously. We do not add locking for // "recursive" invocations, which are done when holding a lock already. std::variant gpu_lock( - std::in_place_index_t<0>{}, &GetGpuMutex(executor)); + std::in_place_index_t<0>{}, GetGpuMutex(executor)); // Maybe update to a writer lock to get exclusive access to underlying GPU. if (auto* gpu_opts = run_options->run_options().gpu_executable_run_options(); gpu_opts && gpu_opts->requires_exclusive_lock_on_gpu()) { - gpu_lock.emplace<1>(&GetGpuMutex(executor)); + gpu_lock.emplace<1>(GetGpuMutex(executor)); } const GpuExecutable::BufferAllocToDeviceMemoryMap* globals; @@ -1313,13 +1329,12 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( if (std::holds_alternative>( arguments)) { return nullptr; - } else { - auto unowned_execution_input = - std::get>(arguments); - ExecutionInput& input = - unowned_execution_input[allocation->parameter_number()]; - return input.MutableBuffer(allocation->param_shape_index()); } + auto unowned_execution_input = + std::get>(arguments); + ExecutionInput& input = + unowned_execution_input[allocation->parameter_number()]; + return input.MutableBuffer(allocation->param_shape_index()); }(); if (output_info.alias_config->must_alias() && maybe_owning_memory && !maybe_owning_memory->HasOwnership()) { @@ -1632,7 +1647,7 @@ absl::Status GpuExecutable::ExecuteThunksWithVaRemapping( has_module() ? &module_config().debug_options() : nullptr, module_name_, unique_id, *thunk_executor_, executable_source, run_options, remapped_buffer_allocations, block_host_until_done, - num_additional_compute_streams_, collective_memory_cache_, + num_additional_streams_, collective_memory_cache_, collective_use_minimal_resource, post_init_rendezvous_flag_)); // Record event so VA range can be reclaimed after GPU finishes. @@ -1723,7 +1738,7 @@ absl::Status GpuExecutable::ExecuteThunks( // Check if command buffer VA remapping is active. bool use_command_buffer_va_remapping = - (command_buffer_allocation_indexes_.size() > 0) && has_module() && + !command_buffer_allocation_indexes_.empty() && has_module() && module_config().debug_options().xla_gpu_command_buffer_update_mode() != DebugOptions::ALWAYS_UPDATE && dynamic_cast(memory_allocator) != nullptr; @@ -1747,9 +1762,9 @@ absl::Status GpuExecutable::ExecuteThunks( RETURN_IF_ERROR(ExecuteThunksImpl( has_module() ? &module_config().debug_options() : nullptr, module_name_, unique_id, *thunk_executor_, executable_source, run_options, - buffer_allocations, block_host_until_done, - num_additional_compute_streams_, collective_memory_cache_, - collective_use_minimal_resource, post_init_rendezvous_flag_)); + buffer_allocations, block_host_until_done, num_additional_streams_, + collective_memory_cache_, collective_use_minimal_resource, + post_init_rendezvous_flag_)); } return absl::OkStatus(); } @@ -1915,7 +1930,8 @@ absl::StatusOr GpuExecutable::ToProto() const { } if (has_module()) { - *proto.mutable_hlo_module_with_config() = module().ToProtoWithConfig(); + *proto.mutable_hlo_module_with_config() = + module().ToProtoWithConfig(/*intern_backend_config=*/true); } proto.mutable_output_info_map()->Reserve(output_info_.size()); @@ -1945,7 +1961,7 @@ absl::StatusOr> GpuExecutable::FromProto( const GpuExecutableProto& proto, const se::DeviceDescription& device_description, absl::string_view platform_name, DebugOptions debug_options, - const std::optional& + const std::optional& symbol_resolver) { Params params; params.debug_options = std::move(debug_options); @@ -1984,9 +2000,9 @@ absl::StatusOr> GpuExecutable::FromProto( params.dnn_compiled_graphs.emplace(key, value); } - ASSIGN_OR_RETURN(stream_executor::GpuComputeCapability gpu_compute_capability, - stream_executor::GpuComputeCapability::FromProto( - proto.gpu_compute_capability())); + ASSIGN_OR_RETURN( + se::GpuComputeCapability gpu_compute_capability, + se::GpuComputeCapability::FromProto(proto.gpu_compute_capability())); if (gpu_compute_capability != device_description.gpu_compute_capability()) { return absl::InvalidArgumentError(absl::StrFormat( @@ -2034,9 +2050,9 @@ absl::StatusOr> GpuExecutable::FromProto( ASSIGN_OR_RETURN(params.program_shape, ProgramShape::FromProto(proto.program_shape())); - ASSIGN_OR_RETURN(params.executable_abi_version, - stream_executor::ExecutableAbiVersion::FromProto( - proto.executable_abi_version())); + ASSIGN_OR_RETURN( + params.executable_abi_version, + se::ExecutableAbiVersion::FromProto(proto.executable_abi_version())); return Create(std::move(params)); } diff --git a/third_party/xla/xla/service/gpu/gpu_executable.h b/third_party/xla/xla/service/gpu/gpu_executable.h index ced3bbf048f224..d09899cd2a35e4 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.h +++ b/third_party/xla/xla/service/gpu/gpu_executable.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -29,7 +30,6 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -52,20 +52,25 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/gpu_executable.pb.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/hlo.pb.h" #include "xla/service/logical_buffer.h" #include "xla/service/rendezvous.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" +#include "xla/service/stream_pool.h" +#include "xla/service/xla_debug_info_manager.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/abi/executable_abi_version.h" #include "xla/stream_executor/device_address.h" #include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/event.h" #include "xla/stream_executor/kernel_stats.h" #include "xla/stream_executor/memory_reservation.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/scoped_module_handle.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla.pb.h" @@ -79,6 +84,11 @@ namespace gpu { // This is an immutable data type after initialization, and thus thread safe. class GpuExecutable : public Executable { public: + struct NumAdditionalStreams { + int compute = 0; + int communication = 0; + }; + struct ConstantInfo { std::string symbol_name; DenseDataIntermediate content; @@ -133,7 +143,7 @@ class GpuExecutable : public Executable { std::unique_ptr debug_module = nullptr; bool enable_debug_info_manager = true; ModuleStats module_stats; - stream_executor::ExecutableAbiVersion executable_abi_version; + se::ExecutableAbiVersion executable_abi_version; std::optional cpu_target_machine_options; std::optional buffer_assignment_proto; }; @@ -237,7 +247,7 @@ class GpuExecutable : public Executable { // the given stream, it is skipped and the cached map is immediately returned // instead. absl::StatusOr ResolveConstantGlobals( - stream_executor::Stream* stream); + se::Stream* stream); absl::Status VerboseAllocationError(absl::Status s); @@ -245,7 +255,7 @@ class GpuExecutable : public Executable { const GpuExecutableProto&, const se::DeviceDescription& device_description, absl::string_view platform, DebugOptions debug_options, - const std::optional& + const std::optional& symbol_resolver = std::nullopt); absl::StatusOr ToProto() const; @@ -254,8 +264,8 @@ class GpuExecutable : public Executable { const ExecutableBuildOptions& options, const DebugOptions& debug_options) const final; - absl::StatusOr - GetExecutableAbiVersion() const override { + absl::StatusOr GetExecutableAbiVersion() + const override { return executable_abi_version_; } @@ -285,6 +295,15 @@ class GpuExecutable : public Executable { std::optional scoped_mapping; }; + // Additional streams borrowed at run time for the execution. + struct BorrowedStreams { + std::vector streams; + std::vector owners; + + // Assigns `stream` to all requested stream slots. + static BorrowedStreams Assign(se::Stream* stream, int num_streams); + }; + // Use GpuExecutable::Create() to create an instance. explicit GpuExecutable( std::unique_ptr debug_module, std::string asm_text, @@ -300,7 +319,7 @@ class GpuExecutable : public Executable { absl::flat_hash_map output_info, bool enable_debug_info_manager, ModuleStats module_stats, absl::StatusOr> thunk_sequence_proto, - stream_executor::ExecutableAbiVersion executable_abi_version, + se::ExecutableAbiVersion executable_abi_version, std::optional cpu_target_machine_options, std::optional buffer_assignment_proto); @@ -324,16 +343,31 @@ class GpuExecutable : public Executable { const absl::flat_hash_map& allocate_granularity); + static absl::StatusOr BorrowStreams( + const ServiceExecutableRunOptions& run_options, int device_ordinal, + int num_streams, se::StreamPriority priority); + // Handles the VA remapping path of ExecuteThunks: reserves or remaps the // virtual address range for command buffer allocations, then delegates to // ExecuteThunksImpl with the remapped BufferAllocations. absl::Status ExecuteThunksWithVaRemapping( const BufferAllocations& buffer_allocations, const ServiceExecutableRunOptions* run_options, - stream_executor::StreamExecutor* executor, int64_t unique_id, + se::StreamExecutor* executor, int64_t unique_id, Thunk::ExecutableSource executable_source, bool block_host_until_done, bool collective_use_minimal_resource); + static absl::Status ExecuteThunksImpl( + const DebugOptions* debug_options, const std::string& module_name, + ModuleIdentifier module_id, ThunkExecutor& thunk_executor, + Thunk::ExecutableSource executable_source, + const ServiceExecutableRunOptions* run_options, + const BufferAllocations& buffer_allocations, bool block_host_until_done, + NumAdditionalStreams num_additional_streams, + CollectiveMemoryCache& collective_memory_cache, + bool collective_use_minimal_resource, + RendezvousFlag& post_init_rendezvous_flag); + // The LLVM IR, in string format, of the unoptimized module generated for // this GpuExecutable. We save a string instead of an llvm::Module* because // leaving llvm::Module* in a singleton can cause the heap checker to emit @@ -360,8 +394,8 @@ class GpuExecutable : public Executable { // ThunkEmitter. std::unique_ptr thunk_executor_; - // Number of additional compute streams requested by `AsyncStartThunks`. - int64_t num_additional_compute_streams_; + // Number of additional streams available at run time. + NumAdditionalStreams num_additional_streams_; std::string module_name_; @@ -421,17 +455,16 @@ class GpuExecutable : public Executable { absl::Mutex module_handle_mutex_; // Cache of module handles. Required to keep loaded modules alive until this // executable is destroyed. - absl::flat_hash_map + absl::flat_hash_map module_handles_ ABSL_GUARDED_BY(module_handle_mutex_); // Cache of constant buffer allocation maps used by `ResolveConstantGlobals`. - absl::flat_hash_map> module_globals_ ABSL_GUARDED_BY(module_handle_mutex_); // Cache previous memory allocations for current module, this is used to help // identify if user's model have unstable pointers by turning on VLOG(5). - absl::flat_hash_map> + absl::flat_hash_map> module_allocations_ ABSL_GUARDED_BY(module_handle_mutex_); std::vector constants_; @@ -445,8 +478,7 @@ class GpuExecutable : public Executable { // Separate mutex for VA ranges to avoid contention with module_handle_mutex_ // during VA remapping operations which may involve GPU synchronization. absl::Mutex va_ranges_mutex_; - absl::node_hash_map, - VaRanges> + absl::node_hash_map, VaRanges> module_va_ranges_ ABSL_GUARDED_BY(va_ranges_mutex_); RendezvousFlag post_init_rendezvous_flag_; GpuExecutable(const GpuExecutable&) = delete; @@ -456,7 +488,7 @@ class GpuExecutable : public Executable { // Might contain an error if the given thunk graph is not serializable. absl::StatusOr> thunk_sequence_proto_; - stream_executor::ExecutableAbiVersion executable_abi_version_; + se::ExecutableAbiVersion executable_abi_version_; std::optional cpu_target_machine_options_; diff --git a/third_party/xla/xla/service/gpu/gpu_executable_test.cc b/third_party/xla/xla/service/gpu/gpu_executable_test.cc index 54e6525ce789af..85819f5923787b 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable_test.cc @@ -519,6 +519,65 @@ TEST_F(GpuExecutableTest, ProtoConversion) { "+test_features")); } +TEST_F(GpuExecutableTest, ProtoConversionWithBackendConfigInterning) { + se::DeviceDescription device_description; + device_description.set_gpu_compute_capability( + se::GpuComputeCapability{se::CudaComputeCapability::Volta()}); + + GpuExecutable::Params params; + params.module_name = "test_module"; + params.executable = std::make_unique(ThunkSequence{}); + params.device_description = device_description; + params.enable_debug_info_manager = false; + + // Create a debug module with some instructions sharing backend config using + // parser. + const char* hlo_text = R"( + HloModule test_module + ENTRY comp { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto debug_module, + ParseAndReturnUnverifiedModule(hlo_text)); + HloInstruction* p0 = + debug_module->entry_computation()->GetInstructionWithName("p0"); + HloInstruction* p1 = + debug_module->entry_computation()->GetInstructionWithName("p1"); + + p0->set_raw_backend_config_string("tokamax:{\"data\": 1}"); + p1->CopyBackendConfigFrom(p0); // Force in-memory sharing to test fast path. + params.debug_module = std::move(debug_module); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + GpuExecutable::Create(std::move(params))); + TF_ASSERT_OK_AND_ASSIGN(GpuExecutableProto proto, executable->ToProto()); + + // Verify that the serialized HLO module has interned backend configs using + // Partially(EqualsProto). + EXPECT_THAT(proto, Partially(EqualsProto(R"pb( + hlo_module_with_config { + hlo_module { + payloads: "tokamax:{\"data\": 1}" + computations { + instructions { + name: "p0" + backend_config_payload { id: 0 } + backend_config: "" + } + instructions { + name: "p1" + backend_config_payload { id: 0 } + backend_config: "" + } + instructions { name: "add" } + } + } + } + )pb"))); +} + TEST_F(GpuExecutableTest, GpuExecutableDump) { tsl::Env* env = tsl::Env::Default(); diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index 559c970a679981..838870c20c4f99 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -37,6 +37,7 @@ cc_library( "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -91,6 +92,7 @@ cc_library( "//xla/stream_executor/cuda:subprocess_compilation", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -189,6 +191,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:rocm_rocdl_path", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:env_var", "@com_google_absl//absl/algorithm:container", @@ -409,6 +412,7 @@ cc_library( "//xla/service/llvm_ir:llvm_command_line_options", "//xla/stream_executor:device_description", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Analysis", diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc index 9ca134642cb3a1..19b927b625dbd8 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc @@ -45,6 +45,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/Analysis/CGSCCPassManager.h" @@ -655,7 +656,7 @@ absl::Status AMDGPUTargetModuleLinker( return xla::Internal("Incompatible compute capability was specified."); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( amdgpu::LinkROCDLIfNecessary(module, compute_capability->gfx_version(), debug_options, device_bitcode_dir_path)); @@ -785,13 +786,13 @@ absl::StatusOr CompileToHsacoInternal( GetTargetMachine(default_target_triple, gfx, debug_options, feature_str); // Link with ROCm-Device-Libs, and optimize the LLVM module. - TF_RETURN_IF_ERROR(gpu::LinkAndOptimizeModule( + RETURN_IF_ERROR(gpu::LinkAndOptimizeModule( module, gpu_version, debug_options, rocdl_dir_path, AMDGPUTargetModuleLinker, default_target_triple, target_machine.get(), kAMDGPUInlineThreshold)); // Lower optimized LLVM module to HSA code object. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::string hsaco_path, EmitModuleToHsaco(module, target_machine.get(), debug_options)); @@ -948,8 +949,7 @@ absl::Status LinkROCDLIfNecessary(llvm::Module* module, return absl::OkStatus(); } - TF_RETURN_IF_ERROR( - LinkWithBitcodeVector(module, GetROCDLPaths(rocdl_dir_path))); + RETURN_IF_ERROR(LinkWithBitcodeVector(module, GetROCDLPaths(rocdl_dir_path))); // Sanitize stray metadata from the bitcode files if (auto* opencl_version = module->getNamedMetadata("opencl.ocl.version")) { diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 875b90d463a928..9e2cd683935adc 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/Any.h" #include "llvm/ADT/StringSet.h" #include "llvm/Analysis/CGSCCPassManager.h" @@ -257,7 +258,7 @@ absl::Status LinkAndOptimizeModule( return absl::StrFormat("XlaOptimizeLlvmIr:#module=%s#", module->getName().str()); }); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( module_linker(module, gpu_version, debug_options, device_bitcode_path)); llvm::LoopAnalysisManager lam; diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc index 7aae06eae96e32..81130c6bb23f91 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_backend.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "llvm/ADT/FloatingPointMode.h" #include "llvm/Analysis/CGSCCPassManager.h" @@ -126,7 +127,7 @@ absl::Status NVPTXTargetModuleLinker(llvm::Module* module, const std::string& device_bitcode_path) { // Link the input module with libdevice, to pull in implementations of some // builtins. - TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(module, device_bitcode_path)); + RETURN_IF_ERROR(LinkLibdeviceIfNecessary(module, device_bitcode_path)); // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass // can access it. @@ -343,7 +344,7 @@ absl::StatusOr CompileToPtx( uint64_t start_usecs = tsl::Env::Default()->NowMicros(); // Link with libdevice, and optimize the LLVM module. - TF_RETURN_IF_ERROR(LinkAndOptimizeModule( + RETURN_IF_ERROR(LinkAndOptimizeModule( module, gpu_version, debug_options, LibDevicePath(debug_options.xla_gpu_cuda_data_dir()), NVPTXTargetModuleLinker, default_target_triple, target_machine.get(), diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/spirv_backend.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/spirv_backend.cc index cde59b020ecdc9..4c1be7eaa2dbaa 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/spirv_backend.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/spirv_backend.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/base/no_destructor.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/IR/Constants.h" @@ -191,7 +192,7 @@ absl::StatusOr CompileToSPIRV( const_cast(sub_target->getSubtargetImpl()) ->initAvailableExtensions(common_spirv_extensions); - TF_RETURN_IF_ERROR(LinkAndOptimizeModule( + RETURN_IF_ERROR(LinkAndOptimizeModule( module, gpu_version, debug_options, "", SPIRVTargetModuleLinker, default_target_triple, target_machine.get(), kDefaultInlineThreshold)); diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 942e706236bd0c..c6945d1be5cb4c 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -106,9 +106,7 @@ xla_cc_test( "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tests:xla_internal_test_main", "//xla/tests/restricted:hlo_test_base_legacy", - "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:status_macros", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/time", @@ -146,7 +144,6 @@ xla_cc_test( "//xla/hlo/parser:hlo_parser", "//xla/stream_executor:device_description", "//xla/tests:xla_internal_test_main", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", @@ -369,7 +366,6 @@ xla_cc_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/stream_executor:device_description", "//xla/tests:xla_internal_test_main", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", @@ -638,8 +634,6 @@ xla_cc_test( "//xla/service/gpu:ir_emission_utils", "//xla/stream_executor:device_description", "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", @@ -709,7 +703,6 @@ xla_cc_test( "//xla/hlo/utils:hlo_traversal", "//xla/service/gpu:backend_configs_cc", "//xla/tsl/platform:status_macros", - "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", @@ -754,15 +747,14 @@ xla_cc_test( deps = [ ":triton_emitter_constraints", "//xla/codegen/tiling:symbolic_tile_analysis", + "//xla/codegen/tiling:tiling_specification", "//xla/hlo/analysis:symbolic_map", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:verified_hlo_module", - "//xla/hlo/utils:hlo_traversal", "//xla/service:instruction_fusion", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor:device_description", - "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/log", "@com_google_absl//absl/status:status_matchers", @@ -1153,7 +1145,6 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", @@ -1239,7 +1230,6 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -1353,7 +1343,6 @@ xla_cc_test( "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tsl/platform:status_macros", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index 6fbec379b95b7b..6d070a822ca864 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -645,7 +645,7 @@ INSTANTIATE_TEST_SUITE_P(CoalescingForTiledHloTest, CoalescingForTiledHloTest, TEST_P( CoalescingForTiledHloTest, EffectiveBandwidthUtilizationRateIsComputedCorrectlyForTiledMemoryAccess) { // NOLINT(whitespace/line_length) - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m ENTRY main { diff --git a/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc b/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc index 1ec971cc809550..16161d534c4369 100644 --- a/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc +++ b/third_party/xla/xla/service/gpu/model/collective_interpolator_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" @@ -41,7 +42,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" -#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" namespace xla::gpu { @@ -1034,7 +1034,7 @@ TEST_P(CollectiveInterpolationWithDefaultProfileTest, LoadsDefaultProfile) { auto device_info = test_name == "B200" ? TestGpuDeviceInfo::B200SXMDeviceInfo(cc) : TestGpuDeviceInfo::RTXA6000DeviceInfo(cc); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr interpolator, CollectiveInterpolator::Create(kNumGpusPerHost, device_info)); absl::string_view kHlo = R"( @@ -1052,7 +1052,7 @@ TEST_P(CollectiveInterpolationWithDefaultProfileTest, LoadsDefaultProfile) { replica_groups=[1,8]<=[8], use_global_device_ids=true, channel_id=1 } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo)); HloCollectiveInstruction* instr = Cast( module->entry_computation()->root_instruction()); diff --git a/third_party/xla/xla/service/gpu/model/collective_ptable_stats_collection_test.cc b/third_party/xla/xla/service/gpu/model/collective_ptable_stats_collection_test.cc index 38388fe250b377..7025412ba445d1 100644 --- a/third_party/xla/xla/service/gpu/model/collective_ptable_stats_collection_test.cc +++ b/third_party/xla/xla/service/gpu/model/collective_ptable_stats_collection_test.cc @@ -84,8 +84,8 @@ class CollectivePerfTableStatsCollectionTest profiles_path_(tsl::io::JoinPath(tsl::testing::TmpDir(), kFile)) {} void SetUp() override { - TF_ASSERT_OK(tsl::WriteTextProto(tsl::Env::Default(), profiles_path_, - TestProfiles(device_info_))); + ASSERT_OK(tsl::WriteTextProto(tsl::Env::Default(), profiles_path_, + TestProfiles(device_info_))); } protected: @@ -112,10 +112,10 @@ TEST_F(CollectivePerfTableStatsCollectionTest, ROOT ar-done = f32[1024] all-reduce-done(ar-start) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, CollectivePerfTableStatsCollection( - profiles_path_, device_info_) - .Run(module.get())); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(bool changed, CollectivePerfTableStatsCollection( + profiles_path_, device_info_) + .Run(module.get())); VLOG(1) << module->ToString(); diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc index 473b7ea2cdb463..ef87ec9b1d9018 100644 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/model/fusion_analysis_cache.h" +#include #include #include "absl/strings/string_view.h" #include "xla/hlo/parser/hlo_parser.h" @@ -47,8 +48,7 @@ TEST_F(FusionAnalysisCacheTest, CachesAndInvalidates) { ENTRY e { ROOT r.1 = f32[1000] fusion(), kind=kLoop, calls=f })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); auto* computation = module->GetComputationWithName("f"); auto* broadcast = computation->GetInstructionWithName("b0"); @@ -87,8 +87,7 @@ TEST_F(FusionAnalysisCacheTest, CachesAndInvalidatesProducerConsumerFusions) { f0 = f32[] fusion(), kind=kInput, calls=f ROOT n0 = f32[] negate(f0) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); auto* fusion = module->entry_computation()->GetInstructionWithName("f0"); auto* neg = module->entry_computation()->GetInstructionWithName("n0"); diff --git a/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc b/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc index 6030bd5be645e4..45f3ac312877f5 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc @@ -57,7 +57,7 @@ class GpuCostModelStatsCollectionTest : public HloHardwareIndependentTestBase { }; TEST_F(GpuCostModelStatsCollectionTest, FusionInEntryComputation) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"hlo( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"hlo( HloModule test_module log { @@ -74,15 +74,15 @@ TEST_F(GpuCostModelStatsCollectionTest, FusionInEntryComputation) { EXPECT_THAT(cost_model_stats_.Run(module.get()), IsOkAndHolds(false)); HloInstruction* root = module->entry_computation()->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, - root->backend_config()); + ASSERT_OK_AND_ASSIGN(auto gpu_config, + root->backend_config()); EXPECT_EQ(gpu_config.reification_cost_size(), 1); EXPECT_GT(gpu_config.reification_cost()[0].end_to_end_cycles(), 0); } TEST_F(GpuCostModelStatsCollectionTest, FusionInWhileComputation) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"hlo( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"hlo( HloModule test_module cond { @@ -111,15 +111,15 @@ TEST_F(GpuCostModelStatsCollectionTest, FusionInWhileComputation) { ->root_instruction() ->while_body() ->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, - root->backend_config()); + ASSERT_OK_AND_ASSIGN(auto gpu_config, + root->backend_config()); EXPECT_EQ(gpu_config.reification_cost_size(), 1); EXPECT_GT(gpu_config.reification_cost()[0].end_to_end_cycles(), 0); } TEST_F(GpuCostModelStatsCollectionTest, GemmCostModelAddedToGemmFusion) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"hlo( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"hlo( HloModule test_module gemm_fusion_dot_computation { @@ -152,8 +152,8 @@ TEST_F(GpuCostModelStatsCollectionTest, GemmCostModelAddedToGemmFusion) { EXPECT_THAT(cost_model_stats_.Run(module.get()), IsOkAndHolds(false)); HloInstruction* root = module->entry_computation()->root_instruction(); - TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, - root->backend_config()); + ASSERT_OK_AND_ASSIGN(auto gpu_config, + root->backend_config()); EXPECT_THAT(gpu_config.reification_cost(), Contains(Truly([](const ReificationCost& cost) { diff --git a/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model_test.cc index 9d4004d06a8b95..0667d288c56ace 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_dot_fusion_cost_model_test.cc @@ -44,8 +44,8 @@ class GpuDotFusionCostModelTest : public HloHardwareIndependentTestBase { }; TEST_F(GpuDotFusionCostModelTest, GpuDotComputeBoundBf16) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( ENTRY e { p0 = bf16[8192,8192] parameter(0) p1 = bf16[8192,8192] parameter(1) @@ -65,11 +65,11 @@ backend_config={"sizes":["32"]} auto* dot = Cast(module->entry_computation()->root_instruction()); ASSERT_IS_OK(gpu_dot_fusion_cost_model::IsSupported(dot)); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( EstimateRunTimeData runtime_h100, gpu_dot_fusion_cost_model::EstimateRunTimeForDotOpWithBlockParameters( dot, block_params, ddh100_)); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto expected_compute_and_flops_h100, gpu_dot_fusion_cost_model::detail:: CalculateComputeTimeWithTileAndWaveQuantization( @@ -85,8 +85,8 @@ backend_config={"sizes":["32"]} TEST_F(GpuDotFusionCostModelTest, GpuDotMemoryBoundBf16) { // TODO: b/510666436 - Backend config tuned to minimize L2 loads replication // so the operation remains strictly HBM bounded. - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( ENTRY e { p0 = bf16[4,4096] parameter(0) p1 = bf16[4096,4096] parameter(1) @@ -119,8 +119,8 @@ backend_config={"sizes":["512"]} } TEST_F(GpuDotFusionCostModelTest, DifferentContractingDimsHaveSameRuntime) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1_0, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1_0, + ParseAndReturnVerifiedModule(R"( ENTRY e { p0 = bf16[8192,1024] parameter(0) p1 = bf16[1024,4096] parameter(1) @@ -129,8 +129,8 @@ lhs_contracting_dims={1}, rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_bf16 backend_config={"sizes":["32"]} })")); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0_1, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0_1, + ParseAndReturnVerifiedModule(R"( ENTRY e { p0 = bf16[1024,8192] parameter(0) p1 = bf16[4096,1024] parameter(1) @@ -148,7 +148,7 @@ backend_config={"sizes":["32"]} auto* dot_1_0 = Cast( module_1_0->entry_computation()->root_instruction()); ASSERT_IS_OK(gpu_dot_fusion_cost_model::IsSupported(dot_1_0)); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( EstimateRunTimeData runtime_h100_1_0, gpu_dot_fusion_cost_model::EstimateRunTimeForDotOpWithBlockParameters( dot_1_0, block_params, ddh100_)); @@ -156,7 +156,7 @@ backend_config={"sizes":["32"]} auto* dot_0_1 = Cast( module_0_1->entry_computation()->root_instruction()); ASSERT_IS_OK(gpu_dot_fusion_cost_model::IsSupported(dot_0_1)); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( EstimateRunTimeData runtime_h100_0_1, gpu_dot_fusion_cost_model::EstimateRunTimeForDotOpWithBlockParameters( dot_0_1, block_params, ddh100_)); @@ -166,8 +166,8 @@ backend_config={"sizes":["32"]} } TEST_F(GpuDotFusionCostModelTest, ExtractBlockKFromTileConfig) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( ENTRY e { p0 = bf16[1024,2048] parameter(0) p1 = bf16[2048,1024] parameter(1) @@ -178,14 +178,14 @@ backend_config={"sizes":["32"]} auto* dot = Cast(module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN(int64_t block_k, - gpu_dot_fusion_cost_model::ExtractBlockK(dot)); + ASSERT_OK_AND_ASSIGN(int64_t block_k, + gpu_dot_fusion_cost_model::ExtractBlockK(dot)); EXPECT_EQ(block_k, 32); } TEST_F(GpuDotFusionCostModelTest, ExtractBlockKNoBackendConfig) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( ENTRY e { p0 = bf16[1024,2048] parameter(0) p1 = bf16[2048,1024] parameter(1) @@ -200,8 +200,8 @@ lhs_contracting_dims={1}, rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_bf16 } TEST_F(GpuDotFusionCostModelTest, GpuDot3DGemmIsSupported) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( ENTRY e { p0 = bf16[16,1024,2048] parameter(0) p1 = bf16[16,2048,1024] parameter(1) @@ -218,7 +218,7 @@ backend_config={"sizes":["32"]} auto* dot = Cast(module->entry_computation()->root_instruction()); ASSERT_IS_OK(gpu_dot_fusion_cost_model::IsSupported(dot)); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( EstimateRunTimeData runtime_h100, gpu_dot_fusion_cost_model::EstimateRunTimeForDotOpWithBlockParameters( dot, block_params, ddh100_)); @@ -229,8 +229,8 @@ backend_config={"sizes":["32"]} // (such as having independent head and batch dimensions in multi-head // attention workloads) without requiring explicit reshape or flattening ops. TEST_F(GpuDotFusionCostModelTest, GpuDot4DGemm) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( ENTRY e { p0 = bf16[2,8,1024,2048] parameter(0) p1 = bf16[2,8,2048,1024] parameter(1) @@ -247,7 +247,7 @@ backend_config={"sizes":["32"]} auto* dot = Cast(module->entry_computation()->root_instruction()); ASSERT_IS_OK(gpu_dot_fusion_cost_model::IsSupported(dot)); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( EstimateRunTimeData runtime_h100, gpu_dot_fusion_cost_model::EstimateRunTimeForDotOpWithBlockParameters( dot, block_params, ddh100_)); @@ -257,8 +257,8 @@ backend_config={"sizes":["32"]} // TODO: b/501002656 - Remove this test once we support transposes in the dot // fusion cost model. TEST_F(GpuDotFusionCostModelTest, GpuDotWithDownstreamTransposeIsRejected) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( ENTRY e { p0 = bf16[1024,2048] parameter(0) p1 = bf16[2048,1024] parameter(1) diff --git a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc index 4532c5182dd65f..23d62fc1d3862e 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -53,8 +54,7 @@ ENTRY entry { ROOT tuple = tuple(conv1) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); HloComputation* comp = module->entry_computation(); const HloInstruction* conv1 = comp->GetInstructionWithName("conv1"); @@ -93,8 +93,7 @@ TEST_F(GpuHloCostAnalysisTest, CublasCustomCall) { } ROOT %get-tuple-element = f32[100,100]{1,0} get-tuple-element(%custom-call.1), index=0 })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); HloComputation* comp = module->entry_computation(); const HloInstruction* instr = comp->GetInstructionWithName("custom-call.1"); @@ -123,8 +122,7 @@ ENTRY entry { ROOT _ = f32[3,4] reduce-window(p0, c0), window={size=4x5 stride=2x1}, to_apply=add } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); int n_output_elements = 3 * 4; @@ -165,8 +163,7 @@ ENTRY e { ROOT r0 = s8[10000] fusion(p0), kind=kInput, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -201,8 +198,7 @@ ENTRY e { ROOT r0 = s8[8000] fusion(p0), kind=kInput, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); options_.count_multiple_input_accesses = false; GpuHloCostAnalysis analysis{options_}; @@ -231,8 +227,7 @@ ENTRY e { ROOT r = f32[1024,1024] fusion(), kind=kInput, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -260,8 +255,7 @@ ENTRY e { ROOT r0 = s8[1] fusion(p0), kind=kInput, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); const HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -291,8 +285,7 @@ ENTRY e { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); const HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -319,8 +312,7 @@ ENTRY e { ROOT r0 = s8[] fusion(param0), kind=kInput, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -349,8 +341,7 @@ ENTRY e { ROOT fusion = (s8[10], u8[10]) fusion(param0), kind=kLoop, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -381,8 +372,7 @@ ENTRY e { ROOT r0 = s8[17] fusion(param0), kind=kInput, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -430,8 +420,7 @@ ENTRY e { ROOT r = s8[2] fusion(p0, p1), kind=kInput, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -457,8 +446,7 @@ ENTRY e { ROOT r = s8[1000] fusion(p0), kind=kInput, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -488,8 +476,8 @@ TEST_F(GpuHloCostAnalysisTest, DynUpdateSliceUsingOperandData) { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_fusion_module_str)); + ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_fusion_module_str)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); HloInstruction* fusion = module->entry_computation()->root_instruction(); @@ -517,8 +505,8 @@ TEST_F(GpuHloCostAnalysisTest, DynUpdateSliceNotUsingOperandData) { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_fusion_module_str)); + ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_fusion_module_str)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); HloInstruction* fusion = module->entry_computation()->root_instruction(); ASSERT_EQ(fusion->opcode(), HloOpcode::kFusion); @@ -554,8 +542,8 @@ TEST_F(GpuHloCostAnalysisTest, CommonElementwiseUseTwoParameters) { ROOT _ = s8[] fusion(p0, p1), kind=kLoop, calls=f })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_fusion_module_str)); + ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_fusion_module_str)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); HloInstruction* fusion = module->entry_computation()->root_instruction(); @@ -582,8 +570,8 @@ TEST_F(GpuHloCostAnalysisTest, CommonElementwiseUseParameterAndRoot) { ROOT _ = s8[10] fusion(p0, p1), kind=kLoop, calls=f })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_fusion_module_str)); + ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_fusion_module_str)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); HloInstruction* fusion = module->entry_computation()->root_instruction(); @@ -615,8 +603,8 @@ TEST_F(GpuHloCostAnalysisTest, ROOT _ = (s8[10], s8[10]) fusion(p0, p1), kind=kLoop, calls=f })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_fusion_module_str)); + ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_fusion_module_str)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); HloInstruction* fusion = module->entry_computation()->root_instruction(); @@ -644,8 +632,7 @@ ENTRY entry_computation { ROOT reduce = f32[32]{0} reduce(param_0.3, constant), dimensions={1}, to_apply=add } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); const HloInstruction* reduce = module->entry_computation()->root_instruction(); @@ -684,8 +671,7 @@ ENTRY entry_computation { ROOT reduce = (f32[32]{0}, f32[32]{0}) reduce(param_0.3, param_1.3, param_2.2, constant), dimensions={1}, to_apply=add } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); const HloInstruction* reduce = module->entry_computation()->root_instruction(); @@ -714,8 +700,7 @@ ENTRY entry { ROOT cp = f32[4096] collective-permute(p0), source_target_pairs={{0,1},{1,0}} } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); const HloInstruction* cp = module->entry_computation()->root_instruction(); EXPECT_EQ(analysis_.BytesTransferred(*cp), 4096 * 4); @@ -731,8 +716,7 @@ ENTRY entry { ROOT r = f32[4096] collective-permute-done(cps) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); const HloInstruction* cps = module->entry_computation()->root_instruction()->operand(0); @@ -755,8 +739,7 @@ ENTRY entry_computation { ROOT _ = f32[4096] all-reduce-done(ar-start) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -776,8 +759,7 @@ ENTRY entry_computation { replica_groups={{0,1,2,3}}, channel_id=1 } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -802,8 +784,7 @@ ENTRY entry_computation { ROOT _ = (f32[4096],f32[2048]) all-gather-done(ag-start) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -832,8 +813,7 @@ ENTRY entry_computation { use_global_device_ids=true, replica_groups={{0,1,2,3}}, channel_id=1 } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -871,8 +851,7 @@ ENTRY entry_computation { ROOT _ = (f32[1024],f32[512]) async-done(rs-start) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -899,8 +878,7 @@ ENTRY entry_computation { ROOT clamp = f32[10] clamp(mul, param_2, param_3) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloOpProfiles::HloOpProfile hlo_op_profile; diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index 0735e40bfc6bf9..796c013ac68f9e 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -45,8 +45,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/tsl/platform/statusor.h" #include "xla/util.h" namespace xla { @@ -79,7 +77,7 @@ INSTANTIATE_TEST_SUITE_P(GpuIndexingPerformanceModelTest, GpuIndexingPerformanceModelTest, ::testing::Bool()); TEST_P(GpuIndexingPerformanceModelTest, TritonGemm) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m triton_dot { @@ -110,9 +108,9 @@ ENTRY e { } } )")); - TF_ASSERT_OK_AND_ASSIGN(auto runtime_data, - indexing_cost_model_.EstimateRunTimeForTriton( - module->entry_computation()->root_instruction())); + ASSERT_OK_AND_ASSIGN(auto runtime_data, + indexing_cost_model_.EstimateRunTimeForTriton( + module->entry_computation()->root_instruction())); EXPECT_EQ(runtime_data.flops, 8388608); EXPECT_EQ(runtime_data.bytes_written, 65536); @@ -122,7 +120,7 @@ ENTRY e { TEST_P(GpuIndexingPerformanceModelTest, TritonSoftmaxFusionInstructionIsSupported) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m add { @@ -148,9 +146,9 @@ ENTRY main { ROOT triton_softmax = f32[512,911]{1,0} fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tiles":[{"sizes":["1","911"]}],"num_warps":"2"}}} } )")); - TF_ASSERT_OK_AND_ASSIGN(auto runtime_data, - indexing_cost_model_.EstimateRunTimeForTriton( - module->entry_computation()->root_instruction())); + ASSERT_OK_AND_ASSIGN(auto runtime_data, + indexing_cost_model_.EstimateRunTimeForTriton( + module->entry_computation()->root_instruction())); constexpr int64_t kParam0SizeBytes = 512 * 911 * 4; constexpr int64_t kParam1SizeBytes = 911 * 4; @@ -168,7 +166,7 @@ ENTRY main { // Example from b/383162692. TEST_P(GpuIndexingPerformanceModelTest, EstimateBestTiling_CombinedFusion) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m add { @@ -259,7 +257,7 @@ ENTRY entry_computation { auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto tiling_result, indexing_cost_model_.TryFindBestTilingForFusion(*fusion_adaptor)); @@ -279,7 +277,7 @@ ENTRY entry_computation { } TEST_P(GpuIndexingPerformanceModelTest, EstimateBestTiling_MultioutputFusion) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m region { @@ -305,7 +303,7 @@ ENTRY entry_computation { auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto tiling_result, indexing_cost_model_.TryFindBestTilingForFusion(*fusion_adaptor)); @@ -323,7 +321,7 @@ ENTRY entry_computation { TEST_P(GpuIndexingPerformanceModelTest, EstimateBestTiling_TritonSoftmax_IsSupported) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m add { @@ -352,7 +350,7 @@ ENTRY main { auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto tiling_result, indexing_cost_model_.TryFindBestTilingForFusion(*fusion_adaptor)); @@ -388,7 +386,7 @@ ENTRY main { TEST_P( GpuIndexingPerformanceModelTest, EstimateRunTimeForTiledFusion_NumberOfTilesLargerThanInt32Max_IsSupported) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule softmax max_computation { @@ -412,7 +410,7 @@ ENTRY main { auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto runtime_data, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, BlockLevelParameters{/*output_tile_sizes=*/{{1, 1}}, @@ -425,7 +423,7 @@ ENTRY main { TEST_P(GpuIndexingPerformanceModelTest, EstimateRunTimeForTiledFusion_Concatenate) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m concatenate_fusion { @@ -448,7 +446,7 @@ ENTRY main { *fusion_adaptor, BlockLevelParameters{/*output_tile_sizes=*/{{1, 128}}, /*num_warps=*/3}); - TF_ASSERT_OK(result.status()); + ASSERT_OK(result.status()); // The flops contribution for a single instruction is calculated as: // flops_per_element * padded_tile_size * num_blocks_cur_hlo EXPECT_EQ(result->flops, @@ -465,7 +463,7 @@ ENTRY main { TEST_P(GpuIndexingPerformanceModelTest, EstimateRunTimeForTiledFusion_DotWithReductionLoop) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m dot_fusion { @@ -490,7 +488,7 @@ ENTRY main { int64_t num_blocks = 8; - TF_ASSERT_OK(result.status()); + ASSERT_OK(result.status()); // The flops contribution for a single instruction is calculated as: // flops_per_element * padded_tile_size * num_blocks_cur_hlo EXPECT_EQ(result->flops, @@ -508,7 +506,7 @@ ENTRY main { TEST_P(GpuIndexingPerformanceModelTest, EstimateRunTimeForTiledFusion_Softmax_RegisterSpill_ReturnsInfinite) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m add { @@ -533,13 +531,13 @@ ENTRY main { auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto res1, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, BlockLevelParameters{/*output_tile_sizes=*/{{1, 16000}}})); EXPECT_NEAR(absl::ToDoubleMicroseconds(res1.exec_time), 3, 1); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto res2, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, BlockLevelParameters{/*output_tile_sizes=*/{{2, 16000}}})); @@ -549,7 +547,7 @@ ENTRY main { TEST_P( GpuIndexingPerformanceModelTest, EstimateRunTimeForTiledFusion_BroadcastReduce_RegisterSpill_ReturnsInfinite) { // NOLINT(whitespace/line_length) - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m add { @@ -580,21 +578,21 @@ ENTRY main { auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto res1, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, BlockLevelParameters{/*output_tile_sizes=*/{{4, 4}}, /*num_warps=*/8})); EXPECT_NEAR(absl::ToDoubleMicroseconds(res1.exec_time), 147, 1); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto res2, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, BlockLevelParameters{/*output_tile_sizes=*/{{8, 4}}, /*num_warps=*/8})); EXPECT_TRUE(res2.IsInfinite()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto res3, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, BlockLevelParameters{/*output_tile_sizes=*/{{4, 8}}, @@ -605,7 +603,7 @@ ENTRY main { TEST_P( GpuIndexingPerformanceModelTest, EstimateRunTimeForTiledFusion_UsesHloDimensionSizeWhenTileCoversFullDimensionForMemoryAccessTime) { // NOLINT(whitespace/line_length) - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m triton_softmax_computation { @@ -623,7 +621,7 @@ ENTRY main { auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto res, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, BlockLevelParameters{/*output_tile_sizes=*/{{65, 65}}, @@ -644,7 +642,7 @@ ENTRY main { TEST_P( GpuIndexingPerformanceModelTest, EstimateRunTimeForTiledFusion_UncoalescedReadsAreScaledBasedOnWasteTransactionPercentage) { // NOLINT(whitespace/line_length) - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m triton_softmax_computation { @@ -662,14 +660,14 @@ ENTRY main { auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto res_coalesced, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, BlockLevelParameters{/*output_tile_sizes=*/{{2, 128}}, /*num_warps=*/2})); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto res_uncoalesced, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, @@ -691,7 +689,7 @@ ENTRY main { TEST_P( GpuIndexingPerformanceModelTest, EstimateRunTimeForTiledFusion_UncoalescedWritesAreScaledBasedOnWasteTransactionPercentage) { // NOLINT(whitespace/line_length) - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m add { @@ -710,13 +708,13 @@ ENTRY main { auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto res_coalesced, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, BlockLevelParameters{/*output_tile_sizes=*/{{16, 128}}})); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( auto res_uncoalesced, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, @@ -736,7 +734,7 @@ ENTRY main { TEST_P(GpuIndexingPerformanceModelTest, GetLaunchDimensionsForTiledFusion_IsSupported) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m triton_softmax_computation { @@ -762,10 +760,10 @@ ENTRY main { /*emitter_specific_constraints_builder=*/nullptr); ASSERT_TRUE(std::holds_alternative(analysis_or_error)); - TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, - std::get(analysis_or_error) - .ComputeTiledComputation(Tiling( - {{fusion_root, FlatTiling({9, 9, 9})}}))); + ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + std::get(analysis_or_error) + .ComputeTiledComputation( + Tiling({{fusion_root, FlatTiling({9, 9, 9})}}))); int64_t num_warps = GpuPerformanceModelWithIndexingAnalysis::EstimateNumWarps( tiled_hlo_computation); @@ -778,7 +776,7 @@ ENTRY main { TEST_P(GpuIndexingPerformanceModelTest, NumberOfWarpsDependsOnLargestLiveTileSize) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m add { @@ -811,7 +809,7 @@ ENTRY main { /*emitter_specific_constraints_builder=*/nullptr); ASSERT_TRUE(std::holds_alternative(analysis_or_error)); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( TiledHloComputation tiled_hlo_computation, std::get(analysis_or_error) .ComputeTiledComputation(Tiling({{fusion_root, FlatTiling({1})}}))); @@ -827,8 +825,8 @@ ENTRY main { class FlopsPerElementTest : public GpuIndexingPerformanceModelTest { public: void CompareFlopsModels(absl::string_view hlo_module_string) { - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_module_string)); + ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_module_string)); GpuHloCostAnalysis cost_analysis( GpuHloCostAnalysis::Options{.count_multiple_input_accesses = true}, diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc index 3f1dd3b0b50d6f..4d5bf529b73e14 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "absl/strings/string_view.h" #include "absl/time/time.h" @@ -31,7 +32,6 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/stream_executor/device_description.h" -#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" namespace xla { @@ -67,8 +67,7 @@ ENTRY entry_computation { ROOT dynamic-update-slice = f32[8,16] dynamic-update-slice(param_0, log, c_0, c_0) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); auto computation = module->entry_computation(); ASSERT_IS_OK(computation->Accept(analysis_.get())); @@ -96,8 +95,7 @@ ENTRY entry_computation { ROOT dynamic-update-slice = f32[8,16] dynamic-update-slice(log, param_1, c_0, c_0) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); auto computation = module->entry_computation(); ASSERT_IS_OK(computation->Accept(analysis_.get())); @@ -139,8 +137,7 @@ ENTRY entry_computation { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); auto computation = module->entry_computation(); ASSERT_IS_OK(computation->Accept(analysis_.get())); @@ -172,8 +169,7 @@ ENTRY entry_computation { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloComputation* computation = module->entry_computation(); ASSERT_IS_OK(computation->Accept(analysis_.get())); @@ -206,8 +202,7 @@ ENTRY entry_computation { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); auto computation = module->entry_computation(); ASSERT_IS_OK(computation->Accept(analysis_.get())); @@ -238,8 +233,7 @@ ENTRY entry_computation { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); auto fusion_analysis = HloFusionAnalysis::Create( *module->entry_computation()->root_instruction(), device_info_); @@ -275,8 +269,7 @@ ENTRY e { backend_config={"fusion_backend_config": {kind: "__triton","block_level_fusion_config":{"output_tiles":[{"sizes":["1","970"]}],"num_warps":"2"}}} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); auto fusion_analysis = HloFusionAnalysis::Create( *module->entry_computation()->root_instruction(), device_info_); @@ -305,8 +298,7 @@ ENTRY e { backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); auto fusion_analysis = HloFusionAnalysis::Create( *module->entry_computation()->root_instruction(), device_info_); diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index d3207c86fb7d1b..f25e377598e2a1 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -92,8 +93,7 @@ ENTRY e { ROOT r.1 = f32[10000000] fusion(), kind=kLoop, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); @@ -118,8 +118,7 @@ ENTRY e { ROOT r.1 = f32[1000] fusion(p0, p1), kind=kLoop, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(root->Accept(&analysis_)); @@ -150,8 +149,7 @@ ENTRY e { ROOT r.1 = f32[10000000] fusion(p0, p1), kind=kLoop, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(root->Accept(&analysis_)); @@ -186,8 +184,7 @@ ENTRY e { ROOT r.1 = f32[10000000] fusion(p0, p1), kind=kLoop, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(root->Accept(&analysis_)); @@ -215,8 +212,7 @@ ENTRY e { ROOT r.1 = f32[10000000] fusion(p0, p1), kind=kLoop, calls=f } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_IS_OK(root->Accept(&analysis_)); @@ -290,7 +286,7 @@ ENTRY fusion { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); auto run = [&](absl::string_view reduce_name_1, @@ -304,10 +300,10 @@ ENTRY fusion { return EstimateRunTimes(producer, consumers); }; - TF_ASSERT_OK_AND_ASSIGN(auto large_small_reduce_runtime, - run("reduce.1", "reduce.2")); - TF_ASSERT_OK_AND_ASSIGN(auto small_large_reduce_runtime, - run("reduce.3", "reduce.4")); + ASSERT_OK_AND_ASSIGN(auto large_small_reduce_runtime, + run("reduce.1", "reduce.2")); + ASSERT_OK_AND_ASSIGN(auto small_large_reduce_runtime, + run("reduce.3", "reduce.4")); // Ignoring memory access patterns and occupancy, the runtime should be about // the same. @@ -334,7 +330,7 @@ ENTRY fusion { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); auto* producer = @@ -376,7 +372,7 @@ ENTRY fusion { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); auto* producer = @@ -407,7 +403,7 @@ ENTRY fusion { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); auto* producer = @@ -461,7 +457,7 @@ ENTRY main { ROOT tuple = (f32[1073741824], f32[1024]) tuple(dus1, dus2) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); auto* operand0 = module->entry_computation()->root_instruction()->operand(0); @@ -508,8 +504,7 @@ ENTRY e2 { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); HloComputation* computation_without_fusion = module->GetComputationWithName("e1"); @@ -548,7 +543,7 @@ ENTRY fusion { ROOT divide = f32[] divide(reduce, p1) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); auto* producer = @@ -595,7 +590,7 @@ ENTRY fusion { ROOT fusion.1 = f32[4,8,8] fusion(exp), kind=kInput, calls=fused_computation.1 })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); auto* fusion_0 = @@ -637,7 +632,7 @@ ENTRY fusion { ROOT reduce = f32[4,32] reduce(fusion, c0), to_apply=add, dimensions={1} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHlo)); ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); auto* fusion = module->entry_computation()->GetInstructionWithName("fusion"); @@ -650,7 +645,7 @@ ENTRY fusion { TEST_F(GpuPerformanceModelTest, EstimateRunTimeForFusion_InfiniteProducer_ReturnsInfinite) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule testmodule ENTRY fusion { @@ -677,7 +672,7 @@ ENTRY fusion { TEST_F(GpuPerformanceModelTest, EstimateRunTimeForFusion_InfiniteConsumer_ReturnsInfinite) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule testmodule ENTRY fusion { @@ -704,7 +699,7 @@ ENTRY fusion { TEST_F(GpuPerformanceModelTest, EstimateRunTimeForFusion_MultiOutputWrite_ReturnsCorrectTime) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule m fused_power { diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc index 0b5fd878a91dbf..966ff15154db2a 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc @@ -120,7 +120,7 @@ TEST_F(HloOpProfilerTest, AllSupportedCombinationsAreMeasurable) { !HloOpProfiler::Unsupported().count(op) && !(is_rocm && skip_on_rocm.count(op))) { auto Type = FloatTypes.count(op) ? F32 : S32; - TF_EXPECT_OK(profiler.MeasureClockCyclesPerOp(op, Type)); + EXPECT_OK(profiler.MeasureClockCyclesPerOp(op, Type)); } } } diff --git a/third_party/xla/xla/service/gpu/model/matmul_interpolator_test.cc b/third_party/xla/xla/service/gpu/model/matmul_interpolator_test.cc index 2bc1ec22a67cb1..60aa6d50f1889b 100644 --- a/third_party/xla/xla/service/gpu/model/matmul_interpolator_test.cc +++ b/third_party/xla/xla/service/gpu/model/matmul_interpolator_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include #include #include "absl/log/check.h" #include "absl/status/statusor.h" @@ -39,7 +40,6 @@ limitations under the License. #include "xla/service/gpu/model/hlo_op_profile.pb.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" -#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" namespace xla::gpu { @@ -212,8 +212,8 @@ class MatmulInterpolatorParamTest : public TestWithParam { TEST_P(MatmulInterpolatorParamTest, MatmulInteprolatorNextNeighbourInterpolation) { const auto& [_, spec, expected_duration] = GetParam(); - TF_ASSERT_OK_AND_ASSIGN(DotContext context, Dot(spec.b, spec.m, spec.n, - spec.k, "f32", "f32", "f32")); + ASSERT_OK_AND_ASSIGN(DotContext context, Dot(spec.b, spec.m, spec.n, spec.k, + "f32", "f32", "f32")); EXPECT_EQ(absl::Trunc(*interpolator().EstimatedRuntime(*context.dot), absl::Milliseconds(1)), expected_duration); @@ -337,8 +337,8 @@ using H100BF16Test = MatmulInterpolatorDefaultTableTest; TEST_P(H100BF16Test, EstimatesRuntimeForBF16) { const auto& [_, spec, expected_duration] = GetParam(); - TF_ASSERT_OK_AND_ASSIGN(DotContext context, - DotBF16(spec.b, spec.m, spec.n, spec.k)); + ASSERT_OK_AND_ASSIGN(DotContext context, + DotBF16(spec.b, spec.m, spec.n, spec.k)); // Compare with nanosecond precision. EXPECT_EQ( absl::Trunc(*GetMatmulInterpolatorH100()->EstimatedRuntime(*context.dot), @@ -423,8 +423,8 @@ using B200BF16Test = MatmulInterpolatorDefaultTableTest; TEST_P(B200BF16Test, EstimatesRuntimeForBF16) { const auto& [_, spec, expected_duration] = GetParam(); - TF_ASSERT_OK_AND_ASSIGN(DotContext context, - DotBF16(spec.b, spec.m, spec.n, spec.k)); + ASSERT_OK_AND_ASSIGN(DotContext context, + DotBF16(spec.b, spec.m, spec.n, spec.k)); // Compare with nanosecond precision. EXPECT_EQ( absl::Trunc(*GetMatmulInterpolatorB200()->EstimatedRuntime(*context.dot), @@ -463,8 +463,8 @@ using H100S8Test = MatmulInterpolatorDefaultTableTest; TEST_P(H100S8Test, EstimatesRuntimeForS8) { const auto& [_, spec, expected_duration] = GetParam(); - TF_ASSERT_OK_AND_ASSIGN(DotContext context, - DotS8(spec.b, spec.m, spec.n, spec.k)); + ASSERT_OK_AND_ASSIGN(DotContext context, + DotS8(spec.b, spec.m, spec.n, spec.k)); // Compare with nanosecond precision. EXPECT_EQ( absl::Trunc(*GetMatmulInterpolatorH100()->EstimatedRuntime(*context.dot), @@ -519,8 +519,8 @@ using B200S8Test = MatmulInterpolatorDefaultTableTest; TEST_P(B200S8Test, EstimatesRuntimeForS8) { const auto& [_, spec, expected_duration] = GetParam(); - TF_ASSERT_OK_AND_ASSIGN(DotContext context, - DotS8(spec.b, spec.m, spec.n, spec.k)); + ASSERT_OK_AND_ASSIGN(DotContext context, + DotS8(spec.b, spec.m, spec.n, spec.k)); // Compare with nanosecond precision. EXPECT_EQ( absl::Trunc(*GetMatmulInterpolatorB200()->EstimatedRuntime(*context.dot), @@ -559,9 +559,9 @@ using H100F8Test = MatmulInterpolatorDefaultTableTest; TEST_P(H100F8Test, EstimatesRuntimeForF8) { const auto& [_, spec, expected_duration] = GetParam(); - TF_ASSERT_OK_AND_ASSIGN(DotContext context, - Dot(spec.b, spec.m, spec.n, spec.k, spec.lhs_type, - spec.rhs_type, spec.result_type)); + ASSERT_OK_AND_ASSIGN(DotContext context, + Dot(spec.b, spec.m, spec.n, spec.k, spec.lhs_type, + spec.rhs_type, spec.result_type)); // Compare with nanosecond precision. EXPECT_EQ( absl::Trunc(*GetMatmulInterpolatorH100()->EstimatedRuntime(*context.dot), @@ -604,9 +604,9 @@ using B200F8Test = MatmulInterpolatorDefaultTableTest; TEST_P(B200F8Test, EstimatesRuntimeForF8) { const auto& [_, spec, expected_duration] = GetParam(); - TF_ASSERT_OK_AND_ASSIGN(DotContext context, - Dot(spec.b, spec.m, spec.n, spec.k, spec.lhs_type, - spec.rhs_type, spec.result_type)); + ASSERT_OK_AND_ASSIGN(DotContext context, + Dot(spec.b, spec.m, spec.n, spec.k, spec.lhs_type, + spec.rhs_type, spec.result_type)); // Compare with nanosecond precision. EXPECT_EQ( absl::Trunc(*GetMatmulInterpolatorB200()->EstimatedRuntime(*context.dot), @@ -725,7 +725,7 @@ TEST_F(MatmulInterpolatorTest, SupportsCublasCustomCalls) { } } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); const HloInstruction& custom_call = *module->entry_computation()->root_instruction(); EXPECT_EQ(*interpolator().EstimatedRuntime(custom_call), absl::Seconds(1)); @@ -762,7 +762,7 @@ TEST_F(MatmulInterpolatorTest, SupportsDotTritonFusion) { } } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); const HloInstruction& custom_call = *module->entry_computation()->root_instruction(); EXPECT_EQ(*interpolator().EstimatedRuntime(custom_call), absl::Seconds(1)); @@ -797,7 +797,7 @@ TEST_F(MatmulInterpolatorTest, SupportsDotTritonNestedGemmFusion) { } } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); const HloInstruction& custom_call = *module->entry_computation()->root_instruction(); EXPECT_EQ(*interpolator().EstimatedRuntime(custom_call), absl::Seconds(1)); diff --git a/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection_test.cc b/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection_test.cc index fd55943d866be2..cf13cd51aef77e 100644 --- a/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection_test.cc +++ b/third_party/xla/xla/service/gpu/model/matmul_ptable_stats_collection_test.cc @@ -141,9 +141,9 @@ TEST_F(MatmulStatsCollectionTest, } } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( bool changed, MatmulPerfTableStatsCollection(profiles_path_, device_info_) .Run(module.get())); @@ -194,8 +194,8 @@ TEST_F(MatmulStatsCollectionTest, } } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + ASSERT_OK_AND_ASSIGN( bool changed, MatmulPerfTableStatsCollection(profiles_path_, device_info_) .Run(module.get())); diff --git a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection_test.cc b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection_test.cc index 13cf1dfa711985..3d20bc2d43fb54 100644 --- a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection_test.cc +++ b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_stats_collection_test.cc @@ -29,10 +29,8 @@ limitations under the License. #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" -#include "xla/tsl/platform/statusor.h" namespace xla::gpu { namespace { @@ -71,13 +69,13 @@ TEST_F(SolGpuCostModelStatsCollectionTest, ROOT ar-done = f32[8192,4096] all-reduce-done(ar-start) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - SolGpuCostModelStatsCollection( - device_info_, HloCostAnalysis::DefaultShapeSize, - pointer_size_, &mlir_context_) - .Run(module.get())); + ASSERT_OK_AND_ASSIGN(bool changed, + SolGpuCostModelStatsCollection( + device_info_, HloCostAnalysis::DefaultShapeSize, + pointer_size_, &mlir_context_) + .Run(module.get())); VLOG(1) << module->ToString(); @@ -109,12 +107,12 @@ TEST_F(SolGpuCostModelStatsCollectionTest, async-start(%param), calls=%async_rs ROOT %rs_done = f32[512,128256] async-done(%rs_start) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - SolGpuCostModelStatsCollection( - device_info_, HloCostAnalysis::DefaultShapeSize, - pointer_size_, &mlir_context_) - .Run(module.get())); + ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + ASSERT_OK_AND_ASSIGN(bool changed, + SolGpuCostModelStatsCollection( + device_info_, HloCostAnalysis::DefaultShapeSize, + pointer_size_, &mlir_context_) + .Run(module.get())); VLOG(1) << module->ToString(); EXPECT_FALSE(changed); HloInstruction* rs_start = FindInstruction(module.get(), "rs_start"); diff --git a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_test.cc b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_test.cc index 7b778e0f682b61..93cecb354ebdfe 100644 --- a/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/sol_gpu_cost_model_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include +#include #include #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/stream_executor/device_description.h" -#include "xla/tsl/platform/statusor.h" namespace xla::gpu { namespace { @@ -110,8 +110,8 @@ TEST(SolGPUCostModelGetConfigTest, ConfigForHopper) { ROOT constant = f32[] constant(0) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kDummyModule)); + ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kDummyModule)); se::DeviceDescription device_info; device_info.set_name("NVIDIA H100 80GB HBM3"); @@ -128,8 +128,8 @@ TEST(SolGPUCostModelGetConfigTest, ConfigForBlackwell) { ROOT constant = f32[] constant(0) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kDummyModule)); + ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kDummyModule)); se::DeviceDescription device_info; device_info.set_name("NVIDIA B200"); @@ -149,8 +149,8 @@ TEST(SolGPUCostModelGetConfigTest, ConfigForDefaultGPU) { ROOT constant = f32[] constant(0) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kDummyModule)); + ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kDummyModule)); se::DeviceDescription device_info; device_info.set_name("NVIDIA H200"); diff --git a/third_party/xla/xla/service/gpu/model/sol_latency_estimator_test.cc b/third_party/xla/xla/service/gpu/model/sol_latency_estimator_test.cc index aeac29bda25510..7910d9052b0f57 100644 --- a/third_party/xla/xla/service/gpu/model/sol_latency_estimator_test.cc +++ b/third_party/xla/xla/service/gpu/model/sol_latency_estimator_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include #include "absl/log/check.h" #include "absl/log/log.h" @@ -45,8 +46,6 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/restricted/hlo_test_base_legacy.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" namespace xla::gpu { @@ -136,16 +135,15 @@ class SolLatencyEstimatorTest : public HloHardwareIndependentTestBase, TEST_P(SolLatencyEstimatorTest, TestLatencyEstimation) { EstimatorTestCase test_case = GetParam(); - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnVerifiedModule(test_case.module_string)); + ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(test_case.module_string)); HloInstruction* instr = hlo_query::FindInstruction( module->entry_computation(), test_case.opcode_to_find); ASSERT_NE(instr, nullptr); absl::Duration actual_time_us; if (test_case.cost_type == CostType::kCollectiveTime) { - TF_ASSERT_OK_AND_ASSIGN(absl::Duration time_us, - ComputeCollectiveTime(*instr)); + ASSERT_OK_AND_ASSIGN(absl::Duration time_us, ComputeCollectiveTime(*instr)); actual_time_us = absl::Trunc(time_us, absl::Microseconds(1)); } else if (test_case.cost_type == CostType::kNodeCost) { actual_time_us = ComputeNodeCost(*instr, module->entry_computation()); @@ -667,7 +665,7 @@ TEST_F(HloHardwareIndependentTestBase, CollectiveCostModelDispatching) { /*analysis=*/nullptr); // NVLink domain collective should use CollectiveInterpolator. - TF_ASSERT_OK_AND_ASSIGN(auto nvl_module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto nvl_module, ParseAndReturnVerifiedModule(R"( HloModule m, num_partitions=16 ENTRY main { p = bf16[8,16000,1000] parameter(0) @@ -688,7 +686,7 @@ ENTRY main { // Cross-partition collective should use S-curve model (world-level across 2 // hosts). - TF_ASSERT_OK_AND_ASSIGN(auto ib_module, ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(auto ib_module, ParseAndReturnVerifiedModule(R"( HloModule m, num_partitions=16 ENTRY main { p = bf16[16,16000,1000] parameter(0) @@ -878,7 +876,7 @@ TEST_F(IsSolLatencyEstimatorEnabledTest, DisabledForHopperWithHostOffloaded) { stream_executor::CudaComputeCapability::Hopper()); auto module = CreateTestModule(config); - TF_ASSERT_OK(AddHostOffloaded(module.get())); + ASSERT_OK(AddHostOffloaded(module.get())); EXPECT_FALSE( SolLatencyEstimator::IsSupportedForModule(*module, gpu_device_info_)); diff --git a/third_party/xla/xla/service/gpu/model/tiling_from_block_parameters_test.cc b/third_party/xla/xla/service/gpu/model/tiling_from_block_parameters_test.cc index 1c5d101610e0e8..3f75bcf04a69de 100644 --- a/third_party/xla/xla/service/gpu/model/tiling_from_block_parameters_test.cc +++ b/third_party/xla/xla/service/gpu/model/tiling_from_block_parameters_test.cc @@ -42,7 +42,6 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/model/block_level_parameters.h" #include "xla/status_macros.h" -#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" namespace xla { @@ -76,8 +75,8 @@ class TilingFromBlockParametersTest : public HloHardwareIndependentTestBase { }; TEST_F(TilingFromBlockParametersTest, GeneratesTilingForSimpleMap) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"hlo( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"hlo( HloModule m fused_computation { @@ -98,9 +97,8 @@ ENTRY entry_computation { BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{16, 32}}; - TF_ASSERT_OK_AND_ASSIGN( - Tiling tiling, - TilingFromAnnotatedFusion(*analysis, block_level_parameters)); + ASSERT_OK_AND_ASSIGN(Tiling tiling, TilingFromAnnotatedFusion( + *analysis, block_level_parameters)); const HloInstruction* log = module->entry_computation() ->root_instruction() @@ -112,8 +110,8 @@ ENTRY entry_computation { } TEST_F(TilingFromBlockParametersTest, GeneratesTilingForDot) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"hlo( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"hlo( HloModule m fused_computation { @@ -137,9 +135,8 @@ ENTRY entry_computation { BlockLevelParameters block_level_parameters; block_level_parameters.output_tile_sizes = {{16, 16}}; - TF_ASSERT_OK_AND_ASSIGN( - Tiling tiling, - TilingFromAnnotatedFusion(*analysis, block_level_parameters)); + ASSERT_OK_AND_ASSIGN(Tiling tiling, TilingFromAnnotatedFusion( + *analysis, block_level_parameters)); const HloInstruction* dot = module->entry_computation() ->root_instruction() @@ -151,8 +148,8 @@ ENTRY entry_computation { } TEST_F(TilingFromBlockParametersTest, GeneratesTilingForDotWithTilingOverride) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"hlo( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"hlo( HloModule m fused_computation { @@ -177,9 +174,9 @@ ENTRY entry_computation { Tile tile_override; tile_override.add_sizes(64); - TF_ASSERT_OK_AND_ASSIGN( - Tiling tiling, TilingFromAnnotatedFusion( - *analysis, block_level_parameters, &tile_override)); + ASSERT_OK_AND_ASSIGN(Tiling tiling, + TilingFromAnnotatedFusion( + *analysis, block_level_parameters, &tile_override)); const HloInstruction* dot = module->entry_computation() ->root_instruction() @@ -215,8 +212,8 @@ class GetTileTilingSpaceConcreteSizesTest }; TEST_F(GetTileTilingSpaceConcreteSizesTest, DotWithBackendConfig) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"hlo( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"hlo( f { p0 = f32[64,128] parameter(0) p1 = f32[128,256] parameter(1) @@ -242,8 +239,8 @@ ENTRY entry { } TEST_F(GetTileTilingSpaceConcreteSizesTest, DotWithoutBackendConfig) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"hlo( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"hlo( f { p0 = f32[64,128] parameter(0) p1 = f32[128,256] parameter(1) @@ -268,8 +265,8 @@ ENTRY entry { } TEST_F(GetTileTilingSpaceConcreteSizesTest, ReductionWithTwoReductionDims) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"hlo( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"hlo( add { lhs = f32[] parameter(0) rhs = f32[] parameter(1) diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc index 4f61cf65823834..f8d2a773ed88a3 100644 --- a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc @@ -26,17 +26,16 @@ limitations under the License. #include "absl/status/status_matchers.h" #include "mlir/IR/MLIRContext.h" #include "xla/codegen/tiling/symbolic_tile_analysis.h" +#include "xla/codegen/tiling/tiling_specification.h" #include "xla/hlo/analysis/symbolic_expr.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" -#include "xla/hlo/utils/hlo_traversal.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" -#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/test.h" namespace xla { @@ -78,8 +77,8 @@ class TritonEmitterConstraintsTest : public HloHardwareIndependentTestBase { }; TEST_F(TritonEmitterConstraintsTest, TooBigTileSizesConstraintIsEnforced) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( HloModule m max_computation { @@ -128,8 +127,8 @@ ENTRY entry_computation { } TEST_F(TritonEmitterConstraintsTest, DotOperandSizeConstraintIsEnforced) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( HloModule m fused_computation { @@ -161,8 +160,8 @@ ENTRY entry_computation { } TEST_F(TritonEmitterConstraintsTest, TooManyBlocksConstraintIsEnforced) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( HloModule m max_computation { @@ -200,8 +199,8 @@ ENTRY entry_computation { } TEST_F(TritonEmitterConstraintsTest, FusionHasValidTileSizes) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( HloModule m fused_computation { @@ -259,8 +258,8 @@ ENTRY entry_computation { } TEST_F(TritonEmitterConstraintsTest, MultiOutputFusionHasPowerOfTwoTileSizes) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( HloModule m add { diff --git a/third_party/xla/xla/service/gpu_compilation_environment.cc b/third_party/xla/xla/service/gpu_compilation_environment.cc index 574f6c735b98c6..b44ee0f18cae5c 100644 --- a/third_party/xla/xla/service/gpu_compilation_environment.cc +++ b/third_party/xla/xla/service/gpu_compilation_environment.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" +#include "xla/tsl/platform/status_macros.h" #include "google/protobuf/descriptor.h" #include "xla/parse_flags_from_env.h" #include "xla/service/compilation_environments.h" @@ -79,8 +80,8 @@ GpuCompilationEnvironment CreateGpuCompEnvWithDefaultValues() { absl::Status InitializeMissingFieldsFromXLAFlags( GpuCompilationEnvironment& env) { - TF_ASSIGN_OR_RETURN(GpuCompilationEnvironment from_env, - CreateGpuCompEnvFromEnvVar()); + ASSIGN_OR_RETURN(GpuCompilationEnvironment from_env, + CreateGpuCompEnvFromEnvVar()); auto default_env = CreateGpuCompEnvWithDefaultValues(); diff --git a/third_party/xla/xla/service/heap_simulator/BUILD b/third_party/xla/xla/service/heap_simulator/BUILD index 313050c752a1b8..1c955810560677 100644 --- a/third_party/xla/xla/service/heap_simulator/BUILD +++ b/third_party/xla/xla/service/heap_simulator/BUILD @@ -53,6 +53,7 @@ cc_library( "//xla/service:logical_buffer", "//xla/service:time_utils", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:nullability", diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc index 66d1a0b0729b75..d8072fc325be61 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc @@ -47,6 +47,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/analysis/alias_info.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" @@ -223,7 +224,7 @@ absl::StatusOr HeapSimulator::MinimumMemoryForModule( // ignoring fragmentation. We run the heap simulation on the whole module, // rather than summing each computation, since it gives us a better lower // bound, by minimizing the liveness of sub-computations. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(std::make_unique>(), *module, schedule, alias_analysis, alias_info, @@ -236,7 +237,7 @@ absl::StatusOr HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const HloAliasAnalysis& alias_analysis, const AliasInfo* alias_info, const LogicalBuffer::SizeFunction* absl_nonnull size_function) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(std::make_unique>(), computation, sequence, alias_analysis, alias_info, @@ -255,12 +256,12 @@ absl::StatusOr> HeapSimulator::Run( const HloComputation* entry_computation = module.entry_computation(); const HloInstructionSequence& instruction_sequence = schedule.sequence(entry_computation); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr hlo_live_range, HloLiveRange::Run(schedule, alias_analysis, entry_computation)); - TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation, - instruction_sequence, alias_analysis, - alias_info, hlo_live_range.get())); + RETURN_IF_ERROR(heap.RunComputation(*entry_computation, instruction_sequence, + alias_analysis, alias_info, + hlo_live_range.get())); return heap.Finish(); } @@ -276,12 +277,12 @@ absl::StatusOr> HeapSimulator::Run( /*schedule=*/nullptr); HloSchedule schedule(computation.parent()); schedule.set_sequence(&computation, instruction_sequence); - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, - HloLiveRange::Run(schedule, alias_analysis, &computation, - /*module_scoped_analysis=*/false)); - TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, - alias_analysis, alias_info, - hlo_live_range.get())); + ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, + HloLiveRange::Run(schedule, alias_analysis, &computation, + /*module_scoped_analysis=*/false)); + RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, + alias_analysis, alias_info, + hlo_live_range.get())); return heap.Finish(); } @@ -295,12 +296,11 @@ absl::StatusOr> HeapSimulator::Run( const HloSchedule* schedule, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*schedule=*/schedule); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_live_range, - HloLiveRange::Run(*schedule, alias_analysis, &computation)); - TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, - alias_analysis, alias_info, - hlo_live_range.get())); + ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, + HloLiveRange::Run(*schedule, alias_analysis, &computation)); + RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, + alias_analysis, alias_info, + hlo_live_range.get())); return heap.Finish(); } @@ -584,7 +584,7 @@ int64_t HeapSimulator::GetBufferSize(const HloValue* buffer) const { } absl::StatusOr> HeapSimulator::Finish() { - TF_ASSIGN_OR_RETURN(Result result, algorithm_->Finish()); + ASSIGN_OR_RETURN(Result result, algorithm_->Finish()); // Post-process the result to add chunks for shared buffers. An empty chunk // map means that either no buffers were allocated, or the heap was only @@ -603,8 +603,8 @@ absl::StatusOr> HeapSimulator::Finish() { } // Fragmentation is the difference between the actual and ideal sizes. - TF_ASSIGN_OR_RETURN(const Result no_frag_result, - no_fragmentation_stats_->Finish()); + ASSIGN_OR_RETURN(const Result no_frag_result, + no_fragmentation_stats_->Finish()); result.fragmentation_size = result.heap_size - no_frag_result.heap_size; // Copy the debug trace we collected to the final result. @@ -2804,7 +2804,7 @@ ConstrainedGlobalDecreasingSizeBestFitHeap::FinishFastMerge() { do { FreeChunksManager chunks_manager( [this](int64_t addr) { return ComputeAlignedChunkEnd(addr); }); - TF_RETURN_IF_ERROR(AllocateBuffersSortedByTimeInSingleHeap( + RETURN_IF_ERROR(AllocateBuffersSortedByTimeInSingleHeap( remaining_sorted_buffers, chunks_manager)); // Collect the result from the currently processed heap and reset the heap // states. @@ -2875,7 +2875,7 @@ ConstrainedGlobalDecreasingSizeBestFitHeap::FinishFastSplit() { chunks_manager.Allocate(0, max_end_in_phase_one); // Second phase: process the rest of the buffers. - TF_RETURN_IF_ERROR(AllocateBuffersSortedByTimeInSingleHeap( + RETURN_IF_ERROR(AllocateBuffersSortedByTimeInSingleHeap( remaining_fast_pass_sorted_buffers, chunks_manager)); // Collect the result from the currently processed heap and reset the heap @@ -3001,7 +3001,7 @@ ChooseBestHeapAlgorithm::Finish() { int64_t min_size = INT64_MAX; int min_size_index = -1; for (int i = 0; i < algorithms_.size(); ++i) { - TF_ASSIGN_OR_RETURN(results[i], algorithms_[i]->Finish()); + ASSIGN_OR_RETURN(results[i], algorithms_[i]->Finish()); if (results[i].heap_size < min_size) { min_size = results[i].heap_size; min_size_index = i; diff --git a/third_party/xla/xla/service/hlo_cost_analysis.cc b/third_party/xla/xla/service/hlo_cost_analysis.cc index 3d83be049b693a..1ae09057b26714 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis.cc +++ b/third_party/xla/xla/service/hlo_cost_analysis.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -134,12 +135,12 @@ absl::Status HloCostAnalysis::RemoveInstruction( absl::Status HloCostAnalysis::RevisitInstruction( const HloInstruction* instruction) { - TF_RETURN_IF_ERROR(RemoveInstruction(instruction)); + RETURN_IF_ERROR(RemoveInstruction(instruction)); // Now do Preprocess() -> Visit() -> Postprocess() for the instruction same // way it is done during the complete analysis. - TF_RETURN_IF_ERROR(Preprocess(instruction)); - TF_RETURN_IF_ERROR(instruction->Visit(this)); - TF_RETURN_IF_ERROR(Postprocess(instruction)); + RETURN_IF_ERROR(Preprocess(instruction)); + RETURN_IF_ERROR(instruction->Visit(this)); + RETURN_IF_ERROR(Postprocess(instruction)); return absl::OkStatus(); } @@ -554,8 +555,8 @@ absl::Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) { absl::Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. - TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(map->to_apply())); + ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(map->to_apply())); // Compute the cost of all elements for this Map operation. const int64_t element_count = ShapeUtil::ElementsIn(map->shape()); @@ -570,8 +571,8 @@ absl::Status HloCostAnalysis::HandleMap(const HloInstruction* map) { absl::Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. - TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. // This counts the number of times the reduction function is applied, so it @@ -594,8 +595,8 @@ absl::Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { absl::Status HloCostAnalysis::HandleScan(const HloInstruction* scan) { HloComputation* function = scan->to_apply(); // Compute the cost of the user function. - TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this Scan operation. auto input = scan->operand(1); @@ -613,8 +614,7 @@ absl::Status HloCostAnalysis::HandleReduceWindow( const Window& window = reduce_window->window(); auto function = reduce_window->to_apply(); // Compute the properties of the reduction function. - TF_ASSIGN_OR_RETURN(Properties sub_properties, - ProcessSubcomputation(function)); + ASSIGN_OR_RETURN(Properties sub_properties, ProcessSubcomputation(function)); // Compute the cost of all elements for this ReduceWindow operation. For each // output element there are window_size - 1 reductions to perform. @@ -695,10 +695,10 @@ absl::Status HloCostAnalysis::HandleSelectAndScatter( const HloInstruction* instruction) { // Compute the properties of the select and scatter function. // Compute the properties of the reduction function. - TF_ASSIGN_OR_RETURN(Properties select_properties, - ProcessSubcomputation(instruction->select())); - TF_ASSIGN_OR_RETURN(Properties scatter_properties, - ProcessSubcomputation(instruction->scatter())); + ASSIGN_OR_RETURN(Properties select_properties, + ProcessSubcomputation(instruction->select())); + ASSIGN_OR_RETURN(Properties scatter_properties, + ProcessSubcomputation(instruction->scatter())); // Compute the cost of all elements for this operation. For each scatter // source element there are window_size - 1 select computations to perform and @@ -750,7 +750,7 @@ absl::Status HloCostAnalysis::HandlePad(const HloInstruction*) { absl::Status HloCostAnalysis::HandleAsyncStart( const HloInstruction* async_start) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( current_properties_, ProcessSubcomputation(async_start->called_computations()[0])); return absl::OkStatus(); @@ -1318,22 +1318,22 @@ absl::Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { } } } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( current_properties_, ProcessSubcomputation(fusion->fused_instructions_computation())); current_properties_[kBytesAccessedKey] = 0; - TF_RETURN_IF_ERROR(FusionProcessOutputBytesAccessed(fusion)); - TF_RETURN_IF_ERROR(FusionCalculateUtilizations(fusion)); - TF_RETURN_IF_ERROR(FusionCountConstantsMemoryAccess(fusion)); - TF_RETURN_IF_ERROR(FusionProcessOperandBytesRead(fusion)); + RETURN_IF_ERROR(FusionProcessOutputBytesAccessed(fusion)); + RETURN_IF_ERROR(FusionCalculateUtilizations(fusion)); + RETURN_IF_ERROR(FusionCountConstantsMemoryAccess(fusion)); + RETURN_IF_ERROR(FusionProcessOperandBytesRead(fusion)); return absl::OkStatus(); } absl::Status HloCostAnalysis::HandleCall(const HloInstruction* call) { - TF_ASSIGN_OR_RETURN(current_properties_, - ProcessSubcomputation(call->to_apply())); + ASSIGN_OR_RETURN(current_properties_, + ProcessSubcomputation(call->to_apply())); current_should_compute_bottleneck_time_ = false; return absl::OkStatus(); } @@ -1372,11 +1372,11 @@ absl::Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { // Since the number of iterations of the while node will not always be // something that we can statically analyze, we cannot precisely compute the // cost of a while node. For now compute the cost of a single iteration. - TF_ASSIGN_OR_RETURN(const Properties body_properties, - ProcessSubcomputation(xla_while->while_body())); + ASSIGN_OR_RETURN(const Properties body_properties, + ProcessSubcomputation(xla_while->while_body())); - TF_ASSIGN_OR_RETURN(const Properties condition_properties, - ProcessSubcomputation(xla_while->while_condition())); + ASSIGN_OR_RETURN(const Properties condition_properties, + ProcessSubcomputation(xla_while->while_condition())); current_properties_ = Properties(); body_properties.ForEach([&](absl::string_view key, float val) { @@ -1394,14 +1394,12 @@ absl::Status HloCostAnalysis::HandleConditional( const HloInstruction* conditional) { // Compute the cost of the branch computations and take the maximum from those // for each property. - TF_ASSIGN_OR_RETURN( - const Properties branch0_computation_properties, - ProcessSubcomputation(conditional->branch_computation(0))); + ASSIGN_OR_RETURN(const Properties branch0_computation_properties, + ProcessSubcomputation(conditional->branch_computation(0))); current_properties_ = branch0_computation_properties; for (int j = 1; j < conditional->branch_count(); ++j) { - TF_ASSIGN_OR_RETURN( - const Properties branch_computation_properties, - ProcessSubcomputation(conditional->branch_computation(j))); + ASSIGN_OR_RETURN(const Properties branch_computation_properties, + ProcessSubcomputation(conditional->branch_computation(j))); branch_computation_properties.ForEach( [&](absl::string_view key, float val) { auto& current_property = current_properties_[key]; @@ -1450,8 +1448,8 @@ absl::Status HloCostAnalysis::HandleScatter(const HloInstruction* hlo) { current_properties_.set_output_bytes_accessed(total_update_size); const int64_t element_count = ShapeUtil::ElementsIn(scatter->scatter_updates()[0]->shape()); - TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(scatter->to_apply())); + ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(scatter->to_apply())); sub_properties.ForEach([&](absl::string_view key, float val) { if (KeyToCopyFromSubcomputation(key)) { current_properties_[key] = val * element_count; @@ -1596,7 +1594,7 @@ absl::StatusOr HloCostAnalysis::ProcessSubcomputation(HloComputation* computation) { auto visitor = CreateNestedCostAnalysis(); visitor->ReserveVisitStates(computation->instruction_count()); - TF_RETURN_IF_ERROR(computation->Accept(visitor.get())); + RETURN_IF_ERROR(computation->Accept(visitor.get())); for (auto& entry : visitor->hlo_properties_) { hlo_properties_[entry.first] = std::move(entry.second); } diff --git a/third_party/xla/xla/service/hlo_creation_utils.cc b/third_party/xla/xla/service/hlo_creation_utils.cc index 9a53fb28177e47..b8b495802c07d8 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.cc +++ b/third_party/xla/xla/service/hlo_creation_utils.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/builder/lib/comparators.h" #include "xla/hlo/builder/xla_builder.h" @@ -59,8 +60,8 @@ absl::StatusOr MakeUnaryHlo(HloOpcode opcode, HloInstruction* operand, const OpMetadata* metadata) { HloComputation* computation = operand->parent(); - TF_ASSIGN_OR_RETURN(Shape unary_op_shape, - ShapeInference::InferUnaryOpShape(opcode, operand)); + ASSIGN_OR_RETURN(Shape unary_op_shape, + ShapeInference::InferUnaryOpShape(opcode, operand)); return computation->AddInstruction( HloInstruction::CreateUnary(unary_op_shape, opcode, operand), metadata); } @@ -75,8 +76,8 @@ absl::StatusOr MakeBinaryHlo( const OpMetadata* metadata, const FrontendAttributes* frontend_attributes) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN(Shape binary_op_shape, - ShapeInference::InferBinaryOpShape(opcode, lhs, rhs)); + ASSIGN_OR_RETURN(Shape binary_op_shape, + ShapeInference::InferBinaryOpShape(opcode, lhs, rhs)); return computation->AddInstruction( HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs), metadata, frontend_attributes); @@ -87,9 +88,8 @@ absl::StatusOr MakeCompareHlo( const OpMetadata* metadata, const FrontendAttributes* frontend_attributes) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN( - Shape binary_op_shape, - ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs)); + ASSIGN_OR_RETURN(Shape binary_op_shape, ShapeInference::InferBinaryOpShape( + HloOpcode::kCompare, lhs, rhs)); return computation->AddInstruction( HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction), metadata, frontend_attributes); @@ -101,7 +101,7 @@ absl::StatusOr MakePadHlo( const FrontendAttributes* frontend_attributes) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, padding_value->parent()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape pad_shape, ShapeInference::InferPadShape(operand->shape(), padding_value->shape(), padding_config)); @@ -117,9 +117,9 @@ absl::StatusOr MakeSliceHlo( absl::Span limit_indices, absl::Span strides, const OpMetadata* metadata, const FrontendAttributes* frontend_attributes) { HloComputation* computation = operand->parent(); - TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( - operand->shape(), start_indices, - limit_indices, strides)); + ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( + operand->shape(), start_indices, + limit_indices, strides)); return computation->AddInstruction( HloInstruction::CreateSlice(slice_shape, operand, start_indices, limit_indices, strides), @@ -136,7 +136,7 @@ absl::StatusOr MakeConvolveHlo( const FrontendAttributes* frontend_attributes) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape convolve_shape, ShapeInference::InferConvolveShape( lhs->shape(), rhs->shape(), feature_group_count, batch_group_count, @@ -150,9 +150,8 @@ absl::StatusOr MakeConvolveHlo( absl::StatusOr MakeTransposeHlo( HloInstruction* operand, absl::Span dimensions) { - TF_ASSIGN_OR_RETURN( - Shape transpose_shape, - ShapeInference::InferTransposeShape(operand->shape(), dimensions)); + ASSIGN_OR_RETURN(Shape transpose_shape, ShapeInference::InferTransposeShape( + operand->shape(), dimensions)); return operand->AddInstruction( HloInstruction::CreateTranspose(transpose_shape, operand, dimensions)); } @@ -182,7 +181,7 @@ absl::StatusOr MakeDynamicSliceHlo( std::vector scalar_start_indices_shapes( start_indices.size(), ShapeUtil::MakeShape(start_indices[0]->shape().element_type(), {})); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape dynamic_slice_shape, ShapeInference::InferDynamicSliceShape( operand->shape(), scalar_start_indices_shapes, slice_sizes)); @@ -211,7 +210,7 @@ absl::StatusOr MakeDynamicSliceHlo( } std::vector scalar_start_indices_shapes( rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {})); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape dynamic_slice_shape, ShapeInference::InferDynamicSliceShape( operand->shape(), scalar_start_indices_shapes, slice_sizes)); @@ -241,7 +240,7 @@ absl::StatusOr MakeDynamicUpdateSliceHlo( } std::vector scalar_start_indices_shapes( rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {})); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape dynamic_update_slice_shape, ShapeInference::InferDynamicUpdateSliceShape( operand->shape(), update->shape(), scalar_start_indices_shapes)); @@ -262,7 +261,7 @@ absl::StatusOr MakeDynamicUpdateSliceHlo( for (auto start_index : start_indices) { scalar_start_indices_shapes.push_back(start_index->shape()); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape dynamic_update_slice_shape, ShapeInference::InferDynamicUpdateSliceShape( operand->shape(), update->shape(), scalar_start_indices_shapes)); @@ -296,9 +295,8 @@ absl::StatusOr MakeGetTupleElementHlo( HloInstruction* operand, int64_t index, const OpMetadata* metadata) { HloComputation* computation = operand->parent(); - TF_ASSIGN_OR_RETURN( - Shape gte_shape, - ShapeInference::InferGetTupleElementShape(operand->shape(), index)); + ASSIGN_OR_RETURN(Shape gte_shape, ShapeInference::InferGetTupleElementShape( + operand->shape(), index)); return computation->AddInstruction( HloInstruction::CreateGetTupleElement(gte_shape, operand, index), metadata); @@ -318,8 +316,8 @@ absl::StatusOr MakeConcatHlo( absl::c_transform(operands, std::back_inserter(operand_shapes), [](HloInstruction* instr) { return &instr->shape(); }); - TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape( - operand_shapes, dimension)); + ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape( + operand_shapes, dimension)); return computation->AddInstruction( HloInstruction::CreateConcatenate(concat_shape, operands, dimension), metadata, frontend_attributes); @@ -381,10 +379,9 @@ absl::StatusOr MakeDotHlo( const OpMetadata* metadata) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN( - Shape dot_shape, - ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers, - preferred_element_type)); + ASSIGN_OR_RETURN(Shape dot_shape, ShapeInference::InferDotOpShape( + lhs->shape(), rhs->shape(), dim_numbers, + preferred_element_type)); return computation->AddInstruction( HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers, precision_config), @@ -399,10 +396,10 @@ absl::StatusOr MakeRaggedDotHlo( HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); CHECK_EQ(computation, group_sizes->parent()); - TF_ASSIGN_OR_RETURN(Shape ragged_dot_shape, - ShapeInference::InferRaggedDotOpShape( - lhs->shape(), rhs->shape(), group_sizes->shape(), - dim_numbers, preferred_element_type)); + ASSIGN_OR_RETURN(Shape ragged_dot_shape, + ShapeInference::InferRaggedDotOpShape( + lhs->shape(), rhs->shape(), group_sizes->shape(), + dim_numbers, preferred_element_type)); return computation->AddInstruction(HloInstruction::CreateRaggedDot( ragged_dot_shape, lhs, rhs, group_sizes, dim_numbers, precision_config)); } @@ -416,10 +413,9 @@ absl::StatusOr MakeScaledDotHlo( CHECK_EQ(computation, lhs_scale->parent()); CHECK_EQ(computation, rhs->parent()); CHECK_EQ(computation, rhs_scale->parent()); - TF_ASSIGN_OR_RETURN( - Shape dot_shape, - ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers, - preferred_element_type)); + ASSIGN_OR_RETURN(Shape dot_shape, ShapeInference::InferDotOpShape( + lhs->shape(), rhs->shape(), dim_numbers, + preferred_element_type)); return computation->AddInstruction( HloInstruction::CreateScaledDot(dot_shape, lhs, rhs, lhs_scale, rhs_scale, dim_numbers, precision_config)); @@ -441,7 +437,7 @@ absl::StatusOr MakeMapHlo( } std::vector map_dims(max_operand_rank); absl::c_iota(map_dims, 0); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape map_shape, ShapeInference::InferMapShape( operand_shapes, map_computation->ComputeProgramShape(), map_dims)); @@ -495,10 +491,10 @@ absl::StatusOr MakeReduceHlo( absl::StatusOr MakeReduceWindowHlo( HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation, const OpMetadata* metadata) { - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferReduceWindowShape( - operand->shape(), init_value->shape(), window, - reduce_computation->ComputeProgramShape())); + ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferReduceWindowShape( + operand->shape(), init_value->shape(), window, + reduce_computation->ComputeProgramShape())); return operand->parent()->AddInstruction( HloInstruction::CreateReduceWindow(inferred_shape, operand, init_value, window, reduce_computation), @@ -511,10 +507,10 @@ absl::StatusOr MakeReduceWindowHlo( HloComputation* reduce_computation = MakeBinaryScalarComputation( binary_opcode, operand->shape().element_type(), operand, operand->GetModule()); - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferReduceWindowShape( - operand->shape(), init_value->shape(), window, - reduce_computation->ComputeProgramShape())); + ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferReduceWindowShape( + operand->shape(), init_value->shape(), window, + reduce_computation->ComputeProgramShape())); return operand->parent()->AddInstruction( HloInstruction::CreateReduceWindow(inferred_shape, operand, init_value, window, reduce_computation), @@ -569,8 +565,8 @@ absl::StatusOr MakeReduceHlo( operand->shape())); } - TF_ASSIGN_OR_RETURN(auto output_shape, - ShapeUtil::MakeValidatedMaybeTupleShape(expected_shapes)); + ASSIGN_OR_RETURN(auto output_shape, + ShapeUtil::MakeValidatedMaybeTupleShape(expected_shapes)); return operands[0]->parent()->AddInstruction( HloInstruction::CreateReduce(output_shape, operands, init_values, dimensions, reduce_computation), @@ -581,8 +577,8 @@ absl::StatusOr MakeReverseHlo( HloInstruction* operand, absl::Span dimensions, const OpMetadata* metadata) { HloComputation* computation = operand->parent(); - TF_ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape( - operand->shape(), dimensions)); + ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape( + operand->shape(), dimensions)); return computation->AddInstruction( HloInstruction::CreateReverse(reverse_shape, operand, dimensions), metadata); @@ -612,9 +608,9 @@ absl::StatusOr MakeSelectHlo( } TF_RET_CHECK(!op_shape.IsTuple()); HloOpcode select_op_code = HloOpcode::kSelect; - TF_ASSIGN_OR_RETURN(Shape select_shape, - ShapeInference::InferTernaryOpShape(select_op_code, pred, - on_true, on_false)); + ASSIGN_OR_RETURN(Shape select_shape, + ShapeInference::InferTernaryOpShape(select_op_code, pred, + on_true, on_false)); HloInstruction* select = computation->AddInstruction( HloInstruction::CreateTernary(select_shape, select_op_code, pred, on_true, on_false), @@ -636,10 +632,10 @@ HloInstruction* MaybeMakeTuple(absl::Span operands) { absl::StatusOr XlaComputationToHloComputation( XlaComputation& src_comp, HloModule* dest_module) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, src_comp.GetProgramShape()); + ASSIGN_OR_RETURN(ProgramShape program_shape, src_comp.GetProgramShape()); HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(src_comp.proto(), config)); + ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(src_comp.proto(), config)); HloCloneContext context(dest_module); return dest_module->DeepCloneComputation(new_module->entry_computation(), &context); @@ -659,8 +655,8 @@ absl::StatusOr MakeSortHlo( operand_types[i] = operands[i]->shape().element_type(); } XlaComputation comparator = CreateScalarLtComputation(operand_types, &b); - TF_ASSIGN_OR_RETURN(HloComputation * compare_computation, - XlaComputationToHloComputation(comparator, module)); + ASSIGN_OR_RETURN(HloComputation * compare_computation, + XlaComputationToHloComputation(comparator, module)); return builder->AddInstruction(HloInstruction::CreateSort( sort_shape, dimension_to_sort, operands, compare_computation, is_stable)); } @@ -815,7 +811,7 @@ absl::StatusOr MakeFusionInstruction( HloComputation* comp = fused->parent(); HloInstruction* fusion_instruction = comp->AddInstruction( HloInstruction::CreateFusion(fused->shape(), kind, fused)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(fused, fusion_instruction)); + RETURN_IF_ERROR(comp->ReplaceInstruction(fused, fusion_instruction)); return fusion_instruction; } @@ -929,10 +925,10 @@ HloInstruction* ExpandDegenerateReshape(HloInstruction* inst) { absl::StatusOr MakeWithinBounds(HloInstruction* inst, HloInstruction* lower_bound, HloInstruction* upper_bound) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * le, MakeCompareHlo(Comparison::Direction::kLe, lower_bound, inst)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * gt, MakeCompareHlo(Comparison::Direction::kGt, upper_bound, inst)); return MakeBinaryHlo(HloOpcode::kAnd, le, gt); diff --git a/third_party/xla/xla/service/hlo_creation_utils.h b/third_party/xla/xla/service/hlo_creation_utils.h index a3b0acb55ee488..39db1bc6e8fc10 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.h +++ b/third_party/xla/xla/service/hlo_creation_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -299,7 +300,7 @@ absl::StatusOr MakeR1ConstantHlo( absl::Span values) { Literal literal = LiteralUtil::CreateR1(values); if (literal.shape().element_type() != type) { - TF_ASSIGN_OR_RETURN(literal, literal.Convert(type)); + ASSIGN_OR_RETURN(literal, literal.Convert(type)); } return computation->AddInstruction( HloInstruction::CreateConstant(std::move(literal))); diff --git a/third_party/xla/xla/service/hlo_cse.cc b/third_party/xla/xla/service/hlo_cse.cc index 7799aca46aa89d..e614790c835399 100644 --- a/third_party/xla/xla/service/hlo_cse.cc +++ b/third_party/xla/xla/service/hlo_cse.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -63,7 +64,7 @@ absl::StatusOr CombineConstants( [&](const HloInstruction* instr) { return instr->opcode() == HloOpcode::kDomain; })) { - TF_ASSIGN_OR_RETURN(domain_map, HloDomainMap::Create(computation, "")); + ASSIGN_OR_RETURN(domain_map, HloDomainMap::Create(computation, "")); } // Map from the literal hash of a constant or the shape hash of an iota all @@ -266,13 +267,12 @@ absl::StatusOr HloCSE::RunOnComputation(HloComputation* computation) { return false; } - TF_ASSIGN_OR_RETURN( - bool changed, - is_layout_sensitive_ - ? CombineConstants(computation, - std::move(should_combine_constant_)) - : CombineConstants(computation, - std::move(should_combine_constant_))); + ASSIGN_OR_RETURN(bool changed, + is_layout_sensitive_ + ? CombineConstants( + computation, std::move(should_combine_constant_)) + : CombineConstants( + computation, std::move(should_combine_constant_))); const auto eq_instructions = [&](const HloInstruction* a, const HloInstruction* b) { @@ -321,9 +321,8 @@ absl::StatusOr HloCSE::RunOnComputation(HloComputation* computation) { auto pair = representatives.insert(CseKey{instruction}); if (!pair.second) { HloInstruction* equivalent_instruction = pair.first->hlo; - TF_RETURN_IF_ERROR( - instruction->ReplaceAllUsesWith(equivalent_instruction)); - TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( + RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(equivalent_instruction)); + RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( instruction, /*cleanup=*/std::nullopt, ignore_control_dependencies_)); VLOG(4) << "Replaced " << instruction->name() << " with " << equivalent_instruction->name(); @@ -340,10 +339,10 @@ absl::StatusOr HloCSE::RunOnComputation(HloComputation* computation) { if (a == b || !eq_instructions(a, b)) { continue; } - TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(j, a)); + RETURN_IF_ERROR(instruction->ReplaceOperandWith(j, a)); changed = true; if (b->IsDead()) { - TF_RETURN_IF_ERROR(computation->RemoveInstruction(b)); + RETURN_IF_ERROR(computation->RemoveInstruction(b)); } } } @@ -382,8 +381,7 @@ absl::StatusOr HloCSE::RunImpl( bool changed = false; for (auto* computation : module->computations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool computation_changed, - RunOnComputation(computation)); + ASSIGN_OR_RETURN(bool computation_changed, RunOnComputation(computation)); changed |= computation_changed; } return changed; diff --git a/third_party/xla/xla/service/hlo_cycle_detection.h b/third_party/xla/xla/service/hlo_cycle_detection.h index 2fc163c5164969..35234a63357eff 100644 --- a/third_party/xla/xla/service/hlo_cycle_detection.h +++ b/third_party/xla/xla/service/hlo_cycle_detection.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -36,7 +37,7 @@ class CycleDetectionVisitor : public DfsHloVisitorWithDefault { // detection by default. absl::Status VerifyNoCycle(HloModule* module) { for (auto* comp : module->computations()) { - TF_RETURN_IF_ERROR(comp->Accept(this)); + RETURN_IF_ERROR(comp->Accept(this)); } return absl::OkStatus(); } @@ -56,7 +57,7 @@ class HloCycleDetection : public HloModulePass { absl::StatusOr RunImpl(HloModule* module, const absl::flat_hash_set& execution_threads) override { - TF_RETURN_IF_ERROR(visitor_.VerifyNoCycle(module)); + RETURN_IF_ERROR(visitor_.VerifyNoCycle(module)); return false; } diff --git a/third_party/xla/xla/service/hlo_domain_isolator.cc b/third_party/xla/xla/service/hlo_domain_isolator.cc index 072a180f03a737..d77251f73085cf 100644 --- a/third_party/xla/xla/service/hlo_domain_isolator.cc +++ b/third_party/xla/xla/service/hlo_domain_isolator.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -47,8 +48,7 @@ absl::StatusOr AddExitDomains( // Call ReplaceUseWithDifferentShape even though the shapes are // expected to match to avoid an expensive shape check between the // original and the new instruction. - TF_RETURN_IF_ERROR( - instruction->ReplaceUseWithDifferentShape(user, domain)); + RETURN_IF_ERROR(instruction->ReplaceUseWithDifferentShape(user, domain)); ++added_domains; } } @@ -82,7 +82,7 @@ absl::StatusOr RunInternal( // Call ReplaceUseWithDifferentShape even though the shapes are // expected to match to avoid an expensive shape check between the // original and the new instruction. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( operand->ReplaceUseWithDifferentShape(instruction, domain)); ++added_domains; } @@ -103,20 +103,20 @@ absl::StatusOr HloDomainIsolator::UpdateDomains( DomainCreator creator = creator_factory_(); bool changed = false; // Update exit domains. - TF_ASSIGN_OR_RETURN(const int64_t removed_domains, - HloDomainRemover::RemoveExitDomains( - instruction, ShardingMetadata::KindName())); - TF_ASSIGN_OR_RETURN(const int64_t added_domains, - AddExitDomains(instruction, &creator)); + ASSIGN_OR_RETURN(const int64_t removed_domains, + HloDomainRemover::RemoveExitDomains( + instruction, ShardingMetadata::KindName())); + ASSIGN_OR_RETURN(const int64_t added_domains, + AddExitDomains(instruction, &creator)); changed |= (removed_domains > 0 || added_domains > 0); // Update the instruction itself if it's a domain. if (instruction->opcode() == HloOpcode::kDomain) { for (HloInstruction* operand : instruction->operands()) { - TF_ASSIGN_OR_RETURN(const int64_t removed_domains, - HloDomainRemover::RemoveExitDomains( - operand, ShardingMetadata::KindName())); - TF_ASSIGN_OR_RETURN(const int64_t added_domains, - AddExitDomains(operand, &creator)); + ASSIGN_OR_RETURN(const int64_t removed_domains, + HloDomainRemover::RemoveExitDomains( + operand, ShardingMetadata::KindName())); + ASSIGN_OR_RETURN(const int64_t added_domains, + AddExitDomains(operand, &creator)); changed |= (removed_domains > 0 || added_domains > 0); } } diff --git a/third_party/xla/xla/service/hlo_domain_map.cc b/third_party/xla/xla/service/hlo_domain_map.cc index 6543b850a39d93..900fa89565d2a9 100644 --- a/third_party/xla/xla/service/hlo_domain_map.cc +++ b/third_party/xla/xla/service/hlo_domain_map.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -43,7 +44,7 @@ namespace xla { /* static */ absl::StatusOr> HloDomainMap::Create( HloComputation* computation, std::string domain_kind) { auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); - TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + RETURN_IF_ERROR(domain_map->Populate(computation)); return std::move(domain_map); } @@ -51,7 +52,7 @@ namespace xla { HloModule* module, std::string domain_kind) { auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); for (HloComputation* computation : module->computations()) { - TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + RETURN_IF_ERROR(domain_map->Populate(computation)); } return std::move(domain_map); } @@ -81,13 +82,13 @@ absl::Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { auto domain = std::make_unique(); domain->enter_domains.insert(operand); domain->exit_domains.insert(instruction); - TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + RETURN_IF_ERROR(InsertDomain(std::move(domain))); } } if (instruction == instruction->parent()->root_instruction()) { auto domain = std::make_unique(); domain->enter_domains.insert(instruction); - TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + RETURN_IF_ERROR(InsertDomain(std::move(domain))); } return absl::OkStatus(); } @@ -102,7 +103,7 @@ absl::Status HloDomainMap::Populate(HloComputation* computation) { if (IsDomainInstruction(instruction)) { // If this is a kDomain of the kind we are currently processing, check // whether this is an "empty domain". - TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction)); + RETURN_IF_ERROR(TryProcessEmptyDomain(instruction)); continue; } int64_t domain_id = FindOrDefault(instruction_to_domain_, instruction, -1); @@ -110,11 +111,11 @@ absl::Status HloDomainMap::Populate(HloComputation* computation) { // We have already processed this instruction. continue; } - TF_ASSIGN_OR_RETURN(std::unique_ptr domain, - CreateDomain(instruction, instructions_post_order)); - TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + ASSIGN_OR_RETURN(std::unique_ptr domain, + CreateDomain(instruction, instructions_post_order)); + RETURN_IF_ERROR(InsertDomain(std::move(domain))); } - TF_RETURN_IF_ERROR(PopulateDomainMetadataMap()); + RETURN_IF_ERROR(PopulateDomainMetadataMap()); return absl::OkStatus(); } @@ -210,7 +211,7 @@ HloDomainMap::CreateDomain( HloInstruction* instruction, const InstructionOrderMap& instructions_order) const { auto domain = std::make_unique(); - TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); + RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); domain->instructions = MakeNonDomainInstructions(domain->reach_set, instructions_order); return std::move(domain); diff --git a/third_party/xla/xla/service/hlo_domain_remover.cc b/third_party/xla/xla/service/hlo_domain_remover.cc index 560c696a5d46c5..e64f3149e7974e 100644 --- a/third_party/xla/xla/service/hlo_domain_remover.cc +++ b/third_party/xla/xla/service/hlo_domain_remover.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -53,16 +54,16 @@ class HloDomainRemover::RunContext { absl::Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( const DomainMetadata::Domain& domain) { - TF_ASSIGN_OR_RETURN(const DomainMetadata* ref_metadata, - HloDomainVerifier::VerifyDomain(domain)); + ASSIGN_OR_RETURN(const DomainMetadata* ref_metadata, + HloDomainVerifier::VerifyDomain(domain)); if (ref_metadata != nullptr) { VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString(); - TF_RETURN_IF_ERROR(remover_->normalizer_(domain, ref_metadata)); + RETURN_IF_ERROR(remover_->normalizer_(domain, ref_metadata)); } else { // No kDomain instruction was present within this domain, so call the // generic normalization functions and have them apply their heuristic. VLOG(2) << "Applying domain-less normalization"; - TF_RETURN_IF_ERROR(remover_->normalizer_(domain, nullptr)); + RETURN_IF_ERROR(remover_->normalizer_(domain, nullptr)); } return absl::OkStatus(); } @@ -74,11 +75,11 @@ absl::StatusOr HloDomainRemover::RunContext::Run( for (HloComputation* computation : module_->computations(execution_threads)) { // First create the domain instruction sets. A domain instruction set is // the set of instructions whose edges never cross a kDomain instruction. - TF_ASSIGN_OR_RETURN(std::unique_ptr domain_map, - HloDomainMap::Create(computation, remover_->kind_)); + ASSIGN_OR_RETURN(std::unique_ptr domain_map, + HloDomainMap::Create(computation, remover_->kind_)); // Verify and normalize every domain populated within the map. for (auto& domain : domain_map->GetDomains()) { - TF_RETURN_IF_ERROR(VerifyAndNormalizeDomain(*domain)); + RETURN_IF_ERROR(VerifyAndNormalizeDomain(*domain)); } // Now remove all the kDomain instructions of the kind specified by the @@ -89,9 +90,9 @@ absl::StatusOr HloDomainRemover::RunContext::Run( for (HloInstruction* operand : instruction->unique_operands()) { if (domain_map->IsDomainInstruction(operand)) { VLOG(5) << "Removing " << operand->name(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( operand->ReplaceAllUsesWith(operand->mutable_operand(0))); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand)); + RETURN_IF_ERROR(computation->RemoveInstruction(operand)); ++removed_domains; } } @@ -100,7 +101,7 @@ absl::StatusOr HloDomainRemover::RunContext::Run( if (root != nullptr && domain_map->IsDomainInstruction(root)) { VLOG(5) << "Removing " << root->name(); computation->set_root_instruction(root->mutable_operand(0)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(root)); + RETURN_IF_ERROR(computation->RemoveInstruction(root)); ++removed_domains; } } @@ -121,8 +122,8 @@ absl::StatusOr HloDomainRemover::RemoveExitDomains( user->user_side_metadata().Kind() == domain_kind && user->operand_side_metadata().Kind() == domain_kind) { VLOG(5) << "Removing exit domain " << user->name(); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(instruction)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(user)); + RETURN_IF_ERROR(user->ReplaceAllUsesWith(instruction)); + RETURN_IF_ERROR(computation->RemoveInstruction(user)); ++removed_domains; } } diff --git a/third_party/xla/xla/service/hlo_domain_verifier.cc b/third_party/xla/xla/service/hlo_domain_verifier.cc index f3ab333b8289ef..705f66d3e2edd5 100644 --- a/third_party/xla/xla/service/hlo_domain_verifier.cc +++ b/third_party/xla/xla/service/hlo_domain_verifier.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -78,16 +79,16 @@ absl::Status HloDomainVerifier::RunContext::PopulateDomainKinds( absl::Status HloDomainVerifier::RunContext::Run( const absl::flat_hash_set& execution_threads) { VLOG(4) << "Running HLO Domain Verifier"; - TF_RETURN_IF_ERROR(PopulateDomainKinds(execution_threads)); + RETURN_IF_ERROR(PopulateDomainKinds(execution_threads)); for (HloComputation* computation : module_->computations(execution_threads)) { for (auto& kind : verifier_->kinds_) { // First create the domain instruction sets. A domain instruction set is // the set of instructions whose edges never cross a kDomain instruction. - TF_ASSIGN_OR_RETURN(std::unique_ptr domain_map, - HloDomainMap::Create(computation, kind)); + ASSIGN_OR_RETURN(std::unique_ptr domain_map, + HloDomainMap::Create(computation, kind)); // Verify every domain populated within the map. for (auto& domain : domain_map->GetDomains()) { - TF_RETURN_IF_ERROR(VerifyDomain(*domain).status()); + RETURN_IF_ERROR(VerifyDomain(*domain).status()); } } } @@ -98,7 +99,7 @@ absl::StatusOr HloDomainVerifier::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { RunContext run_context(module, this); - TF_RETURN_IF_ERROR(run_context.Run(execution_threads)); + RETURN_IF_ERROR(run_context.Run(execution_threads)); return false; } diff --git a/third_party/xla/xla/service/hlo_graph_dumper.cc b/third_party/xla/xla/service/hlo_graph_dumper.cc index b61cb3924da807..c80cb5e7746263 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.cc +++ b/third_party/xla/xla/service/hlo_graph_dumper.cc @@ -59,6 +59,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -2001,12 +2002,12 @@ static absl::StatusOr CompressAndEncode(absl::string_view input) { auto gz_opts = tsl::io::ZlibCompressionOptions::GZIP(); tsl::io::ZlibOutputBuffer gz_file(&f, gz_opts.input_buffer_size, gz_opts.output_buffer_size, gz_opts); - TF_RETURN_IF_ERROR(gz_file.Init()); - TF_RETURN_IF_ERROR(gz_file.Append(input)); - TF_RETURN_IF_ERROR(gz_file.Close()); + RETURN_IF_ERROR(gz_file.Init()); + RETURN_IF_ERROR(gz_file.Append(input)); + RETURN_IF_ERROR(gz_file.Close()); std::string encoded; - TF_RETURN_IF_ERROR(tsl::Base64Encode(compressed, &encoded)); + RETURN_IF_ERROR(tsl::Base64Encode(compressed, &encoded)); return absl::StrReplaceAll(encoded, {{"_", "/"}, {"-", "+"}}); } @@ -2037,8 +2038,8 @@ absl::StatusOr WrapFusionExplorer( EscapeJSONString(p.to_highlight))); }); - TF_ASSIGN_OR_RETURN(std::string dot_graphs_compressed, - CompressAndEncode(dot_graphs)); + ASSIGN_OR_RETURN(std::string dot_graphs_compressed, + CompressAndEncode(dot_graphs)); return absl::StrReplaceAll( R"wrapper( diff --git a/third_party/xla/xla/service/hlo_module_config.cc b/third_party/xla/xla/service/hlo_module_config.cc index d1de84a2567124..3a0dcb79e73d1f 100644 --- a/third_party/xla/xla/service/hlo_module_config.cc +++ b/third_party/xla/xla/service/hlo_module_config.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/service/computation_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" @@ -353,9 +354,8 @@ HloModuleConfig::CreateFromProto(const HloModuleConfigProto& proto) { auto config = std::make_unique(); if (proto.has_entry_computation_layout()) { - TF_ASSIGN_OR_RETURN( - auto comp_layout, - ProgramShape::FromProto(proto.entry_computation_layout())); + ASSIGN_OR_RETURN(auto comp_layout, + ProgramShape::FromProto(proto.entry_computation_layout())); config->SetComputationLayoutIfExists(comp_layout); } else { config->clear_entry_computation_layout(); @@ -387,7 +387,7 @@ HloModuleConfig::CreateFromProto(const HloModuleConfigProto& proto) { config->debug_options_ = proto.debug_options(); } if (proto.has_static_device_assignment()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr device_assignment, DeviceAssignment::Deserialize(proto.static_device_assignment())); config->static_device_assignment_ = std::move(*device_assignment); diff --git a/third_party/xla/xla/service/hlo_module_dce.cc b/third_party/xla/xla/service/hlo_module_dce.cc index 504c65f315e561..daedb6d2be0466 100644 --- a/third_party/xla/xla/service/hlo_module_dce.cc +++ b/third_party/xla/xla/service/hlo_module_dce.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/hlo_liveness_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -81,8 +82,7 @@ absl::StatusOr RunWhileDCE( // Replace while.body.root Tuple operand at 'tuple_index' with // 'pass_thru_gte', making prior operand a dead root (to be cleaned // up with a subsequent DCE pass). - TF_RETURN_IF_ERROR( - while_body_root->ReplaceOperandWith(i, pass_thru_gte)); + RETURN_IF_ERROR(while_body_root->ReplaceOperandWith(i, pass_thru_gte)); changed = true; modified_while_body_comp = true; } @@ -94,10 +94,10 @@ absl::StatusOr RunWhileDCE( // Run DCE on while body computations that we modified. for (auto* while_body_comp : while_body_comps_to_dce) { - TF_ASSIGN_OR_RETURN(bool changed_for_computation, - HloDCE::RunOnComputation( - while_body_comp, - /*remove_cross_partition_collective_ops=*/false)); + ASSIGN_OR_RETURN(bool changed_for_computation, + HloDCE::RunOnComputation( + while_body_comp, + /*remove_cross_partition_collective_ops=*/false)); changed |= changed_for_computation; } return changed; @@ -112,27 +112,27 @@ absl::StatusOr HloModuleDCE::RunImpl( XLA_VLOG_LINES(3, module->ToString()); std::unique_ptr liveness; - TF_ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module)); + ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module)); // Sweep through while instructions, transforming dead while tuple element // computations to pass through tuple values (creating dead roots in while // body computation in the process). - TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed, - RunWhileDCE(module, liveness.get(), execution_threads)); + ASSIGN_OR_RETURN(bool hlo_module_dce_changed, + RunWhileDCE(module, liveness.get(), execution_threads)); // Run the while loop simplifier to remove dead tuple elements. WhileLoopSimplifier while_loop_simplifier; - TF_ASSIGN_OR_RETURN(bool while_loop_simplifier_changed, - while_loop_simplifier.Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool while_loop_simplifier_changed, + while_loop_simplifier.Run(module, execution_threads)); TupleSimplifier tuple_simplifier; - TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed, - tuple_simplifier.Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool tuple_simplifier_changed, + tuple_simplifier.Run(module, execution_threads)); // Run HloDCE to clean up any dead code created during HloModuleDCE. HloDCE hlo_dce; - TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, - hlo_dce.Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool hlo_dce_changed, + hlo_dce.Run(module, execution_threads)); VLOG(2) << "After HloModuleDCE:"; XLA_VLOG_LINES(3, module->ToString()); diff --git a/third_party/xla/xla/service/hlo_module_util.cc b/third_party/xla/xla/service/hlo_module_util.cc index e2802f6d27b657..14f133fdc10fe2 100644 --- a/third_party/xla/xla/service/hlo_module_util.cc +++ b/third_party/xla/xla/service/hlo_module_util.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/parser/hlo_parser.h" @@ -48,7 +49,7 @@ namespace xla { namespace { absl::Status ValidateResultShape(const Shape& client_shape, const Shape& result_shape) { - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape)); if (!ShapeUtil::Compatible(client_shape, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " @@ -70,7 +71,7 @@ absl::StatusOr> CreateModuleFromString( absl::StatusOr> CreateModuleFromProto( const HloModuleProto& proto, const DebugOptions& debug_options) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloModuleConfig config, HloModule::CreateModuleConfigFromProto(proto, debug_options)); return HloModule::CreateFromProto(proto, config, @@ -82,12 +83,12 @@ absl::StatusOr> CreateModuleFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config, bool is_module_post_optimizations) { VLOG(4) << proto.ShortDebugString(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr module, HloModule::CreateFromProto(proto, module_config, /*buffer_assignment_proto=*/nullptr, /*preserve_instruction_ids=*/false)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/is_module_post_optimizations) .Run(module.get()) @@ -99,11 +100,11 @@ absl::StatusOr> ReadModuleFromBinaryProtoFile( absl::string_view filename, const DebugOptions& debug_options, bool remap_instruction_ids) { HloProto proto; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::ReadBinaryProto(tsl::Env::Default(), std::string(filename), &proto)); if (remap_instruction_ids) { - TF_ASSIGN_OR_RETURN(HloModuleProto sanitized_proto, - HloModule::RemapInstructionIds(proto.hlo_module())); + ASSIGN_OR_RETURN(HloModuleProto sanitized_proto, + HloModule::RemapInstructionIds(proto.hlo_module())); return CreateModuleFromProto(sanitized_proto, debug_options); } return CreateModuleFromProto(proto.hlo_module(), debug_options); @@ -113,8 +114,8 @@ absl::StatusOr> ReadModuleFromHloTextFile( absl::string_view filename, const DebugOptions& debug_options, const HloParserOptions& options) { std::string hlo_string; - TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), - std::string(filename), &hlo_string)); + RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), + std::string(filename), &hlo_string)); HloModuleConfig config; config.set_debug_options(debug_options); return ParseAndReturnUnverifiedModule(hlo_string, config, options); @@ -123,7 +124,7 @@ absl::StatusOr> ReadModuleFromHloTextFile( absl::StatusOr> ReadModuleFromTextProtoFile( absl::string_view hlo_file, const DebugOptions& debug_options) { HloProto proto; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::ReadTextProto(tsl::Env::Default(), std::string(hlo_file), &proto)); return CreateModuleFromProto(proto.hlo_module(), debug_options); } @@ -131,10 +132,10 @@ absl::StatusOr> ReadModuleFromTextProtoFile( absl::StatusOr> ReadModuleFromModuleBinaryProtofile( absl::string_view filename, const DebugOptions& debug_options) { HloModuleProto module_proto; - TF_RETURN_IF_ERROR(tsl::ReadBinaryProto( - tsl::Env::Default(), std::string(filename), &module_proto)); + RETURN_IF_ERROR(tsl::ReadBinaryProto(tsl::Env::Default(), + std::string(filename), &module_proto)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloModuleConfig module_config, HloModule::CreateModuleConfigFromProto(module_proto, debug_options)); @@ -146,10 +147,10 @@ absl::StatusOr> ReadModuleFromModuleBinaryProtofile( absl::StatusOr> ReadModuleFromModuleTextProtoFile( absl::string_view hlo_file, const DebugOptions& debug_options) { HloModuleProto module_proto; - TF_RETURN_IF_ERROR(tsl::ReadTextProto(tsl::Env::Default(), - std::string(hlo_file), &module_proto)); + RETURN_IF_ERROR(tsl::ReadTextProto(tsl::Env::Default(), std::string(hlo_file), + &module_proto)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloModuleConfig module_config, HloModule::CreateModuleConfigFromProto(module_proto, debug_options)); @@ -183,18 +184,18 @@ absl::StatusOr> CreateModuleConfig( i, ShapeUtil::HumanString(program_shape.parameters(i)), ShapeUtil::HumanString(*argument_shapes[i])); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( *argument_shapes[i])); } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const Shape shape_with_output_layout, Shape::FromProto(execution_options->shape_with_output_layout())); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ValidateResultShape(shape_with_output_layout, program_shape.result())); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); } else { @@ -243,9 +244,9 @@ absl::StatusOr> CreateModuleConfig( config->set_launch_id(execution_options->launch_id()); config->set_debug_options(execution_options->debug_options()); if (execution_options->has_device_assignment()) { - TF_ASSIGN_OR_RETURN(auto device_assignment, - DeviceAssignment::Deserialize( - execution_options->device_assignment())); + ASSIGN_OR_RETURN(auto device_assignment, + DeviceAssignment::Deserialize( + execution_options->device_assignment())); config->set_static_device_assignment(*device_assignment); } config->set_alias_passthrough_params( diff --git a/third_party/xla/xla/service/hlo_runner_interface.cc b/third_party/xla/xla/service/hlo_runner_interface.cc index 043e8c1a71b11c..f19fe1c398a463 100644 --- a/third_party/xla/xla/service/hlo_runner_interface.cc +++ b/third_party/xla/xla/service/hlo_runner_interface.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/tsl/platform/statusor.h" @@ -73,7 +74,7 @@ absl::StatusOr HloRunnerInterface::ExecuteWithExecutable( absl::StatusOr HloRunnerInterface::ExecuteWithExecutable( OpaqueExecutable* executable, absl::Span arguments) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> results, ExecuteWithExecutable(executable, arguments, /*num_repeats=*/1)); CHECK_EQ(results.size(), 1); diff --git a/third_party/xla/xla/service/hlo_runner_pjrt.cc b/third_party/xla/xla/service/hlo_runner_pjrt.cc index ec80e85453269a..2043ad3f92fb4d 100644 --- a/third_party/xla/xla/service/hlo_runner_pjrt.cc +++ b/third_party/xla/xla/service/hlo_runner_pjrt.cc @@ -109,7 +109,7 @@ absl::Status SanityCheckParameterLayouts( absl::StatusOr MustFlattenInputTuple( const absl::Span layouts) { - TF_RETURN_IF_ERROR(SanityCheckParameterLayouts(layouts)); + RETURN_IF_ERROR(SanityCheckParameterLayouts(layouts)); // Strictly, we only need to flatten tuples with mixed host/device leaves // because mixed host/device PjRtBuffer's are not supported. // However, splitting all tuples makes the code simpler and is the way @@ -121,7 +121,7 @@ absl::StatusOr> FlattenedParameterLayouts( const absl::Span layouts) { std::vector result; for (const ShapeLayout& layout : layouts) { - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( layout.shape(), [&result](const Shape& subshape, const ShapeIndex& index) -> absl::Status { @@ -208,10 +208,10 @@ absl::StatusOr GetMemorySpaceFromLayout( PjRtDevice* absl_nonnull const device, const Layout& layout) { PjRtMemorySpace* memory_space = nullptr; if (layout.memory_space() == Layout::kHostMemorySpace) { - TF_ASSIGN_OR_RETURN(memory_space, device->memory_space_by_kind( - PinnedHostMemorySpace::kKind)); + ASSIGN_OR_RETURN(memory_space, device->memory_space_by_kind( + PinnedHostMemorySpace::kKind)); } else { - TF_ASSIGN_OR_RETURN(memory_space, device->default_memory_space()); + ASSIGN_OR_RETURN(memory_space, device->default_memory_space()); } TF_RET_CHECK(memory_space != nullptr) << "Memory space " << layout.memory_space() @@ -244,8 +244,8 @@ class HloRunnerPjRtExecutable : public OpaqueExecutable { absl::StatusOr GetOrLoadExecutable( PjRtClient* const absl_nonnull client) { if (loaded_executable_ == nullptr) { - TF_ASSIGN_OR_RETURN(loaded_executable_, - client->Load(std::move(executable_), LoadOptions())); + ASSIGN_OR_RETURN(loaded_executable_, + client->Load(std::move(executable_), LoadOptions())); } return loaded_executable_.get(); } @@ -284,8 +284,8 @@ absl::StatusOr GetBestDeviceAssignment( return compile_options.executable_build_options.device_assignment(); } - TF_ASSIGN_OR_RETURN(std::vector> hlo_modules, - executable->executable()->GetHloModules()); + ASSIGN_OR_RETURN(std::vector> hlo_modules, + executable->executable()->GetHloModules()); TF_RET_CHECK(hlo_modules.size() == 1); return GetStaticDeviceAssignmentOrComputeDefault(*hlo_modules.front(), client); @@ -298,7 +298,7 @@ HloRunnerPjRt::HloRunnerPjRt(std::unique_ptr pjrt_client) absl::StatusOr HloRunnerPjRt::GenerateDefaultCompileOptions( HloModule* module, bool run_hlo_passes) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const DeviceAssignment device_assignment, GetStaticDeviceAssignmentOrComputeDefault(*module, *pjrt_client_)); @@ -367,10 +367,9 @@ absl::StatusOr HloRunnerPjRt::GenerateDefaultCompileOptions( } compile_options.argument_layouts = parameter_shapes; - TF_ASSIGN_OR_RETURN( - bool flatten, - MustFlattenInputTuple( - module->entry_computation_layout().parameter_layouts())); + ASSIGN_OR_RETURN(bool flatten, + MustFlattenInputTuple( + module->entry_computation_layout().parameter_layouts())); compile_options.parameter_is_tupled_arguments = flatten; compile_options.executable_build_options.set_result_layout( @@ -418,8 +417,7 @@ absl::StatusOr HloRunnerPjRt::TransferLiteralsFromDevice( for (const std::unique_ptr& leaf_buffer : output_buffers) { const Shape& leaf_shape = leaf_buffer->on_device_shape(); if (leaf_shape.IsArray()) { - TF_ASSIGN_OR_RETURN(Literal leaf, - TransferLiteralFromDevice(*leaf_buffer)); + ASSIGN_OR_RETURN(Literal leaf, TransferLiteralFromDevice(*leaf_buffer)); result_leaves.push_back(std::move(leaf)); } else { // Untupled non-array buffers are not supported by @@ -438,8 +436,8 @@ absl::StatusOr HloRunnerPjRt::TransferLiteralsFromDevice( absl::StatusOr HloRunnerPjRt::Execute( std::unique_ptr module, absl::Span arguments, bool run_hlo_passes) { - TF_ASSIGN_OR_RETURN(const std::unique_ptr executable, - CreateExecutable(std::move(module), run_hlo_passes)); + ASSIGN_OR_RETURN(const std::unique_ptr executable, + CreateExecutable(std::move(module), run_hlo_passes)); return HloRunnerInterface::ExecuteWithExecutable(executable.get(), arguments); } @@ -448,8 +446,8 @@ HloRunnerPjRt::ExecuteWithDeviceBuffers( OpaqueExecutable* executable, const std::vector>& arguments, const ExecuteOptions* execute_options) { - TF_ASSIGN_OR_RETURN(HloRunnerPjRtExecutable* const wrapped_executable, - HloRunnerPjRtExecutable::TryUnwrap(*this, executable)); + ASSIGN_OR_RETURN(HloRunnerPjRtExecutable* const wrapped_executable, + HloRunnerPjRtExecutable::TryUnwrap(*this, executable)); HloRunnerInterface::ReplicatedExecuteOptions replicated_execute_options; ExecuteOptions new_execute_options = UpdateOrCreateDefaultExecuteOptions( @@ -460,18 +458,17 @@ HloRunnerPjRt::ExecuteWithDeviceBuffers( new_execute_options.strict_shape_checking = true; } - TF_ASSIGN_OR_RETURN( - PjRtLoadedExecutable * pjrt_executable, - wrapped_executable->GetOrLoadExecutable(pjrt_client_.get())); + ASSIGN_OR_RETURN(PjRtLoadedExecutable * pjrt_executable, + wrapped_executable->GetOrLoadExecutable(pjrt_client_.get())); std::vector argument_ptrs = BufferVecToPointerVec(arguments); std::optional> returned_future = {}; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> buffers, pjrt_executable->ExecuteSharded( argument_ptrs, pjrt_client_->addressable_devices()[kDeviceIdx], new_execute_options, returned_future, true)); if (returned_future.has_value()) { - TF_RETURN_IF_ERROR(returned_future->Await()); + RETURN_IF_ERROR(returned_future->Await()); } return buffers; } @@ -480,15 +477,15 @@ absl::StatusOr>> HloRunnerPjRt::ExecuteWithExecutable(OpaqueExecutable* executable, absl::Span arguments, int64_t num_repeats) { - TF_ASSIGN_OR_RETURN(HloRunnerPjRtExecutable* const wrapped_executable, - HloRunnerPjRtExecutable::TryUnwrap(*this, executable)); + ASSIGN_OR_RETURN(HloRunnerPjRtExecutable* const wrapped_executable, + HloRunnerPjRtExecutable::TryUnwrap(*this, executable)); - TF_ASSIGN_OR_RETURN(std::vector> hlo_modules, - wrapped_executable->executable()->GetHloModules()); + ASSIGN_OR_RETURN(std::vector> hlo_modules, + wrapped_executable->executable()->GetHloModules()); TF_RET_CHECK(hlo_modules.size() == 1); const HloModule& module = *hlo_modules.front(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> argument_handles, TransferLiteralsToDefaultDevice( module.entry_computation_layout().parameter_layouts(), arguments)); @@ -516,9 +513,8 @@ HloRunnerPjRt::ExecuteWithExecutable(OpaqueExecutable* executable, absl::StatusOr> HloRunnerPjRt::CreateExecutable(std::unique_ptr module, bool run_hlo_passes) { - TF_ASSIGN_OR_RETURN( - CompileOptions compile_options, - GenerateDefaultCompileOptions(module.get(), run_hlo_passes)); + ASSIGN_OR_RETURN(CompileOptions compile_options, + GenerateDefaultCompileOptions(module.get(), run_hlo_passes)); XlaComputation computation(module->ToProto()); // Attempt to compile without loading. If that fails, fall back to compile + @@ -534,7 +530,7 @@ HloRunnerPjRt::CreateExecutable(std::unique_ptr module, } // Fall back to compile + load if Compile() was not implemented. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr pjrt_loaded_executable, pjrt_client_->CompileAndLoad(computation, std::move(compile_options))); return std::make_unique( @@ -560,10 +556,9 @@ HloRunnerPjRt::DeserializeExecutable(const absl::string_view serialized) const { // Fall back to deserialize + load if DeserializeExecutable() was not // implemented. This is similar to how we handle CreateExecutable above. - TF_ASSIGN_OR_RETURN( - std::unique_ptr pjrt_loaded_executable, - pjrt_client_->LoadSerializedExecutable( - serialized, /*options=*/std::nullopt, LoadOptions())); + ASSIGN_OR_RETURN(std::unique_ptr pjrt_loaded_executable, + pjrt_client_->LoadSerializedExecutable( + serialized, /*options=*/std::nullopt, LoadOptions())); return std::make_unique( this, std::move(pjrt_loaded_executable)); } @@ -579,9 +574,8 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( std::unique_ptr module, const HloRunnerInterface::ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - CreateExecutable(std::move(module), options.run_hlo_passes)); + ASSIGN_OR_RETURN(std::unique_ptr executable, + CreateExecutable(std::move(module), options.run_hlo_passes)); return ExecuteReplicatedWithExecutable(executable.get(), options, device_assignment); } @@ -654,10 +648,10 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( for (const PjRtBuffer* const buffer : argument_buffer_slices[i]) { TF_RET_CHECK(buffer != nullptr); } - TF_ASSIGN_OR_RETURN(HloRunnerPjRtExecutable* const executable, - HloRunnerPjRtExecutable::TryUnwrap( - *this, executable_provider_arg(i))); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(HloRunnerPjRtExecutable* const executable, + HloRunnerPjRtExecutable::TryUnwrap( + *this, executable_provider_arg(i))); + ASSIGN_OR_RETURN( PjRtLoadedExecutable * pjrt_executable, executable->GetOrLoadExecutable(pjrt_client_.get())); pool.Schedule([&per_replica_results, i, pjrt_executable, @@ -723,15 +717,15 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( std::vector is_tuple_result(options.num_devices, false); for (int64_t i = 0; i < options.num_devices; ++i) { // Amortize device lookup. - TF_ASSIGN_OR_RETURN(PjRtDevice* const device_ptr, - pjrt_client_->LookupDevice( - DeviceIdForInvocation(*device_assignment, i))); + ASSIGN_OR_RETURN(PjRtDevice* const device_ptr, + pjrt_client_->LookupDevice( + DeviceIdForInvocation(*device_assignment, i))); id_to_device_ptr[i] = device_ptr; // Get the entry layout. OpaqueExecutable* const wrapped_executable = executable_provider(i); - TF_ASSIGN_OR_RETURN(const HloModule* const module, - HloModuleFromWrapped(wrapped_executable)); + ASSIGN_OR_RETURN(const HloModule* const module, + HloModuleFromWrapped(wrapped_executable)); const ComputationLayout& ecl = module->entry_computation_layout(); is_tuple_result[i] = ecl.result_shape().IsTuple(); @@ -815,7 +809,7 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( ExecuteOptions execute_options = UpdateOrCreateDefaultExecuteOptions(options); VLOG(1) << "Replicated execution started"; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const std::vector>> result_buffers, execution_helper(BufferMatToPointerMat(argument_buffer_slices), @@ -827,7 +821,7 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( std::vector result_literals; result_literals.reserve(options.num_devices); for (int64_t i = 0; i < options.num_devices; ++i) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Literal literal, TransferLiteralsFromDevice(result_buffers[i], is_tuple_result[i])); result_literals.push_back(std::move(literal)); @@ -836,7 +830,7 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( // Join infeed and outfeed threads, if they exist. The thread pool's threads // are joined on destruction. No-op otherwise. pool = nullptr; - TF_RETURN_IF_ERROR(infeed_outfeed_status); + RETURN_IF_ERROR(infeed_outfeed_status); return std::move(result_literals); } @@ -846,9 +840,9 @@ HloRunnerPjRt::TransferLiteralsToDevice( absl::Span layouts, absl::Span literals, PjRtDevice* absl_nonnull device) { - TF_ASSIGN_OR_RETURN(bool flatten, MustFlattenInputTuple(layouts)); - TF_ASSIGN_OR_RETURN(std::vector parameter_layouts, - FlattenedParameterLayouts(layouts)); + ASSIGN_OR_RETURN(bool flatten, MustFlattenInputTuple(layouts)); + ASSIGN_OR_RETURN(std::vector parameter_layouts, + FlattenedParameterLayouts(layouts)); absl::Span input_literals = literals; std::optional> flattened; @@ -873,9 +867,9 @@ HloRunnerPjRt::TransferLiteralsToDevice( const Literal* literal = input_literals[i]; TF_RET_CHECK(literal != nullptr); const Layout& on_device_layout = parameter_layouts[i]; - TF_ASSIGN_OR_RETURN(PjRtMemorySpace* absl_nonnull memory_space, - GetMemorySpaceFromLayout(device, on_device_layout)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(PjRtMemorySpace* absl_nonnull memory_space, + GetMemorySpaceFromLayout(device, on_device_layout)); + ASSIGN_OR_RETURN( std::unique_ptr buffer, TransferLiteralToDevice(*literal, memory_space, on_device_layout)); buffer_ready_futures.push_back(buffer->GetReadyFuture()); @@ -912,7 +906,7 @@ HloRunnerPjRt::TransferLiteralToDevice( absl::StatusOr HloRunnerPjRt::TransferLiteralFromDevice( PjRtBuffer& buffer) { - TF_RETURN_IF_ERROR(buffer.GetReadyFuture().Await()); + RETURN_IF_ERROR(buffer.GetReadyFuture().Await()); // Implementations of ToLiteralSync() do not support empty tuples. Since an // empty tuple literal is easy to construct, we do so here. @@ -920,8 +914,8 @@ absl::StatusOr HloRunnerPjRt::TransferLiteralFromDevice( on_device_shape.IsTuple() && on_device_shape.tuple_shapes().size() == 0) { return LiteralUtil::MakeTuple({}); } - TF_ASSIGN_OR_RETURN(std::shared_ptr literal, - buffer.ToLiteral().Await()); + ASSIGN_OR_RETURN(std::shared_ptr literal, + buffer.ToLiteral().Await()); return std::move(*literal); } @@ -945,12 +939,11 @@ bool HloRunnerPjRt::HasProperty(const HloRunnerPropertyTag::Type tag) const { absl::StatusOr HloRunnerPjRt::HloModuleFromWrapped(const OpaqueExecutable* wrapped) const { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const HloRunnerPjRtExecutable* const hlo_runner_pjrt_executable, HloRunnerPjRtExecutable::TryUnwrap(*this, wrapped)); - TF_ASSIGN_OR_RETURN( - std::vector> modules, - hlo_runner_pjrt_executable->executable()->GetHloModules()); + ASSIGN_OR_RETURN(std::vector> modules, + hlo_runner_pjrt_executable->executable()->GetHloModules()); if (!modules.empty()) { return modules.front().get(); } @@ -963,8 +956,7 @@ bool HloRunnerPjRt::ExecutablesAreEquivalent( constexpr auto kFingerprint = [](const absl::StatusOr wrapped) -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(const HloRunnerPjRtExecutable* const executable, - wrapped); + ASSIGN_OR_RETURN(const HloRunnerPjRtExecutable* const executable, wrapped); return executable->executable()->FingerprintExecutable(); }; @@ -1049,16 +1041,16 @@ CompilePhaseHloRunnerPjRt::CreateExecutable(std::unique_ptr module, const bool run_hlo_passes) { const std::string path = tsl::io::JoinPath(artifact_dir_, MakeFilename(*module, run_hlo_passes)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr wrapped_executable, HloRunnerPjRt::CreateExecutable(std::move(module), run_hlo_passes)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloRunnerPjRtExecutable* const executable, HloRunnerPjRtExecutable::TryUnwrap(*this, wrapped_executable.get())); - TF_ASSIGN_OR_RETURN(const std::string serialized_executable, - executable->executable()->SerializeExecutable()); - TF_RETURN_IF_ERROR(WriteCompressedExecutable(path, serialized_executable)); + ASSIGN_OR_RETURN(const std::string serialized_executable, + executable->executable()->SerializeExecutable()); + RETURN_IF_ERROR(WriteCompressedExecutable(path, serialized_executable)); return wrapped_executable; } @@ -1066,16 +1058,16 @@ CompilePhaseHloRunnerPjRt::CreateExecutable(std::unique_ptr module, absl::Status CompilePhaseHloRunnerPjRt::WriteCompressedExecutable( absl::string_view path, absl::string_view serialized_executable) { std::unique_ptr file; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::Env::Default()->NewWritableFile(std::string(path), &file)); tsl::io::ZlibCompressionOptions gz_opts = tsl::io::ZlibCompressionOptions::GZIP(); tsl::io::ZlibOutputBuffer gz_file(file.get(), gz_opts.input_buffer_size, gz_opts.output_buffer_size, gz_opts); - TF_RETURN_IF_ERROR(gz_file.Init()); - TF_RETURN_IF_ERROR(gz_file.Append(serialized_executable)); - TF_RETURN_IF_ERROR(gz_file.Close()); + RETURN_IF_ERROR(gz_file.Init()); + RETURN_IF_ERROR(gz_file.Append(serialized_executable)); + RETURN_IF_ERROR(gz_file.Close()); return file->Close(); } @@ -1134,7 +1126,7 @@ ExecutePhaseHloRunnerPjRt::CreateExecutable(std::unique_ptr module, absl::Status ExecutePhaseHloRunnerPjRt::ReadCompressedExecutable( absl::string_view path, tsl::tstring* serialized_executable) { std::unique_ptr file; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::Env::Default()->NewRandomAccessFile(std::string(path), &file)); tsl::io::RandomAccessInputStream stream(file.get()); diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index 01dd81e3cf7570..32877e00289e4a 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -45,6 +45,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/collective_op_group_mode.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -122,7 +123,7 @@ absl::Status ShapeVerifier::Preprocess(HloInstruction* hlo) { } std::optional arity = HloOpcodeArity(hlo->opcode()); if (arity) { - TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity)); + RETURN_IF_ERROR(CheckOperandCount(hlo, *arity)); } if (!opts_.allow_unbounded_dynamism && hlo->shape().is_unbounded_dynamic()) { return InvalidArgument("Unbounded dynamism is disabled for instruction: %s", @@ -190,20 +191,19 @@ absl::Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } absl::Status ShapeVerifier::HandleDot(HloInstruction* dot) { - TF_ASSIGN_OR_RETURN( - const Shape expected, - ShapeInference::InferDotOpShape( - dot->operand(0)->shape(), dot->operand(1)->shape(), - dot->dot_dimension_numbers(), - /*preferred_element_type=*/dot->shape().element_type())); + ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferDotOpShape( + dot->operand(0)->shape(), dot->operand(1)->shape(), + dot->dot_dimension_numbers(), + /*preferred_element_type=*/dot->shape().element_type())); return CheckShape(dot, expected); } absl::Status ShapeVerifier::HandleRaggedDot(HloInstruction* ragged_dot) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckOperandCount(ragged_dot, HloRaggedDotInstruction::kOperands)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferRaggedDotOpShape( ragged_dot->operand(0)->shape(), ragged_dot->operand(1)->shape(), @@ -242,8 +242,8 @@ absl::Status ScalesShapeVerifier( const HloInstruction* operand = dot->operand(operand_number); const HloInstruction* scale_operand = dot->operand(scale_operand_number); - TF_ASSIGN_OR_RETURN(bool is_dummy_scale, - IsNoOpScale(dot, operand, scale_operand)); + ASSIGN_OR_RETURN(bool is_dummy_scale, + IsNoOpScale(dot, operand, scale_operand)); if (is_dummy_scale) { return absl::OkStatus(); } @@ -280,20 +280,19 @@ absl::Status ScalesShapeVerifier( } absl::Status ShapeVerifier::HandleScaledDot(HloInstruction* scaled_dot) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckOperandCount(scaled_dot, HloScaledDotInstruction::kOperands)); - TF_ASSIGN_OR_RETURN(auto dim_numbers, - DotOperandDims::FromScaledDot(scaled_dot)); - TF_RETURN_IF_ERROR(ScalesShapeVerifier(scaled_dot, dim_numbers, 0, 2)); - TF_RETURN_IF_ERROR(ScalesShapeVerifier(scaled_dot, dim_numbers, 1, 3)); + ASSIGN_OR_RETURN(auto dim_numbers, DotOperandDims::FromScaledDot(scaled_dot)); + RETURN_IF_ERROR(ScalesShapeVerifier(scaled_dot, dim_numbers, 0, 2)); + RETURN_IF_ERROR(ScalesShapeVerifier(scaled_dot, dim_numbers, 1, 3)); if (ShapeUtil::IsScalar(scaled_dot->operand(2)->shape()) && ShapeUtil::IsScalar(scaled_dot->operand(3)->shape())) { return absl::FailedPreconditionError(absl::StrFormat( "At least one of the scales should be not a scalar in %s", scaled_dot->ToString())); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferDotOpShape( scaled_dot->operand(0)->shape(), scaled_dot->operand(1)->shape(), @@ -303,7 +302,7 @@ absl::Status ShapeVerifier::HandleScaledDot(HloInstruction* scaled_dot) { } absl::Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), @@ -316,7 +315,7 @@ absl::Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { } absl::Status ShapeVerifier::HandleFft(HloInstruction* fft) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), fft->fft_length())); @@ -324,22 +323,22 @@ absl::Status ShapeVerifier::HandleFft(HloInstruction* fft) { } absl::Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) { - TF_ASSIGN_OR_RETURN(const Shape expected, - ShapeInference::InferTriangularSolveShape( - hlo->operand(0)->shape(), hlo->operand(1)->shape(), - hlo->triangular_solve_options())); + ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferTriangularSolveShape( + hlo->operand(0)->shape(), hlo->operand(1)->shape(), + hlo->triangular_solve_options())); return CheckShape(hlo, expected); } absl::Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); - TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape( - hlo->operand(0)->shape())); + RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); + ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape( + hlo->operand(0)->shape())); return CheckShape(hlo, expected); } absl::Status ShapeVerifier::HandleOptimizationBarrier(HloInstruction* hlo) { - TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); + RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); return CheckShape(hlo, hlo->operand(0)->shape()); } @@ -463,11 +462,11 @@ static absl::Status CheckCommonAllGatherInvariants( HloAllGatherInstruction* ag, int64_t* computed_shard_count, bool check_replica_groups) { CHECK_NE(computed_shard_count, nullptr) << "Expected a shard count as input"; - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(ag->channel_id().has_value(), - ag->use_global_device_ids())); + ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(ag->channel_id().has_value(), + ag->use_global_device_ids())); if (check_replica_groups) { - TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode)); + RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode)); } TF_RET_CHECK(ag->all_gather_dimension() >= 0); TF_RET_CHECK(ag->operand_count() >= 1); @@ -523,7 +522,7 @@ static absl::Status CheckCommonAllGatherInvariants( absl::Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) { auto ag = Cast(hlo); int64_t shard_count; - TF_RETURN_IF_ERROR(CheckCommonAllGatherInvariants( + RETURN_IF_ERROR(CheckCommonAllGatherInvariants( ag, &shard_count, opts_.ShouldCheckReplicaGroups())); std::vector operand_shapes; for (const HloInstruction* operand : hlo->operands()) { @@ -537,7 +536,7 @@ absl::Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) { absl::Status ShapeVerifier::HandleAllGatherStart(HloInstruction* hlo) { auto ag = Cast(hlo); int64_t shard_count; - TF_RETURN_IF_ERROR(CheckCommonAllGatherInvariants( + RETURN_IF_ERROR(CheckCommonAllGatherInvariants( ag, &shard_count, opts_.ShouldCheckReplicaGroups())); std::vector operand_shapes; for (const HloInstruction* operand : hlo->operands()) { @@ -556,11 +555,11 @@ absl::Status ShapeVerifier::HandleAllGatherDone(HloInstruction* hlo) { absl::Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) { auto ar = Cast(hlo); if (opts_.ShouldCheckReplicaGroups()) { - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(ar->channel_id().has_value(), - ar->use_global_device_ids())); - TF_RETURN_IF_ERROR(CheckReplicaGroups( - ar, group_mode, /*uniform_replica_group_size=*/false)); + ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(ar->channel_id().has_value(), + ar->use_global_device_ids())); + RETURN_IF_ERROR(CheckReplicaGroups(ar, group_mode, + /*uniform_replica_group_size=*/false)); } std::vector operand_shapes; for (const HloInstruction* operand : hlo->operands()) { @@ -571,11 +570,11 @@ absl::Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) { absl::Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) { auto ars = Cast(hlo); - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(ars->channel_id().has_value(), - ars->use_global_device_ids())); + ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(ars->channel_id().has_value(), + ars->use_global_device_ids())); if (opts_.ShouldCheckReplicaGroups()) { - TF_RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode)); + RETURN_IF_ERROR(CheckReplicaGroups(ars, group_mode)); } TF_RET_CHECK(ars->scatter_dimension() >= 0); TF_RET_CHECK(ars->operand_count() >= 1); @@ -622,11 +621,11 @@ absl::Status ShapeVerifier::HandleReduceScatter(HloInstruction* hlo) { absl::Status ShapeVerifier::HandleAllReduceStart(HloInstruction* hlo) { auto ar = Cast(hlo); if (opts_.ShouldCheckReplicaGroups()) { - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(ar->channel_id().has_value(), - ar->use_global_device_ids())); - TF_RETURN_IF_ERROR(CheckReplicaGroups( - ar, group_mode, /*uniform_replica_group_size=*/false)); + ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(ar->channel_id().has_value(), + ar->use_global_device_ids())); + RETURN_IF_ERROR(CheckReplicaGroups(ar, group_mode, + /*uniform_replica_group_size=*/false)); } std::vector operand_shapes; for (const HloInstruction* operand : hlo->operands()) { @@ -643,12 +642,12 @@ absl::Status ShapeVerifier::HandleAllReduceDone(HloInstruction* hlo) { absl::Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { auto* all_to_all = Cast(hlo); - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode( - all_to_all->channel_id().has_value(), std::nullopt)); + ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode( + all_to_all->channel_id().has_value(), std::nullopt)); if (opts_.ShouldCheckReplicaGroups()) { - TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode)); + RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode)); } TF_RET_CHECK(all_to_all != nullptr); const int64_t split_count = GetSubgroupSize(all_to_all, group_mode); @@ -671,12 +670,11 @@ absl::Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { absl::Status ShapeVerifier::HandleRaggedAllToAll(HloInstruction* hlo) { auto* all_to_all = Cast(hlo); if (opts_.ShouldCheckReplicaGroups()) { - TF_ASSIGN_OR_RETURN( - CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(all_to_all->channel_id().has_value(), - std::nullopt)); + ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode( + all_to_all->channel_id().has_value(), std::nullopt)); - TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode)); + RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode)); } const int64_t kNumRaggedOperands = 6; @@ -773,9 +771,8 @@ absl::Status CheckInplaceCollectivePermute( const Shape& output_offset_shape = collective_permute->operand(3)->shape(); if (input_buffer_shape.IsArray() && output_buffer_shape.IsArray()) { - TF_RETURN_IF_ERROR( - CheckBufferOffset(input_buffer_shape, input_offset_shape)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(CheckBufferOffset(input_buffer_shape, input_offset_shape)); + RETURN_IF_ERROR( CheckBufferOffset(output_buffer_shape, output_offset_shape)); } else if (input_buffer_shape.IsTuple() && output_buffer_shape.IsTuple()) { if (ShapeUtil::TupleElementCount(input_buffer_shape) != @@ -789,8 +786,8 @@ absl::Status CheckInplaceCollectivePermute( } for (int i = 0; i < input_buffer_shape.tuple_shapes().size(); ++i) { - TF_RETURN_IF_ERROR(CheckBufferOffset(input_buffer_shape.tuple_shapes(i), - input_offset_shape.tuple_shapes(i))); + RETURN_IF_ERROR(CheckBufferOffset(input_buffer_shape.tuple_shapes(i), + input_offset_shape.tuple_shapes(i))); } if (!output_offset_shape.IsTuple() || ShapeUtil::TupleElementCount(output_offset_shape) != @@ -798,9 +795,8 @@ absl::Status CheckInplaceCollectivePermute( return Internal("Unmatching output buffers and output offset."); } for (int i = 0; i < output_buffer_shape.tuple_shapes().size(); ++i) { - TF_RETURN_IF_ERROR( - CheckBufferOffset(output_buffer_shape.tuple_shapes(i), - output_offset_shape.tuple_shapes(i))); + RETURN_IF_ERROR(CheckBufferOffset(output_buffer_shape.tuple_shapes(i), + output_offset_shape.tuple_shapes(i))); } } else { return Internal("Unmatching input buffers and output buffers."); @@ -810,8 +806,8 @@ absl::Status CheckInplaceCollectivePermute( absl::Status CheckDuplicatedSourceOrTarget( HloCollectivePermuteInstruction* collective_permute) { - TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(collective_permute)); + ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(collective_permute)); // A source or target cannot appear twice in the collective-permute's // source-target pairs. Also, based on the group formation mode, check if the @@ -917,8 +913,8 @@ absl::Status ShapeVerifier::HandleCollectiveBroadcast(HloInstruction* hlo) { absl::Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { HloCollectivePermuteInstruction* collective_permute = Cast(hlo); - TF_RETURN_IF_ERROR(CheckInplaceCollectivePermute(collective_permute)); - TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(collective_permute)); + RETURN_IF_ERROR(CheckInplaceCollectivePermute(collective_permute)); + RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(collective_permute)); std::vector operand_shapes; absl::c_transform( collective_permute->operands(), std::back_inserter(operand_shapes), @@ -931,8 +927,8 @@ absl::Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) { HloCollectivePermuteInstruction* collective_permute_start = Cast(hlo); - TF_RETURN_IF_ERROR(CheckInplaceCollectivePermute(collective_permute_start)); - TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(collective_permute_start)); + RETURN_IF_ERROR(CheckInplaceCollectivePermute(collective_permute_start)); + RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(collective_permute_start)); std::vector operand_shapes; absl::c_transform( collective_permute_start->operands(), std::back_inserter(operand_shapes), @@ -991,7 +987,7 @@ absl::Status ShapeVerifier::CheckOperandAndParameter( absl::Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast(instruction); - TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); + RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); // The output of infeed is a tuple containing the data value and a token. return CheckShape(infeed, @@ -1002,7 +998,7 @@ absl::Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { absl::Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { HloOutfeedInstruction* outfeed = Cast(instruction); - TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); + RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); // Outfeed has a separate shape field for the value which is outfed to the // host. The shape of the instruction itself is always a token. @@ -1027,7 +1023,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, } absl::Status ShapeVerifier::HandleRng(HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); + RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); const Shape& shape_0 = instruction->operand(0)->shape(); const Shape& shape_1 = instruction->operand(1)->shape(); @@ -1097,7 +1093,7 @@ absl::Status ShapeVerifier::HandleRngBitGenerator(HloInstruction* hlo) { absl::Status ShapeVerifier::HandleRngGetAndUpdateState( HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); + RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); const Shape& result_shape = instruction->shape(); const Shape expected_shape = ShapeUtil::MakeShape(U64, {2}); if (!ShapeUtil::Compatible(result_shape, expected_shape)) { @@ -1140,7 +1136,7 @@ absl::Status ShapeVerifier::HandleSort(HloInstruction* hlo) { // Check that the number of parameters of the 'compare' computation is // correct. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckParameterCount(sort, compare, sort->operand_count() * 2)); // Verify that the operands of the compare computation have the correct scalar @@ -1261,7 +1257,7 @@ absl::Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { for (const HloInstruction* operand : reduce->operands()) { operand_shapes.push_back(&operand->shape()); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckShape(reduce, ShapeInference::InferReduceShape( operand_shapes, reduce->dimensions(), reduce->to_apply()->ComputeProgramShape()))); @@ -1353,7 +1349,7 @@ absl::Status ShapeVerifier::HandleScan(HloInstruction* scan) { int64_t num_carries = scan_instr->num_carries(); - TF_RETURN_IF_ERROR(CheckScanOperandAndResultCounts( + RETURN_IF_ERROR(CheckScanOperandAndResultCounts( operand_shapes.size(), parameter_shapes.size(), root_shapes.size(), result_shapes.size(), num_carries)); @@ -1361,7 +1357,7 @@ absl::Status ShapeVerifier::HandleScan(HloInstruction* scan) { if (scan_dim < 0) { return Internal("Scan dimension %d should be non-negative", scan_dim); } - TF_ASSIGN_OR_RETURN(int64_t scan_dim_size, scan_instr->GetScanDimSize()); + ASSIGN_OR_RETURN(int64_t scan_dim_size, scan_instr->GetScanDimSize()); int64_t num_inputs = operand_shapes.size() - num_carries; int64_t num_outputs = result_shapes.size() - num_carries; @@ -1643,10 +1639,10 @@ absl::Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { } absl::Status ShapeVerifier::HandleCall(HloInstruction* call) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckParameterCount(call, call->to_apply(), call->operand_count())); for (int64_t i = 0; i < call->to_apply()->num_parameters(); ++i) { - TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i)); + RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i)); } if (call->is_composite()) { TF_RET_CHECK(call->has_frontend_attributes()) @@ -1789,7 +1785,7 @@ absl::Status ShapeVerifier::HandleMap(HloInstruction* map) { std::vector map_dims(max_operand_rank); std::iota(map_dims.begin(), map_dims.end(), 0); - TF_RETURN_IF_ERROR(CheckShape( + RETURN_IF_ERROR(CheckShape( map, ShapeInference::InferMapShape( operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims))); @@ -1804,7 +1800,7 @@ absl::Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { auto reduce_window_instr = Cast(reduce_window); auto input_shapes = reduce_window_instr->input_shapes(); auto init_shapes = reduce_window_instr->init_value_shapes(); - TF_RETURN_IF_ERROR(CheckShape( + RETURN_IF_ERROR(CheckShape( reduce_window, ShapeInference::InferReduceWindowShape( input_shapes, init_shapes, reduce_window->window(), reduce_window->to_apply()->ComputeProgramShape()))); @@ -1827,13 +1823,12 @@ absl::Status ShapeVerifier::HandleSelectAndScatter( } absl::Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { - TF_RETURN_IF_ERROR( - CheckParameterCount(xla_while, xla_while->while_body(), 1)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(CheckParameterCount(xla_while, xla_while->while_body(), 1)); + RETURN_IF_ERROR( CheckParameterCount(xla_while, xla_while->while_condition(), 1)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); const Shape& conditional_shape = xla_while->while_condition()->root_instruction()->shape(); @@ -1869,13 +1864,13 @@ absl::Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { } TF_RET_CHECK(num_branches >= 1); } - TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1)); + RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1)); for (int j = 0; j < num_branches; ++j) { - TF_RETURN_IF_ERROR(CheckParameterCount( - conditional, conditional->branch_computation(j), 1)); - TF_RETURN_IF_ERROR(CheckOperandAndParameter( + RETURN_IF_ERROR(CheckParameterCount(conditional, + conditional->branch_computation(j), 1)); + RETURN_IF_ERROR(CheckOperandAndParameter( conditional, j + 1, conditional->branch_computation(j), 0)); - TF_RETURN_IF_ERROR(CheckShape( + RETURN_IF_ERROR(CheckShape( conditional, conditional->branch_computation(j)->root_instruction()->shape())); } @@ -1978,9 +1973,9 @@ absl::Status ShapeVerifier::CheckAsyncOpComputationShapes( } absl::Status ShapeVerifier::HandleAsyncStart(HloInstruction* async_start) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckAsyncOpComputationShapes(async_start, async_start->shape())); - TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_start)); + RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_start)); const Shape& param_shape = async_start->shape().tuple_shapes(0); for (int i = 0; i < async_start->operand_count(); ++i) { if (!ShapesSame(param_shape.tuple_shapes(i), @@ -2026,7 +2021,7 @@ absl::Status ShapeVerifier::HandleAsyncStart(HloInstruction* async_start) { } absl::Status ShapeVerifier::HandleAsyncUpdate(HloInstruction* async_update) { - TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_update)); + RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_update)); if (!ShapesSame(async_update->operand(0)->shape(), async_update->shape())) { return Internal( "The %s expects the shape of operand and output to match (%s vs %s).", @@ -2034,14 +2029,14 @@ absl::Status ShapeVerifier::HandleAsyncUpdate(HloInstruction* async_update) { async_update->operand(0)->shape().ToString(true), async_update->shape().ToString(true)); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckAsyncOpComputationShapes(async_update, async_update->shape())); return CheckAsyncOpOperand(async_update); } absl::Status ShapeVerifier::HandleAsyncDone(HloInstruction* async_done) { - TF_RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_done)); - TF_RETURN_IF_ERROR(CheckAsyncOpComputationShapes( + RETURN_IF_ERROR(CheckAsyncOpComputationThreadName(async_done)); + RETURN_IF_ERROR(CheckAsyncOpComputationShapes( async_done, async_done->operand(0)->shape())); const Shape& root_shape = async_done->operand(0)->shape().tuple_shapes(1); if (!ShapesSame(root_shape, async_done->shape())) { @@ -2193,7 +2188,7 @@ absl::Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { default: { PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID; for (auto operand : instruction->operands()) { - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) -> absl::Status { @@ -2248,7 +2243,7 @@ absl::Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { absl::Status ShapeVerifier::HandleAddDependency( HloInstruction* add_dependency) { - TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1)); + RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1)); return CheckShape(add_dependency, add_dependency->operand(0)->shape()); } @@ -2272,7 +2267,7 @@ absl::Status ShapeVerifier::CheckShape( // different precisions. We need this check because ShapeInference allows // mixed precision inputs. if (!opts_.allow_mixed_precision) { - TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction)); + RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction)); } // Check if the output shape matches the expected shape. @@ -2422,7 +2417,7 @@ absl::Status ShapeVerifier::VerifyEntryComputationLayout( const auto& layout = module.entry_computation_layout(); const ShapeLayout& result_layout = layout.result_layout(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape())); // TPU layout assignment doesn't set the tiles on entry_computation_layout, so @@ -2449,7 +2444,7 @@ absl::Status ShapeVerifier::VerifyEntryComputationLayout( for (int i = 0; i < computation->num_parameters(); ++i) { const HloInstruction* parameter = computation->parameter_instruction(i); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i))); // TPU layout assignment doesn't set the tiles on entry_computation_layout, // so let's not check that. @@ -2471,7 +2466,7 @@ absl::Status ShapeVerifier::VerifyEntryComputationLayout( // same shape, layout and memory space for them (for example we can't alias // parameter and result if they have different memory spaces). const auto& alias_config = module.input_output_alias_config(); - TF_RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus( + RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus( [&](ShapeIndex result_index, HloInputOutputAliasConfig::Alias alias) -> absl::Status { // We skip may-alias buffers as they do not force aliasing. @@ -2701,59 +2696,59 @@ absl::Status VerifyAsynchronousInstructionPairs(const HloModule& module) { for (const HloInstruction* instruction : computation->instructions()) { switch (instruction->opcode()) { case HloOpcode::kAsyncStart: { - TF_RETURN_IF_ERROR(VerifySingleUser( + RETURN_IF_ERROR(VerifySingleUser( instruction, {HloOpcode::kAsyncUpdate, HloOpcode::kAsyncDone})); break; } case HloOpcode::kAsyncUpdate: { - TF_RETURN_IF_ERROR(VerifySingleOperand( + RETURN_IF_ERROR(VerifySingleOperand( instruction, {HloOpcode::kAsyncStart, HloOpcode::kAsyncUpdate})); - TF_RETURN_IF_ERROR(VerifySingleUser( + RETURN_IF_ERROR(VerifySingleUser( instruction, {HloOpcode::kAsyncUpdate, HloOpcode::kAsyncDone})); break; } case HloOpcode::kAsyncDone: { - TF_RETURN_IF_ERROR(VerifySingleOperand( + RETURN_IF_ERROR(VerifySingleOperand( instruction, {HloOpcode::kAsyncStart, HloOpcode::kAsyncUpdate})); break; } case HloOpcode::kAllReduceStart: { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( VerifySingleUser(instruction, {HloOpcode::kAllReduceDone})); break; } case HloOpcode::kAllReduceDone: { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( VerifySingleOperand(instruction, {HloOpcode::kAllReduceStart})); break; } case HloOpcode::kAllGatherStart: { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( VerifySingleUser(instruction, {HloOpcode::kAllGatherDone})); break; } case HloOpcode::kAllGatherDone: { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( VerifySingleOperand(instruction, {HloOpcode::kAllGatherStart})); break; } case HloOpcode::kCopyStart: { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( VerifySingleUser(instruction, {HloOpcode::kCopyDone})); break; } case HloOpcode::kCopyDone: { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( VerifySingleOperand(instruction, {HloOpcode::kCopyStart})); break; } case HloOpcode::kCollectivePermuteStart: { - TF_RETURN_IF_ERROR(VerifySingleUser( + RETURN_IF_ERROR(VerifySingleUser( instruction, {HloOpcode::kCollectivePermuteDone})); break; } case HloOpcode::kCollectivePermuteDone: { - TF_RETURN_IF_ERROR(VerifySingleOperand( + RETURN_IF_ERROR(VerifySingleOperand( instruction, {HloOpcode::kCollectivePermuteStart})); break; } @@ -2764,12 +2759,12 @@ absl::Status VerifyAsynchronousInstructionPairs(const HloModule& module) { instruction->parent()->IsAsyncComputation()) { break; } - TF_RETURN_IF_ERROR(VerifySingleUser( + RETURN_IF_ERROR(VerifySingleUser( instruction, {HloOpcode::kSendDone, HloOpcode::kTuple})); break; } case HloOpcode::kSendDone: { - TF_RETURN_IF_ERROR(VerifySingleOperand( + RETURN_IF_ERROR(VerifySingleOperand( instruction, {HloOpcode::kSend, HloOpcode::kGetTupleElement})); break; } @@ -2780,12 +2775,12 @@ absl::Status VerifyAsynchronousInstructionPairs(const HloModule& module) { instruction->parent()->IsAsyncComputation()) { break; } - TF_RETURN_IF_ERROR(VerifySingleUser( + RETURN_IF_ERROR(VerifySingleUser( instruction, {HloOpcode::kRecvDone, HloOpcode::kTuple})); break; } case HloOpcode::kRecvDone: { - TF_RETURN_IF_ERROR(VerifySingleOperand( + RETURN_IF_ERROR(VerifySingleOperand( instruction, {HloOpcode::kRecv, HloOpcode::kGetTupleElement})); break; } @@ -2836,13 +2831,12 @@ absl::Status VerifyNoConflictingSourceTargetPairs( template absl::Status VerifyNoConflictingSendOrRecv( const T* instruction, absl::flat_hash_set& instructions) { - TF_ASSIGN_OR_RETURN(SourceTargetPairs source_target_pairs_array, - SourceTargetPairs::FromInstruction(instruction)); + ASSIGN_OR_RETURN(SourceTargetPairs source_target_pairs_array, + SourceTargetPairs::FromInstruction(instruction)); for (const T* existing_instruction : instructions) { - TF_ASSIGN_OR_RETURN( - SourceTargetPairs existing_source_target_pairs_array, - SourceTargetPairs::FromInstruction(existing_instruction)); - TF_RETURN_IF_ERROR(VerifyNoConflictingSourceTargetPairs( + ASSIGN_OR_RETURN(SourceTargetPairs existing_source_target_pairs_array, + SourceTargetPairs::FromInstruction(existing_instruction)); + RETURN_IF_ERROR(VerifyNoConflictingSourceTargetPairs( existing_instruction, SourceTargetPairs::Join(source_target_pairs_array, existing_source_target_pairs_array))); @@ -2867,9 +2861,9 @@ absl::StatusOr ShouldSkipDeadlockCheck(const T* instruction) { } // Check that the instruction itself does not have conflicting // source-target pairs. - TF_ASSIGN_OR_RETURN(SourceTargetPairs source_target_pairs_array, - SourceTargetPairs::FromInstruction(instruction)); - TF_RETURN_IF_ERROR(VerifyNoConflictingSourceTargetPairs( + ASSIGN_OR_RETURN(SourceTargetPairs source_target_pairs_array, + SourceTargetPairs::FromInstruction(instruction)); + RETURN_IF_ERROR(VerifyNoConflictingSourceTargetPairs( instruction, source_target_pairs_array)); return false; } @@ -2902,7 +2896,7 @@ absl::Status CheckDeadlocksForSend( const HloSendInstruction* send, DfaState& current_state, absl::flat_hash_set& pending_send_instructions, absl::flat_hash_set& pending_recv_instructions) { - TF_ASSIGN_OR_RETURN(bool skip, ShouldSkipDeadlockCheck(send)); + ASSIGN_OR_RETURN(bool skip, ShouldSkipDeadlockCheck(send)); if (skip) { return absl::OkStatus(); } @@ -2939,7 +2933,7 @@ absl::Status CheckDeadlocksForRecv( const HloRecvInstruction* recv, DfaState& current_state, absl::flat_hash_set& send_instructions, absl::flat_hash_set& recv_instructions) { - TF_ASSIGN_OR_RETURN(bool skip, ShouldSkipDeadlockCheck(recv)); + ASSIGN_OR_RETURN(bool skip, ShouldSkipDeadlockCheck(recv)); if (skip) { return absl::OkStatus(); } @@ -3054,15 +3048,15 @@ absl::Status VerifyNoCollectiveDeadlocksRecursive( for (const HloInstruction* instruction : computation->instructions()) { if (instruction->called_computations().empty()) { if (instruction->opcode() == HloOpcode::kSend) { - TF_RETURN_IF_ERROR(CheckDeadlocksForSend( + RETURN_IF_ERROR(CheckDeadlocksForSend( DynCast(instruction), current_state, send_instructions, recv_instructions)); } else if (instruction->opcode() == HloOpcode::kRecv) { - TF_RETURN_IF_ERROR(CheckDeadlocksForRecv( + RETURN_IF_ERROR(CheckDeadlocksForRecv( DynCast(instruction), current_state, send_instructions, recv_instructions)); } else if (IsOtherCollective(instruction)) { - TF_RETURN_IF_ERROR(CheckDeadlocksForOtherCollectives( + RETURN_IF_ERROR(CheckDeadlocksForOtherCollectives( instruction, current_state, send_instructions, recv_instructions)); } else { continue; @@ -3083,16 +3077,16 @@ absl::Status VerifyNoCollectiveDeadlocksRecursive( async_comp_send_instructions; absl::flat_hash_set async_comp_recv_instructions; - TF_RETURN_IF_ERROR(VerifyNoCollectiveDeadlocksRecursive( + RETURN_IF_ERROR(VerifyNoCollectiveDeadlocksRecursive( computation, async_comp_current_state, async_comp_send_instructions, async_comp_recv_instructions)); if (current_state != DfaState::kNoExpectation) { - TF_RETURN_IF_ERROR(CheckPendingSendRecvDeadlocks( + RETURN_IF_ERROR(CheckPendingSendRecvDeadlocks( async_comp_send_instructions, async_comp_recv_instructions)); } } else { // normal case - TF_RETURN_IF_ERROR(VerifyNoCollectiveDeadlocksRecursive( + RETURN_IF_ERROR(VerifyNoCollectiveDeadlocksRecursive( computation, current_state, send_instructions, recv_instructions)); } @@ -3106,11 +3100,11 @@ absl::Status VerifyNoCollectiveDeadlocks(const HloModule& module) { DfaState current_state = DfaState::kNoExpectation; absl::flat_hash_set send_instructions; absl::flat_hash_set recv_instructions; - TF_RETURN_IF_ERROR(VerifyNoCollectiveDeadlocksRecursive( + RETURN_IF_ERROR(VerifyNoCollectiveDeadlocksRecursive( module.entry_computation(), current_state, send_instructions, recv_instructions)); if (current_state != DfaState::kNoExpectation) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckPendingSendRecvDeadlocks(send_instructions, recv_instructions)); } return absl::OkStatus(); @@ -3257,8 +3251,8 @@ absl::Status VerifyChannels(const HloModule& module, : HloOpcode::kRecvDone; const HloInstruction* done = instruction->users().front(); if (done->opcode() == done_opcode) { - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, done)); + RETURN_IF_ERROR(CheckSameChannel(instruction, done)); + RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, done)); } break; } @@ -3576,7 +3570,7 @@ absl::Status CheckBufferHasUniqueWriters(const HloInstruction* inst) { inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) -> absl::Status { if (subshape.IsBuffer()) { - TF_RETURN_IF_ERROR(CheckBufferHasUniqueWriter(inst, index)); + RETURN_IF_ERROR(CheckBufferHasUniqueWriter(inst, index)); } return absl::OkStatus(); }); @@ -3677,7 +3671,7 @@ absl::Status VerifyBuffersInOperands(const HloCustomCallInstruction* inst) { int64_t operand_index = 0; for (auto* operand : inst->operands()) { - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( operand->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) -> absl::Status { @@ -3768,23 +3762,23 @@ absl::Status VerifyCustomCall(const HloCustomCallInstruction* inst, return VerifyUnpin(Cast(inst), layout_sensitive); } - TF_RETURN_IF_ERROR(VerifyBuffersInOperands(inst)); + RETURN_IF_ERROR(VerifyBuffersInOperands(inst)); // Record the ShapeIndex for the buffers in the results. absl::flat_hash_set buffer_results; - TF_RETURN_IF_ERROR(VerifyBuffersInResults(inst, buffer_results)); + RETURN_IF_ERROR(VerifyBuffersInResults(inst, buffer_results)); // Ensure that an SSA buffer result can have at most one writer. for (const auto& result_index : buffer_results) { - TF_RETURN_IF_ERROR(CheckBufferHasUniqueWriter(inst, result_index)); + RETURN_IF_ERROR(CheckBufferHasUniqueWriter(inst, result_index)); } return absl::OkStatus(); } absl::Status VerifyNoBuffersInContext(const HloInstruction* inst) { - TF_RETURN_IF_ERROR(VerifyNoBuffers(inst->shape(), inst)); + RETURN_IF_ERROR(VerifyNoBuffers(inst->shape(), inst)); for (auto* operand : inst->operands()) { - TF_RETURN_IF_ERROR(VerifyNoBuffers(operand->shape(), inst)); + RETURN_IF_ERROR(VerifyNoBuffers(operand->shape(), inst)); } return absl::OkStatus(); } @@ -3799,8 +3793,8 @@ absl::Status VerifyBuffers(const HloModule& module, bool layout_sensitive) { // allow the op to use buffers. HloInstruction* root = comp->root_instruction(); if (root->opcode() == HloOpcode::kCustomCall) { - TF_RETURN_IF_ERROR(VerifyCustomCall( - Cast(root), layout_sensitive)); + RETURN_IF_ERROR(VerifyCustomCall(Cast(root), + layout_sensitive)); } continue; } @@ -3816,22 +3810,21 @@ absl::Status VerifyBuffers(const HloModule& module, bool layout_sensitive) { continue; } if (inst->opcode() == HloOpcode::kCustomCall) { - TF_RETURN_IF_ERROR(VerifyCustomCall( - Cast(inst), layout_sensitive)); + RETURN_IF_ERROR(VerifyCustomCall(Cast(inst), + layout_sensitive)); } else if (inst->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(CheckBufferHasUniqueWriters(inst)); + RETURN_IF_ERROR(CheckBufferHasUniqueWriters(inst)); } else if (inst->opcode() == HloOpcode::kParameter) { if (comp->IsEntryComputation()) { - TF_RETURN_IF_ERROR(VerifyNoBuffersInContext(inst)); + RETURN_IF_ERROR(VerifyNoBuffersInContext(inst)); } - TF_RETURN_IF_ERROR(CheckBufferHasUniqueWriters(inst)); + RETURN_IF_ERROR(CheckBufferHasUniqueWriters(inst)); } else if (inst->opcode() == HloOpcode::kDynamicUpdateSlice) { if (inst->operand(0)->shape().IsBuffer()) { - TF_RETURN_IF_ERROR(CheckBufferHasUniqueWriters(inst)); + RETURN_IF_ERROR(CheckBufferHasUniqueWriters(inst)); // Operand 1 and following should not be buffers. for (int i = 1; i < inst->operand_count(); ++i) { - TF_RETURN_IF_ERROR( - VerifyNoBuffers(inst->operand(i)->shape(), inst)); + RETURN_IF_ERROR(VerifyNoBuffers(inst->operand(i)->shape(), inst)); } if (!inst->shape().IsBuffer()) { return InvalidArgument( @@ -3839,11 +3832,11 @@ absl::Status VerifyBuffers(const HloModule& module, bool layout_sensitive) { "buffer"); } } else { - TF_RETURN_IF_ERROR(VerifyNoBuffersInContext(inst)); + RETURN_IF_ERROR(VerifyNoBuffersInContext(inst)); } } else if (inst->opcode() != HloOpcode::kGetTupleElement && inst->opcode() != HloOpcode::kTuple) { - TF_RETURN_IF_ERROR(VerifyNoBuffersInContext(inst)); + RETURN_IF_ERROR(VerifyNoBuffersInContext(inst)); } } } @@ -3858,7 +3851,7 @@ absl::Status InstructionVerifier::DefaultAction(HloInstruction*) { } absl::Status InstructionVerifier::HandleFusion(HloInstruction* fusion) { - TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(fusion)); + RETURN_IF_ERROR(CheckCallableInstructionThreadName(fusion)); return CheckFusionInstruction(fusion); } @@ -3905,11 +3898,11 @@ absl::Status InstructionVerifier::HandleWhile(HloInstruction* xla_while) { xla_while->operand_count(), xla_while->ToString()); } // Allow kWhile to contain computations on separate thread. - TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(xla_while)); + RETURN_IF_ERROR(CheckCallableInstructionThreadName(xla_while)); // Verify consistency of sharding of while instructions and related // instructions (parameters, root) in its called computations. - TF_RETURN_IF_ERROR(VerifyConsistentSharding( + RETURN_IF_ERROR(VerifyConsistentSharding( xla_while, {xla_while, xla_while->while_body()->root_instruction(), xla_while->while_body()->parameter_instruction(0), xla_while->while_condition()->parameter_instruction(0)})); @@ -3952,11 +3945,11 @@ absl::Status InstructionVerifier::HandleConditional( branch_computation->root_instruction()); } // Allow kConditional to contain computations on separate thread. - TF_RETURN_IF_ERROR(CheckCallableInstructionThreadName(conditional)); + RETURN_IF_ERROR(CheckCallableInstructionThreadName(conditional)); // Verify consistency of sharding of conditional instructions and roots of // its branches. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( VerifyConsistentSharding(conditional, sharding_check_instructions)); return absl::OkStatus(); @@ -3964,7 +3957,7 @@ absl::Status InstructionVerifier::HandleConditional( absl::Status InstructionVerifier::HandleElementwiseUnary( HloInstruction* instruction) { - TF_RETURN_IF_ERROR(CheckUnaryOpWithResultAccuracy(instruction)); + RETURN_IF_ERROR(CheckUnaryOpWithResultAccuracy(instruction)); return CheckElementwiseInstruction(instruction); } @@ -4079,7 +4072,7 @@ absl::Status InstructionVerifier::Preprocess(HloInstruction* instruction) { absl::Status InstructionVerifier::Postprocess(HloInstruction* instruction) { if (opts_.verify_no_host_memory_space) { - TF_RETURN_IF_ERROR(VerifyNoHostMemorySpace(instruction)); + RETURN_IF_ERROR(VerifyNoHostMemorySpace(instruction)); } if (!opts_.InstructionCanChangeLayout(instruction) && instruction->shape().IsArray() && instruction->shape().has_layout()) { @@ -4106,7 +4099,7 @@ absl::Status InstructionVerifier::Postprocess(HloInstruction* instruction) { } else if (instruction->opcode() == HloOpcode::kDynamicSlice || instruction->opcode() == HloOpcode::kDynamicUpdateSlice || instruction->opcode() == HloOpcode::kCopy) { - TF_RETURN_IF_ERROR(HostOffloadInstructionCanChangeMemorySpace( + RETURN_IF_ERROR(HostOffloadInstructionCanChangeMemorySpace( instruction, operand_layout.memory_space(), result_layout.memory_space())); equal_predicate.IgnoreMemorySpace(); @@ -4204,11 +4197,11 @@ absl::StatusOr HloVerifier::RunImpl( "Module entry computation cannot be a fusion computation"); } - TF_RETURN_IF_ERROR(VerifyHloStructure(module)); - TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(VerifyHloStructure(module)); + RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module)); + RETURN_IF_ERROR( VerifyChannels(*module, target_metadata_->GetVerifierOpts())); - TF_RETURN_IF_ERROR(VerifyInstructionNameUnchanged( + RETURN_IF_ERROR(VerifyInstructionNameUnchanged( *module, target_metadata_->GetVerifierOpts())); std::unique_ptr shape_verifier = @@ -4216,8 +4209,8 @@ absl::StatusOr HloVerifier::RunImpl( InstructionVerifier instruction_verifier( module, target_metadata_->GetVerifierOpts()); for (auto* computation : module->computations(execution_threads)) { - TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); - TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); + RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); + RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); // Verify that async computations contain a single instruction or a // collection of send/recv instructions. This is needed to represent NCCL // groups on GPU. @@ -4225,27 +4218,27 @@ absl::StatusOr HloVerifier::RunImpl( !computation->OnlyContainsSendRecv() && !IsCollectivesGroupComputation(computation) && !IsAsyncBarrierComputation(computation)) { - TF_RETURN_IF_ERROR(VerifyAsyncComputation(computation)); + RETURN_IF_ERROR(VerifyAsyncComputation(computation)); } } - TF_RETURN_IF_ERROR(VerifyBuffers( + RETURN_IF_ERROR(VerifyBuffers( *module, target_metadata_->GetVerifierOpts().IsLayoutSensitive())); - TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module)); + RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module)); // If the module has a schedule, it must be valid. if (module->has_schedule()) { - TF_RETURN_IF_ERROR(module->schedule().Verify()); + RETURN_IF_ERROR(module->schedule().Verify()); if (target_metadata_->GetVerifierOpts().CheckForCollectiveDeadlocks()) { - TF_RETURN_IF_ERROR(VerifyNoCollectiveDeadlocks(*module)); + RETURN_IF_ERROR(VerifyNoCollectiveDeadlocks(*module)); } } if (HloInstruction::IsThreadIncluded( module->entry_computation()->execution_thread(), execution_threads)) { - TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify( + RETURN_IF_ERROR(module->input_output_alias_config().Verify( *module, [this](const Shape& shape) -> int64_t { if (target_metadata_->GetVerifierOpts().IsLayoutSensitive()) { return target_metadata_->GetVerifierOpts().ShapeSize(shape); @@ -4254,9 +4247,9 @@ absl::StatusOr HloVerifier::RunImpl( })); } - TF_RETURN_IF_ERROR(module->buffer_donor_config().Verify(*module)); - TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module)); - TF_RETURN_IF_ERROR(VerifyOriginalValue(*module)); + RETURN_IF_ERROR(module->buffer_donor_config().Verify(*module)); + RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module)); + RETURN_IF_ERROR(VerifyOriginalValue(*module)); return false; }(); if (status_or_changed.ok()) { diff --git a/third_party/xla/xla/service/host_offload_utils.cc b/third_party/xla/xla/service/host_offload_utils.cc index c48245151b6876..66f7565923a41b 100644 --- a/third_party/xla/xla/service/host_offload_utils.cc +++ b/third_party/xla/xla/service/host_offload_utils.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -450,8 +451,8 @@ absl::Status MarkDynamicVariables(HloInstruction* while_loop) { } WhileLoopBackendConfig config; - TF_ASSIGN_OR_RETURN(config, - while_loop->backend_config()); + ASSIGN_OR_RETURN(config, + while_loop->backend_config()); config.clear_dynamic_variable_tuple_indices(); @@ -479,7 +480,7 @@ absl::Status MarkDynamicVariables(HloInstruction* while_loop) { config.add_dynamic_variable_tuple_indices(tuple_idx); } - TF_RETURN_IF_ERROR(while_loop->set_backend_config(config)); + RETURN_IF_ERROR(while_loop->set_backend_config(config)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/instruction_fusion.cc b/third_party/xla/xla/service/instruction_fusion.cc index 3dc0721150bc63..ace65c88af68f4 100644 --- a/third_party/xla/xla/service/instruction_fusion.cc +++ b/third_party/xla/xla/service/instruction_fusion.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/types/source_location.h" #endif // PLATFORM_GOOGLE #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/hlo/analysis/alias_info.h" #include "xla/hlo/analysis/hlo_operand_index.h" @@ -733,8 +734,8 @@ absl::StatusOr InstructionFusion::RunImpl( // Operand is now dead. Remove from queue. fusion_queue->RemoveInstruction(operand); // Remove from computation. - TF_RETURN_IF_ERROR(operand->SafelyDropAllControlDependencies()); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand)); + RETURN_IF_ERROR(operand->SafelyDropAllControlDependencies()); + RETURN_IF_ERROR(computation->RemoveInstruction(operand)); } if (dump_fusion) { diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index 5664cc4f430900..4dd842a092ad2b 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -45,6 +45,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "re2/re2.h" #include "xla/hlo/analysis/alias_info.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" @@ -2454,7 +2455,7 @@ absl::Status DefaultSchedulerCore::ScheduleAnnotation( }()); VLOG(2) << "Current time: " << sched_state->current_time; // Find the best annotated node to schedule. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloGraphNode * node, FindAndExtractBestAnnotatedNode( *sched_state, scheduling_instruction_crosses_overlap_limit_, @@ -2486,8 +2487,8 @@ absl::Status DefaultSchedulerCore::ScheduleAnnotation( sched_state->ready_set.pop_back(); // Schedule the node. - TF_ASSIGN_OR_RETURN(sched_state->current_time, - ScheduleNode(node, sched_state)); + ASSIGN_OR_RETURN(sched_state->current_time, + ScheduleNode(node, sched_state)); num_scheduled++; VLOG(2) << "Scheduled annotated node (" << num_scheduled << "/" << annotation_size << "): " << node->GetInstr().name(); @@ -3526,12 +3527,11 @@ absl::Status DefaultSchedulerCore::SchedulingStep( SchedulingState* sched_state) { // Get the first available node for scheduling that is the node that // satisfies our ready heuristic the best. - TF_ASSIGN_OR_RETURN(HloGraphNode * node, - FindAndExtractBestNodeAvailable( - *sched_state, /*should_skip_node=*/nullptr)); + ASSIGN_OR_RETURN(HloGraphNode * node, + FindAndExtractBestNodeAvailable( + *sched_state, /*should_skip_node=*/nullptr)); CHECK(node != nullptr); - TF_ASSIGN_OR_RETURN(sched_state->current_time, - ScheduleNode(node, sched_state)); + ASSIGN_OR_RETURN(sched_state->current_time, ScheduleNode(node, sched_state)); VLOG(2) << "Scheduled: " << node->GetInstr().name(); XLA_VLOG_LINES(5, node->ToString()); return absl::OkStatus(); @@ -3665,8 +3665,7 @@ absl::StatusOr DefaultSchedulerCore::TryScheduleOneAnnotationGroup( sched_state->ready_annotations.pop_back(); VLOG(2) << "------- BEGIN ANNOTATION: " << annotation << " -------"; sched_state->ongoing_annotation = annotation; - TF_RETURN_IF_ERROR( - ScheduleAnnotation(computation, annotation, sched_state)); + RETURN_IF_ERROR(ScheduleAnnotation(computation, annotation, sched_state)); VLOG(2) << "------- END ANNOTATION: " << annotation << " --------"; sched_state->ongoing_annotation = -1; return true; @@ -3690,7 +3689,7 @@ DefaultSchedulerCore::MakeSchedulingState(const HloComputation* computation) { absl::StatusOr> DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) { - TF_ASSIGN_OR_RETURN(auto sched_state, MakeSchedulingState(computation)); + ASSIGN_OR_RETURN(auto sched_state, MakeSchedulingState(computation)); // Activate the log filter for this computation. ScopedVlogFilter filter_guard(computation->name(), config_.log_computation_re); @@ -3746,7 +3745,7 @@ DefaultSchedulerCore::ScheduleComputation( .live_ids_at_bottom); if (graph_processing_hook_) { - TF_RETURN_IF_ERROR(graph_processing_hook_(&sched_state->sched_graph)); + RETURN_IF_ERROR(graph_processing_hook_(&sched_state->sched_graph)); } VLOG(5) << "Just built graph:"; @@ -4164,18 +4163,18 @@ LatencyHidingScheduler::ScheduleWithPreferences( HloModule* module, const std::vector& preferences, const HloComputation* computation, std::shared_ptr sched_state) { - TF_RETURN_IF_ERROR(scheduler_core_->ResetScheduler(module)); + RETURN_IF_ERROR(scheduler_core_->ResetScheduler(module)); auto set_preferences = [&](HloScheduleGraph* graph) -> absl::Status { VLOG(3) << "Setting scheduling preferences."; graph->SetPreferences(preferences); return absl::OkStatus(); }; - TF_RETURN_IF_ERROR(scheduler_core_->SetGraphProcessingHook(set_preferences)); + RETURN_IF_ERROR(scheduler_core_->SetGraphProcessingHook(set_preferences)); absl::Cleanup clear_hook = [&] { scheduler_core_->SetGraphProcessingHook(nullptr).IgnoreError(); }; - TF_ASSIGN_OR_RETURN(auto new_schedule, scheduler_core_->ScheduleComputation( - computation, sched_state)); + ASSIGN_OR_RETURN(auto new_schedule, scheduler_core_->ScheduleComputation( + computation, sched_state)); // Save the old schedule. auto old_schedule = std::vector( @@ -4247,10 +4246,10 @@ absl::StatusOr LatencyHidingScheduler::RunImpl( if (computations_to_schedule_.empty()) { return false; } - TF_RETURN_IF_ERROR(scheduler_core_->InitializeScheduler(module)); + RETURN_IF_ERROR(scheduler_core_->InitializeScheduler(module)); const auto& debug_options = module->config().debug_options(); if (debug_options.xla_dump_latency_hiding_schedule()) { - TF_RETURN_IF_ERROR(scheduler_core_->CaptureScheduleProto()); + RETURN_IF_ERROR(scheduler_core_->CaptureScheduleProto()); } if (VLOG_IS_ON(1)) { // Log the statistics before scheduling. We batch the per-computation @@ -4268,8 +4267,8 @@ absl::StatusOr LatencyHidingScheduler::RunImpl( } } for (HloComputation* computation : computations_to_schedule_) { - TF_ASSIGN_OR_RETURN(std::vector new_schedule, - scheduler_core_->ScheduleComputation(computation)); + ASSIGN_OR_RETURN(std::vector new_schedule, + scheduler_core_->ScheduleComputation(computation)); // Update target specific states that may include altering the // computation. scheduling_context_->GetAsyncTracker()->UpdateTargetDefinedStates( @@ -4297,11 +4296,11 @@ absl::StatusOr LatencyHidingScheduler::RunImpl( << " bytes, does not fit in initial limit: " << initial_memory_limit << ". Setting the new limit to " << static_cast(scheduler_core_->GetMemoryLimit() * 0.9); - TF_RETURN_IF_ERROR(scheduler_core_->InitializeScheduler(module)); + RETURN_IF_ERROR(scheduler_core_->InitializeScheduler(module)); scheduler_core_->SetMemoryLimit(scheduler_core_->GetMemoryLimit() * 0.9); for (HloComputation* computation : computations_to_schedule_) { - TF_ASSIGN_OR_RETURN(std::vector new_schedule, - scheduler_core_->ScheduleComputation(computation)); + ASSIGN_OR_RETURN(std::vector new_schedule, + scheduler_core_->ScheduleComputation(computation)); scheduling_context_->GetAsyncTracker()->UpdateTargetDefinedStates( computation); module->schedule().set_sequence(computation, @@ -4338,8 +4337,8 @@ absl::StatusOr LatencyHidingScheduler::RunImpl( } } if (debug_options.xla_dump_latency_hiding_schedule()) { - TF_ASSIGN_OR_RETURN(ScheduleProto proto, - scheduler_core_->GetCapturedScheduleProto()); + ASSIGN_OR_RETURN(ScheduleProto proto, + scheduler_core_->GetCapturedScheduleProto()); const std::string filename = absl::StrFormat("%s.schedule", module->name()); DumpProtobufToFile(proto, debug_options, filename); } diff --git a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc index e2cf838df5b71c..73917c231f5d86 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/hlo/analysis/alias_info.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -220,14 +221,14 @@ absl::StatusOr RunScheduler( /*convert_collective_permute=*/HloPredicateTrue}; bool value = false; if (!skip_async_collective_creator) { - TF_ASSIGN_OR_RETURN(value, - AsyncCollectiveCreator(std::move(config)).Run(module)); + ASSIGN_OR_RETURN(value, + AsyncCollectiveCreator(std::move(config)).Run(module)); } if (!legalizer_config) { legalizer_config = std::make_unique(); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( value, LegalizeSchedulingAnnotations(std::move(*legalizer_config)).Run(module)); HloCostAnalysis::ShapeSizeFunction shape_size_bytes = @@ -254,9 +255,9 @@ absl::StatusOr RunScheduler( &alias_info, shape_size_bytes); auto scheduler_core = std::make_unique(scheduling_context, sched_config); - TF_ASSIGN_OR_RETURN(value, LatencyHidingScheduler(scheduling_context, - std::move(scheduler_core)) - .Run(module)); + ASSIGN_OR_RETURN(value, LatencyHidingScheduler(scheduling_context, + std::move(scheduler_core)) + .Run(module)); return value; } @@ -295,11 +296,11 @@ class LatencyHidingSchedulerTest : public HloHardwareIndependentTestBase { /*convert_all_gather=*/HloPredicateTrue, /*convert_collective_broadcast=*/HloPredicateTrue, /*convert_collective_permute=*/HloPredicateTrue}; - TF_ASSIGN_OR_RETURN(bool value, - AsyncCollectiveCreator(std::move(config)).Run(module)); - TF_ASSIGN_OR_RETURN(value, LegalizeSchedulingAnnotations( - LegalizeSchedulingAnnotations::Config()) - .Run(module)); + ASSIGN_OR_RETURN(bool value, + AsyncCollectiveCreator(std::move(config)).Run(module)); + ASSIGN_OR_RETURN(value, LegalizeSchedulingAnnotations( + LegalizeSchedulingAnnotations::Config()) + .Run(module)); if (!async_tracker) { async_tracker = std::make_unique(sched_config); diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 172e993686998d..97cecb74fbde22 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -36,6 +36,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/tuple_points_to_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -273,7 +274,7 @@ absl::Status LayoutAssignment::SetBufferLayout(const Layout& layout, VLOG(3) << "SetBufferLayout : " << buffer << " : " << LayoutUtil::HumanString(layout) << " with priority " << priority << "; mandatory = " << mandatory << "; dfs = " << dfs << "\n"; - TF_RETURN_IF_ERROR(points_to_analysis_->VerifyBuffer(buffer)); + RETURN_IF_ERROR(points_to_analysis_->VerifyBuffer(buffer)); if (unconstrained_buffer_ids_.erase(buffer.id()) > 0) { VLOG(3) << "Erase buffer from unconstrained ids\n"; } @@ -284,8 +285,7 @@ absl::Status LayoutAssignment::SetBufferLayout(const Layout& layout, "array-shaped, has shape: %s", buffer.ToString(), ShapeUtil::HumanString(buffer.shape())); } - TF_RETURN_IF_ERROR( - LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); + RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); auto& buffer_constraint = buffer_constraints_[&buffer]; if (buffer_constraint == nullptr) { @@ -323,9 +323,9 @@ absl::Status LayoutAssignment::SetBufferLayout(const Layout& layout, Shape shape(instruction->operand(operand_no)->shape()); *shape.mutable_layout() = layout; VLOG(3) << "operand_no=" << operand_no << ":" << shape.ToString(true); - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); - TF_RETURN_IF_ERROR(SetOperandLayout(shape, instruction, operand_no, - mandatory, dfs, priority)); + RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); + RETURN_IF_ERROR(SetOperandLayout(shape, instruction, operand_no, + mandatory, dfs, priority)); } } } @@ -406,7 +406,7 @@ absl::Status LayoutAssignment::SetArrayOperandLayout( TF_RET_CHECK(operand->shape().IsArray()); Shape shape(operand->shape()); *shape.mutable_layout() = layout; - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); + RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs, priority); } @@ -478,7 +478,7 @@ absl::Status LayoutAssignment::SetInstructionLayout( // Create a BufferLayoutConstraint for each array shape in the output of the // instruction. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( shape_with_layout, [this, dfs, instruction, mandatory, allow_alias, priority, subshape_index](const Shape& subshape, @@ -508,7 +508,7 @@ absl::Status LayoutAssignment::SetInstructionLayout( for (int i = 0; i < instruction->operand_count(); ++i) { if (instruction->operand(i)->shape().dimensions().size() == shape_with_layout.dimensions().size()) { - TF_RETURN_IF_ERROR(SetArrayOperandLayout( + RETURN_IF_ERROR(SetArrayOperandLayout( shape_with_layout.layout(), instruction, /*operand_no=*/i, /*mandatory=*/mandatory, /*dfs=*/dfs, priority)); } @@ -642,13 +642,13 @@ absl::Status PropagateParameterLayoutToUsers(const HloInstruction* instruction, auto tuple_index = user->tuple_index(); CHECK(shape.IsTuple()); auto elem_shape = shape.tuple_shapes(tuple_index); - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + RETURN_IF_ERROR(constraints->SetInstructionLayout( elem_shape, user, /*mandatory=*/false, /*dfs=*/false, /*allow_alias=*/true)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( PropagateParameterLayoutToUsers(user, elem_shape, constraints)); } else { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + RETURN_IF_ERROR(constraints->SetOperandLayout( shape, user, user->operand_index(instruction), /*mandatory=*/false, /*dfs=*/false)); } @@ -658,7 +658,7 @@ absl::Status PropagateParameterLayoutToUsers(const HloInstruction* instruction, absl::Status ResetMemorySpaceInLayout(ShapeLayout& mutable_shape_layout) { Shape shape = mutable_shape_layout.shape(); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( &shape, [](Shape* subshape, const ShapeIndex& shape_index) { if (subshape->has_layout() && subshape->IsArray()) { subshape->mutable_layout()->set_memory_space( @@ -666,7 +666,7 @@ absl::Status ResetMemorySpaceInLayout(ShapeLayout& mutable_shape_layout) { } return absl::OkStatus(); })); - TF_RETURN_IF_ERROR(mutable_shape_layout.CopyLayoutFromShape(shape)); + RETURN_IF_ERROR(mutable_shape_layout.CopyLayoutFromShape(shape)); return absl::OkStatus(); } @@ -691,15 +691,15 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( // instruction. // TODO(b/31425034): Change infeeds to be more like parameters, with // shapes in the ComputationLayout. - TF_RETURN_IF_ERROR(SetInstructionLayout(instruction->shape(), instruction, - /*mandatory=*/true, /*dfs=*/true, - /*allow_alias=*/false)); + RETURN_IF_ERROR(SetInstructionLayout(instruction->shape(), instruction, + /*mandatory=*/true, /*dfs=*/true, + /*allow_alias=*/false)); } else if (instruction->opcode() == HloOpcode::kOutfeed) { // Constrain the input to the Outfeed instruction to be the expected // layout of the Outfeed. - TF_RETURN_IF_ERROR(SetOperandLayout(instruction->outfeed_shape(), - instruction, 0, - /*mandatory=*/true, /*dfs=*/true)); + RETURN_IF_ERROR(SetOperandLayout(instruction->outfeed_shape(), + instruction, 0, + /*mandatory=*/true, /*dfs=*/true)); } else if (instruction->opcode() == HloOpcode::kParameter) { if (reverse_computation_order_ || (constraints->computation()->IsEntryComputation() && @@ -714,20 +714,19 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( if (parameter_layout.AnyLayoutIsSet()) { // Clear out memory space in layout. Host offloader will do the // analysis later. - TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(parameter_layout)); + RETURN_IF_ERROR(ResetMemorySpaceInLayout(parameter_layout)); // Parameter layouts must match the respective layout in // ComputationLayout, if there is one. Shape param_shape = parameter_layout.shape(); - TF_RETURN_IF_ERROR(SetInstructionLayout(param_shape, instruction)); + RETURN_IF_ERROR(SetInstructionLayout(param_shape, instruction)); if (reverse_computation_order_) { - TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers( - instruction, param_shape, this)); + RETURN_IF_ERROR(PropagateParameterLayoutToUsers(instruction, + param_shape, this)); } } } } else if (IsLayoutConstrainedCollective(instruction)) { - TF_RETURN_IF_ERROR( - SetInstructionLayout(instruction->shape(), instruction)); + RETURN_IF_ERROR(SetInstructionLayout(instruction->shape(), instruction)); for (int64_t i = 0; i < instruction->operand_count(); ++i) { CHECK(instruction->shape().IsArray() || instruction->shape().IsTuple() && @@ -735,8 +734,8 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( const Shape& shape = instruction->shape().IsTuple() ? instruction->shape().tuple_shapes(i) : instruction->shape(); - TF_RETURN_IF_ERROR(SetOperandLayout(shape, instruction, i, - /*mandatory=*/true, /*dfs=*/true)); + RETURN_IF_ERROR(SetOperandLayout(shape, instruction, i, + /*mandatory=*/true, /*dfs=*/true)); } } else if (instruction->IsCrossModuleAllReduce() && !instruction->GetModule()->config().use_spmd_partitioning()) { @@ -753,7 +752,7 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( Shape new_buffer_shape = get_channel_constraints(instruction) ->LayoutShapeForChannel(buffer_shape, channel_id); - TF_RETURN_IF_ERROR(SetInstructionLayout(new_buffer_shape, instruction)); + RETURN_IF_ERROR(SetInstructionLayout(new_buffer_shape, instruction)); } } @@ -772,15 +771,15 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( ->computation_layout(); auto result_shape = UnShardedShape( instruction, called_computation_layout.result_layout().shape(), -1); - TF_RETURN_IF_ERROR(SetInstructionLayout(result_shape, instruction)); + RETURN_IF_ERROR(SetInstructionLayout(result_shape, instruction)); TF_RET_CHECK(instruction->operand_count() == called_computation_layout.parameter_count()); for (int64_t i = 0; i < instruction->operand_count(); ++i) { auto operand_shape = UnShardedShape( instruction, called_computation_layout.parameter_layout(i).shape(), i); - TF_RETURN_IF_ERROR(SetOperandLayout(operand_shape, instruction, i, - /*mandatory=*/true, /*dfs=*/true)); + RETURN_IF_ERROR(SetOperandLayout(operand_shape, instruction, i, + /*mandatory=*/true, /*dfs=*/true)); } } else if (instruction->opcode() == HloOpcode::kWhile && computation_layouts_.find(instruction->while_body()) != @@ -835,9 +834,9 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( // Constrain the output and the operand of the while instruction to match // the computations. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetOperandLayout(body_layout.result_shape(), instruction, 0)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetInstructionLayout(body_layout.result_shape(), instruction)); } else if (instruction->opcode() == HloOpcode::kConditional && computation_layouts_.find(instruction->branch_computation(0)) != @@ -879,16 +878,16 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( instruction->branch_computation(k), branch_computation_layout); } else { - TF_RETURN_IF_ERROR(SetOperandLayout( + RETURN_IF_ERROR(SetOperandLayout( branch_computation_layout.parameter_shape(0), instruction, k + 1, /*mandatory=*/true, /*dfs=*/true)); } } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetOperandLayout(best_branch_computation_layout.parameter_shape(0), instruction, largest_branch + 1, /*mandatory=*/true, /*dfs=*/true)); - TF_RETURN_IF_ERROR(SetInstructionLayout( + RETURN_IF_ERROR(SetInstructionLayout( best_branch_computation_layout.result_shape(), instruction, /*mandatory=*/true, /*dfs=*/true, /*allow_alias=*/false)); } @@ -897,7 +896,7 @@ absl::Status LayoutAssignment::AddMandatoryConstraints( if (conditional_mismatch_.count(constraints->computation()) > 0) { VLOG(5) << "Setting mismatching conditional result:" << constraints->computation()->name() << "\n"; - TF_RETURN_IF_ERROR(constraints->SetResultLayout( + RETURN_IF_ERROR(constraints->SetResultLayout( this, FindOrDie(conditional_mismatch_, constraints->computation()) .result_layout() @@ -1121,8 +1120,8 @@ absl::StatusOr LayoutAssignment::CreateCopyWithNewLayout( } else { SetupCopiedInstruction(*instruction, gte, {i}); // Recurse to copy each element. - TF_ASSIGN_OR_RETURN(HloInstruction * element_copy, - CreateCopyWithNewLayout(target_shape, gte)); + ASSIGN_OR_RETURN(HloInstruction * element_copy, + CreateCopyWithNewLayout(target_shape, gte)); element_copies.push_back(element_copy); } } @@ -1131,7 +1130,7 @@ absl::StatusOr LayoutAssignment::CreateCopyWithNewLayout( HloInstruction::CreateTuple(element_copies)); SetupCopiedInstruction(*instruction, tuple_copy, {}); LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( shape_with_layout, tuple_copy->mutable_shape())); return tuple_copy; } @@ -1142,8 +1141,8 @@ absl::StatusOr LayoutAssignment::CreateCopyWithNewLayout( RegisterAddedCopy(copy); SetupCopiedInstruction(*instruction, copy, {}); LayoutUtil::ClearLayout(copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, copy->mutable_shape())); + RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(shape_with_layout, + copy->mutable_shape())); return copy; } @@ -1181,10 +1180,10 @@ absl::Status LayoutAssignment::CopyOperandIfLayoutsDiffer( auto param = branch_comp->parameter_instruction(0); *param->mutable_shape() = operand->shape(); auto param_users = param->users(); - TF_ASSIGN_OR_RETURN(HloInstruction * param_copy, - CreateCopyWithNewLayout(operand_layout.shape(), param)); + ASSIGN_OR_RETURN(HloInstruction * param_copy, + CreateCopyWithNewLayout(operand_layout.shape(), param)); for (auto user : param_users) { - TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy)); + RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy)); } VLOG(2) << "New copy of " << operand->ToString() << " is " << param_copy->ToString(); @@ -1205,8 +1204,8 @@ absl::Status LayoutAssignment::CopyOperandIfLayoutsDiffer( return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, - CreateCopyWithNewLayout(operand_layout.shape(), operand)); + ASSIGN_OR_RETURN(HloInstruction * operand_copy, + CreateCopyWithNewLayout(operand_layout.shape(), operand)); VLOG(4) << "New copy of " << operand->ToString() << " is " << operand_copy->ToString(); @@ -1239,22 +1238,21 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, absl::Status LayoutAssignment::CheckLayouts( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); + ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); for (auto* computation : module->MakeNonfusionComputations(execution_threads)) { for (auto* instruction : computation->instructions()) { // Verify every instruction has a layout and the layout is valid for the // shape. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); + RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); // Use points-to analysis to verify that every subshape element in the // output of the instruction matches the layout of the logical buffer // which could be the source of the subshape value. const PointsToSet& points_to_set = points_to_analysis->GetPointsToSet(instruction); - TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus( + RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus( [&instruction]( ShapeIndex index, const PointsToSet::BufferList& buffers) -> absl::Status { @@ -1283,31 +1281,31 @@ absl::Status LayoutAssignment::CheckLayouts( // Verify instructions that have special layout constraints. switch (instruction->opcode()) { case HloOpcode::kCall: - TF_RETURN_IF_ERROR(CheckCallLayout( + RETURN_IF_ERROR(CheckCallLayout( instruction, FindOrDie(computation_layouts_, instruction->to_apply()) ->computation_layout())); break; case HloOpcode::kCustomCall: - TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); + RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); break; case HloOpcode::kFusion: - TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); + RETURN_IF_ERROR(CheckFusionLayout(instruction)); break; case HloOpcode::kParameter: - TF_RETURN_IF_ERROR(CheckParameterLayout( + RETURN_IF_ERROR(CheckParameterLayout( instruction, FindOrDie(computation_layouts_, instruction->parent()) ->computation_layout())); break; case HloOpcode::kBroadcast: - TF_RETURN_IF_ERROR(CheckBroadcastLayout(instruction)); + RETURN_IF_ERROR(CheckBroadcastLayout(instruction)); break; case HloOpcode::kConstant: - TF_RETURN_IF_ERROR(CheckConstantLayout(instruction)); + RETURN_IF_ERROR(CheckConstantLayout(instruction)); break; case HloOpcode::kWhile: - TF_RETURN_IF_ERROR(CheckWhileLayout( + RETURN_IF_ERROR(CheckWhileLayout( instruction, FindOrDie(computation_layouts_, instruction->while_condition()) ->computation_layout(), @@ -1315,7 +1313,7 @@ absl::Status LayoutAssignment::CheckLayouts( ->computation_layout())); break; case HloOpcode::kOptimizationBarrier: - TF_RETURN_IF_ERROR(CheckOptimizationBarrierLayout(instruction)); + RETURN_IF_ERROR(CheckOptimizationBarrierLayout(instruction)); break; case HloOpcode::kConditional: { std::vector branch_computation_layouts; @@ -1326,7 +1324,7 @@ absl::Status LayoutAssignment::CheckLayouts( FindOrDie(computation_layouts_, branch_computation) ->computation_layout()); } - TF_RETURN_IF_ERROR(CheckConditionalLayout( + RETURN_IF_ERROR(CheckConditionalLayout( instruction, absl::MakeSpan(branch_computation_layouts))); break; } @@ -1606,17 +1604,17 @@ absl::Status LayoutAssignment::PropagateConstraints( << "; mandatory = " << layout_constraint->mandatory() << "\n"; if (auto* buffer_constraint = dynamic_cast(layout_constraint)) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( PropagateBufferConstraint(*buffer_constraint, constraints)); } else if (auto* operand_constraint = dynamic_cast( layout_constraint)) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( PropagateOperandConstraint(*operand_constraint, constraints)); } else if (auto* computation_constraint = dynamic_cast( layout_constraint)) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( PropagateResultConstraint(*computation_constraint, constraints)); } else { LOG(FATAL) << "Invalid constraint type: " << *layout_constraint; @@ -1674,9 +1672,9 @@ absl::Status LayoutAssignment::PropagateUseConstraintToDefs( if (buffer->shape().IsArray() && (buffer->instruction()->opcode() != HloOpcode::kReduce || !buffer->instruction()->shape().IsTuple())) { - TF_RETURN_IF_ERROR(SetBufferLayout(subshape.layout(), *buffer, - /*mandatory=*/false, - /*dfs=*/true, priority, user)); + RETURN_IF_ERROR(SetBufferLayout(subshape.layout(), *buffer, + /*mandatory=*/false, + /*dfs=*/true, priority, user)); } } } @@ -1720,7 +1718,7 @@ absl::Status LayoutAssignment::PropagateOperandConstraintToResultForCustomCall( shape_index)) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const LogicalBuffer* buffer, points_to_analysis_->GetBufferDefinedAt(user, shape_index)); @@ -1744,14 +1742,14 @@ absl::Status LayoutAssignment::PropagateOperandConstraint( << "\n"; // Try to set the layout of the logical buffers in the given operand to match // the constrained layout. This avoids copies. - TF_RETURN_IF_ERROR(PropagateUseConstraintToDefs( + RETURN_IF_ERROR(PropagateUseConstraintToDefs( operand_constraint.shape_layout(), operand_constraint.operand(), constraints, operand_constraint.priority(), operand_constraint.instruction())); const HloInstruction* user = operand_constraint.instruction(); if (user->opcode() == HloOpcode::kCustomCall) { - TF_RETURN_IF_ERROR(PropagateOperandConstraintToResultForCustomCall( + RETURN_IF_ERROR(PropagateOperandConstraintToResultForCustomCall( user, operand_constraint)); } // CustomCall that can't change layout, such as TopK, are handled below. @@ -1775,12 +1773,12 @@ absl::Status LayoutAssignment::PropagateOperandConstraint( user->operand_count() == 1 ? ShapeIndex() : ShapeIndex({operand_constraint.operand_no()}); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const LogicalBuffer* buffer, points_to_analysis_->GetBufferDefinedAt(user, shape_index)); - TF_RETURN_IF_ERROR( - SetBufferLayout(operand_constraint.shape_layout().layout(), *buffer, - /*mandatory=*/true, /*dfs=*/true)); + RETURN_IF_ERROR(SetBufferLayout(operand_constraint.shape_layout().layout(), + *buffer, + /*mandatory=*/true, /*dfs=*/true)); } if (InstructionCanChangeLayoutInstance(user) && !user->shape().IsArray() && @@ -1833,13 +1831,14 @@ absl::Status LayoutAssignment::PropagateOperandConstraint( if (operand_rank != sibling_rank) { continue; } - TF_RETURN_IF_ERROR(SetArrayOperandLayout( + RETURN_IF_ERROR(SetArrayOperandLayout( operand_constraint.shape_layout().layout(), user, operand_no, /*mandatory=*/true, /*dfs=*/true, operand_constraint.priority())); } - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( user->shape(), - [&](const Shape& subshape, const ShapeIndex& shape_index) { + [&](const Shape& subshape, + const ShapeIndex& shape_index) -> absl::Status { if (subshape.IsTuple()) { return absl::OkStatus(); } @@ -1859,7 +1858,7 @@ absl::Status LayoutAssignment::PropagateOperandConstraint( } // TODO(b/67641796): Are there cases except fusion that use this code // path? - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const LogicalBuffer* buffer, points_to_analysis_->GetBufferDefinedAt(user, shape_index)); // If we already have a constraint for the buffer it was assigned but @@ -1867,15 +1866,17 @@ absl::Status LayoutAssignment::PropagateOperandConstraint( // where one path is first evaluated in depth-first order (we're here) // and the other path is propagated later. We don't set the layout // here as it will always be overwritten later. - TF_RETURN_IF_ERROR(SetBufferLayout( + RETURN_IF_ERROR(SetBufferLayout( operand_constraint.shape_layout().layout(), *buffer, /*mandatory=*/true, /*dfs=*/true, operand_constraint.priority())); return absl::OkStatus(); })); return absl::OkStatus(); } - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + user->shape(), + [&](const Shape& subshape, + const ShapeIndex& shape_index) -> absl::Status { if (subshape.IsTuple()) { return absl::OkStatus(); } @@ -1886,14 +1887,14 @@ absl::Status LayoutAssignment::PropagateOperandConstraint( user, shape_index)) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( const LogicalBuffer* buffer, points_to_analysis_->GetBufferDefinedAt(user, shape_index)); std::unique_ptr layout = ChooseOutputLayoutFromOperandLayout( operand_constraint.shape_layout().layout(), user, operand_constraint.operand_no()); if (layout != nullptr) { - TF_RETURN_IF_ERROR(SetBufferLayout( + RETURN_IF_ERROR(SetBufferLayout( *layout, *buffer, /*mandatory=*/OperandLayoutAlwaysPropagateForward(user), /*dfs=*/InstructionShouldPropagateDepthFirst(*user), @@ -1917,7 +1918,7 @@ absl::Status LayoutAssignment::PropagateBufferConstraintToOperands( << buffer_constraint.ToString(); if (instruction->opcode() == HloOpcode::kAllReduce) { - TF_RETURN_IF_ERROR(SetArrayOperandLayout( + RETURN_IF_ERROR(SetArrayOperandLayout( buffer_constraint.layout(), instruction, instruction->operand_count() == 1 ? 0 : buffer.index()[0], /*mandatory=*/true, /*dfs=*/true, buffer_constraint.priority())); @@ -1934,7 +1935,7 @@ absl::Status LayoutAssignment::PropagateBufferConstraintToOperands( if (buffer.IsArray() && operand->shape().IsArray()) { if (operand->shape().dimensions().size() == LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) { - TF_RETURN_IF_ERROR(SetArrayOperandLayout( + RETURN_IF_ERROR(SetArrayOperandLayout( buffer_constraint.layout(), instruction, operand_no, /*mandatory=*/true, /*dfs=*/true, current_priority_)); } else if (instruction->opcode() == HloOpcode::kBitcastConvert) { @@ -1947,7 +1948,7 @@ absl::Status LayoutAssignment::PropagateBufferConstraintToOperands( ShapeUtil::AppendMinorDimension( operand->shape().dimensions().back(), &shape); } - TF_RETURN_IF_ERROR(SetArrayOperandLayout( + RETURN_IF_ERROR(SetArrayOperandLayout( shape.layout(), instruction, operand_no, /*mandatory=*/true, /*dfs=*/true, current_priority_)); } @@ -1956,7 +1957,7 @@ absl::Status LayoutAssignment::PropagateBufferConstraintToOperands( } else if (instruction->opcode() == HloOpcode::kBroadcast) { Layout layout = GetBroadcastLayoutFromOutput(buffer_constraint.layout(), instruction); - TF_RETURN_IF_ERROR(SetArrayOperandLayout( + RETURN_IF_ERROR(SetArrayOperandLayout( layout, instruction, operand_no, /*mandatory=*/true, /*dfs=*/ InstructionShouldPropagateDepthFirst(*instruction), @@ -1972,7 +1973,7 @@ absl::Status LayoutAssignment::PropagateBufferConstraintToOperands( ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(), instruction, operand_no); if (operand_layout != nullptr) { - TF_RETURN_IF_ERROR(SetArrayOperandLayout( + RETURN_IF_ERROR(SetArrayOperandLayout( *operand_layout, instruction, operand_no, /*mandatory=*/OutputLayoutAlwaysPropagateToOperands(instruction), /*dfs=*/ @@ -1992,7 +1993,7 @@ absl::Status LayoutAssignment::PropagateBufferConstraint( if (!buffer.IsArray()) { return absl::OkStatus(); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( PropagateBufferConstraintToOperands(buffer_constraint, constraints)); return PropagateBufferConstraintToUses(buffer_constraint, constraints); } @@ -2014,7 +2015,7 @@ absl::Status LayoutAssignment::PropagateBufferConstraintToUses( // Only add an operand constraint if the user does not forward the buffer // because this case is not handled is SetOperandLayout. if (!AnyOperandBufferForwarded(user, operand_no)) { - TF_RETURN_IF_ERROR(SetArrayOperandLayout( + RETURN_IF_ERROR(SetArrayOperandLayout( buffer_constraint.layout(), user, operand_no, /*mandatory=*/false, /*dfs=*/true, buffer_constraint.priority())); } @@ -2045,11 +2046,11 @@ absl::Status LayoutAssignment::PropagateBufferConstraintToUses( ShapeIndex used_index = buffer.index(); used_index.push_front(index); - TF_ASSIGN_OR_RETURN(auto buffer, points_to_analysis_->GetBufferDefinedAt( - inputs, used_index)); + ASSIGN_OR_RETURN(auto buffer, points_to_analysis_->GetBufferDefinedAt( + inputs, used_index)); - TF_RETURN_IF_ERROR(SetBufferLayout(buffer_constraint.layout(), *buffer, - /*mandatory=*/false)); + RETURN_IF_ERROR(SetBufferLayout(buffer_constraint.layout(), *buffer, + /*mandatory=*/false)); } } @@ -2064,7 +2065,7 @@ absl::Status LayoutAssignment::PropagateResultConstraint( // Clear out memory space in layout for entry computation root. Host offloader // will do the analysis later and add back the memory space for host outputs. if (constraints->computation()->IsEntryComputation()) { - TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout)); + RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout)); } // Propagate the use constraint of the root instruction up to the logical @@ -2132,25 +2133,25 @@ absl::Status SetFusionLayouts(HloInstruction* fusion) { fusion->operand(fused_instruction->parameter_number()); DCHECK(ShapeUtil::Compatible(fusion_operand->shape(), fused_instruction->shape())); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( fusion_operand->shape(), fused_instruction->mutable_shape())); } else if (fused_instruction == fusion->fused_expression_root()) { // The layout of the root of the fused expression must match the fusion // instruction layout. DCHECK( ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape())); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( fusion->shape(), fused_instruction->mutable_shape())); } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) { // A GTE inherits its layout from its operand (which should ultimately be // a parameter). - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( fused_instruction->operand(0)->shape().tuple_shapes( fused_instruction->tuple_index()), fused_instruction->mutable_shape())); } else if (fused_instruction->opcode() == HloOpcode::kConstant) { // Give constants the layout of their literal. - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( fused_instruction->literal().shape(), fused_instruction->mutable_shape())); } else if (fused_instruction->opcode() == HloOpcode::kInfeed) { @@ -2209,7 +2210,7 @@ absl::Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { // Any remaining layouts in the output of the instruction must be // inferrable using points-to analysis. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( instruction->mutable_shape(), [instruction, this](Shape* subshape, const ShapeIndex& index) { if (subshape->has_layout() || !subshape->IsArray()) { @@ -2217,8 +2218,8 @@ absl::Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { } // Set Layout of subshape to match layout of LogicalBuffer which // produces it. - TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(), - InferArrayLayout(instruction, index)); + ASSIGN_OR_RETURN(*subshape->mutable_layout(), + InferArrayLayout(instruction, index)); return absl::OkStatus(); })); VLOG(3) << "Instruction layout:" << instruction->ToString(); @@ -2229,22 +2230,22 @@ absl::Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { const ShapeLayout* operand_layout = constraints.OperandLayout(instruction, operand_no); if (operand_layout != nullptr) { - TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout, - instruction, operand_no)); + RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout, instruction, + operand_no)); } else { VLOG(2) << "operand " << operand_no << " has no constraint"; } } if (instruction->opcode() == HloOpcode::kFusion) { - TF_RETURN_IF_ERROR(SetFusionLayouts(instruction)); + RETURN_IF_ERROR(SetFusionLayouts(instruction)); } VLOG(3) << "Resulting instruction:" << instruction->ToString() << "\n"; // Execute extra verification step once the layout has been finalized. - TF_RETURN_IF_ERROR(Verify(instruction)); + RETURN_IF_ERROR(Verify(instruction)); // Shape must be valid. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); // Verify all layouts in the shape have been set. @@ -2257,14 +2258,14 @@ absl::Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) { ShapeLayout result_layout = *constraints.ResultLayout(); // Clear out memory space in layout. Host offloader will do the // analysis later. - TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout)); + RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout)); // Layout assignment at this point only does minor-to-major assignment so // tiling info should be ignored here for comparison. VLOG(5) << "Computation result layout needs root copying\n"; if (!result_layout.MatchesLayoutInShape( computation->root_instruction()->shape(), /*minor_to_major_only=*/true)) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_root, CreateCopyWithNewLayout(result_layout.shape(), computation->root_instruction())); @@ -2361,7 +2362,7 @@ absl::Status LayoutAssignment::CalculateComputationLayout( constraints->computation()->MakeInstructionPostOrder()) { switch (instruction->opcode()) { case HloOpcode::kFusion: - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetCalleeLayout(instruction, instruction->operands(), mutable_computation_constraints( instruction->fused_instructions_computation()), @@ -2382,7 +2383,7 @@ absl::Status LayoutAssignment::CalculateComputationLayout( // If the branches don't yet have layouts, propagate existing layout // inside the branches. for (int i = 0; i < instruction->branch_count(); ++i) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetCalleeLayout(instruction, {instruction->operand(i + 1)}, mutable_computation_constraints( instruction->branch_computation(i)), @@ -2395,13 +2396,13 @@ absl::Status LayoutAssignment::CalculateComputationLayout( if (reverse_computation_order_) { VLOG(2) << "Populating while loop constraints inside loop body."; VLOG(2) << instruction->ToString(); - TF_RETURN_IF_ERROR(SetCalleeLayout( + RETURN_IF_ERROR(SetCalleeLayout( instruction, {instruction->operand(0)}, mutable_computation_constraints(instruction->while_body()), current_priority_ + 1)); VLOG(2) << "Populating while loop constraints inside loop condition."; VLOG(2) << instruction->ToString(); - TF_RETURN_IF_ERROR(SetCalleeLayout( + RETURN_IF_ERROR(SetCalleeLayout( instruction->operand(0), {instruction->operand(0)}, mutable_computation_constraints(instruction->while_condition()), current_priority_ + 1)); @@ -2414,7 +2415,7 @@ absl::Status LayoutAssignment::CalculateComputationLayout( // Reset the layout of the current computation from its body. if (current_priority_ == 0 || conditional_mismatch_.count(constraints->computation()) > 0) { - TF_RETURN_IF_ERROR(SetCalleeLayout( + RETURN_IF_ERROR(SetCalleeLayout( constraints->computation()->root_instruction(), constraints->computation()->parameter_instructions(), constraints, current_priority_ + kNumberOfPropagationRounds)); @@ -2476,10 +2477,10 @@ absl::Status LayoutAssignment::RunOnComputation( // Add constraints required for correctness on all backends (eg, entry // parameter layout constraints). - TF_RETURN_IF_ERROR(AddMandatoryConstraints(channel_constraints, constraints)); + RETURN_IF_ERROR(AddMandatoryConstraints(channel_constraints, constraints)); // Add any backend-specific constraints. - TF_RETURN_IF_ERROR(AddBackendConstraints(constraints)); + RETURN_IF_ERROR(AddBackendConstraints(constraints)); for (HloInstruction* instruction : constraints->computation()->MakeInstructionPostOrder()) { @@ -2489,22 +2490,21 @@ absl::Status LayoutAssignment::RunOnComputation( const HloCustomCallInstruction* custom_call = DynCast(instruction); - TF_RETURN_IF_ERROR(SetInstructionLayout(custom_call->shape(), custom_call, - /*mandatory=*/true, /*dfs=*/true, - /*allow_alias=*/true)); + RETURN_IF_ERROR(SetInstructionLayout(custom_call->shape(), custom_call, + /*mandatory=*/true, /*dfs=*/true, + /*allow_alias=*/true)); if (custom_call->IsCustomCall("LayoutConstraint")) { - TF_RETURN_IF_ERROR( - SetOperandLayout(custom_call->shape(), custom_call, 0)); + RETURN_IF_ERROR(SetOperandLayout(custom_call->shape(), custom_call, 0)); } else { for (int64_t i = 0; i < custom_call->operand_count(); ++i) { if (AnyOperandBufferForwarded(custom_call, i)) { TF_RET_CHECK(AllOperandBuffersForwarded(custom_call, i)) << "Partial alias of an operand is not supported"; } else { - TF_RETURN_IF_ERROR(SetOperandLayout( + RETURN_IF_ERROR(SetOperandLayout( custom_call->operand_shapes_with_layout()[i], custom_call, i)); if (instruction->operand(i)->opcode() == HloOpcode::kCopy) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetOperandLayout(custom_call->operand_shapes_with_layout()[i], custom_call->operand(i), 0)); } @@ -2514,7 +2514,7 @@ absl::Status LayoutAssignment::RunOnComputation( } // Propagates layouts from mandatory and backend constraints. - TF_RETURN_IF_ERROR(PropagateConstraints(constraints)); + RETURN_IF_ERROR(PropagateConstraints(constraints)); // Prior to applying default layouts, we take note of all HLO instructions // which lack a layout constraint. @@ -2544,22 +2544,21 @@ absl::Status LayoutAssignment::RunOnComputation( buffer.index()) .layout() : GetUnconstrainedLayout(buffer); - TF_RETURN_IF_ERROR(SetBufferLayout(new_layout, buffer, - /*mandatory=*/false)); + RETURN_IF_ERROR(SetBufferLayout(new_layout, buffer, + /*mandatory=*/false)); - TF_RETURN_IF_ERROR(PropagateConstraints(constraints)); + RETURN_IF_ERROR(PropagateConstraints(constraints)); // To verify progress has been made, check that the number of unconstrained // buffers has been reduced. CHECK_LT(unconstrained_buffer_ids_.size(), unconstrained_count); } - TF_RETURN_IF_ERROR(CalculateComputationLayout(constraints)); + RETURN_IF_ERROR(CalculateComputationLayout(constraints)); // Record the layouts assigned for any communication ops in // channel_constraints so that they are constrained for future modules. if (channel_constraints != nullptr) { - TF_RETURN_IF_ERROR( - ConstrainChannelLayouts(computation, channel_constraints)); + RETURN_IF_ERROR(ConstrainChannelLayouts(computation, channel_constraints)); } return absl::OkStatus(); @@ -2573,7 +2572,7 @@ absl::Status LayoutAssignment::ConstrainChannelLayouts( instruction->opcode() != HloOpcode::kAllReduceStart && instruction->opcode() != HloOpcode::kAllReduceDone) { // TODO: b/501070020 - Support asynchronous all-reduce. - TF_ASSIGN_OR_RETURN(auto op_layout, InferArrayLayout(instruction, {})); + ASSIGN_OR_RETURN(auto op_layout, InferArrayLayout(instruction, {})); VLOG(5) << "Constrain cross module all reduce: " << op_layout.ToString() << "\n"; channel_constraints->ConstrainChannel(instruction->channel_id().value(), @@ -2591,7 +2590,7 @@ absl::Status LayoutAssignment::PropagateComputationLayouts( for (int64_t i = 0; i < computed_computation_layout.parameter_count(); ++i) { ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i); bool needs_assign = false; - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( param_layout->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) -> absl::Status { @@ -2639,7 +2638,7 @@ absl::StatusOr LayoutAssignment::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "Running layout assignment on module " << module->name(); - TF_RETURN_IF_ERROR(Init(module)); + RETURN_IF_ERROR(Init(module)); std::vector> operands_to_copy; for (HloComputation* computation : module->computations(execution_threads)) { @@ -2676,7 +2675,7 @@ absl::StatusOr LayoutAssignment::RunImpl( } } for (const auto [instruction, operand_no] : operands_to_copy) { - TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, operand_no)); + RETURN_IF_ERROR(AddCopyForOperand(instruction, operand_no)); } operands_to_copy.clear(); } @@ -2748,8 +2747,7 @@ absl::StatusOr LayoutAssignment::RunImpl( // So in the first pass, while allowing the layouts to flow to parameters and // root, we also fix up the eventually inconsistent ComputationLayout, which // will be then made mandatory by the second pass. - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); + ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); points_to_analysis_ = std::move(points_to_analysis); auto computations_to_work = module->MakeNonfusionComputations(execution_threads); @@ -2773,11 +2771,11 @@ absl::StatusOr LayoutAssignment::RunImpl( for (int64_t i = 0; changed || i < kNumberOfPropagationRounds; ++i) { changed = false; VLOG(1) << "Running " << (i == 0 ? "un" : "") << "constrained pass"; - TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module, execution_threads)); + RETURN_IF_ERROR(ClearPreviousPassSideEffects(module, execution_threads)); for (auto* computation : computations_to_work) { LayoutConstraints* constraints = mutable_computation_constraints(computation); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( RunOnComputation(constraints, channel_layout_constraints_)); } current_priority_ += 1; @@ -2785,7 +2783,7 @@ absl::StatusOr LayoutAssignment::RunImpl( mutable_computation_constraints(module->entry_computation()) ->mutable_computation_constraint() ->mutable_computation_layout(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( module->input_output_alias_config().ForEachAliasWithStatus( [&](const ShapeIndex& output_index, const HloInputOutputAliasConfig::Alias& alias) { @@ -2805,10 +2803,10 @@ absl::StatusOr LayoutAssignment::RunImpl( return absl::OkStatus(); } auto* entry = module->entry_computation(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto param_layout, InferArrayLayout(entry->parameter_instruction(param), index)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto result_layout, InferArrayLayout(entry->root_instruction(), output_index)); if (param_layout.minor_to_major() == @@ -2845,13 +2843,13 @@ absl::StatusOr LayoutAssignment::RunImpl( // All logical buffers should have constraints at this point. All that // remains is assign the constraints to the buffers and infer layouts for // aliased buffers. - TF_RETURN_IF_ERROR(AssignLayouts(*constraints)); + RETURN_IF_ERROR(AssignLayouts(*constraints)); } - TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(), - entry_computation_layout_)); + RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(), + entry_computation_layout_)); #ifndef NDEBUG - TF_RETURN_IF_ERROR(CheckLayouts(module, execution_threads)); + RETURN_IF_ERROR(CheckLayouts(module, execution_threads)); #endif // NDEBUG // All layouts are reset then reassigned by this pass. @@ -3037,15 +3035,15 @@ absl::Status LayoutAssignment::Init(HloModule* module) { for (HloInstruction* instruction : copies_to_remove) { VLOG(5) << "Removing added copy: " << instruction->ToString(); HloComputation* computation = instruction->parent(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); } added_copies_.clear(); TupleSimplifier tuple_simplifier; HloDCE dce; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); - TF_RETURN_IF_ERROR(dce.Run(module).status()); + RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + RETURN_IF_ERROR(dce.Run(module).status()); } return absl::OkStatus(); } @@ -3074,7 +3072,7 @@ absl::Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction, operand->shape(), HloOpcode::kCopy, operand)); SetupCopiedInstruction(*operand, copy, {}); LayoutUtil::ClearLayout(copy->mutable_shape()); - TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy)); + RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy)); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/layout_assignment_test.cc b/third_party/xla/xla/service/layout_assignment_test.cc index 0be67e482adb71..88ce73ace9674a 100644 --- a/third_party/xla/xla/service/layout_assignment_test.cc +++ b/third_party/xla/xla/service/layout_assignment_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -537,9 +538,9 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { operand->shape().dimensions().size()) { continue; } - TF_RETURN_IF_ERROR(SetArrayOperandLayout(buffer_constraint.layout(), - instruction, operand_no, - /*mandatory=*/true)); + RETURN_IF_ERROR(SetArrayOperandLayout(buffer_constraint.layout(), + instruction, operand_no, + /*mandatory=*/true)); } return PropagateBufferConstraintToUses(buffer_constraint, constraints); } diff --git a/third_party/xla/xla/service/layout_normalization.cc b/third_party/xla/xla/service/layout_normalization.cc index 420c28311b6145..cb891dcd3621c2 100644 --- a/third_party/xla/xla/service/layout_normalization.cc +++ b/third_party/xla/xla/service/layout_normalization.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -105,7 +106,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { *hlo->mutable_shape() = normalized_shape; HloInstruction* bc_to_orig = MaybeBitcast(hlo, shape); - TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWithDifferentShape(bc_to_orig)); + RETURN_IF_ERROR(hlo->ReplaceAllUsesWithDifferentShape(bc_to_orig)); MarkAsChanged(); return absl::OkStatus(); } @@ -117,8 +118,8 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { const Shape& s = hlo->shape(); const Shape& operand_shape = operand->shape(); TF_RET_CHECK(s.layout() == operand_shape.layout()); - TF_ASSIGN_OR_RETURN(HloInstruction * normalized_input, - GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(HloInstruction * normalized_input, + GetNormalizedInput(operand)); std::vector layout_as_permutation = ToTransposeDimensions(hlo->shape().layout()); @@ -127,17 +128,17 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { return Permute(input, layout_as_permutation); }; - TF_ASSIGN_OR_RETURN(HloInstruction * normalized_slice, - MakeSliceHlo(normalized_input, - normalize_slice_attr(hlo->slice_starts()), - normalize_slice_attr(hlo->slice_limits()), - normalize_slice_attr(hlo->slice_strides()), - &hlo->metadata())); + ASSIGN_OR_RETURN(HloInstruction * normalized_slice, + MakeSliceHlo(normalized_input, + normalize_slice_attr(hlo->slice_starts()), + normalize_slice_attr(hlo->slice_limits()), + normalize_slice_attr(hlo->slice_strides()), + &hlo->metadata())); *normalized_slice->mutable_shape()->mutable_layout() = normalized_input->shape().layout(); SetVisited(*normalized_slice); HloInstruction* bc_to_orig = MaybeBitcast(normalized_slice, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -164,7 +165,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { SetVisited(*bc_to_normalized); auto bc_to_orig = MaybeBitcast(bc_to_normalized, shape); if (bc_to_orig != hlo) { - TF_RETURN_IF_ERROR(hlo->ReplaceUsesWith(users, bc_to_orig)); + RETURN_IF_ERROR(hlo->ReplaceUsesWith(users, bc_to_orig)); MarkAsChanged(); } return absl::OkStatus(); @@ -180,7 +181,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { std::vector normalized_inputs; for (HloInstruction* operand : hlo->mutable_operands()) { - TF_ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand)); normalized_inputs.push_back(normalized_input); } auto normalized_shape = Normalize(s); @@ -192,7 +193,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { normalized_shape, normalized_inputs, normalized_concat_dim)); SetVisited(*normalized_concat); auto bc_to_orig = MaybeBitcast(normalized_concat, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -204,8 +205,8 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { HloInstruction* operand = hlo->mutable_operand(0); TF_RET_CHECK(hlo->shape().layout() == operand->shape().layout()); - TF_ASSIGN_OR_RETURN(HloInstruction * normalized_input, - GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(HloInstruction * normalized_input, + GetNormalizedInput(operand)); std::vector layout_as_permutation = ToTransposeDimensions(hlo->shape().layout()); @@ -221,16 +222,15 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { *new_window.add_dimensions() = d; } - TF_ASSIGN_OR_RETURN( - HloInstruction * rw, - MakeReduceWindowHlo(normalized_input, hlo->mutable_operand(1), - new_window, hlo->called_computations()[0], - &hlo->metadata())); + ASSIGN_OR_RETURN(HloInstruction * rw, + MakeReduceWindowHlo( + normalized_input, hlo->mutable_operand(1), new_window, + hlo->called_computations()[0], &hlo->metadata())); normalization_->UpdateLayout(rw->mutable_shape()); SetVisited(*rw); HloInstruction* bc_to_orig = MaybeBitcast(rw, hlo->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -247,7 +247,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { VLOG(3) << "Input broadcast: " << hlo->ToString(); auto s = hlo->shape(); auto operand = hlo->mutable_operand(0); - TF_ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand)); auto normalized_shape = Normalize(s); std::vector layout_as_permutation = ToTransposeDimensions(operand->shape().layout()); @@ -267,7 +267,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { SetVisited(*normalized_broadcast); VLOG(3) << "Generated broadcast: " << normalized_broadcast->ToString(); auto bc_to_orig = MaybeBitcast(normalized_broadcast, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -284,7 +284,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { SetVisited(*normalized_iota); VLOG(3) << "Generated iota: " << normalized_iota->ToString(); auto bc_to_orig = MaybeBitcast(normalized_iota, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -310,15 +310,15 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { if (ShapeUtil::LastDimIsMinorMost(shape_with_extra_dimension)) { const Shape original_shape = hlo->shape(); - TF_ASSIGN_OR_RETURN(HloInstruction * normalized_input, - GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(HloInstruction * normalized_input, + GetNormalizedInput(operand)); HloInstruction* normalized = hlo->parent()->AddInstruction( HloInstruction::CreateBitcastConvert(Normalize(hlo->shape()), normalized_input), &hlo->metadata()); SetVisited(*normalized); HloInstruction* bitcast_back = MaybeBitcast(normalized, original_shape); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bitcast_back)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bitcast_back)); return absl::OkStatus(); } @@ -345,7 +345,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { Layout::Equal().IgnoreElementSize()(s.layout(), operand_shape.layout())) << "Unexpected non-layout preserving elementwise unary: " << hlo->ToString(); - TF_ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(auto normalized_input, GetNormalizedInput(operand)); PrimitiveType to_element_type = s.element_type(); HloInstruction* new_unary; @@ -360,9 +360,8 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { new_unary = MakeBitcastConvertToHlo(normalized_input, to_element_type, &hlo->metadata()); } else { - TF_ASSIGN_OR_RETURN( - new_unary, - MakeUnaryHlo(hlo->opcode(), normalized_input, &hlo->metadata())); + ASSIGN_OR_RETURN(new_unary, MakeUnaryHlo(hlo->opcode(), normalized_input, + &hlo->metadata())); } if (normalized_input != new_unary) { // SetVisited() should only be called for unvisited ops. @@ -370,7 +369,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { SetVisited(*new_unary); } auto bc_to_orig = MaybeBitcast(new_unary, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -398,21 +397,20 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { layout_equal.IgnoreElementSize(); } TF_RET_CHECK(layout_equal(a->shape().layout(), s.layout())); - TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(a)); - TF_ASSIGN_OR_RETURN(auto b0, GetNormalizedInput(b)); + ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(a)); + ASSIGN_OR_RETURN(auto b0, GetNormalizedInput(b)); HloInstruction* new_binary; if (hlo->opcode() == HloOpcode::kCompare) { - TF_ASSIGN_OR_RETURN(new_binary, - MakeCompareHlo(hlo->comparison_direction(), a0, b0, - &hlo->metadata())); + ASSIGN_OR_RETURN(new_binary, MakeCompareHlo(hlo->comparison_direction(), + a0, b0, &hlo->metadata())); } else { - TF_ASSIGN_OR_RETURN( - new_binary, MakeBinaryHlo(hlo->opcode(), a0, b0, &hlo->metadata())); + ASSIGN_OR_RETURN(new_binary, + MakeBinaryHlo(hlo->opcode(), a0, b0, &hlo->metadata())); } SetVisited(*new_binary); auto bc_to_orig = MaybeBitcast(new_binary, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -429,13 +427,13 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto s = hlo->shape(); auto operand = hlo->mutable_operand(0); TF_RET_CHECK(ShapeUtil::ReshapeIsBitcast(s, operand->shape())); - TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand)); auto normalized_reshape_s = Normalize(s); - TF_ASSIGN_OR_RETURN(auto new_reshape, - MakeReshapeHlo(normalized_reshape_s, a0)); + ASSIGN_OR_RETURN(auto new_reshape, + MakeReshapeHlo(normalized_reshape_s, a0)); SetVisited(*new_reshape); auto bc_to_orig = MaybeBitcast(new_reshape, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -451,7 +449,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { return FailedPrecondition( "All scatter operands must have the same layout"); } - TF_ASSIGN_OR_RETURN(auto normalized_operand, GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(auto normalized_operand, GetNormalizedInput(operand)); normalized_operands.push_back(normalized_operand); } std::vector normalized_updates; @@ -462,7 +460,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { return FailedPrecondition( "All scatter updates must have the same layout"); } - TF_ASSIGN_OR_RETURN(auto normalized_update, GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(auto normalized_update, GetNormalizedInput(operand)); normalized_updates.push_back(normalized_update); } @@ -481,8 +479,8 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { "There should be just a single scatter dimension. Make sure to run " "ScatterSimplifier before LayoutNormalization"); } - TF_ASSIGN_OR_RETURN(auto normalized_indices, - GetNormalizedInput(scatter->scatter_indices())); + ASSIGN_OR_RETURN(auto normalized_indices, + GetNormalizedInput(scatter->scatter_indices())); // The scatter operands are normalized by applying a permutation such that // perm(layout) = standard layout -> inverse layout permutation is applied. @@ -556,7 +554,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { scatter->indices_are_sorted(), scatter->unique_indices())); SetVisited(*normalized_scatter); auto bc_to_orig = MaybeBitcast(normalized_scatter, scatter->shape()); - TF_RETURN_IF_ERROR(ReplaceInstruction(scatter, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(scatter, bc_to_orig)); return absl::OkStatus(); } @@ -580,7 +578,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto s = hlo->shape(); auto operand = hlo->mutable_operand(0); auto operand_s = operand->shape(); - TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand)); auto normalized_shape = Normalize(s); VLOG(3) << "Input transpose: " << hlo->ToString(); @@ -624,7 +622,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { VLOG(3) << "Processing copy: " << hlo->ToString(); auto s = hlo->shape(); auto operand = hlo->mutable_operand(0); - TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand)); auto s_normalized = Normalize(s); auto l0_perm = InversePermutation(ToTransposeDimensions(operand->shape().layout())); @@ -634,7 +632,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { HloInstruction::CreateTranspose(s_normalized, a0, dimensions)); SetVisited(*t); auto bc_to_orig = MaybeBitcast(t, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -642,7 +640,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { absl::Status HandleReverse(HloInstruction* hlo) override { auto s = hlo->shape(); auto operand = hlo->mutable_operand(0); - TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(operand)); std::vector layout_as_permutation = ToTransposeDimensions(hlo->shape().layout()); std::vector new_dimensions; @@ -656,7 +654,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { HloInstruction::CreateReverse(a0->shape(), a0, new_dimensions)); SetVisited(*normalized_reverse); auto bc_to_orig = MaybeBitcast(normalized_reverse, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -667,8 +665,8 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto operand = hlo->mutable_operand(0); auto padded_by = hlo->mutable_operand(1); auto padded_config = hlo->padding_config(); - TF_ASSIGN_OR_RETURN(HloInstruction * normalized_input, - GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(HloInstruction * normalized_input, + GetNormalizedInput(operand)); auto s_normalized = Normalize(s); auto layout_as_permutation = ToTransposeDimensions(s.layout()); @@ -689,7 +687,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { s_normalized, normalized_input, padded_by, new_padding)); SetVisited(*padded_normalized); auto bc_to_orig = MaybeBitcast(padded_normalized, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -701,12 +699,11 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { custom_call->raw_backend_config_string(); absl::InlinedVector original_operands( custom_call->operands().begin(), custom_call->operands().end()); - TF_ASSIGN_OR_RETURN( - std::optional transformed_custom_call, - custom_call_transformer_(custom_call)); + ASSIGN_OR_RETURN(std::optional transformed_custom_call, + custom_call_transformer_(custom_call)); if (transformed_custom_call) { SetVisited(*(*transformed_custom_call)->operand(0)); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, *transformed_custom_call)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, *transformed_custom_call)); return absl::OkStatus(); } if (custom_call->custom_call_target() != original_target || @@ -721,8 +718,8 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { absl::Status HandleConvolution(HloInstruction* hlo) override { std::vector normalized_operands; for (HloInstruction* operand : hlo->mutable_operands()) { - TF_ASSIGN_OR_RETURN(normalized_operands.emplace_back(), - GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(normalized_operands.emplace_back(), + GetNormalizedInput(operand)); } Shape normalized_shape = Normalize(hlo->shape()); @@ -770,7 +767,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { hlo->CloneWithNewOperands(normalized_shape, normalized_operands)); normalized_hlo->set_convolution_dimension_numbers(new_dnums); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ReplaceInstruction(hlo, MaybeBitcast(normalized_hlo, hlo->shape()))); return absl::OkStatus(); } @@ -790,8 +787,8 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { const Shape& operand_shape = operand->shape(); TF_RET_CHECK(s.layout() == operand_shape.layout()); - TF_ASSIGN_OR_RETURN(HloInstruction * normalized_input, - GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(HloInstruction * normalized_input, + GetNormalizedInput(operand)); Shape normalized = Normalize(operand_shape); @@ -803,7 +800,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto normalize_slice_attr = [&](absl::Span input) { return Permute(input, layout_as_permutation); }; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * normalized_dynamic_slice, MakeDynamicSliceHlo(normalized_input, new_start_indices, normalize_slice_attr(hlo->dynamic_slice_sizes()), @@ -812,7 +809,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { normalized_input->shape().layout(); SetVisited(*normalized_dynamic_slice); HloInstruction* bc_to_orig = MaybeBitcast(normalized_dynamic_slice, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -825,14 +822,12 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { std::vector layout_as_permutation = ToTransposeDimensions(hlo->shape().layout()); - TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, - GetNormalizedInput(operand)); - TF_ASSIGN_OR_RETURN(HloInstruction * new_update, - GetNormalizedInput(update)); + ASSIGN_OR_RETURN(HloInstruction * new_operand, GetNormalizedInput(operand)); + ASSIGN_OR_RETURN(HloInstruction * new_update, GetNormalizedInput(update)); std::vector new_start_indices = GetNewStartIdxs(hlo, /*param_offset=*/2, layout_as_permutation); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_dus, MakeDynamicUpdateSliceHlo(new_operand, new_update, new_start_indices, &hlo->metadata())); @@ -840,7 +835,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { SetVisited(*new_dus); HloInstruction* bc_to_orig = MaybeBitcast(new_dus, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } @@ -867,16 +862,16 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { TF_RET_CHECK(false); } - TF_ASSIGN_OR_RETURN(HloInstruction * normalized_arg0, - GetNormalizedInput(arg0)); - TF_ASSIGN_OR_RETURN(HloInstruction * normalized_arg1, - GetNormalizedInput(arg1)); - TF_ASSIGN_OR_RETURN(HloInstruction * normalized_arg2, - GetNormalizedInput(arg2)); + ASSIGN_OR_RETURN(HloInstruction * normalized_arg0, + GetNormalizedInput(arg0)); + ASSIGN_OR_RETURN(HloInstruction * normalized_arg1, + GetNormalizedInput(arg1)); + ASSIGN_OR_RETURN(HloInstruction * normalized_arg2, + GetNormalizedInput(arg2)); - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferTernaryOpShape( - opcode, normalized_arg0, - normalized_arg1, normalized_arg2)); + ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferTernaryOpShape( + opcode, normalized_arg0, + normalized_arg1, normalized_arg2)); HloInstruction* normalized = hlo->parent()->AddInstruction( HloInstruction::CreateTernary(new_shape, opcode, normalized_arg0, normalized_arg1, normalized_arg2)); @@ -884,7 +879,7 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { SetVisited(*normalized); HloInstruction* bc_to_orig = MaybeBitcast(normalized, s); - TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); + RETURN_IF_ERROR(ReplaceInstruction(hlo, bc_to_orig)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/legalize_scheduling_annotations.cc b/third_party/xla/xla/service/legalize_scheduling_annotations.cc index 5986394d05d3b2..5a3d0cc3eaa278 100644 --- a/third_party/xla/xla/service/legalize_scheduling_annotations.cc +++ b/third_party/xla/xla/service/legalize_scheduling_annotations.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -104,8 +105,8 @@ absl::Status AttachAnnotation( const absl::flat_hash_set& instructions, bool dry_run = false) { for (HloInstruction* instr : instructions) { - TF_ASSIGN_OR_RETURN(std::optional instr_annotation, - GetSchedulingAnnotation(instr)); + ASSIGN_OR_RETURN(std::optional instr_annotation, + GetSchedulingAnnotation(instr)); if (instr_annotation) { return absl::InternalError("Trying to propagate scheduling annotation " + annotation.ToString() + " to " + @@ -116,7 +117,7 @@ absl::Status AttachAnnotation( LOG(INFO) << "Propagating annotation " << annotation.ToString() << " to " << instr->name(); if (!dry_run) { - TF_RETURN_IF_ERROR(SetSchedulingAnnotation(instr, annotation)); + RETURN_IF_ERROR(SetSchedulingAnnotation(instr, annotation)); } } return absl::OkStatus(); @@ -224,8 +225,8 @@ absl::StatusOr HaulAnnotationToFusionInstruction( changed = true; std::optional seen_annotation; for (HloInstruction* instr : computation->instructions()) { - TF_ASSIGN_OR_RETURN(std::optional annotation, - GetSchedulingAnnotation(instr)); + ASSIGN_OR_RETURN(std::optional annotation, + GetSchedulingAnnotation(instr)); if (!annotation) { continue; } @@ -245,8 +246,8 @@ absl::StatusOr HaulAnnotationToFusionInstruction( if (!seen_annotation) { continue; } - TF_RETURN_IF_ERROR(SetSchedulingAnnotation(computation->FusionInstruction(), - seen_annotation->ToString())); + RETURN_IF_ERROR(SetSchedulingAnnotation(computation->FusionInstruction(), + seen_annotation->ToString())); } return changed; } @@ -255,8 +256,8 @@ absl::StatusOr RemoveLoopIterationAnnotation(HloModule* module) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instr : computation->instructions()) { - TF_ASSIGN_OR_RETURN(bool removed, - RemoveSchedulingAnnotationIterationId(instr)); + ASSIGN_OR_RETURN(bool removed, + RemoveSchedulingAnnotationIterationId(instr)); changed |= removed; } } @@ -585,8 +586,8 @@ absl::StatusOr LegalizeSchedulingAnnotations::RunImpl( for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instr : computation->instructions()) { - TF_ASSIGN_OR_RETURN(std::optional annotation, - GetSchedulingAnnotation(instr)); + ASSIGN_OR_RETURN(std::optional annotation, + GetSchedulingAnnotation(instr)); if (!annotation) { continue; } @@ -594,20 +595,20 @@ absl::StatusOr LegalizeSchedulingAnnotations::RunImpl( annotation_to_instruction[*annotation][computation].push_back(instr); } } - TF_RETURN_IF_ERROR(CheckGapBetweenAnnotatedInstructions( + RETURN_IF_ERROR(CheckGapBetweenAnnotatedInstructions( annotation_to_instruction, instruction_to_annotation)); return false; } // Run verification if requested. if (config_.run_verification) { - TF_RETURN_IF_ERROR(Verify(module)); + RETURN_IF_ERROR(Verify(module)); } bool changed = false; // Remove loop iteration annotation if requested. if (config_.remove_loop_iteration_annotation_only) { - TF_ASSIGN_OR_RETURN(bool removed, RemoveLoopIterationAnnotation(module)); + ASSIGN_OR_RETURN(bool removed, RemoveLoopIterationAnnotation(module)); changed |= removed; return changed; } @@ -626,8 +627,8 @@ absl::StatusOr LegalizeSchedulingAnnotations::RunImpl( for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instr : computation->instructions()) { - TF_ASSIGN_OR_RETURN(std::optional annotation, - GetSchedulingAnnotation(instr)); + ASSIGN_OR_RETURN(std::optional annotation, + GetSchedulingAnnotation(instr)); if (!annotation) { continue; } @@ -639,7 +640,7 @@ absl::StatusOr LegalizeSchedulingAnnotations::RunImpl( // Move the annotation from inside fusion computation to the caller // instruction if the caller doesn't have an annotation. Return an error if // there are some fused instructions with different annotations. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool haul_annotation_to_top_level, HaulAnnotationToFusionInstruction( module, execution_threads, annotation_to_instruction, @@ -688,7 +689,7 @@ absl::StatusOr LegalizeSchedulingAnnotations::RunImpl( changed |= result.value(); } } else { - TF_RETURN_IF_ERROR(CheckGapBetweenAnnotatedInstructions( + RETURN_IF_ERROR(CheckGapBetweenAnnotatedInstructions( annotation_to_instruction, instruction_to_annotation)); } @@ -702,15 +703,15 @@ absl::StatusOr CheckNoDataDependencyInSchedulingAnnotations::RunImpl( module->MakeNonfusionComputations(execution_threads)) { for (HloInstruction* instr : computation->instructions()) { if (HasSchedulingAnnotation(instr)) { - TF_ASSIGN_OR_RETURN(std::optional annotation, - GetSchedulingAnnotation(instr)); + ASSIGN_OR_RETURN(std::optional annotation, + GetSchedulingAnnotation(instr)); if (!annotation) { continue; } for (HloInstruction* operand : instr->operands()) { if (HasSchedulingAnnotation(operand)) { - TF_ASSIGN_OR_RETURN(std::optional operand_annotation, - GetSchedulingAnnotation(operand)); + ASSIGN_OR_RETURN(std::optional operand_annotation, + GetSchedulingAnnotation(operand)); if (!operand_annotation) { continue; } diff --git a/third_party/xla/xla/service/llvm_compiler.cc b/third_party/xla/xla/service/llvm_compiler.cc index 8f92011d2bc848..de8bc325677e70 100644 --- a/third_party/xla/xla/service/llvm_compiler.cc +++ b/third_party/xla/xla/service/llvm_compiler.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/executable.h" #include "xla/service/stream_pool.h" @@ -54,11 +55,10 @@ absl::StatusOr>> LLVMCompiler::Compile( return absl::StrFormat("XlaCompile:#module=%s,program_id=%d#", hlo_module->name(), hlo_module->unique_id()); }}; - TF_ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module), - stream_execs[0], options)); - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - RunBackend(std::move(hlo_module), stream_execs[0], options)); + ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module), + stream_execs[0], options)); + ASSIGN_OR_RETURN(std::unique_ptr executable, + RunBackend(std::move(hlo_module), stream_execs[0], options)); result.push_back(std::move(executable)); return std::move(result); diff --git a/third_party/xla/xla/service/llvm_ir/BUILD b/third_party/xla/xla/service/llvm_ir/BUILD index b6d917c22f84cd..22f7a29a052392 100644 --- a/third_party/xla/xla/service/llvm_ir/BUILD +++ b/third_party/xla/xla/service/llvm_ir/BUILD @@ -175,6 +175,7 @@ cc_library( "//xla:shape_util", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -199,6 +200,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:elemental_ir_emitter", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -223,6 +225,7 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service/cpu:backend_config_proto_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -263,6 +266,7 @@ cc_library( ":llvm_util", "//xla/service:hlo_module_config", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc b/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc index 348b78e47cd016..e760d52e347327 100644 --- a/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/third_party/xla/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -131,7 +132,7 @@ static absl::Status EmitDynamicUpdateSliceInPlaceImpl( const int64_t rank = output_shape.dimensions().size(); std::vector start_multi_index(rank); for (int64_t i = 0; i < rank; ++i) { - TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i)); + ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i)); llvm::Value* output_dim_size = llvm::ConstantInt::get( start_multi_index[i]->getType(), output_shape.dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( @@ -172,8 +173,8 @@ static absl::Status EmitDynamicUpdateSliceInPlaceImpl( // Do output[output_index] = update[update_index]. IrArray::Index output_index(output_multi_index, output_shape, b->getInt64Ty()); - TF_ASSIGN_OR_RETURN(llvm::Value * update_data, - update_array_generator(update_index)); + ASSIGN_OR_RETURN(llvm::Value * update_data, + update_array_generator(update_index)); output_array.EmitWriteArrayElement(output_index, update_data, b); return absl::OkStatus(); }; @@ -240,23 +241,23 @@ static absl::Status EmitFusedDynamicUpdateSliceInPlaceImpl( // through the chain of ops that gives us the update operand and use the // layout of its source buffer(s). But this is no worse than we do with // fusion elsewhere.) - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( dynamic_update_slice->shape(), &update_shape)); // Create element generators for update and start_indices. - TF_ASSIGN_OR_RETURN(ElementGenerator update_array_generator, - fused_emitter->GetGenerator(*update)); + ASSIGN_OR_RETURN(ElementGenerator update_array_generator, + fused_emitter->GetGenerator(*update)); IndexGenerator start_indices_generator = [&](int64_t index) -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(ElementGenerator element_generator, - fused_emitter->GetGenerator( - *dynamic_update_slice->operand(2 + index))); + ASSIGN_OR_RETURN(ElementGenerator element_generator, + fused_emitter->GetGenerator( + *dynamic_update_slice->operand(2 + index))); return element_generator(IrArray::Index(b->getInt64Ty())); }; bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); - TF_RETURN_IF_ERROR(EmitDynamicUpdateSliceInPlaceImpl( + RETURN_IF_ERROR(EmitDynamicUpdateSliceInPlaceImpl( update_shape, start_indices_generator, is_signed, update_array_generator, fusion_output_array, IrName(dynamic_update_slice), b)); diff --git a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc index 8a2c8de7a820e3..15fc252b5169e2 100644 --- a/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/fused_ir_emitter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" @@ -84,7 +85,7 @@ absl::StatusOr FusedIrEmitter::DefaultAction( } } - TF_ASSIGN_OR_RETURN(value, generator(index)); + ASSIGN_OR_RETURN(value, generator(index)); value_cache_[std::move(key)] = value; return value; }); @@ -143,8 +144,8 @@ absl::StatusOr FusedIrEmitter::HandleTuple( used_index = used_index.SourceIndexOfBitcast( tuple.operand(0)->shape(), tuple.operand(i)->shape(), b); } - TF_ASSIGN_OR_RETURN(llvm::Value * value, - indexed_generators_.at(tuple.operand(i))(used_index)); + ASSIGN_OR_RETURN(llvm::Value * value, + indexed_generators_.at(tuple.operand(i))(used_index)); ret = b->CreateInsertValue(ret, value, i); } return ret; @@ -178,7 +179,7 @@ absl::StatusOr FusedIrEmitter::GetGenerator( if (indexed_generator != nullptr) continue; stack.insert(stack.end(), instr.operands().begin(), instr.operands().end()); - TF_ASSIGN_OR_RETURN(indexed_generator, CreateGenerator(instr)); + ASSIGN_OR_RETURN(indexed_generator, CreateGenerator(instr)); } return indexed_generators_[&instruction]; } diff --git a/third_party/xla/xla/service/llvm_ir/kernel_support_library.cc b/third_party/xla/xla/service/llvm_ir/kernel_support_library.cc index 51d8ef32a247ca..6e039abb215870 100644 --- a/third_party/xla/xla/service/llvm_ir/kernel_support_library.cc +++ b/third_party/xla/xla/service/llvm_ir/kernel_support_library.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -47,7 +48,7 @@ absl::Status KernelSupportLibrary::ForWithStatus( llvm::Value* step, const std::function& for_body_generator) { return IfWithStatus(b_->CreateICmpSLT(start, end), [&]() -> absl::Status { - TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); + RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); return ForWithStatus( name, b_->CreateAdd(start, step), end, step, [&](llvm::Value* iv) { return for_body_generator(iv, false); }); @@ -63,7 +64,7 @@ absl::Status KernelSupportLibrary::ForWithStatus( /*unroll_mode=*/unroll_mode_, /*prevent_vectorization=*/prevent_vectorization_); b_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); - TF_RETURN_IF_ERROR(for_body_generator(loop->GetIndVarValue())); + RETURN_IF_ERROR(for_body_generator(loop->GetIndVarValue())); llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), b_); return absl::OkStatus(); } @@ -76,10 +77,10 @@ absl::Status KernelSupportLibrary::IfWithStatus( llvm_ir::EmitIfThenElse(condition, name, b_, /*emit_else=*/false_block_generator != nullptr); b_->SetInsertPoint(&if_data.true_block->back()); - TF_RETURN_IF_ERROR(true_block_generator()); + RETURN_IF_ERROR(true_block_generator()); if (false_block_generator != nullptr) { b_->SetInsertPoint(&if_data.false_block->back()); - TF_RETURN_IF_ERROR(false_block_generator()); + RETURN_IF_ERROR(false_block_generator()); } llvm_ir::SetToLastInsertPoint(if_data.after_block, b_); return absl::OkStatus(); diff --git a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc index 13f6a67764b346..59a951cedb4275 100644 --- a/third_party/xla/xla/service/llvm_ir/loop_emitter.cc +++ b/third_party/xla/xla/service/llvm_ir/loop_emitter.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" @@ -86,8 +87,8 @@ BodyEmitter MakeBodyEmitter(const ElementGenerator& target_element_generator, CHECK_EQ(target_arrays.size(), 1); return [=](const llvm_ir::IrArray::Index array_index) -> absl::Status { // Convert target_element_generator to a BodyEmitter. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - target_element_generator(array_index)); + ASSIGN_OR_RETURN(llvm::Value * target_element, + target_element_generator(array_index)); target_arrays_vec[0].EmitWriteArrayElement(array_index, target_element, b); return absl::OkStatus(); @@ -95,8 +96,8 @@ BodyEmitter MakeBodyEmitter(const ElementGenerator& target_element_generator, } return [=](const llvm_ir::IrArray::Index array_index) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - target_element_generator(array_index)); + ASSIGN_OR_RETURN(llvm::Value * target_element, + target_element_generator(array_index)); CHECK(target_element->getType()->isStructTy()) << "This BodyEmitter is for multi-output, but target element " "generator does not produce values of struct type."; @@ -209,7 +210,7 @@ absl::Status LoopEmitter::EmitLoop(absl::string_view loop_name, for (const IrArray::Index& array_index : EmitIndexAndSetExitBasicBlock(loop_name, index_type, /*base_index*/ nullptr)) { - TF_RETURN_IF_ERROR(body_emitter_(array_index)); + RETURN_IF_ERROR(body_emitter_(array_index)); } // Set the insertion point of b_ to the loop exit, so that diff --git a/third_party/xla/xla/service/local_service.cc b/third_party/xla/xla/service/local_service.cc index e612b0aab96472..d042741153315c 100644 --- a/third_party/xla/xla/service/local_service.cc +++ b/third_party/xla/xla/service/local_service.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/client/executable_build_options.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/service/backend.h" @@ -51,7 +52,7 @@ namespace xla { LocalService::NewService(const ServiceOptions& options) { se::Platform* platform = options.platform(); if (platform == nullptr) { - TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } BackendOptions backend_options; @@ -59,8 +60,8 @@ LocalService::NewService(const ServiceOptions& options) { .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()) .set_allowed_devices(options.allowed_devices()); - TF_ASSIGN_OR_RETURN(std::unique_ptr backend, - Backend::CreateBackend(backend_options)); + ASSIGN_OR_RETURN(std::unique_ptr backend, + Backend::CreateBackend(backend_options)); std::unique_ptr service( new LocalService(options, std::move(backend))); @@ -76,7 +77,7 @@ LocalService::CompileExecutables( const XlaComputation& computation, const absl::Span argument_layouts, const ExecutableBuildOptions& build_options) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr module_config, GetHloModuleConfig(computation, argument_layouts, build_options, &options_, execute_backend_.get())); @@ -84,7 +85,7 @@ LocalService::CompileExecutables( VLOG(3) << "Computation Layout: " << module_config->entry_computation_layout().ToString(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); @@ -102,7 +103,7 @@ LocalService::CompileExecutables( build_options.process_count()}, build_options.slice_size()}; if (build_options.num_partitions() == 1) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr executable, BuildExecutable(computation.proto(), std::move(module_config), execute_backend_.get(), executor, compile_options, @@ -127,12 +128,12 @@ LocalService::CompileAotResults( const XlaComputation& computation, const absl::Span argument_layouts, const ExecutableBuildOptions& build_options) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr module_config, GetHloModuleConfig(computation, argument_layouts, build_options, &options_, execute_backend_.get())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); @@ -151,16 +152,15 @@ LocalService::CompileAotResults( absl::StatusOr LocalService::ReplicaNumberToDeviceOrdinal( int replica_number) { - TF_ASSIGN_OR_RETURN( - DeviceAssignment da, - backend().computation_placer()->AssignDevices( - options_.number_of_replicas(), /*computation_count=*/1)); + ASSIGN_OR_RETURN(DeviceAssignment da, + backend().computation_placer()->AssignDevices( + options_.number_of_replicas(), /*computation_count=*/1)); return da.DeviceId(replica_number, /*computation=*/0); } absl::StatusOr LocalService::GlobalDataToShapedBuffer( const GlobalDataHandle& data, int replica_number) { - TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); + ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); if (replica_number >= buffers.size()) { return InvalidArgument( "replica_number %d out of range; must be less than num_replicas = %u.", diff --git a/third_party/xla/xla/service/local_service_utils.cc b/third_party/xla/xla/service/local_service_utils.cc index af6d7f0c3f2541..5a52294660a16d 100644 --- a/third_party/xla/xla/service/local_service_utils.cc +++ b/third_party/xla/xla/service/local_service_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/client/executable_build_options.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -72,8 +73,8 @@ absl::StatusOr> GetHloModuleConfig( Backend* backend) { const HloModuleProto& proto = computation.proto(); TF_RET_CHECK(proto.has_host_program_shape()); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - ProgramShape::FromProto(proto.host_program_shape())); + ASSIGN_OR_RETURN(ProgramShape program_shape, + ProgramShape::FromProto(proto.host_program_shape())); // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { @@ -84,8 +85,7 @@ absl::StatusOr> GetHloModuleConfig( for (int i = 0; i < argument_layouts.size(); ++i) { const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { std::optional metadata = ParameterMetadata(computation, /*parameter_number=*/i); @@ -108,8 +108,8 @@ absl::StatusOr> GetHloModuleConfig( } } if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(Service::ValidateResultShape( - *build_options.result_layout(), program_shape.result())); + RETURN_IF_ERROR(Service::ValidateResultShape(*build_options.result_layout(), + program_shape.result())); } ExecutionOptions execution_options = diff --git a/third_party/xla/xla/service/loop_schedule_linearizer.cc b/third_party/xla/xla/service/loop_schedule_linearizer.cc index 39329ad2c7440c..feb8decabe2e49 100644 --- a/third_party/xla/xla/service/loop_schedule_linearizer.cc +++ b/third_party/xla/xla/service/loop_schedule_linearizer.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" @@ -152,7 +153,7 @@ static absl::StatusOr AddControlEdgesForLoopWrites( // Add control dependency if it does not already exist. if (!absl::c_linear_search(read->control_successors(), write)) { // Unless we want a copy, read should happen before write. - TF_RETURN_IF_ERROR(read->AddControlDependencyTo(write)); + RETURN_IF_ERROR(read->AddControlDependencyTo(write)); VLOG(2) << "Adding dependency: " << read->ToShortString() << " before " << write->ToShortString(); changed = true; @@ -198,11 +199,11 @@ absl::StatusOr LoopScheduleLinearizer::RunImpl( } if (alias_analysis == nullptr) { - TF_ASSIGN_OR_RETURN(alias_analysis, - HloAliasAnalysis::Run(module, alias_info_)); + ASSIGN_OR_RETURN(alias_analysis, + HloAliasAnalysis::Run(module, alias_info_)); } - TF_ASSIGN_OR_RETURN(bool updated_loop, AddControlEdgesForLoopWrites( - instruction, *alias_analysis)); + ASSIGN_OR_RETURN(bool updated_loop, AddControlEdgesForLoopWrites( + instruction, *alias_analysis)); changed |= updated_loop; } } diff --git a/third_party/xla/xla/service/map_inliner.cc b/third_party/xla/xla/service/map_inliner.cc index e003f189f60203..8a0191fa623922 100644 --- a/third_party/xla/xla/service/map_inliner.cc +++ b/third_party/xla/xla/service/map_inliner.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -62,7 +63,7 @@ class MapInlinerVisitor : public DfsHloVisitorWithDefault { absl::StatusOr MapInlinerVisitor::Run(HloComputation* computation) { changed_ = false; computation_ = computation; - TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); + RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); return changed_; } @@ -81,9 +82,9 @@ absl::Status MapInlinerVisitor::HandleMap(HloInstruction* map) { if (root.opcode() == HloOpcode::kParameter) { // If the root is a parameter, then use the corresponding operand as the // result of the computation. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( map->ReplaceAllUsesWith(map->operands()[root.parameter_number()])); - TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map)); + RETURN_IF_ERROR(computation_->RemoveInstruction(map)); } else if (root.opcode() == HloOpcode::kConstant) { // If the input is a constant then the shape of the constant could be // different than the map shape. Hence, a broadcast is needed, else the @@ -94,7 +95,7 @@ absl::Status MapInlinerVisitor::HandleMap(HloInstruction* map) { HloInstruction* constant = computation_->AddInstruction(root.Clone()); HloInstruction* placed_instruction = computation_->AddInstruction( HloInstruction::CreateBroadcast(map->shape(), constant, {})); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); } else { std::vector params; @@ -104,7 +105,7 @@ absl::Status MapInlinerVisitor::HandleMap(HloInstruction* map) { } HloInstruction* placed_instruction = computation_->AddInstruction( root.CloneWithNewOperands(map->shape(), params)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); } changed_ = true; @@ -120,7 +121,7 @@ absl::StatusOr MapInliner::RunImpl( MapInlinerVisitor visitor(/*computation=*/nullptr); bool changed = false; for (HloComputation* computation : module->computations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation)); + ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation)); changed |= computation_changed; } return changed; diff --git a/third_party/xla/xla/service/mapped_ptr_container_sorter.h b/third_party/xla/xla/service/mapped_ptr_container_sorter.h index 1e5cea62fc00f9..801049f7441c13 100644 --- a/third_party/xla/xla/service/mapped_ptr_container_sorter.h +++ b/third_party/xla/xla/service/mapped_ptr_container_sorter.h @@ -47,6 +47,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -302,14 +303,14 @@ MappedPtrContainerSorter::SortedIndices::Flatten() const { const auto& indices = target_index_to_unmapped_element_index_.at(IndexBeforeMappedElements()); for (size_t index : indices) { - TF_ASSIGN_OR_RETURN(result[index], next_index_fn()); + ASSIGN_OR_RETURN(result[index], next_index_fn()); } } size_t num_inserted_mapped_elements = 0; for (const auto& mapped_element_indices : mapped_element_indices_by_partial_order_) { for (size_t mapped_element_index : mapped_element_indices) { - TF_ASSIGN_OR_RETURN(result[mapped_element_index], next_index_fn()); + ASSIGN_OR_RETURN(result[mapped_element_index], next_index_fn()); ++num_inserted_mapped_elements; if (target_index_to_unmapped_element_index_.contains( num_inserted_mapped_elements - 1)) { @@ -317,7 +318,7 @@ MappedPtrContainerSorter::SortedIndices::Flatten() const { target_index_to_unmapped_element_index_.at( num_inserted_mapped_elements - 1); for (size_t unmapped_element_index : unmapped_element_indices) { - TF_ASSIGN_OR_RETURN(result[unmapped_element_index], next_index_fn()); + ASSIGN_OR_RETURN(result[unmapped_element_index], next_index_fn()); } } } @@ -327,7 +328,7 @@ MappedPtrContainerSorter::SortedIndices::Flatten() const { const auto& indices = target_index_to_unmapped_element_index_.at(IndexAfterMappedElements()); for (size_t index : indices) { - TF_ASSIGN_OR_RETURN(result[index], next_index_fn()); + ASSIGN_OR_RETURN(result[index], next_index_fn()); } } @@ -409,7 +410,7 @@ MappedPtrContainerSorter::ComputeNewIndices( // Potentially, several elements in ordered_container map to ptr. // We assign ptr theindex corresponding to the next such ordered element. auto& index_list = mapped_ptr_to_partial_order[ptr]; - TF_RETURN_IF_ERROR(result.AddMappedElement(i, index_list.front())); + RETURN_IF_ERROR(result.AddMappedElement(i, index_list.front())); // Do not map more than one unordered element to the same index, unless we // have no choice. if (index_list.size() > 1) { @@ -445,9 +446,9 @@ absl::Status MappedPtrContainerSorter::Sort( MapPtrFn map_ptr, UnmappedPtrIndexFn unmapped_index, const OrderedTy& ordered_container, UnorderedTy& unordered_container) { std::vector indices; - TF_ASSIGN_OR_RETURN( - indices, ComputeNewIndices(map_ptr, unmapped_index, ordered_container, - unordered_container)); + ASSIGN_OR_RETURN(indices, + ComputeNewIndices(map_ptr, unmapped_index, ordered_container, + unordered_container)); Reorder(std::move(indices), unordered_container); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/matmul_indexing_utils.cc b/third_party/xla/xla/service/matmul_indexing_utils.cc index e03c6df8e0fdfa..d469aa72745603 100644 --- a/third_party/xla/xla/service/matmul_indexing_utils.cc +++ b/third_party/xla/xla/service/matmul_indexing_utils.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -104,9 +105,9 @@ absl::StatusOr ContractingDimensionIndex(const HloInstruction& dot, absl::StatusOr NonContractingDimensionIndex(const HloInstruction& dot, const int operand_number) { - TF_ASSIGN_OR_RETURN(int64_t contracting_dim, - ContractingDimensionIndex(dot, operand_number)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(int64_t contracting_dim, + ContractingDimensionIndex(dot, operand_number)); + ASSIGN_OR_RETURN( std::vector non_contracting_dims, GetNonContractingDims(dot.operand(operand_number)->shape(), BatchDimensionsForOperand(dot, operand_number), @@ -129,23 +130,23 @@ DotOperandDims::DotOperandDims(Shape shape, absl::StatusOr> DotOperandDims::FromDot( const HloInstruction* dot) { - TF_ASSIGN_OR_RETURN(auto lhs_dims, FromDotOperand(dot, 0)); - TF_ASSIGN_OR_RETURN(auto rhs_dims, FromDotOperand(dot, 1)); + ASSIGN_OR_RETURN(auto lhs_dims, FromDotOperand(dot, 0)); + ASSIGN_OR_RETURN(auto rhs_dims, FromDotOperand(dot, 1)); return std::array{lhs_dims, rhs_dims}; } absl::StatusOr> DotOperandDims::FromScaledDot( const HloInstruction* scaled_dot) { - TF_ASSIGN_OR_RETURN(auto lhs_dims, FromDotOperand(scaled_dot, 0)); + ASSIGN_OR_RETURN(auto lhs_dims, FromDotOperand(scaled_dot, 0)); DotOperandDims lhs_scale_dims; if (!ShapeUtil::IsScalar(scaled_dot->operand(2)->shape())) { - TF_ASSIGN_OR_RETURN(lhs_scale_dims, FromDotOperand(scaled_dot, 2)); + ASSIGN_OR_RETURN(lhs_scale_dims, FromDotOperand(scaled_dot, 2)); } - TF_ASSIGN_OR_RETURN(auto rhs_dims, FromDotOperand(scaled_dot, 1)); + ASSIGN_OR_RETURN(auto rhs_dims, FromDotOperand(scaled_dot, 1)); DotOperandDims rhs_scale_dims; if (!ShapeUtil::IsScalar(scaled_dot->operand(3)->shape())) { - TF_ASSIGN_OR_RETURN(rhs_scale_dims, FromDotOperand(scaled_dot, 3)); + ASSIGN_OR_RETURN(rhs_scale_dims, FromDotOperand(scaled_dot, 3)); } return std::array{lhs_dims, rhs_dims, lhs_scale_dims, @@ -158,9 +159,8 @@ absl::StatusOr DotOperandDims::FromDotOperand( const auto& batch_dims = BatchDimensionsForOperand(*dot, operand_number); const auto& contracting_dims = ContractingDimensionsForOperand(*dot, operand_number); - TF_ASSIGN_OR_RETURN( - std::vector non_contracting_dims, - GetNonContractingDims(shape, batch_dims, contracting_dims)); + ASSIGN_OR_RETURN(std::vector non_contracting_dims, + GetNonContractingDims(shape, batch_dims, contracting_dims)); return DotOperandDims(shape, batch_dims, non_contracting_dims, contracting_dims); } @@ -220,8 +220,8 @@ absl::StatusOr DotOperandDims::ComputeOutputShape( operand.shape_.is_dynamic_dimension(nc_dim)); } } - TF_ASSIGN_OR_RETURN(Shape output_shape, ShapeUtil::MakeValidatedShape( - element_type, output_dimensions)); + ASSIGN_OR_RETURN(Shape output_shape, ShapeUtil::MakeValidatedShape( + element_type, output_dimensions)); for (int64_t i = 0; i < output_dynamic_dimensions.size(); ++i) { output_shape.set_dynamic_dimension(i, output_dynamic_dimensions[i]); } diff --git a/third_party/xla/xla/service/multi_output_fusion.cc b/third_party/xla/xla/service/multi_output_fusion.cc index e37d9c4a68527a..3997d80c7c6a0e 100644 --- a/third_party/xla/xla/service/multi_output_fusion.cc +++ b/third_party/xla/xla/service/multi_output_fusion.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -87,7 +88,7 @@ absl::StatusOr MultiOutputFusion::RunImpl( CHECK_OK(module->RemoveUnusedComputations()); if (changed) { HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status()); + RETURN_IF_ERROR(dce.Run(module, execution_threads).status()); } return changed; } diff --git a/third_party/xla/xla/service/p2p_schedule_preparation.cc b/third_party/xla/xla/service/p2p_schedule_preparation.cc index da00aefd86a10b..d88fe3151af7fb 100644 --- a/third_party/xla/xla/service/p2p_schedule_preparation.cc +++ b/third_party/xla/xla/service/p2p_schedule_preparation.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -447,13 +448,13 @@ absl::Status MayAddWhileOpToPipelinedGroup(HloInstruction* while_op, return Internal( "Expecting up to two pipelined P2P groups for each while-loop"); } - TF_RETURN_IF_ERROR(group->second.RecordWhileOpToPipelinedGroup(while_op)); + RETURN_IF_ERROR(group->second.RecordWhileOpToPipelinedGroup(while_op)); } return absl::OkStatus(); } absl::Status OrderBefore(HloInstruction* i1, HloInstruction* i2) { - TF_RETURN_IF_ERROR(i1->AddControlDependencyTo(i2)); + RETURN_IF_ERROR(i1->AddControlDependencyTo(i2)); VLOG(10) << "Add control predecessor " << i2->ToString(); return absl::OkStatus(); } @@ -465,9 +466,9 @@ absl::Status ConnectP2P1NodeChain(const P2PGroupNode& node) { HloRecvInstruction* recv = node.recv; HloSendDoneInstruction* send_done = node.send_done; HloSendInstruction* send = node.send; - TF_RETURN_IF_ERROR(OrderBefore(recv, send)); - TF_RETURN_IF_ERROR(OrderBefore(send, recv_done)); - TF_RETURN_IF_ERROR(OrderBefore(recv_done, send_done)); + RETURN_IF_ERROR(OrderBefore(recv, send)); + RETURN_IF_ERROR(OrderBefore(send, recv_done)); + RETURN_IF_ERROR(OrderBefore(recv_done, send_done)); return absl::OkStatus(); } @@ -500,15 +501,15 @@ absl::Status ConnectP2P2NodeChain(const P2PGroupNode& node0, HloSendRecvInstruction* send_done1 = node1.send_done; HloSendInstruction* send1 = node1.send; - TF_RETURN_IF_ERROR(OrderBefore(recv_done0, recv_done1)); - TF_RETURN_IF_ERROR(OrderBefore(recv_done1, send_done0)); - TF_RETURN_IF_ERROR(OrderBefore(send_done0, send_done1)); + RETURN_IF_ERROR(OrderBefore(recv_done0, recv_done1)); + RETURN_IF_ERROR(OrderBefore(recv_done1, send_done0)); + RETURN_IF_ERROR(OrderBefore(send_done0, send_done1)); - TF_RETURN_IF_ERROR(OrderBefore(recv0, send0)); - TF_RETURN_IF_ERROR(OrderBefore(send0, recv1)); - TF_RETURN_IF_ERROR(OrderBefore(recv1, send1)); + RETURN_IF_ERROR(OrderBefore(recv0, send0)); + RETURN_IF_ERROR(OrderBefore(send0, recv1)); + RETURN_IF_ERROR(OrderBefore(recv1, send1)); - TF_RETURN_IF_ERROR(OrderBefore(send1, recv_done0)); + RETURN_IF_ERROR(OrderBefore(send1, recv_done0)); return absl::OkStatus(); } @@ -589,15 +590,15 @@ absl::Status GatherP2PGroupsAndCollectiveInfo( // P2P group and may turn it into a kPipelined group or kUnrecognized // group. P2PGroup group; - TF_RETURN_IF_ERROR(group.RecordP2POpForUnpipelinedGroup(p2p)); + RETURN_IF_ERROR(group.RecordP2POpForUnpipelinedGroup(p2p)); p2p_group_map[channel] = group; } else { P2PGroup& group = p2p_group->second; if (group.ChildComputation() == computation) { - TF_RETURN_IF_ERROR(group.RecordP2POpForUnpipelinedGroup(p2p)); + RETURN_IF_ERROR(group.RecordP2POpForUnpipelinedGroup(p2p)); } else { // We are at the parent computation for a pipelined P2P group. - TF_RETURN_IF_ERROR(group.RecordP2POpForPipelinedGroup(p2p)); + RETURN_IF_ERROR(group.RecordP2POpForPipelinedGroup(p2p)); } } // We can't rely on the operation on p2p_group_map above to find out @@ -615,7 +616,7 @@ absl::Status GatherP2PGroupsAndCollectiveInfo( } for (auto hlo : while_ops) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( MayAddWhileOpToPipelinedGroup(hlo, p2p_in_computation, p2p_group_map)); } @@ -654,7 +655,7 @@ absl::Status GatherP2PGroupsAndCollectiveInfo( p2p_group.InCycle() || p2p_group.runtime_stream != kStream1) { continue; } - TF_RETURN_IF_ERROR(p2p_group.RecordComplementGroup(p2p_group_map)); + RETURN_IF_ERROR(p2p_group.RecordComplementGroup(p2p_group_map)); } return absl::OkStatus(); @@ -683,23 +684,23 @@ absl::StatusOr> ConnectP2PChain( P2PGroupKind kind = p2p_group.kind; if (kind == P2PGroupKind::kUnpipelined) { if (!p2p_group.InCycle()) { - TF_RETURN_IF_ERROR(ConnectUnpipelinedP2P(p2p_group)); + RETURN_IF_ERROR(ConnectUnpipelinedP2P(p2p_group)); } else if (p2p_group.runtime_stream == kStream1) { - TF_RETURN_IF_ERROR(ConnectUnpipelined2P2P(p2p_group, p2p_group_map)); + RETURN_IF_ERROR(ConnectUnpipelined2P2P(p2p_group, p2p_group_map)); } continue; } if (!p2p_group.InCycle()) { if (computation == p2p_group.ParentComputation()) { - TF_RETURN_IF_ERROR(ConnectPipelined1P2PParent(p2p_group)); + RETURN_IF_ERROR(ConnectPipelined1P2PParent(p2p_group)); } else { // A pipeline of one group. if (pipelined_group != nullptr) { return Internal("Expected <=1 pipelined group in a while-body"); } pipelined_group = &p2p_group; - TF_RETURN_IF_ERROR(ConnectPipelined1P2PChild(p2p_group)); + RETURN_IF_ERROR(ConnectPipelined1P2PChild(p2p_group)); } continue; } @@ -711,7 +712,7 @@ absl::StatusOr> ConnectP2PChain( } if (computation == p2p_group.ParentComputation()) { - TF_RETURN_IF_ERROR(ConnectPipelined2P2PParent(p2p_group, p2p_group_map)); + RETURN_IF_ERROR(ConnectPipelined2P2PParent(p2p_group, p2p_group_map)); } else { if (pipelined_group != nullptr) { return Internal( @@ -719,7 +720,7 @@ absl::StatusOr> ConnectP2PChain( "while-body"); } pipelined_group = &p2p_group; - TF_RETURN_IF_ERROR(ConnectPipelined2P2PChild(p2p_group, p2p_group_map)); + RETURN_IF_ERROR(ConnectPipelined2P2PChild(p2p_group, p2p_group_map)); } } return std::make_pair(num_p2p_chains, pipelined_group); @@ -729,7 +730,7 @@ absl::Status OrderBefore(HloReachabilityMap* reachability, HloInstruction* a, HloInstruction* b) { VLOG(10) << "OrderBefore " << a->ToString() << " " << b->ToString(); if (!reachability->IsReachable(a, b)) { - TF_RETURN_IF_ERROR(a->AddControlDependencyTo(b)); + RETURN_IF_ERROR(a->AddControlDependencyTo(b)); VLOG(10) << "add control predecessor " << b->ToString(); reachability->UpdateReachabilityThroughInstruction(b); } @@ -787,11 +788,11 @@ absl::Status LinearizeCollectivesWithOtherP2P( if (reachability->IsReachable(start_end.first, cur_start_end.second)) { // Order chain A before chain B. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( OrderBefore(reachability, start_end.second, cur_start_end.first)); } else { // Order chain B before chain A. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( OrderBefore(reachability, cur_start_end.second, start_end.first)); } continue; @@ -809,20 +810,20 @@ absl::Status LinearizeCollectivesWithOtherP2P( if (hlo_query::IsAsyncCollectiveDoneOp(hlo, /*include_send_recv=*/false)) { if (reachability->IsReachable(start_end.first, hlo)) { // Order chain A before the async op. - TF_RETURN_IF_ERROR(OrderBefore(reachability, start_end.second, - GetStartOpForDoneOp(hlo))); + RETURN_IF_ERROR(OrderBefore(reachability, start_end.second, + GetStartOpForDoneOp(hlo))); } else { // Order the async op before chain A. - TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.first)); + RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.first)); } } // CustomCall or other op that indirectly invoke collectives. if (reachability->IsReachable(start_end.first, hlo)) { // Order chain A before the op. - TF_RETURN_IF_ERROR(OrderBefore(reachability, start_end.second, hlo)); + RETURN_IF_ERROR(OrderBefore(reachability, start_end.second, hlo)); } else { // Order the op before chain A. - TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.first)); + RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.first)); } } @@ -868,14 +869,14 @@ absl::Status LinearizeCollectivesWithPipelinedP2PChild( ChainStartEnd cur_start_end = cur_group.GetChainStartEnd(computation, p2p_group_map); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( OrderBefore(reachability, cur_start_end.second, start_end.first)); continue; } // Async done, CustomCall, or other ops that indirectly invoke collectives. - TF_RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.first)); + RETURN_IF_ERROR(OrderBefore(reachability, hlo, start_end.first)); } return absl::OkStatus(); @@ -901,7 +902,7 @@ absl::StatusOr P2PSchedulePreparation::RunImpl( ++iter) { VLOG(10) << "Gathering P2P groups and collective info for computation " << (*iter)->name(); - TF_RETURN_IF_ERROR(GatherP2PGroupsAndCollectiveInfo( + RETURN_IF_ERROR(GatherP2PGroupsAndCollectiveInfo( *iter, p2p_in_computation, p2p_group_map, collective_in_computation)); } @@ -927,8 +928,8 @@ absl::StatusOr P2PSchedulePreparation::RunImpl( // Connect P2P chains and return the number of chains and the P2P group // representation for pipelined P2P in the current computation as a // while-body. - TF_ASSIGN_OR_RETURN( - auto result, ConnectP2PChain(computation, p2p_group_map, p2p_channels)); + ASSIGN_OR_RETURN(auto result, + ConnectP2PChain(computation, p2p_group_map, p2p_channels)); if (result.first == 0) { continue; } @@ -942,7 +943,7 @@ absl::StatusOr P2PSchedulePreparation::RunImpl( // The current computation is a while-body with pipelined P2P chain. // Order all other collectives in a pipelined while-body before the // pipelined P2P chain. - TF_RETURN_IF_ERROR(LinearizeCollectivesWithPipelinedP2PChild( + RETURN_IF_ERROR(LinearizeCollectivesWithPipelinedP2PChild( p2p_group_map, *result.second, collective_in_computation, computation, reachability.get())); } @@ -982,7 +983,7 @@ absl::StatusOr P2PSchedulePreparation::RunImpl( VLOG(10) << "linearize other collectives with respect to channel " << hlo->ToString(); - TF_RETURN_IF_ERROR(LinearizeCollectivesWithOtherP2P( + RETURN_IF_ERROR(LinearizeCollectivesWithOtherP2P( p2p_group_map, group, collective_in_computation, instr_it, all_instructions.begin(), all_instructions.end(), reachability.get())); diff --git a/third_party/xla/xla/service/platform_util.cc b/third_party/xla/xla/service/platform_util.cc index 9d5be2d45d48a3..8d22fec7d6864c 100644 --- a/third_party/xla/xla/service/platform_util.cc +++ b/third_party/xla/xla/service/platform_util.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/debug_options_flags.h" #include "xla/service/compiler.h" #include "xla/status_macros.h" @@ -116,9 +117,9 @@ absl::Status IsDeviceSupported(se::StreamExecutor* executor) { absl::StatusOr ExecutorForDevice(se::Platform* platform, int device_ordinal) { - TF_ASSIGN_OR_RETURN(se::StreamExecutor * exec, - platform->ExecutorForDevice(device_ordinal)); - TF_RETURN_IF_ERROR(IsDeviceSupported(exec)); + ASSIGN_OR_RETURN(se::StreamExecutor * exec, + platform->ExecutorForDevice(device_ordinal)); + RETURN_IF_ERROR(IsDeviceSupported(exec)); return exec; } @@ -197,7 +198,7 @@ absl::StatusOr PlatformUtil::GetDefaultPlatform() { "double-check that you are using a PJRT-compatible test class.", allow_default); } - TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); TF_RET_CHECK(!platforms.empty()) << "No platforms found"; if (platforms.size() == 1) { @@ -227,10 +228,10 @@ absl::StatusOr PlatformUtil::GetDefaultPlatform() { /*static*/ absl::StatusOr PlatformUtil::GetPlatform( absl::string_view platform_name) { - TF_ASSIGN_OR_RETURN(se::Platform * platform, - se::PlatformManager::PlatformWithName( - xla::CanonicalPlatformName(platform_name))); - TF_RETURN_IF_ERROR(Compiler::GetForPlatform(platform->id()).status()); + ASSIGN_OR_RETURN(se::Platform * platform, + se::PlatformManager::PlatformWithName( + xla::CanonicalPlatformName(platform_name))); + RETURN_IF_ERROR(Compiler::GetForPlatform(platform->id()).status()); return platform; } @@ -238,8 +239,8 @@ absl::StatusOr> PlatformUtil::GetStreamExecutors( se::Platform* platform, const std::optional>& allowed_devices) { - TF_ASSIGN_OR_RETURN(std::vector device_ordinals, - GetDeviceOrdinals(platform, allowed_devices)); + ASSIGN_OR_RETURN(std::vector device_ordinals, + GetDeviceOrdinals(platform, allowed_devices)); std::vector> executors( device_ordinals.size(), diff --git a/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc b/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc index 0e16f2be6316d1..72eb44df032004 100644 --- a/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc +++ b/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status_matchers.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/alias_info.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_schedule.h" @@ -81,10 +82,9 @@ absl::StatusOr RunScheduler( &alias_info, shape_size_bytes); auto scheduler_core = std::make_unique(scheduling_context, sched_config); - TF_ASSIGN_OR_RETURN( - bool value, - LatencyHidingScheduler(scheduling_context, std::move(scheduler_core)) - .Run(module)); + ASSIGN_OR_RETURN(bool value, LatencyHidingScheduler(scheduling_context, + std::move(scheduler_core)) + .Run(module)); return value; } diff --git a/third_party/xla/xla/service/reduce_scatter_combiner.cc b/third_party/xla/xla/service/reduce_scatter_combiner.cc index 84b504c26dfda4..f065e729140ab4 100644 --- a/third_party/xla/xla/service/reduce_scatter_combiner.cc +++ b/third_party/xla/xla/service/reduce_scatter_combiner.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -152,7 +153,7 @@ absl::Status CombineReduceScatters( combined->set_metadata(MergeMetadata(to_combine)); combined->set_frontend_attributes(MergeFrontendAttributes(to_combine)); if (post_combine != nullptr) { - TF_RETURN_IF_ERROR(post_combine(to_combine, combined)); + RETURN_IF_ERROR(post_combine(to_combine, combined)); } // We have to propagate the sharding manually because Domain instructions are @@ -173,8 +174,7 @@ absl::Status CombineReduceScatters( replacement->shape()), replacement)); } - TF_RETURN_IF_ERROR( - computation.ReplaceInstruction(to_combine[i], replacement)); + RETURN_IF_ERROR(computation.ReplaceInstruction(to_combine[i], replacement)); } return absl::OkStatus(); } @@ -245,7 +245,7 @@ absl::StatusOr ReduceScatterCombiner::RunWithKeyCombiner( << computation->ToString(); continue; } - TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, "")); + ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, "")); auto key_fn = [&](const HloInstruction* instruction) { return combine_key(instruction, *domain_map, combine_by_dim_); @@ -255,7 +255,7 @@ absl::StatusOr ReduceScatterCombiner::RunWithKeyCombiner( return CombineReduceScatters(to_combine, post_combine); }; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool computation_changed, CombineInstructionsByKey( computation, key_fn, combine_fn, combine_threshold_in_bytes_, @@ -278,8 +278,8 @@ ReduceScatterCombiner::ReduceScatterCombiner(int64_t combine_threshold_in_bytes, absl::StatusOr ReduceScatterCombiner::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN( - bool changed, RunWithKeyCombiner(module, execution_threads, CombineKey)); + ASSIGN_OR_RETURN(bool changed, + RunWithKeyCombiner(module, execution_threads, CombineKey)); return changed; } diff --git a/third_party/xla/xla/service/reduce_scatter_combiner_test.cc b/third_party/xla/xla/service/reduce_scatter_combiner_test.cc index 57f3aca3e3dcd9..08ce74b86a08b6 100644 --- a/third_party/xla/xla/service/reduce_scatter_combiner_test.cc +++ b/third_party/xla/xla/service/reduce_scatter_combiner_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/log.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -43,7 +44,7 @@ class ReduceScatterCombinerTest : public HloHardwareIndependentTestBase { int64_t byte_threshold = kMaxByteCount, int64_t count_threshold = kMaxCombineCount, bool combine_by_dim = true, bool combine_while_loops = true) { - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); + ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); VLOG(1) << "Before running ReduceScatterCombiner: " << ReduceScatterCount(module.get()) << " reduce-scatter ops"; diff --git a/third_party/xla/xla/service/reduce_scatter_decomposer.cc b/third_party/xla/xla/service/reduce_scatter_decomposer.cc index f3cdf12259a934..d15690d6cc9cad 100644 --- a/third_party/xla/xla/service/reduce_scatter_decomposer.cc +++ b/third_party/xla/xla/service/reduce_scatter_decomposer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -75,11 +76,10 @@ absl::StatusOr ReduceScatterDecomposer::RunImpl( // Create start indices for a dynamic slice to decompose the all-reduce // results. - TF_ASSIGN_OR_RETURN( - CollectiveOpGroupMode group_mode, - GetCollectiveOpGroupMode(rs->channel_id().has_value(), - rs->use_global_device_ids())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(rs->channel_id().has_value(), + rs->use_global_device_ids())); + ASSIGN_OR_RETURN( std::vector start_indices, CreateStartIndicesForCollectiveDecomposition( group_mode, rs->replica_groups(), rs->shape(), @@ -89,8 +89,8 @@ absl::StatusOr ReduceScatterDecomposer::RunImpl( computation->AddInstruction(HloInstruction::CreateDynamicSlice( rs->shape(), ar, start_indices, rs->shape().dimensions())); - TF_RETURN_IF_ERROR(rs->ReplaceAllUsesWith(ds)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(rs)); + RETURN_IF_ERROR(rs->ReplaceAllUsesWith(ds)); + RETURN_IF_ERROR(computation->RemoveInstruction(rs)); changed = true; } } diff --git a/third_party/xla/xla/service/reduce_scatter_reassociate.cc b/third_party/xla/xla/service/reduce_scatter_reassociate.cc index 61d9ef774f6fab..0bbde0d28dfd2e 100644 --- a/third_party/xla/xla/service/reduce_scatter_reassociate.cc +++ b/third_party/xla/xla/service/reduce_scatter_reassociate.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -88,8 +89,8 @@ absl::StatusOr ReduceScatterReassociate::RunImpl( VLOG(2) << "Reduce-Scatter operations have > 1 users"; continue; } - TF_ASSIGN_OR_RETURN(auto rs0_annotation, GetSchedulingAnnotation(rs0)); - TF_ASSIGN_OR_RETURN(auto rs1_annotation, GetSchedulingAnnotation(rs1)); + ASSIGN_OR_RETURN(auto rs0_annotation, GetSchedulingAnnotation(rs0)); + ASSIGN_OR_RETURN(auto rs1_annotation, GetSchedulingAnnotation(rs1)); if (rs0_annotation.has_value() && rs1_annotation.has_value() && *rs0_annotation != *rs1_annotation) { VLOG(2) << "If two reduce scatters have different scheduling group do " @@ -115,14 +116,14 @@ absl::StatusOr ReduceScatterReassociate::RunImpl( new_rs->set_channel_id(next_channel_id++); } - TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(new_rs)); + RETURN_IF_ERROR(inst->ReplaceAllUsesWith(new_rs)); // Note that RemoveInstructionAndUnusedOperands may not remove the 2 // reduce-scatter operands of `inst` if they are not safe to remove // otherwise, so manually these instructions. - TF_RETURN_IF_ERROR(computation->RemoveInstruction(inst)); - TF_RETURN_IF_ERROR(computation->RemoveInstruction(rs0)); + RETURN_IF_ERROR(computation->RemoveInstruction(inst)); + RETURN_IF_ERROR(computation->RemoveInstruction(rs0)); if (rs0 != rs1) { - TF_RETURN_IF_ERROR(computation->RemoveInstruction(rs1)); + RETURN_IF_ERROR(computation->RemoveInstruction(rs1)); } changed = true; } diff --git a/third_party/xla/xla/service/reduce_scatter_reassociate_test.cc b/third_party/xla/xla/service/reduce_scatter_reassociate_test.cc index fb0403f1a0a701..9961044be6dabf 100644 --- a/third_party/xla/xla/service/reduce_scatter_reassociate_test.cc +++ b/third_party/xla/xla/service/reduce_scatter_reassociate_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/reduce_scatter_reassociate.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" @@ -31,7 +32,7 @@ class ReduceScatterReassociateTest : public HloHardwareIndependentTestBase { public: absl::StatusOr> RunPass( absl::string_view hlo_module, bool expect_change) { - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); + ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); auto changed = ReduceScatterReassociate().Run(module.get()); if (!changed.ok()) { return changed.status(); diff --git a/third_party/xla/xla/service/riegeli_dump_writer.cc b/third_party/xla/xla/service/riegeli_dump_writer.cc index a931737ca929ef..c586856f79a550 100644 --- a/third_party/xla/xla/service/riegeli_dump_writer.cc +++ b/third_party/xla/xla/service/riegeli_dump_writer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "riegeli/bytes/writer.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/dump.h" @@ -49,7 +50,7 @@ absl::StatusOr> CreateRiegeliDumpWriter( module == nullptr ? std::string(filename) : FilenameFor(*module, TimestampFor(*module), filename); - TF_RETURN_IF_ERROR(CreateDirIfNeeded(opts.dump_to, tsl::Env::Default())); + RETURN_IF_ERROR(CreateDirIfNeeded(opts.dump_to, tsl::Env::Default())); std::string file_path = tsl::io::JoinPath(opts.dump_to, SanitizeFileName(partial_path)); diff --git a/third_party/xla/xla/service/scan_expander.cc b/third_party/xla/xla/service/scan_expander.cc index a8a1419acabde4..589f688ef24d98 100644 --- a/third_party/xla/xla/service/scan_expander.cc +++ b/third_party/xla/xla/service/scan_expander.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -39,7 +40,7 @@ namespace { absl::StatusOr BuildConditionComputation( HloScanInstruction* scan, const Shape& loop_state_shape) { - TF_ASSIGN_OR_RETURN(int64_t scan_dim_size, scan->GetScanDimSize()); + ASSIGN_OR_RETURN(int64_t scan_dim_size, scan->GetScanDimSize()); HloComputation::Builder builder(absl::StrCat(scan->name(), "_condition")); auto* param = builder.AddInstruction( @@ -147,7 +148,7 @@ absl::StatusOr BuildBodyComputation( } else { num_outputs = 1 - num_carries; } - TF_ASSIGN_OR_RETURN(int64_t scan_dim_size, scan->GetScanDimSize()); + ASSIGN_OR_RETURN(int64_t scan_dim_size, scan->GetScanDimSize()); Shape scalar_shape = ShapeUtil::MakeShape(S64, {}); HloComputation::Builder builder(absl::StrCat(scan->name(), "_body")); @@ -230,7 +231,7 @@ absl::StatusOr BuildBodyComputation( scan->parent()->parent()->AddEmbeddedComputation(builder.Build()); // Inline the call instruction within body_computation - TF_RETURN_IF_ERROR(CallInliner::Inline(call).status()); + RETURN_IF_ERROR(CallInliner::Inline(call).status()); return body_computation; } @@ -285,11 +286,11 @@ absl::StatusOr ScanExpander::ExpandInstruction( } Shape loop_state_shape = ShapeUtil::MakeTupleShape(loop_state_shapes); - TF_ASSIGN_OR_RETURN(HloComputation * condition_computation, - BuildConditionComputation(scan, loop_state_shape)); + ASSIGN_OR_RETURN(HloComputation * condition_computation, + BuildConditionComputation(scan, loop_state_shape)); - TF_ASSIGN_OR_RETURN(HloComputation * body_computation, - BuildBodyComputation(scan, loop_state_shape)); + ASSIGN_OR_RETURN(HloComputation * body_computation, + BuildBodyComputation(scan, loop_state_shape)); // 3. Build Init Loop State std::vector init_values; diff --git a/third_party/xla/xla/service/scan_loop_accumulator_input_unification.cc b/third_party/xla/xla/service/scan_loop_accumulator_input_unification.cc index 8771873f293a79..379ffaa2183848 100644 --- a/third_party/xla/xla/service/scan_loop_accumulator_input_unification.cc +++ b/third_party/xla/xla/service/scan_loop_accumulator_input_unification.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" @@ -236,12 +237,12 @@ absl::StatusOr UnifyAccumulatorWithInput( VLOG(3) << while_instr->name() << " -> " << "tuple_index() << ": " << acc->name() << ", " << "input_@" << input->tuple_index() << ": " << input->name() << ">"; - TF_RETURN_IF_ERROR(input->ReplaceAllUsesWith(acc)); - TF_RETURN_IF_ERROR(while_instr->while_init()->ReplaceOperandWith( + RETURN_IF_ERROR(input->ReplaceAllUsesWith(acc)); + RETURN_IF_ERROR(while_instr->while_init()->ReplaceOperandWith( acc->tuple_index(), while_instr->while_init()->mutable_operand(input->tuple_index()))); if (input->user_count() == 0) { - TF_RETURN_IF_ERROR(while_instr->while_body()->RemoveInstruction(input)); + RETURN_IF_ERROR(while_instr->while_body()->RemoveInstruction(input)); } unified = true; } @@ -257,8 +258,8 @@ absl::StatusOr ScanLoopAccumulatorInputUnification::RunImpl( VLOG(2) << "HLO module before ScanLoopAccumulatorInputUnification:"; XLA_VLOG_LINES(2, module->ToString()); - TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow_analysis, - HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + ASSIGN_OR_RETURN(std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); // This pass can only be applied to unrollable loops since we need to find the // accumulators and inputs that are by definition updated and read fully via @@ -269,15 +270,15 @@ absl::StatusOr ScanLoopAccumulatorInputUnification::RunImpl( // TODO(b/337883537): We might want to simplify compare instructions before // this. It helps us identify more inputs and accumulators. - TF_ASSIGN_OR_RETURN(bool changed, UnifyAccumulatorWithInput( - *dataflow_analysis, unrollable_loops)); + ASSIGN_OR_RETURN(bool changed, UnifyAccumulatorWithInput(*dataflow_analysis, + unrollable_loops)); if (changed) { for (auto& [while_instr, loop_config] : unrollable_loops) { - TF_RETURN_IF_ERROR(TryRemoveDeadWhileParams(while_instr).status()); + RETURN_IF_ERROR(TryRemoveDeadWhileParams(while_instr).status()); } - TF_RETURN_IF_ERROR(TupleSimplifier{}.Run(module).status()); - TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); + RETURN_IF_ERROR(TupleSimplifier{}.Run(module).status()); + RETURN_IF_ERROR(module->RemoveUnusedComputations()); VLOG(2) << "HLO module after ScanLoopAccumulatorInputUnification:"; XLA_VLOG_LINES(2, module->ToString()); diff --git a/third_party/xla/xla/service/scatter_expander.cc b/third_party/xla/xla/service/scatter_expander.cc index 3b6b1776ec29f8..a529fe96d42307 100644 --- a/third_party/xla/xla/service/scatter_expander.cc +++ b/third_party/xla/xla/service/scatter_expander.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -57,32 +58,31 @@ static absl::StatusOr CheckIndexValidity( // Check if the index has any negative values. HloInstruction* zero_index = BroadcastZeros( computation, index->shape().element_type(), index->shape().dimensions()); - TF_ASSIGN_OR_RETURN( - HloInstruction * negative_index_check, - MakeCompareHlo(ComparisonDirection::kLe, zero_index, index)); + ASSIGN_OR_RETURN(HloInstruction * negative_index_check, + MakeCompareHlo(ComparisonDirection::kLe, zero_index, index)); // Check if the index is OOB w.r.t. the operand dimensions and window sizes. std::vector max_valid_index(operand_dims.size()); for (int i = 0; i < operand_dims.size(); ++i) { max_valid_index[i] = operand_dims[i] - window_sizes[i]; } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * max_valid_index_constant, MakeR1ConstantHlo(computation, index->shape().element_type(), max_valid_index)); - TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check, - MakeCompareHlo(ComparisonDirection::kGe, - max_valid_index_constant, index)); + ASSIGN_OR_RETURN(HloInstruction * oob_index_check, + MakeCompareHlo(ComparisonDirection::kGe, + max_valid_index_constant, index)); // Combine the results of the two checks above. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * valid_index, MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check)); // Reduce the index validity check vector into a scalar predicate. auto reduction_init = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * valid_index_reduced, MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module)); @@ -131,23 +131,22 @@ absl::StatusOr> ScatterExpander::ScatterLoopBody( // and transform that to an index into the `operand` space. HloInstruction* index_vector; if (has_scalar_indices) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( index_vector, MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1})); } else { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * index_into_scatter_indices, PadVectorWithZeros(induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); int index_vector_size = scatter_indices->shape().dimensions(1); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * index_vector_2d, MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices, {1, index_vector_size})); - TF_ASSIGN_OR_RETURN(index_vector, - ElideDegenerateDims(index_vector_2d, {0})); + ASSIGN_OR_RETURN(index_vector, ElideDegenerateDims(index_vector_2d, {0})); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * scatter_slice_start, ExpandIndexVectorIntoOperandSpace( scatter->scatter_indices()->shape(), @@ -159,7 +158,7 @@ absl::StatusOr> ScatterExpander::ScatterLoopBody( // Extract the slice to be used to update from `updates` tensor for the // induction_var corresponding to this iteration of the while loop. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * index_into_updates, PadVectorWithZeros( induction_var_as_vector, /*zeros_to_prepend=*/0, @@ -182,12 +181,12 @@ absl::StatusOr> ScatterExpander::ScatterLoopBody( for (int i = 0, n = operands.size(); i < n; ++i) { HloInstruction* update = updates[i]; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * update_slice, MakeDynamicSliceHlo(update, index_into_updates, update_slice_bounds)); - TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter, - ElideDegenerateDims(update_slice, {0})); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter, + ElideDegenerateDims(update_slice, {0})); + ASSIGN_OR_RETURN( HloInstruction * update_slice_with_dims_inserted, InsertDegenerateDims(update_slice_for_scatter, degenerated_dims)); update_slices_with_dims_inserted[i] = update_slice_with_dims_inserted; @@ -203,9 +202,9 @@ absl::StatusOr> ScatterExpander::ScatterLoopBody( // Extract the slice to update from `operand` tensor. HloInstruction* operand = operands[i]; const Shape& update_slice_shape = update_slice_with_dims_inserted->shape(); - TF_ASSIGN_OR_RETURN(HloInstruction * operand_slice_to_update, - MakeDynamicSliceHlo(operand, scatter_slice_start, - update_slice_shape.dimensions())); + ASSIGN_OR_RETURN(HloInstruction * operand_slice_to_update, + MakeDynamicSliceHlo(operand, scatter_slice_start, + update_slice_shape.dimensions())); operand_slices_to_update[i] = operand_slice_to_update; if (i == 0) { actual_update_slice_dims = update_slice_shape.dimensions(); @@ -214,7 +213,7 @@ absl::StatusOr> ScatterExpander::ScatterLoopBody( } } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * is_index_valid, CheckIndexValidity(operands[0]->parent(), scatter_slice_start, operands[0]->shape().dimensions(), @@ -229,18 +228,18 @@ absl::StatusOr> ScatterExpander::ScatterLoopBody( // computation. // NOTE: For scatters with N outputs, we currently have duplicate the Map // computation N times because we don't support multioutput Map yet. - TF_ASSIGN_OR_RETURN(HloComputation * to_apply, - CallAndGetOutput(scatter->to_apply(), i)); - TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand_slice, - MakeMapHlo(map_operands, to_apply)); + ASSIGN_OR_RETURN(HloComputation * to_apply, + CallAndGetOutput(scatter->to_apply(), i)); + ASSIGN_OR_RETURN(HloInstruction * updated_operand_slice, + MakeMapHlo(map_operands, to_apply)); // Select the updated operand only if the index is valid. If not, select the // original value. - TF_ASSIGN_OR_RETURN(HloInstruction * updates_to_apply, - MakeSelectHlo(is_index_valid, updated_operand_slice, - operand_slices_to_update[i])); - TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand, - MakeDynamicUpdateSliceHlo(operands[i], updates_to_apply, - scatter_slice_start)); + ASSIGN_OR_RETURN(HloInstruction * updates_to_apply, + MakeSelectHlo(is_index_valid, updated_operand_slice, + operand_slices_to_update[i])); + ASSIGN_OR_RETURN(HloInstruction * updated_operand, + MakeDynamicUpdateSliceHlo(operands[i], updates_to_apply, + scatter_slice_start)); updated_loop_state.push_back(updated_operand); } updated_loop_state.push_back(scatter_indices); @@ -297,9 +296,9 @@ absl::StatusOr ScatterExpander::ExpandInstruction( // Canonicalize the scatter_indices, after which the size of its most-major // dimension must be same as the while loop trip count. - TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices, - CanonicalizeScatterIndices( - scatter_indices, dim_numbers.index_vector_dim())); + ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices, + CanonicalizeScatterIndices(scatter_indices, + dim_numbers.index_vector_dim())); CHECK_EQ(scatter_loop_trip_count, canonical_scatter_indices->shape().dimensions(0)); @@ -308,10 +307,10 @@ absl::StatusOr ScatterExpander::ExpandInstruction( std::vector adjusted_canonical_updates; adjusted_canonical_updates.reserve(scatter_updates.size()); for (HloInstruction* update : scatter_updates) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * canonical_update, PermuteScatterAndWindowDims(update, dim_numbers.update_window_dims())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * adjusted_canonical_update, AdjustScatterDims(scatter_indices->shape(), canonical_update, dim_numbers.index_vector_dim())); @@ -334,8 +333,8 @@ absl::StatusOr ScatterExpander::ExpandInstruction( return ScatterLoopBody(scatter, induction_var, loop_state); }, scatter->metadata()); - TF_ASSIGN_OR_RETURN(std::vector scatter_loop_result, - scatter_loop_result_status); + ASSIGN_OR_RETURN(std::vector scatter_loop_result, + scatter_loop_result_status); auto results = absl::MakeSpan(scatter_loop_result).first(scatter_operands.size()); return MaybeMakeTuple(results); diff --git a/third_party/xla/xla/service/scatter_utils.cc b/third_party/xla/xla/service/scatter_utils.cc index 63cb38c62fbd7d..75f64968ad3094 100644 --- a/third_party/xla/xla/service/scatter_utils.cc +++ b/third_party/xla/xla/service/scatter_utils.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -96,15 +97,15 @@ absl::StatusOr AdjustScatterDims( absl::StatusOr CanonicalizeScatterIndices( HloInstruction* scatter_indices, int64_t index_vector_dim) { // Transpose the non-index-vector dimensions to the front. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * transposed_scatter_indices, TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); if (scatter_indices->shape().dimensions().size() == index_vector_dim + 1 && scatter_indices->shape().dimensions(index_vector_dim) == 1) { auto new_shape = ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); - TF_ASSIGN_OR_RETURN(scatter_indices, - MakeReshapeHlo(new_shape, scatter_indices)); + ASSIGN_OR_RETURN(scatter_indices, + MakeReshapeHlo(new_shape, scatter_indices)); } bool indices_are_scalar = index_vector_dim == scatter_indices->shape().dimensions().size(); @@ -152,7 +153,7 @@ absl::StatusOr CallAndGetOutput(HloComputation* original, new_comp->AddInstruction( HloInstruction::CreateGetTupleElement(call_original, output_index)), /*accept_different_shape=*/true); - TF_RETURN_IF_ERROR(CallInliner::Inline(call_original).status()); + RETURN_IF_ERROR(CallInliner::Inline(call_original).status()); return new_comp; } @@ -196,7 +197,7 @@ absl::StatusOr CallComputationAndGetIthOutputWithBinaryParams( new_comp->AddInstruction( HloInstruction::CreateGetTupleElement(call_original, output_index)), /*accept_different_shape=*/true); - TF_RETURN_IF_ERROR(CallInliner::Inline(call_original).status()); + RETURN_IF_ERROR(CallInliner::Inline(call_original).status()); return new_comp; } diff --git a/third_party/xla/xla/service/scheduling_annotations_util.cc b/third_party/xla/xla/service/scheduling_annotations_util.cc index d76e54920d8880..20890cf2283894 100644 --- a/third_party/xla/xla/service/scheduling_annotations_util.cc +++ b/third_party/xla/xla/service/scheduling_annotations_util.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/collective_pipeliner_utils.h" #include "xla/side_effect_util.h" @@ -75,10 +76,10 @@ absl::Status VerifyAnnotation(const HloInstruction* instr, "Instruction has more than 2 scheduling annotation fields, inst: ", instr->name(), ", annotation: ", annotation)); } - TF_RETURN_IF_ERROR(verify_integer_or_empty( + RETURN_IF_ERROR(verify_integer_or_empty( annotation_fields[0], "group id", /*verify_non_negative_integer=*/true)); if (annotation_fields.size() == 2) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( verify_integer_or_empty(annotation_fields[1], "iteration id")); } return absl::OkStatus(); @@ -96,7 +97,7 @@ absl::StatusOr> ParseAnnotation( if (annotation_str == kXlaNoOpSchedulingGroup) { return std::nullopt; } - TF_RETURN_IF_ERROR(VerifyAnnotation(instr, annotation_str)); + RETURN_IF_ERROR(VerifyAnnotation(instr, annotation_str)); std::vector annotation_fields = absl::StrSplit(annotation_str, delimiter); @@ -133,7 +134,7 @@ absl::StatusOr> GetSchedulingAnnotation( absl::Status SetSchedulingAnnotation(HloInstruction* instr, std::string annotation) { - TF_RETURN_IF_ERROR(VerifyAnnotation(instr, annotation)); + RETURN_IF_ERROR(VerifyAnnotation(instr, annotation)); FrontendAttributes frontend_attributes = instr->frontend_attributes(); if (frontend_attributes.map().contains(kXlaSchedulingGroupIdAttr)) { frontend_attributes.mutable_map()->find(kXlaSchedulingGroupIdAttr)->second = @@ -163,7 +164,7 @@ bool RemoveSchedulingAnnotation(HloInstruction* instr) { absl::StatusOr> GetSchedulingAnnotationIterationId(const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto annotation, ParseAnnotation(instr)); + ASSIGN_OR_RETURN(auto annotation, ParseAnnotation(instr)); if (!annotation.has_value()) { return std::nullopt; } @@ -172,8 +173,8 @@ GetSchedulingAnnotationIterationId(const HloInstruction* instr) { absl::StatusOr RemoveSchedulingAnnotationIterationId( HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(std::optional annotation, - GetSchedulingAnnotation(instr)); + ASSIGN_OR_RETURN(std::optional annotation, + GetSchedulingAnnotation(instr)); if (!annotation || !annotation->iteration_id) { return false; } @@ -182,13 +183,13 @@ absl::StatusOr RemoveSchedulingAnnotationIterationId( return RemoveSchedulingAnnotation(instr); } annotation->iteration_id = std::nullopt; - TF_RETURN_IF_ERROR(SetSchedulingAnnotation(instr, *annotation)); + RETURN_IF_ERROR(SetSchedulingAnnotation(instr, *annotation)); return true; } absl::StatusOr> GetSchedulingAnnotationGroupId( const HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto annotation, ParseAnnotation(instr)); + ASSIGN_OR_RETURN(auto annotation, ParseAnnotation(instr)); if (!annotation.has_value()) { return std::nullopt; } @@ -204,8 +205,8 @@ absl::StatusOr NextSchedulingGroupId( int64_t next_scheduling_id = 0; for (const HloComputation* comp : module.computations()) { for (const HloInstruction* hlo : comp->instructions()) { - TF_ASSIGN_OR_RETURN(std::optional scheduling_id, - GetSchedulingAnnotationGroupId(hlo)); + ASSIGN_OR_RETURN(std::optional scheduling_id, + GetSchedulingAnnotationGroupId(hlo)); if (scheduling_id.has_value()) { next_scheduling_id = std::max(next_scheduling_id, scheduling_id.value()); diff --git a/third_party/xla/xla/service/service.cc b/third_party/xla/xla/service/service.cc index 567f136e82a5b7..82125f4b87ee3e 100644 --- a/third_party/xla/xla/service/service.cc +++ b/third_party/xla/xla/service/service.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/executable_run_options.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/evaluator/hlo_evaluator.h" @@ -108,7 +109,7 @@ absl::Status RecordArguments( TransferManager* transfer_manager, HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Literal literal, transfer_manager->TransferLiteralFromDevice(stream, *argument)); *module->add_arguments() = literal.ToProto(); @@ -121,9 +122,8 @@ absl::Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); - TF_ASSIGN_OR_RETURN( - Literal literal, - transfer_manager->TransferLiteralFromDevice(stream, result)); + ASSIGN_OR_RETURN(Literal literal, + transfer_manager->TransferLiteralFromDevice(stream, result)); *module->mutable_result() = literal.ToProto(); return absl::OkStatus(); } @@ -216,8 +216,8 @@ absl::Status Service::Unregister(const GlobalDataHandle& data) { // Deconstructs a previously-allocated global handle. absl::StatusOr>> Service::DeconstructTuple(const GlobalData& data) { - TF_ASSIGN_OR_RETURN(std::vector elements, - allocation_tracker_.DeconstructTuple(data.handle())); + ASSIGN_OR_RETURN(std::vector elements, + allocation_tracker_.DeconstructTuple(data.handle())); std::vector> out; out.reserve(elements.size()); for (GlobalDataHandle& element : elements) { @@ -228,7 +228,7 @@ Service::DeconstructTuple(const GlobalData& data) { absl::Status Service::ValidateResultShape(const Shape& client_shape, const Shape& result_shape) { - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape)); if (!ShapeUtil::Compatible(client_shape, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " @@ -247,8 +247,8 @@ Service::ResolveAndValidateArguments( std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); for (size_t i = 0; i < arguments.size(); ++i) { - TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, - allocation_tracker_.Resolve(arguments[i]->handle())); + ASSIGN_OR_RETURN(std::vector replicated_buffers, + allocation_tracker_.Resolve(arguments[i]->handle())); CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size()); for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { const ShapedBuffer* shaped_buffer = replicated_buffers[replica]; @@ -298,7 +298,7 @@ Service::BuildExecutables(const HloModuleProto* module_proto, VLOG(1) << StrFormat("BuildExecutable on service %p", this); VLOG(1) << "Computation :" << module_proto->name(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto module, CreateModuleFromProto(*module_proto, *module_config, run_backend_only)); module->set_layout_canonicalization_callback( @@ -310,13 +310,13 @@ Service::BuildExecutables(const HloModuleProto* module_proto, std::vector> executables; if (!run_backend_only) { - TF_ASSIGN_OR_RETURN(executables, - backend->compiler()->Compile( - std::move(module), std::move(executors), options)); + ASSIGN_OR_RETURN(executables, + backend->compiler()->Compile( + std::move(module), std::move(executors), options)); } else { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - backend->compiler()->RunBackend(std::move(module), - executors[0], options)); + ASSIGN_OR_RETURN(std::unique_ptr executable, + backend->compiler()->RunBackend(std::move(module), + executors[0], options)); executables.push_back(std::move(executable)); } @@ -333,7 +333,7 @@ Service::BuildAotResults( VLOG(1) << "Computation: " << module_proto->name(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto module, CreateModuleFromProto(*module_proto, *module_config, run_backend_only)); DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName); @@ -343,7 +343,7 @@ Service::BuildAotResults( aot_options.set_device_allocator(options.device_allocator); aot_options.set_run_backend_only(run_backend_only); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> aot_results, backend->compiler()->CompileAheadOfTime(std::move(module), aot_options)); return std::move(aot_results); @@ -357,11 +357,10 @@ absl::StatusOr Service::ExecuteAndRegisterResult( // Set up streams. std::vector streams; - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle)); + ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle)); TF_RET_CHECK(!replicas.empty()); for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, - backend->BorrowStream(executor)); + ASSIGN_OR_RETURN(StreamPool::Ptr stream, backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } @@ -394,8 +393,8 @@ absl::StatusOr Service::ExecuteAndRegisterResult( } if (options_.number_of_replicas() == 1) { - TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper( - run_options.data(), arguments[0])); + ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper( + run_options.data(), arguments[0])); return allocation_tracker_.Register(std::move(result), result_tag); } @@ -406,8 +405,8 @@ absl::StatusOr Service::ExecuteAndRegisterResult( replicated_arguments.push_back(arg); } - TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( - run_options, replicated_arguments)); + ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( + run_options, replicated_arguments)); TF_RET_CHECK(!results.empty()); return allocation_tracker_.RegisterReplicatedBuffers(std::move(results), result_tag); @@ -429,8 +428,7 @@ absl::StatusOr> Service::GetExecutors( } std::vector executors; for (const auto& device_handle : execution_options.device_handles()) { - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, device_handle)); + ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, device_handle)); se::StreamExecutor* executor = replicas[0]; CHECK(executor != nullptr); executors.push_back(executor); @@ -445,10 +443,10 @@ Service::GetArguments(const ExecutionOptions& execution_options, // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the // zeroth core. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto replicas, Replicas(*execute_backend_, execution_options.device_handles(0))); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> replicated_arguments, ResolveAndValidateArguments(arguments, replicas)); return replicated_arguments; @@ -475,12 +473,12 @@ absl::StatusOr>> Service::ExecuteGraph( << "program shape may not be empty"; // Get the executors. - TF_ASSIGN_OR_RETURN(std::vector executors, - GetExecutors(execution_options, /*requests_size=*/1, - /*request_index=*/0)); + ASSIGN_OR_RETURN(std::vector executors, + GetExecutors(execution_options, /*requests_size=*/1, + /*request_index=*/0)); // Get the replicated arguments. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> replicated_arguments, GetArguments(execution_options, computation.arguments)); @@ -504,11 +502,10 @@ absl::StatusOr>> Service::ExecuteGraph( // the program and the argument allocations. Here, we care only about the // shapes of the arguments, so, it is sufficient to use the arguments of // replica 0. - TF_ASSIGN_OR_RETURN( - ProgramShape program_shape, - ProgramShape::FromProto( - computation.computation.proto().host_program_shape())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(ProgramShape program_shape, + ProgramShape::FromProto( + computation.computation.proto().host_program_shape())); + ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(program_shape, replicated_arguments.front(), computation.execution_options)); @@ -524,7 +521,7 @@ absl::StatusOr>> Service::ExecuteGraph( // // TODO(jlebar): There's currently no way to pass a device allocator to // ExecuteGraph, so we have to pass a null device_allocator below. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> executables, BuildExecutables(&computation.computation.proto(), std::move(module_config), execute_backend_.get(), @@ -540,11 +537,11 @@ absl::StatusOr>> Service::ExecuteGraph( if (executable_ptr->dumping_snapshot()) { *snapshot.mutable_hlo() = *executable_ptr->hlo_proto(); - TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( - executors[0]->device_ordinal())); - TF_RETURN_IF_ERROR( - RecordArguments(replicated_arguments.front(), stream.get(), - execute_backend_->transfer_manager(), &snapshot)); + ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( + executors[0]->device_ordinal())); + RETURN_IF_ERROR(RecordArguments(replicated_arguments.front(), stream.get(), + execute_backend_->transfer_manager(), + &snapshot)); } ExecutionProfile profile; @@ -572,7 +569,7 @@ absl::StatusOr>> Service::ExecuteGraph( DumpHloSnapshotIfEnabled(executable_ptr->module(), snapshot); } - TF_RETURN_IF_ERROR(execution_status); + RETURN_IF_ERROR(execution_status); std::vector> out; out.reserve(outputs.size()); @@ -581,13 +578,12 @@ absl::StatusOr>> Service::ExecuteGraph( } if (executable_ptr->dumping_snapshot()) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(outputs[0], 0)); - TF_ASSIGN_OR_RETURN(auto stream, - execute_backend_->BorrowStream(executors[0])); - TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), - execute_backend_->transfer_manager(), - &snapshot)); + ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(outputs[0], 0)); + ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executors[0])); + RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), + execute_backend_->transfer_manager(), + &snapshot)); DumpHloSnapshotIfEnabled(executable_ptr->module(), snapshot); } @@ -634,7 +630,7 @@ absl::StatusOr> Service::BuildExecutable( return absl::StrFormat("XlaCompile:#module=%s#", module_proto.name()); }}; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr module, CreateModuleFromProto(module_proto, *module_config, run_backend_only)); UpdateEntryComputationLayout( @@ -650,11 +646,11 @@ absl::StatusOr> Service::BuildExecutable( if (DumpingEnabledForHloModule(*module)) { hlo_proto_before_opt = std::make_unique(MakeHloProto(*module)); } - TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses( - std::move(module), executor, options)); + ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses( + std::move(module), executor, options)); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr executable, backend->compiler()->RunBackend(std::move(module), executor, options)); @@ -695,16 +691,16 @@ absl::StatusOr Service::Compile( argument_shape_ptrs.push_back(&shape); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ProgramShape program_shape, ProgramShape::FromProto(computation.proto().host_program_shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, argument_shape_ptrs, - &execution_options)); + ASSIGN_OR_RETURN(std::unique_ptr module_config, + CreateModuleConfig(program_shape, argument_shape_ptrs, + &execution_options)); VLOG(3) << "Compile created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr executable, BuildExecutable(computation.proto(), std::move(module_config), execute_backend_.get(), @@ -720,13 +716,13 @@ absl::StatusOr> Service::Execute( ExecutionProfile* execution_profile) { VLOG(1) << "running execute request"; - TF_ASSIGN_OR_RETURN(std::shared_ptr executable, - compilation_cache_.LookUp(handle)); + ASSIGN_OR_RETURN(std::shared_ptr executable, + compilation_cache_.LookUp(handle)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> replicated_arguments, ResolveAndValidateArguments(arguments, replicas)); @@ -752,19 +748,19 @@ absl::StatusOr> Service::Execute( } } - TF_ASSIGN_OR_RETURN(auto stream, - execute_backend_->BorrowStream( - execute_backend_->default_stream_executor())); + ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream( + execute_backend_->default_stream_executor())); HloSnapshot snapshot; if (executable->dumping_snapshot()) { *snapshot.mutable_hlo() = *executable->hlo_proto(); snapshot.set_execution_platform(execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR( - RecordArguments(replicated_arguments.front(), stream.get(), - execute_backend_->transfer_manager(), &snapshot)); + RETURN_IF_ERROR(RecordArguments(replicated_arguments.front(), stream.get(), + execute_backend_->transfer_manager(), + &snapshot)); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( GlobalDataHandle output, ExecuteAndRegisterResult( executable.get(), replicated_arguments, execute_backend_.get(), @@ -772,11 +768,11 @@ absl::StatusOr> Service::Execute( absl::StrCat("result of ", executable->name()), execution_profile)); if (executable->dumping_snapshot()) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(output, 0)); - TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), - execute_backend_->transfer_manager(), - &snapshot)); + ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(output, 0)); + RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), + execute_backend_->transfer_manager(), + &snapshot)); DumpHloSnapshotIfEnabled(executable->module(), snapshot); } @@ -785,8 +781,8 @@ absl::StatusOr> Service::Execute( absl::StatusOr Service::TransferToClient( const GlobalData& data, const Shape* shape_with_layout) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, - allocation_tracker_.ResolveForReplica(data.handle(), 0)); + ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, + allocation_tracker_.ResolveForReplica(data.handle(), 0)); Shape return_shape; if (shape_with_layout) { @@ -809,11 +805,10 @@ absl::StatusOr Service::TransferToClient( } } - TF_ASSIGN_OR_RETURN( - auto stream, - execute_backend_->BorrowStream(shaped_buffer->physical_device_ordinal())); + ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( + shaped_buffer->physical_device_ordinal())); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Literal result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( stream.get(), *shaped_buffer)); @@ -829,9 +824,9 @@ absl::StatusOr> Service::TransferToServer( const Shape& shape = literal_slice.shape(); std::vector replicas; if (device_handle) { - TF_ASSIGN_OR_RETURN(replicas, Replicas(*execute_backend_, *device_handle)); + ASSIGN_OR_RETURN(replicas, Replicas(*execute_backend_, *device_handle)); } else { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } @@ -843,23 +838,23 @@ absl::StatusOr> Service::TransferToServer( return execute_backend_->compiler()->DefaultDeviceShapeRepresentation( shape); }; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ScopedShapedBuffer shaped_buffer, execute_backend_->transfer_manager()->AllocateScopedShapedBuffer( shape, execute_backend_->memory_allocator(), executor->device_ordinal(), device_shape_representation_fn)); - TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); - TF_RETURN_IF_ERROR( + ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); + RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( stream.get(), literal_slice, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } - TF_ASSIGN_OR_RETURN(GlobalDataHandle out, - allocation_tracker_.RegisterReplicatedBuffers( - std::move(replicated_buffers), - StrCat("TransferToServer literal of shape ", - ShapeUtil::HumanString(shape)))); + ASSIGN_OR_RETURN(GlobalDataHandle out, + allocation_tracker_.RegisterReplicatedBuffers( + std::move(replicated_buffers), + StrCat("TransferToServer literal of shape ", + ShapeUtil::HumanString(shape)))); return std::make_unique(this, out); } @@ -878,13 +873,12 @@ absl::Status Service::TransferToInfeed(const LiteralSlice& literal, se::StreamExecutor* executor; if (device_handle) { - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, *device_handle)); + ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, *device_handle)); executor = replicas[replica_id]; } else { - TF_ASSIGN_OR_RETURN( - auto replicas, - Replicas(*execute_backend_, SingleComputationDeviceHandle())); + ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); executor = replicas[replica_id]; } @@ -904,19 +898,18 @@ absl::StatusOr Service::TransferFromOutfeed( se::StreamExecutor* executor; if (device_handle) { - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, *device_handle)); + ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, *device_handle)); executor = replicas[replica_id]; } else { - TF_ASSIGN_OR_RETURN( - auto replicas, - Replicas(*execute_backend_, SingleComputationDeviceHandle())); + ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); executor = replicas[replica_id]; } auto literal = Literal::CreateFromShape(*shape_with_layout); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( executor, &literal)); return literal; @@ -934,25 +927,25 @@ absl::StatusOr Service::ComputeConstantGraph( "constant computation may not depend on any parameters."); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ProgramShape program_shape, ProgramShape::FromProto(computation.proto().host_program_shape())); DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); if (output_layout) { - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - *output_layout, program_shape.result())); + RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(*output_layout, + program_shape.result())); } HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - CreateModuleFromProto(computation.proto(), config)); + ASSIGN_OR_RETURN(std::unique_ptr module, + CreateModuleFromProto(computation.proto(), config)); DynamicPadder dynamic_padder; - TF_RETURN_IF_ERROR(dynamic_padder.Run(module.get()).status()); + RETURN_IF_ERROR(dynamic_padder.Run(module.get()).status()); - TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, - DynamicDimensionInference::Run(module.get())); + ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module.get())); HloEvaluator evaluator; evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference); @@ -970,7 +963,7 @@ absl::StatusOr Service::ComputeConstantGraph( custom_call->custom_call_target(), custom_call->ToString()); }); - TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {})); + ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. @@ -983,8 +976,8 @@ absl::StatusOr Service::ComputeConstantGraph( } absl::StatusOr Service::GetShape(const GlobalData& data) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.ResolveForReplica(data.handle(), 0)); + ASSIGN_OR_RETURN(const ShapedBuffer* buffer, + allocation_tracker_.ResolveForReplica(data.handle(), 0)); return buffer->on_device_shape(); } @@ -997,7 +990,7 @@ DeviceHandle Service::SingleComputationDeviceHandle() const { absl::StatusOr> Service::Replicas( const Backend& backend, const DeviceHandle& device_handle) const { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( DeviceAssignment da, backend.computation_placer()->AssignDevices( options_.number_of_replicas(), device_handle.device_count())); @@ -1006,7 +999,7 @@ absl::StatusOr> Service::Replicas( // From the computation placer, find out the device ids of the replicas for // the given device handle. int64_t device_ordinal = da.DeviceId(replica, device_handle.handle()); - TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal)); + ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal)); replicas.push_back(executor); } return replicas; diff --git a/third_party/xla/xla/service/service_executable_run_options.h b/third_party/xla/xla/service/service_executable_run_options.h index 98c983ce4f0b39..a25eee4e06bd85 100644 --- a/third_party/xla/xla/service/service_executable_run_options.h +++ b/third_party/xla/xla/service/service_executable_run_options.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/executable_run_options.h" #include "xla/service/stream_pool.h" #include "xla/stream_executor/platform.h" @@ -71,7 +72,7 @@ class ServiceExecutableRunOptions { "No stream borrower"); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector streams, stream_borrower_(device_ordinal, /*num_streams=*/1, priority)); StreamPool::Ptr stream = std::move(streams.back()); diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index d2184d441db9ef..a4585365846172 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/overflow_util.h" @@ -342,7 +343,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, return shape; } - TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation")); + RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation")); DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); switch (opcode) { @@ -474,7 +475,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, /* static */ absl::StatusOr ShapeInference::InferTopKShape( const Shape& operand_shape, int64_t k) { - TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of top-k operation")); + RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of top-k operation")); int64_t last_dim = static_cast(operand_shape.dimensions().size()) - 1; std::vector is_dynamic(operand_shape.dimensions().size()); @@ -507,7 +508,7 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, const Shape* arg_shape = nullptr; PrimitiveType element_type = PRIMITIVE_TYPE_INVALID; for (const Shape* shape : arg_shapes) { - TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation")); + RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation")); if (!arg_shape) { arg_shape = shape; element_type = arg_shape->element_type(); @@ -569,10 +570,9 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, inferred_dim_and_bound = InferConcatenatedDimAndBound( leftSize, rightSize, leftBound, rightBound); } else { - TF_ASSIGN_OR_RETURN( - inferred_dim_and_bound, - InferMostSpecificDimAndBound(dim, leftSize, rightSize, leftBound, - rightBound)); + ASSIGN_OR_RETURN(inferred_dim_and_bound, + InferMostSpecificDimAndBound(dim, leftSize, rightSize, + leftBound, rightBound)); } inferred_sizes[dim] = inferred_dim_and_bound.dimension; inferred_bounds[dim] = inferred_dim_and_bound.bound; @@ -666,9 +666,9 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(random_shape)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(operand_shape, "lhs of stochastic convert operation")); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(random_shape, "rhs of stochastic convert operation")); if (!primitive_util::IsUnsignedIntegralType(random_shape.element_type())) { @@ -954,14 +954,14 @@ void GenerateDotResultDimensions( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers, std::optional preferred_element_type) { - TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot")); - TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); + RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot")); + RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); // Validate basic properties of dot dimension numbers. - TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); + RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); // Check the number and sizes of batch and contracting dimensions. - TF_RETURN_IF_ERROR(CheckDotDimensionConstraints(lhs, rhs, dimension_numbers)); + RETURN_IF_ERROR(CheckDotDimensionConstraints(lhs, rhs, dimension_numbers)); std::vector dimensions; std::vector is_dynamic; @@ -999,9 +999,9 @@ void GenerateDotResultDimensions( const Shape& lhs, const Shape& rhs, const Shape& group_sizes, const RaggedDotDimensionNumbers& ragged_dot_dim_nums, std::optional preferred_element_type) { - TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of ragged dot")); - TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of ragged dot")); - TF_RETURN_IF_ERROR(ExpectArray(group_sizes, "group_sizes of ragged dot")); + RETURN_IF_ERROR(ExpectArray(lhs, "lhs of ragged dot")); + RETURN_IF_ERROR(ExpectArray(rhs, "rhs of ragged dot")); + RETURN_IF_ERROR(ExpectArray(group_sizes, "group_sizes of ragged dot")); auto fail = [lhs, rhs](const std::string& addendum) -> absl::Status { std::string message = StrFormat( @@ -1017,9 +1017,9 @@ void GenerateDotResultDimensions( ragged_dot_dim_nums.dot_dimension_numbers(); // Validate basic properties of dot dimension numbers. - TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); + RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); // Check the number and sizes of batch and contracting dimensions. - TF_RETURN_IF_ERROR(CheckDotDimensionConstraints(lhs, rhs, dimension_numbers)); + RETURN_IF_ERROR(CheckDotDimensionConstraints(lhs, rhs, dimension_numbers)); // Check that there is exactly one lhs ragged dimension. if (ragged_dot_dim_nums.lhs_ragged_dimensions_size() != 1) { @@ -1373,8 +1373,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, ShapeInference::InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, absl::Span broadcast_dimensions) { - TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); - TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); + RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); + RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( @@ -1426,9 +1426,9 @@ ShapeInference::InferElementwiseBinaryOpShape( lhs.dimensions().size() > rhs.dimensions().size() ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. - TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, - InferInDimBroadcastShape(smaller_shape, larger_shape, - broadcast_dimensions)); + ASSIGN_OR_RETURN(Shape indim_broadcast_shape, + InferInDimBroadcastShape(smaller_shape, larger_shape, + broadcast_dimensions)); return InferDegenerateDimensionBroadcastShape(indim_broadcast_shape, larger_shape); } @@ -1473,9 +1473,9 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR(ExpectArray( + RETURN_IF_ERROR(ExpectArray( lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode)))); - TF_RETURN_IF_ERROR(ExpectArray( + RETURN_IF_ERROR(ExpectArray( rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode)))); switch (opcode) { case HloOpcode::kAdd: @@ -1509,9 +1509,9 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { "operation; got %s.", PrimitiveType_Name(lhs.element_type())); } - TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(opcode, lhs, rhs, - broadcast_dimensions)); + ASSIGN_OR_RETURN(const Shape& shape, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, + broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); } else if (lhs.element_type() == F64 && rhs.element_type() == F64) { @@ -1533,9 +1533,9 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); case HloOpcode::kCompare: { - TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(opcode, lhs, rhs, - broadcast_dimensions)); + ASSIGN_OR_RETURN(const Shape& shape, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, + broadcast_dimensions)); return ShapeUtil::ChangeElementType(shape, PRED); } default: @@ -1623,7 +1623,7 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { // All arguments must have the same shape ignoring the element types. const Shape* arg_shape = arg_shapes[0]; for (size_t i = 1; i < arg_shapes.size(); ++i) { - TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map")); + RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map")); if (ShapeUtil::CompatibleIgnoringElementType(*arg_shapes[i], *arg_shape)) { continue; @@ -1711,11 +1711,10 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { /* static */ absl::StatusOr ShapeInference::InferBatchNormTrainingShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, int64_t feature_index) { - TF_RETURN_IF_ERROR( - ExpectArray(operand_shape, "operand of batch norm training")); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm training")); + RETURN_IF_ERROR( ExpectArray(offset_shape, "offset input of batch norm training")); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == @@ -1821,19 +1820,18 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, const Shape& mean_shape, const Shape& variance_shape, int64_t feature_index) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(operand_shape, "operand of batch norm inference")); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(offset_shape, "offset input of batch norm inference")); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(scale_shape, "scale input of batch norm inference")); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape)); if (feature_index >= static_cast(operand_shape.dimensions().size())) { @@ -1968,19 +1966,18 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { const Shape& operand_shape, const Shape& scale_shape, const Shape& mean_shape, const Shape& var_shape, const Shape& output_grad_shape, int64_t feature_index) { - TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad")); - TF_RETURN_IF_ERROR( - ExpectArray(scale_shape, "scale input of batch norm grad")); - TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad")); - TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad")); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad")); + RETURN_IF_ERROR(ExpectArray(scale_shape, "scale input of batch norm grad")); + RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad")); + RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad")); + RETURN_IF_ERROR( ExpectArray(output_grad_shape, "output_grad input of batch norm grad")); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape)); + RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape)); if (feature_index >= @@ -2145,8 +2142,8 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { } const Shape& rhs = *rhs_ptr; - TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); - TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); + RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); + RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); if (feature_group_count <= 0) { return InvalidArgument( @@ -2362,7 +2359,7 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { } Shape base_shape = ShapeUtil::MakeShape( lhs.element_type(), input_spatial_dims, dynamic_dimensions); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape window_output_shape, InferWindowOutputShape(base_shape, window, lhs.element_type())); @@ -2600,7 +2597,7 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RET_CHECK(all_gather_dimension < static_cast(operand_shape->dimensions().size())); - TF_RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of all-gather")); + RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of all-gather")); Shape output_shape = *operand_shape; int64_t output_shape_dimension = @@ -2629,7 +2626,7 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { /* static */ absl::StatusOr ShapeInference::InferAllGatherStartShape( absl::Span operand_shapes, int64_t all_gather_dimension, int64_t shard_count) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Shape ag_shape, InferAllGatherShape(operand_shapes, all_gather_dimension, shard_count)); Shape input_shape; @@ -2649,7 +2646,7 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { /* static */ absl::StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(*operand_shape, "operand of cross replica sum")); } if (operand_shapes.size() == 1) { @@ -2669,8 +2666,7 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RET_CHECK(scatter_dimension < static_cast(operand_shape->dimensions().size())); - TF_RETURN_IF_ERROR( - ExpectArray(*operand_shape, "operand of reduce-scatter")); + RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of reduce-scatter")); int64_t scatter_dim_input_size = operand_shape->dimensions(scatter_dimension); @@ -2780,7 +2776,7 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { /* static */ absl::StatusOr ShapeInference::InferRaggedAllToAllShape( absl::Span operand_shapes) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(*(operand_shapes[1]), "operand 1 of ragged-all-to-all")); return *(operand_shapes[1]); } @@ -2788,7 +2784,7 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { /* static */ absl::StatusOr ShapeInference::InferCollectiveBroadcastShape( absl::Span operand_shapes) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(*(operand_shapes[0]), "operand of collective-broadcast")); return *(operand_shapes[0]); } @@ -2797,7 +2793,7 @@ ShapeInference::InferCollectiveBroadcastShape( absl::Span operand_shapes, bool inplace) { if (!inplace) { for (const Shape* operand_shape : operand_shapes) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(*operand_shape, "operand of collective-permute")); } if (operand_shapes.size() == 1) { @@ -2817,8 +2813,8 @@ ShapeInference::InferCollectivePermuteStartShape( absl::InlinedVector shapes; if (!inplace) { if (operand_shapes.size() == 1) { - TF_RETURN_IF_ERROR(ExpectArray(*(operand_shapes[0]), - "operand of collective-permute-start")); + RETURN_IF_ERROR(ExpectArray(*(operand_shapes[0]), + "operand of collective-permute-start")); shapes = {*operand_shapes[0], *operand_shapes[0]}; } else { Shape tuple_shape = ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes); @@ -2881,8 +2877,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { for (const Shape* arg : reduced_args) { element_types.push_back(arg->element_type()); } - TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types, - num_reduced_args)); + RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types, + num_reduced_args)); absl::flat_hash_set dimensions_to_reduce_set; for (int64_t dim_to_reduce : dimensions_to_reduce) { @@ -2919,9 +2915,9 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { - TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, - {operand_shape.element_type()}, - /*inputs=*/1)); + RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {operand_shape.element_type()}, + /*inputs=*/1)); return InferReduceWindowShape(operand_shape, init_value_shape, window); } @@ -2946,14 +2942,14 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { for (const Shape* s : operands) { operand_element_type_vec.push_back(s->element_type()); } - TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values, - operand_element_type_vec, - /*inputs=*/number_of_input)); + RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values, + operand_element_type_vec, + /*inputs=*/number_of_input)); std::vector output_shape_vec; const size_t n = operands.size(); output_shape_vec.reserve(n); for (size_t i = 0; i < operands.size(); ++i) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto cur_output_shape, InferReduceWindowShape(*operands[i], *init_values[i], window)); output_shape_vec.push_back(cur_output_shape); @@ -2969,7 +2965,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window) { - TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); + RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); return InferWindowOutputShape(operand_shape, window, init_value_shape.element_type()); } @@ -2978,8 +2974,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { const Shape& operand_shape, const ProgramShape& select_shape, const Window& window, const Shape& source_shape, const Shape& init_value_shape, const ProgramShape& scatter_shape) { - TF_RETURN_IF_ERROR( - ExpectArray(operand_shape, "operand of select-and-scatter")); + RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of select-and-scatter")); // Check if the select function has a proper shape of (T,T) -> PRED. if (select_shape.parameters_size() != 2) { @@ -3013,14 +3008,14 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { } // Check if the scatter function has a proper shape as a reduction. - TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape}, - {source_shape.element_type()}, - /*inputs=*/1)); + RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape}, + {source_shape.element_type()}, + /*inputs=*/1)); // Check if the result shape of window operation matches the source shape. - TF_ASSIGN_OR_RETURN(const Shape& window_result_shape, - InferWindowOutputShape(operand_shape, window, - operand_shape.element_type())); + ASSIGN_OR_RETURN(const Shape& window_result_shape, + InferWindowOutputShape(operand_shape, window, + operand_shape.element_type())); if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape, window_result_shape)) { return InvalidArgument( @@ -3098,12 +3093,12 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { "\nNumber of ", x_name, ": ", x, "\n")); } }; - TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); - TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries")); - TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors")); - TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors")); + RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); + RETURN_IF_ERROR(verify_size(padding.size(), "padding entries")); + RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors")); + RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors")); if (window_reversal.has_value()) { - TF_RETURN_IF_ERROR(verify_size(window_reversal->size(), "window reversal")); + RETURN_IF_ERROR(verify_size(window_reversal->size(), "window reversal")); } Window window; @@ -3151,7 +3146,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { message, ShapeUtil::HumanString(arg), StrJoin(starts, ","), StrJoin(limits, ","), StrJoin(strides, ",")); }; - TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); + RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}", ShapeUtil::HumanString(arg), StrJoin(starts, ", "), StrJoin(limits, ", ")); @@ -3219,7 +3214,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, absl::Span start_index_shapes, absl::Span slice_sizes, bool allow_scalar_indices) { - TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); + RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); auto number_of_indices = start_index_shapes.size(); // TODO(b/118437727): Remove this path. if (!allow_scalar_indices || @@ -3238,7 +3233,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(start_indices_shape, "start indices of dynamic slice")); if (start_indices_shape.dimensions().size() != 1) { @@ -3343,8 +3338,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { "but got %s.", ShapeUtil::HumanString(operand_shape)); } - TF_RETURN_IF_ERROR( - ExpectArray(update_shape, "update of dynamic update slice")); + RETURN_IF_ERROR(ExpectArray(update_shape, "update of dynamic update slice")); auto number_of_indices = start_index_shapes.size(); // TODO(b/118437727): Remove this path. @@ -3357,8 +3351,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { number_of_indices); } const Shape& start_indices_shape = start_index_shapes[0]; - TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, - "start indices of dynamic update slice")); + RETURN_IF_ERROR(ExpectArray(start_indices_shape, + "start indices of dynamic update slice")); VLOG(2) << StrFormat( "updating slice of shape %s at dynamic start_indices %s with update " @@ -3479,7 +3473,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /*static */ absl::StatusOr ShapeInference::InferReverseShape( const Shape& operand_shape, absl::Span dimensions) { - TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); + RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { return InvalidArgument("a dimension number is duplicated in reverse"); } @@ -3626,7 +3620,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, absl::Span broadcast_sizes) { // This method is used to infer shape for xla::BroadcastInDim. - TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); + RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); TF_RET_CHECK(!operand.is_unbounded_dynamic()); for (int64_t size : broadcast_sizes) { if (size == Shape::kUnboundedSize) { @@ -3644,8 +3638,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { absl::c_copy(operand.dimensions(), dimensions.begin() + broadcast_sizes.size()); - TF_ASSIGN_OR_RETURN(Shape result, ShapeUtil::MakeValidatedShape( - operand.element_type(), dimensions)); + ASSIGN_OR_RETURN(Shape result, ShapeUtil::MakeValidatedShape( + operand.element_type(), dimensions)); for (int64_t i = 0; i < operand.dimensions().size(); ++i) { result.set_dynamic_dimension(broadcast_sizes.size() + i, operand.is_dynamic_dimension(i)); @@ -3657,8 +3651,8 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { const Shape& operand_shape, const Shape& output_shape, absl::Span broadcast_dimensions) { // This method is used to infer shape for xla::BroadcastInDim. - TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast")); - TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast")); + RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast")); + RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast")); TF_RET_CHECK(!output_shape.is_unbounded_dynamic()); const int64_t operand_rank = operand_shape.dimensions().size(); const int64_t output_rank = output_shape.dimensions().size(); @@ -3746,7 +3740,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferReshapeShape( const Shape& operand, absl::Span dimensions, int64_t inferred_dimension) { - TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); + RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), dimensions); VLOG(3) << "Reshape inferred shape: " @@ -3897,7 +3891,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferTransposeShape( const Shape& operand, absl::Span dimensions) { - TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); + RETURN_IF_ERROR(ExpectArray(operand, "transpose")); if (dimensions.size() != operand.dimensions().size() || !IsPermutation(dimensions)) { @@ -3912,9 +3906,9 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferClampShape( const Shape& min, const Shape& operand, const Shape& max) { - TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min")); - TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand")); - TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max")); + RETURN_IF_ERROR(ExpectArray(min, "clamp min")); + RETURN_IF_ERROR(ExpectArray(operand, "clamp operand")); + RETURN_IF_ERROR(ExpectArray(max, "clamp max")); // min, operand, and max must have compatible element types. if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || @@ -3942,9 +3936,9 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { /* static */ absl::StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { - TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred")); - TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true")); - TF_RETURN_IF_ERROR(ExpectArray(on_false, "select on-false")); + RETURN_IF_ERROR(ExpectArray(pred, "select pred")); + RETURN_IF_ERROR(ExpectArray(on_true, "select on-true")); + RETURN_IF_ERROR(ExpectArray(on_false, "select on-false")); if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( @@ -4196,9 +4190,9 @@ static absl::Status ValidateGatherDimensionNumbers( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, absl::Span slice_sizes) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(input_shape, "input tensor operand of gather op")); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(start_indices_shape, "gather indices operand of gather op")); if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { @@ -4239,7 +4233,7 @@ static absl::Status ValidateGatherDimensionNumbers( expanded_start_indices_shape_dynamic_dimensions.push_back(false); } - TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( + RETURN_IF_ERROR(ValidateGatherDimensionNumbers( input_shape, expanded_start_indices_shape, gather_dim_numbers)); if (slice_sizes.size() != input_shape.dimensions().size()) { @@ -4513,7 +4507,7 @@ absl::Status ValidateScatterDimensionNumbers( } const Shape& scatter_indices_shape = *arg_shapes[operand_count]; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ExpectArray(scatter_indices_shape, "scatter indices of scatter op")); if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { return InvalidArgument( @@ -4543,9 +4537,9 @@ absl::Status ValidateScatterDimensionNumbers( for (int64_t operand_i = 0; operand_i < operand_count; ++operand_i) { const Shape& operand_shape = *operand_shapes[operand_i]; const Shape& updates_shape = *updates_shapes[operand_i]; - TF_RETURN_IF_ERROR(ExpectArray( + RETURN_IF_ERROR(ExpectArray( operand_shape, absl::StrCat("operand ", operand_i, " of scatter op"))); - TF_RETURN_IF_ERROR(ExpectArray( + RETURN_IF_ERROR(ExpectArray( updates_shape, absl::StrCat("updates ", operand_i, " of scatter op"))); int64_t inserted_dims_seen = 0, input_batching_dims_seen = 0; @@ -4575,7 +4569,7 @@ absl::Status ValidateScatterDimensionNumbers( updates_shape.dimensions().size()); } - TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( + RETURN_IF_ERROR(ValidateScatterDimensionNumbers( operand_shape, expanded_scatter_indices_shape, updates_shape, scatter_dim_numbers)); @@ -4630,8 +4624,8 @@ absl::Status ValidateScatterDimensionNumbers( init_element_shape_ptrs.push_back(&init_element_shapes.back()); updates_element_types.push_back(updates_shapes[i]->element_type()); } - TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_element_shape_ptrs, - updates_element_types, operand_count)); + RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_element_shape_ptrs, + updates_element_types, operand_count)); return operand_count == 1 ? *operand_shapes[0] : ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes); diff --git a/third_party/xla/xla/service/shaped_buffer.cc b/third_party/xla/xla/service/shaped_buffer.cc index 21ac29b100f538..e1cf979074264c 100644 --- a/third_party/xla/xla/service/shaped_buffer.cc +++ b/third_party/xla/xla/service/shaped_buffer.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" @@ -78,12 +79,12 @@ ShapedBuffer::~ShapedBuffer() {} absl::StatusOr ShapedBuffer::SubShapedBuffer( const ShapeIndex& index) const { - TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape, - ShapeUtil::TryGetSubshape(on_device_shape(), index)); + ASSIGN_OR_RETURN(const Shape* device_sub_shape, + ShapeUtil::TryGetSubshape(on_device_shape(), index)); ShapedBuffer sub_shaped_buffer(*device_sub_shape, device_ordinal_, physical_device_ordinal_); - TF_ASSIGN_OR_RETURN(ShapeTree sub_buffers, - buffers_.SubShapeTree(index)); + ASSIGN_OR_RETURN(ShapeTree sub_buffers, + buffers_.SubShapeTree(index)); sub_shaped_buffer.set_buffers(std::move(sub_buffers)); return std::move(sub_shaped_buffer); } diff --git a/third_party/xla/xla/service/shaped_slice.cc b/third_party/xla/xla/service/shaped_slice.cc index 544075c77170c5..444e10bcde410d 100644 --- a/third_party/xla/xla/service/shaped_slice.cc +++ b/third_party/xla/xla/service/shaped_slice.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/service/buffer_assignment.h" #include "xla/service/shaped_slice.pb.h" #include "xla/shape.h" @@ -32,19 +33,18 @@ absl::StatusOr ShapedSlice::FromProto( const ShapedSliceProto& proto, absl::Span buffer_allocations) { ShapedSlice shaped_slice; - TF_ASSIGN_OR_RETURN( - shaped_slice.slice, - BufferAllocation::Slice::FromProto(proto.slice(), buffer_allocations)); + ASSIGN_OR_RETURN(shaped_slice.slice, BufferAllocation::Slice::FromProto( + proto.slice(), buffer_allocations)); if (!proto.has_shape()) { return absl::InvalidArgumentError("ShapedSlice proto has no shape"); } - TF_ASSIGN_OR_RETURN(shaped_slice.shape, Shape::FromProto(proto.shape())); + ASSIGN_OR_RETURN(shaped_slice.shape, Shape::FromProto(proto.shape())); return shaped_slice; } absl::StatusOr ShapedSlice::ToProto() const { ShapedSliceProto proto; - TF_ASSIGN_OR_RETURN(*proto.mutable_slice(), slice.ToProto()); + ASSIGN_OR_RETURN(*proto.mutable_slice(), slice.ToProto()); *proto.mutable_shape() = shape.ToProto(); return proto; } @@ -53,7 +53,7 @@ absl::StatusOr NullableShapedSlice::FromProto( const NullableShapedSliceProto& proto, absl::Span buffer_allocations) { if (proto.has_shaped_slice()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ShapedSlice shaped_slice, ShapedSlice::FromProto(proto.shaped_slice(), buffer_allocations)); return NullableShapedSlice(std::move(shaped_slice)); @@ -64,7 +64,7 @@ absl::StatusOr NullableShapedSlice::FromProto( absl::StatusOr NullableShapedSlice::ToProto() const { NullableShapedSliceProto proto; if (has_value()) { - TF_ASSIGN_OR_RETURN(*proto.mutable_shaped_slice(), value().ToProto()); + ASSIGN_OR_RETURN(*proto.mutable_shaped_slice(), value().ToProto()); } return proto; } diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index e02bbeab32b6ce..affd2898ddb94a 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -36,6 +36,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -1485,7 +1486,7 @@ absl::StatusOr ProcessShardingInstruction( HloSharding original_sharding = instruction->sharding(); std::vector unspec_dims; - TF_RETURN_IF_ERROR(sharding_op_util::ParseAttributes( + RETURN_IF_ERROR(sharding_op_util::ParseAttributes( Cast(instruction)->opaque(), &unspec_dims)); @@ -1499,8 +1500,8 @@ absl::StatusOr ProcessShardingInstruction( auto copy = computation->AddInstruction(HloInstruction::CreateUnary( instruction->shape(), HloOpcode::kCopy, instruction->mutable_operand(0))); - TF_ASSIGN_OR_RETURN( - std::ignore, computation->ReplaceInstruction( + ASSIGN_OR_RETURN(std::ignore, + computation->ReplaceInstruction( instruction, copy, /*preserve_sharding=*/false, /*relay_control_dependency=*/false, /*remove_unused_operands=*/false)); @@ -1509,7 +1510,7 @@ absl::StatusOr ProcessShardingInstruction( changed = true; } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool shard_group_remove_instruction, process_shard_group_instruction(instruction, replaced_with_copy)); if (!unspec_dims.empty()) { @@ -1520,17 +1521,17 @@ absl::StatusOr ProcessShardingInstruction( instruction->sharding()); } if (shard_group_remove_instruction) { - TF_ASSIGN_OR_RETURN(std::ignore, - computation->ReplaceInstruction( - instruction, instruction->mutable_operand(0), - /*preserve_sharding=*/false, - /*relay_control_dependency=*/false, - /*remove_unused_operands=*/false)); + ASSIGN_OR_RETURN(std::ignore, + computation->ReplaceInstruction( + instruction, instruction->mutable_operand(0), + /*preserve_sharding=*/false, + /*relay_control_dependency=*/false, + /*remove_unused_operands=*/false)); } } else { - TF_ASSIGN_OR_RETURN(std::ignore, - process_shard_group_instruction( - instruction, /*replaced_with_copy=*/false)); + ASSIGN_OR_RETURN(std::ignore, + process_shard_group_instruction( + instruction, /*replaced_with_copy=*/false)); } } } @@ -1571,8 +1572,8 @@ int64_t ComputeNonRootUsers(const HloInstruction* instr) { /*static*/ absl::Status ShardingPropagation::NormalizeDomain( const DomainMetadata::Domain& domain, const DomainMetadata* metadata) { if (metadata != nullptr) { - TF_ASSIGN_OR_RETURN(const auto& sharding_metadata, - ShardingMetadata::ToShardingMetadata(metadata)); + ASSIGN_OR_RETURN(const auto& sharding_metadata, + ShardingMetadata::ToShardingMetadata(metadata)); const auto& sharding = sharding_metadata->sharding(); if (sharding != nullptr) { bool is_spatially_partitioned = !sharding->IsSingleDevice(); @@ -3182,7 +3183,7 @@ absl::StatusOr ShardingPropagation::RunImpl( shard_group_id_to_shard_as_group; absl::flat_hash_map> shard_group_id_to_shard_like_group; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool changed, ProcessShardingInstruction( module, execution_threads, !cse_prevention_only_, &unspecified_dims, @@ -3250,7 +3251,7 @@ absl::StatusOr ShardingPropagation::RunImpl( for (auto computation : module->computations(execution_threads)) { for (auto instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CheckAndUpdateDeviceAssignmentsInWhileBody(instruction)); } } @@ -3350,7 +3351,7 @@ absl::StatusOr ShardingPropagation::RunImpl( int64_t iterations = 0; std::unique_ptr call_graph = CallGraph::Build(module); for (int64_t aggressiveness = 0; aggressiveness < 4; ++aggressiveness) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool changed, RunToFixPoint(aggressiveness, /*propagate_shard_group=*/true, computation_map, provided_shardings, *call_graph, module, @@ -3429,7 +3430,7 @@ absl::StatusOr ShardingPropagation::RunImpl( } } { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool changed, RunToFixPoint(/*aggressiveness=*/3, /*propagate_shard_group=*/true, computation_map, provided_shardings, *call_graph, module, @@ -3459,12 +3460,12 @@ absl::StatusOr ShardingPropagation::RunImpl( } if (instruction->IsCustomCall(spmd::kShardBarrierFrom) || instruction->IsCustomCall(spmd::kShardBarrierTo)) { - TF_ASSIGN_OR_RETURN(std::ignore, - computation->ReplaceInstruction( - instruction, instruction->mutable_operand(0), - /*preserve_sharding=*/false, - /*relay_control_dependency=*/false, - /*remove_unused_operands=*/false)); + ASSIGN_OR_RETURN(std::ignore, + computation->ReplaceInstruction( + instruction, instruction->mutable_operand(0), + /*preserve_sharding=*/false, + /*relay_control_dependency=*/false, + /*remove_unused_operands=*/false)); } } } @@ -3547,10 +3548,9 @@ absl::StatusOr ShardingPropagation::RunImpl( module, allow_spmd_sharding_propagation_to_output_vector_, allow_spmd_sharding_propagation_to_parameters_vector_); - TF_RETURN_IF_ERROR( - hlo_sharding_util::CanonicalizeLayoutAfterShardingPropagation( - module, allow_spmd_sharding_propagation_to_output_vector_, - allow_spmd_sharding_propagation_to_parameters_vector_)); + RETURN_IF_ERROR(hlo_sharding_util::CanonicalizeLayoutAfterShardingPropagation( + module, allow_spmd_sharding_propagation_to_output_vector_, + allow_spmd_sharding_propagation_to_parameters_vector_)); VLOG(1) << "Sharding propagation completed after " << iterations << " iterations"; diff --git a/third_party/xla/xla/service/sharding_remover.cc b/third_party/xla/xla/service/sharding_remover.cc index d7c52c68a534a9..ac841231720ae5 100644 --- a/third_party/xla/xla/service/sharding_remover.cc +++ b/third_party/xla/xla/service/sharding_remover.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -68,11 +69,11 @@ absl::StatusOr ShardingRemover::RunImpl( // ShardingGroupOp is dangling so we just remove it. if (instruction->custom_call_target() == sdy::kShardingGroupCustomCallTargetName) { - TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); continue; } - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith( + RETURN_IF_ERROR(instruction->ReplaceAllUsesWith( instruction->mutable_operand(0), name())); changed = true; @@ -88,7 +89,7 @@ absl::StatusOr ShardingRemover::RunImpl( auto copy = computation->AddInstruction( HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kCopy, instruction->mutable_operand(0))); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(instruction, copy)); + RETURN_IF_ERROR(computation->ReplaceInstruction(instruction, copy)); instruction = copy; } } diff --git a/third_party/xla/xla/service/source_target_pairs.h b/third_party/xla/xla/service/source_target_pairs.h index dda3ef7ab34e9f..34aa9ab8275f0e 100644 --- a/third_party/xla/xla/service/source_target_pairs.h +++ b/third_party/xla/xla/service/source_target_pairs.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" @@ -70,8 +71,8 @@ class SourceTargetPairs { static absl::StatusOr FromString(absl::string_view str) { // reusing replica groups parsing. - TF_ASSIGN_OR_RETURN(std::vector groups, - ParseReplicaGroupsOnly(str)); + ASSIGN_OR_RETURN(std::vector groups, + ParseReplicaGroupsOnly(str)); SourceTargetPairs res; for (const ReplicaGroup& group : groups) { if (group.replica_ids_size() != 2) { @@ -87,8 +88,8 @@ class SourceTargetPairs { auto source_target_pairs = instruction->frontend_attributes().map().find( kSendRecvSourceTargetPairsAttr); if (source_target_pairs != instruction->frontend_attributes().map().end()) { - TF_ASSIGN_OR_RETURN(SourceTargetPairs res, - FromString(source_target_pairs->second)); + ASSIGN_OR_RETURN(SourceTargetPairs res, + FromString(source_target_pairs->second)); return res; } return Internal( diff --git a/third_party/xla/xla/service/space_to_batch_converter.cc b/third_party/xla/xla/service/space_to_batch_converter.cc index e7c71751be70c8..d25558ba041dc2 100644 --- a/third_party/xla/xla/service/space_to_batch_converter.cc +++ b/third_party/xla/xla/service/space_to_batch_converter.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" @@ -525,8 +526,8 @@ absl::StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( activations->shape().dimensions(spatial_dimensions_to_split[0]); const int64_t batch_size = ctrl_.number_of_splits; - TF_ASSIGN_OR_RETURN( - activations, SplitAndTransposeMergedBatch( + ASSIGN_OR_RETURN(activations, + SplitAndTransposeMergedBatch( activations, activations_batch_dim, original_batch_size, spatial_dimensions_to_split)); @@ -560,10 +561,10 @@ absl::StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( end_indices[remapped_batch_dimension] = batch_size - 1; end_indices[spatial_dimension_to_split] = spatial_split_size; - TF_ASSIGN_OR_RETURN(first_slice, - MakeSliceHlo(activations, start_indices, end_indices, - strides, &activations->metadata(), - &activations->frontend_attributes())); + ASSIGN_OR_RETURN(first_slice, + MakeSliceHlo(activations, start_indices, end_indices, + strides, &activations->metadata(), + &activations->frontend_attributes())); VLOG(1) << "first slice " << first_slice->ToString(); PaddingConfig padding_config = @@ -571,10 +572,10 @@ absl::StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( padding_config.mutable_dimensions(remapped_batch_dimension) ->set_edge_padding_low(1); - TF_ASSIGN_OR_RETURN(first_slice, - MakePadHlo(first_slice, padding, padding_config, - &first_slice->metadata(), - &first_slice->frontend_attributes())); + ASSIGN_OR_RETURN(first_slice, + MakePadHlo(first_slice, padding, padding_config, + &first_slice->metadata(), + &first_slice->frontend_attributes())); } HloInstruction* halo_region = nullptr; @@ -585,7 +586,7 @@ absl::StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( start_indices_halo[remapped_batch_dimension] = 1; end_indices_halo[spatial_dimension_to_split] = halo_size - low_padding; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( halo_region, MakeSliceHlo(activations, start_indices_halo, end_indices_halo, strides, &activations->metadata(), @@ -595,10 +596,10 @@ absl::StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( MakeNoPaddingConfig(halo_region->shape().dimensions().size()); padding_config_halo.mutable_dimensions(remapped_batch_dimension) ->set_edge_padding_high(1); - TF_ASSIGN_OR_RETURN(halo_region, - MakePadHlo(halo_region, padding, padding_config_halo, - &halo_region->metadata(), - &halo_region->frontend_attributes())); + ASSIGN_OR_RETURN(halo_region, + MakePadHlo(halo_region, padding, padding_config_halo, + &halo_region->metadata(), + &halo_region->frontend_attributes())); } if ((halo_size == 0 && low_padding != 0) || low_padding < 0) { @@ -615,15 +616,15 @@ absl::StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( end_indices_activations_cut[spatial_dimension_to_split] = spatial_split_size; } - TF_ASSIGN_OR_RETURN( - activations, MakeSliceHlo(activations, start_indices_activations_cut, + ASSIGN_OR_RETURN(activations, + MakeSliceHlo(activations, start_indices_activations_cut, end_indices_activations_cut, strides, &activations->metadata(), &activations->frontend_attributes())); } if (first_slice != nullptr) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations, MakeConcatHlo({first_slice, activations}, spatial_dimension_to_split, &activations->metadata(), @@ -631,7 +632,7 @@ absl::StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( } if (halo_region != nullptr) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations, MakeConcatHlo({activations, halo_region}, spatial_dimension_to_split, &activations->metadata(), @@ -639,7 +640,7 @@ absl::StatusOr ConvolutionVisitor::HaloDuplicateWithSlice( } } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations, TransposeAndMergeBatch( activations, @@ -700,8 +701,8 @@ ConvolutionVisitor::BringSpaceNextToBatch( activations_batch_dim = new_batch_dim; spatial_dimension_to_split = new_spatial_dim; - TF_ASSIGN_OR_RETURN(activations, - MakeTransposeHlo(activations, transpose_dims)); + ASSIGN_OR_RETURN(activations, + MakeTransposeHlo(activations, transpose_dims)); new_dim_numbers.set_kernel_input_feature_dimension(activations_batch_dim); @@ -734,8 +735,8 @@ ConvolutionVisitor::BringSpaceNextToBatch( activations_batch_dim = new_batch_dim; spatial_dimension_to_split = new_spatial_dim; - TF_ASSIGN_OR_RETURN(activations, - MakeTransposeHlo(activations, transpose_dims)); + ASSIGN_OR_RETURN(activations, + MakeTransposeHlo(activations, transpose_dims)); if (is_backprop) { new_dim_numbers.set_input_feature_dimension(activations_batch_dim); @@ -779,8 +780,8 @@ ConvolutionVisitor::SplitAndTransposeMergedBatch( } // Reshape the output of the new conv into the old convolutions shape. - TF_ASSIGN_OR_RETURN(HloInstruction * batch_split_activations, - MakeReshapeHlo(new_dimensions, activations)); + ASSIGN_OR_RETURN(HloInstruction * batch_split_activations, + MakeReshapeHlo(new_dimensions, activations)); if (spatial_dim_count > 1) { // Transpose such that we get // B, B0, S0, B1, S1,... @@ -800,9 +801,8 @@ ConvolutionVisitor::SplitAndTransposeMergedBatch( batch_dimension + spatial_dim_count + 1 + i; } - TF_ASSIGN_OR_RETURN( - batch_split_activations, - MakeTransposeHlo(batch_split_activations, transpose_dims)); + ASSIGN_OR_RETURN(batch_split_activations, + MakeTransposeHlo(batch_split_activations, transpose_dims)); } return batch_split_activations; } @@ -822,7 +822,7 @@ ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape( const int64_t reshaped_space_size = spatial_dim_size * ctrl_.number_of_splits; // Reshape the output of the new conv into the old convolutions shape. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * batch_split_activations, SplitAndTransposeMergedBatch(activations, batch_dimension, old_batch_size, spatial_dimensions)); @@ -841,9 +841,9 @@ ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape( batch_space_collapse_reshape_dims[spatial_dimension] = reshaped_space_size; } - TF_ASSIGN_OR_RETURN(HloInstruction * batch_space_collapsed_reshape, - MakeReshapeHlo(batch_space_collapse_reshape_dims, - batch_split_activations)); + ASSIGN_OR_RETURN(HloInstruction * batch_space_collapsed_reshape, + MakeReshapeHlo(batch_space_collapse_reshape_dims, + batch_split_activations)); VLOG(3) << "First reshape done"; @@ -866,7 +866,7 @@ ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape( HloInstruction::CreateConstant(LiteralUtil::Zero( batch_space_collapsed_reshape->shape().element_type()))); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( batch_space_collapsed_reshape, MakePadHlo(batch_space_collapsed_reshape, padding, padding_config, &batch_space_collapsed_reshape->metadata(), @@ -881,13 +881,13 @@ ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape( new_spatial_dim_size * ctrl_.number_of_splits; } // This is the slice from halo padding. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( batch_space_collapsed_reshape, MakeSliceHlo(batch_space_collapsed_reshape, start_indices, end_indices, strides, &batch_space_collapsed_reshape->metadata(), &batch_space_collapsed_reshape->frontend_attributes())); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * activations_new, PerformSplitSpace(batch_space_collapsed_reshape, spatial_dimensions, batch_dimension, new_spatial_dim_size, @@ -933,8 +933,8 @@ absl::StatusOr ConvolutionVisitor::Run() { if (producer) { if (CanPropagate(instr, producer)) { bool needs_further_propagation; - TF_ASSIGN_OR_RETURN(needs_further_propagation, - Propagate(instr, producer)); + ASSIGN_OR_RETURN(needs_further_propagation, + Propagate(instr, producer)); CHECK_OK(computation_->ReplaceInstruction(instr, old_to_new_instrs_[instr])); continue; @@ -945,8 +945,8 @@ absl::StatusOr ConvolutionVisitor::Run() { absl::flat_hash_map operand_map; for (int64_t i = 0; i < instr->operand_count(); ++i) { if (old_to_new_instrs_.count(instr->mutable_operand(i))) { - TF_ASSIGN_OR_RETURN(operand_map[i], - BatchToSpace(instr->mutable_operand(i))); + ASSIGN_OR_RETURN(operand_map[i], + BatchToSpace(instr->mutable_operand(i))); } } for (auto entry : operand_map) { @@ -1931,8 +1931,8 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, new_dimensions[space_dim] *= (batch_size / old_batch_size); new_dimensions[batch_dim] = old_batch_size; - TF_ASSIGN_OR_RETURN(HloInstruction * reshape, - MakeReshapeHlo(new_dimensions, new_instr)); + ASSIGN_OR_RETURN(HloInstruction * reshape, + MakeReshapeHlo(new_dimensions, new_instr)); const int64_t pivot_space_size = pivot_new_instr->shape().dimensions(space_dim) * batch_size / @@ -1949,15 +1949,14 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, HloInstruction* padding = consumer->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(reshape->shape().element_type()))); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * padded_operand, MakePadHlo(reshape, padding, padding_config, &reshape->metadata(), &reshape->frontend_attributes())); - TF_ASSIGN_OR_RETURN( - operand_to_use, - MakeReshapeHlo(pivot_new_instr->shape().dimensions(), - padded_operand)); + ASSIGN_OR_RETURN(operand_to_use, + MakeReshapeHlo(pivot_new_instr->shape().dimensions(), + padded_operand)); } else { operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)]; @@ -1973,7 +1972,7 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape( i, old_to_new_instrs_[consumer->mutable_operand(i)])); } else if (consumer->operand(i)->opcode() == HloOpcode::kConstant) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto new_constant, PropagateOnConstant(consumer->mutable_operand(i), producer)); CHECK_OK( @@ -2125,7 +2124,7 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, } HloInstruction* new_consumer = nullptr; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( new_consumer, MakeReduceHlo(first_operand, consumer->mutable_operand(1), changed_dims, consumer->called_computations()[0])); @@ -2190,7 +2189,7 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, std::vector old_spatial_dims = retval.first; std::vector new_spatial_dims = retval.second; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( first_operand, SelectValidPortion(first_operand, consumer->mutable_operand(0), consumer->mutable_operand(1), new_batch_dim, @@ -2292,7 +2291,7 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, HloInstruction::CreateConstant(LiteralUtil::MinValue( consumer->operand(2)->shape().element_type()))) : init_val; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( first_operand, SelectValidPortion(first_operand, consumer->mutable_operand(0), pad_val, new_batch_dim, new_spatial_dims, old_batch_dim, @@ -2308,12 +2307,12 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, if ((new_space_size - extra_space) * old_batch_size * ctrl_.number_of_splits >= old_batch_size * old_space_size) { - TF_ASSIGN_OR_RETURN( - first_operand, ChangeSpatialSizeOnSpaceToBatchedShape( - first_operand, new_batch_dim, old_batch_size, - new_spatial_dims, new_space_size - extra_space)); + ASSIGN_OR_RETURN(first_operand, + ChangeSpatialSizeOnSpaceToBatchedShape( + first_operand, new_batch_dim, old_batch_size, + new_spatial_dims, new_space_size - extra_space)); } else { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( first_operand, ChangeSpatialSizeOnSpaceToBatchedShape( first_operand, new_batch_dim, old_batch_size, new_spatial_dims, @@ -2329,7 +2328,7 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, const int64_t halo_size = last_overlap_point + window_size - new_space_size; if (halo_size > 0) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( first_operand, HaloDuplicateWithSlice(first_operand, new_spatial_dims, new_batch_dim, /*low_padding=*/0, halo_size, init_val)); @@ -2369,12 +2368,11 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, auto select_comp = consumer->select(); auto scatter_comp = consumer->scatter(); - TF_ASSIGN_OR_RETURN( - auto new_select_and_scatter_shape, - ShapeInference::InferSelectAndScatterShape( - new_shape, select_comp->ComputeProgramShape(), new_win, - second_operand->shape(), init_val->shape(), - scatter_comp->ComputeProgramShape())); + ASSIGN_OR_RETURN(auto new_select_and_scatter_shape, + ShapeInference::InferSelectAndScatterShape( + new_shape, select_comp->ComputeProgramShape(), + new_win, second_operand->shape(), init_val->shape(), + scatter_comp->ComputeProgramShape())); new_consumer = computation_->AddInstruction( HloInstruction::CreateSelectAndScatter( new_select_and_scatter_shape, first_operand, select_comp, new_win, @@ -2405,11 +2403,10 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, end_indices[new_batch_dim] = batch_size - 1; // This is the slice from halo padding. - TF_ASSIGN_OR_RETURN( - HloInstruction * bottom, - MakeSliceHlo(new_consumer, start_indices, end_indices, strides, - &consumer->metadata(), - &consumer->frontend_attributes())); + ASSIGN_OR_RETURN(HloInstruction * bottom, + MakeSliceHlo(new_consumer, start_indices, end_indices, + strides, &consumer->metadata(), + &consumer->frontend_attributes())); std::vector start_indices_top(rank, 0), end_indices_top(new_consumer->shape().dimensions().begin(), @@ -2419,7 +2416,7 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, start_indices_top[new_batch_dim] = 1; // This is the original area from where halo pad was extracted. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * top, MakeSliceHlo(new_consumer, start_indices_top, end_indices_top, strides, &consumer->metadata(), @@ -2430,7 +2427,7 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, &init_val->frontend_attributes()); // Compare to see if the bottom area was changed. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * bottom_compare, // TODO(hanrach): Verify that this is the correct metadata MakeCompareHlo(ComparisonDirection::kNe, bottom, default_fill, @@ -2438,42 +2435,40 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, &bottom->frontend_attributes())); // Take out only the changed values. - TF_ASSIGN_OR_RETURN( - HloInstruction * bottom_taken, - MakeSelectHlo(bottom_compare, bottom, default_fill, nullptr, - &bottom_compare->metadata(), - &bottom_compare->frontend_attributes())); + ASSIGN_OR_RETURN(HloInstruction * bottom_taken, + MakeSelectHlo(bottom_compare, bottom, default_fill, + nullptr, &bottom_compare->metadata(), + &bottom_compare->frontend_attributes())); // Compare to see if the top area was changed. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * top_compare, MakeCompareHlo(ComparisonDirection::kNe, top, default_fill, &top->metadata(), &top->frontend_attributes())); // Take out only the changed values. - TF_ASSIGN_OR_RETURN(HloInstruction * top_taken, - MakeSelectHlo(top_compare, top, bottom_taken, - nullptr, &top_compare->metadata(), - &top_compare->frontend_attributes())); + ASSIGN_OR_RETURN(HloInstruction * top_taken, + MakeSelectHlo(top_compare, top, bottom_taken, nullptr, + &top_compare->metadata(), + &top_compare->frontend_attributes())); // This makes checks if the area was updated by both overlaps. - TF_ASSIGN_OR_RETURN(HloInstruction * both_compare, - MakeBinaryHlo(HloOpcode::kAnd, top_compare, - bottom_compare, &consumer->metadata(), - &consumer->frontend_attributes())); + ASSIGN_OR_RETURN(HloInstruction * both_compare, + MakeBinaryHlo(HloOpcode::kAnd, top_compare, + bottom_compare, &consumer->metadata(), + &consumer->frontend_attributes())); // If it was, add them up. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * both_added, MakeBinaryHlo(HloOpcode::kAdd, top, bottom, &consumer->metadata(), &consumer->frontend_attributes())); // Pad the final result to the original shape. - TF_ASSIGN_OR_RETURN( - HloInstruction * final_selection, - MakeSelectHlo(both_compare, both_added, top_taken, nullptr, - &both_compare->metadata(), - &both_compare->frontend_attributes())); + ASSIGN_OR_RETURN(HloInstruction * final_selection, + MakeSelectHlo(both_compare, both_added, top_taken, + nullptr, &both_compare->metadata(), + &both_compare->frontend_attributes())); PaddingConfig padding_config = MakeNoPaddingConfig(final_selection->shape().dimensions().size()); @@ -2486,11 +2481,10 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, LiteralUtil::Zero(final_selection->shape().element_type())), &consumer->metadata(), &consumer->frontend_attributes()); - TF_ASSIGN_OR_RETURN( - final_selection, - MakePadHlo(final_selection, padding, padding_config, - &final_selection->metadata(), - &final_selection->frontend_attributes())); + ASSIGN_OR_RETURN(final_selection, + MakePadHlo(final_selection, padding, padding_config, + &final_selection->metadata(), + &final_selection->frontend_attributes())); tsl::core::Bitmap b(batch_size * (new_space_size + halo_size)); for (int k = 0; k < batch_size * (new_space_size + halo_size); ++k) { @@ -2513,9 +2507,8 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, slice_mask_reshape_dims[0] = batch_size; slice_mask_reshape_dims[1] = (new_space_size + halo_size); - TF_ASSIGN_OR_RETURN( - HloInstruction * slice_mask_reshaped, - MakeReshapeHlo(slice_mask_reshape_dims, slice_mask)); + ASSIGN_OR_RETURN(HloInstruction * slice_mask_reshaped, + MakeReshapeHlo(slice_mask_reshape_dims, slice_mask)); // Broadcast the mask in all dimensions. HloInstruction* shape_mask = MakeBroadcastHlo( @@ -2523,7 +2516,7 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, final_selection->shape().dimensions(), &slice_mask->metadata(), &slice_mask->frontend_attributes()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( new_consumer, MakeSelectHlo(shape_mask, new_consumer, final_selection, nullptr, &shape_mask->metadata(), @@ -2537,16 +2530,16 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, previous_shape.dimensions().end()), strides(previous_shape.dimensions().size(), 1); - TF_ASSIGN_OR_RETURN(new_consumer, - MakeSliceHlo(new_consumer, start_indices, end_indices, - strides, &consumer->metadata(), - &consumer->frontend_attributes())); + ASSIGN_OR_RETURN(new_consumer, + MakeSliceHlo(new_consumer, start_indices, end_indices, + strides, &consumer->metadata(), + &consumer->frontend_attributes())); } else { auto reduce_comp = consumer->to_apply(); - TF_ASSIGN_OR_RETURN(auto new_reduce_window_shape, - ShapeInference::InferReduceWindowShape( - new_shape, init_val->shape(), new_win)); + ASSIGN_OR_RETURN(auto new_reduce_window_shape, + ShapeInference::InferReduceWindowShape( + new_shape, init_val->shape(), new_win)); new_consumer = computation_->AddInstruction( HloInstruction::CreateReduceWindow(new_reduce_window_shape, first_operand, init_val, new_win, @@ -2636,8 +2629,8 @@ absl::StatusOr ConvolutionVisitor::SelectValidPortion( new_space_size); slice_mask_reshape_dims[0] = new_batch_size; - TF_ASSIGN_OR_RETURN(HloInstruction * slice_mask_reshaped, - MakeReshapeHlo(slice_mask_reshape_dims, slice_mask)); + ASSIGN_OR_RETURN(HloInstruction * slice_mask_reshaped, + MakeReshapeHlo(slice_mask_reshape_dims, slice_mask)); std::vector broadcast_dims(new_space_dims.begin(), new_space_dims.end()); @@ -2654,10 +2647,10 @@ absl::StatusOr ConvolutionVisitor::SelectValidPortion( select_val, {}, new_instr->shape().dimensions(), &select_val->metadata(), &select_val->frontend_attributes()); - TF_ASSIGN_OR_RETURN(new_instr, - MakeSelectHlo(shape_mask, new_instr, zeroes, nullptr, - &shape_mask->metadata(), - &shape_mask->frontend_attributes())); + ASSIGN_OR_RETURN(new_instr, + MakeSelectHlo(shape_mask, new_instr, zeroes, nullptr, + &shape_mask->metadata(), + &shape_mask->frontend_attributes())); return new_instr; } @@ -2692,9 +2685,9 @@ absl::StatusOr ConvolutionVisitor::BatchToSpace( ctrl_.count_of_dimensions_to_convert); absl::c_iota(split_spatial_dimensions, space_dim); - TF_ASSIGN_OR_RETURN(new_instr, SplitAndTransposeMergedBatch( - new_instr, batch_dim, old_batch_size, - split_spatial_dimensions)); + ASSIGN_OR_RETURN(new_instr, SplitAndTransposeMergedBatch( + new_instr, batch_dim, old_batch_size, + split_spatial_dimensions)); std::vector new_dimensions(new_instr->shape().dimensions().begin(), new_instr->shape().dimensions().end()); @@ -2709,8 +2702,8 @@ absl::StatusOr ConvolutionVisitor::BatchToSpace( } // Reshape the output of the new conv into the old convolutions shape. - TF_ASSIGN_OR_RETURN(HloInstruction * reshape, - MakeReshapeHlo(new_dimensions, new_instr)); + ASSIGN_OR_RETURN(HloInstruction * reshape, + MakeReshapeHlo(new_dimensions, new_instr)); VLOG(1) << "Batch to space reshape " << reshape->ToString(); const int64_t rank = old_instr->shape().dimensions().size(); @@ -2724,14 +2717,14 @@ absl::StatusOr ConvolutionVisitor::BatchToSpace( } // This slicing is getting rid of the padding we added to evenly divide space. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * output_slice, MakeSliceHlo(reshape, start_indices, end_indices, strides, &reshape->metadata(), &reshape->frontend_attributes())); VLOG(1) << "Batch to space slice " << output_slice->ToString(); std::vector transpose_dims(permute_dims); - TF_ASSIGN_OR_RETURN(HloInstruction * output_transpose, - MakeTransposeHlo(output_slice, transpose_dims)); + ASSIGN_OR_RETURN(HloInstruction * output_transpose, + MakeTransposeHlo(output_slice, transpose_dims)); old_instr->SetupDerivedInstruction(output_transpose); batch_to_space_map_[old_instr] = output_transpose; @@ -2742,8 +2735,7 @@ absl::Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) { std::queue> propagation_worklist; if (old_conv->user_count() == 0) { - TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space, - BatchToSpace(old_conv)); + ASSIGN_OR_RETURN(HloInstruction * batch_to_space, BatchToSpace(old_conv)); VLOG(1) << "Replacing the root instruction to " << batch_to_space->ToString(); CHECK_OK(computation_->ReplaceInstruction(old_conv, batch_to_space)); @@ -2770,7 +2762,7 @@ absl::Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) { bool needs_further_propagation = true; if (iteration_count != 0) { // Do the space-to-batch propagation on this node. - TF_ASSIGN_OR_RETURN(needs_further_propagation, Propagate(node, parent)); + ASSIGN_OR_RETURN(needs_further_propagation, Propagate(node, parent)); } iteration_count++; // If this is the root, no room for further propagation. @@ -2784,7 +2776,7 @@ absl::Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) { continue; } - TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space, BatchToSpace(node)); + ASSIGN_OR_RETURN(HloInstruction * batch_to_space, BatchToSpace(node)); VLOG(1) << "Replacing the root instruction to " << batch_to_space->ToString(); CHECK_OK(computation_->ReplaceInstruction(node, batch_to_space)); @@ -2816,8 +2808,7 @@ absl::Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) { } if (!unsupported_users.empty()) { - TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space, - BatchToSpace(node)); + ASSIGN_OR_RETURN(HloInstruction * batch_to_space, BatchToSpace(node)); for (auto user : unsupported_users) { for (int64_t i = 0; i < user->operand_count(); ++i) { if (user->operand(i) == node) { @@ -2873,10 +2864,9 @@ absl::Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { << c.spatial_dimensions_to_split[0] << " old_batch_size " << old_batch_size; - TF_ASSIGN_OR_RETURN( - auto retval, - BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers, - activations_batch_dim, &new_spatial_dims)); + ASSIGN_OR_RETURN(auto retval, BringSpaceNextToBatch( + activations_new, permuted_conv_dims_numbers, + activations_batch_dim, &new_spatial_dims)); activations_new = retval.instr; std::vector trans_dims = retval.transpose_dims; CHECK(!trans_dims.empty()); @@ -2885,7 +2875,7 @@ absl::Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { LiteralUtil::Zero(activations_new->shape().element_type())), &convolution->metadata(), &convolution->frontend_attributes()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations_new, SelectValidPortion(activations_new, activations_old, select_val, activations_batch_dim, new_spatial_dims, old_batch_dim, @@ -2925,12 +2915,11 @@ absl::Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { // In the below case, we cannot use the activations directly for Halo // Duplication. We must reshape them. if (spatial_split_size > new_space_size) { - TF_ASSIGN_OR_RETURN( - activations_new, - ChangeSpatialSizeOnSpaceToBatchedShape( - activations_new, activations_batch_dim, old_batch_size, - new_spatial_dims, spatial_split_size, - /*increase_spatial_size*/ true)); + ASSIGN_OR_RETURN(activations_new, + ChangeSpatialSizeOnSpaceToBatchedShape( + activations_new, activations_batch_dim, old_batch_size, + new_spatial_dims, spatial_split_size, + /*increase_spatial_size*/ true)); } else { // If the ideal spatial_split_size was smaller than the incoming spatial @@ -2944,7 +2933,7 @@ absl::Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { // If there's a stride mismatch, we change the new_space_size be // smaller (equal to spatial_split_size). if (new_space_size % c.stride != 0 || c.base_dilation_factor != 1) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations_new, ChangeSpatialSizeOnSpaceToBatchedShape( activations_new, activations_batch_dim, old_batch_size, @@ -2962,7 +2951,7 @@ absl::Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { // For space-to-batch supported base-dilated convolutions, the low padding is // passed on to the new convolutions. Halo does not have to account for it. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations_new, HaloDuplicateWithSlice( activations_new, new_spatial_dims, activations_batch_dim, @@ -3010,7 +2999,7 @@ absl::Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) { new_window.mutable_dimensions(first_dim + i) ->set_padding_low(c.low_padding_for_conv); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_conv, MakeConvolveHlo( activations_new, /*rhs=*/convolution->mutable_operand(1), @@ -3048,7 +3037,7 @@ absl::Status ConvolutionVisitor::PropagateOnConcat(HloInstruction* concat) { for (int64_t i = 0; i < concat->operand_count(); ++i) { new_operands[i] = old_to_new_instrs_[concat->mutable_operand(i)]; } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_concat, MakeConcatHlo(new_operands, new_concat_dim, &concat->metadata(), &concat->frontend_attributes())); @@ -3071,8 +3060,8 @@ absl::Status ConvolutionVisitor::PropagateOnReverse(HloInstruction* reverse) { for (auto dim : reverse->dimensions()) { new_reverse_dimensions[dim_count++] = DimLookUp(permute_dims, dim); } - TF_ASSIGN_OR_RETURN(HloInstruction * new_reverse, - MakeReverseHlo(first_operand, new_reverse_dimensions)); + ASSIGN_OR_RETURN(HloInstruction * new_reverse, + MakeReverseHlo(first_operand, new_reverse_dimensions)); old_to_new_instrs_[reverse] = new_reverse; // Set mappings from operand 0. instr_to_dim_map_[reverse] = @@ -3099,10 +3088,10 @@ absl::Status ConvolutionVisitor::PropagateOnPad(HloInstruction* pad) { HloInstruction* padding = pad->mutable_operand(1); - TF_ASSIGN_OR_RETURN(auto new_pad, - MakePadHlo(first_operand, padding, padding_config, - &first_operand->metadata(), - &first_operand->frontend_attributes())); + ASSIGN_OR_RETURN(auto new_pad, + MakePadHlo(first_operand, padding, padding_config, + &first_operand->metadata(), + &first_operand->frontend_attributes())); old_to_new_instrs_[pad] = new_pad; // Set mappings from operand 0. @@ -3135,7 +3124,7 @@ absl::Status ConvolutionVisitor::PropagateOnSlice(HloInstruction* slice) { limits[i] = slice->slice_limits(old_dim); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto new_slice, MakeSliceHlo(operand, starts, limits, strides, &operand->metadata(), &operand->frontend_attributes())); @@ -3171,7 +3160,7 @@ absl::StatusOr ConvolutionVisitor::TransposeAndMergeBatch( start_batch_dim_position + i * 2 + 1; } - TF_ASSIGN_OR_RETURN(activations, MakeTransposeHlo(activations, trans_dims)); + ASSIGN_OR_RETURN(activations, MakeTransposeHlo(activations, trans_dims)); } std::vector batch_collapse_reshape_dims( @@ -3187,8 +3176,8 @@ absl::StatusOr ConvolutionVisitor::TransposeAndMergeBatch( spatial_dim_count); batch_collapse_reshape_dims[activations_batch_dim] = collapsed_batch_size; - TF_ASSIGN_OR_RETURN(HloInstruction * batch_collapsed_reshape, - MakeReshapeHlo(batch_collapse_reshape_dims, activations)); + ASSIGN_OR_RETURN(HloInstruction * batch_collapsed_reshape, + MakeReshapeHlo(batch_collapse_reshape_dims, activations)); return batch_collapsed_reshape; } @@ -3229,8 +3218,8 @@ absl::StatusOr ConvolutionVisitor::PerformSplitSpace( counter++; } - TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape, - MakeReshapeHlo(reshape_dimensions, activations)); + ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape, + MakeReshapeHlo(reshape_dimensions, activations)); return TransposeAndMergeBatch( batch_increased_reshape, @@ -3261,10 +3250,10 @@ absl::StatusOr ConvolutionVisitor::PadAndSplitSpace( HloInstruction::CreateConstant( LiteralUtil::Zero(activations->shape().element_type())), &activations->metadata(), &activations->frontend_attributes()); - TF_ASSIGN_OR_RETURN(activations, - MakePadHlo(activations, padding, padding_config, - &activations->metadata(), - &activations->frontend_attributes())); + ASSIGN_OR_RETURN(activations, + MakePadHlo(activations, padding, padding_config, + &activations->metadata(), + &activations->frontend_attributes())); } VLOG(1) << "Initial padded activations shape " << activations->shape().ToString() << " old_batch_size " @@ -3282,14 +3271,14 @@ ConvolutionVisitor::SplitSpace( int64_t spatial_split_size, int64_t num_splits, std::vector* spatial_dimensions_to_split, bool is_backprop, bool is_rhs) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto retval, BringSpaceNextToBatch(activations, dim_numbers, activations_batch_dim, spatial_dimensions_to_split, is_backprop, is_rhs)); activations = retval.instr; std::vector transpose_dims = retval.transpose_dims; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto new_activations, PadAndSplitSpace(activations, *spatial_dimensions_to_split, activations_batch_dim, high_padding, low_padding, @@ -3307,8 +3296,8 @@ absl::StatusOr ConvolutionVisitor::PropagateOnConstant( reversed_transpose_dims[i] = ReverseDimLookUp(prod_transpose_dims, i); } // Bring space next to batch. - TF_ASSIGN_OR_RETURN(consumer, - MakeTransposeHlo(consumer, reversed_transpose_dims)); + ASSIGN_OR_RETURN(consumer, + MakeTransposeHlo(consumer, reversed_transpose_dims)); auto retval = GetSpatialDimsToSplit(producer); std::vector old_spatial_dims = retval.first; @@ -3407,8 +3396,8 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( needed_spatial_size * ctrl_.number_of_splits - old_split_dim_size; ConvolutionDimensionNumbers tmp_dim_numbers; tmp_dim_numbers = original_conv_dims; - TF_ASSIGN_OR_RETURN( - auto retval, SplitSpace(activations_old, tmp_dim_numbers, old_batch_dim, + ASSIGN_OR_RETURN(auto retval, + SplitSpace(activations_old, tmp_dim_numbers, old_batch_dim, /*high_padding=*/pad_size, /*low_padding=*/0, needed_spatial_size, ctrl_.number_of_splits, &old_split_spatial_dims, @@ -3444,7 +3433,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( ConvolutionDimensionNumbers tmp_dim_numbers; tmp_dim_numbers = original_conv_dims; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto retval, SplitSpace(kernel_old, tmp_dim_numbers, kernel_old_batch_dim, /*high_padding=*/pad_size, /*low_padding=*/0, @@ -3548,7 +3537,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( << spatial_dimensions_to_split[0] << " old_batch_size " << old_batch_size << " new_split_dim_size " << new_split_dim_size; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto retval, BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers, activations_batch_dim, &spatial_dimensions_to_split, @@ -3569,7 +3558,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( if (new_split_dim_size != expected_split_dim_size) { CHECK_LT(new_split_dim_size, expected_split_dim_size); new_split_dim_size = expected_split_dim_size; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations_new, ChangeSpatialSizeOnSpaceToBatchedShape( activations_new, activations_batch_dim, old_batch_size, @@ -3584,7 +3573,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( if (!activations_locally_space_to_batched) { // Select activations correctly by masking additional space. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations_new, SelectValidPortion(activations_new, activations_old, select_val, activations_batch_dim, spatial_dimensions_to_split, @@ -3601,7 +3590,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( // IncreaseSpatialSizeOnSpaceToBatchedShape returns all dimensions. new_kernel_split_spatial_dims[0] = kernel_spatial_dimension_to_split; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( kernel_new, SelectValidPortion(kernel_new, kernel_old, select_val, /*new_batch_dim=*/kernel_input_feature_dim, @@ -3636,7 +3625,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( } else { activations_to_use = activations_chunks.back(); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * activations_slice, HaloDuplicateWithSlice(activations_to_use, spatial_dimensions_to_split, activations_batch_dim, /*low_padding=*/1, @@ -3668,7 +3657,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( if (i == 0) { activations_to_use = activations_new; if (inherent_low_padding < 0) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations_slice, HaloDuplicateWithSlice( activations_to_use, spatial_dimensions_to_split, @@ -3680,11 +3669,11 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( } else { activations_to_use = activations_chunks.back(); - TF_ASSIGN_OR_RETURN(activations_slice, - HaloDuplicateWithSlice( - activations_to_use, spatial_dimensions_to_split, - activations_batch_dim, /*low_padding=*/-1, - /*halo_size=*/0)); + ASSIGN_OR_RETURN(activations_slice, + HaloDuplicateWithSlice( + activations_to_use, spatial_dimensions_to_split, + activations_batch_dim, /*low_padding=*/-1, + /*halo_size=*/0)); } activations_chunks.push_back(activations_slice); @@ -3705,7 +3694,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( HloInstruction* activations_to_use = nullptr; activations_to_use = activations_chunks.back(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * activations_slice, HaloDuplicateWithSlice(activations_to_use, spatial_dimensions_to_split, activations_batch_dim, @@ -3719,13 +3708,13 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( activations_chunks[i]->shape().dimensions().end()); // Insert 1-sized dimension at the end input_sizes.push_back(1); - TF_ASSIGN_OR_RETURN(activations_chunks[i], - MakeReshapeHlo(input_sizes, activations_chunks[i])); + ASSIGN_OR_RETURN(activations_chunks[i], + MakeReshapeHlo(input_sizes, activations_chunks[i])); VLOG(1) << "new_spatial_dimension " << new_spatial_dimension << " slice " << activations_chunks[i]->ToString(); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations_new, MakeConcatHlo(absl::MakeSpan(activations_chunks), new_spatial_dimension, &activations_old->metadata(), @@ -3736,7 +3725,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( kernel_new->shape().dimensions().end()); // Insert 1-sized dimension at the end kernel_sizes.push_back(1); - TF_ASSIGN_OR_RETURN(kernel_new, MakeReshapeHlo(kernel_sizes, kernel_new)); + ASSIGN_OR_RETURN(kernel_new, MakeReshapeHlo(kernel_sizes, kernel_new)); auto new_window = convolution->window(); new_window.mutable_dimensions(GetFirstChosenSpatialDim(convolution)) @@ -3762,7 +3751,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( window_dim->set_window_reversal(false); window_dim->set_window_dilation(1); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_conv, MakeConvolveHlo( activations_new, kernel_new, convolution->feature_group_count(), @@ -3781,7 +3770,7 @@ absl::Status ConvolutionVisitor::PropagateOnBackpropFilterConv( new_dim_numbers.output_spatial_dimensions( GetFirstChosenSpatialDim(convolution))); - TF_ASSIGN_OR_RETURN(new_conv, MakeReshapeHlo(output_sizes, new_conv)); + ASSIGN_OR_RETURN(new_conv, MakeReshapeHlo(output_sizes, new_conv)); old_to_new_instrs_[convolution] = new_conv; VLOG(1) << "Space-to-featured convolution " << new_conv->ToString(); @@ -4063,7 +4052,7 @@ absl::Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( << " kernel_spatial_dim_size " << c.kernel_spatial_dim_size; std::vector spatial_dimensions_to_split = c.spatial_dimensions_to_split; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto retval, SplitSpace( activations, dim_numbers, activations_batch_dim, @@ -4078,7 +4067,7 @@ absl::Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( VLOG(1) << "First reshape done " << batch_increased_reshape->ToString(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( activations, HaloDuplicateWithSlice( batch_increased_reshape, spatial_dimensions_to_split, @@ -4129,7 +4118,7 @@ absl::Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( new_window.mutable_dimensions(first_dim + i) ->set_padding_low(c.low_padding_for_conv); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * new_conv, MakeConvolveHlo( activations, /*rhs=*/convolution->mutable_operand(1), @@ -4164,7 +4153,7 @@ absl::Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution( LiteralUtil::Zero(new_conv->shape().element_type())), &convolution->metadata(), &convolution->frontend_attributes()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( new_conv, SelectValidPortion(new_conv, original_conv, select_val, output_batch_dim, new_output_split_spatial_dims, diff --git a/third_party/xla/xla/service/spmd/shardy/BUILD b/third_party/xla/xla/service/spmd/shardy/BUILD index efca1b656bea07..2457391f9afe41 100644 --- a/third_party/xla/xla/service/spmd/shardy/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/BUILD @@ -43,6 +43,7 @@ cc_library( "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc index 67b53d891476ef..6c1cdd29c96373 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc @@ -245,6 +245,7 @@ void convertShardyAttrsWithHloShardingV3(FuncOp funcOp) { } else if (auto customCallOp = mlir::dyn_cast(op)) { StringRef targetName = customCallOp.getCallTargetName(); if (targetName == kShardingCustomCallTargetName || + targetName == "X64Combine" || isPythonCallbackCustomCall(customCallOp)) { customCallOp->setAttr( kShardingAttr, diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD index 970b1b2dd80872..1984028b06ce1a 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD @@ -24,6 +24,7 @@ cc_library( "//xla/hlo/translate:stablehlo", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/stablehlo_to_hlo_to_stablehlo.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/stablehlo_to_hlo_to_stablehlo.cc index e173308a260616..906144b8f3e3e1 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/stablehlo_to_hlo_to_stablehlo.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/stablehlo_to_hlo_to_stablehlo.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinOps.h" @@ -49,8 +50,8 @@ using ::mlir::StringRef; // Converts a StableHLO module to an HLO module. absl::StatusOr> toHlo(ModuleOp module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr hloModule, - xla::ConvertStablehloToHlo(module)); + ASSIGN_OR_RETURN(std::unique_ptr hloModule, + xla::ConvertStablehloToHlo(module)); hloModule->mutable_config().set_use_spmd_partitioning(true); return hloModule; } @@ -58,7 +59,7 @@ absl::StatusOr> toHlo(ModuleOp module) { // Converts an HLO module to a StableHLO module. absl::Status toStablehlo(std::unique_ptr hloModule, ModuleOp& module) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( mlir::OwningOpRef newModule, xla::ConvertHloToStablehlo(*module->getContext(), hloModule.get())); // Erase the old body region and replace it with the new one. diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc index 01dd8528d4d9c1..a46a788a2b100f 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" @@ -100,7 +101,7 @@ absl::Status createFromProtoAndReplaceComputations( // Create HLO computations from proto. for (const HloComputationProto& computationProto : proto.computations()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr computation, HloComputation::CreateFromProto(computationProto, idToComputation)); CHECK_NE(computation.get(), nullptr); @@ -136,8 +137,8 @@ absl::Status createFromProtoAndReplaceComputations( // Remove the old computations, which are currently dead. CHECK_OK(HloDCE().Run(module)); - TF_ASSIGN_OR_RETURN(StackFrames stack_frames, - StackFrames::FromProto(proto.stack_frame_index())); + ASSIGN_OR_RETURN(StackFrames stack_frames, + StackFrames::FromProto(proto.stack_frame_index())); module->set_stack_frames(std::move(stack_frames)); return absl::OkStatus(); } @@ -348,7 +349,7 @@ absl::Status runShardingPropagation(HloModule* hloModule, tsl::io::JoinPath(shardyDir, "shardy", uniqueModuleName(*hloModule)); LOG(INFO) << "Using Shardy output directory: " << shardyDir; } - TF_RETURN_IF_ERROR(tsl::Env::Default()->RecursivelyCreateDir(shardyDir)); + RETURN_IF_ERROR(tsl::Env::Default()->RecursivelyCreateDir(shardyDir)); // MLIR pipeline: (1) import, (2) Shardy, and (3) export. bool enableVerifier = false; @@ -458,9 +459,8 @@ absl::StatusOr ShardyXLA::RunImpl( // HLO -> StableHLO auto mlirContext = std::make_unique(); loadAllRequiredDialects(mlirContext.get()); - TF_ASSIGN_OR_RETURN( - mlir::OwningOpRef mlirModule, - xla::ConvertHloToStablehlo(*mlirContext.get(), hloModule)); + ASSIGN_OR_RETURN(mlir::OwningOpRef mlirModule, + xla::ConvertHloToStablehlo(*mlirContext.get(), hloModule)); // Store the entry computation layout, input-output alias config, and buffer // donors, which will be restored in the end, since MLIR does not preserve @@ -480,7 +480,7 @@ absl::StatusOr ShardyXLA::RunImpl( useTupleArgs); if (runSdyShardingPropagation) { - TF_RETURN_IF_ERROR(runShardingPropagation( + RETURN_IF_ERROR(runShardingPropagation( hloModule, mlirModule.get(), importMhloShardings, propagationOptions, enableNativeNonFlatSupport, name())); } @@ -494,9 +494,9 @@ absl::StatusOr ShardyXLA::RunImpl( // StableHlo -> HLO HloProto hloProto; - TF_RETURN_IF_ERROR(ConvertStablehloWithManyArgsToHloProto( - *mlirModule, &hloProto, useTupleArgs)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(ConvertStablehloWithManyArgsToHloProto(*mlirModule, &hloProto, + useTupleArgs)); + RETURN_IF_ERROR( createFromProtoAndReplaceComputations(hloModule, hloProto.hlo_module())); // If the module returns a single tensor as result with sharding, @@ -511,11 +511,10 @@ absl::StatusOr ShardyXLA::RunImpl( std::move(flattenedInputOutputAliasConfig)); hloModule->set_buffer_donor_config(std::move(flattenedBufferDonorsConfig)); - TF_RETURN_IF_ERROR( - hlo_sharding_util::CanonicalizeLayoutAfterShardingPropagation( - hloModule, - hloModule->config().allow_spmd_sharding_propagation_to_output(), - hloModule->config().allow_spmd_sharding_propagation_to_parameters())); + RETURN_IF_ERROR(hlo_sharding_util::CanonicalizeLayoutAfterShardingPropagation( + hloModule, + hloModule->config().allow_spmd_sharding_propagation_to_output(), + hloModule->config().allow_spmd_sharding_propagation_to_parameters())); // We don't fully replace the HLO module, so it will continue to have the // temporary frontend attributes. So clean them up as XLA won't need them. diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline_hlo_sharding_v3.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline_hlo_sharding_v3.mlir index 5e5b1b8987f8a8..67790a22feaf3f 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline_hlo_sharding_v3.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline_hlo_sharding_v3.mlir @@ -32,6 +32,21 @@ module @module_1 { return %arg0, %arg1, %arg0, %arg1, %arg1, %arg2 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32> } + // CHECK-LABEL: func @x64_combine(%arg0: tensor<16xi64>) + // CHECK-SAME: -> (tensor<16xi64> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>}) { + func.func @x64_combine( + %arg0: tensor<16xi64>) -> (tensor<16xi64> {mhlo.sharding = "{mesh['a'=8,'b'=8,'c'=8], [{'a'}]}"}) { + // CHECK-NEXT: %[[SPLIT_LOW:.*]] = stablehlo.custom_call @X64SplitLow(%arg0) : (tensor<16xi64>) -> tensor<16xui32> + // CHECK-NEXT: %[[SPLIT_HIGH:.*]] = stablehlo.custom_call @X64SplitHigh(%arg0) : (tensor<16xi64>) -> tensor<16xui32> + // CHECK-NEXT: %[[COMBINE:.*]] = stablehlo.custom_call @X64Combine(%[[SPLIT_LOW]], %[[SPLIT_HIGH]]) + // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : (tensor<16xui32>, tensor<16xui32>) -> tensor<16xi64> + // CHECK-NEXT: return %[[COMBINE]] + %0 = stablehlo.custom_call @X64SplitLow(%arg0) : (tensor<16xi64>) -> tensor<16xui32> + %1 = stablehlo.custom_call @X64SplitHigh(%arg0) : (tensor<16xi64>) -> tensor<16xui32> + %6 = stablehlo.custom_call @X64Combine(%0, %1) {mhlo.sharding = "{mesh['a'=8,'b'=8,'c'=8], [{'a'}]}"} : (tensor<16xui32>, tensor<16xui32>) -> tensor<16xi64> + return %6 : tensor<16xi64> + } + // CHECK-LABEL: func @while_with_free_variables func.func @while_with_free_variables( %arg0: tensor<32x96xf32>, diff --git a/third_party/xla/xla/service/topk_rewriter.cc b/third_party/xla/xla/service/topk_rewriter.cc index f72ed2f4bc92a2..68eafad93526ba 100644 --- a/third_party/xla/xla/service/topk_rewriter.cc +++ b/third_party/xla/xla/service/topk_rewriter.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/builder/lib/comparators.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" @@ -402,9 +403,9 @@ absl::StatusOr TopkRewriter::TransformPatternToCustomCall( HloInstruction* gte = user; for (HloInstruction* slice : gte->users()) { if (gte->tuple_index() == 0) { - TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(topkcc.value_gte)); + RETURN_IF_ERROR(slice->ReplaceAllUsesWith(topkcc.value_gte)); } else if (gte->tuple_index() == 1) { - TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(topkcc.index_gte)); + RETURN_IF_ERROR(slice->ReplaceAllUsesWith(topkcc.index_gte)); } else { // The line below should be unreachable. SortIsInTopK() already checks // that sort has either 1 or 2 operands. Reaching this line indicates @@ -414,7 +415,7 @@ absl::StatusOr TopkRewriter::TransformPatternToCustomCall( } } } else { - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(topkcc.value_gte)); + RETURN_IF_ERROR(user->ReplaceAllUsesWith(topkcc.value_gte)); } } @@ -427,8 +428,8 @@ absl::StatusOr TopkRewriter::TransformToCustomCall( bool changed = false; for (HloComputation* comp : module->computations(execution_threads)) { for (HloInstruction* inst : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(HloInstruction * topkcc, - TransformPatternToCustomCall(inst)); + ASSIGN_OR_RETURN(HloInstruction * topkcc, + TransformPatternToCustomCall(inst)); if (topkcc != nullptr) { VLOG(2) << "Rewritten Topk: " << topkcc->ToString(); changed = true; @@ -442,8 +443,8 @@ absl::StatusOr TopkRewriter::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; - TF_ASSIGN_OR_RETURN(auto transform_to_customcall_changed, - TransformToCustomCall(module, execution_threads)); + ASSIGN_OR_RETURN(auto transform_to_customcall_changed, + TransformToCustomCall(module, execution_threads)); changed |= transform_to_customcall_changed; return changed; } @@ -469,8 +470,8 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor { if (should_decompose_ && !should_decompose_(topk)) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(HloComputation * comparator, - CreateVariadicComparator(topk)); + ASSIGN_OR_RETURN(HloComputation * comparator, + CreateVariadicComparator(topk)); return DecomposeTopK(topk, comparator); } @@ -493,7 +494,7 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor { XlaComputation comparison = topk->largest() ? CreateScalarGtComputation(ptypes, &b) : CreateScalarLtComputation(ptypes, &b); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloComputation * comparator, XlaComputationToHloComputation(comparison, topk->parent()->parent())); return comparator; @@ -514,7 +515,7 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor { HloInstruction* sort = call->AddInstruction(HloInstruction::CreateSort( input->shape(), sort_dimension, {input}, variadic_comparator, /*is_stable=*/true)); - TF_RETURN_IF_ERROR(ReplaceInstruction( + RETURN_IF_ERROR(ReplaceInstruction( call->users().front(), call->AddInstruction(HloInstruction::CreateSlice( call->shape().tuple_shapes(0), sort, zeroes, @@ -534,7 +535,7 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor { sort->shape().tuple_shapes(index), sort, index)), zeroes, call->shape().tuple_shapes(index).dimensions(), ones)); }; - TF_RETURN_IF_ERROR(ReplaceInstruction( + RETURN_IF_ERROR(ReplaceInstruction( call, call->AddInstruction(HloInstruction::CreateTuple( {slice_tuple(0), slice_tuple(1)})))); } diff --git a/third_party/xla/xla/service/transfer_manager.cc b/third_party/xla/xla/service/transfer_manager.cc index 312dbb442f22ee..43bbf9f566cba8 100644 --- a/third_party/xla/xla/service/transfer_manager.cc +++ b/third_party/xla/xla/service/transfer_manager.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/literal.h" #include "xla/service/compiler.h" #include "xla/service/maybe_owning_device_address.h" @@ -62,8 +63,8 @@ absl::StatusOr TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, const TransferMetadata* transfer_metadata) { Literal literal(device_buffer.on_host_shape()); - TF_RETURN_IF_ERROR(TransferLiteralFromDevice(stream, device_buffer, &literal, - transfer_metadata)); + RETURN_IF_ERROR(TransferLiteralFromDevice(stream, device_buffer, &literal, + transfer_metadata)); return std::move(literal); } @@ -71,8 +72,8 @@ absl::Status TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, const MutableBorrowingLiteral& literal, const TransferMetadata* transfer_metadata) { - TF_ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); - TF_RETURN_IF_ERROR(substream->WaitFor(stream)); + ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); + RETURN_IF_ERROR(substream->WaitFor(stream)); absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); }; absl::Status ret; @@ -95,10 +96,10 @@ absl::Status TransferManager::TransferLiteralToDevice( // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - TF_ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); - TF_RETURN_IF_ERROR(substream->WaitFor(stream)); + ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); + RETURN_IF_ERROR(substream->WaitFor(stream)); absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); }; - TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync( + RETURN_IF_ERROR(TransferLiteralToDeviceAsync( substream, literal, device_buffer, transfer_metadata)); return substream->BlockHostUntilDone(); } @@ -112,8 +113,8 @@ absl::StatusOr TransferManager::TransferArrayFromDevice( Literal literal(shape); ShapedBuffer shaped_buffer(shape, stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); - TF_RETURN_IF_ERROR(TransferLiteralFromDevice(stream, shaped_buffer, &literal, - transfer_metadata)); + RETURN_IF_ERROR(TransferLiteralFromDevice(stream, shaped_buffer, &literal, + transfer_metadata)); return std::move(literal); } @@ -124,10 +125,10 @@ absl::Status TransferManager::TransferArrayToDevice( // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - TF_ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); - TF_RETURN_IF_ERROR(substream->WaitFor(stream)); + ASSIGN_OR_RETURN(se::Stream * substream, stream->GetOrCreateSubStream()); + RETURN_IF_ERROR(substream->WaitFor(stream)); absl::Cleanup cleanup = [&]() { stream->ReturnSubStream(substream); }; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata)); return substream->BlockHostUntilDone(); } @@ -149,12 +150,11 @@ absl::Status TransferManager::ReadDynamicShapes( Shape* device_shape) { DCHECK(device_shape->is_dynamic()); Shape original_device_shape = *device_shape; - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + RETURN_IF_ERROR(stream->BlockHostUntilDone()); - TF_ASSIGN_OR_RETURN( - auto compiler, - Compiler::GetForPlatform(stream->parent()->GetPlatform()->id())); - TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachElementWithStatus( + ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform( + stream->parent()->GetPlatform()->id())); + RETURN_IF_ERROR(device_buffer->buffers().ForEachElementWithStatus( [&](const ShapeIndex& index, const se::DeviceAddressBase& buffer) -> absl::Status { const Shape& buffer_shape = @@ -179,7 +179,7 @@ absl::Status TransferManager::ReadDynamicShapes( } auto buffer_8 = se::DeviceAddress(buffer); auto metadata_buffer = buffer_8.GetSlice(offset, metadata_size); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto metadata, TransferArrayFromDevice( stream, @@ -233,7 +233,7 @@ absl::Status TransferManager::ReadDynamicShapes( absl::Status TransferManager::WriteTupleIndexTables( se::Stream* stream, const ShapedBuffer& device_buffer) { - TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); + RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); return stream->BlockHostUntilDone(); } @@ -318,7 +318,7 @@ absl::StatusOr TransferManager::AllocateScopedShapedBuffer( return InvalidArgument("Shape must have a layout: %s", ShapeUtil::HumanStringWithLayout(on_host_shape)); } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); + RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); Shape on_device_shape = (shape_representation_fn == nullptr) ? HostShapeToDeviceShape(on_host_shape) : shape_representation_fn(on_host_shape); @@ -334,11 +334,11 @@ absl::StatusOr TransferManager::AllocateScopedShapedBuffer( se::DeviceAddressBase& memory_base = pair.second; const Shape& subshape = ShapeUtil::GetSubshape(shaped_buffer.on_device_shape(), index); - TF_ASSIGN_OR_RETURN(auto memory, - allocator->Allocate(shaped_buffer.device_ordinal(), - GetByteSizeRequirement(subshape), - /*retry_on_failure=*/true, - LayoutUtil::MemorySpace(subshape))); + ASSIGN_OR_RETURN(auto memory, + allocator->Allocate(shaped_buffer.device_ordinal(), + GetByteSizeRequirement(subshape), + /*retry_on_failure=*/true, + LayoutUtil::MemorySpace(subshape))); // Move the allocated buffer into the ScopedShapedBuffer, which owns it. memory_base = memory.Release(); } diff --git a/third_party/xla/xla/service/transpose_folding.cc b/third_party/xla/xla/service/transpose_folding.cc index a04814c31dc262..22454385f7c39e 100644 --- a/third_party/xla/xla/service/transpose_folding.cc +++ b/third_party/xla/xla/service/transpose_folding.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -214,8 +215,8 @@ absl::StatusOr TransposeFolding::RunImpl( continue; } - TF_ASSIGN_OR_RETURN(bool can_fold_operand, - dot_can_fold_transpose_operand_(*instruction, i)); + ASSIGN_OR_RETURN(bool can_fold_operand, + dot_can_fold_transpose_operand_(*instruction, i)); if (can_fold_operand) { operand_indices.push_back(i); @@ -238,12 +239,12 @@ absl::StatusOr TransposeFolding::RunImpl( }); for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { - TF_RETURN_IF_ERROR(comp->Accept(&visit_fn)); + RETURN_IF_ERROR(comp->Accept(&visit_fn)); } bool changed = false; for (InstructionOperandsPair& pair : foldable_dots) { - TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair)); + RETURN_IF_ERROR(FoldTransposeIntoDot(pair)); changed = true; } for (InstructionOperandsPair& pair : foldable_convolutions) { diff --git a/third_party/xla/xla/service/triangular_solve_expander.cc b/third_party/xla/xla/service/triangular_solve_expander.cc index 5c8577a47eca98..ab187456107613 100644 --- a/third_party/xla/xla/service/triangular_solve_expander.cc +++ b/third_party/xla/xla/service/triangular_solve_expander.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/builder/lib/constants.h" #include "xla/hlo/builder/lib/math.h" #include "xla/hlo/builder/lib/matrix.h" @@ -51,7 +52,7 @@ namespace { XlaOp DiagonalBlocks(XlaOp a, int64_t block_size) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); + ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); int ndims = shape.dimensions().size(); int64_t n = ShapeUtil::GetDimension(shape, -1); int64_t num_blocks = n / block_size; @@ -116,7 +117,7 @@ XlaOp DiagonalBlocks(XlaOp a, int64_t block_size) { // Add a singleton dimension // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] - TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); + ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); auto shape_dims = blocks_shape.dimensions(); auto last_blocks_dims = std::vector(ndims); absl::c_copy(shape_dims, last_blocks_dims.begin()); @@ -142,11 +143,11 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks)); - TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks)); + ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); int64_t block_size = ShapeUtil::GetDimension(blocks_shape, -1); - TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); int64_t ndims = a_shape.dimensions().size(); int64_t n = ShapeUtil::GetDimension(a_shape, -1); int64_t num_blocks = n / block_size + (n % block_size != 0); @@ -253,7 +254,7 @@ XlaOp TriangularSolveExpander::InvertDiagonalBlocks( return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is // (..., size, size). We resize this to (num_blocks, size, size). - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); + ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); int64_t block_size = ShapeUtil::GetDimension(shape, -1); int64_t num_blocks = ShapeUtil::ElementsIn(shape) / IPow(block_size, 2); diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); @@ -321,7 +322,7 @@ XlaOp TriangularSolveExpander::InvertDiagonalBlocks( Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); Lt(i, ConstantR0(condb.get(), block_size)); } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + ASSIGN_OR_RETURN(auto cond, condb->Build()); // Construct the loop body function. std::unique_ptr bodyb = @@ -356,7 +357,7 @@ XlaOp TriangularSolveExpander::InvertDiagonalBlocks( auto next_i = i + ScalarLike(i, 1); Tuple(bodyb.get(), {next_i, body_out, body_input}); } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + ASSIGN_OR_RETURN(auto body, bodyb->Build()); // Construct the While loop and return the result, // return while_loop(cond_fun, body_fun, init)[1] @@ -377,7 +378,7 @@ XlaOp TriangularSolveExpander::SolveByInvertingDiagonalBlocks( PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); const int64_t ndims = a_shape.dimensions().size(); int64_t k = ShapeUtil::GetDimension(a_shape, -1); @@ -421,8 +422,8 @@ XlaOp TriangularSolveExpander::SolveDirectly( PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); int64_t m = ShapeUtil::GetDimension(b_shape, -2); int64_t n = ShapeUtil::GetDimension(b_shape, -1); const int64_t a_size = ShapeUtil::GetDimension(a_shape, -1); @@ -478,8 +479,8 @@ XlaOp TriangularSolveExpander::BuildTriangularSolve( PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); if (a_shape.dimensions().size() != b_shape.dimensions().size()) { return InvalidArgument( "Arguments to TriangularSolve have shapes with different ranks: " @@ -600,9 +601,9 @@ absl::StatusOr TriangularSolveExpander::ExpandInstruction( transpose_a, conjugate_a, options.unit_diagonal(), /*block_size=*/block_size_, /*precision=*/PrecisionConfig::HIGHEST); - TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - TF_ASSIGN_OR_RETURN( - computation, XlaComputationToHloComputation(xla_computation, module)); + ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); + ASSIGN_OR_RETURN(computation, + XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/third_party/xla/xla/service/tuple_util.cc b/third_party/xla/xla/service/tuple_util.cc index f3a416fe79c93a..54137c75c5af73 100644 --- a/third_party/xla/xla/service/tuple_util.cc +++ b/third_party/xla/xla/service/tuple_util.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -133,10 +134,10 @@ namespace xla { // If the subshape is still a tuple, recurse and pass a new shape index // for the one level deeper. if (subshape.IsTuple()) { - TF_ASSIGN_OR_RETURN(tuple_args[i], - ReplaceTupleWith(new_instruction, get_operand(), - ShapeIndex(shape_index.begin() + 1, - shape_index.end()))); + ASSIGN_OR_RETURN(tuple_args[i], + ReplaceTupleWith(new_instruction, get_operand(), + ShapeIndex(shape_index.begin() + 1, + shape_index.end()))); } else { if (subshape != new_instruction->shape() && insert_bitcast_if_different_shape) { diff --git a/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc b/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc index eea52acb8b110a..0bc9793cf4511b 100644 --- a/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/core/collectives/reduction_kind.h" #include "xla/hlo/analysis/hlo_replication_analysis.h" #include "xla/hlo/analysis/while_loop_analysis.h" @@ -1106,7 +1107,7 @@ absl::Status ChangeAccumulatorShapesInLoopBodies( HloInstruction* pred = body->AddInstruction(HloInstruction::CreateBroadcast( pred_shape, scalar_predicate, {})); - TF_RETURN_IF_ERROR(user->ReplaceOperandWithDifferentShape(0, pred)); + RETURN_IF_ERROR(user->ReplaceOperandWithDifferentShape(0, pred)); HloInstruction *new_operand_1, *new_operand_2; if (user->operand_index(loop_reduce_scatter) == 1) { new_operand_1 = loop_reduce_scatter->mutable_operand(0); @@ -1115,9 +1116,9 @@ absl::Status ChangeAccumulatorShapesInLoopBodies( new_operand_1 = zero; new_operand_2 = loop_reduce_scatter->mutable_operand(0); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( user->ReplaceOperandWithDifferentShape(1, new_operand_1)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( user->ReplaceOperandWithDifferentShape(2, new_operand_2)); *user->mutable_shape() = accumulation_shape; } else { @@ -1265,7 +1266,7 @@ absl::Status AddSinkedAllReducesAndReplaceWhile( // For reduce-scatter, we need to adjust all the accumulator shapes to use // the pre-scatter shape. - TF_RETURN_IF_ERROR(ChangeAccumulatorShapesInLoopBodies( + RETURN_IF_ERROR(ChangeAccumulatorShapesInLoopBodies( while_instruction, all_reduce_to_accumulations)); // Step 2) create the new while instruction. @@ -1282,7 +1283,7 @@ absl::Status AddSinkedAllReducesAndReplaceWhile( // its uses. HloInstruction* new_while_result = CreateNewWhileResult(new_while_instruction, tuple_index_to_new_buffer); - TF_RETURN_IF_ERROR(while_instruction->parent()->ReplaceInstruction( + RETURN_IF_ERROR(while_instruction->parent()->ReplaceInstruction( while_instruction, new_while_result)); return absl::OkStatus(); } @@ -1377,7 +1378,7 @@ absl::StatusOr AddSinkedAllReducesAndReplaceWhile( // Replace the old while instruction with the new one. HloInstruction* new_while_result = CreateNewWhileResult(new_while_instruction, tuple_index_to_new_buffer); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( while_parent->ReplaceInstruction(while_instruction, new_while_result)); return new_while_instruction; } @@ -1402,18 +1403,18 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::RunImpl( if (module->config().replica_count() > 1) { VLOG(5) << "num_replicas: " << module->config().replica_count() << " run HloReplicationAnalysis across replicas"; - TF_ASSIGN_OR_RETURN(cross_replica_replication_analysis, - HloReplicationAnalysis::RunWithPartialReplication( - module, /*cross_partition_spmd=*/false)); + ASSIGN_OR_RETURN(cross_replica_replication_analysis, + HloReplicationAnalysis::RunWithPartialReplication( + module, /*cross_partition_spmd=*/false)); } std::unique_ptr cross_partition_replication_analysis; if (module->config().use_spmd_partitioning() && module->config().num_partitions() > 1) { VLOG(5) << "num_partitions: " << module->config().num_partitions() << " run HloReplicationAnalysis across partitions"; - TF_ASSIGN_OR_RETURN(cross_partition_replication_analysis, - HloReplicationAnalysis::RunWithPartialReplication( - module, /*cross_partition_spmd=*/true)); + ASSIGN_OR_RETURN(cross_partition_replication_analysis, + HloReplicationAnalysis::RunWithPartialReplication( + module, /*cross_partition_spmd=*/true)); } // Run setup passes that may setup the add(all-reduce/reduce-scatter, @@ -1425,7 +1426,7 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::RunImpl( } pipeline.AddPass( /*enable_reduce_scatter=*/enable_reduce_scatter_); - TF_RETURN_IF_ERROR(pipeline.Run(module, execution_threads).status()); + RETURN_IF_ERROR(pipeline.Run(module, execution_threads).status()); } // The while instruction's parent could be a while body for another while @@ -1504,15 +1505,14 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::RunImpl( // For each while instruction calling this computation, create the // corresponding all-reduces after the while loop. for (auto& while_instruction : while_caller_instructions) { - TF_ASSIGN_OR_RETURN( - while_instruction, - AddSinkedAllReducesAndReplaceWhile(while_instruction, - all_reduce_to_update_slices)); + ASSIGN_OR_RETURN(while_instruction, + AddSinkedAllReducesAndReplaceWhile( + while_instruction, all_reduce_to_update_slices)); } // Remove all-reduce instructions in the loop body. for (const auto& [all_reduce, _] : all_reduce_to_update_slices) { ++count_all_reduce; - TF_RETURN_IF_ERROR(computation->ReplaceInstruction( + RETURN_IF_ERROR(computation->ReplaceInstruction( all_reduce, all_reduce->mutable_operand(0))); } is_changed = true; @@ -1521,7 +1521,7 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::RunImpl( // For each while instruction calling this computation, create the // corresponding all-reduces after the while loop. for (HloInstruction* while_instruction : while_caller_instructions) { - TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile( + RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile( while_instruction, all_reduce_to_accumulations)); } // At last, remove the old all-reduce instructions in the while body. @@ -1533,7 +1533,7 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::RunImpl( } else { ++count_reduce_scatter; } - TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( + RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( all_reduce, all_reduce->mutable_operand(0))); } is_changed = true; diff --git a/third_party/xla/xla/service/while_loop_concat_code_motion.cc b/third_party/xla/xla/service/while_loop_concat_code_motion.cc index b7d7888ac99f6b..e0582b2c587625 100644 --- a/third_party/xla/xla/service/while_loop_concat_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_concat_code_motion.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -678,7 +679,7 @@ absl::Status AddCopiesToRoot(HloComputation* body, } copies[i] = body->AddInstruction(HloInstruction::CreateUnary( element->shape(), HloOpcode::kCopy, element)); - TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copies[i])); + RETURN_IF_ERROR(root->ReplaceOperandWith(i, copies[i])); } for (int64_t i = 0; i < copies.size(); ++i) { auto copy = copies[i]; @@ -710,7 +711,7 @@ absl::Status RemoveCopiesFromRoot(HloComputation* body) { for (int64_t i = 0; i < root->operand_count(); ++i) { auto copy = root->mutable_operand(i); if (copy->opcode() == HloOpcode::kCopy) { - TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copy->mutable_operand(0))); + RETURN_IF_ERROR(root->ReplaceOperandWith(i, copy->mutable_operand(0))); } } return absl::OkStatus(); @@ -761,7 +762,7 @@ absl::Status RewriteLoopWithConcatGroups( init_elements[i] = group.CreateConcat(std::move(input_concat_elements), loop->parent()); } - TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape( + RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape( 0, loop->parent()->AddInstruction( HloInstruction::CreateTuple(init_elements)))); // Adjust loop users. @@ -790,8 +791,7 @@ absl::Status RewriteLoopWithConcatGroups( auto new_output_tuple = loop->parent()->AddInstruction( HloInstruction::CreateTuple(output_elements)); for (auto user : original_loop_users) { - TF_RETURN_IF_ERROR( - loop->ReplaceUseWithDifferentShape(user, new_output_tuple)); + RETURN_IF_ERROR(loop->ReplaceUseWithDifferentShape(user, new_output_tuple)); } if (loop_is_root) { loop->parent()->set_root_instruction(new_output_tuple, @@ -868,8 +868,7 @@ absl::Status RewriteLoopWithConcatGroups( new_dims), hlo->mutable_operand(i))); new_reshapes.insert(reshape); - TF_RETURN_IF_ERROR( - hlo->ReplaceOperandWithDifferentShape(i, reshape)); + RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, reshape)); } continue; } @@ -914,7 +913,7 @@ absl::Status RewriteLoopWithConcatGroups( broadcast = body->AddInstruction( HloInstruction::CreateReshape(data_shape, broadcast)); } - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, broadcast)); + RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, broadcast)); } } VLOG(2) << "Modifying HLO to full shape " << hlo->ToString(); @@ -947,13 +946,13 @@ absl::Status RewriteLoopWithConcatGroups( const auto& operand_group = groups.GetGroup(operand_group_index->first); auto slice = operand_group.CreateSlice( operand_group.elements[0], operand_group_index->second, body); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, slice)); + RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, slice)); } } } for (auto slice : slices_to_remove) { - TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(slice->mutable_operand(0))); - TF_RETURN_IF_ERROR(body->RemoveInstruction(slice)); + RETURN_IF_ERROR(slice->ReplaceAllUsesWith(slice->mutable_operand(0))); + RETURN_IF_ERROR(body->RemoveInstruction(slice)); } return absl::OkStatus(); } @@ -1013,8 +1012,8 @@ absl::StatusOr RunOnLoop(HloInstruction* loop, return false; } - TF_RETURN_IF_ERROR(AddCopiesToRoot(body, gtes, &groups)); - TF_RETURN_IF_ERROR(RewriteLoopWithConcatGroups(loop, gtes, groups)); + RETURN_IF_ERROR(AddCopiesToRoot(body, gtes, &groups)); + RETURN_IF_ERROR(RewriteLoopWithConcatGroups(loop, gtes, groups)); for (auto concat : concats) { if (concat == nullptr) { continue; @@ -1022,17 +1021,16 @@ absl::StatusOr RunOnLoop(HloInstruction* loop, // We have repalced the operands of the concat with slices of full data. auto new_slice = concat->mutable_operand(0); CHECK_EQ(new_slice->opcode(), HloOpcode::kSlice); - TF_RETURN_IF_ERROR( - concat->ReplaceAllUsesWith(new_slice->mutable_operand(0))); - TF_RETURN_IF_ERROR(body->RemoveInstruction(concat)); + RETURN_IF_ERROR(concat->ReplaceAllUsesWith(new_slice->mutable_operand(0))); + RETURN_IF_ERROR(body->RemoveInstruction(concat)); } - TF_RETURN_IF_ERROR(RemoveCopiesFromRoot(body)); + RETURN_IF_ERROR(RemoveCopiesFromRoot(body)); // Finally pass-through replaced elements from parameter to root, so that // while loop simplifier can get rid of them. for (auto gte : gtes) { auto group_index = groups.GetGroupIndex(gte); if (group_index.has_value() && group_index->second > 0) { - TF_RETURN_IF_ERROR(root->ReplaceOperandWith(gte->tuple_index(), gte)); + RETURN_IF_ERROR(root->ReplaceOperandWith(gte->tuple_index(), gte)); } } return true; @@ -1048,8 +1046,8 @@ absl::StatusOr WhileLoopConcatCodeMotion::RunImpl( module->MakeComputationPostOrder(execution_threads)) { for (HloInstruction* hlo : comp->MakeInstructionPostOrder()) { if (hlo->opcode() == HloOpcode::kWhile) { - TF_ASSIGN_OR_RETURN(bool loop_changed, - RunOnLoop(hlo, min_operand_count_to_optimize_)); + ASSIGN_OR_RETURN(bool loop_changed, + RunOnLoop(hlo, min_operand_count_to_optimize_)); changed |= loop_changed; } } @@ -1061,7 +1059,7 @@ absl::StatusOr WhileLoopConcatCodeMotion::RunImpl( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(module, execution_threads).status()); + RETURN_IF_ERROR(pipeline.Run(module, execution_threads).status()); } return changed; } diff --git a/third_party/xla/xla/service/while_loop_constant_sinking.cc b/third_party/xla/xla/service/while_loop_constant_sinking.cc index 396b221748c30c..0ac4392e73e165 100644 --- a/third_party/xla/xla/service/while_loop_constant_sinking.cc +++ b/third_party/xla/xla/service/while_loop_constant_sinking.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/while_util.h" @@ -58,7 +59,7 @@ absl::Status ReplaceUsesWhileKeepingLoopInvariance( for (int64_t i = 0, e = user->operand_count(); i < e; i++) { if (user->operand(i) == old_instr && !(user == while_body_root && i == tuple_index)) { - TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, new_instr)); + RETURN_IF_ERROR(user->ReplaceOperandWith(i, new_instr)); } } } @@ -129,7 +130,7 @@ absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( } HloInstruction* constant_instr = CloneHelper(&invariant_value, body_clone); - TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( + RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( body_clone_context.FindInstruction(invariant_body_gte), constant_instr, body_clone_context.FindInstruction(while_body->root_instruction()), @@ -154,8 +155,8 @@ absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( CloneHelper(&invariant_value, cond_clone); HloInstruction* cond_gte = cond_clone_context.FindInstruction(invariant_cond_gte); - TF_RETURN_IF_ERROR(cond_gte->ReplaceAllUsesWith(constant_instr)); - TF_RETURN_IF_ERROR(cond_clone->RemoveInstruction(cond_gte)); + RETURN_IF_ERROR(cond_gte->ReplaceAllUsesWith(constant_instr)); + RETURN_IF_ERROR(cond_clone->RemoveInstruction(cond_gte)); } } } @@ -206,8 +207,8 @@ absl::StatusOr WhileLoopConstantSinking::RunImpl( // Sinking constants may change the called computations, so do that first // if this is a while instruction. if (instr->opcode() == HloOpcode::kWhile) { - TF_ASSIGN_OR_RETURN(bool result, - TrySinkingConstantsIntoWhileLoop(module, instr)); + ASSIGN_OR_RETURN(bool result, + TrySinkingConstantsIntoWhileLoop(module, instr)); changed |= result; } for (HloComputation* child : instr->called_computations()) { @@ -217,7 +218,7 @@ absl::StatusOr WhileLoopConstantSinking::RunImpl( } if (changed) { - TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); + RETURN_IF_ERROR(module->RemoveUnusedComputations()); VLOG(2) << "HLO module after WhileLoopConstantSinking:"; XLA_VLOG_LINES(2, module->ToString()); } else { diff --git a/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion.cc b/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion.cc index 2fb18f31dd3f73..39410606b8fe9a 100644 --- a/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -276,7 +277,7 @@ absl::StatusOr WhileLoopExpensiveInvariantCodeMotion:: return false; } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result, WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions)); @@ -287,7 +288,7 @@ absl::StatusOr WhileLoopExpensiveInvariantCodeMotion:: HloInstruction* instruction_to_replace_in_new_while = FindOrDie(live_in_instructions_result.while_body_instruction_map, instructions_to_replace[i]); - TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction( + RETURN_IF_ERROR(new_while_body->ReplaceInstruction( instruction_to_replace_in_new_while, live_in_instructions_result.while_body_live_in_values[i])); } @@ -340,9 +341,8 @@ absl::StatusOr WhileLoopExpensiveInvariantCodeMotion::RunImpl( continue; } - TF_ASSIGN_OR_RETURN( - bool result, - TryHoistingInvariantInstructionsFromWhileBody(while_instr)); + ASSIGN_OR_RETURN(bool result, TryHoistingInvariantInstructionsFromWhileBody( + while_instr)); changed |= result; } diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc index 3d35b2336bd627..b09a88c78c8c8e 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_computation.h" @@ -89,7 +90,7 @@ absl::Status UpdateWhileUsesWithTuple(HloInstruction* while_instr, while_instr->parent()->set_root_instruction(tuple); } if (!users.empty()) { - TF_RETURN_IF_ERROR(while_instr->ReplaceUsesWith(users, tuple)); + RETURN_IF_ERROR(while_instr->ReplaceUsesWith(users, tuple)); } return absl::OkStatus(); } @@ -118,7 +119,7 @@ absl::StatusOr AppendToWhileState( *condition->parameter_instruction(0)->mutable_shape() = while_input->shape(); // Finalize the update by changing the uses of the while loop and updating its // shape. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( UpdateWhileUsesWithTuple(while_instr, while_input->operand_count() - 1)); *while_instr->mutable_shape() = while_input->shape(); // The new body root tuple element has the same value as the new operand. @@ -225,7 +226,7 @@ WhileLoopFusibleSinking::TryRewritingBroadcastAsAllocateBuffer( // inside the body to create a predicate that checks if the loop iteration // variable is equal to the first iteration value. This is done only once // regardless of the number of sinkable indices. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * loop_iteration_variable_initial_value_gte, AppendToWhileState(while_instr, loop_iteration_variable_initial_value)); HloInstruction* iteration_var_gte = hlo_query::GetUniqueGteInstruction( @@ -248,15 +249,15 @@ WhileLoopFusibleSinking::TryRewritingBroadcastAsAllocateBuffer( // It is possible that the same broadcast has multiple users, first clone // the buffer and then replace this specific use with the clone. HloInstruction* buffer_clone = buffer->AddInstruction(buffer->Clone()); - TF_RETURN_IF_ERROR(while_instr->while_init()->ReplaceOperandWith( + RETURN_IF_ERROR(while_instr->while_init()->ReplaceOperandWith( loop_index, buffer_clone)); // Replace the clone with a free AllocateBuffer. HloInstruction* new_buffer = while_instr->parent()->AddInstruction(HloInstruction::CreateCustomCall( buffer_clone->shape(), {}, "AllocateBuffer")); - TF_RETURN_IF_ERROR(buffer_clone->ReplaceAllUsesWith(new_buffer)); - TF_RETURN_IF_ERROR(buffer_clone->parent()->RemoveInstruction(buffer_clone)); + RETURN_IF_ERROR(buffer_clone->ReplaceAllUsesWith(new_buffer)); + RETURN_IF_ERROR(buffer_clone->parent()->RemoveInstruction(buffer_clone)); // Broadcast the predicate to the shape of the buffer. HloInstruction* is_first_iteration_pred_broadcast = while_body->AddInstruction(HloInstruction::CreateBroadcast( @@ -278,9 +279,9 @@ WhileLoopFusibleSinking::TryRewritingBroadcastAsAllocateBuffer( new_buffer->shape(), HloOpcode::kSelect, is_first_iteration_pred_broadcast, sunk_constant_broadcast, buffer_body_gte)); - TF_RETURN_IF_ERROR(buffer_body_gte->ReplaceAllUsesWith(new_buffer_value)); + RETURN_IF_ERROR(buffer_body_gte->ReplaceAllUsesWith(new_buffer_value)); if (buffer->user_count() == 0) { - TF_RETURN_IF_ERROR(buffer->parent()->RemoveInstruction(buffer)); + RETURN_IF_ERROR(buffer->parent()->RemoveInstruction(buffer)); } changed = true; } @@ -409,7 +410,7 @@ absl::StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( if (init_value->IsRoot() || init_value->user_count() > 1) { init_value = init_value->AddInstruction(init_value->Clone()); - TF_RETURN_IF_ERROR(while_instr->ReplaceOperandWith(0, init_value)); + RETURN_IF_ERROR(while_instr->ReplaceOperandWith(0, init_value)); } // Original value should be a fusible subgraph. if (!IsSinkableFusion(invariant_value)) { @@ -436,7 +437,7 @@ absl::StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( while_instr->parent()->set_root_instruction(tuple); } if (!uses.empty()) { - TF_RETURN_IF_ERROR(while_instr->ReplaceUsesWith(uses, tuple)); + RETURN_IF_ERROR(while_instr->ReplaceUsesWith(uses, tuple)); } } @@ -448,7 +449,7 @@ absl::StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( } } for (auto use : invariant_output_uses) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( while_instr->parent()->ReplaceInstruction(use, invariant_value)); } @@ -477,10 +478,10 @@ absl::StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( auto cloned_fusion = while_body->AddInstruction( fusion->CloneWithNewOperands(fusion->shape(), new_operands)); - TF_RETURN_IF_ERROR(fusion->parent()->RemoveInstruction(fusion)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(fusion->parent()->RemoveInstruction(fusion)); + RETURN_IF_ERROR( while_body->ReplaceInstruction(invariant_body_gte, cloned_fusion)); - TF_RETURN_IF_ERROR(cloned_fusion->Defuse()); + RETURN_IF_ERROR(cloned_fusion->Defuse()); } return changed; } @@ -526,8 +527,7 @@ absl::StatusOr WhileLoopFusibleSinking::RunImpl( } for (HloInstruction* while_instr : while_instrs) { - TF_ASSIGN_OR_RETURN(bool result, - TrySinkingFusiblesIntoWhileLoop(while_instr)); + ASSIGN_OR_RETURN(bool result, TrySinkingFusiblesIntoWhileLoop(while_instr)); changed |= result; } @@ -536,8 +536,8 @@ absl::StatusOr WhileLoopFusibleSinking::RunImpl( for (HloInstruction* instr : comp->instructions()) { // TODO: b/358837872 - Handle loops with sharding. if (Match(instr, match::While()) && !instr->has_sharding()) { - TF_ASSIGN_OR_RETURN(bool result, - TryRewritingBroadcastAsAllocateBuffer(instr)); + ASSIGN_OR_RETURN(bool result, + TryRewritingBroadcastAsAllocateBuffer(instr)); changed |= result; } } diff --git a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc index ba68c757c13f79..b4820391d90803 100644 --- a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -269,7 +270,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( return false; } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result, WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions)); @@ -280,7 +281,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* instruction_to_replace_in_new_while = FindOrDie(live_in_instructions_result.while_body_instruction_map, instructions_to_replace[i]); - TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction( + RETURN_IF_ERROR(new_while_body->ReplaceInstruction( instruction_to_replace_in_new_while, live_in_instructions_result.while_body_live_in_values[i])); } @@ -338,9 +339,8 @@ absl::StatusOr WhileLoopInvariantCodeMotion::RunImpl( continue; } - TF_ASSIGN_OR_RETURN( - bool result, - TryHoistingInvariantInstructionsFromWhileBody(while_instr, &allowance)); + ASSIGN_OR_RETURN(bool result, TryHoistingInvariantInstructionsFromWhileBody( + while_instr, &allowance)); changed |= result; } @@ -350,10 +350,10 @@ absl::StatusOr WhileLoopInvariantCodeMotion::RunImpl( // verification failures (e.g., the verifier may see multiple channel // instructions that have the same channel ids). HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module).status()); + RETURN_IF_ERROR(dce.Run(module).status()); // Simplify while loops after narrowing / widening. TupleSimplifier tuple_simplifier; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); } if (changed) { diff --git a/third_party/xla/xla/service/while_loop_simplifier.cc b/third_party/xla/xla/service/while_loop_simplifier.cc index 3bcbfd9b30fdbb..fbeab5a1156bfb 100644 --- a/third_party/xla/xla/service/while_loop_simplifier.cc +++ b/third_party/xla/xla/service/while_loop_simplifier.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -97,12 +98,12 @@ static absl::StatusOr TryRemoveTrivialCompare(HloInstruction* while_op) { if (constant_value.value() <= init_value.value()) { if (body_instr->comparison_direction() == ComparisonDirection::kLt) { - TF_RETURN_IF_ERROR(while_op->while_body()->ReplaceInstruction( + RETURN_IF_ERROR(while_op->while_body()->ReplaceInstruction( body_instr, MakeScalarLike(body_instr, false))); return true; } else if (body_instr->comparison_direction() == ComparisonDirection::kGt) { - TF_RETURN_IF_ERROR(while_op->while_body()->ReplaceInstruction( + RETURN_IF_ERROR(while_op->while_body()->ReplaceInstruction( body_instr, MakeScalarLike(body_instr, true))); return true; } @@ -112,12 +113,12 @@ static absl::StatusOr TryRemoveTrivialCompare(HloInstruction* while_op) { init_value.value() + trip_count.value()) { if (body_instr->comparison_direction() == ComparisonDirection::kLt) { - TF_RETURN_IF_ERROR(while_op->while_body()->ReplaceInstruction( + RETURN_IF_ERROR(while_op->while_body()->ReplaceInstruction( body_instr, MakeScalarLike(body_instr, true))); return true; } else if (body_instr->comparison_direction() == ComparisonDirection::kGt) { - TF_RETURN_IF_ERROR(while_op->while_body()->ReplaceInstruction( + RETURN_IF_ERROR(while_op->while_body()->ReplaceInstruction( body_instr, MakeScalarLike(body_instr, false))); return true; } @@ -320,7 +321,7 @@ static absl::StatusOr RemoveDeadTupleIndices( } HloInstruction* new_tuple = computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems)); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple)); + RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple)); return new_while_op; } @@ -575,8 +576,8 @@ absl::StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { << " elements from tuple of " << while_op->ToString(print_no_metadata); - TF_ASSIGN_OR_RETURN(while_op, - RemoveDeadTupleIndices(while_op, used_tuple_indices)); + ASSIGN_OR_RETURN(while_op, + RemoveDeadTupleIndices(while_op, used_tuple_indices)); return true; } @@ -648,7 +649,7 @@ absl::StatusOr RemoveRepeatedWhileTupleIndices( } else { surviving_gte = it->second; } - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(gte, surviving_gte)); + RETURN_IF_ERROR(comp->ReplaceInstruction(gte, surviving_gte)); } } } @@ -663,7 +664,7 @@ absl::StatusOr RemoveRepeatedWhileTupleIndices( } } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( while_op, RemoveDeadTupleIndices( while_op, used_tuple_indices, @@ -756,9 +757,9 @@ static absl::StatusOr TryRemoveRepeatedWhileTupleIndices( // Only keep one index for each equivalence set. HloInstruction* original_while_op = while_op; - TF_ASSIGN_OR_RETURN( - while_op, RemoveRepeatedWhileTupleIndices(while_op, init_to_indices, - /*replace_with_init=*/true)); + ASSIGN_OR_RETURN(while_op, + RemoveRepeatedWhileTupleIndices(while_op, init_to_indices, + /*replace_with_init=*/true)); // In theory, we could handle the "simple" case and the "dynamic-update-slice" // case in one go, but it's probably not worth the added complexity, so do it @@ -782,7 +783,7 @@ static absl::StatusOr TryRemoveRepeatedWhileTupleIndices( .push_back(index); } } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( while_op, RemoveRepeatedWhileTupleIndices(while_op, dus_key_to_indices, /*replace_with_init=*/false)); @@ -903,7 +904,7 @@ static absl::StatusOr TryRemoveConstantParams(HloInstruction* while_op) { // CloneWithReplacementPairs will *leave the parameter out entirely*, creating // invalid HLO. if (ShapeUtil::IsEmptyTuple(new_while_shape)) { - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init)); + RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init)); return true; } @@ -951,7 +952,7 @@ static absl::StatusOr TryRemoveConstantParams(HloInstruction* while_op) { new_while_op->CopyBackendConfigFrom(while_op); CopyFrontendAttributes(while_op, new_while_op); CopyMetadata(while_op, new_while_op); - TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( while_op, add_constant_elems(new_while_op))); for (auto& instr : new_instrs) { computation->AddInstruction(std::move(instr)); @@ -998,7 +999,7 @@ static absl::StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // Remove while_op (i.e., call ReplaceInstruction rather than // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in // a loop without an intervening DCE, we don't try to re-remove the loop. - TF_RETURN_IF_ERROR(computation->ReplaceInstruction( + RETURN_IF_ERROR(computation->ReplaceInstruction( while_op, while_op->mutable_operand(0))); return true; } @@ -1039,10 +1040,10 @@ static absl::StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { auto call_op = computation->AddInstruction(HloInstruction::CreateCall( while_op->shape(), while_op->operands(), while_op->while_body())); call_op->set_original_value(while_op->original_value()); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); + RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); call_op->set_metadata_op_name(""); - TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, - CallInliner::Inline(call_op)); + ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_op)); (void)inlined_instructions_map; return true; } else { @@ -1109,7 +1110,7 @@ static absl::StatusOr TryPropagateConstant(HloInstruction* while_op) { const HloInstruction* hlo_constant = (*iter).second; VLOG(3) << "Replace use of " << instr->ToString() << " with " << hlo_constant->ToString(); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith( + RETURN_IF_ERROR(instr->ReplaceAllUsesWith( computation->AddInstruction(hlo_constant->Clone()))); changed = true; } @@ -1118,9 +1119,9 @@ static absl::StatusOr TryPropagateConstant(HloInstruction* while_op) { return changed; }; - TF_ASSIGN_OR_RETURN(bool changed_cond, - propagate_constant(while_op->while_condition())); - TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body)); + ASSIGN_OR_RETURN(bool changed_cond, + propagate_constant(while_op->while_condition())); + ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body)); return changed_cond || changed_body; } @@ -1302,7 +1303,7 @@ static absl::StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { new_while_op->CopyBackendConfigFrom(while_op); CopyFrontendAttributes(while_op, new_while_op); CopyMetadata(while_op, new_while_op); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation->ReplaceWithNewInstruction(while_op, nested(new_while_op))); for (auto& instr : new_instrs) { computation->AddInstruction(std::move(instr)); @@ -1536,7 +1537,7 @@ static absl::StatusOr TryMergeInductionVariables( Cast( temp_new_while_body->parameter_instruction(0))), }); - TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body)); + RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body)); // Create the final while loop, and add any new instructions created to // `computation`. @@ -1558,7 +1559,7 @@ static absl::StatusOr TryMergeInductionVariables( } CopyFrontendAttributes(while_op, new_while); CopyMetadata(while_op, new_while); - TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( while_op, convert_to_old_form(new_while))); for (auto& instr : new_instrs) { computation->AddInstruction(std::move(instr)); @@ -1601,33 +1602,32 @@ absl::StatusOr WhileLoopSimplifier::RunImpl( // These optimizations should be fine even with send/recv nodes within the // loop. - TF_ASSIGN_OR_RETURN(bool result, - TryRemoveRepeatedWhileTupleIndices(while_op)); + ASSIGN_OR_RETURN(bool result, TryRemoveRepeatedWhileTupleIndices(while_op)); changed |= result; if (result) { continue; } - TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); + ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); changed |= result; if (result) { continue; } - TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); + ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); changed |= result; if (result) { continue; } - TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op)); + ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op)); changed |= result; if (result) { continue; } if (simplify_compare_instrs_) { - TF_ASSIGN_OR_RETURN(result, TryRemoveTrivialCompare(while_op)); + ASSIGN_OR_RETURN(result, TryRemoveTrivialCompare(while_op)); changed |= result; if (result) { continue; @@ -1650,10 +1650,10 @@ absl::StatusOr WhileLoopSimplifier::RunImpl( continue; } - TF_ASSIGN_OR_RETURN(result, TryPropagateConstant(while_op)); + ASSIGN_OR_RETURN(result, TryPropagateConstant(while_op)); changed |= result; - TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); + ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); changed |= result; if (result) { @@ -1675,8 +1675,8 @@ absl::StatusOr WhileLoopSimplifier::RunImpl( // Notably missing from this list are S16 and U16. These don't currently // work because S/U16 literals are not implemented. for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) { - TF_ASSIGN_OR_RETURN(auto* new_while_op, - TryMergeInductionVariables(while_op, elem_ty)); + ASSIGN_OR_RETURN(auto* new_while_op, + TryMergeInductionVariables(while_op, elem_ty)); if (new_while_op) { while_op = new_while_op; changed = true; @@ -1689,7 +1689,7 @@ absl::StatusOr WhileLoopSimplifier::RunImpl( } if (changed) { HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module).status()); + RETURN_IF_ERROR(dce.Run(module).status()); } XLA_VLOG_LINES( 3, "WhileLoopSimplifier::RunImpl(), after:\n" + module->ToString()); diff --git a/third_party/xla/xla/service/while_loop_unroller.cc b/third_party/xla/xla/service/while_loop_unroller.cc index a9dbb06099045e..572e47f3737cc3 100644 --- a/third_party/xla/xla/service/while_loop_unroller.cc +++ b/third_party/xla/xla/service/while_loop_unroller.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/evaluator/hlo_evaluator.h" @@ -98,7 +99,7 @@ std::unique_ptr MakeTrivialLoopCondition( absl::Status HandleDynamicGteOrTuple(HloInstruction* instr) { if (instr->IsCustomCall("DynamicGte")) { HloEvaluator evaluator(/*max_loop_iterations=*/0); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Literal index_lit, evaluator.Evaluate(instr->mutable_operand(1), /*precomputed_analyses=*/{}, @@ -112,7 +113,7 @@ absl::Status HandleDynamicGteOrTuple(HloInstruction* instr) { } else if (instr->IsCustomCall("DynamicTuple")) { HloEvaluator evaluator(/*max_loop_iterations=*/0); std::vector tuple_operands; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( Literal index_lit, evaluator.Evaluate(instr->mutable_operand(2), /*precomputed_analyses=*/{}, @@ -174,7 +175,7 @@ absl::Status ReplaceInductionVarUses(HloComputation* body, const HloInstruction* indvar_use_operand = indvar_use->operand(i); // Found the induction var user. if (indvar_use_operand == body_inst) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( indvar_use->ReplaceOperandWith(i, induction_value_constant)); } } @@ -206,9 +207,9 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, HloInstruction* induction_value_constant = while_body_clone->AddInstruction( MakeScalarConstantWithShape(induction_var_hlo->shape(), induction_value)); - TF_RETURN_IF_ERROR(ReplaceInductionVarUses(while_body_clone.get(), - induction_value_constant, - config.induction_var_idx)); + RETURN_IF_ERROR(ReplaceInductionVarUses(while_body_clone.get(), + induction_value_constant, + config.induction_var_idx)); absl::flat_hash_set seen_scheduling_ids; for (HloInstruction* body_inst : while_body_clone->instructions()) { @@ -224,21 +225,21 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, // We need to assign a unique id to each scheduling group (of instructions) // that are unrolled within the while loop body. - TF_ASSIGN_OR_RETURN(std::optional scheduling_id, - GetSchedulingAnnotationGroupId(body_inst)); + ASSIGN_OR_RETURN(std::optional scheduling_id, + GetSchedulingAnnotationGroupId(body_inst)); if (scheduling_id.has_value()) { if (!seen_scheduling_ids.contains(scheduling_id.value())) { seen_scheduling_ids.insert(scheduling_id.value()); next_scheduling_id++; } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetSchedulingAnnotationGroupId(body_inst, next_scheduling_id)); } // Handle DynamicGte and DynamicTuple custom-calls created during unstacking // pass. All custom-calls must be replaced for the loop to be unrolled // successfully. - TF_RETURN_IF_ERROR(HandleDynamicGteOrTuple(body_inst)); + RETURN_IF_ERROR(HandleDynamicGteOrTuple(body_inst)); } return while_body_clone; } @@ -298,8 +299,8 @@ absl::StatusOr UnrollInternal(HloInstruction* while_op, HloInstruction* unrolled_body_call_op; std::vector call_operands = {while_op->operands().at(0)}; - TF_ASSIGN_OR_RETURN(int64_t next_scheduling_id, - NextSchedulingGroupId(*while_op->GetModule())); + ASSIGN_OR_RETURN(int64_t next_scheduling_id, + NextSchedulingGroupId(*while_op->GetModule())); std::vector new_calls; for (int64_t i = config.init; i < config.trip_count + config.init; ++i) { CHECK(OverflowSafeAdd(i, (int64_t)1).has_value()); @@ -316,11 +317,11 @@ absl::StatusOr UnrollInternal(HloInstruction* while_op, call_operands.clear(); call_operands.push_back(unrolled_body_call_op); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( computation->ReplaceInstruction(while_op, unrolled_body_call_op)); unrolled_body_call_op->set_metadata_op_name(""); for (HloInstruction* call : new_calls) { - TF_RETURN_IF_ERROR(CallInliner::Inline(call).status()); + RETURN_IF_ERROR(CallInliner::Inline(call).status()); } return true; } @@ -344,8 +345,8 @@ absl::StatusOr UnrollInternalWrappedAndReturnReplacement( // We assume while has only one tuple parameter call_operands.emplace_back(std::move(p.value())); - TF_ASSIGN_OR_RETURN(int64_t next_scheduling_id, - NextSchedulingGroupId(*while_op->GetModule())); + ASSIGN_OR_RETURN(int64_t next_scheduling_id, + NextSchedulingGroupId(*while_op->GetModule())); std::vector new_calls; for (int64_t i = config.init; i < config.trip_count + config.init; ++i) { @@ -379,7 +380,7 @@ absl::StatusOr UnrollInternalWrappedAndReturnReplacement( while_op->SetupDerivedInstruction(new_while_op); CHECK_OK(computation->ReplaceInstruction(while_op, new_while_op)); for (HloInstruction* call : new_calls) { - TF_RETURN_IF_ERROR(CallInliner::Inline(call).status()); + RETURN_IF_ERROR(CallInliner::Inline(call).status()); } UnrollResult result; @@ -390,9 +391,8 @@ absl::StatusOr UnrollInternalWrappedAndReturnReplacement( absl::StatusOr UnrollInternalWrapped(HloInstruction* while_op, WhileLoopConfig config) { - TF_ASSIGN_OR_RETURN( - UnrollResult result, - UnrollInternalWrappedAndReturnReplacement(while_op, config)); + ASSIGN_OR_RETURN(UnrollResult result, + UnrollInternalWrappedAndReturnReplacement(while_op, config)); return result.unrolled; } @@ -652,8 +652,8 @@ absl::Status FindIndicesCoveredByDynamicInstructionsInInnerLoop( predefined_ranges[induction_var_gte] = loop_range.value(); // Step 2: Find dynamic instructions inside the while body. - TF_ASSIGN_OR_RETURN(std::vector dynamic_instructions, - FindDynamicInstructions(input, while_instr)); + ASSIGN_OR_RETURN(std::vector dynamic_instructions, + FindDynamicInstructions(input, while_instr)); const Shape& input_shape = input->shape(); const int64_t dimension_size = input_shape.dimensions(dynamic_indices.first); @@ -969,8 +969,8 @@ absl::StatusOr IsInputShapeCoveredByDynamicUpdateSliceInstructions( const HloInstruction* input = config.while_instr->while_init()->operand(input_idx); - TF_ASSIGN_OR_RETURN(std::vector dynamic_instructions, - FindDynamicInstructions(input, config.while_instr)); + ASSIGN_OR_RETURN(std::vector dynamic_instructions, + FindDynamicInstructions(input, config.while_instr)); TF_RET_CHECK(dynamic_instructions.size() == 1); const HloInstruction* dus = dynamic_instructions.front(); @@ -1311,7 +1311,7 @@ std::optional AdvancedMatchShapeCoveringDynamicIndexInstruction( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( bool applied_cse, HloCSE(/*is_layout_sensitive=*/true, /*ignore_control_dependencies=*/false, @@ -1330,8 +1330,8 @@ std::optional AdvancedMatchShapeCoveringDynamicIndexInstruction( changed = true; VLOG(3) << "Applied hlo cse to module " << module->name(); } - TF_ASSIGN_OR_RETURN(bool applied_tuple_simplifier, - TupleSimplifier{}.Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool applied_tuple_simplifier, + TupleSimplifier{}.Run(module, execution_threads)); if (applied_tuple_simplifier) { changed = true; VLOG(3) << "Applied tuple simplifier to module " << module->name(); @@ -1341,8 +1341,8 @@ std::optional AdvancedMatchShapeCoveringDynamicIndexInstruction( HloPassFix constant_sinking( /*sink_broadcast_of_constants=*/true, /*sink_only_scalar_constants=*/true); - TF_ASSIGN_OR_RETURN(bool applied_constant_sinking, - constant_sinking.Run(module, execution_threads)); + ASSIGN_OR_RETURN(bool applied_constant_sinking, + constant_sinking.Run(module, execution_threads)); if (applied_constant_sinking) { changed = true; VLOG(3) << "Applied constant sinking to module " << module->name(); @@ -1398,7 +1398,7 @@ WhileLoopUnroller::UnrollAndReturnReplacement( if (prepare) { // Make sure all the necessary passes are executed before unrolling in order // to unroll every possible loop. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( PrepareModuleForUnrolling(module, /*execution_threads=*/{}).status()); } @@ -1415,17 +1415,16 @@ WhileLoopUnroller::UnrollAndReturnReplacement( return result; } if (wrap_in_trivial_loop) { - TF_ASSIGN_OR_RETURN(result, UnrollInternalWrappedAndReturnReplacement( - while_op, config.value())); + ASSIGN_OR_RETURN(result, UnrollInternalWrappedAndReturnReplacement( + while_op, config.value())); } else { - TF_ASSIGN_OR_RETURN(result.unrolled, - UnrollInternal(while_op, config.value())); + ASSIGN_OR_RETURN(result.unrolled, UnrollInternal(while_op, config.value())); } if (result.unrolled) { // Inlining calls created during unrolling may have left unused computations // around, run DCE to clean them up. - TF_RETURN_IF_ERROR(HloDCE().Run(module, /*execution_threads=*/{}).status()); + RETURN_IF_ERROR(HloDCE().Run(module, /*execution_threads=*/{}).status()); } return result; @@ -1444,8 +1443,8 @@ absl::StatusOr WhileLoopUnroller::RunImpl( bool changed = false; // Make sure all the necessary passes are executed before unrolling in order // to unroll every possible loop. - TF_ASSIGN_OR_RETURN(changed, - PrepareModuleForUnrolling(module, execution_threads)); + ASSIGN_OR_RETURN(changed, + PrepareModuleForUnrolling(module, execution_threads)); // Processing the while loops in the reverse of topological order. If the body // of while loop A calls while loop B, B comes before A. std::vector all_while_ops; @@ -1465,9 +1464,9 @@ absl::StatusOr WhileLoopUnroller::RunImpl( bool unrolled = false; for (auto& [while_op, config] : unrollable_while_ops) { if (wrap_in_trivial_loop_) { - TF_ASSIGN_OR_RETURN(unrolled, UnrollInternalWrapped(while_op, config)); + ASSIGN_OR_RETURN(unrolled, UnrollInternalWrapped(while_op, config)); } else { - TF_ASSIGN_OR_RETURN(unrolled, UnrollInternal(while_op, config)); + ASSIGN_OR_RETURN(unrolled, UnrollInternal(while_op, config)); } changed |= unrolled; } @@ -1475,7 +1474,7 @@ absl::StatusOr WhileLoopUnroller::RunImpl( if (changed) { // Inlining calls created during unrolling may have left unused computations // around, run DCE to clean them up. - TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); + RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); } XLA_VLOG_LINES(3, @@ -1500,16 +1499,16 @@ absl::StatusOr> CreatePartiallyUnrolledLoop( const std::vector& loop_state) -> absl::StatusOr> { std::vector inner_loop_state = loop_state; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * inner_loop_indvar, MakeBinaryHlo( HloOpcode::kMultiply, induction_var, MakeR0ConstantHlo(induction_var->parent(), unroll_factor))); for (int i = 0; i < unroll_factor; ++i) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( inner_loop_state, loop_body_generator(inner_loop_indvar, inner_loop_state)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( inner_loop_indvar, MakeBinaryHlo( HloOpcode::kAdd, inner_loop_indvar, @@ -1528,7 +1527,7 @@ absl::StatusOr> CreatePartiallyUnrolledLoop( HloInstruction* induction_var, const std::vector& loop_state) -> absl::StatusOr> { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * adjusted_induction_var, MakeBinaryHlo(HloOpcode::kAdd, induction_var, MakeR0ConstantHlo(induction_var->parent(), diff --git a/third_party/xla/xla/service/while_util.cc b/third_party/xla/xla/service/while_util.cc index 7e092d2bc2682d..54bf2a4d88f7a8 100644 --- a/third_party/xla/xla/service/while_util.cc +++ b/third_party/xla/xla/service/while_util.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -91,8 +92,8 @@ WidenWhileCondition(HloComputation* narrow_condition, const Shape& wide_shape) { wide_while_cond->set_root_instruction(call_narrow_cond); - TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, - CallInliner::Inline(call_narrow_cond)); + ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_narrow_cond)); return {{wide_while_cond, std::move(inlined_instructions_map)}}; } @@ -133,8 +134,8 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { wide_while_body->set_root_instruction( TupleUtil::AppendSuffix(call_narrow_body, live_through_values)); - TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, - CallInliner::Inline(call_narrow_body)); + ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_narrow_body)); return {{wide_while_body, std::move(inlined_instructions_map)}}; } @@ -236,15 +237,14 @@ WhileUtil::MakeInstructionsLiveIn( HloComputation* new_while_condition; CallInliner::InlinedInstructionMap inlined_condition_instructions_map; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::tie(new_while_condition, inlined_condition_instructions_map), WidenWhileCondition(while_instr->while_condition(), new_while_shape)); HloComputation* new_while_body; CallInliner::InlinedInstructionMap inlined_instructions_map; - TF_ASSIGN_OR_RETURN( - std::tie(new_while_body, inlined_instructions_map), - WidenWhileBody(while_instr->while_body(), new_while_shape)); + ASSIGN_OR_RETURN(std::tie(new_while_body, inlined_instructions_map), + WidenWhileBody(while_instr->while_body(), new_while_shape)); HloInstruction* new_while_init = TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions); @@ -285,10 +285,10 @@ WhileUtil::MakeInstructionsLiveIn( // instead of relying on HloComputation::ReplaceInstruction. HloInstruction* replacement_instr = TupleUtil::ExtractPrefix( new_while, while_instr->shape().tuple_shapes().size()); - TF_RETURN_IF_ERROR(new_while->CopyAllControlDepsFrom(while_instr)); - TF_RETURN_IF_ERROR(while_instr->DropAllControlDeps()); - TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr)); - TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr)); + RETURN_IF_ERROR(new_while->CopyAllControlDepsFrom(while_instr)); + RETURN_IF_ERROR(while_instr->DropAllControlDeps()); + RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr)); + RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr)); HloInstruction* while_body_param = new_while_body->parameter_instruction(0); std::vector live_in_instructions; @@ -319,19 +319,18 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape, int32_t trip_count) { Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); - TF_ASSIGN_OR_RETURN(std::unique_ptr cond_computation, - CreateComputationWithSignature( - {&loop_state_shape}, scalar_pred, "while_cond")); + ASSIGN_OR_RETURN(std::unique_ptr cond_computation, + CreateComputationWithSignature({&loop_state_shape}, + scalar_pred, "while_cond")); HloInstruction* trip_count_constant = cond_computation->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR0(trip_count))); HloInstruction* param = cond_computation->parameter_instruction(0); - TF_ASSIGN_OR_RETURN(HloInstruction * indvar, - MakeGetTupleElementHlo(param, 0)); + ASSIGN_OR_RETURN(HloInstruction * indvar, MakeGetTupleElementHlo(param, 0)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( HloInstruction * compare, MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant)); cond_computation->set_root_instruction(compare); @@ -344,25 +343,24 @@ MakeCountedLoopBodyComputation( absl::FunctionRef( HloInstruction*, const WhileUtil::LoopStateTy&)> loop_body_generator) { - TF_ASSIGN_OR_RETURN(std::unique_ptr body_computation, - CreateComputationWithSignature( - {&loop_state_shape}, loop_state_shape, "while_body")); + ASSIGN_OR_RETURN(std::unique_ptr body_computation, + CreateComputationWithSignature( + {&loop_state_shape}, loop_state_shape, "while_body")); HloInstruction* one = body_computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); HloInstruction* param = body_computation->parameter_instruction(0); - TF_ASSIGN_OR_RETURN(HloInstruction * indvar, - MakeGetTupleElementHlo(param, 0)); - TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar, - MakeBinaryHlo(HloOpcode::kAdd, indvar, one)); + ASSIGN_OR_RETURN(HloInstruction * indvar, MakeGetTupleElementHlo(param, 0)); + ASSIGN_OR_RETURN(HloInstruction * next_indvar, + MakeBinaryHlo(HloOpcode::kAdd, indvar, one)); std::vector loop_body_generator_args; for (int i = 1, e = loop_state_shape.tuple_shapes().size(); i < e; i++) { - TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element, - MakeGetTupleElementHlo(param, i)); + ASSIGN_OR_RETURN(HloInstruction * tuple_element, + MakeGetTupleElementHlo(param, i)); loop_body_generator_args.push_back(tuple_element); } - TF_ASSIGN_OR_RETURN(std::vector next_state, - loop_body_generator(indvar, loop_body_generator_args)); + ASSIGN_OR_RETURN(std::vector next_state, + loop_body_generator(indvar, loop_body_generator_args)); next_state.insert(next_state.begin(), next_indvar); HloInstruction* next_state_tuple = body_computation->AddInstruction(HloInstruction::CreateTuple(next_state)); @@ -415,10 +413,10 @@ WhileUtil::MakeCountedLoop(HloModule* module, int32_t trip_count, // use loop_state_shape to create a literal, which requires loop_state_shape // to have a layout. Shape loop_state_shape = MakeLoopStateShapeWithLayout(init_values); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr cond, MakeCountedLoopConditionComputation(loop_state_shape, trip_count)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr body, MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator)); std::unique_ptr owned_indvar; @@ -451,10 +449,9 @@ WhileUtil::MakeCountedLoop(HloModule* module, int32_t trip_count, const WhileUtil::LoopStateTy& init_values, WhileUtil::LoopBodyGeneratorTy loop_body_generator, const OpMetadata& metadata) { - TF_ASSIGN_OR_RETURN( - auto owning_loop_state, - MakeCountedLoop(computation->parent(), trip_count, init_values, - loop_body_generator, metadata)); + ASSIGN_OR_RETURN(auto owning_loop_state, + MakeCountedLoop(computation->parent(), trip_count, + init_values, loop_body_generator, metadata)); for (auto& instruction_to_add : owning_loop_state.instructions_to_add) { computation->AddInstruction(std::move(instruction_to_add)); } diff --git a/third_party/xla/xla/service/while_util_test.cc b/third_party/xla/xla/service/while_util_test.cc index 35060031b97a5b..7834901f837b9b 100644 --- a/third_party/xla/xla/service/while_util_test.cc +++ b/third_party/xla/xla/service/while_util_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" @@ -63,7 +64,7 @@ ENTRY entry { } )"; - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); + ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 6307e1da7cb16a..597bd2f5d14333 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -220,6 +220,28 @@ cc_library( ], ) +cc_library( + name = "mock_command_buffer", + testonly = True, + hdrs = ["mock_command_buffer.h"], + deps = [ + ":bit_pattern", + ":command_buffer", + ":device_address", + ":dnn", + ":kernel", + ":kernel_args", + ":launch_dim", + ":platform", + ":stream", + "//xla/hlo/testlib:test", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "mock_platform", testonly = True, diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 7decfa4a45a65c..d3c1e29fb09fdf 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -139,6 +139,7 @@ cc_library( "//xla/tsl/cuda:nvml", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", @@ -212,6 +213,7 @@ cc_library( "//xla/stream_executor/gpu:context_map", "//xla/stream_executor/gpu:scoped_activate_context", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -270,6 +272,7 @@ cc_library( "//xla/stream_executor:scratch_allocator", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -322,6 +325,7 @@ cc_library( "//xla/tsl/cuda:cublas_lt", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/base", @@ -359,6 +363,7 @@ cc_library( deps = [ "//xla/stream_executor:blas", "//xla/tsl/cuda:cublas", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -389,6 +394,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/platform:initialize", "//xla/tsl/cuda:cufft", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base", "@com_google_absl//absl/log", @@ -418,6 +424,7 @@ cuda_library( "//xla/stream_executor:stream", "//xla/stream_executor:typed_kernel_factory", "//xla/stream_executor/gpu:gpu_semaphore", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status:statusor", ], ) @@ -464,6 +471,7 @@ xla_test( "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/log", @@ -523,6 +531,7 @@ xla_test( "//xla/stream_executor/gpu:gpu_kernel_registry", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/log", @@ -566,6 +575,7 @@ cc_library( "//xla/tsl/cuda:cudnn", "//xla/tsl/cuda:nvrtc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:dnn_proto_cc", "//xla/tsl/util:env_var", @@ -605,6 +615,7 @@ cc_library( deps = [ "//xla/stream_executor:semantic_version", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -650,6 +661,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service/gpu:stream_executor_util", "//xla/stream_executor:dnn", + "//xla/tsl/platform:status_macros", "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", @@ -804,6 +816,7 @@ cc_library( "//xla/stream_executor:activate_context", "//xla/stream_executor:event", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -829,6 +842,7 @@ cc_library( "//xla/stream_executor:device_address", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@local_config_cuda//cuda:cuda_headers", @@ -978,6 +992,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:vmm_device_address_allocator", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", @@ -1095,6 +1110,7 @@ cc_library( "//xla/stream_executor:kernel_stats", "//xla/stream_executor:semantic_version", "//xla/stream_executor/gpu:gpu_asm_opts", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -1160,6 +1176,7 @@ xla_cc_test( "//xla/stream_executor:kernel_stats", "//xla/stream_executor:semantic_version", "//xla/stream_executor/gpu:gpu_asm_opts", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", @@ -1365,6 +1382,7 @@ cc_library( "//xla/pjrt/distributed:key_value_store_interface", "//xla/runtime:process_id", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base", "@com_google_absl//absl/base:no_destructor", @@ -1427,6 +1445,7 @@ cc_library( "//xla/stream_executor:kernel_stats", "//xla/stream_executor/gpu:gpu_asm_opts", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -1556,6 +1575,7 @@ cc_library( ":subprocess_compilation", "//xla/stream_executor/gpu:gpu_asm_opts", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -1790,6 +1810,7 @@ cuda_library( deps = [ ":cuda_status", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status", "@local_config_cuda//cuda:cuda_runtime", ], @@ -1823,6 +1844,7 @@ xla_test( "//xla/stream_executor/gpu:gpu_init", "//xla/stream_executor/gpu:multicast_memory", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", @@ -1962,6 +1984,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:env_var", "@com_google_absl//absl/base", @@ -2061,6 +2084,7 @@ cc_library( "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_semaphore", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -2217,6 +2241,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_asm_opts", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:subprocess", "@com_google_absl//absl/algorithm:container", @@ -2322,6 +2347,7 @@ cc_library( "//xla/stream_executor:activate_context", "//xla/stream_executor:device_description", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", @@ -2395,6 +2421,7 @@ cc_library( ":ptx_compiler_helpers", ":subprocess_compilation", "//xla/stream_executor/gpu:gpu_asm_opts", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:subprocess", "@com_google_absl//absl/status", @@ -2516,6 +2543,7 @@ cc_library( ":cuda_compute_capability", ":ptx_compiler", "//xla/stream_executor/gpu:gpu_asm_opts", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -2533,6 +2561,7 @@ cc_library( ":compilation_provider", ":cuda_compute_capability", "//xla/stream_executor:device_description", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -2568,6 +2597,7 @@ cc_library( ":compilation_provider", ":cuda_compute_capability", "//xla/stream_executor:device_description", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:node_hash_map", @@ -2721,6 +2751,7 @@ cc_library( "//xla:xla_proto_cc", "//xla/stream_executor:semantic_version", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -3071,6 +3102,7 @@ cuda_library( ":cuda_status", "//xla:xla_data_proto_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -3122,6 +3154,7 @@ xla_test( "//xla/stream_executor:stream_executor_h", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc index e9442454c753d8..ed4911d8b0210d 100644 --- a/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/assemble_compilation_provider.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/compilation_provider_options.h" #include "xla/stream_executor/cuda/composite_compilation_provider.h" @@ -135,7 +136,7 @@ absl::StatusOr> AssembleCompilationProvider(const CompilationProviderOptions& options) { // TODO(b/381059098): Simplify this logic - TF_RETURN_IF_ERROR(CheckIncompatibleFlagSettings(options)); + RETURN_IF_ERROR(CheckIncompatibleFlagSettings(options)); std::string decision_log; const auto append_to_decision_log = [&](absl::string_view decision) { diff --git a/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda_test.cc b/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda_test.cc index fbe2aaf161b804..3d5568417026bc 100644 --- a/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/buffer_debug_log_structs.h" #include "xla/backends/gpu/runtime/buffer_debug_log_structs_test_matchers.h" #include "xla/backends/gpu/runtime/thunk_id.h" @@ -101,23 +102,23 @@ class FloatCheckKernelTest : public ::testing::Test { // Load kernel gpu::GpuKernelRegistry registry = gpu::GpuKernelRegistry::GetGlobalRegistry(); - TF_ASSIGN_OR_RETURN(auto check_kernel, - registry.LoadKernel(executor_)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(auto check_kernel, + registry.LoadKernel(executor_)); + ASSIGN_OR_RETURN( auto reduce_kernel, registry .LoadKernel( executor_)); // Setup device buffers - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::DeviceAddress device_input, CheckNotNull(executor_->AllocateArray(input.size()), "input")); auto cleanup_input = absl::MakeCleanup([&]() { executor_->Deallocate(&device_input); }); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( se::DeviceAddress device_tmp, CheckNotNull(executor_->AllocateArray( temp_buffer_size_elements), @@ -128,18 +129,18 @@ class FloatCheckKernelTest : public ::testing::Test { const se::ThreadDim thread_dim(1024, 1, 1); // Call kernel - TF_RETURN_IF_ERROR(stream_->Memcpy(&device_input, input.data(), - input.size() * sizeof(input[0]))); - TF_RETURN_IF_ERROR(check_kernel.Launch( + RETURN_IF_ERROR(stream_->Memcpy(&device_input, input.data(), + input.size() * sizeof(input[0]))); + RETURN_IF_ERROR(check_kernel.Launch( thread_dim, block_dim, stream_.get(), device_input, device_input.ElementCount(), device_tmp, device_tmp.ElementCount())); - TF_RETURN_IF_ERROR(reduce_kernel.Launch( + RETURN_IF_ERROR(reduce_kernel.Launch( thread_dim, se::BlockDim(1, 1, 1), stream_.get(), device_tmp, std::min(device_tmp.ElementCount(), block_dim.x * block_dim.y * block_dim.z), entry_id, buffer_debug_log.GetDeviceHeader(), buffer_debug_log.GetDeviceEntries())); - TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + RETURN_IF_ERROR(stream_->BlockHostUntilDone()); // The result gets stored in `buffer_debug_log`. return absl::OkStatus(); diff --git a/third_party/xla/xla/stream_executor/cuda/buffer_debug_xor_checksum_kernel_cuda_test.cc b/third_party/xla/xla/stream_executor/cuda/buffer_debug_xor_checksum_kernel_cuda_test.cc index b0583d8000dc1f..2a5afaa9d76508 100644 --- a/third_party/xla/xla/stream_executor/cuda/buffer_debug_xor_checksum_kernel_cuda_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/buffer_debug_xor_checksum_kernel_cuda_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/backends/gpu/runtime/buffer_debug_log_structs.h" #include "xla/stream_executor/device_address.h" #include "xla/stream_executor/gpu/buffer_debug_log.h" @@ -88,27 +89,27 @@ class ChecksumKernelTest : public ::testing::Test { // Load kernel gpu::GpuKernelRegistry registry = gpu::GpuKernelRegistry::GetGlobalRegistry(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto kernel, registry.LoadKernel(executor_)); // Setup device buffers - TF_ASSIGN_OR_RETURN(se::DeviceAddress device_input, - CheckNotNull(executor_->AllocateArray( - input.size() * sizeof(input[0])), - "input")); + ASSIGN_OR_RETURN(se::DeviceAddress device_input, + CheckNotNull(executor_->AllocateArray( + input.size() * sizeof(input[0])), + "input")); auto cleanup_input = absl::MakeCleanup([&]() { executor_->Deallocate(&device_input); }); // Call kernel - TF_RETURN_IF_ERROR(stream_->Memcpy(&device_input, input.data(), - input.size() * sizeof(input[0]))); - TF_RETURN_IF_ERROR(kernel.Launch(dim, stream_executor::BlockDim(1, 1, 1), - stream_.get(), entry_id, device_input, - device_input.ElementCount(), - buffer_debug_log.GetDeviceHeader(), - buffer_debug_log.GetDeviceEntries())); - TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + RETURN_IF_ERROR(stream_->Memcpy(&device_input, input.data(), + input.size() * sizeof(input[0]))); + RETURN_IF_ERROR(kernel.Launch(dim, stream_executor::BlockDim(1, 1, 1), + stream_.get(), entry_id, device_input, + device_input.ElementCount(), + buffer_debug_log.GetDeviceHeader(), + buffer_debug_log.GetDeviceEntries())); + RETURN_IF_ERROR(stream_->BlockHostUntilDone()); // The result gets stored in `buffer_debug_log`. return absl::OkStatus(); diff --git a/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.cc index bf16a1eee71399..1661df9b90d64a 100644 --- a/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/caching_compilation_provider.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" @@ -127,7 +128,7 @@ absl::StatusOr CachingCompilationProvider::CompileAndLink( if (std::holds_alternative(input)) { modules.push_back(std::get(input)); } else { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( RelocatableModule relocatable_module, CompileToRelocatableModule(cc, std::get(input).ptx, options)); modules.push_back(std::move(relocatable_module)); diff --git a/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc index b1a50a80c0694f..b69430e6d51dc0 100644 --- a/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/composite_compilation_provider.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" @@ -114,8 +115,7 @@ absl::StatusOr CompositeCompilationProvider::GetLatestPtxIsaVersion() const { std::optional latest_supported_version; for (const auto& provider : providers_) { - TF_ASSIGN_OR_RETURN(int provider_version, - provider->GetLatestPtxIsaVersion()); + ASSIGN_OR_RETURN(int provider_version, provider->GetLatestPtxIsaVersion()); if (!latest_supported_version.has_value()) { latest_supported_version = provider_version; continue; diff --git a/third_party/xla/xla/stream_executor/cuda/cub_scan_kernel_cuda_impl.cu.cc b/third_party/xla/xla/stream_executor/cuda/cub_scan_kernel_cuda_impl.cu.cc index 8abde076bf93d0..4149a361331295 100644 --- a/third_party/xla/xla/stream_executor/cuda/cub_scan_kernel_cuda_impl.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/cub_scan_kernel_cuda_impl.cu.cc @@ -25,6 +25,7 @@ limitations under the License. #include "cub/device/device_scan.cuh" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_bf16.h" // IWYU pragma: keep #include "third_party/gpus/cuda/include/cuda_fp16.h" // IWYU pragma: keep @@ -136,7 +137,7 @@ absl::Status CubThreadScanDispatch(const T* d_in, T* d_out, int64_t row_length, constexpr int block_size = 256; auto* kernel = ThreadScanKernel; size_t shared_mem_bytes = block_size * row_length * sizeof(T); - TF_RETURN_IF_ERROR(ToStatus(cudaFuncSetAttribute( + RETURN_IF_ERROR(ToStatus(cudaFuncSetAttribute( reinterpret_cast(kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_bytes))); int grid_size = (column_length + block_size - 1) / block_size; @@ -183,10 +184,10 @@ absl::Status CubScanDispatch(void* d_temp_storage, size_t* temp_bytes, // max threads per block, which should match ScanPolicyT::BLOCK_THREADS // because we use that as __launch_bounds__. cudaFunction_t function; - TF_RETURN_IF_ERROR(ToStatus( + RETURN_IF_ERROR(ToStatus( cudaGetFuncBySymbol(&function, reinterpret_cast(kernel)))); int block_size; - TF_RETURN_IF_ERROR(ToStatus(cuFuncGetAttribute( + RETURN_IF_ERROR(ToStatus(cuFuncGetAttribute( &block_size, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, function))); cudaLaunchConfig_t config = { @@ -273,9 +274,9 @@ absl::StatusOr CubScanGetScratchSize( xla::PrimitiveType type, int64_t vector_length, int64_t row_length, int64_t column_length, CubScanKind kind, bool is_reverse) { size_t temp_bytes = 0; - TF_RETURN_IF_ERROR(CubScanDispatch(type, nullptr, &temp_bytes, nullptr, - nullptr, vector_length, row_length, - column_length, kind, is_reverse, nullptr)); + RETURN_IF_ERROR(CubScanDispatch(type, nullptr, &temp_bytes, nullptr, nullptr, + vector_length, row_length, column_length, + kind, is_reverse, nullptr)); return temp_bytes; } diff --git a/third_party/xla/xla/stream_executor/cuda/cub_scan_kernel_cuda_test.cc b/third_party/xla/xla/stream_executor/cuda/cub_scan_kernel_cuda_test.cc index 6b07b7b1334b79..ccd05f828d12f1 100644 --- a/third_party/xla/xla/stream_executor/cuda/cub_scan_kernel_cuda_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cub_scan_kernel_cuda_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_bf16.h" #include "third_party/gpus/cuda/include/cuda_fp16.h" @@ -90,9 +91,9 @@ class CubScanKernelCudaTest } // Get scratch size. - TF_ASSIGN_OR_RETURN(size_t temp_bytes, - CubScanGetScratchSize(type, vector_length, row_length, - col_length, kind, is_reverse)); + ASSIGN_OR_RETURN(size_t temp_bytes, + CubScanGetScratchSize(type, vector_length, row_length, + col_length, kind, is_reverse)); // Allocate device buffers se::DeviceAddress device_data = @@ -108,18 +109,17 @@ class CubScanKernelCudaTest // Copy data to device. size_t size_bytes = num_elements * sizeof(T); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( stream_->Memcpy(&device_data, host_data.data(), size_bytes)); - TF_RETURN_IF_ERROR(CubScanLaunchKernel( + RETURN_IF_ERROR(CubScanLaunchKernel( type, device_temp.opaque(), temp_bytes, device_data.opaque(), device_data.opaque(), vector_length, row_length, col_length, kind, is_reverse, static_cast(stream_->platform_specific_handle().stream))); - TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); - TF_RETURN_IF_ERROR( - stream_->Memcpy(host_data.data(), device_data, size_bytes)); + RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + RETURN_IF_ERROR(stream_->Memcpy(host_data.data(), device_data, size_bytes)); if constexpr (std::is_same_v) { EXPECT_THAT(host_data, diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc index 29e219e135213f..7ffcc4e98b9be9 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/cuda/cubin_or_ptx_image.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/cuda/ptx_compiler.h" @@ -45,14 +46,14 @@ absl::StatusOr> CompileGpuAsm( GpuAsmOpts options) { if (IsLibNvPtxCompilerSupported()) { VLOG(3) << "Compiling GPU ASM with libnvptxcompiler"; - TF_ASSIGN_OR_RETURN(auto assembly, - CompileGpuAsmUsingLibNvPtxCompiler(cc, ptx, options)); + ASSIGN_OR_RETURN(auto assembly, + CompileGpuAsmUsingLibNvPtxCompiler(cc, ptx, options)); return std::move(assembly.cubin); } VLOG(3) << "Compiling GPU ASM with PTXAS. Libnvptxcompiler compilation " "not supported."; - TF_ASSIGN_OR_RETURN(auto assembly, CompileGpuAsmUsingPtxAs(cc, ptx, options)); + ASSIGN_OR_RETURN(auto assembly, CompileGpuAsmUsingPtxAs(cc, ptx, options)); return std::move(assembly.cubin); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc index b17f0eeec568d1..6ce57d3b50875a 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "Eigen/Core" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuComplex.h" #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" @@ -702,7 +703,7 @@ static absl::Status PopulateProfileFromTimer( EventBasedTimer *timer, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { if (output_profile_result) { - TF_ASSIGN_OR_RETURN(absl::Duration duration, timer->GetElapsedDuration()); + ASSIGN_OR_RETURN(absl::Duration duration, timer->GetElapsedDuration()); output_profile_result->set_is_valid(true); output_profile_result->set_algorithm(algorithm); output_profile_result->set_elapsed_time_in_ms( @@ -719,29 +720,28 @@ absl::Status CUDABlas::DoBlasGemmWithAlgorithm( blas::DataType type_c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, const EngineOptions& engine_options, blas::ProfileResult* output_profile_result, blas::CallContext context) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( cublasMath_t math_type, GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, engine_options)); std::unique_ptr timer; if (output_profile_result != nullptr) { - TF_ASSIGN_OR_RETURN(timer, - stream->CreateEventBasedTimer( - output_profile_result->warmup_run_executed())); + ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast, // we do the following compile-time check on the default value: static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, ""); - TF_RETURN_IF_ERROR(DoBlasInternalImpl( + RETURN_IF_ERROR(DoBlasInternalImpl( AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n, k, alpha, a.opaque(), AsCudaDataType(type_a), lda, b.opaque(), AsCudaDataType(type_b), ldb, beta, c->opaque(), AsCudaDataType(type_c), ldc, AsCublasComputeType(computation_type), static_cast(algorithm))); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( PopulateProfileFromTimer(timer.get(), algorithm, output_profile_result)); return absl::OkStatus(); } @@ -756,14 +756,13 @@ absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( blas::ComputationType computation_type, blas::AlgorithmType algorithm, const EngineOptions& engine_options, blas::ProfileResult* output_profile_result, blas::CallContext context) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( cublasMath_t math_type, GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, engine_options)); std::unique_ptr timer; if (output_profile_result != nullptr) { - TF_ASSIGN_OR_RETURN(timer, - stream->CreateEventBasedTimer( - output_profile_result->warmup_run_executed())); + ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } cudaDataType_t cuda_in_type = AsCudaDataType(type_a); @@ -781,21 +780,21 @@ absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( if (AsCudaDataType(type_c) == CUDA_R_16BF) { auto *c_matrix = reinterpret_cast<__nv_bfloat16 *>( static_cast(c->opaque()) + batch * stride_c); - TF_RETURN_IF_ERROR(DoBlasInternalImpl( + RETURN_IF_ERROR(DoBlasInternalImpl( AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, - n, k, static_cast(alpha), a_matrix, CUDA_R_16BF, lda, - b_matrix, CUDA_R_16BF, ldb, static_cast(beta), + n, k, static_cast(alpha), a_matrix, CUDA_R_16BF, lda, + b_matrix, CUDA_R_16BF, ldb, static_cast(beta), c_matrix, AsCudaDataType(type_c), ldc, AsCublasComputeType(computation_type), static_cast(algorithm))); } else if (AsCudaDataType(type_c) == CUDA_R_32F) { auto *c_matrix = static_cast(c->opaque()) + batch * stride_c; - TF_RETURN_IF_ERROR(DoBlasInternalImpl( + RETURN_IF_ERROR(DoBlasInternalImpl( AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, - n, k, static_cast(alpha), a_matrix, CUDA_R_16BF, lda, - b_matrix, CUDA_R_16BF, ldb, static_cast(beta), + n, k, static_cast(alpha), a_matrix, CUDA_R_16BF, lda, + b_matrix, CUDA_R_16BF, ldb, static_cast(beta), c_matrix, AsCudaDataType(type_c), ldc, AsCublasComputeType(computation_type), static_cast(algorithm))); @@ -805,20 +804,20 @@ absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( blas::DataTypeString(type_a), blas::DataTypeString(type_c))); } } - TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm, - output_profile_result)); + RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm, + output_profile_result)); return absl::OkStatus(); } #endif - TF_RETURN_IF_ERROR(DoBlasInternalImpl( + RETURN_IF_ERROR(DoBlasInternalImpl( AS_LAMBDA(cublasGemmStridedBatchedEx), stream, /*pointer_mode_host=*/true, math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n, k, alpha, a.opaque(), cuda_in_type, lda, stride_a, b.opaque(), cuda_in_type, ldb, stride_b, beta, c->opaque(), AsCudaDataType(type_c), ldc, stride_c, batch_count, AsCublasComputeType(computation_type), static_cast(algorithm))); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( PopulateProfileFromTimer(timer.get(), algorithm, output_profile_result)); return absl::OkStatus(); } @@ -936,19 +935,19 @@ absl::Status CUDABlas::DoBlasGemmBatchedInternal( if (scratch_allocator == nullptr) { return absl::InternalError("scratch_allocator is null"); } - TF_ASSIGN_OR_RETURN(DeviceAddress a_bytes, - scratch_allocator->AllocateBytes(size)); - TF_ASSIGN_OR_RETURN(DeviceAddress b_bytes, - scratch_allocator->AllocateBytes(size)); - TF_ASSIGN_OR_RETURN(DeviceAddress c_bytes, - scratch_allocator->AllocateBytes(size)); + ASSIGN_OR_RETURN(DeviceAddress a_bytes, + scratch_allocator->AllocateBytes(size)); + ASSIGN_OR_RETURN(DeviceAddress b_bytes, + scratch_allocator->AllocateBytes(size)); + ASSIGN_OR_RETURN(DeviceAddress c_bytes, + scratch_allocator->AllocateBytes(size)); DeviceAddress a(a_bytes); DeviceAddress b(b_bytes); DeviceAddress c(c_bytes); - TF_RETURN_IF_ERROR(stream->Memcpy(&a, a_raw_ptrs.data(), size)); - TF_RETURN_IF_ERROR(stream->Memcpy(&b, b_raw_ptrs.data(), size)); - TF_RETURN_IF_ERROR(stream->Memcpy(&c, c_raw_ptrs.data(), size)); + RETURN_IF_ERROR(stream->Memcpy(&a, a_raw_ptrs.data(), size)); + RETURN_IF_ERROR(stream->Memcpy(&b, b_raw_ptrs.data(), size)); + RETURN_IF_ERROR(stream->Memcpy(&c, c_raw_ptrs.data(), size)); cudaDataType_t data_type = CUDADataType::type; @@ -1017,10 +1016,10 @@ absl::Status CUDABlas::DoBlasGemmBatchedInternal( const DeviceAddress& a_matrix = *a_ptrs_to_wrappers[b]; const DeviceAddress& b_matrix = *b_ptrs_to_wrappers[b]; DeviceAddress* c_matrix = c_ptrs_to_wrappers[b]; - TF_RETURN_IF_ERROR(DoBlasGemm( - stream, transa, transb, m, n, k, blas::ToDataType::value, &alpha, - a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc, engine_options, - blas::CallContext::kNone)); + RETURN_IF_ERROR(DoBlasGemm(stream, transa, transb, m, n, k, + blas::ToDataType::value, &alpha, a_matrix, + lda, b_matrix, ldb, &beta, c_matrix, ldc, + engine_options, blas::CallContext::kNone)); } return absl::OkStatus(); } @@ -1182,12 +1181,12 @@ absl::Status CUDABlas::DoBlasGemmStridedBatched( batch * stride_b); auto *c_matrix = reinterpret_cast<__nv_bfloat16 *>( static_cast(c->opaque()) + batch * stride_c); - TF_RETURN_IF_ERROR(DoBlasInternalImpl( + RETURN_IF_ERROR(DoBlasInternalImpl( cublasSgemmEx, stream, true /* = pointer_mode_host */, CUBLAS_DEFAULT_MATH, AsCublasOperation(transa), AsCublasOperation(transb), m, n, k, - static_cast(alpha), a_matrix, CUDA_R_16BF, lda, - b_matrix, CUDA_R_16BF, ldb, static_cast(beta), + static_cast(alpha), a_matrix, CUDA_R_16BF, lda, + b_matrix, CUDA_R_16BF, ldb, static_cast(beta), c_matrix, CUDA_R_16BF, ldc)); } return absl::OkStatus(); @@ -1214,12 +1213,12 @@ absl::Status CUDABlas::DoBlasGemmStridedBatched( static_cast(b.opaque()) + batch * stride_b); auto *c_matrix = reinterpret_cast<__half *>( static_cast(c->opaque()) + batch * stride_c); - TF_RETURN_IF_ERROR(DoBlasInternalImpl( + RETURN_IF_ERROR(DoBlasInternalImpl( cublasSgemmEx, stream, true /* = pointer_mode_host */, CUBLAS_DEFAULT_MATH, AsCublasOperation(transa), AsCublasOperation(transb), m, n, k, - static_cast(alpha), a_matrix, CUDA_R_16F, lda, - b_matrix, CUDA_R_16F, ldb, static_cast(beta), + static_cast(alpha), a_matrix, CUDA_R_16F, lda, + b_matrix, CUDA_R_16F, ldb, static_cast(beta), c_matrix, CUDA_R_16F, ldc)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc index 3de8db14cf4232..9e4938b2ccf6c4 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_blas_lt.h" -#include #include #include #include @@ -33,6 +32,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" @@ -59,7 +59,7 @@ limitations under the License. #define GET_ATTR(getter, handle, attr, ValueT) \ [&]() -> absl::StatusOr { \ ValueT value; \ - TF_RETURN_IF_ERROR(ToStatus( \ + RETURN_IF_ERROR(ToStatus( \ getter(handle, attr, &value, sizeof(ValueT), nullptr), #getter)); \ return std::move(value); \ }() @@ -150,16 +150,16 @@ absl::StatusOr AsCublasLtEpilogue( } // namespace absl::Status BlasLt::Init() { - cublasLtHandle_t blas_lt; - SE_CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&blas_lt)); + cublasLtHandle_t handle; + SE_CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&handle)); absl::MutexLock lock(mu_); - blas_lt_.reset(blas_lt); + handle_.reset(handle); return absl::OkStatus(); } /*static*/ absl::StatusOr BlasLt::MatrixLayout::Create( const gpu::MatrixLayout& m) { - TF_ASSIGN_OR_RETURN(auto type, gpu::AsBlasDataType(m.dtype)); + ASSIGN_OR_RETURN(auto type, gpu::AsBlasDataType(m.dtype)); cublasLtMatrixLayout_t cu_layout; SE_CUBLAS_RETURN_IF_ERROR( @@ -167,21 +167,22 @@ absl::Status BlasLt::Init() { m.num_cols, m.leading_dim_stride)); // Wrap cublas handle immediately, so it is cleaned up if an error occurs. BlasLt::MatrixLayout layout(cu_layout); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetAttr(cu_layout, CUBLASLT_MATRIX_LAYOUT_ORDER, int32_t{(m.order == gpu::MatrixLayout::Order::kRowMajor) ? CUBLASLT_ORDER_ROW : CUBLASLT_ORDER_COL})); - TF_RETURN_IF_ERROR(SetAttr(cu_layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - static_cast(m.batch_size))); + RETURN_IF_ERROR(SetAttr(cu_layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + static_cast(m.batch_size))); VLOG(2) << "MatrixLayout::Create: num_rows: " << m.num_rows << " num_cols:" << (int)m.num_cols << ", order: " << (int)m.order - << "," << " batchsz " << m.batch_size + << "," + << " batchsz " << m.batch_size << " leaddimstride: " << m.leading_dim_stride << " batch_stride: " << m.batch_stride; - TF_RETURN_IF_ERROR(SetAttr( + RETURN_IF_ERROR(SetAttr( cu_layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, m.batch_stride)); return std::move(layout); } @@ -206,19 +207,19 @@ cudaDataType_t BlasLt::MatrixLayout::type() const { &cu_desc, AsCublasComputeType(compute_type), AsCudaDataType(scale_type))); // Wrap cublas handle immediately, so it is cleaned up if an error occurs. BlasLt::MatmulDesc desc(cu_desc); - TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, - AsCublasLtPointerMode(pointer_mode))); - TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_TRANSA, - AsCublasOperation(trans_a))); - TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_TRANSB, - AsCublasOperation(trans_b))); - TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue)); - TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi)); + RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + AsCublasLtPointerMode(pointer_mode))); + RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_TRANSA, + AsCublasOperation(trans_a))); + RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_TRANSB, + AsCublasOperation(trans_b))); + ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue)); + RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi)); // The CUBLASLT_MATMUL_DESC_FAST_ACCUM flag only impacts FP8 gemms. It speeds // up gemms at the expense of accumulation precision. In practice, it is safe // to set on the forward pass but not the backward pass. - TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, - static_cast(enable_fast_accum))); + RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, + static_cast(enable_fast_accum))); return std::move(desc); } @@ -239,16 +240,14 @@ cublasLtPointerMode_t BlasLt::MatmulDesc::pointer_mode() const { .value()); } -auto BlasLt::MatmulPlan::GetAlgorithms(const Stream* stream, - size_t max_algorithm_count, +auto BlasLt::MatmulPlan::GetAlgorithms(size_t max_algorithm_count, size_t max_workspace_size) const -> absl::StatusOr> { max_algorithm_count = std::min(max_algorithm_count, size_t{INT_MAX}); std::vector results(max_algorithm_count); { - auto blas_lt = static_cast(gpu::BlasLt::Get(stream)); - absl::MutexLock lock(blas_lt->mu_); - TF_RET_CHECK(blas_lt->blas_lt_ != nullptr); + absl::MutexLock lock(blas_lt_.mu_); + TF_RET_CHECK(blas_lt_.handle_.get() != nullptr); cublasLtMatmulPreference_t cu_preference; SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmulPreferenceCreate(&cu_preference)); @@ -257,9 +256,9 @@ auto BlasLt::MatmulPlan::GetAlgorithms(const Stream* stream, Owned preference( cu_preference, cublasLtMatmulPreferenceDestroy); - TF_RETURN_IF_ERROR(SetAttr( - cu_preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - max_workspace_size)); + RETURN_IF_ERROR(SetAttr(cu_preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + max_workspace_size)); #if CUDA_VERSION >= 11080 // Set dummy (non-null, aligned) scale pointers before querying heuristics @@ -271,25 +270,26 @@ auto BlasLt::MatmulPlan::GetAlgorithms(const Stream* stream, bool is_fp8_scaled = is_fp8(a_desc_.type()) || is_fp8(b_desc_.type()); if (is_fp8_scaled) { void* dummy = reinterpret_cast(0x40); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, dummy)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, dummy)); if (is_fp8(c_desc_.type())) { - TF_RETURN_IF_ERROR(SetAttr( - op_desc_.get(), CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, dummy)); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), + CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, dummy)); } if (is_fp8(d_desc_.type())) { - TF_RETURN_IF_ERROR(SetAttr( - op_desc_.get(), CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, dummy)); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), + CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, dummy)); } } #endif - std::unique_ptr activation = blas_lt->parent_->Activate(); + std::unique_ptr activation = + blas_lt_.executor_->Activate(); int found_algorithm_count = 0; SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmulAlgoGetHeuristic( - blas_lt->blas_lt_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(), + blas_lt_.handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(), c_desc_.get(), d_desc_.get(), preference.get(), max_algorithm_count, results.data(), &found_algorithm_count)); results.resize(found_algorithm_count); @@ -336,17 +336,16 @@ absl::StatusOr BlasLt::GetMatmulPlan( bool must_swap_operands = MakeOutputColumnMajor(lhs_layout, rhs_layout, output_layout, &c_layout); - TF_ASSIGN_OR_RETURN(auto output_dtype, - gpu::AsBlasDataType(output_layout.dtype)); + ASSIGN_OR_RETURN(auto output_dtype, gpu::AsBlasDataType(output_layout.dtype)); auto compute_type = cfg.compute_type; if (!compute_type) { // obtain compute_type unless provided by the user - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( compute_type, gpu::GetBlasComputationType( cfg.precision_algorithm, lhs_layout.dtype, output_layout.dtype, cfg.compute_precision, - parent_->GetDeviceDescription().gpu_compute_capability())); + executor_->GetDeviceDescription().gpu_compute_capability())); } // FP8 matmuls have a fast accumulation mode that is less precise than the @@ -356,21 +355,21 @@ absl::StatusOr BlasLt::GetMatmulPlan( IsFastAccumEnabled(cfg.precision_algorithm, lhs_layout.dtype, rhs_layout.dtype, cfg.compute_precision); auto trans_a = lhs_layout.transpose, trans_b = rhs_layout.transpose; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto op_desc, MatmulDesc::Create(*compute_type, gpu::GetScaleType(output_dtype, *compute_type), trans_a, trans_b, epilogue, enable_fast_accum)); - TF_ASSIGN_OR_RETURN(auto a_desc, MatrixLayout::Create(lhs_layout)); - TF_ASSIGN_OR_RETURN(auto b_desc, MatrixLayout::Create(rhs_layout)); - TF_ASSIGN_OR_RETURN(auto c_desc, MatrixLayout::Create(c_layout)); - TF_ASSIGN_OR_RETURN(auto d_desc, MatrixLayout::Create(output_layout)); + ASSIGN_OR_RETURN(auto a_desc, MatrixLayout::Create(lhs_layout)); + ASSIGN_OR_RETURN(auto b_desc, MatrixLayout::Create(rhs_layout)); + ASSIGN_OR_RETURN(auto c_desc, MatrixLayout::Create(c_layout)); + ASSIGN_OR_RETURN(auto d_desc, MatrixLayout::Create(output_layout)); - return std::make_unique(std::move(op_desc), std::move(a_desc), - std::move(b_desc), std::move(c_desc), - std::move(d_desc), cfg.alpha, cfg.beta, - must_swap_operands); + return std::make_unique(*this, std::move(op_desc), + std::move(a_desc), std::move(b_desc), + std::move(c_desc), std::move(d_desc), + cfg.alpha, cfg.beta, must_swap_operands); } absl::Status BlasLt::MatmulPlan::DoMatmul( @@ -388,22 +387,18 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( std::swap(a_scale, b_scale); } - auto blas_lt = static_cast(gpu::BlasLt::Get(stream)); - TF_RET_CHECK(blas_lt != nullptr); - std::unique_ptr timer; if (profile_result != nullptr) { - TF_ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( - profile_result->warmup_run_executed())); + ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + profile_result->warmup_run_executed())); } void* workspace_addr = nullptr; uint64_t workspace_size = algorithm_->workspace_size; if (workspace_size > 0) { if (args.scratch_allocator != nullptr) { - TF_ASSIGN_OR_RETURN( - DeviceAddress alloc, - args.scratch_allocator->AllocateBytes(workspace_size)); + ASSIGN_OR_RETURN(DeviceAddress alloc, + args.scratch_allocator->AllocateBytes(workspace_size)); workspace_addr = gpu::GpuMemoryMutable(&alloc); } else { workspace_addr = args.workspace.opaque(); @@ -415,33 +410,31 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( auto palgo = std::any_cast(&algorithm_->opaque_algo); { - absl::MutexLock lock(blas_lt->mu_); - TF_RET_CHECK(blas_lt->blas_lt_ != nullptr); + absl::MutexLock lock(blas_lt_.mu_); + TF_RET_CHECK(blas_lt_.handle_.get() != nullptr); // We must set the bias and aux pointers while holding the mutex, to avoid a // potential race condition from multiple threads sharing the same plan. if (args.bias != nullptr) { - TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - args.bias.opaque())); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER, + args.bias.opaque())); } #if CUDA_VERSION >= 11080 // Always set scale pointers (null when not provided) to overwrite any // dummy values left by GetAlgorithms(). - TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), - CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, - a_scale.opaque())); - TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), - CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, - b_scale.opaque())); - TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), - CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, - args.c_scale.opaque())); - TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), - CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, - args.d_scale.opaque())); - TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), - CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, - args.d_amax.opaque())); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + a_scale.opaque())); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + b_scale.opaque())); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), + CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, + args.c_scale.opaque())); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), + CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, + args.d_scale.opaque())); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, + args.d_amax.opaque())); #else if (!(a_scale == nullptr && b_scale == nullptr && args.c_scale == nullptr && args.d_scale == nullptr && args.d_amax == nullptr)) { @@ -452,44 +445,41 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( if (args.aux != nullptr) { #if CUDA_VERSION >= 11040 - TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, - args.aux.opaque())); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + args.aux.opaque())); // Set leading dim and batch stride of auxiliary output to match output. // TODO(cjfj): Set this once at initialization. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( int64_t output_leading_dim, GetAttr(d_desc_.get(), CUBLASLT_MATRIX_LAYOUT_LD)); - TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, - output_leading_dim)); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + output_leading_dim)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( int64_t output_batch_stride, GetAttr(d_desc_.get(), CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET)); - TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE, - output_batch_stride)); + RETURN_IF_ERROR(SetAttr(op_desc_.get(), + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE, + output_batch_stride)); #else return absl::InternalError( "Auxiliary inputs / outputs require cublasLt >= 11.4"); #endif } - std::unique_ptr activation = blas_lt->parent_->Activate(); - - void* c_ptr = args.c.opaque(); - if (beta_ == 0.0) { - c_ptr = nullptr; - } + std::unique_ptr activation = + blas_lt_.executor_->Activate(); + void* c_ptr = beta_ == 0.0 ? nullptr : args.c.opaque(); if (palgo != nullptr) { SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmul( - blas_lt->blas_lt_.get(), op_desc_.get(), alpha, a.opaque(), + blas_lt_.handle_.get(), op_desc_.get(), alpha, a.opaque(), a_desc_.get(), b.opaque(), b_desc_.get(), beta, c_ptr, c_desc_.get(), args.d.opaque(), d_desc_.get(), palgo, workspace_addr, workspace_size, absl::bit_cast(stream->platform_specific_handle().stream))); @@ -499,7 +489,7 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( } if (profile_result != nullptr) { - TF_ASSIGN_OR_RETURN(absl::Duration elapsed, timer->GetElapsedDuration()); + ASSIGN_OR_RETURN(absl::Duration elapsed, timer->GetElapsedDuration()); // set algorithm ID to be unique (otherwise it gets kDefaultAlgorithm ID) profile_result->set_algorithm(reinterpret_cast(palgo)); profile_result->set_is_valid(true); @@ -584,12 +574,6 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream( return xla::Internal("Unexpected dtype"); } -absl::StatusOr BlasLt::GetGroupedMatmulPlan( - gpu::GroupedGemmConfig& config, Epilogue epilogue) const { - return absl::UnimplementedError( - "Grouped GEMM is not supported for CUDA BlasLt"); -} - } // namespace cuda } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h index 001ccc73449af6..f257520dcd0026 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h @@ -86,11 +86,12 @@ class BlasLt : public gpu::BlasLt { class MatmulPlan : public gpu::BlasLt::MatmulPlan { public: - MatmulPlan(MatmulDesc&& op_desc, MatrixLayout&& a_desc, - MatrixLayout&& b_desc, MatrixLayout&& c_desc, - MatrixLayout&& d_desc, xla::complex128 alpha, double beta, - bool must_swap_operands) - : op_desc_(std::move(op_desc)), + MatmulPlan(const BlasLt& blas_lt, MatmulDesc&& op_desc, + MatrixLayout&& a_desc, MatrixLayout&& b_desc, + MatrixLayout&& c_desc, MatrixLayout&& d_desc, + xla::complex128 alpha, double beta, bool must_swap_operands) + : blas_lt_(blas_lt), + op_desc_(std::move(op_desc)), a_desc_(std::move(a_desc)), b_desc_(std::move(b_desc)), c_desc_(std::move(c_desc)), @@ -106,8 +107,7 @@ class BlasLt : public gpu::BlasLt { blas::ProfileResult* profile_result) const override; absl::StatusOr> GetAlgorithms( - const Stream* stream, size_t max_algorithm_count = 128, - size_t max_workspace_size = 1ll << 32) const override; + size_t max_algorithm_count, size_t max_workspace_size) const override; absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) override { algorithm_ = algorithm; @@ -119,7 +119,7 @@ class BlasLt : public gpu::BlasLt { const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const; - // TODO(cjfj): Add consistency checks for types, shapes, etc.? + const BlasLt& blas_lt_; MatmulDesc op_desc_; MatrixLayout a_desc_; MatrixLayout b_desc_; @@ -131,8 +131,8 @@ class BlasLt : public gpu::BlasLt { std::optional algorithm_; // selected algorithm }; // class MatmulPlan - explicit BlasLt(StreamExecutor* parent) - : parent_(parent), blas_lt_(nullptr, cublasLtDestroy) {} + explicit BlasLt(StreamExecutor* executor) + : executor_(executor), handle_(nullptr, cublasLtDestroy) {} absl::Status Init() override; @@ -140,14 +140,17 @@ class BlasLt : public gpu::BlasLt { Epilogue epilogue) const override; absl::StatusOr GetGroupedMatmulPlan( - gpu::GroupedGemmConfig& config, Epilogue epilogue) const override; + const gpu::GroupedGemmConfig& config, Epilogue epilogue) const override { + return absl::UnimplementedError( + "Grouped GEMM is not supported for CUDA BlasLt"); + }; ~BlasLt() override = default; private: - StreamExecutor* parent_; + StreamExecutor* executor_; mutable absl::Mutex mu_; - Owned blas_lt_ ABSL_GUARDED_BY(mu_); + Owned handle_ ABSL_GUARDED_BY(mu_); }; } // namespace cuda diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt_test.cc index b3aa482fab7df5..75521bc01c380c 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt_test.cc @@ -42,8 +42,7 @@ class CudaBlasLtTest : public ::testing::Test { ASSERT_OK_AND_ASSIGN(executor_, platform_->ExecutorForDevice(0)); LOG(INFO) << "Device name: " << executor_->GetDeviceDescription().name(); ASSERT_OK_AND_ASSIGN(stream_, executor_->CreateStream()); - blas_lt_ = gpu::BlasLt::Get(stream_.get()); - ASSERT_NE(blas_lt_, nullptr); + ASSERT_OK_AND_ASSIGN(blas_lt_, gpu::BlasLt::Get(executor_)); } template @@ -102,9 +101,8 @@ class CudaBlasLtTest : public ::testing::Test { cfg, gpu::BlasLt::Epilogue::kDefault)); uint32_t workspace_size = 32 * 1024 * 1024; // 32 MB - ASSERT_OK_AND_ASSIGN( - auto algorithms, - plan->GetAlgorithms(stream_.get(), 128, workspace_size)); + ASSERT_OK_AND_ASSIGN(auto algorithms, + plan->GetAlgorithms(128, workspace_size)); ASSERT_FALSE(algorithms.empty()); ASSERT_OK(plan->SetAlgorithm(algorithms[0])); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.h index aaaf4257f4f5b5..4d8cdf61b7c996 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_utils.h @@ -16,15 +16,15 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_UTILS_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_UTILS_H_ - #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/library_types.h" #include "xla/stream_executor/blas.h" #include "tsl/platform/errors.h" #define SE_CUBLAS_RETURN_IF_ERROR(expr) \ - TF_RETURN_IF_ERROR(::stream_executor::cuda::ToStatus(expr, #expr)) + RETURN_IF_ERROR(::stream_executor::cuda::ToStatus(expr, #expr)) namespace stream_executor { namespace cuda { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc index e53827f1765763..448fb782f500c2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc @@ -62,8 +62,8 @@ namespace { absl::StatusOr CreateGraph() { VLOG(2) << "Create new CUDA graph"; CUgraph graph = nullptr; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphCreate(&graph, /*flags=*/0), - "Failed to create CUDA graph")); + RETURN_IF_ERROR(cuda::ToStatus(cuGraphCreate(&graph, /*flags=*/0), + "Failed to create CUDA graph")); VLOG(2) << "Created CUDA graph " << graph; return graph; } @@ -137,7 +137,7 @@ absl::Status GraphInstantiate(CUgraphExec* exec, CUgraph graph) { absl::StatusOr> CudaCommandBuffer::Create( Mode mode, StreamExecutor* executor, CudaContext* cuda_context) { - TF_ASSIGN_OR_RETURN(CUgraph graph, CreateGraph()); + ASSIGN_OR_RETURN(CUgraph graph, CreateGraph()); return std::unique_ptr(new CudaCommandBuffer( mode, executor, cuda_context, graph, /*is_owned_graph=*/true)); } @@ -150,9 +150,8 @@ absl::StatusOr CudaCommandBuffer::CreateSetWhileConditionNode( GraphConditionalHandle conditional, DeviceAddress predicate, absl::Span dependencies) { if (!set_while_condition_kernel_) { - TF_ASSIGN_OR_RETURN(auto spec, - cuda::GetSetWhileConditionKernelLoaderSpec()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(auto spec, cuda::GetSetWhileConditionKernelLoaderSpec()); + ASSIGN_OR_RETURN( set_while_condition_kernel_, SetWhileConditionKernel::FactoryType::Create(stream_exec_, spec)); } @@ -203,8 +202,8 @@ absl::StatusOr CudaCommandBuffer::CreateSetCaseConditionNode( bool enable_conditional_default, absl::Span dependencies) { if (!set_case_condition_kernel_) { - TF_ASSIGN_OR_RETURN(auto spec, cuda::GetSetCaseConditionKernelLoaderSpec()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN(auto spec, cuda::GetSetCaseConditionKernelLoaderSpec()); + ASSIGN_OR_RETURN( set_case_condition_kernel_, SetCaseConditionKernel::FactoryType::Create(stream_exec_, spec)); } @@ -234,9 +233,9 @@ absl::Status CudaCommandBuffer::UpdateSetCaseConditionNode( absl::StatusOr CudaCommandBuffer::GetNoOpKernel() { if (!noop_kernel_) { - TF_ASSIGN_OR_RETURN(auto spec, cuda::GetNoOpKernelLoaderSpec()); - TF_ASSIGN_OR_RETURN(noop_kernel_, - NoOpKernel::FactoryType::Create(stream_exec_, spec)); + ASSIGN_OR_RETURN(auto spec, cuda::GetNoOpKernelLoaderSpec()); + ASSIGN_OR_RETURN(noop_kernel_, + NoOpKernel::FactoryType::Create(stream_exec_, spec)); } return &noop_kernel_; } @@ -276,7 +275,7 @@ CudaCommandBuffer::CreateConditionalNode( std::vector deps = ToCudaGraphHandles(dependencies); CUgraphNode node_handle = nullptr; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuGraphAddNode_v2(&node_handle, graph_, deps.data(), /*dependencyData=*/nullptr, deps.size(), &cu_params), "Failed to add conditional node to a CUDA graph")); @@ -318,7 +317,7 @@ absl::StatusOr CudaCommandBuffer::CreateMemsetNode( std::vector deps = ToCudaGraphHandles(dependencies); CUgraphNode node_handle = nullptr; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuGraphAddMemsetNode(&node_handle, graph_, deps.data(), deps.size(), ¶ms, cuda_context_->context()), "Failed to add memset node to a CUDA graph")); @@ -370,7 +369,7 @@ absl::StatusOr CudaCommandBuffer::CreateMemcpyD2DNode( std::vector deps = ToCudaGraphHandles(dependencies); CUgraphNode node_handle = nullptr; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuGraphAddMemcpyNode(&node_handle, graph_, deps.data(), deps.size(), ¶ms, cuda_context_->context()), "Failed to add memcpy d2d node to a CUDA graph")); @@ -410,9 +409,9 @@ absl::Status CudaCommandBuffer::UpdateDnnGraphNode( dnn::DnnGraph& dnn_graph, Stream& stream, absl::Span operands, GraphNodeHandle node_handle) { CUgraph child_graph; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphChildGraphNodeGetGraph( + RETURN_IF_ERROR(cuda::ToStatus(cuGraphChildGraphNodeGetGraph( ToCudaGraphHandle(node_handle), &child_graph))); - TF_RETURN_IF_ERROR(dnn_graph.PopulateOrUpdateRawCommandBuffer( + RETURN_IF_ERROR(dnn_graph.PopulateOrUpdateRawCommandBuffer( stream, operands, child_graph, true)); return cuda::ToStatus( cuGraphExecChildGraphNodeSetParams( @@ -435,7 +434,7 @@ absl::StatusOr CudaCommandBuffer::CreateClonedChildNode( << " and add it to " << graph_ << "; deps: " << dependencies.size(); CUgraphNode node_handle; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuGraphAddChildGraphNode(&node_handle, graph_, deps.data(), deps.size(), child_graph), "Failed to create a child graph node and add it to a CUDA graph")); @@ -474,7 +473,7 @@ absl::StatusOr CudaCommandBuffer::CreateMovedChildNode( nodeParams.graph.ownership = CU_GRAPH_CHILD_GRAPH_OWNERSHIP_MOVE; CUgraphNode node_handle; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuGraphAddNode_v2(&node_handle, graph_, deps.data(), /*dependencyData=*/nullptr, deps.size(), &nodeParams), "Failed to create a child graph node and add it to a CUDA graph")); @@ -517,7 +516,7 @@ absl::StatusOr CudaCommandBuffer::CreateKernelNode( CUgraphNode node_handle = nullptr; const auto& cuda_kernel = static_cast(kernel); CUfunction function = cuda_kernel.gpu_function(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda_kernel.UpdateMaxDynamicSharedMemoryBytes(shared_mem_bytes)); std::unique_ptr repacked; @@ -560,7 +559,7 @@ absl::StatusOr CudaCommandBuffer::CreateKernelNode( CUgraphEdgeData edge_data_item; std::memset(&edge_data_item, 0, sizeof(edge_data_item)); CUgraphNodeType type; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuGraphNodeGetType(deps[i], &type), absl::StrCat("Failed to get CUDA graph node type for dependency ", i))); @@ -570,7 +569,7 @@ absl::StatusOr CudaCommandBuffer::CreateKernelNode( } edge_data.push_back(edge_data_item); } - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuGraphAddNode_v2(&node_handle, graph_, deps.data(), edge_data.data(), deps.size(), &cu_params), "Failed to add kernel node to a CUDA graph")); @@ -582,7 +581,7 @@ absl::StatusOr CudaCommandBuffer::CreateKernelNode( CUDA_KERNEL_NODE_PARAMS params{}; set_params(params); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuGraphAddKernelNode(&node_handle, graph_, deps.data(), deps.size(), ¶ms), "Failed to add kernel node to a CUDA graph")); @@ -592,7 +591,7 @@ absl::StatusOr CudaCommandBuffer::CreateKernelNode( CUDA_KERNEL_NODE_PARAMS params{}; set_params(params); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuGraphAddKernelNode(&node_handle, graph_, deps.data(), deps.size(), ¶ms), "Failed to add kernel node to a CUDA graph")); @@ -601,7 +600,7 @@ absl::StatusOr CudaCommandBuffer::CreateKernelNode( if (priority != StreamPriority::Default) { CUlaunchAttributeValue value; value.priority = stream_exec_->GetGpuStreamPriority(priority); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuGraphKernelNodeSetAttribute( node_handle, CU_LAUNCH_ATTRIBUTE_PRIORITY, &value), "Failed to set kernel node priority")); @@ -647,7 +646,7 @@ absl::Status CudaCommandBuffer::UpdateKernelNode( const_cast(packed_args->argument_addresses().data()); params.extra = nullptr; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda_kernel.UpdateMaxDynamicSharedMemoryBytes(shared_mem_bytes)); return cuda::ToStatus( @@ -664,7 +663,7 @@ absl::StatusOr CudaCommandBuffer::CreateEmptyNode( std::vector deps = ToCudaGraphHandles(dependencies); CUgraphNode node_handle = nullptr; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuGraphAddEmptyNode(&node_handle, graph_, deps.data(), deps.size()), "Failed to add empty node to a CUDA graph")); @@ -685,7 +684,7 @@ absl::Status CudaCommandBuffer::Trace( "12.3. Therefore tracing is not supported."); } - TF_RETURN_IF_ERROR(CheckNotFinalized()); + RETURN_IF_ERROR(CheckNotFinalized()); VLOG(5) << "Trace into GPU command buffer graph " << graph_ << " on a stream: " << stream; @@ -696,7 +695,7 @@ absl::Status CudaCommandBuffer::Trace( // Switch stream into the capture mode. uint64_t start_nanos = tsl::Env::Default()->NowNanos(); - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuStreamBeginCaptureToGraph(stream_handle, graph_, /*dependencies=*/nullptr, /*dependencyData=*/nullptr, @@ -708,11 +707,11 @@ absl::Status CudaCommandBuffer::Trace( // Always stop capturing the stream before checking `traced` result. VLOG(5) << "End stream " << stream << " capture"; CUgraph captured_graph; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuStreamEndCapture(stream_handle, &captured_graph), "Failed to end stream capture")); DCHECK(captured_graph == graph_) << "Stream capture should update graph_"; - TF_RETURN_IF_ERROR(traced); + RETURN_IF_ERROR(traced); uint64_t end_nanos = tsl::Env::Default()->NowNanos(); VLOG(5) << "Traced into the GPU command buffer graph " << graph_ << " (took " @@ -721,7 +720,7 @@ absl::Status CudaCommandBuffer::Trace( // Check that traced graph is not empty. Trying to instantiate a CUDA graph // with empty child node leads to a crash. size_t num_root_nodes = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuGraphGetRootNodes(captured_graph, nullptr, &num_root_nodes))); if (num_root_nodes == 0) { @@ -751,30 +750,30 @@ absl::Status CudaCommandBuffer::LaunchGraph(Stream* stream) { absl::StatusOr CudaCommandBuffer::GetNodeCount() const { size_t num_nodes; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuGraphGetNodes(graph_, /*nodes=*/nullptr, &num_nodes))); return num_nodes; } absl::Status CudaCommandBuffer::SetPriority(StreamPriority priority) { size_t num_nodes; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuGraphGetNodes(graph_, /*nodes=*/nullptr, &num_nodes))); std::vector nodes(num_nodes); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuGraphGetNodes(graph_, nodes.data(), &num_nodes))); int priority_value = stream_exec_->GetGpuStreamPriority(priority); for (size_t i = 0; i < num_nodes; i++) { CUgraphNodeType type; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphNodeGetType(nodes[i], &type), - "Failed to get kernel node type")); + RETURN_IF_ERROR(cuda::ToStatus(cuGraphNodeGetType(nodes[i], &type), + "Failed to get kernel node type")); if (type == CU_GRAPH_NODE_TYPE_KERNEL) { CUlaunchAttributeValue value; value.priority = priority_value; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuGraphKernelNodeSetAttribute( nodes[i], CU_LAUNCH_ATTRIBUTE_PRIORITY, &value), "Failed to set kernel node priority")); @@ -788,14 +787,13 @@ absl::Status CudaCommandBuffer::PrepareFinalization() { SemanticVersion{12, 8, 0}) { // For CUDA < 12080, cuda graph conditional node does not support // empty body graph. - TF_ASSIGN_OR_RETURN(auto node_count, GetNodeCount()); + ASSIGN_OR_RETURN(auto node_count, GetNodeCount()); if (node_count > 0) { return absl::OkStatus(); } - TF_ASSIGN_OR_RETURN(NoOpKernel * noop, GetNoOpKernel()); - TF_RETURN_IF_ERROR( - CreateLaunch(*noop, ThreadDim(), BlockDim(), {}).status()); + ASSIGN_OR_RETURN(NoOpKernel * noop, GetNoOpKernel()); + RETURN_IF_ERROR(CreateLaunch(*noop, ThreadDim(), BlockDim(), {}).status()); } return absl::OkStatus(); } @@ -811,7 +809,7 @@ CudaCommandBuffer::CreateConditionalHandle() { #if CUDA_VERSION >= 12030 CUgraphConditionalHandle handle; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuGraphConditionalHandleCreate(&handle, graph_, cuda_context_->context(), kDefaultLaunchValue, kNoFlags), "Failed to create conditional handle for a CUDA graph")); @@ -843,14 +841,14 @@ absl::Status CudaCommandBuffer::InstantiateGraph() { if (instantiated.code() == absl::StatusCode::kResourceExhausted) { LOG(WARNING) << "Retry CUDA graph instantiation after OOM error"; CUdevice device; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuDeviceGet(&device, stream_exec_->device_ordinal()), "Failed call to cuDeviceGet")); - TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGraphMemTrim(device), - "Failed to trim device graph memory")); - TF_RETURN_IF_ERROR(GraphInstantiate(&graph_exec_, graph_)); + RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGraphMemTrim(device), + "Failed to trim device graph memory")); + RETURN_IF_ERROR(GraphInstantiate(&graph_exec_, graph_)); } else { - TF_RETURN_IF_ERROR(instantiated); + RETURN_IF_ERROR(instantiated); } return absl::OkStatus(); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_context.cc b/third_party/xla/xla/stream_executor/cuda/cuda_context.cc index c0f11c6fd37605..41a151817fe04e 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_context.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_context.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/context_map.h" @@ -132,7 +133,7 @@ absl::StatusOr CudaContext::Create(int device_ordinal, unsigned int former_primary_context_flags; int former_primary_context_is_active; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuDevicePrimaryCtxGetState(device, &former_primary_context_flags, &former_primary_context_is_active))); if (former_primary_context_flags != flags) { @@ -142,14 +143,14 @@ absl::StatusOr CudaContext::Create(int device_ordinal, << former_primary_context_flags << ") than the desired flag set (" << flags << ")."; } else { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuDevicePrimaryCtxSetFlags(device, flags))); } } CUcontext former_context = CurrentContextOrDie(); CUcontext new_context; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuDevicePrimaryCtxRetain(&new_context, device))); if (former_context != nullptr) { CUdevice former_device; @@ -172,7 +173,7 @@ absl::StatusOr CudaContext::Create(int device_ordinal, << former_context; } } - TF_RETURN_IF_ERROR(cuda::ToStatus(cuCtxSetCurrent(former_context))); + RETURN_IF_ERROR(cuda::ToStatus(cuCtxSetCurrent(former_context))); context = GetContextMap()->Add(new_context, device_ordinal); CHECK(context != nullptr) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.cc b/third_party/xla/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.cc index c205e4154c4e22..5ce4565f7eb468 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_device_address_vmm_allocator.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/cuda/cuda_memory_reservation.h" @@ -52,7 +53,7 @@ CudaDeviceAddressVmmAllocator::Create(const Platform* platform, absl::Span devices) { auto allocator = absl::WrapUnique(new CudaDeviceAddressVmmAllocator(platform)); - TF_RETURN_IF_ERROR(PopulateDevices(allocator.get(), devices)); + RETURN_IF_ERROR(PopulateDevices(allocator.get(), devices)); return allocator; } @@ -99,10 +100,10 @@ absl::Status CudaDeviceAddressVmmAllocator::InitializeDeviceState( // Verify that the device supports 64-bit stream memory operations // (cuStreamWriteValue64), which requires compute capability >= 7.0. CUdevice cu_device; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuDeviceGet(&cu_device, ordinal), "cuDeviceGet")); int supported = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuDeviceGetAttribute(&supported, CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS, cu_device), @@ -124,7 +125,7 @@ absl::Status CudaDeviceAddressVmmAllocator::InitializeDeviceState( CUdeviceptr dev_ptr = 0; { std::unique_ptr activation = state.executor->Activate(); - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuMemHostAlloc(&host_ptr, sizeof(uint64_t), CU_MEMHOSTALLOC_PORTABLE), "cuMemHostAlloc for timeline counter")); *static_cast(host_ptr) = 0; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 6ca4af481bfcd3..d7fb078113cee3 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -98,6 +98,7 @@ limitations under the License. #include "third_party/cudnn_frontend/include/cudnn_frontend_PointWiseDesc.h" #include "third_party/cudnn_frontend/include/cudnn_frontend_Tensor.h" #include "third_party/cudnn_frontend/include/cudnn_frontend_VariantPack.h" +#include "xla/tsl/platform/status_macros.h" // clang-format on #ifdef __clang__ @@ -414,8 +415,8 @@ absl::Status CudnnSupport::Init() { constexpr SemanticVersion kSourceVersion(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); - TF_ASSIGN_OR_RETURN(SemanticVersion loaded_version, - cuda::GetLoadedCudnnVersion()); + ASSIGN_OR_RETURN(SemanticVersion loaded_version, + cuda::GetLoadedCudnnVersion()); if (!IsSourceCompatibleWithCudnnLibrary(kSourceVersion, loaded_version)) { const std::string error = absl::StrCat( "Loaded runtime CuDNN library: ", loaded_version.ToString(), @@ -431,7 +432,7 @@ absl::Status CudnnSupport::Init() { } cudnn_ = std::make_unique(cudnn_handle); - TF_RETURN_IF_ERROR(cudnn_->InitializeCompilationHandle()); + RETURN_IF_ERROR(cudnn_->InitializeCompilationHandle()); LOG(INFO) << "Loaded cuDNN version " << cudnnGetVersion(); return absl::OkStatus(); } @@ -466,8 +467,8 @@ void CudnnSupport::NotifyStreamDestroyed(Stream* stream) /* override */ { } absl::StatusOr CudnnSupport::GetVersion() { - TF_ASSIGN_OR_RETURN(SemanticVersion version, - stream_executor::cuda::GetLoadedCudnnVersion()); + ASSIGN_OR_RETURN(SemanticVersion version, + stream_executor::cuda::GetLoadedCudnnVersion()); return stream_executor::dnn::VersionInfo(version.major_version(), version.minor_version(), version.patch_version()); @@ -1183,8 +1184,8 @@ class CudnnDropoutDescriptor { size_t state_sizes_in_bytes = 0; RETURN_IF_CUDNN_ERROR( cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes)); - TF_ASSIGN_OR_RETURN(state_memory, - state_allocator->AllocateBytes(state_sizes_in_bytes)); + ASSIGN_OR_RETURN(state_memory, + state_allocator->AllocateBytes(state_sizes_in_bytes)); } RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor( handle.get(), cudnn.handle(), dropout, state_memory.opaque(), @@ -1275,7 +1276,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { const dnn::AlgorithmConfig& algorithm_config, const EngineOptions& engine_options, float dropout, uint64_t seed, ScratchAllocator* state_allocator, bool use_padded_io) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( CudnnDropoutDescriptor dropout_desc, CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator)); @@ -1337,10 +1338,10 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { // TODO(kaixih@nvidia.com): Should be removed when cudnnRNNForward*** and // cudnnRNNForward***Ex are removed from the codebase, since the new API // doesn't need param descriptors any more. - TF_ASSIGN_OR_RETURN(auto params_desc, - CudnnRnnParamsDescriptor::Create( - cudnn, input_size, data_type, rnn_desc.get(), - rnn_mode, direction_mode, num_layers)); + ASSIGN_OR_RETURN(auto params_desc, + CudnnRnnParamsDescriptor::Create( + cudnn, input_size, data_type, rnn_desc.get(), rnn_mode, + direction_mode, num_layers)); return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan), num_layers, hidden_size, input_size, cell_size, @@ -1584,17 +1585,17 @@ absl::StatusOr CudnnRnnParamsDescriptor::Create( dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); return size; }; - TF_ASSIGN_OR_RETURN(int64_t m_size, get_size(m_region_desc_handle)); + ASSIGN_OR_RETURN(int64_t m_size, get_size(m_region_desc_handle)); int64_t m_offset = static_cast(m_ptr) - static_cast(w_ptr); dnn::RnnDescriptor::ParamsRegion m_region = {m_offset, m_size}; weights.push_back(m_region); - TF_ASSIGN_OR_RETURN(int64_t b_size, get_size(b_region_desc_handle)); + ASSIGN_OR_RETURN(int64_t b_size, get_size(b_region_desc_handle)); int64_t b_offset = static_cast(b_ptr) - static_cast(w_ptr); dnn::RnnDescriptor::ParamsRegion b_region = {b_offset, b_size}; biases.push_back(b_region); } - TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights( + RETURN_IF_ERROR(CheckAndFetchProjectionWeights( cudnn, rnn_desc, layer, input_desc, filter_desc, params_size_in_bytes, region_desc_handle, &weights)); } @@ -1840,13 +1841,13 @@ absl::Status CreateRnnTempSpace( } if (workspace_size_in_bytes > 0) { - TF_ASSIGN_OR_RETURN(*workspace, workspace_allocator->AllocateBytes( - workspace_size_in_bytes)); + ASSIGN_OR_RETURN(*workspace, workspace_allocator->AllocateBytes( + workspace_size_in_bytes)); } if (reserve_space_allocator != nullptr && is_fwd_training && reserve_space_size_in_bytes > 0) { - TF_ASSIGN_OR_RETURN(*reserve_space, reserve_space_allocator->AllocateBytes( - reserve_space_size_in_bytes)); + ASSIGN_OR_RETURN(*reserve_space, reserve_space_allocator->AllocateBytes( + reserve_space_size_in_bytes)); } return absl::OkStatus(); } @@ -1909,7 +1910,7 @@ static absl::Status PopulateProfileFromTimer( dnn::ProfileResult* profile_result, std::optional scratch_size = std::nullopt) { if (profile_result) { - TF_ASSIGN_OR_RETURN(absl::Duration duration, timer->GetElapsedDuration()); + ASSIGN_OR_RETURN(absl::Duration duration, timer->GetElapsedDuration()); profile_result->set_algorithm(algorithm); profile_result->set_elapsed_time_in_ms( absl::ToDoubleMilliseconds(duration)); @@ -1939,7 +1940,7 @@ absl::Status CudnnSupport::DoRnnForwardImpl( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( RnnModelDims model_dims, ExtractAndCheckRnnForward( rnn_desc, input_desc, input_data, input_h_desc, input_h_data, @@ -1948,19 +1949,18 @@ absl::Status CudnnSupport::DoRnnForwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); - TF_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); DeviceAddress reserve_space; DeviceAddress workspace; - TF_RETURN_IF_ERROR(CreateRnnTempSpace( + RETURN_IF_ERROR(CreateRnnTempSpace( stream, cudnn, rnn_desc, model_dims, input_desc, workspace_allocator, reserve_space_allocator, is_training, &workspace, &reserve_space)); std::unique_ptr timer; if (output_profile_result != nullptr) { - TF_ASSIGN_OR_RETURN(timer, - stream->CreateEventBasedTimer( - output_profile_result->warmup_run_executed())); + ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } if (input_desc.is_var_seq_lengths()) { @@ -2020,7 +2020,7 @@ absl::Status CudnnSupport::DoRnnForwardImpl( } if (timer != nullptr) { - TF_RETURN_IF_ERROR(PopulateProfileFromTimer( + RETURN_IF_ERROR(PopulateProfileFromTimer( timer.get(), *rnn_desc.algorithm_config().algorithm(), output_profile_result)); } @@ -2054,7 +2054,7 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( DeviceAddress* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( RnnModelDims model_dims, ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc, input_h_data, input_c_desc, input_c_data, @@ -2063,18 +2063,17 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); - TF_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); DeviceAddress workspace; - TF_RETURN_IF_ERROR(CreateRnnTempSpace(stream, cudnn, rnn_desc, model_dims, - input_desc, workspace_allocator, - nullptr, true, &workspace, nullptr)); + RETURN_IF_ERROR(CreateRnnTempSpace(stream, cudnn, rnn_desc, model_dims, + input_desc, workspace_allocator, nullptr, + true, &workspace, nullptr)); std::unique_ptr timer; if (output_profile_result != nullptr) { - TF_ASSIGN_OR_RETURN(timer, - stream->CreateEventBasedTimer( - output_profile_result->warmup_run_executed())); + ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + output_profile_result->warmup_run_executed())); } if (input_desc.is_var_seq_lengths()) { @@ -2100,7 +2099,7 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( if (params_backprop_data != nullptr) { // Clear the dw to zeros. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( stream->MemZero(params_backprop_data, params_backprop_data->size())); RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8( /*handle=*/cudnn.handle(), @@ -2151,7 +2150,7 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( if (params_backprop_data != nullptr) { // Clear the dw to zeros. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( stream->MemZero(params_backprop_data, params_backprop_data->size())); // make the backward weight call RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights( @@ -2171,7 +2170,7 @@ absl::Status CudnnSupport::DoRnnBackwardImpl( } if (timer != nullptr) { - TF_RETURN_IF_ERROR(PopulateProfileFromTimer( + RETURN_IF_ERROR(PopulateProfileFromTimer( timer.get(), *rnn_desc.algorithm_config().algorithm(), output_profile_result)); } @@ -2223,7 +2222,7 @@ CudnnSupport::CreateRnnDescriptor( // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's // not enqueueing anything into a stream, we pass in the null stream. auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( CudnnRnnDescriptor rnn_desc, CudnnRnnDescriptor::Create( cudnn, num_layers, hidden_size, input_size, cell_size, batch_size, @@ -2240,10 +2239,10 @@ absl::StatusOr> CudnnSupport::CreateRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, dnn::DataType data_type) { - TF_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, - CudnnRnnSequenceTensorDescriptor::Create( - parent_, max_seq_length, batch_size, data_size, - ToCudnnDataType(data_type))); + ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, + CudnnRnnSequenceTensorDescriptor::Create( + parent_, max_seq_length, batch_size, data_size, + ToCudnnDataType(data_type))); return std::unique_ptr( new CudnnRnnSequenceTensorDescriptor(std::move(descriptor))); } @@ -2253,10 +2252,10 @@ CudnnSupport::CreateRnnSequenceTensorDescriptor( int max_seq_length, int batch_size, int data_size, const absl::Span& seq_lengths, bool time_major, dnn::DataType data_type) { - TF_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, - CudnnRnnSequenceTensorDescriptor::Create( - parent_, max_seq_length, batch_size, data_size, - seq_lengths, time_major, ToCudnnDataType(data_type))); + ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, + CudnnRnnSequenceTensorDescriptor::Create( + parent_, max_seq_length, batch_size, data_size, + seq_lengths, time_major, ToCudnnDataType(data_type))); return std::unique_ptr( new CudnnRnnSequenceTensorDescriptor(std::move(descriptor))); } @@ -2852,9 +2851,9 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, std::vector input_strides = input_descriptor.vectorized_strides( dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); - TF_ASSIGN_OR_RETURN(auto tensor_x, - CreateCudnnTensor(input_dims, input_strides, 'x', - input_type, vector_size, vector_dim)); + ASSIGN_OR_RETURN(auto tensor_x, + CreateCudnnTensor(input_dims, input_strides, 'x', input_type, + vector_size, vector_dim)); // y tensor. std::tie(vector_size, vector_dim) = @@ -2864,9 +2863,9 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, std::vector output_strides = output_descriptor.vectorized_strides( dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); - TF_ASSIGN_OR_RETURN(auto tensor_y, - CreateCudnnTensor(output_dims, output_strides, 'y', - output_type, vector_size, vector_dim)); + ASSIGN_OR_RETURN(auto tensor_y, + CreateCudnnTensor(output_dims, output_strides, 'y', + output_type, vector_size, vector_dim)); // w tensor. std::tie(vector_size, vector_dim) = @@ -2882,7 +2881,7 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, ? CUDNN_TENSOR_REORDERING_INT8x32 : CUDNN_TENSOR_REORDERING_NONE; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tensor_w, CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, vector_size, vector_dim, @@ -3116,8 +3115,8 @@ GetGenericCudnnOperationGraph( } pos = l + 2; - TF_ASSIGN_OR_RETURN(output_type, - PrimitiveTypeStringToDnnType(data_type_string)); + ASSIGN_OR_RETURN(output_type, + PrimitiveTypeStringToDnnType(data_type_string)); TensorKind binary_operand_kind, output_kind; if (op_string == "conv") { if (!op_graph.Empty()) { @@ -3143,16 +3142,16 @@ GetGenericCudnnOperationGraph( "Non-convolution op must have one or more operands in the " "graph."); } - TF_ASSIGN_OR_RETURN(std::tie(binary_operand_kind, output_kind, mode), - OpNameStringToOperandKindAndMode(op_string)); + ASSIGN_OR_RETURN(std::tie(binary_operand_kind, output_kind, mode), + OpNameStringToOperandKindAndMode(op_string)); } - TF_RETURN_IF_ERROR(op_graph.AddOp( - uid, operands, mode, binary_operand_kind, output_kind, output_type)); + RETURN_IF_ERROR(op_graph.AddOp(uid, operands, mode, binary_operand_kind, + output_kind, output_type)); } return op_graph; }; - TF_ASSIGN_OR_RETURN(OpGraph op_graph, deserialize_cudnn_graph()); + ASSIGN_OR_RETURN(OpGraph op_graph, deserialize_cudnn_graph()); if (op_graph.Empty()) { return absl::InternalError("No supported ops in convolution graph."); } @@ -3198,7 +3197,7 @@ GetGenericCudnnOperationGraph( std::vector input_strides = input_descriptor.vectorized_strides( dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tensor_x, CreateCudnnTensor(input_dims, input_strides, next_uid(/*is_operand=*/true, /*is_virtual=*/false), @@ -3217,7 +3216,7 @@ GetGenericCudnnOperationGraph( dnn::FilterLayout::kOutputInputYX32_CudnnReordered ? CUDNN_TENSOR_REORDERING_INT8x32 : CUDNN_TENSOR_REORDERING_NONE; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tensor_w, CreateCudnnTensor(filter_dims, filter_strides, next_uid(/*is_operand=*/true, /*is_virtual=*/false), @@ -3225,7 +3224,7 @@ GetGenericCudnnOperationGraph( /*is_virtual=*/false, tensor_ordering_type)); // Result tensor. - TF_ASSIGN_OR_RETURN(OpDescriptor op_descriptor, op_graph.OpDescriptorAt(0)); + ASSIGN_OR_RETURN(OpDescriptor op_descriptor, op_graph.OpDescriptorAt(0)); std::tie(vector_size, vector_dim) = GetTensorVectorSizeAndDim(output_descriptor, op_descriptor.result_type); std::vector output_dims = output_descriptor.vectorized_dims( @@ -3233,7 +3232,7 @@ GetGenericCudnnOperationGraph( std::vector output_strides = output_descriptor.vectorized_strides( dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tensor_y, CreateCudnnTensor(output_dims, output_strides, next_uid(/*is_operand=*/false, @@ -3279,11 +3278,10 @@ GetGenericCudnnOperationGraph( // Add the convolution to the cuDNN graph. ops.push_back(std::move(op)); result_tensors.push_back(std::move(tensor_y)); - TF_RETURN_IF_ERROR( - op_graph.SetSequenceIndex(op_descriptor.uid, ops.size() - 1)); + RETURN_IF_ERROR(op_graph.SetSequenceIndex(op_descriptor.uid, ops.size() - 1)); for (int op_index = 1; op_index < op_graph.Size(); ++op_index) { - TF_ASSIGN_OR_RETURN(op_descriptor, op_graph.OpDescriptorAt(op_index)); + ASSIGN_OR_RETURN(op_descriptor, op_graph.OpDescriptorAt(op_index)); std::vector preceding_ops; preceding_ops.reserve(op_descriptor.operand_uids.size()); for (int64_t operand_uid : op_descriptor.operand_uids) { @@ -3297,7 +3295,7 @@ GetGenericCudnnOperationGraph( if (op_descriptor.operand_kind == TensorKind::kScalar && preceding_ops.size() == 1) { std::vector scale_dim(4, 1); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( external_operand, CreateCudnnTensor(scale_dim, scale_dim, next_uid(/*is_operand=*/true, /*is_virtual=*/false), @@ -3305,7 +3303,7 @@ GetGenericCudnnOperationGraph( VLOG(4) << "\nPointwise operand: " << external_operand->describe(); } else if (op_descriptor.operand_kind == TensorKind::kTensor && preceding_ops.size() == 1) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( external_operand, CreateCudnnTensor(tensor_y, next_uid(/*is_operand=*/true, /*is_virtual=*/false), @@ -3317,15 +3315,15 @@ GetGenericCudnnOperationGraph( // Create the result tensor of the op. if (op_descriptor.result_kind == TensorKind::kScalar) { std::vector scale_dim(4, 1); - TF_ASSIGN_OR_RETURN(cudnn_frontend::Tensor result, - CreateCudnnTensor(scale_dim, scale_dim, - next_uid(/*is_operand=*/false, - /*is_virtual=*/false), - op_descriptor.result_type, 1, -1)); + ASSIGN_OR_RETURN(cudnn_frontend::Tensor result, + CreateCudnnTensor(scale_dim, scale_dim, + next_uid(/*is_operand=*/false, + /*is_virtual=*/false), + op_descriptor.result_type, 1, -1)); VLOG(4) << "\nScalar result: " << result.describe(); result_tensors.push_back(std::move(result)); } else if (op_descriptor.result_kind == TensorKind::kTensor) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( cudnn_frontend::Tensor result, CreateCudnnTensor(tensor_y, next_uid(/*is_operand=*/false, @@ -3392,7 +3390,7 @@ GetGenericCudnnOperationGraph( .setreductionDesc(desc) .build()); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( op_graph.SetSequenceIndex(op_descriptor.uid, ops.size() - 1)); } @@ -3460,9 +3458,9 @@ GetCudnnFusedOperationGraph( std::vector input_strides = input_descriptor.vectorized_strides( dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); - TF_ASSIGN_OR_RETURN(auto tensor_x, - CreateCudnnTensor(input_dims, input_strides, 'x', - input_type, vector_size, vector_dim)); + ASSIGN_OR_RETURN(auto tensor_x, + CreateCudnnTensor(input_dims, input_strides, 'x', input_type, + vector_size, vector_dim)); std::tie(vector_size, vector_dim) = GetTensorVectorSizeAndDim(output_descriptor, output_type); @@ -3470,13 +3468,13 @@ GetCudnnFusedOperationGraph( dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); std::vector output_strides = output_descriptor.vectorized_strides( dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); - TF_ASSIGN_OR_RETURN(auto tensor_y, - CreateCudnnTensor(output_dims, output_strides, 'y', - output_type, vector_size, vector_dim)); + ASSIGN_OR_RETURN(auto tensor_y, + CreateCudnnTensor(output_dims, output_strides, 'y', + output_type, vector_size, vector_dim)); - TF_ASSIGN_OR_RETURN(auto tensor_z, - CreateCudnnTensor(output_dims, output_strides, 'z', - output_type, vector_size, vector_dim)); + ASSIGN_OR_RETURN(auto tensor_z, + CreateCudnnTensor(output_dims, output_strides, 'z', + output_type, vector_size, vector_dim)); std::tie(vector_size, vector_dim) = GetTensorVectorSizeAndDim(filter_descriptor, input_type); @@ -3491,7 +3489,7 @@ GetCudnnFusedOperationGraph( ? CUDNN_TENSOR_REORDERING_INT8x32 : CUDNN_TENSOR_REORDERING_NONE; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tensor_w, CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, vector_size, vector_dim, @@ -3537,21 +3535,21 @@ GetCudnnFusedOperationGraph( maybe_tensor_b = CreateCudnnTensor(bias_dims, bias_strides, 'b', bias_type, vector_size, vector_dim); } - TF_ASSIGN_OR_RETURN(auto tensor_b, std::move(maybe_tensor_b)); + ASSIGN_OR_RETURN(auto tensor_b, std::move(maybe_tensor_b)); std::tie(vector_size, vector_dim) = GetTensorVectorSizeAndDim(output_descriptor, output_type); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tensor_conv, CreateCudnnTensor(output_dims, output_strides, 'C', accumulator_type, vector_size, vector_dim, /*is_virtual=*/true)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tensor_add, CreateCudnnTensor(output_dims, output_strides, 'A', activation_type, vector_size, vector_dim, /*is_virtual=*/true)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tensor_bias, CreateCudnnTensor(output_dims, output_strides, 'B', activation_type, vector_size, vector_dim, /*is_virtual=*/true)); @@ -3743,36 +3741,36 @@ GetCudnnFusedMatmulGraph(dnn::DataType input_type, dnn::DataType bias_type, int64_t stride1 = trans_a ? 1 : lda; int64_t stride2 = trans_a ? lda : 1; std::vector a_strides = {m * k, stride1, stride2}; - TF_ASSIGN_OR_RETURN(auto tensor_a, - CreateCudnnTensor(a_dims, a_strides, 'a', input_type, - vector_size, vector_dim)); + ASSIGN_OR_RETURN(auto tensor_a, + CreateCudnnTensor(a_dims, a_strides, 'a', input_type, + vector_size, vector_dim)); std::vector b_dims = {1, k, n}; stride1 = trans_b ? 1 : ldb; stride2 = trans_b ? ldb : 1; std::vector b_strides = {k * n, stride1, stride2}; - TF_ASSIGN_OR_RETURN(auto tensor_b, - CreateCudnnTensor(b_dims, b_strides, 'b', input_type, - vector_size, vector_dim)); + ASSIGN_OR_RETURN(auto tensor_b, + CreateCudnnTensor(b_dims, b_strides, 'b', input_type, + vector_size, vector_dim)); std::vector c_dims = {1, m, n}; std::vector c_strides = {m * n, ldc, 1}; - TF_ASSIGN_OR_RETURN(auto tensor_c, - CreateCudnnTensor(c_dims, c_strides, 'c', output_type, - vector_size, vector_dim)); + ASSIGN_OR_RETURN(auto tensor_c, + CreateCudnnTensor(c_dims, c_strides, 'c', output_type, + vector_size, vector_dim)); std::vector z_dims = {1, 1, n}; std::vector z_strides = {n, n, 1}; - TF_ASSIGN_OR_RETURN(auto tensor_z, - CreateCudnnTensor(z_dims, z_strides, 'z', bias_type, - vector_size, vector_dim)); + ASSIGN_OR_RETURN(auto tensor_z, + CreateCudnnTensor(z_dims, z_strides, 'z', bias_type, + vector_size, vector_dim)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tensor_matmul, CreateCudnnTensor(c_dims, c_strides, 'M', accumulator_type, vector_size, vector_dim, /*is_virtual=*/true)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto tensor_bias, CreateCudnnTensor(c_dims, c_strides, 'B', activation_type, vector_size, vector_dim, /*is_virtual=*/true)); @@ -4158,7 +4156,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( } if (score_mod) { - TF_RETURN_IF_ERROR(score_mod->UpdateCudnnMap(graph, next_uid)); + RETURN_IF_ERROR(score_mod->UpdateCudnnMap(graph, next_uid)); sdpa_options.set_score_mod([=](Graph graph, Tensor score) -> Tensor { return score_mod->Forward(graph, score); }); @@ -4200,11 +4198,11 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( offset_tensor->set_uid(next_uid()); } CudnnGraph cudnnGraph(std::move(graph)); - TF_RETURN_IF_ERROR(cudnnGraph.Prepare( + RETURN_IF_ERROR(cudnnGraph.Prepare( dnn_support, EngineOptions{/*require_determinism=*/false, /*allow_tf32=*/true, /*require_command_buffer=*/false})); - TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); + RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); VLOG(4) << "\b flash attention operation graph: " << cudnnGraph.Graph(); return cudnnGraph; @@ -4333,11 +4331,11 @@ absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( .set_uid(next_uid()); } CudnnGraph cudnnGraph(std::move(graph)); - TF_RETURN_IF_ERROR(cudnnGraph.Prepare( + RETURN_IF_ERROR(cudnnGraph.Prepare( dnn_support, EngineOptions{/*require_determinism=*/false, /*allow_tf32=*/true, /*require_command_buffer=*/false})); - TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); + RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); VLOG(4) << "\b workspace size:" << cudnnGraph.Graph().get_workspace_size(); VLOG(4) << "\b flash attention operation graph: " << cudnnGraph.Graph(); @@ -4523,11 +4521,11 @@ absl::StatusOr GetCudnnFlashAttentionBackwardF8OperationGraph( .set_uid(next_uid()); CudnnGraph cudnnGraph(std::move(graph)); - TF_RETURN_IF_ERROR(cudnnGraph.Prepare( + RETURN_IF_ERROR(cudnnGraph.Prepare( dnn_support, EngineOptions{/*require_determinism=*/false, /*allow_tf32=*/true, /*require_command_buffer=*/false})); - TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); + RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); VLOG(4) << "\b workspace size:" << cudnnGraph.Graph().get_workspace_size(); VLOG(4) << "\b flash attention f8 operation backward graph: " @@ -4568,8 +4566,8 @@ absl::StatusOr GetCudnnBlockScaledDotOperationGraph( auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; auto get_tensor_attr = [&](const dnn::TensorDescriptor& desc, bool is_rhs) -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(std::vector dimensions, - desc.GetPhysicalDimensionsMajorToMinor()); + ASSIGN_OR_RETURN(std::vector dimensions, + desc.GetPhysicalDimensionsMajorToMinor()); std::vector strides = desc.GetPhysicalStridesMajorToMinor(); if (dimensions.size() == 2) { dimensions.insert(dimensions.begin(), 1); // Batch dimension is implicit. @@ -4586,10 +4584,10 @@ absl::StatusOr GetCudnnBlockScaledDotOperationGraph( .set_stride(strides) .set_data_type(ToCudnnFrontendDataType(desc.type())); }; - TF_ASSIGN_OR_RETURN(auto a_data_attr, get_tensor_attr(lhs_data, false)); - TF_ASSIGN_OR_RETURN(auto b_data_attr, get_tensor_attr(rhs_data, true)); - TF_ASSIGN_OR_RETURN(auto a_scale_attr, get_tensor_attr(lhs_scale, false)); - TF_ASSIGN_OR_RETURN(auto b_scale_attr, get_tensor_attr(rhs_scale, true)); + ASSIGN_OR_RETURN(auto a_data_attr, get_tensor_attr(lhs_data, false)); + ASSIGN_OR_RETURN(auto b_data_attr, get_tensor_attr(rhs_data, true)); + ASSIGN_OR_RETURN(auto a_scale_attr, get_tensor_attr(lhs_scale, false)); + ASSIGN_OR_RETURN(auto b_scale_attr, get_tensor_attr(rhs_scale, true)); a_scale_attr.set_reordering_type( cudnn_frontend::TensorReordering_t::F8_128x4); @@ -4630,11 +4628,11 @@ absl::StatusOr GetCudnnBlockScaledDotOperationGraph( } CudnnGraph cudnnGraph(std::move(graph)); - TF_RETURN_IF_ERROR(cudnnGraph.Prepare( + RETURN_IF_ERROR(cudnnGraph.Prepare( dnn_support, EngineOptions{/*require_determinism=*/false, /*allow_tf32=*/true, /*require_command_buffer=*/false})); - TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); + RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); VLOG(4) << "\b workspace size:" << cudnnGraph.Graph().get_workspace_size(); VLOG(4) << "\b block scaled dot graph: " << cudnnGraph.Graph(); @@ -4877,7 +4875,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( } if (score_mod) { - TF_RETURN_IF_ERROR(score_mod->UpdateCudnnMap(graph, next_uid)); + RETURN_IF_ERROR(score_mod->UpdateCudnnMap(graph, next_uid)); sdpa_backward_options.set_score_mod( [=](Graph graph, Tensor score) -> Tensor { return score_mod->Forward(graph, score); @@ -4924,11 +4922,11 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( } CudnnGraph cudnnGraph(std::move(graph)); - TF_RETURN_IF_ERROR(cudnnGraph.Prepare( + RETURN_IF_ERROR(cudnnGraph.Prepare( dnn_support, EngineOptions{force_deterministic, /*allow_tf32=*/true, /*require_command_buffer=*/false})); - TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); + RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); VLOG(4) << "\b flash attention operation backward graph: " << cudnnGraph.Graph(); @@ -5031,7 +5029,7 @@ absl::StatusOr> GetDescriptorAttribute( std::vector result(n); for (int i = 0; i < n; ++i) { - TF_ASSIGN_OR_RETURN(result[i], CreateBackendDesc(type)); + ASSIGN_OR_RETURN(result[i], CreateBackendDesc(type)); } std::vector raw_ptrs; @@ -5054,7 +5052,7 @@ absl::StatusOr> GetDescriptorAttribute( // them in the form of an AlgorithmDesc for use with RebuildExecutionPlan. absl::StatusOr ExecutionPlanToAlgorithmDesc( const cudnn_frontend::ExecutionPlan& plan, size_t workspace_size) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto engine_cfgs, GetDescriptorAttribute(plan.get_raw_desc(), CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG, @@ -5064,7 +5062,7 @@ absl::StatusOr ExecutionPlanToAlgorithmDesc( "CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG had more than one element."); } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto engines, GetDescriptorAttribute(engine_cfgs[0].get(), CUDNN_ATTR_ENGINECFG_ENGINE, CUDNN_BACKEND_ENGINE_DESCRIPTOR)); @@ -5087,8 +5085,8 @@ absl::StatusOr ExecutionPlanToAlgorithmDesc( // were filled. std::vector knobs(CUDNN_KNOB_TYPE_COUNTS); for (int i = 0; i < knobs.size(); ++i) { - TF_ASSIGN_OR_RETURN( - knobs[i], CreateBackendDesc(CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR)); + ASSIGN_OR_RETURN(knobs[i], + CreateBackendDesc(CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR)); } std::vector raw_knob_ptrs; raw_knob_ptrs.reserve(knobs.size()); @@ -5219,8 +5217,8 @@ class CudnnExecutionPlanRunner std::unique_ptr timer; if (profile_result != nullptr) { - TF_ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( - profile_result->warmup_run_executed())); + ASSIGN_OR_RETURN(timer, stream->CreateEventBasedTimer( + profile_result->warmup_run_executed())); } cudnnStatus_t status = cudnnBackendExecute( @@ -5228,8 +5226,8 @@ class CudnnExecutionPlanRunner RETURN_IF_CUDNN_ERROR(status); if (timer != nullptr) { - TF_ASSIGN_OR_RETURN(auto desc, ToAlgorithmDesc()); - TF_RETURN_IF_ERROR(PopulateProfileFromTimer( + ASSIGN_OR_RETURN(auto desc, ToAlgorithmDesc()); + RETURN_IF_ERROR(PopulateProfileFromTimer( timer.get(), desc, profile_result, scratch_memory.size())); VLOG(4) << "cudnn op with plan " << plan_.getTag() @@ -5427,7 +5425,7 @@ absl::Status CudnnSupport::GetConvolveRunners( const EngineOptions& engine_options, std::vector>* out_exec_plans) { auto cudnn = cudnn_->GetHandle(parent_, stream); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto op_graph, GetCudnnOperationGraph(kind, input_type, output_type, input_descriptor, filter_descriptor, output_descriptor, @@ -5450,7 +5448,7 @@ absl::Status CudnnSupport::GetGraphConvolveRunners( std::vector>* out_exec_plans, std::string serialized_graph) { auto cudnn = cudnn_->GetHandle(parent_, stream); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto op_graph_and_uids, GetGenericCudnnOperationGraph( kind, input_type, input_descriptor, filter_descriptor, @@ -5471,16 +5469,16 @@ CudnnSupport::ConvolveRunnerFromDesc( const dnn::ConvolutionDescriptor& convolution_descriptor) { auto cudnn = cudnn_->GetHandle(parent_, stream); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto op_graph, GetCudnnOperationGraph(kind, input_type, output_type, input_descriptor, filter_descriptor, output_descriptor, convolution_descriptor, cudnn)); - TF_ASSIGN_OR_RETURN(auto execution_plan, - RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph)); + ASSIGN_OR_RETURN(auto execution_plan, + RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto runner, CudnnExecutionPlanRunner::Create( parent_, cudnn_.get(), std::move(execution_plan), {'x', 'w', 'y'}, @@ -5500,21 +5498,21 @@ CudnnSupport::GraphConvolveRunnerFromDesc( std::string serialized_graph) { auto cudnn = cudnn_->GetHandle(parent_, stream); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto op_graph_and_uids, GetGenericCudnnOperationGraph( kind, input_type, input_descriptor, filter_descriptor, output_descriptor, convolution_descriptor, cudnn, serialized_graph)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto execution_plan, RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph_and_uids.first)); - TF_ASSIGN_OR_RETURN(auto runner, - CudnnExecutionPlanRunner::Create( - parent_, cudnn_.get(), std::move(execution_plan), - op_graph_and_uids.second, - /*need_side_input=*/false)); + ASSIGN_OR_RETURN(auto runner, + CudnnExecutionPlanRunner::Create( + parent_, cudnn_.get(), std::move(execution_plan), + op_graph_and_uids.second, + /*need_side_input=*/false)); return {std::make_unique>( std::move(runner))}; } @@ -5533,22 +5531,22 @@ CudnnSupport::FusedConvolveRunnerFromDesc( dnn::ActivationMode activation_mode) { auto cudnn = cudnn_->GetHandle(parent_, stream); - TF_ASSIGN_OR_RETURN(auto op_graph, - GetCudnnFusedOperationGraph( - kind, input_type, bias_type, output_type, conv_scale, - side_input_scale, leakyrelu_alpha, input_descriptor, - filter_descriptor, bias_descriptor, output_descriptor, - convolution_descriptor, activation_mode, cudnn)); + ASSIGN_OR_RETURN(auto op_graph, + GetCudnnFusedOperationGraph( + kind, input_type, bias_type, output_type, conv_scale, + side_input_scale, leakyrelu_alpha, input_descriptor, + filter_descriptor, bias_descriptor, output_descriptor, + convolution_descriptor, activation_mode, cudnn)); - TF_ASSIGN_OR_RETURN(auto execution_plan, - RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph)); + ASSIGN_OR_RETURN(auto execution_plan, + RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph)); bool need_side_input = SideInputNeeded(activation_mode, conv_scale, side_input_scale); - TF_ASSIGN_OR_RETURN(auto runner, - CudnnExecutionPlanRunner::Create( - parent_, cudnn_.get(), std::move(execution_plan), - {'x', 'w', 'z', 'b', 'y'}, need_side_input)); + ASSIGN_OR_RETURN(auto runner, + CudnnExecutionPlanRunner::Create( + parent_, cudnn_.get(), std::move(execution_plan), + {'x', 'w', 'z', 'b', 'y'}, need_side_input)); return {std::make_unique>( std::move(runner))}; } @@ -5705,40 +5703,39 @@ CudnnSupport::NormRunnerFromDesc( next_uid(), tensor_descriptor.type(), 1, -1); }; - TF_ASSIGN_OR_RETURN(auto x_tensor, create_cudnn_tensor(x_descriptor)); - TF_ASSIGN_OR_RETURN(auto scale_tensor, create_cudnn_tensor(scale_descriptor)); - TF_ASSIGN_OR_RETURN(auto y_or_dx_tensor, - create_cudnn_tensor(y_or_dx_descriptor)); + ASSIGN_OR_RETURN(auto x_tensor, create_cudnn_tensor(x_descriptor)); + ASSIGN_OR_RETURN(auto scale_tensor, create_cudnn_tensor(scale_descriptor)); + ASSIGN_OR_RETURN(auto y_or_dx_tensor, + create_cudnn_tensor(y_or_dx_descriptor)); std::optional bias_tensor, expectation_tensor, norm_factor_tensor, dy_tensor, dscale_tensor, dbias_tensor; if (kind == dnn::NormKind::LAYER_FWD_INFER || kind == dnn::NormKind::LAYER_FWD_TRAIN) { - TF_ASSIGN_OR_RETURN(bias_tensor, - create_cudnn_tensor(bias_descriptor.value())); + ASSIGN_OR_RETURN(bias_tensor, create_cudnn_tensor(bias_descriptor.value())); } if (kind == dnn::LAYER_FWD_TRAIN) { - TF_ASSIGN_OR_RETURN(expectation_tensor, - create_cudnn_tensor(expectation_descriptor.value())); - TF_ASSIGN_OR_RETURN(norm_factor_tensor, - create_cudnn_tensor(norm_factor_descriptor.value())); + ASSIGN_OR_RETURN(expectation_tensor, + create_cudnn_tensor(expectation_descriptor.value())); + ASSIGN_OR_RETURN(norm_factor_tensor, + create_cudnn_tensor(norm_factor_descriptor.value())); } if (kind == dnn::LAYER_BWD) { - TF_ASSIGN_OR_RETURN(dy_tensor, create_cudnn_tensor(dy_descriptor.value())); - TF_ASSIGN_OR_RETURN(expectation_tensor, - create_cudnn_tensor(expectation_descriptor.value())); - TF_ASSIGN_OR_RETURN(norm_factor_tensor, - create_cudnn_tensor(norm_factor_descriptor.value())); - TF_ASSIGN_OR_RETURN(dscale_tensor, - create_cudnn_tensor(dscale_descriptor.value())); - TF_ASSIGN_OR_RETURN(dbias_tensor, - create_cudnn_tensor(dbias_descriptor.value())); + ASSIGN_OR_RETURN(dy_tensor, create_cudnn_tensor(dy_descriptor.value())); + ASSIGN_OR_RETURN(expectation_tensor, + create_cudnn_tensor(expectation_descriptor.value())); + ASSIGN_OR_RETURN(norm_factor_tensor, + create_cudnn_tensor(norm_factor_descriptor.value())); + ASSIGN_OR_RETURN(dscale_tensor, + create_cudnn_tensor(dscale_descriptor.value())); + ASSIGN_OR_RETURN(dbias_tensor, + create_cudnn_tensor(dbias_descriptor.value())); } std::vector scale_dim(4, 1), scalar_uids; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto epsilon_tensor, CreateCudnnTensor(scale_dim, scale_dim, scalar_uids.emplace_back(uids.back() + 1), @@ -5798,14 +5795,14 @@ CudnnSupport::NormRunnerFromDesc( .setOperationGraph(ops.size(), ops.data()) .build(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto execution_plan, GetExecPlanFromHeuristics(std::move(op_graph), cudnn, /*include_fallback_heuristics=*/true)); std::vector scalar_input_values = { ScalingParam(epsilon, dnn::DataType::kDouble)}; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto runner, CudnnExecutionPlanRunner::Create( parent_, cudnn_.get(), std::move(execution_plan), uids, @@ -6028,7 +6025,7 @@ absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max()); if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( workspace, CreateBatchNormForwardWorkspace( stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor, @@ -6041,8 +6038,8 @@ absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( /*activationDesc=*/activation_desc.handle(), /*xDesc=*/x_descriptor.handle(), /*sizeInBytes=*/&reserve_space_size_in_bytes)); - TF_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( - reserve_space_size_in_bytes)); + ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( + reserve_space_size_in_bytes)); } } @@ -6066,8 +6063,8 @@ absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( void* batch_var_opaque; if (!batch_mean->is_null() && !batch_var->is_null()) { if (exponential_average_factor == 1.0) { - TF_RETURN_IF_ERROR(stream->MemZero(batch_mean, batch_mean->size())); - TF_RETURN_IF_ERROR(stream->MemZero(batch_var, batch_var->size())); + RETURN_IF_ERROR(stream->MemZero(batch_mean, batch_mean->size())); + RETURN_IF_ERROR(stream->MemZero(batch_var, batch_var->size())); } batch_mean_opaque = batch_mean->opaque(); batch_var_opaque = batch_var->opaque(); @@ -6107,7 +6104,7 @@ absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( /*reserveSpaceSizeInBytes=*/reserve_space.size())); } if (!called) { - TF_RETURN_IF_ERROR(check_no_side_input_or_activation()); + RETURN_IF_ERROR(check_no_side_input_or_activation()); RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining( cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), @@ -6117,7 +6114,7 @@ absl::Status CudnnSupport::DoBatchNormalizationForwardImpl( } } else { const void* maybe_inv_var = estimated_variance.opaque(); - TF_RETURN_IF_ERROR(check_no_side_input_or_activation()); + RETURN_IF_ERROR(check_no_side_input_or_activation()); RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference( cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), @@ -6235,7 +6232,7 @@ absl::Status CudnnSupport::DoBatchNormalizationBackwardImpl( CudnnActivationDescriptor activation_desc( activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max()); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( DeviceAddress workspace, CreateBatchNormBackwardWorkspace( stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor, @@ -6802,8 +6799,8 @@ absl::StatusOr> CudnnSupport::DeserializeGraph( absl::Status CudnnGraph::Prepare(dnn::DnnSupport& dnn_support, const EngineOptions& engine_options) { const CudnnSupport& cudnn_support = static_cast(dnn_support); - TF_ASSIGN_OR_RETURN(auto cudnn_handle, - cudnn_support.cudnn_->GetCompilationHandle()); + ASSIGN_OR_RETURN(auto cudnn_handle, + cudnn_support.cudnn_->GetCompilationHandle()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.validate()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.build_operation_graph(cudnn_handle)); RETURN_IF_CUDNN_FRONTEND_ERROR( @@ -6822,8 +6819,8 @@ absl::Status CudnnGraph::Prepare(dnn::DnnSupport& dnn_support, absl::Status CudnnGraph::Build(dnn::DnnSupport& dnn_support, const std::optional plan_id) { const CudnnSupport& cudnn_support = static_cast(dnn_support); - TF_ASSIGN_OR_RETURN(auto cudnn_handle, - cudnn_support.cudnn_->GetCompilationHandle()); + ASSIGN_OR_RETURN(auto cudnn_handle, + cudnn_support.cudnn_->GetCompilationHandle()); if (plan_id.has_value()) { RETURN_CUDNN_FRONTEND_STATUS( graph_.build_plan_at_index(cudnn_handle, *plan_id)); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_event.cc b/third_party/xla/xla/stream_executor/cuda/cuda_event.cc index f232232db3f053..2a39be3fe8c27e 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_event.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_event.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/cuda/cuda_status.h" @@ -66,7 +67,7 @@ absl::StatusOr InitEvent(StreamExecutor *executor, EventFlags flags) { std::unique_ptr activation = executor->Activate(); CUevent event_handle; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuEventCreate(&event_handle, cuflags))); + RETURN_IF_ERROR(cuda::ToStatus(cuEventCreate(&event_handle, cuflags))); return event_handle; } @@ -95,7 +96,7 @@ absl::Status CudaEvent::Synchronize() { absl::StatusOr CudaEvent::Create(StreamExecutor *executor, bool allow_timing) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( CUevent event_handle, InitEvent(executor, allow_timing ? EventFlags::kDefault : EventFlags::kDisableTiming)); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 1f248a9bb72e92..1d67b90f808366 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -218,7 +218,7 @@ absl::StatusOr LoadPtx(Context* context, const char* ptx_contents) { }); notification.WaitForNotification(); - TF_RETURN_IF_ERROR(returned_status); + RETURN_IF_ERROR(returned_status); return module; } @@ -227,7 +227,7 @@ absl::StatusOr LoadPtx(Context* context, const char* ptx_contents) { absl::StatusOr LoadCubin(Context* context, const char* cubin_bytes) { ScopedActivateContext activation(context); CUmodule module; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuModuleLoadFatBinary(&module, cubin_bytes), absl::StrCat(xla::XlaFormatDevice(context->device_ordinal()), "Failed to load in-memory CUBIN " @@ -251,7 +251,7 @@ absl::StatusOr GetModuleFunction(Context* context, CUmodule module, cudaGetErrorString(cuda_error))); } CUfunction function; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuModuleGetFunction(&function, module, kernel_name), absl::StrCat(xla::XlaFormatDevice(context->device_ordinal()), "Failed to get module function ", kernel_name))); @@ -287,7 +287,7 @@ void UnloadCudaModule(Context* context, CUmodule module) { absl::StatusOr GetDeviceAttribute(CUdevice_attribute attribute, CUdevice device) { int val; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuDeviceGetAttribute(&val, attribute, device))); return val; } @@ -295,7 +295,7 @@ absl::StatusOr GetDeviceAttribute(CUdevice_attribute attribute, // Returns the name of the device. absl::StatusOr GetDeviceName(CUdevice device) { std::array chars; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuDeviceGetName(chars.begin(), chars.size() - 1, device), "Failed to get device name")); chars[chars.size() - 1] = '\0'; @@ -305,11 +305,11 @@ absl::StatusOr GetDeviceName(CUdevice device) { // Returns the compute capability for the device; i.e (3, 5). absl::StatusOr GetComputeCapability(CUdevice device) { int cc_major = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( + RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( &cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device))); int cc_minor = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( + RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( &cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device))); bool has_accelerated_features = cc_major >= 9; @@ -326,7 +326,7 @@ template static absl::StatusOr GetSimpleAttribute(CUdevice device, CUdevice_attribute attribute) { int value = -1; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuDeviceGetAttribute(&value, attribute, device), absl::StrCat("Could not retrieve CUDA device attribute (", attribute))); T converted = value; @@ -371,17 +371,17 @@ absl::StatusOr GetThreadsPerWarp(CUdevice device) { absl::Status GetGridLimits(int* x, int* y, int* z, CUdevice device) { int value; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device), "Could not get device attribute")); *x = value; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device), "Could not get device attribute")); *y = value; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device), "Could not get device attribute")); *z = value; @@ -391,8 +391,8 @@ absl::Status GetGridLimits(int* x, int* y, int* z, CUdevice device) { // Returns the device associated with the given device_ordinal. absl::StatusOr GetDevice(int device_ordinal) { CUdevice device; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGet(&device, device_ordinal), - "Failed call to cuDeviceGet")); + RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGet(&device, device_ordinal), + "Failed call to cuDeviceGet")); return device; } @@ -512,7 +512,7 @@ bool HostUnregister(Context* context, void* location) { absl::StatusOr IsVmmSupported(CUdevice device) { int deviceSupportsVmm = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( + RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( &deviceSupportsVmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device))); return deviceSupportsVmm; @@ -547,7 +547,7 @@ CUmemAccessDesc GetVmmAccessDesc(int device) { absl::StatusOr IsRdmaSupported(CUdevice device) { int rdma_supported = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( + RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( &rdma_supported, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED, device))); return rdma_supported; @@ -555,7 +555,7 @@ absl::StatusOr IsRdmaSupported(CUdevice device) { absl::StatusOr IsMulticastSupported(CUdevice device) { int is_multicast_supported = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuDeviceGetAttribute(&is_multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, device))); return is_multicast_supported; @@ -575,7 +575,7 @@ absl::StatusOr IsFabricSupported(CUdevice device) { return false; } - TF_RETURN_IF_ERROR(cuda::ToStatus(result)); + RETURN_IF_ERROR(cuda::ToStatus(result)); return fabric_supported > 0; } @@ -584,8 +584,8 @@ absl::StatusOr IsFabricSupported(CUdevice device) { // (e.g. MIG partitions, containers, or older drivers). Returns VMM allocator // options with alignment set to the queried granularity. absl::StatusOr QueryVmmOptions(CUdevice device) { - TF_ASSIGN_OR_RETURN(bool rdma, IsRdmaSupported(device)); - TF_ASSIGN_OR_RETURN(bool fabric, IsFabricSupported(device)); + ASSIGN_OR_RETURN(bool rdma, IsRdmaSupported(device)); + ASSIGN_OR_RETURN(bool fabric, IsFabricSupported(device)); bool posix_fd = true; size_t granularity = 0; @@ -643,7 +643,7 @@ absl::StatusOr CreateMulticastObjectProperties( multicast_properties.flags = 0; size_t multicast_granularity = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuMulticastGetGranularity(&multicast_granularity, &multicast_properties, CU_MULTICAST_GRANULARITY_RECOMMENDED))); @@ -670,7 +670,7 @@ absl::Status ToStatus(nvmlReturn_t result) { // CUDA and Nvml can have different device ordering. absl::StatusOr GetNvmlDevice(const std::string& pci_bus_id) { nvmlDevice_t device; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ToStatus(nvmlDeviceGetHandleByPciBusId_v2(pci_bus_id.c_str(), &device))); return device; } @@ -681,10 +681,10 @@ absl::StatusOr GetDevicePcieBandwidth(nvmlDevice_t nvml_device) { unsigned int link_gen, link_width; nvmlReturn_t result = nvmlDeviceGetCurrPcieLinkGeneration(nvml_device, &link_gen); - TF_RETURN_IF_ERROR(ToStatus(result)); + RETURN_IF_ERROR(ToStatus(result)); result = nvmlDeviceGetCurrPcieLinkWidth(nvml_device, &link_width); - TF_RETURN_IF_ERROR(ToStatus(result)); + RETURN_IF_ERROR(ToStatus(result)); // PCIe v1 single lane speed. 0.25 GB/s int64_t lane_speed = 0.25 * 1024 * 1024 * 1024; @@ -705,7 +705,7 @@ absl::StatusOr GetNumberOfActiveP2PNvlinks(nvmlDevice_t nvml_device) { if (result == NVML_ERROR_NOT_SUPPORTED) { break; } - TF_RETURN_IF_ERROR(ToStatus(result)); + RETURN_IF_ERROR(ToStatus(result)); if (is_active == NVML_FEATURE_DISABLED) { break; } @@ -714,7 +714,7 @@ absl::StatusOr GetNumberOfActiveP2PNvlinks(nvmlDevice_t nvml_device) { result = nvmlDeviceGetNvLinkCapability( nvml_device, i, NVML_NVLINK_CAP_P2P_SUPPORTED, &supported_p2p); if (result != NVML_ERROR_NOT_SUPPORTED) { - TF_RETURN_IF_ERROR(ToStatus(result)); + RETURN_IF_ERROR(ToStatus(result)); } if (supported_p2p) { ++p2p_links; @@ -734,7 +734,7 @@ absl::StatusOr GetDeviceFabricInfo(nvmlDevice_t device) { fabricInfo.state = NVML_GPU_FABRIC_STATE_NOT_SUPPORTED; nvmlReturn_t result = nvmlDeviceGetGpuFabricInfoV(device, &fabricInfo); - TF_RETURN_IF_ERROR(ToStatus(result)); + RETURN_IF_ERROR(ToStatus(result)); if (fabricInfo.state == NVML_GPU_FABRIC_STATE_NOT_SUPPORTED) { std::string error_message = @@ -800,7 +800,7 @@ absl::StatusOr CudaExecutor::GetMemoryRange( const DeviceAddressBase& location) const { CUdeviceptr device_pointer; size_t size; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuMemGetAddressRange(&device_pointer, &size, AsCudaDevicePtr(location)))); return DeviceAddressBase(reinterpret_cast(device_pointer), size); } @@ -819,8 +819,8 @@ static bool IsNvshmemEnabled() { } static absl::StatusOr GetNvshmemCollectives() { - TF_ASSIGN_OR_RETURN(xla::Collectives * collectives, - xla::CollectivesRegistry::Get("gpu", "nvshmem")); + ASSIGN_OR_RETURN(xla::Collectives * collectives, + xla::CollectivesRegistry::Get("gpu", "nvshmem")); auto* gpu_collectives = absl::down_cast(collectives); if (gpu_collectives == nullptr) { @@ -833,7 +833,7 @@ CudaExecutor::VmmMemoryHandle::~VmmMemoryHandle() { CHECK_OK(Release()); } absl::Status CudaExecutor::VmmMemoryHandle::Release() { if (handle_ != 0) { - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuMemRelease(static_cast(handle_)))); handle_ = 0; } @@ -859,7 +859,7 @@ CudaExecutor::VmmMemoryHandle& CudaExecutor::VmmMemoryHandle::operator=( absl::StatusOr CudaExecutor::RetainVmmMemoryHandle(void* ptr) const { CUmemGenericAllocationHandle handle; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuMemRetainAllocationHandle(&handle, ptr))); + RETURN_IF_ERROR(cuda::ToStatus(cuMemRetainAllocationHandle(&handle, ptr))); return CudaExecutor::VmmMemoryHandle(static_cast(handle)); } @@ -867,7 +867,7 @@ CudaExecutor::RetainVmmMemoryHandle(void* ptr) const { absl::StatusOr CudaExecutor::GetVmmGranularity() const { CUmemAllocationProp properties = GetVmmAllocationProp(device_, vmm_options_); size_t granularity = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuMemGetAllocationGranularity( + RETURN_IF_ERROR(cuda::ToStatus(cuMemGetAllocationGranularity( &granularity, &properties, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED))); return granularity; } @@ -878,14 +878,14 @@ static absl::StatusOr NvshmemCollectiveMemoryAllocate( return nullptr; } std::unique_ptr activation = executor->Activate(); - TF_ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectives()); + ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectives()); return collectives->Allocate(bytes); } static absl::Status NvshmemCollectiveMemoryDeallocate(StreamExecutor* executor, void* location) { std::unique_ptr activation = executor->Activate(); - TF_ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectives()); + ASSIGN_OR_RETURN(auto* collectives, GetNvshmemCollectives()); return collectives->Deallocate(location); } @@ -900,8 +900,8 @@ CudaExecutor::CreateMemoryAllocator(MemorySpace type) { return std::make_unique( [this](uint64_t size) -> absl::StatusOr> { - TF_ASSIGN_OR_RETURN(void* ptr, - NvshmemCollectiveMemoryAllocate(this, size)); + ASSIGN_OR_RETURN(void* ptr, + NvshmemCollectiveMemoryAllocate(this, size)); return std::make_unique( ptr, size, [this](void* location, uint64_t size) { auto status = @@ -926,9 +926,9 @@ CudaExecutor::CreateMemoryAllocator(MemorySpace type) { } absl::Status CudaExecutor::Init() { - TF_ASSIGN_OR_RETURN(device_, GetDevice(device_ordinal())); + ASSIGN_OR_RETURN(device_, GetDevice(device_ordinal())); - TF_ASSIGN_OR_RETURN(bool is_vmm_supported, IsVmmSupported(device_)); + ASSIGN_OR_RETURN(bool is_vmm_supported, IsVmmSupported(device_)); if (!is_vmm_supported) { return absl::InternalError(absl::StrFormat( "Device %d does not support CUDA Virtual Memory Management (VMM). " @@ -936,11 +936,11 @@ absl::Status CudaExecutor::Init() { device_ordinal())); } - TF_ASSIGN_OR_RETURN(is_multicast_supported_, IsMulticastSupported(device_)); - TF_ASSIGN_OR_RETURN(CudaContext * context, - CudaContext::Create(device_ordinal(), device_)); + ASSIGN_OR_RETURN(is_multicast_supported_, IsMulticastSupported(device_)); + ASSIGN_OR_RETURN(CudaContext * context, + CudaContext::Create(device_ordinal(), device_)); cuda_context_ = context; - TF_ASSIGN_OR_RETURN(delay_kernels_supported_, DelayKernelIsSupported()); + ASSIGN_OR_RETURN(delay_kernels_supported_, DelayKernelIsSupported()); numa_node_ = ReadNumaNode(GetPCIBusID(device_), device_ordinal()) .value_or(tsl::port::kNUMANoAffinity); if (numa_node_ == tsl::port::kNUMANoAffinity) { @@ -948,7 +948,7 @@ absl::Status CudaExecutor::Init() { } int cuda_device_count = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus(cudaGetDeviceCount(&cuda_device_count))); + RETURN_IF_ERROR(cuda::ToStatus(cudaGetDeviceCount(&cuda_device_count))); for (int i = 0; i < cuda_device_count; ++i) { if (i == device_ordinal()) { peer_access_cache_[i] = true; @@ -957,7 +957,7 @@ absl::Status CudaExecutor::Init() { peer_access_cache_[i] = CanEnablePeerAccess(device_, i); } - TF_ASSIGN_OR_RETURN(vmm_options_, QueryVmmOptions(device_)); + ASSIGN_OR_RETURN(vmm_options_, QueryVmmOptions(device_)); vmm_options_.enable_peer_access = absl::c_any_of( peer_access_cache_, [](const auto& p) { return p.second; }); @@ -980,7 +980,7 @@ absl::Status CudaExecutor::Init() { absl::StatusOr CudaExecutor::DelayKernelIsSupported() { // Check the assumption that this device supports unified addressing, // otherwise skip the delay kernel - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( int status, GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device_)); @@ -995,7 +995,7 @@ absl::StatusOr CudaExecutor::LoadModuleFromCuBin( std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadCubin(cuda_context_, cubin)); + ASSIGN_OR_RETURN(module, LoadCubin(cuda_context_, cubin)); module_refcount = 1; XLA_VLOG_DEVICE(3, device_ordinal()) << "Loaded CUBIN " << static_cast(cubin) << " as module " @@ -1017,7 +1017,7 @@ absl::StatusOr CudaExecutor::LoadModuleFromPtx(const char* ptx) { std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadPtx(cuda_context_, ptx)); + ASSIGN_OR_RETURN(module, LoadPtx(cuda_context_, ptx)); XLA_VLOG_DEVICE(3, device_ordinal()) << "Loaded PTX " << static_cast(ptx) << " as module " << module; @@ -1041,13 +1041,13 @@ absl::StatusOr> CudaExecutor::LoadKernel( absl::MutexLock lock{in_memory_modules_mu_}; const char* cubin = reinterpret_cast( spec.cuda_cubin_in_memory()->cubin_bytes.data()); - TF_ASSIGN_OR_RETURN(ModuleHandle module_handle, LoadModuleFromCuBin(cubin)); + ASSIGN_OR_RETURN(ModuleHandle module_handle, LoadModuleFromCuBin(cubin)); kernel_to_gpu_binary_[cuda_kernel.get()] = module_handle; CUmodule module = gpu_binary_to_module_.at(module_handle).first; XLA_VLOG_DEVICE(2, device_ordinal()) << "getting function " << kernel_name << " from module " << module; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( CUfunction function, GetModuleFunction(cuda_context_, module, kernel_name.c_str())); cuda_kernel->set_gpu_function(function); @@ -1060,13 +1060,13 @@ absl::StatusOr> CudaExecutor::LoadKernel( } absl::MutexLock lock{in_memory_modules_mu_}; - TF_ASSIGN_OR_RETURN(ModuleHandle module_handle, LoadModuleFromPtx(ptx)); + ASSIGN_OR_RETURN(ModuleHandle module_handle, LoadModuleFromPtx(ptx)); kernel_to_gpu_binary_[cuda_kernel.get()] = module_handle; CUmodule module = gpu_binary_to_module_.at(module_handle).first; XLA_VLOG_DEVICE(2, device_ordinal()) << "getting function " << kernel_name << " from module " << module; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( CUfunction function, GetModuleFunction(cuda_context_, module, kernel_name.c_str())); cuda_kernel->set_gpu_function(function); @@ -1079,7 +1079,7 @@ absl::StatusOr> CudaExecutor::LoadKernel( << " from symbol pointer: " << symbol; cudaFunction_t func; std::unique_ptr scoped_activation = Activate(); - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cudaGetFuncBySymbol(&func, symbol), absl::StrFormat("[%d] Failed call to cudaGetFuncBySymbol", device_ordinal()))); @@ -1104,8 +1104,8 @@ absl::StatusOr> CudaExecutor::LoadKernel( // to be a way to reflect on the number of expected arguments w/the CUDA API. cuda_kernel->set_arity(spec.arity()); - TF_ASSIGN_OR_RETURN(KernelMetadata kernel_metadata, - cuda_kernel->GetKernelMetadata()); + ASSIGN_OR_RETURN(KernelMetadata kernel_metadata, + cuda_kernel->GetKernelMetadata()); cuda_kernel->set_metadata(kernel_metadata); if (std::holds_alternative( spec.kernel_args_packing())) { @@ -1134,8 +1134,8 @@ CudaExecutor::CreateEventBasedTimer(Stream* stream, bool use_delay_kernel) { ? CudaTimer::TimerType::kDelayKernel : CudaTimer::TimerType::kEventBased; - TF_ASSIGN_OR_RETURN(CudaTimer timer, - CudaTimer::Create(this, stream, timer_type)); + ASSIGN_OR_RETURN(CudaTimer timer, + CudaTimer::Create(this, stream, timer_type)); return std::make_unique(std::move(timer)); } @@ -1245,7 +1245,7 @@ CudaExecutor::CreateOrShareConstant(Stream* stream, "Failed to allocate %d bytes for new constant", content.size())); } - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( stream->Memcpy(new_constant.get(), content.data(), content.size())); absl::Status status = stream->BlockHostUntilDone(); if (!status.ok()) { @@ -1369,7 +1369,7 @@ absl::Status CudaExecutor::SynchronousMemcpy(DeviceAddressBase* gpu_dst, const void* host_src, uint64_t size) { std::unique_ptr activation = Activate(); - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuMemcpyHtoD(AsCudaDevicePtr(gpu_dst), host_src, size), absl::StrFormat("%sfailed to synchronous memcpy from " "host to device: GPU dst: %llx;" @@ -1385,7 +1385,7 @@ absl::Status CudaExecutor::SynchronousMemcpy(void* host_dst, const DeviceAddressBase& gpu_src, uint64_t size) { std::unique_ptr activation = Activate(); - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuMemcpyDtoH(host_dst, AsCudaDevicePtr(gpu_src), size), absl::StrFormat("%sfailed to synchronous memcpy from device to host " "host dst: %p; GPU src: %llx; size: %u=0x%x", @@ -1525,7 +1525,7 @@ absl::StatusOr CudaExecutor::GetSymbol( CUmodule gpu_module_handle = it->second.first; CHECK(gpu_module_handle != nullptr); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( GetModuleSymbol(cuda_context_, gpu_module_handle, symbol_name.c_str(), reinterpret_cast(&mem), &bytes)); return DeviceAddressBase(mem, bytes); @@ -1740,12 +1740,12 @@ CudaExecutor::CreateDeviceDescription(int device_ordinal) { { BlockDim block_dim_limit; - TF_RETURN_IF_ERROR(FillBlockDimLimit(device, &block_dim_limit)); + RETURN_IF_ERROR(FillBlockDimLimit(device, &block_dim_limit)); desc.set_block_dim_limit(block_dim_limit); } { - TF_ASSIGN_OR_RETURN(std::string device_name, GetDeviceName(device)); + ASSIGN_OR_RETURN(std::string device_name, GetDeviceName(device)); desc.set_name(device_name); } @@ -1810,7 +1810,7 @@ absl::StatusOr CudaExecutor::GetPointerMemorySpace( const void* ptr) { CUdeviceptr pointer = reinterpret_cast(const_cast(ptr)); unsigned int is_managed; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute( + RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute( &is_managed, CU_POINTER_ATTRIBUTE_IS_MANAGED, pointer))); if (is_managed) { @@ -1818,7 +1818,7 @@ absl::StatusOr CudaExecutor::GetPointerMemorySpace( } unsigned int value; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute( + RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute( &value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer))); switch (value) { case CU_MEMORYTYPE_DEVICE: @@ -1872,8 +1872,8 @@ absl::StatusOr CudaExecutor::GetCudaKernel( absl::StatusOr CudaExecutor::CreateTensorMap( const TmaDescriptor& tma_desc, void* global_address) { - TF_ASSIGN_OR_RETURN(CUtensorMapDataType data_type, - GetTensorMapDataType(tma_desc.element_size())); + ASSIGN_OR_RETURN(CUtensorMapDataType data_type, + GetTensorMapDataType(tma_desc.element_size())); CUtensorMapSwizzle swizzle = GetTensorMapSwizzle(tma_desc.swizzle()); CUtensorMapL2promotion l2_promotion = GetTensorMapL2Promotion(tma_desc.l2_promotion()); @@ -1912,7 +1912,7 @@ CudaExecutor::CreateMulticastMemory(uint64_t size, int num_devices) const { } auto multicast_memory = std::make_unique(); - TF_RETURN_IF_ERROR(multicast_memory->Initialize(size, num_devices, this)); + RETURN_IF_ERROR(multicast_memory->Initialize(size, num_devices, this)); return multicast_memory; } @@ -1956,15 +1956,15 @@ absl::Status CudaExecutor::CudaMulticastMemory::Initialize( CUmemAllocationProp properties = GetVmmAllocationProp(cuda_executor->device_, cuda_executor->vmm_options_); - TF_RETURN_IF_ERROR(cuda::ToStatus(cuMemGetAllocationGranularity( + RETURN_IF_ERROR(cuda::ToStatus(cuMemGetAllocationGranularity( &granularity_, &properties, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED))); padded_size_ = xla::RoundUpTo(size, granularity_); num_devices_ = num_devices; - TF_ASSIGN_OR_RETURN(CUmulticastObjectProp multicast_properties, - CreateMulticastObjectProperties(num_devices_, size)); + ASSIGN_OR_RETURN(CUmulticastObjectProp multicast_properties, + CreateMulticastObjectProperties(num_devices_, size)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuMulticastCreate(&handle_, &multicast_properties))); XLA_VLOG_DEVICE(3, cuda_executor->device_ordinal()) << "Created multicast memory: " << static_cast(handle_) @@ -1985,8 +1985,7 @@ absl::Status CudaExecutor::CudaMulticastMemory::SubscribeDevice( } XLA_VLOG_DEVICE(3, device_number) << "Subscribe to multicast: " << handle_; - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuMulticastAddDevice(handle_, device_number))); + RETURN_IF_ERROR(cuda::ToStatus(cuMulticastAddDevice(handle_, device_number))); subscribed_devices_++; return absl::OkStatus(); } @@ -2012,19 +2011,18 @@ absl::StatusOr CudaExecutor::CudaMulticastMemory::MapMemory( return absl::FailedPreconditionError("All devices should be subscribed."); } - TF_ASSIGN_OR_RETURN(CudaExecutor::VmmMemoryHandle memory_handle, - cuda_executor->RetainVmmMemoryHandle(location.opaque())); + ASSIGN_OR_RETURN(CudaExecutor::VmmMemoryHandle memory_handle, + cuda_executor->RetainVmmMemoryHandle(location.opaque())); CUmemGenericAllocationHandle retained_memory_handle = static_cast(memory_handle.handle()); - TF_ASSIGN_OR_RETURN(auto base_address, - cuda_executor->GetMemoryRange(location)); + ASSIGN_OR_RETURN(auto base_address, cuda_executor->GetMemoryRange(location)); uint64_t offset = reinterpret_cast(location.opaque()) - reinterpret_cast(base_address.opaque()); // Bind the memory to the multicast object. - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuMulticastBindMem(handle_, /*mcOffset=*/0, retained_memory_handle, /*memOffset=*/offset, padded_size_, /*flags=*/0))); @@ -2037,14 +2035,14 @@ absl::StatusOr CudaExecutor::CudaMulticastMemory::MapMemory( // Map a virtual address range for the multicast memory. Multicast // memory is used to reduce the data stored in the multicast object. CUdeviceptr multicast_device_ptr; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuMemAddressReserve( + RETURN_IF_ERROR(cuda::ToStatus(cuMemAddressReserve( &multicast_device_ptr, padded_size_, granularity_, 0, 0))); - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuMemMap(multicast_device_ptr, padded_size_, 0, handle_, 0))); CUmemAccessDesc accessDesc = GetVmmAccessDesc(cuda_executor->device_); - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuMemSetAccess(multicast_device_ptr, padded_size_, &accessDesc, 1))); absl::MutexLock subscription_lock(mapped_devices_mu_); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor_multigpu_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor_multigpu_test.cc index abce4a83702949..a790330984f652 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor_multigpu_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor_multigpu_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/cuda/cuda_executor.h" #include "xla/stream_executor/cuda/cuda_executor_multigpu_test_kernels.h" #include "xla/stream_executor/device_address.h" @@ -55,7 +56,7 @@ absl::StatusOr AllocateInitializedMemory( std::vector device_memory_vector(num_initialized_elements, value); auto stride_memory = device_memory.GetByteSlice(offset, size); - TF_RETURN_IF_ERROR(executor->SynchronousMemcpy( + RETURN_IF_ERROR(executor->SynchronousMemcpy( &stride_memory, device_memory_vector.data(), size)); return stride_memory; } @@ -66,7 +67,7 @@ absl::Status CheckMemory(CudaExecutor* executor, T expected_value) { size_t num_elements = device_memory.size() / sizeof(T); std::vector device_memory_vector(num_elements, 0); - TF_RETURN_IF_ERROR(executor->SynchronousMemcpy( + RETURN_IF_ERROR(executor->SynchronousMemcpy( device_memory_vector.data(), device_memory, device_memory.size())); for (int i = 0; i < device_memory_vector.size(); ++i) { EXPECT_EQ(device_memory_vector[i], expected_value); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor_multigpu_test_kernels.cu.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor_multigpu_test_kernels.cu.cc index fc2652cf210d94..b896bdb76e54d1 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor_multigpu_test_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor_multigpu_test_kernels.cu.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_executor_multigpu_test_kernels.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/tsl/platform/errors.h" @@ -38,8 +39,8 @@ __global__ void MulticastReduceKernel(int* input, int* output, size_t size) { } // namespace __host__ absl::Status MulticastReduce(int* input, int* output, size_t size) { - TF_RETURN_IF_ERROR(stream_executor::cuda::ToStatus(cudaSetDevice(0))); - TF_RETURN_IF_ERROR(stream_executor::cuda::ToStatus(cudaDeviceSynchronize())); + RETURN_IF_ERROR(stream_executor::cuda::ToStatus(cudaSetDevice(0))); + RETURN_IF_ERROR(stream_executor::cuda::ToStatus(cudaDeviceSynchronize())); MulticastReduceKernel<<<1, 1, 0>>>(input, output, size); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc b/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc index 41a44b53e66656..38112ed7eca784 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cufft.h" #include "xla/stream_executor/activate_context.h" @@ -218,12 +219,12 @@ absl::Status CUDAFftPlan::Initialize( // For either multiple batches or rank higher than 3, use cufft*PlanMany*(). if (scratch_allocator == nullptr) { // Downsize 64b arrays to 32b as there's no 64b version of cufftPlanMany - TF_ASSIGN_OR_RETURN(auto elem_count_32b_, - Downsize64bArray(elem_count_, rank)); - TF_ASSIGN_OR_RETURN(auto input_embed_32b_, - Downsize64bArray(input_embed_, rank)); - TF_ASSIGN_OR_RETURN(auto output_embed_32b_, - Downsize64bArray(output_embed_, rank)); + ASSIGN_OR_RETURN(auto elem_count_32b_, + Downsize64bArray(elem_count_, rank)); + ASSIGN_OR_RETURN(auto input_embed_32b_, + Downsize64bArray(input_embed_, rank)); + ASSIGN_OR_RETURN(auto output_embed_32b_, + Downsize64bArray(output_embed_, rank)); auto ret = cufftPlanMany( &plan_, rank, elem_count_32b_.data(), input_embed ? input_embed_32b_.data() : nullptr, input_stride, diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc index aa398e0d666259..c223990388e0bc 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc @@ -75,7 +75,7 @@ absl::StatusOr CudaKernel::GetMaxOccupiedBlocksPerCore( std::unique_ptr activation = executor_->Activate(); int max_blocks; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( &max_blocks, gpu_function_, threads_per_block, dynamic_shared_memory_bytes, CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE), @@ -87,12 +87,12 @@ absl::StatusOr CudaKernel::GetMaxOccupiedBlocksPerCore( absl::StatusOr CudaKernel::GetKernelMetadata() { KernelMetadata kernel_metadata; int value; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( GetCudaAttribute(CU_FUNC_ATTRIBUTE_NUM_REGS, gpu_function_, &value)); kernel_metadata.set_registers_per_thread(value); - TF_RETURN_IF_ERROR(GetCudaAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, - gpu_function_, &value)); + RETURN_IF_ERROR(GetCudaAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, + gpu_function_, &value)); kernel_metadata.set_shared_memory_bytes(value); return kernel_metadata; } @@ -111,10 +111,10 @@ absl::Status CudaKernel::UpdateMaxDynamicSharedMemoryBytes( } std::unique_ptr activation = executor_->Activate(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( SetCudaAttribute(CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, gpu_function_, shared_memory_bytes)); - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuFuncSetCacheConfig(gpu_function_, CU_FUNC_CACHE_PREFER_SHARED))); max_dynamic_shared_memory_bytes_.store(shared_memory_bytes, @@ -132,8 +132,9 @@ absl::Status CudaKernel::Launch(const ThreadDim& thread_dims, CUfunction function = gpu_function(); // Launch kernels with packed arguments. - auto launch = [this, stream, &cluster_dims, &thread_dims, &block_dims, - function](const KernelArgsPackedArrayBase& packed) { + auto launch = + [this, stream, &cluster_dims, &thread_dims, &block_dims, + function](const KernelArgsPackedArrayBase& packed) -> absl::Status { TraceMe trace([] { return TraceMeEncode("CudaKernel::Launch/launch", {}); }, /*level=*/TraceMeLevel::kVerbose); @@ -148,7 +149,7 @@ absl::Status CudaKernel::Launch(const ThreadDim& thread_dims, void** params = const_cast(packed.argument_addresses().data()); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( UpdateMaxDynamicSharedMemoryBytes(packed.number_of_shared_bytes())); return stream->LaunchKernel(thread_dims, block_dims, cluster_dims, function, diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_memory_reservation.cc b/third_party/xla/xla/stream_executor/cuda/cuda_memory_reservation.cc index 8d50e1f0c5a767..44b1e8ec1bea7b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_memory_reservation.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_memory_reservation.cc @@ -54,19 +54,19 @@ CudaMemoryReservation::Create(StreamExecutor* executor, uint64_t size) { std::unique_ptr activation = executor->Activate(); CUdevice device; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuDeviceGet(&device, executor->device_ordinal()))); CUmemAllocationProp props = BuildAllocationProperties(device); size_t granularity = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuMemGetAllocationGranularity( + RETURN_IF_ERROR(cuda::ToStatus(cuMemGetAllocationGranularity( &granularity, &props, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED))); uint64_t padded_size = xla::RoundUpTo(size, granularity); CUdeviceptr ptr; - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuMemAddressReserve(&ptr, padded_size, granularity, 0, 0))); return std::unique_ptr( @@ -112,7 +112,7 @@ absl::Status CudaMemoryReservation::SetAccess(uint64_t reservation_offset, // automatically inherit peer access, so NCCL collective operations using // NVLink to read/write peer GPU buffers will deadlock without this. int device_count = 0; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cudaGetDeviceCount(&device_count), "cudaGetDeviceCount")); for (int32_t peer = 0; peer < device_count; ++peer) { if (peer == executor_->device_ordinal()) { diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc index c425a4bb889704..ca103f2d84aa77 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/nvml/include/nvml.h" #include "third_party/gpus/cudnn/cudnn_version.h" @@ -111,12 +112,12 @@ const std::string& CudaPlatform::Name() const { return name_; } absl::StatusOr> CudaPlatform::DescriptionForDevice(int ordinal) const { - TF_RETURN_IF_ERROR(PlatformInitialize()); + RETURN_IF_ERROR(PlatformInitialize()); return CudaExecutor::CreateDeviceDescription(ordinal); } absl::StatusOr CudaPlatform::ExecutorForDevice(int ordinal) { - TF_RETURN_IF_ERROR(PlatformInitialize()); + RETURN_IF_ERROR(PlatformInitialize()); return executor_cache_.GetOrCreate( ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } @@ -128,7 +129,7 @@ absl::StatusOr CudaPlatform::FindExisting(int ordinal) { absl::StatusOr> CudaPlatform::GetUncachedExecutor(int ordinal) { auto executor = std::make_unique(this, ordinal); - TF_RETURN_IF_ERROR(executor->Init()); + RETURN_IF_ERROR(executor->Init()); return std::move(executor); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_raw_memory_allocation.cc b/third_party/xla/xla/stream_executor/cuda/cuda_raw_memory_allocation.cc index 433d61227bac01..8270a339be70f5 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_raw_memory_allocation.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_raw_memory_allocation.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/cuda/cuda_status.h" @@ -50,20 +51,19 @@ CudaRawMemoryAllocation::Create(StreamExecutor* executor, uint64_t size) { std::unique_ptr activation = executor->Activate(); CUdevice device; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuDeviceGet(&device, executor->device_ordinal()))); CUmemAllocationProp props = BuildAllocationProperties(device); size_t granularity = 0; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuMemGetAllocationGranularity( + RETURN_IF_ERROR(cuda::ToStatus(cuMemGetAllocationGranularity( &granularity, &props, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED))); uint64_t padded_size = xla::RoundUpTo(size, granularity); CUmemGenericAllocationHandle handle; - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuMemCreate(&handle, padded_size, &props, 0))); + RETURN_IF_ERROR(cuda::ToStatus(cuMemCreate(&handle, padded_size, &props, 0))); return std::unique_ptr( new CudaRawMemoryAllocation(executor, handle, padded_size)); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc b/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc index 0ff5a0440d5bb6..fe6f3b210a0cba 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc @@ -36,6 +36,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/cuda/cuda_context.h" @@ -84,10 +85,10 @@ absl::StatusOr CreateStream(StreamExecutor* executor, int priority) { // the default priority for backward compatibility. Probably there is no // difference in using the new api call but leaving it as is for now. if (priority == 0) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING))); } else { - TF_RETURN_IF_ERROR(cuda::ToStatus( + RETURN_IF_ERROR(cuda::ToStatus( cuStreamCreateWithPriority(&stream, CU_STREAM_NON_BLOCKING, priority))); } @@ -100,8 +101,8 @@ absl::StatusOr StreamIsCapturing(CUstream stream) { VLOG(2) << "Checking if stream " << stream << " is capturing"; CUstreamCaptureStatus status; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuStreamIsCapturing(stream, &status), - "Failed to check stream capturing status")); + RETURN_IF_ERROR(cuda::ToStatus(cuStreamIsCapturing(stream, &status), + "Failed to check stream capturing status")); return status == CU_STREAM_CAPTURE_STATUS_ACTIVE; } @@ -111,7 +112,7 @@ absl::Status AsynchronousMemcpyD2H(StreamExecutor* executor, void* host_dst, CUstream stream) { std::unique_ptr activation = executor->Activate(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuMemcpyDtoHAsync(host_dst, gpu_src, size, stream))); VLOG(2) << "successfully enqueued async memcpy d2h of " << size @@ -124,7 +125,7 @@ absl::Status AsynchronousMemcpyH2D(StreamExecutor* executor, CUdeviceptr gpu_dst, const void* host_src, uint64_t size, CUstream stream) { std::unique_ptr activation = executor->Activate(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuMemcpyHtoDAsync(gpu_dst, host_src, size, stream))); VLOG(2) << "successfully enqueued async memcpy h2d of " << size << " bytes" @@ -140,12 +141,12 @@ absl::Status AsynchronousMemcpyD2D(StreamExecutor* executor, // In graph capture mode we never have operations that access peer memory, so // we can always make a call to cuMemcpyDtoDAsync. - TF_ASSIGN_OR_RETURN(bool is_capturing, StreamIsCapturing(stream)); + ASSIGN_OR_RETURN(bool is_capturing, StreamIsCapturing(stream)); if ((gpu_dst == 0 || gpu_src == 0) || is_capturing) { // GetContextMap()->GetAnyContext() doesn't work when ptr == 0. // This happens when the size is 0. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream))); } else { // Any context work here. @@ -157,10 +158,10 @@ absl::Status AsynchronousMemcpyD2D(StreamExecutor* executor, if (dst_context == src_context) { // Since the CUDA context is the same, the src and dst are within the same // GPU. So we can use cuMemcpyDtoD. - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream))); } else { - TF_RETURN_IF_ERROR(cuda::ToStatus(cuMemcpyPeerAsync( + RETURN_IF_ERROR(cuda::ToStatus(cuMemcpyPeerAsync( gpu_dst, dst_context, gpu_src, src_context, size, stream))); } } @@ -208,12 +209,11 @@ absl::StatusOr> CudaStream::Create( return executor->GetGpuStreamPriority( std::get(priority.value_or(StreamPriority::Default))); }(); - TF_ASSIGN_OR_RETURN(auto stream_handle, - CreateStream(executor, stream_priority)); + ASSIGN_OR_RETURN(auto stream_handle, CreateStream(executor, stream_priority)); - TF_ASSIGN_OR_RETURN(auto completed_event, - CudaEvent::Create(executor, - /*allow_timing=*/false)); + ASSIGN_OR_RETURN(auto completed_event, + CudaEvent::Create(executor, + /*allow_timing=*/false)); return std::unique_ptr(new CudaStream( executor, std::move(completed_event), priority, stream_handle)); @@ -222,7 +222,7 @@ absl::StatusOr> CudaStream::Create( absl::Status CudaStream::WaitFor(Stream* other) { CudaStream* other_stream = static_cast(other); - TF_RETURN_IF_ERROR(other_stream->RecordCompletedEvent()); + RETURN_IF_ERROR(other_stream->RecordCompletedEvent()); return WaitStreamOnEvent(executor_, stream_handle_, other_stream->completed_event_.GetHandle()); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_timer.cc b/third_party/xla/xla/stream_executor/cuda/cuda_timer.cc index 89c2018dfeaf12..616a56989bc581 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_timer.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_timer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/cuda/cuda_event.h" @@ -41,11 +42,11 @@ absl::StatusOr GetEventElapsedTime(StreamExecutor *executor, std::unique_ptr activation = executor->Activate(); // The stop event must have completed in order for cuEventElapsedTime to // work. - TF_RETURN_IF_ERROR(cuda::ToStatus(cuEventSynchronize(stop))); + RETURN_IF_ERROR(cuda::ToStatus(cuEventSynchronize(stop))); float elapsed_milliseconds; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuEventElapsedTime(&elapsed_milliseconds, start, stop))); return elapsed_milliseconds; @@ -79,7 +80,7 @@ absl::StatusOr CudaTimer::GetElapsedDuration() { if (is_stopped_) { return absl::FailedPreconditionError("Measuring inactive timer"); } - TF_RETURN_IF_ERROR(stream_->RecordEvent(&stop_event_)); + RETURN_IF_ERROR(stream_->RecordEvent(&stop_event_)); // If we launched the delay kernel then check if it already timed out. if (semaphore_) { if (*semaphore_ == GpuSemaphoreState::kTimedOut) { @@ -92,9 +93,9 @@ absl::StatusOr CudaTimer::GetElapsedDuration() { *semaphore_ = GpuSemaphoreState::kRelease; } } - TF_ASSIGN_OR_RETURN(float elapsed_milliseconds, - GetEventElapsedTime(executor_, start_event_.GetHandle(), - stop_event_.GetHandle())); + ASSIGN_OR_RETURN(float elapsed_milliseconds, + GetEventElapsedTime(executor_, start_event_.GetHandle(), + stop_event_.GetHandle())); is_stopped_ = true; return absl::Milliseconds(elapsed_milliseconds); } @@ -105,15 +106,15 @@ absl::StatusOr CudaTimer::Create(StreamExecutor *executor, GpuSemaphore semaphore{}; if (timer_type == TimerType::kDelayKernel) { - TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(stream)); + ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(stream)); } - TF_ASSIGN_OR_RETURN(CudaEvent start_event, - CudaEvent::Create(executor, /*allow_timing=*/true)); - TF_ASSIGN_OR_RETURN(CudaEvent stop_event, - CudaEvent::Create(executor, /*allow_timing=*/true)); + ASSIGN_OR_RETURN(CudaEvent start_event, + CudaEvent::Create(executor, /*allow_timing=*/true)); + ASSIGN_OR_RETURN(CudaEvent stop_event, + CudaEvent::Create(executor, /*allow_timing=*/true)); - TF_RETURN_IF_ERROR(stream->RecordEvent(&start_event)); + RETURN_IF_ERROR(stream->RecordEvent(&start_event)); return CudaTimer(executor, std::move(start_event), std::move(stop_event), stream, std::move(semaphore)); diff --git a/third_party/xla/xla/stream_executor/cuda/cudnn_api_wrappers.cc b/third_party/xla/xla/stream_executor/cuda/cudnn_api_wrappers.cc index b170ebe8f830c7..53d6ac88f64614 100644 --- a/third_party/xla/xla/stream_executor/cuda/cudnn_api_wrappers.cc +++ b/third_party/xla/xla/stream_executor/cuda/cudnn_api_wrappers.cc @@ -23,9 +23,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/library_types.h" -#include "third_party/gpus/cudnn/cudnn_version.h" #include "third_party/gpus/cudnn/cudnn_graph.h" +#include "third_party/gpus/cudnn/cudnn_version.h" #include "xla/stream_executor/semantic_version.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" @@ -102,18 +103,16 @@ absl::StatusOr GetCudnnProperty(CudnnProperty type) { return absl::NotFoundError("cuDNN is not linked into the application."); } int value{}; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ToStatus(cudnnGetProperty(ToLibraryPropertyType(type), &value))); return value; } absl::StatusOr GetLoadedCudnnVersion() { - TF_ASSIGN_OR_RETURN(int major, - GetCudnnProperty(CudnnProperty::kMajorVersion)); - TF_ASSIGN_OR_RETURN(int minor, - GetCudnnProperty(CudnnProperty::kMinorVersion)); - TF_ASSIGN_OR_RETURN(int patch, - GetCudnnProperty(CudnnProperty::kPatchLevelVersion)); + ASSIGN_OR_RETURN(int major, GetCudnnProperty(CudnnProperty::kMajorVersion)); + ASSIGN_OR_RETURN(int minor, GetCudnnProperty(CudnnProperty::kMinorVersion)); + ASSIGN_OR_RETURN(int patch, + GetCudnnProperty(CudnnProperty::kPatchLevelVersion)); return SemanticVersion(major, minor, patch); } diff --git a/third_party/xla/xla/stream_executor/cuda/cudnn_sdpa_score_mod.cc b/third_party/xla/xla/stream_executor/cuda/cudnn_sdpa_score_mod.cc index edc2cf75d3f899..f7a7b8a48b15c5 100644 --- a/third_party/xla/xla/stream_executor/cuda/cudnn_sdpa_score_mod.cc +++ b/third_party/xla/xla/stream_executor/cuda/cudnn_sdpa_score_mod.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "third_party/cudnn_frontend/include/cudnn_frontend.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -185,12 +186,12 @@ Tensor LiteralToCudnnTensor(const xla::HloInstruction* hlo, absl::Status ScoreModFunc::UpdateCudnnMap(cudnn_frontend::graph::Graph& graph, UidGenerator next_uid) { - TF_RETURN_IF_ERROR(UpdateHloParameterToCudnnMap(graph, fwd_hlo_to_cudnn_, - fwd_comp_, next_uid)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR(UpdateHloParameterToCudnnMap(graph, fwd_hlo_to_cudnn_, + fwd_comp_, next_uid)); + RETURN_IF_ERROR( UpdateHloConstantToCudnnMap(graph, fwd_hlo_to_cudnn_, fwd_comp_)); if (bwd_comp_) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( UpdateHloConstantToCudnnMap(graph, bwd_hlo_to_cudnn_, bwd_comp_)); } return absl::OkStatus(); @@ -202,9 +203,9 @@ absl::Status ScoreModFunc::UpdateHloParameterToCudnnMap( const xla::HloComputation* computation, UidGenerator next_uid) { for (int i = 1; i < computation->num_parameters(); i++) { auto parameter = computation->parameter_instruction(i); - TF_ASSIGN_OR_RETURN(const dnn::DataType type, - xla::gpu::GetDNNDataTypeFromPrimitiveType( - parameter->shape().element_type())); + ASSIGN_OR_RETURN(const dnn::DataType type, + xla::gpu::GetDNNDataTypeFromPrimitiveType( + parameter->shape().element_type())); auto desc = dnn::TensorDescriptor::For( type, parameter->shape().dimensions(), parameter->shape().layout().minor_to_major()); diff --git a/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc b/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc index b06499c152547f..9d39b954639385 100644 --- a/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/cuda/delay_kernel.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/typed_kernel_factory.h" @@ -54,11 +55,11 @@ absl::StatusOr LaunchDelayKernel(Stream* stream) { // Allocate a semaphore value that will be used to signal to the delay // kernel that it may exit. - TF_ASSIGN_OR_RETURN(auto semaphore, GpuSemaphore::Create(executor)); + ASSIGN_OR_RETURN(auto semaphore, GpuSemaphore::Create(executor)); *semaphore = GpuSemaphoreState::kHold; // In principle the kernel could be loaded lazily and shared across // multiple GpuTimer objects. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto kernel, (TypedKernelFactory, GpuSemaphoreState>::Create(executor, "DelayKernel", @@ -67,9 +68,9 @@ absl::StatusOr LaunchDelayKernel(Stream* stream) { // Launch a delay kernel into this stream, which will spin until // GetElapsedDuration() is called, the timer is destroyed, or the timeout // in the kernel is reached. - TF_RETURN_IF_ERROR(kernel.Launch(ThreadDim(1, 1, 1), BlockDim(1, 1, 1), - stream, semaphore.device(), - GpuSemaphoreState::kRelease)); + RETURN_IF_ERROR(kernel.Launch(ThreadDim(1, 1, 1), BlockDim(1, 1, 1), stream, + semaphore.device(), + GpuSemaphoreState::kRelease)); return semaphore; } diff --git a/third_party/xla/xla/stream_executor/cuda/driver_compilation.cc b/third_party/xla/xla/stream_executor/cuda/driver_compilation.cc index c7615278d5a07f..2eb501c6124bee 100644 --- a/third_party/xla/xla/stream_executor/cuda/driver_compilation.cc +++ b/third_party/xla/xla/stream_executor/cuda/driver_compilation.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/cuda/cuda_status.h" @@ -58,7 +59,7 @@ absl::StatusOr> LinkGpuAsmUsingDriver( static_assert(sizeof(options) / sizeof(options[0]) == sizeof(option_values) / sizeof(option_values[0])); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuLinkCreate(sizeof(options) / sizeof(options[0]), options, option_values, &link_state))); for (const std::vector& image : images) { @@ -73,11 +74,11 @@ absl::StatusOr> LinkGpuAsmUsingDriver( } void* cubin_out; size_t cubin_size; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( cuda::ToStatus(cuLinkComplete(link_state, &cubin_out, &cubin_size))); std::vector cubin(static_cast(cubin_out), static_cast(cubin_out) + cubin_size); - TF_RETURN_IF_ERROR(cuda::ToStatus(cuLinkDestroy(link_state))); + RETURN_IF_ERROR(cuda::ToStatus(cuLinkDestroy(link_state))); return cubin; } diff --git a/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.cc index 5c9be3cdea152b..6cc083cd02649b 100644 --- a/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/driver_compilation_provider.cc @@ -185,8 +185,8 @@ absl::StatusOr DriverCompilationProvider::CompileAndLink( } VLOG(3) << "Driver compilation info log output: " << info_log_buffer; - TF_RETURN_IF_ERROR(CreateErrorFromPTXASLog(info_log_buffer, architecture, - options.cancel_if_reg_spill)); + RETURN_IF_ERROR(CreateErrorFromPTXASLog(info_log_buffer, architecture, + options.cancel_if_reg_spill)); std::vector cubin(static_cast(cubin_out), static_cast(cubin_out) + cubin_size); diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc index af829ca48ff629..51c547f1a17565 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/nvJitLink.h" #include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" @@ -136,7 +137,7 @@ absl::StatusOr CompileAndLinkUsingLibNvJitLink( return cuda::Assembly{}; } - TF_ASSIGN_OR_RETURN(NvJitLinkVersion version, GetNvJitLinkVersion()); + ASSIGN_OR_RETURN(NvJitLinkVersion version, GetNvJitLinkVersion()); auto [version_major, version_minor] = version; WarnIfBadPtxasVersion("nvJitLink", cc, {version_major, version_minor, 0}); @@ -175,8 +176,7 @@ absl::StatusOr CompileAndLinkUsingLibNvJitLink( }; if (create_result != NVJITLINK_SUCCESS) { - TF_ASSIGN_OR_RETURN(std::string error_log, - nvJitLinkGetErrorLog(link_handle)); + ASSIGN_OR_RETURN(std::string error_log, nvJitLinkGetErrorLog(link_handle)); VLOG(3) << "libnvjitlink error log output: " << error_log; @@ -203,15 +203,15 @@ absl::StatusOr CompileAndLinkUsingLibNvJitLink( std::optional error_log; if (dump_compilation_log || result != NVJITLINK_SUCCESS) { - TF_ASSIGN_OR_RETURN(error_log, nvJitLinkGetErrorLog(link_handle)); + ASSIGN_OR_RETURN(error_log, nvJitLinkGetErrorLog(link_handle)); } if (result != NVJITLINK_SUCCESS) { // Print the verbose output of ptxas. VLOG(3) << "libnvjitlink error log output: " << *error_log; - TF_RETURN_IF_ERROR(CreateErrorFromPTXASLog(*error_log, architecture, - cancel_if_reg_spill)); + RETURN_IF_ERROR(CreateErrorFromPTXASLog(*error_log, architecture, + cancel_if_reg_spill)); return ToStatus(result, *error_log); } @@ -226,23 +226,22 @@ absl::StatusOr CompileAndLinkUsingLibNvJitLink( nvJitLinkResult linking_result = nvJitLinkComplete(link_handle); if (linking_result != NVJITLINK_SUCCESS) { - TF_ASSIGN_OR_RETURN(std::string error_log, - nvJitLinkGetErrorLog(link_handle)); + ASSIGN_OR_RETURN(std::string error_log, nvJitLinkGetErrorLog(link_handle)); // Print the verbose output of ptxas. VLOG(3) << "libnvjitlink error log output: " << error_log; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CreateErrorFromPTXASLog(error_log, architecture, cancel_if_reg_spill)); return ToStatus(linking_result, error_log); } - TF_ASSIGN_OR_RETURN(std::string info_log, nvJitLinkGetInfoLog(link_handle)); + ASSIGN_OR_RETURN(std::string info_log, nvJitLinkGetInfoLog(link_handle)); // Print the verbose output of ptxas. VLOG(3) << "libnvjitlink info log output: " << info_log; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( CreateErrorFromPTXASLog(info_log, architecture, cancel_if_reg_spill)); ModuleStats module_stats = ExtractModuleStatsFromLog(info_log); @@ -279,8 +278,7 @@ absl::StatusOr GetLatestPtxIsaVersionForLibNvJitLink() { }; if (create_result != NVJITLINK_SUCCESS) { - TF_ASSIGN_OR_RETURN(std::string error_log, - nvJitLinkGetErrorLog(link_handle)); + ASSIGN_OR_RETURN(std::string error_log, nvJitLinkGetErrorLog(link_handle)); VLOG(3) << "libnvjitlink error log output: " << error_log; @@ -304,7 +302,7 @@ absl::StatusOr GetLatestPtxIsaVersionForLibNvJitLink() { "libnvjitlink compilation succeeded where it was expected to fail"); } - TF_ASSIGN_OR_RETURN(std::string error_log, nvJitLinkGetErrorLog(link_handle)); + ASSIGN_OR_RETURN(std::string error_log, nvJitLinkGetErrorLog(link_handle)); return GetLatestPtxIsaVersionFromUnsupportedVersionErrorLog(error_log); } diff --git a/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc index f544efa99ebdcd..0c9f9c0d1825f6 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvptxcompiler_compilation_provider.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" @@ -64,9 +65,9 @@ absl::StatusOr NvptxcompilerCompilationProvider::CompileToRelocatableModule( const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { - TF_ASSIGN_OR_RETURN(Assembly assembly, - CompileHelper(cc, ptx, options, - /*compile_to_relocatable_module=*/true)); + ASSIGN_OR_RETURN(Assembly assembly, + CompileHelper(cc, ptx, options, + /*compile_to_relocatable_module=*/true)); return RelocatableModule{std::move(assembly.cubin), std::move(assembly.compilation_log), std::move(assembly.module_stats)}; diff --git a/third_party/xla/xla/stream_executor/cuda/nvshmem.cc b/third_party/xla/xla/stream_executor/cuda/nvshmem.cc index 77ed480c34b7c8..20aeabd2dc0fbe 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvshmem.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvshmem.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/nvshmem/nvshmem.h" // IWYU pragma: keep #include "third_party/nvshmem/nvshmemx.h" // IWYU pragma: keep #include "xla/pjrt/distributed/key_value_store_interface.h" @@ -89,10 +90,10 @@ absl::Status InitializeOnce() { char buf[sizeof(nvshmemx_uniqueid_t)]; std::memcpy(buf, &nvshmem_id, sizeof(nvshmemx_uniqueid_t)); absl::string_view nvshmem_id_str{buf, sizeof(buf)}; - TF_RETURN_IF_ERROR(kv_store->Set(kKvStoreKey, nvshmem_id_str)); + RETURN_IF_ERROR(kv_store->Set(kKvStoreKey, nvshmem_id_str)); } else { - TF_ASSIGN_OR_RETURN(std::string id_str, - kv_store->Get(kKvStoreKey, absl::Minutes(10))); + ASSIGN_OR_RETURN(std::string id_str, + kv_store->Get(kKvStoreKey, absl::Minutes(10))); CHECK(id_str.size() >= sizeof(nvshmemx_uniqueid_t)); std::memcpy(&nvshmem_id, id_str.data(), sizeof(nvshmemx_uniqueid_t)); } diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc index 5308c890053f7a..aba081516900c4 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/nvPTXCompiler.h" #include "xla/stream_executor/cuda/compilation_provider.h" @@ -91,7 +92,7 @@ static absl::string_view ToString(nvPTXCompileResult status) { absl::StatusOr CompileGpuAsmUsingLibNvPtxCompiler( const CudaComputeCapability& cc, const std::string& ptx_contents, GpuAsmOpts options, bool cancel_if_reg_spill, bool dump_compilation_log) { - TF_ASSIGN_OR_RETURN(auto version, GetLibNvPtxCompilerVersion()); + ASSIGN_OR_RETURN(auto version, GetLibNvPtxCompilerVersion()); WarnIfBadPtxasVersion("nvPTXCompiler", cc, version); nvPTXCompilerHandle compiler_handle{}; @@ -225,7 +226,7 @@ absl::StatusOr GetLatestPtxIsaVersionForNvptxCompiler() { }; std::optional disabler; - TF_ASSIGN_OR_RETURN(SemanticVersion version, GetLibNvPtxCompilerVersion()); + ASSIGN_OR_RETURN(SemanticVersion version, GetLibNvPtxCompilerVersion()); if (version < SemanticVersion(13, 0, 0)) { // libNvptxCompiler prior to CUDA 13 has a memory leak when calling // nvPTXCompilerCompile when the input PTX is invalid. diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc index 2ca4b32f070f02..1813c0c446553b 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" @@ -185,10 +186,10 @@ absl::StatusOr> CompileHelper( stream_executor::GpuAsmOpts options(disable_gpuasm_optimizations, /*preferred_cuda_dir=*/"", extra_flags); - TF_ASSIGN_OR_RETURN(stream_executor::cuda::Assembly assembly, - stream_executor::CompileGpuAsmUsingLibNvPtxCompiler( - cc, ptx_input, options, cancel_if_reg_spill, - /*dump_compilation_log=*/false)); + ASSIGN_OR_RETURN(stream_executor::cuda::Assembly assembly, + stream_executor::CompileGpuAsmUsingLibNvPtxCompiler( + cc, ptx_input, options, cancel_if_reg_spill, + /*dump_compilation_log=*/false)); return assembly.cubin; } diff --git a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.cc b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.cc index 0d7298a869934d..1419937869441c 100644 --- a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.cc +++ b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation.cc @@ -44,6 +44,7 @@ limitations under the License. #include "absl/strings/strip.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/status_macros.h" #include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/cubin_or_ptx_image.h" @@ -257,8 +258,8 @@ static void AppendArgsFromOptions(GpuAsmOpts options, absl::StatusOr CompileGpuAsmUsingPtxAs( const CudaComputeCapability& cc, absl::string_view ptx, GpuAsmOpts options, bool cancel_if_reg_spill, bool dump_compilation_log) { - TF_ASSIGN_OR_RETURN(std::string ptxas_path, - FindPtxAsExecutable(options.preferred_cuda_dir)); + ASSIGN_OR_RETURN(std::string ptxas_path, + FindPtxAsExecutable(options.preferred_cuda_dir)); return CompileGpuAsmUsingPtxAs(ptxas_path, cc, ptx, options, cancel_if_reg_spill, dump_compilation_log); } @@ -267,7 +268,7 @@ absl::StatusOr CompileGpuAsmUsingPtxAs( absl::string_view ptxas_path, const CudaComputeCapability& cc, absl::string_view ptx, GpuAsmOpts options, bool cancel_if_reg_spill, bool dump_compilation_log) { - TF_ASSIGN_OR_RETURN(auto version, GetToolVersion(ptxas_path)); + ASSIGN_OR_RETURN(auto version, GetToolVersion(ptxas_path)); WarnIfBadPtxasVersion("ptxas", cc, version); // Write ptx into a temporary file. @@ -359,7 +360,7 @@ absl::StatusOr CompileGpuAsmUsingPtxAs( // Read in the result of compilation and return it as a byte vector. std::string cubin; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::ReadFileToString(tsl::Env::Default(), cubin_path, &cubin)); std::vector cubin_vector(cubin.begin(), cubin.end()); std::optional maybe_compilation_error_log; @@ -372,16 +373,15 @@ absl::StatusOr CompileGpuAsmUsingPtxAs( absl::StatusOr GetAsmCompilerVersion( absl::string_view preferred_cuda_dir) { - TF_ASSIGN_OR_RETURN(std::string ptxas_path, - FindPtxAsExecutable(preferred_cuda_dir)); + ASSIGN_OR_RETURN(std::string ptxas_path, + FindPtxAsExecutable(preferred_cuda_dir)); return GetToolVersion(ptxas_path); } absl::StatusOr> BundleGpuAsmUsingFatbin( std::vector images, GpuAsmOpts options) { - TF_ASSIGN_OR_RETURN( - std::string fatbinary_path, - FindCudaExecutable("fatbinary", options.preferred_cuda_dir)); + ASSIGN_OR_RETURN(std::string fatbinary_path, + FindCudaExecutable("fatbinary", options.preferred_cuda_dir)); // Write images to temporary files. std::vector image_paths; @@ -392,7 +392,7 @@ absl::StatusOr> BundleGpuAsmUsingFatbin( return absl::InternalError( "Could not get temporary filenames for images."); } - TF_RETURN_IF_ERROR(tsl::WriteStringToFile( + RETURN_IF_ERROR(tsl::WriteStringToFile( env, img_path, std::string(img.bytes.begin(), img.bytes.end()))); VLOG(2) << "image written to " << img_path; image_paths.push_back(std::move(img_path)); @@ -460,7 +460,7 @@ absl::StatusOr> BundleGpuAsmUsingFatbin( // Read in the result and return it as a byte vector. std::string result_blob; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::ReadFileToString(tsl::Env::Default(), result_path, &result_blob)); return std::vector(result_blob.begin(), result_blob.end()); } @@ -478,8 +478,8 @@ absl::StatusOr FindNvlinkExecutable( absl::StatusOr GetNvLinkVersion( absl::string_view preferred_cuda_dir) { // Make sure nvlink exists and is executable. - TF_ASSIGN_OR_RETURN(std::string bin_path, - FindNvlinkExecutable(preferred_cuda_dir)); + ASSIGN_OR_RETURN(std::string bin_path, + FindNvlinkExecutable(preferred_cuda_dir)); return GetToolVersion(bin_path); } @@ -488,8 +488,8 @@ absl::StatusOr> LinkUsingNvlink( stream_executor::CudaComputeCapability cc, absl::string_view preferred_cuda_dir, absl::Span> images) { - TF_ASSIGN_OR_RETURN(std::string bin_path, - FindNvlinkExecutable(preferred_cuda_dir)); + ASSIGN_OR_RETURN(std::string bin_path, + FindNvlinkExecutable(preferred_cuda_dir)); return LinkUsingNvlink(bin_path, cc, images); } @@ -514,7 +514,7 @@ absl::StatusOr> LinkUsingNvlink( temp_files.emplace_back(); TF_RET_CHECK(env->LocalTempFilename(&temp_files.back())); temp_files.back() += ".cubin"; - TF_RETURN_IF_ERROR(tsl::WriteStringToFile( + RETURN_IF_ERROR(tsl::WriteStringToFile( env, temp_files.back(), absl::string_view(reinterpret_cast(images[i].data()), images[i].size()))); @@ -560,7 +560,7 @@ absl::StatusOr> LinkUsingNvlink( // Read in the result of compilation and return it as a byte vector. std::string cubin; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::ReadFileToString(tsl::Env::Default(), output_path, &cubin)); std::vector cubin_vector(cubin.begin(), cubin.end()); return cubin_vector; @@ -580,8 +580,8 @@ absl::StatusOr FindNvdisasmExecutable( absl::StatusOr GetNvdisasmVersion( absl::string_view preferred_cuda_dir) { // Make sure nvdisasm exists and is executable. - TF_ASSIGN_OR_RETURN(std::string bin_path, - FindNvdisasmExecutable(preferred_cuda_dir)); + ASSIGN_OR_RETURN(std::string bin_path, + FindNvdisasmExecutable(preferred_cuda_dir)); return GetToolVersion(bin_path); } diff --git a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc index ef8d2f57fe6cca..cc9a821e2b2c45 100644 --- a/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc +++ b/third_party/xla/xla/stream_executor/cuda/subprocess_compilation_provider.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/cuda/compilation_options.h" #include "xla/stream_executor/cuda/compilation_provider.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" @@ -77,9 +78,9 @@ absl::StatusOr SubprocessCompilationProvider::CompileToRelocatableModule( const CudaComputeCapability& cc, absl::string_view ptx, const CompilationOptions& options) const { - TF_ASSIGN_OR_RETURN(auto assembly, - CompileHelper(path_to_ptxas_, cc, ptx, options, - /*compile_to_relocatable_module=*/true)); + ASSIGN_OR_RETURN(auto assembly, + CompileHelper(path_to_ptxas_, cc, ptx, options, + /*compile_to_relocatable_module=*/true)); return RelocatableModule{std::move(assembly.cubin), std::move(assembly.compilation_log)}; } @@ -94,14 +95,14 @@ absl::StatusOr SubprocessCompilationProvider::CompileAndLink( images.push_back(std::get(input).cubin); } else { // If we have a PTX string, we need to compile it to CUBIN first. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( RelocatableModule module, CompileToRelocatableModule(cc, std::get(input).ptx, options)); images.push_back(std::move(module.cubin)); } } - TF_ASSIGN_OR_RETURN(auto cubin, LinkUsingNvlink(path_to_nvlink_, cc, images)); + ASSIGN_OR_RETURN(auto cubin, LinkUsingNvlink(path_to_nvlink_, cc, images)); return Assembly{std::move(cubin)}; } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc index 212bb0439bda33..549c6e2580e934 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -286,18 +286,13 @@ bool MakeOutputColumnMajor(MatrixLayout& lhs, MatrixLayout& rhs, return swap_operands; } -/*static*/ absl::StatusOr BlasLt::GetMatmulPlan( - const Stream* stream, const GemmConfig& cfg, Epilogue epilogue) { - auto blas = Get(stream); - if (blas == nullptr) { - return xla::Internal("BlasLt is unavailable"); +/*static*/ absl::StatusOr BlasLt::Get(StreamExecutor* executor) { + auto blas = executor->AsBlas(); + auto blas_lt = blas != nullptr ? blas->GetBlasLt() : nullptr; + if (blas_lt == nullptr) { + return absl::InternalError("BlasLt is unavailable"); } - return blas->GetMatmulPlan(cfg, epilogue); -} - -/*static*/ BlasLt* BlasLt::Get(const Stream* stream) { - auto blas = stream->parent()->AsBlas(); - return (blas != nullptr ? blas->GetBlasLt() : nullptr); + return blas_lt; } DataType GetScaleType(DataType c_type, ComputationType computation_type) { @@ -309,7 +304,7 @@ DataType GetScaleType(DataType c_type, ComputationType computation_type) { absl::StatusOr BlasLt::GetOrCreateMatmulPlan( const std::string& key, PlanCreateFunc create) { - absl::MutexLock lock(plan_cache_mu_); // double mutex ??? + absl::MutexLock lock(plan_cache_mu_); auto res = plan_cache_.emplace(key, MatmulPlanPtr{}); // New entry inserted: always create a new matmul plan if key is empty, // this is used by command_buffer_thunk test. @@ -331,21 +326,10 @@ size_t BlasLt::GetMatmulPlanCacheSize() const { return plan_cache_.size(); } -/*static*/ absl::StatusOr BlasLt::GetGroupedMatmulPlan( - const Stream* stream, GroupedGemmConfig& cfg, Epilogue epilogue) { - BlasLt* blas = BlasLt::Get(stream); - if (blas == nullptr) { - return xla::Internal("BlasLt is unavailable"); - } - return blas->GetGroupedMatmulPlan(cfg, epilogue); -} - absl::StatusOr BlasLt::GetOrCreateGroupedMatmulPlan( const std::string& key, PlanCreateFunc create) { absl::MutexLock lock(plan_cache_mu_); auto res = grouped_plan_cache_.emplace(key, MatmulPlanPtr{}); - // New entry inserted: always create a new matmul plan if key is empty, - // this is used by command_buffer_thunk test. if (res.second || key.empty()) { VLOG(2) << "Creating a grouped plan for: " << key; ASSIGN_OR_RETURN(res.first->second, create()); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h index ad12759a164b68..0a41ba5180e494 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h @@ -39,6 +39,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_blas_lt.pb.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/platform/statusor.h" #include "xla/types.h" #include "xla/xla_data.pb.h" @@ -233,37 +234,6 @@ struct BlasLt { }; struct MatmulPlan { - // This function is to be removed once TF interface is fixed, - // see tensorflow/core/kernels/matmul_util.cc - absl::Status ExecuteOnStream( - Stream* stream, DeviceAddressBase a, DeviceAddressBase b, - DeviceAddressBase c, DeviceAddressBase d, - DeviceAddressBase bias, // may be null - DeviceAddressBase aux, // may be null - DeviceAddressBase a_scale, DeviceAddressBase b_scale, - DeviceAddressBase c_scale, DeviceAddressBase d_scale, - DeviceAddressBase d_amax, const MatmulAlgorithm& algorithm, - ScratchAllocator& scratch_allocator, - blas::ProfileResult* profile_result = nullptr) const { - // Temporary hack until Tensorflow side is fixed - RETURN_IF_ERROR(const_cast(this)->SetAlgorithm(algorithm)); - return ExecuteOnStream(stream, - MemoryArgs{a, - b, - c, - d, - bias, - aux, - a_scale, - b_scale, - c_scale, - d_scale, - {d_amax}, - DeviceAddressBase{}, - &scratch_allocator}, - profile_result); - } - // API that uses scratch_allocator to allocate workspace. // This version is used by TF: see tensorflow/core/kernels/matmul_util.cc absl::Status ExecuteOnStream( @@ -292,71 +262,24 @@ struct BlasLt { profile_result); } - // API that uses pre-allocated buffer as workspace (regular matmul). - absl::Status ExecuteOnStream( - Stream* stream, DeviceAddressBase a, DeviceAddressBase b, - DeviceAddressBase c, DeviceAddressBase d, - DeviceAddressBase bias, // may be null - DeviceAddressBase aux, // may be null - DeviceAddressBase a_scale, DeviceAddressBase b_scale, - DeviceAddressBase c_scale, DeviceAddressBase d_scale, - DeviceAddressBase d_amax, DeviceAddressBase workspace, - blas::ProfileResult* profile_result = nullptr) const { - return ExecuteOnStream(stream, - MemoryArgs{a, - b, - c, - d, - bias, - aux, - a_scale, - b_scale, - c_scale, - d_scale, - {d_amax}, - workspace, - nullptr}, - profile_result); - } - - // API that uses pre-allocated buffer as workspace (grouped matmul). - absl::Status ExecuteOnStream( - Stream* stream, DeviceAddressBase a, DeviceAddressBase b, - DeviceAddressBase c, DeviceAddressBase d, DeviceAddressBase group_sizes, - DeviceAddressBase bias, // may be null - DeviceAddressBase aux, // may be null - DeviceAddressBase a_scale, DeviceAddressBase b_scale, - DeviceAddressBase c_scale, DeviceAddressBase d_scale, - DeviceAddressBase d_amax, DeviceAddressBase workspace, - blas::ProfileResult* profile_result = nullptr) const { - return ExecuteOnStream(stream, - MemoryArgs{a, - b, - c, - d, - bias, - aux, - a_scale, - b_scale, - c_scale, - d_scale, - {group_sizes}, - workspace, - nullptr}, - profile_result); - } - // The most general form: to be implemented by derived clases. virtual absl::Status ExecuteOnStream( Stream* stream, const MemoryArgs& args, - blas::ProfileResult* profile_result) const = 0; + blas::ProfileResult* profile_result = nullptr) const = 0; // Returns a list of supported algorithms for DoMatmul. The algorithms are // returned in the order of increasing estimated compute time according to // an internal heuristic. virtual absl::StatusOr> GetAlgorithms( - const Stream* stream, size_t max_algorithm_count, - size_t max_workspace_size) const = 0; + size_t max_algorithm_count, size_t max_workspace_size) const = 0; + + // Shim for Tensorflow: to be removed once Tensorflow BlasLt interface is + // updated. Do not use this function directly ! + virtual absl::StatusOr> GetAlgorithms( + const Stream* /*stream*/, size_t max_algorithm_count, + size_t max_workspace_size) const { + return GetAlgorithms(max_algorithm_count, max_workspace_size); + } // Algorithm must to be set before calling ExecuteOnStream function(s). // Usually, we call ExecuteOnStream with the same algorithm ID, hence using @@ -376,17 +299,18 @@ struct BlasLt { const GemmConfig& cfg, Epilogue epilogue) const = 0; virtual absl::StatusOr GetGroupedMatmulPlan( - gpu::GroupedGemmConfig& config, Epilogue epilogue) const = 0; + const gpu::GroupedGemmConfig& config, Epilogue epilogue) const = 0; - static BlasLt* Get(const Stream* stream); + static absl::StatusOr Get(StreamExecutor* executor); - // convenience function to create MatmulPlan directly using stream - static absl::StatusOr GetMatmulPlan(const Stream* stream, + // Shim for Tensorflow: to be removed once Tensorflow BlasLt interface + // is updated. Do not use this function directly ! + static absl::StatusOr GetMatmulPlan(Stream* stream, const GemmConfig& cfg, - Epilogue epilogue); - - static absl::StatusOr GetGroupedMatmulPlan( - const Stream* stream, gpu::GroupedGemmConfig& cfg, Epilogue epilogue); + Epilogue epilogue) { + TF_ASSIGN_OR_RETURN(auto* blas_lt, Get(stream->parent())); + return blas_lt->GetMatmulPlan(cfg, epilogue); + } absl::StatusOr GetOrCreateMatmulPlan(const std::string& key, PlanCreateFunc create); diff --git a/third_party/xla/xla/stream_executor/memory_allocator.cc b/third_party/xla/xla/stream_executor/memory_allocator.cc index b884df059fff4d..0681564330d5eb 100644 --- a/third_party/xla/xla/stream_executor/memory_allocator.cc +++ b/third_party/xla/xla/stream_executor/memory_allocator.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/stream_executor/memory_allocator.h" +#include #include #include @@ -37,30 +38,89 @@ absl::StatusOr MemoryAllocator::AllocationTracker::Track( void* ptr = addr.opaque(); absl::MutexLock lock(&mu_); - auto [it, inserted] = allocations_.emplace(ptr, std::move(allocation)); - if (!inserted) { - return absl::AlreadyExistsError( - absl::StrFormat("Allocation at address %p (size %d) is already tracked", - ptr, addr.size())); + uint64_t id = next_allocation_id_++; + // Set the payload on the DeviceAddressBase handle being returned. This + // allows subsequent operations using this specific handle to quickly find + // the tracked allocation using the payload ID. The MemoryAllocation object + // stored in 'allocations_' does not have its internal DeviceAddressBase + // updated with this payload, as the tracker manages the IDs externally. + // Set the payload on the DeviceAddressBase handle being returned. This + // allows subsequent operations using this specific handle to quickly find + // the tracked allocation using the payload ID. The MemoryAllocation object + // stored in 'allocations_' does not have its internal DeviceAddressBase + // updated with this payload, as the tracker manages the IDs externally. + addr.SetPayload(id); + + if (ptr != nullptr) { + if (ptr_to_id_.contains(ptr)) { + return absl::AlreadyExistsError(absl::StrFormat( + "Allocation at address %p (size %d) is already tracked", ptr, + addr.size())); + } + ptr_to_id_.emplace(ptr, id); } + + allocations_.emplace(id, std::move(allocation)); return addr; } bool MemoryAllocator::AllocationTracker::IsTracked( const DeviceAddressBase& addr) const { absl::MutexLock lock(&mu_); - return allocations_.contains(addr.opaque()); + uint64_t id = addr.payload(); + if (id == 0 && addr.opaque() != nullptr) { + auto it = ptr_to_id_.find(addr.opaque()); + if (it != ptr_to_id_.end()) { + id = it->second; + } + } + return id != 0 && allocations_.contains(id); } absl::Status MemoryAllocator::AllocationTracker::Free(DeviceAddressBase addr) { absl::MutexLock lock(&mu_); - auto it = allocations_.find(addr.opaque()); - if (it == allocations_.end()) { + uint64_t id = addr.payload(); + + // If payload is 0, the caller may have reconstructed the DeviceAddressBase + // using only the void* pointer. We fall back to looking up the unique ID + // using the pointer, if it is not null. + if (id == 0 && addr.opaque() != nullptr) { + auto it = ptr_to_id_.find(addr.opaque()); + if (it != ptr_to_id_.end()) { + id = it->second; + } + } + + if (id == 0) { return absl::NotFoundError( - absl::StrFormat("No tracked allocation at address %p (size %d)", + absl::StrFormat("No tracked allocation for address %p (size %d)", addr.opaque(), addr.size())); } - allocations_.erase(it); + + auto alloc_it = allocations_.find(id); + if (alloc_it == allocations_.end()) { + return absl::NotFoundError(absl::StrFormat( + "No tracked allocation for payload ID %v (address %p, size %d)", id, + addr.opaque(), addr.size())); + } + + const DeviceAddressBase& tracked_addr = alloc_it->second->address(); + if (addr.opaque() != nullptr && addr.opaque() != tracked_addr.opaque()) { + return absl::InvalidArgumentError(absl::StrFormat( + "Address mismatch for payload ID %v: provided %p, tracked %p", id, + addr.opaque(), tracked_addr.opaque())); + } + if (addr.size() != 0 && addr.size() != tracked_addr.size()) { + return absl::InvalidArgumentError(absl::StrFormat( + "Size mismatch for payload ID %v: provided %d, tracked %d", id, + addr.size(), tracked_addr.size())); + } + + void* stored_ptr = alloc_it->second->address().opaque(); + if (stored_ptr != nullptr) { + ptr_to_id_.erase(stored_ptr); + } + allocations_.erase(alloc_it); return absl::OkStatus(); } diff --git a/third_party/xla/xla/stream_executor/memory_allocator.h b/third_party/xla/xla/stream_executor/memory_allocator.h index 3e621d312cdfe5..400a05e603fd6f 100644 --- a/third_party/xla/xla/stream_executor/memory_allocator.h +++ b/third_party/xla/xla/stream_executor/memory_allocator.h @@ -74,11 +74,17 @@ class MemoryAllocator { private: mutable absl::Mutex mu_; - // Keyed by the raw opaque pointer rather than DeviceAddressBase, because - // callers of Free() may not know the original allocation size (e.g. - // DeviceMemAllocator::Free constructs a DeviceAddressBase with size=0). - absl::flat_hash_map> allocations_ - ABSL_GUARDED_BY(mu_); + uint64_t next_allocation_id_ ABSL_GUARDED_BY(mu_) = 1; + + // Primary map keyed by unique allocation ID. + absl::flat_hash_map> + allocations_ ABSL_GUARDED_BY(mu_); + + // Secondary map keyed by the raw opaque pointer, because callers of Free() + // may not know the original payload/ID (e.g., DeviceMemAllocator::Free + // constructs a DeviceAddressBase with payload=0). Non-addressable + // allocations (nullptr) are not stored here. + absl::flat_hash_map ptr_to_id_ ABSL_GUARDED_BY(mu_); }; }; diff --git a/third_party/xla/xla/stream_executor/mock_command_buffer.h b/third_party/xla/xla/stream_executor/mock_command_buffer.h new file mode 100644 index 00000000000000..32c3d3b94ccf00 --- /dev/null +++ b/third_party/xla/xla/stream_executor/mock_command_buffer.h @@ -0,0 +1,144 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_MOCK_COMMAND_BUFFER_H_ +#define XLA_STREAM_EXECUTOR_MOCK_COMMAND_BUFFER_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/testlib/test.h" +#include "xla/stream_executor/bit_pattern.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_args.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" + +namespace stream_executor { + +// Implements CommandBuffer for testing. +class MockCommandBuffer : public CommandBuffer { + public: + MockCommandBuffer() = default; + + MOCK_METHOD(absl::StatusOr, CreateEmptyCmd, + (absl::Span dependencies, + StreamPriority priority), + (override)); + MOCK_METHOD(absl::StatusOr, CreateLaunch, + (const ThreadDim& threads, const BlockDim& blocks, + const Kernel& kernel, const KernelArgs& args, + absl::Span dependencies, + StreamPriority priority), + (override)); + MOCK_METHOD(absl::Status, UpdateLaunch, + (const Command* command, const ThreadDim& threads, + const BlockDim& blocks, const Kernel& kernel, + const KernelArgs& args), + (override)); + MOCK_METHOD(absl::StatusOr, CreateChildCommand, + (const CommandBuffer& nested, + absl::Span dependencies), + (override)); + MOCK_METHOD(absl::Status, UpdateChildCommand, + (const Command* command, const CommandBuffer& nested), + (override)); + MOCK_METHOD(absl::StatusOr, CreateChildCommand, + (absl::AnyInvocable record_fn, + absl::Span dependencies), + (override)); + MOCK_METHOD(absl::Status, UpdateChildCommand, + (const Command* command, + absl::AnyInvocable update_fn), + (override)); + MOCK_METHOD(absl::StatusOr, CreateMemcpyD2D, + (DeviceAddressBase * dst, const DeviceAddressBase& src, + uint64_t size, absl::Span dependencies), + (override)); + MOCK_METHOD(absl::Status, UpdateMemcpyD2D, + (const Command* command, DeviceAddressBase* dst, + const DeviceAddressBase& src, uint64_t size), + (override)); + MOCK_METHOD(absl::StatusOr, CreateMemset, + (DeviceAddressBase * dst, BitPattern bit_pattern, + size_t num_elements, + absl::Span dependencies), + (override)); + MOCK_METHOD(absl::Status, UpdateMemset, + (const Command* command, DeviceAddressBase* dst, + const BitPattern& bit_pattern, size_t num_elements), + (override)); + MOCK_METHOD(absl::StatusOr, CreateDnnGraphCommand, + (dnn::DnnGraph&, Stream&, absl::Span operands, + absl::Span dependencies), + (override)); + MOCK_METHOD(absl::Status, UpdateDnnGraphCommand, + (const Command*, dnn::DnnGraph&, Stream&, + absl::Span operands), + (override)); + MOCK_METHOD(absl::StatusOr, CreateCase, + (DeviceAddress index, + std::vector create_branches, + absl::Span dependencies), + (override)); + MOCK_METHOD(absl::StatusOr, CreateCase, + (DeviceAddress index, + std::vector create_branches, + absl::Span dependencies), + (override)); + MOCK_METHOD(absl::Status, UpdateCase, + (const Command* command, DeviceAddress index, + std::vector update_branches), + (override)); + MOCK_METHOD(absl::Status, UpdateCase, + (const Command* command, DeviceAddress index, + std::vector update_branches), + (override)); + MOCK_METHOD(absl::StatusOr, CreateWhile, + (DeviceAddress pred, CreateCommands create_cond, + CreateCommands create_body, + absl::Span dependencies), + (override)); + MOCK_METHOD(absl::Status, UpdateWhile, + (const Command* command, DeviceAddress pred, + UpdateCommands update_cond, UpdateCommands update_body), + (override)); + MOCK_METHOD(absl::Status, SetPriority, (StreamPriority priority), (override)); + MOCK_METHOD(absl::Status, Submit, (Stream * stream), (override)); + MOCK_METHOD(absl::Status, Finalize, (), (override)); + MOCK_METHOD(absl::Status, Update, (), (override)); + MOCK_METHOD(Mode, mode, (), (const, override)); + MOCK_METHOD(State, state, (), (const, override)); + MOCK_METHOD(std::string, ToString, (), (const, override)); + + private: + MOCK_METHOD(absl::Status, Trace, + (Stream * stream, absl::AnyInvocable function), + (override)); +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_MOCK_COMMAND_BUFFER_H_ diff --git a/third_party/xla/xla/stream_executor/platform/BUILD b/third_party/xla/xla/stream_executor/platform/BUILD index d10d1d854fb180..3eddcb29063f24 100644 --- a/third_party/xla/xla/stream_executor/platform/BUILD +++ b/third_party/xla/xla/stream_executor/platform/BUILD @@ -36,6 +36,7 @@ cc_library( ":initialize", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", diff --git a/third_party/xla/xla/stream_executor/platform/platform_object_registry.h b/third_party/xla/xla/stream_executor/platform/platform_object_registry.h index e5f5c3bf71c4bc..a2884b5cb4f116 100644 --- a/third_party/xla/xla/stream_executor/platform/platform_object_registry.h +++ b/third_party/xla/xla/stream_executor/platform/platform_object_registry.h @@ -29,6 +29,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" // IWYU pragma: keep #include "xla/tsl/platform/statusor.h" @@ -79,8 +80,8 @@ class PlatformObjectRegistry { template absl::StatusOr> FindObject( Platform::Id platform_id) { - TF_ASSIGN_OR_RETURN(const Container& obj, - FindObject(typeid(Trait), platform_id)); + ASSIGN_OR_RETURN(const Container& obj, + FindObject(typeid(Trait), platform_id)); return std::any_cast(obj.element); } diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 63405cf56aea40..92c6b5b582c2a6 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -424,7 +424,6 @@ cc_library( "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:determinism", - "//xla/tsl/util:determinism_hdr_lib", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index 192b5bb544c401..84cf0434f1008a 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/tsl/platform/status_macros.h" #include "rocm/include/hip/library_types.h" #include "rocm/include/hipblas/hipblas.h" +#include "rocm/include/hipblaslt/hipblaslt-ext.hpp" #include "rocm/include/hipblaslt/hipblaslt.h" #include "rocm/include/rocblas/internal/rocblas-types.h" #include "rocm/rocm_config.h" @@ -156,10 +157,10 @@ static absl::StatusOr AsHipblasLtEpilogue( } // namespace absl::Status BlasLt::Init() { - hipblasLtHandle_t blas_lt; - SE_HIPBLAS_RETURN_IF_ERROR(hipblasLtCreate(&blas_lt)); + hipblasLtHandle_t handle; + SE_HIPBLAS_RETURN_IF_ERROR(hipblasLtCreate(&handle)); absl::MutexLock lock(mu_); - blas_lt_.reset(blas_lt); + handle_.reset(handle); return absl::OkStatus(); } @@ -225,56 +226,14 @@ absl::Status BlasLt::Init() { return std::move(desc); } -auto BlasLt::MatmulPlan::GetAlgorithmsForGroupedMatmul( - const Stream* stream, size_t max_algorithm_count, - size_t max_workspace_size) const - -> absl::StatusOr> { - std::vector heuristicResult; - - auto blas_lt = static_cast(gpu::BlasLt::Get(stream)); - absl::MutexLock lock(&blas_lt->mu_); - - std::unique_ptr activation = blas_lt->parent_->Activate(); - - auto problem = grouped_gemm_->getProblemTypes()[0]; - - grouped_gemm_->setMaxWorkspaceBytes(max_workspace_size); - - SE_HIPBLAS_RETURN_IF_ERROR(hipblaslt_ext::getAllAlgos( - blas_lt->blas_lt_.get(), hipblaslt_ext::GemmType::HIPBLASLT_GROUPED_GEMM, - problem.getOpA(), problem.getOpB(), problem.getTypeA(), - problem.getTypeB(), problem.getTypeC(), problem.getTypeD(), - problem.getTypeCompute(), heuristicResult)); - VLOG(2) << "Total heuristics found: " << heuristicResult.size(); - std::vector algorithms; - algorithms.reserve(max_algorithm_count); - for (hipblasLtMatmulHeuristicResult_t& result : heuristicResult) { - if (algorithms.size() >= max_algorithm_count) break; - size_t workspace_size = 0; - if ((result.state == HIPBLAS_STATUS_SUCCESS) && - (grouped_gemm_->isAlgoSupported(result.algo, workspace_size) == - HIPBLAS_STATUS_SUCCESS)) { - algorithms.push_back({result.algo, result.workspaceSize}); - } - } - - VLOG(2) << "Grouped GEMM algorithms found with epilogue " - << static_cast(grouped_gemm_epilogue_) << ": " - << algorithms.size(); - - return std::move(algorithms); -} - -auto BlasLt::MatmulPlan::GetAlgorithmsForMatmul(const Stream* stream, - size_t max_algorithm_count, - size_t max_workspace_size) const +auto BlasLt::RegularMatmulPlan::GetAlgorithms(size_t max_algorithm_count, + size_t max_workspace_size) const -> absl::StatusOr> { max_algorithm_count = std::min(max_algorithm_count, size_t{INT_MAX}); std::vector results(max_algorithm_count); { - auto blas_lt = static_cast(gpu::BlasLt::Get(stream)); - absl::MutexLock lock(blas_lt->mu_); - TF_RET_CHECK(blas_lt->blas_lt_ != nullptr); + absl::MutexLock lock(blas_lt_.mu_); + TF_RET_CHECK(blas_lt_.handle_ != nullptr); hipblasLtMatmulPreference_t hip_preference; SE_HIPBLAS_RETURN_IF_ERROR( @@ -288,29 +247,30 @@ auto BlasLt::MatmulPlan::GetAlgorithmsForMatmul(const Stream* stream, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, max_workspace_size)); - std::unique_ptr activation = blas_lt->parent_->Activate(); + std::unique_ptr activation = + blas_lt_.executor_->Activate(); // hipBlasLt requires setting the bias pointer (even a dummy one), otherwise // no algorithms can be found for "bias epilogues". This is to be removed // later when this limitation is gone. - if (op_desc_->has_bias_epilogue()) { + if (op_desc_.has_bias_epilogue()) { static int64_t dummy_pointer = 0xACEBALL; RETURN_IF_ERROR(SetAttr( - op_desc_->get(), HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &dummy_pointer)); + op_desc_.get(), HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &dummy_pointer)); } // hipBlasLt requires setting the a/b scale pointer (even a dummy one), // otherwise no algorithms can be found for "a/b scaling". This is to be // removed later when this limitation is gone. - switch (op_desc_->scale_mode()) { + switch (op_desc_.scale_mode()) { case gpu::ScaleMode::kNone: break; case gpu::ScaleMode::kTensorScaling: { static int64_t dummy_pointer = 0xACEBALL; - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &dummy_pointer)); - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &dummy_pointer)); break; @@ -318,17 +278,17 @@ auto BlasLt::MatmulPlan::GetAlgorithmsForMatmul(const Stream* stream, case gpu::ScaleMode::kBlockScaling: { #if TF_ROCM_VERSION >= 70000 static int64_t dummy_pointer = 0xACEBALL; - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &dummy_pointer)); - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &dummy_pointer)); hipblasLtMatmulMatrixScale_t mx_scale = HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, mx_scale)); - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, mx_scale)); #else return absl::InternalError("Block scaling requires ROCm >= 7.0"); @@ -339,9 +299,9 @@ auto BlasLt::MatmulPlan::GetAlgorithmsForMatmul(const Stream* stream, int found_algorithm_count = 0; auto error = hipblasLtMatmulAlgoGetHeuristic( - blas_lt->blas_lt_.get(), op_desc_->get(), a_desc_->get(), - b_desc_->get(), c_desc_->get(), d_desc_->get(), preference.get(), - max_algorithm_count, results.data(), &found_algorithm_count); + blas_lt_.handle_.get(), op_desc_.get(), a_desc_.get(), b_desc_.get(), + c_desc_.get(), d_desc_.get(), preference.get(), max_algorithm_count, + results.data(), &found_algorithm_count); if (error != 0) { VLOG(0) << "hipblasLtMatmulAlgoGetHeuristic returned " << (int)error; SE_HIPBLAS_RETURN_IF_ERROR(error); @@ -356,20 +316,7 @@ auto BlasLt::MatmulPlan::GetAlgorithmsForMatmul(const Stream* stream, algorithms.push_back({result.algo, result.workspaceSize}); } } - return std::move(algorithms); -} - -auto BlasLt::MatmulPlan::GetAlgorithms(const Stream* stream, - size_t max_algorithm_count, - size_t max_workspace_size) const - -> absl::StatusOr> { - if (is_grouped()) { - return GetAlgorithmsForGroupedMatmul(stream, max_algorithm_count, - max_workspace_size); - } else { - return GetAlgorithmsForMatmul(stream, max_algorithm_count, - max_workspace_size); - } + return algorithms; } absl::StatusOr BlasLt::GetMatmulPlan( @@ -416,7 +363,7 @@ absl::StatusOr BlasLt::GetMatmulPlan( gpu::GetBlasComputationType( cfg.precision_algorithm, lhs_layout.dtype, output_layout.dtype, cfg.compute_precision, - parent_->GetDeviceDescription().gpu_compute_capability())); + executor_->GetDeviceDescription().gpu_compute_capability())); } if (lhs_layout.order == gpu::MatrixLayout::Order::kRowMajor) { @@ -444,32 +391,25 @@ absl::StatusOr BlasLt::GetMatmulPlan( // data type for fp8 matmul, which is different from cublasLt. This is a // workaround to match cublasLt behavior. if (epilogue == gpu::BlasLt::Epilogue::kBias) { - auto a_dtype = a_desc.type(); - auto b_dtype = b_desc.type(); - - auto bias_dtype = d_desc.type(); + auto a_dtype = a_desc.type(), b_dtype = b_desc.type(); if ((a_dtype == HIP_R_8F_E4M3_FNUZ || a_dtype == HIP_R_8F_E5M2_FNUZ) && (b_dtype == HIP_R_8F_E4M3_FNUZ || b_dtype == HIP_R_8F_E5M2_FNUZ)) { - auto d_dtype = d_desc.type(); - if (d_dtype == HIP_R_32F) { - bias_dtype = HIP_R_16BF; - } - - if (bias_dtype != d_dtype) { + auto bias_dtype = d_desc.type(); + if (bias_dtype == HIP_R_32F) { RETURN_IF_ERROR(SetAttr( - op_desc.get(), HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_dtype)); + op_desc.get(), HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, HIP_R_16BF)); } } } #endif // TF_ROCM_VERSION >= 60000 - return std::make_unique(std::move(op_desc), std::move(a_desc), - std::move(b_desc), std::move(c_desc), - std::move(d_desc), cfg.alpha, cfg.beta, - must_swap_operands); + return std::make_unique( + *this, std::move(op_desc), std::move(a_desc), std::move(b_desc), + std::move(c_desc), std::move(d_desc), cfg.alpha, cfg.beta, + must_swap_operands); } -absl::Status BlasLt::MatmulPlan::DoMatmul( +absl::Status BlasLt::RegularMatmulPlan::DoMatmul( Stream* stream, const void* alpha, const void* beta, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const { @@ -486,10 +426,8 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( } } - auto blas_lt = static_cast(gpu::BlasLt::Get(stream)); - TF_RET_CHECK(blas_lt != nullptr); absl::Status status = - blas_lt->parent_->RecordApiTrace(StreamExecutor::GemmCallTrace{ + blas_lt_.executor_->RecordApiTrace(StreamExecutor::GemmCallTrace{ StreamExecutor::GemmCallTrace::GemmType::kBlasLt, 0, a.size(), b.size()}); std::unique_ptr timer; @@ -516,34 +454,34 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( auto palgo = std::any_cast(&algorithm_->opaque_algo); { - absl::MutexLock lock(blas_lt->mu_); - TF_RET_CHECK(blas_lt->blas_lt_ != nullptr); + absl::MutexLock lock(blas_lt_.mu_); + TF_RET_CHECK(blas_lt_.handle_ != nullptr); // We must set the bias and aux pointers while holding the mutex, to avoid a // potential race condition from multiple threads sharing the same plan. - if (op_desc_->has_bias_epilogue() && args.bias != nullptr) { - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + if (op_desc_.has_bias_epilogue() && args.bias != nullptr) { + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_BIAS_POINTER, args.bias.opaque())); } #if TF_ROCM_VERSION >= 60000 if (a_scale != nullptr) { - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, a_scale.opaque())); } if (b_scale != nullptr) { - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, b_scale.opaque())); } if (args.c_scale != nullptr) { - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_C_SCALE_POINTER, args.c_scale.opaque())); } if (args.d_scale != nullptr) { - RETURN_IF_ERROR(SetAttr(op_desc_->get(), + RETURN_IF_ERROR(SetAttr(op_desc_.get(), HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, args.d_scale.opaque())); } @@ -563,40 +501,33 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( "hipblaslt does not support auxiliary inputs / outputs"); } - std::unique_ptr activation = blas_lt->parent_->Activate(); + std::unique_ptr activation = + blas_lt_.executor_->Activate(); if (palgo != nullptr) { - SE_HIPBLAS_RETURN_IF_ERROR(hipblasLtMatmul( - blas_lt->blas_lt_.get(), op_desc_->get(), alpha, a.opaque(), - a_desc_->get(), b.opaque(), b_desc_->get(), beta, args.c.opaque(), - c_desc_->get(), args.d.opaque(), d_desc_->get(), palgo, - workspace_addr, workspace_size, - absl::bit_cast( - stream->platform_specific_handle().stream))); + SE_HIPBLAS_RETURN_IF_ERROR( + hipblasLtMatmul(blas_lt_.handle_.get(), op_desc_.get(), alpha, + a.opaque(), a_desc_.get(), b.opaque(), b_desc_.get(), + beta, args.c.opaque(), c_desc_.get(), args.d.opaque(), + d_desc_.get(), palgo, workspace_addr, workspace_size, + absl::bit_cast( + stream->platform_specific_handle().stream))); } else { return absl::InternalError("hipblaslt: Invalid algorithm type"); } } - typedef struct __attribute__((packed, aligned(8))) _rocblaslt_matmul_algo { - uint8_t data[8] = {0}; - bool fallback = false; - size_t max_workspace_bytes = 0; - } rocblaslt_matmul_algo; - if (profile_result != nullptr) { ASSIGN_OR_RETURN(absl::Duration elapsed, timer->GetElapsedDuration()); // set algorithm ID to be unique (otherwise it gets kDefaultAlgorithm ID) - auto roc_algo = (const rocblaslt_matmul_algo*)palgo; - auto pindex = (int*)roc_algo->data; - profile_result->set_algorithm(static_cast(*pindex)); + profile_result->set_algorithm(hipblaslt_ext::getIndexFromAlgo(*palgo)); profile_result->set_is_valid(true); profile_result->set_elapsed_time_in_ms(absl::ToDoubleMilliseconds(elapsed)); } return absl::OkStatus(); } -absl::Status BlasLt::MatmulPlan::ExecuteRegularMatmul( +absl::Status BlasLt::RegularMatmulPlan::ExecuteOnStream( Stream* stream, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const { auto wrapped_matmul = [&](auto scale) { @@ -604,16 +535,16 @@ absl::Status BlasLt::MatmulPlan::ExecuteRegularMatmul( Scale salpha; if constexpr (std::is_same_v || std::is_same_v) { - salpha = static_cast(*alpha_); + salpha = static_cast(alpha_); } else { - salpha = static_cast(alpha_->real()); + salpha = static_cast(alpha_.real()); } - Scale sbeta = static_cast(*beta_); + Scale sbeta = static_cast(beta_); return DoMatmul(stream, &salpha, &sbeta, args, profile_result); }; - std::tuple operand_types{a_desc_->type(), b_desc_->type(), c_desc_->type(), - d_desc_->type()}; + std::tuple operand_types{a_desc_.type(), b_desc_.type(), c_desc_.type(), + d_desc_.type()}; #define TYPED_MATMUL(Scale, ATYPE, BTYPE, CTYPE, DTYPE) \ if (operand_types == std::tuple{ATYPE, BTYPE, CTYPE, DTYPE}) { \ @@ -750,54 +681,73 @@ absl::Status BlasLt::MatmulPlan::ExecuteRegularMatmul( return xla::Internal("Unexpected dtype"); } -/*static*/ absl::StatusOr> -BlasLt::MatmulPlan::InitializeGroupedGemm(gpu::GroupedGemmConfig cfg, - Epilogue epilogue, - hipblasLtHandle_t blas_lt_handle, - blas::ComputationType compute_type) { - const bool must_swap_operands = cfg.must_swap_operands; - // Use `new` directly because the grouped-matmul constructor is private. - // This is equivalent to make_unique; the private constructor is accessible - // here because this is a static member of MatmulPlan. - auto plan = std::unique_ptr( - new MatmulPlan(std::move(cfg), must_swap_operands, epilogue)); - RETURN_IF_ERROR(plan->DoInitializeGroupedGemm(blas_lt_handle, compute_type)); - return plan; +auto BlasLt::GroupedMatmulPlan::GetAlgorithms(size_t max_algorithm_count, + size_t max_workspace_size) const + -> absl::StatusOr> { + std::vector heuristicResult; + + absl::MutexLock lock(blas_lt_.mu_); + + std::unique_ptr activation = blas_lt_.executor_->Activate(); + + auto problem = grouped_gemm_->getProblemTypes()[0]; + + grouped_gemm_->setMaxWorkspaceBytes(max_workspace_size); + + SE_HIPBLAS_RETURN_IF_ERROR(hipblaslt_ext::getAllAlgos( + blas_lt_.handle_.get(), hipblaslt_ext::GemmType::HIPBLASLT_GROUPED_GEMM, + problem.getOpA(), problem.getOpB(), problem.getTypeA(), + problem.getTypeB(), problem.getTypeC(), problem.getTypeD(), + problem.getTypeCompute(), heuristicResult)); + VLOG(2) << "Total heuristics found: " << heuristicResult.size(); + std::vector algorithms; + algorithms.reserve(max_algorithm_count); + for (hipblasLtMatmulHeuristicResult_t& result : heuristicResult) { + if (algorithms.size() >= max_algorithm_count) { + break; + } + size_t workspace_size = 0; + if ((result.state == HIPBLAS_STATUS_SUCCESS) && + (grouped_gemm_->isAlgoSupported(result.algo, workspace_size) == + HIPBLAS_STATUS_SUCCESS)) { + algorithms.push_back({result.algo, result.workspaceSize}); + } + } + + VLOG(2) << "Grouped GEMM algorithms found with epilogue " + << static_cast(epilogue_) << ": " << algorithms.size(); + + return std::move(algorithms); } -absl::Status BlasLt::MatmulPlan::DoInitializeGroupedGemm( - hipblasLtHandle_t blas_lt_handle, blas::ComputationType compute_type) { - auto batch_stride_a = (cfg_->m * cfg_->k); - auto batch_stride_b = (cfg_->n * cfg_->k); - if (cfg_->ragged_mode == gpu::RaggedDotMode::kRaggedNonContracting) { - if (cfg_->must_swap_operands) { - batch_stride_a *= cfg_->group_count; +absl::Status BlasLt::GroupedMatmulPlan::DoInitialize( + blas::ComputationType compute_type, Epilogue epilogue) { + epilogue_ = epilogue; + auto batch_stride_a = (cfg_.m * cfg_.k); + auto batch_stride_b = (cfg_.n * cfg_.k); + if (cfg_.ragged_mode == gpu::RaggedDotMode::kRaggedNonContracting) { + if (cfg_.must_swap_operands) { + batch_stride_a *= cfg_.group_count; } else { - batch_stride_b *= cfg_->group_count; + batch_stride_b *= cfg_.group_count; } } - grouped_gemm_ = std::make_unique( - blas_lt_handle, AsHipblasOperation(cfg_->trans_a), - AsHipblasOperation(cfg_->trans_b), AsHipblasDataType(cfg_->type_a), - AsHipblasDataType(cfg_->type_b), AsHipblasDataType(cfg_->type_c), - AsHipblasDataType(cfg_->type_d), AsHipblasComputeType(compute_type)); - - std::vector v_m(cfg_->group_count, cfg_->m), - v_n(cfg_->group_count, cfg_->n), v_k(cfg_->group_count, cfg_->k), - v_batch_count(cfg_->group_count, cfg_->batch_count), - v_lda(cfg_->group_count, cfg_->lhs_leading_dim_stride), - v_ldb(cfg_->group_count, cfg_->rhs_leading_dim_stride), - v_ldc(cfg_->group_count, cfg_->output_leading_dim_stride), - v_ldd(cfg_->group_count, cfg_->output_leading_dim_stride), - v_strideA(cfg_->group_count, batch_stride_a), - v_strideB(cfg_->group_count, batch_stride_b), - v_strideC(cfg_->group_count, (cfg_->m * cfg_->n)), - v_strideD(cfg_->group_count, (cfg_->m * cfg_->n)); - - switch (cfg_->ragged_mode) { + std::vector v_m(cfg_.group_count, cfg_.m), + v_n(cfg_.group_count, cfg_.n), v_k(cfg_.group_count, cfg_.k), + v_batch_count(cfg_.group_count, cfg_.batch_count), + v_lda(cfg_.group_count, cfg_.lhs_leading_dim_stride), + v_ldb(cfg_.group_count, cfg_.rhs_leading_dim_stride), + v_ldc(cfg_.group_count, cfg_.output_leading_dim_stride), + v_ldd(cfg_.group_count, cfg_.output_leading_dim_stride), + v_strideA(cfg_.group_count, batch_stride_a), + v_strideB(cfg_.group_count, batch_stride_b), + v_strideC(cfg_.group_count, (cfg_.m * cfg_.n)), + v_strideD(cfg_.group_count, (cfg_.m * cfg_.n)); + + switch (cfg_.ragged_mode) { case gpu::RaggedDotMode::kRaggedNonContracting: { - if (cfg_->must_swap_operands) { + if (cfg_.must_swap_operands) { // ragged dimension in the n dimension std::fill(v_n.begin() + 1, v_n.end(), 1); } else { @@ -815,26 +765,24 @@ absl::Status BlasLt::MatmulPlan::DoInitializeGroupedGemm( } } - std::vector epilogue(cfg_->group_count); - std::vector inputs(cfg_->group_count); + std::vector epilogues(cfg_.group_count); + std::vector inputs(cfg_.group_count); // Convert the epilogue from the stored member variable - ASSIGN_OR_RETURN(auto hip_epilogue, - AsHipblasLtEpilogue(grouped_gemm_epilogue_)); + ASSIGN_OR_RETURN(auto hip_epilogue, AsHipblasLtEpilogue(epilogue)); - float salpha = cfg_->alpha.real(); - float sbeta = cfg_->beta; + float salpha = cfg_.alpha.real(); + float sbeta = cfg_.beta; // Dummy bias pointer for initialization (similar to A, B, C, D) static void* dummy_bias = reinterpret_cast(~0ULL); - for (int64_t i = 0; i < cfg_->group_count; i++) { - epilogue[i].setMode(hip_epilogue); + for (int64_t i = 0; i < cfg_.group_count; i++) { + epilogues[i].setMode(hip_epilogue); // Set bias data type and dummy bias pointer for bias epilogues - if (grouped_gemm_epilogue_ == Epilogue::kBias || - grouped_gemm_epilogue_ == Epilogue::kBiasThenReLU || - grouped_gemm_epilogue_ == Epilogue::kBiasThenGELU || - grouped_gemm_epilogue_ == Epilogue::kBiasThenSILU) { - epilogue[i].setBiasDataType(AsHipblasDataType(cfg_->type_d)); + if (epilogue == Epilogue::kBias || epilogue == Epilogue::kBiasThenReLU || + epilogue == Epilogue::kBiasThenGELU || + epilogue == Epilogue::kBiasThenSILU) { + epilogues[i].setBiasDataType(AsHipblasDataType(cfg_.type_d)); inputs[i].setBias(dummy_bias); } inputs[i].setA(reinterpret_cast(~0ULL)); @@ -846,17 +794,23 @@ absl::Status BlasLt::MatmulPlan::DoInitializeGroupedGemm( } hipblaslt_ext::GemmProblemType problem( - AsHipblasOperation(cfg_->trans_a), AsHipblasOperation(cfg_->trans_b), - AsHipblasDataType(cfg_->type_a), AsHipblasDataType(cfg_->type_b), - AsHipblasDataType(cfg_->type_c), AsHipblasDataType(cfg_->type_d), + AsHipblasOperation(cfg_.trans_a), AsHipblasOperation(cfg_.trans_b), + AsHipblasDataType(cfg_.type_a), AsHipblasDataType(cfg_.type_b), + AsHipblasDataType(cfg_.type_c), AsHipblasDataType(cfg_.type_d), AsHipblasComputeType(compute_type)); + absl::MutexLock lock(blas_lt_.mu_); + grouped_gemm_ = std::make_unique( + blas_lt_.handle_.get(), problem.getOpA(), problem.getOpB(), + problem.getTypeA(), problem.getTypeB(), problem.getTypeC(), + problem.getTypeD(), problem.getTypeCompute()); + // Note that Matrices given to HipBlasLt Group-Gemm // are expected to be in COLUMN-MAJOR order. SE_HIPBLAS_RETURN_IF_ERROR(grouped_gemm_->setProblem( v_m, v_n, v_k, v_batch_count, v_lda, v_ldb, v_ldc, v_ldd, v_strideA, - v_strideB, v_strideC, v_strideD, epilogue, inputs, problem)); + v_strideB, v_strideC, v_strideD, epilogues, inputs, problem)); // UserArgument is expecting specific code for activation and bias types. // These are defined by the hipBLASLt library during the problem @@ -872,9 +826,10 @@ absl::Status BlasLt::MatmulPlan::DoInitializeGroupedGemm( // Get default UserArguments from hipBLASLt and save required parameters auto default_ua = - std::make_unique(cfg_->group_count); - grouped_gemm_->getDefaultValueForDeviceUserArguments( - static_cast(default_ua.get())); + std::make_unique(cfg_.group_count); + SE_HIPBLAS_RETURN_IF_ERROR( + grouped_gemm_->getDefaultValueForDeviceUserArguments( + static_cast(default_ua.get()))); // The ragged-dot API enforce that activation and bias types are the same // for all the gemm operations. // We can therefore only retrieve/verify value of the first gemm op. @@ -892,7 +847,7 @@ absl::Status BlasLt::MatmulPlan::DoInitializeGroupedGemm( } absl::StatusOr BlasLt::GetGroupedMatmulPlan( - gpu::GroupedGemmConfig& cfg, Epilogue epilogue) const { + const gpu::GroupedGemmConfig& cfg, Epilogue epilogue) const { auto compute_type = cfg.compute_type; if (!compute_type) { // obtain compute_type unless provided by the user ASSIGN_OR_RETURN(xla::PrimitiveType primitive_type_a, @@ -904,26 +859,15 @@ absl::StatusOr BlasLt::GetGroupedMatmulPlan( gpu::GetBlasComputationType( cfg.precision_algorithm, primitive_type_a, primitive_type_d, cfg.compute_precision, - parent_->GetDeviceDescription().gpu_compute_capability())); - } - if (!compute_type) { - return absl::InternalError( - "This algorithm requires a non-zero compute_type!"); - } - - hipblasLtHandle_t blas_lt_handle; - { - absl::MutexLock lock(&mu_); - blas_lt_handle = blas_lt_.get(); + executor_->GetDeviceDescription().gpu_compute_capability())); } - ASSIGN_OR_RETURN( - auto plan, MatmulPlan::InitializeGroupedGemm( - std::move(cfg), epilogue, blas_lt_handle, *compute_type)); - return absl::StatusOr(std::move(plan)); + auto plan = std::make_unique(*this, cfg); + RETURN_IF_ERROR(plan->DoInitialize(*compute_type, epilogue)); + return plan; } -absl::Status BlasLt::MatmulPlan::ExecuteGroupedMatmul( +absl::Status BlasLt::GroupedMatmulPlan::ExecuteOnStream( Stream* stream, const MemoryArgs& args, blas::ProfileResult* profile_result) const { if (!algorithm_.has_value()) { @@ -932,18 +876,19 @@ absl::Status BlasLt::MatmulPlan::ExecuteGroupedMatmul( } auto palgo = std::any_cast(&algorithm_->opaque_algo); - auto blas_lt = static_cast(gpu::BlasLt::Get(stream)); - absl::MutexLock lock(&blas_lt->mu_); + if (palgo == nullptr) { + return absl::InternalError("Invalid algorithm type!"); + } + absl::MutexLock lock(blas_lt_.mu_); // The first chunk of the workspace is reserved for userargs. - if (algorithm_must_be_initialized_ || - !args.workspace.IsSameAs(saved_address_workspace_)) { + if (algorithm_dirty_ || !args.workspace.IsSameAs(saved_address_workspace_)) { void* addr_workspace = static_cast( static_cast(args.workspace.opaque()) + - sizeof(hipblaslt_ext::UserArguments) * cfg_->group_count); + sizeof(hipblaslt_ext::UserArguments) * cfg_.group_count); SE_HIPBLAS_RETURN_IF_ERROR( grouped_gemm_->initialize(*palgo, addr_workspace)); - algorithm_must_be_initialized_ = false; + algorithm_dirty_ = false; saved_address_workspace_ = args.workspace; } @@ -975,27 +920,28 @@ absl::Status BlasLt::MatmulPlan::ExecuteGroupedMatmul( } }; - DeviceAddressBase a = cfg_->must_swap_operands ? args.b : args.a; - DeviceAddressBase b = cfg_->must_swap_operands ? args.a : args.b; + DeviceAddressBase a = args.a, b = args.b; + if (cfg_.must_swap_operands) { + std::swap(a, b); + } const DeviceAddressBase& d_userArgs = args.workspace; - uint8_t log2_byte_width_elem_a = Log2ByteWidth(cfg_->type_a); - uint8_t log2_byte_width_elem_b = Log2ByteWidth(cfg_->type_b); - uint8_t log2_byte_width_elem_c = Log2ByteWidth(cfg_->type_c); - uint8_t log2_byte_width_elem_d = Log2ByteWidth(cfg_->type_d); + uint8_t log2_byte_width_elem_a = Log2ByteWidth(cfg_.type_a); + uint8_t log2_byte_width_elem_b = Log2ByteWidth(cfg_.type_b); + uint8_t log2_byte_width_elem_c = Log2ByteWidth(cfg_.type_c); + uint8_t log2_byte_width_elem_d = Log2ByteWidth(cfg_.type_d); auto group_size_bytewidth = - (cfg_->ragged_mode != gpu::RaggedDotMode::kRaggedBatch) + (cfg_.ragged_mode != gpu::RaggedDotMode::kRaggedBatch) ? static_cast(args.group_sizes.size() / - (cfg_->group_count * cfg_->batch_count)) - : static_cast(args.group_sizes.size() / cfg_->group_count); + (cfg_.group_count * cfg_.batch_count)) + : static_cast(args.group_sizes.size() / cfg_.group_count); - TF_RET_CHECK(blas_lt != nullptr); absl::Status status = - blas_lt->parent_->RecordApiTrace(StreamExecutor::GemmCallTrace{ + blas_lt_.executor_->RecordApiTrace(StreamExecutor::GemmCallTrace{ StreamExecutor::GemmCallTrace::GemmType::kBlasLt, 0, - cfg_->m * cfg_->k * cfg_->batch_count, - cfg_->k * cfg_->n * cfg_->batch_count * cfg_->group_count}); + cfg_.m * cfg_.k * cfg_.batch_count, + cfg_.k * cfg_.n * cfg_.batch_count * cfg_.group_count}); std::unique_ptr timer; if (profile_result != nullptr) { @@ -1003,36 +949,36 @@ absl::Status BlasLt::MatmulPlan::ExecuteGroupedMatmul( profile_result->warmup_run_executed())); } - uint32_t strideA1 = cfg_->lhs_leading_dim_stride; - uint32_t strideA2 = cfg_->m * cfg_->k; - uint32_t strideB1 = cfg_->rhs_leading_dim_stride; - uint32_t strideB2 = cfg_->n * cfg_->k; - if (cfg_->ragged_mode == gpu::RaggedDotMode::kRaggedNonContracting) { - if (cfg_->must_swap_operands) { - strideA2 *= cfg_->group_count; + uint32_t strideA1 = cfg_.lhs_leading_dim_stride; + uint32_t strideA2 = cfg_.m * cfg_.k; + uint32_t strideB1 = cfg_.rhs_leading_dim_stride; + uint32_t strideB2 = cfg_.n * cfg_.k; + if (cfg_.ragged_mode == gpu::RaggedDotMode::kRaggedNonContracting) { + if (cfg_.must_swap_operands) { + strideA2 *= cfg_.group_count; } else { - strideB2 *= cfg_->group_count; + strideB2 *= cfg_.group_count; } } - uint32_t strideC1 = cfg_->c_leading_dim_stride; - uint32_t strideC2 = cfg_->m * cfg_->n; - uint32_t strideD1 = cfg_->output_leading_dim_stride; - uint32_t strideD2 = cfg_->m * cfg_->n; + uint32_t strideC1 = cfg_.c_leading_dim_stride; + uint32_t strideC2 = cfg_.m * cfg_.n; + uint32_t strideD1 = cfg_.output_leading_dim_stride; + uint32_t strideD2 = cfg_.m * cfg_.n; auto hip_stream = absl::bit_cast(stream->platform_specific_handle().stream); - bool has_matrix_bias = (cfg_->beta != 0.0); + bool has_matrix_bias = (cfg_.beta != 0.0); GroupGemmUpdateArgs( hip_stream, d_userArgs, a, b, args.c, args.d, args.bias, args.group_sizes, group_size_bytewidth, log2_byte_width_elem_a, log2_byte_width_elem_b, - log2_byte_width_elem_c, log2_byte_width_elem_d, cfg_->stride_ragged_dim, - cfg_->stride_group_dim, cfg_->c_stride_ragged_dim, - cfg_->output_stride_ragged_dim, cfg_->must_swap_operands, cfg_->m, - cfg_->n, cfg_->k, cfg_->batch_count, strideA1, strideA2, strideB1, - strideB2, strideC1, strideC2, strideD1, strideD2, cfg_->ragged_mode, - cfg_->group_count, activation_type_, bias_type_, has_matrix_bias); + log2_byte_width_elem_c, log2_byte_width_elem_d, cfg_.stride_ragged_dim, + cfg_.stride_group_dim, cfg_.c_stride_ragged_dim, + cfg_.output_stride_ragged_dim, cfg_.must_swap_operands, cfg_.m, cfg_.n, + cfg_.k, cfg_.batch_count, strideA1, strideA2, strideB1, strideB2, + strideC1, strideC2, strideD1, strideD2, cfg_.ragged_mode, + cfg_.group_count, activation_type_, bias_type_, has_matrix_bias); SE_HIPBLAS_RETURN_IF_ERROR( grouped_gemm_->run(d_userArgs.opaque(), hip_stream)); @@ -1041,25 +987,13 @@ absl::Status BlasLt::MatmulPlan::ExecuteGroupedMatmul( if (profile_result != nullptr) { ASSIGN_OR_RETURN(absl::Duration elapsed, timer->GetElapsedDuration()); // set algorithm ID to be unique (otherwise it gets kDefaultAlgorithm ID) - hipblasLtMatmulAlgo_t algo = - std::any_cast(algorithm_->opaque_algo); - profile_result->set_algorithm(hipblaslt_ext::getIndexFromAlgo(algo)); + profile_result->set_algorithm(hipblaslt_ext::getIndexFromAlgo(*palgo)); profile_result->set_is_valid(true); profile_result->set_elapsed_time_in_ms(absl::ToDoubleMilliseconds(elapsed)); } return absl::OkStatus(); } -absl::Status BlasLt::MatmulPlan::ExecuteOnStream( - Stream* stream, const gpu::BlasLt::MemoryArgs& args, - blas::ProfileResult* profile_result) const { - if (is_grouped()) { - return ExecuteGroupedMatmul(stream, args, profile_result); - } else { - return ExecuteRegularMatmul(stream, args, profile_result); - } -} - } // namespace rocm } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h index 3c74a721ede4c6..b404693dfb4597 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h @@ -14,6 +14,7 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_ROCM_HIP_BLAS_LT_H_ #include +#include #include #include "absl/base/thread_annotations.h" @@ -21,6 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "rocm/include/hipblaslt/hipblaslt-ext.hpp" +#include "rocm/include/hipblaslt/hipblaslt.h" #include "rocm/rocm_config.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_address.h" @@ -94,104 +96,95 @@ class BlasLt : public gpu::BlasLt { gpu::ScaleMode scale_mode_; }; - struct MatmulPlan : public gpu::BlasLt::MatmulPlan { - // Constructor for regular matmul - MatmulPlan(MatmulDesc&& op_desc, MatrixLayout&& a_desc, - MatrixLayout&& b_desc, MatrixLayout&& c_desc, - MatrixLayout&& d_desc, xla::complex128 alpha, double beta, - bool must_swap_operands) - : op_desc_(std::move(op_desc)), + class RegularMatmulPlan : public gpu::BlasLt::MatmulPlan { + public: + RegularMatmulPlan(const BlasLt& blas_lt, MatmulDesc&& op_desc, + MatrixLayout&& a_desc, MatrixLayout&& b_desc, + MatrixLayout&& c_desc, MatrixLayout&& d_desc, + xla::complex128 alpha, double beta, + bool must_swap_operands) + : blas_lt_(blas_lt), + op_desc_(std::move(op_desc)), a_desc_(std::move(a_desc)), b_desc_(std::move(b_desc)), c_desc_(std::move(c_desc)), d_desc_(std::move(d_desc)), alpha_(alpha), beta_(beta), - must_swap_operands_(must_swap_operands), - grouped_gemm_(nullptr) {} + must_swap_operands_(must_swap_operands) {} - ~MatmulPlan() override = default; + ~RegularMatmulPlan() override = default; absl::Status ExecuteOnStream( Stream* stream, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const override; absl::StatusOr> GetAlgorithms( - const Stream* stream, size_t max_algorithm_count = 128, - size_t max_workspace_size = 1ll << 32) const override; + size_t max_algorithm_count, size_t max_workspace_size) const override; absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) override { algorithm_ = algorithm; - algorithm_must_be_initialized_ = true; return absl::OkStatus(); } - bool is_grouped() const { return grouped_gemm_ != nullptr; } - - // Static factory for grouped-GEMM plans. Creates a fully-initialized - // MatmulPlan or returns an error; it is impossible to construct a grouped - // MatmulPlan without going through this function. - static absl::StatusOr> InitializeGroupedGemm( - gpu::GroupedGemmConfig cfg, Epilogue epilogue, - hipblasLtHandle_t blas_lt_handle, blas::ComputationType compute_type); - protected: absl::Status DoMatmul(Stream* stream, const void* alpha, const void* beta, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const; private: - // Private constructor for grouped matmul. Callers must use - // InitializeGroupedGemm() instead. - MatmulPlan(gpu::GroupedGemmConfig&& cfg, bool must_swap_operands, - Epilogue epilogue) - : must_swap_operands_(must_swap_operands), - cfg_(std::move(cfg)), - grouped_gemm_epilogue_(epilogue), - grouped_gemm_(nullptr) {} - - // Performs the hipBLASLt grouped-GEMM initialization work. Called by the - // static factory InitializeGroupedGemm(). - absl::Status DoInitializeGroupedGemm(hipblasLtHandle_t blas_lt_handle, - blas::ComputationType compute_type); - - absl::StatusOr> GetAlgorithmsForGroupedMatmul( - const Stream* stream, size_t max_algorithm_count, - size_t max_workspace_size) const; - absl::StatusOr> GetAlgorithmsForMatmul( - const Stream* stream, size_t max_algorithm_count, - size_t max_workspace_size) const; - absl::Status ExecuteRegularMatmul( - Stream* stream, const gpu::BlasLt::MemoryArgs& args, - blas::ProfileResult* profile_result) const; - absl::Status ExecuteGroupedMatmul( - Stream* stream, const gpu::BlasLt::MemoryArgs& args, - blas::ProfileResult* profile_result) const; - - // TODO(cjfj): Add consistency checks for types, shapes, etc.? - // Regular matmul members (optional for grouped matmul) - std::optional op_desc_; - std::optional a_desc_; - std::optional b_desc_; - std::optional c_desc_; - std::optional d_desc_; - std::optional alpha_; - std::optional beta_; + const BlasLt& blas_lt_; + MatmulDesc op_desc_; + MatrixLayout a_desc_; + MatrixLayout b_desc_; + MatrixLayout c_desc_; + MatrixLayout d_desc_; + xla::complex128 alpha_; + double beta_; bool must_swap_operands_; - std::optional algorithm_; // selected algorithm - // Grouped matmul members - std::optional cfg_; - Epilogue grouped_gemm_epilogue_ = Epilogue::kDefault; + mutable std::optional algorithm_; // selected algorithm + }; // class RegularMatmulPlan + + class GroupedMatmulPlan : public gpu::BlasLt::MatmulPlan { + public: + friend class BlasLt; + + GroupedMatmulPlan(const BlasLt& blas_lt, const gpu::GroupedGemmConfig& cfg) + : blas_lt_(blas_lt), cfg_(cfg) {} + + ~GroupedMatmulPlan() override = default; + + absl::Status ExecuteOnStream( + Stream* stream, const gpu::BlasLt::MemoryArgs& args, + blas::ProfileResult* profile_result) const override; + + absl::StatusOr> GetAlgorithms( + size_t max_algorithm_count, size_t max_workspace_size) const override; + + absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) override { + algorithm_ = algorithm; + algorithm_dirty_ = true; + return absl::OkStatus(); + } + + private: + absl::Status DoInitialize(blas::ComputationType compute_type, + Epilogue epilogue); + + const BlasLt& blas_lt_; + gpu::GroupedGemmConfig cfg_; + Epilogue epilogue_ = Epilogue::kDefault; std::unique_ptr grouped_gemm_; - mutable bool algorithm_must_be_initialized_ = false; + mutable std::optional algorithm_; // selected algorithm + mutable bool algorithm_dirty_ = false; mutable DeviceAddressBase saved_address_workspace_{}; // Saved default activation parameters from hipBLASLt int32_t activation_type_ = 0; int8_t bias_type_ = 0; - }; // class MatmulPlan + }; // class GroupedMatmulPlan - explicit BlasLt(StreamExecutor* parent) - : parent_(parent), blas_lt_(nullptr, hipblasLtDestroy) {} + explicit BlasLt(StreamExecutor* executor) + : executor_(executor), handle_(nullptr, hipblasLtDestroy) {} absl::Status Init() override; @@ -199,14 +192,14 @@ class BlasLt : public gpu::BlasLt { Epilogue epilogue) const override; absl::StatusOr GetGroupedMatmulPlan( - gpu::GroupedGemmConfig& config, Epilogue epilogue) const override; + const gpu::GroupedGemmConfig& config, Epilogue epilogue) const override; ~BlasLt() override = default; private: - StreamExecutor* parent_; + StreamExecutor* executor_; mutable absl::Mutex mu_; - Owned blas_lt_ ABSL_GUARDED_BY(mu_); + Owned handle_ ABSL_GUARDED_BY(mu_); }; } // namespace rocm diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_blas_lt.cc b/third_party/xla/xla/stream_executor/sycl/sycl_blas_lt.cc index bc351e7f41f2da..f837c75200fde2 100644 --- a/third_party/xla/xla/stream_executor/sycl/sycl_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/sycl/sycl_blas_lt.cc @@ -31,12 +31,6 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& config, return std::make_unique(config, epilogue); } -absl::StatusOr BlasLt::GetGroupedMatmulPlan( - gpu::GroupedGemmConfig& config, Epilogue epilogue) const { - return absl::UnimplementedError( - "Grouped GEMM is not supported for Sycl BlasLt"); -} - absl::Status BlasLt::MatmulPlan::ExecuteOnStream( Stream* stream, const gpu::BlasLt::MemoryArgs& args, blas::ProfileResult* profile_result) const { @@ -44,8 +38,7 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream( "SyclBlasLt MatmulPlan::ExecuteOnStream not implemented"); } -auto BlasLt::MatmulPlan::GetAlgorithms(const Stream* stream, - size_t max_algorithm_count, +auto BlasLt::MatmulPlan::GetAlgorithms(size_t max_algorithm_count, size_t max_workspace_size) const -> absl::StatusOr> { absl::MutexLock lock(&mu_); diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_blas_lt.h b/third_party/xla/xla/stream_executor/sycl/sycl_blas_lt.h index c70598be5d68d4..0fed6047c8ec2d 100644 --- a/third_party/xla/xla/stream_executor/sycl/sycl_blas_lt.h +++ b/third_party/xla/xla/stream_executor/sycl/sycl_blas_lt.h @@ -36,7 +36,10 @@ class BlasLt : public gpu::BlasLt { const gpu::GemmConfig& config, Epilogue epilogue) const override; absl::StatusOr GetGroupedMatmulPlan( - gpu::GroupedGemmConfig& config, Epilogue epilogue) const override; + const gpu::GroupedGemmConfig& config, Epilogue epilogue) const override { + return absl::UnimplementedError( + "Grouped GEMM is not supported for Sycl BlasLt"); + } ~BlasLt() override = default; @@ -51,8 +54,7 @@ class BlasLt : public gpu::BlasLt { blas::ProfileResult* profile_result) const override; absl::StatusOr> GetAlgorithms( - const Stream* stream, size_t max_algorithm_count, - size_t max_workspace_size) const override; + size_t max_algorithm_count, size_t max_workspace_size) const override; absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) override { // TODO(intel-tf): Do we need a lock here? diff --git a/third_party/xla/xla/stream_executor/tpu/BUILD b/third_party/xla/xla/stream_executor/tpu/BUILD index 84c4d37c21d82f..86a61a8f6fff1f 100644 --- a/third_party/xla/xla/stream_executor/tpu/BUILD +++ b/third_party/xla/xla/stream_executor/tpu/BUILD @@ -341,6 +341,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/c:tsl_status_internal", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -579,6 +580,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_proto_cc", "//xla/stream_executor:platform", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:statusor", @@ -618,6 +620,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -649,6 +652,7 @@ cc_library( "//xla/service:shaped_buffer", "//xla/stream_executor:device_address", "//xla/stream_executor:stream", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executable.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executable.cc index b5f5c6d80017ab..77d531d660bdf6 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executable.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executable.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/executable.h" #include "xla/service/service_executable_run_options.h" @@ -222,8 +223,8 @@ absl::StatusOr> TpuExecutable::Deserialize( absl::Cleanup cleanup_c_module = [&c_module]() { ApiConverter::Destroy(&c_module); }; - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - ApiConverter::FromC(c_module)); + ASSIGN_OR_RETURN(std::unique_ptr hlo_module, + ApiConverter::FromC(c_module)); return std::make_unique(se_executable, std::move(hlo_module)); } diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc index 7ff8c7855457d0..d93fdfbde272c8 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/layout_util.h" #include "xla/service/compiler.h" @@ -53,15 +54,14 @@ namespace { static absl::Status PopulateResultTupleBuffers(const ShapedBuffer& result, se::Stream* stream, se::Stream* transfer_stream) { - TF_ASSIGN_OR_RETURN( - auto transfer_manager, - TransferManager::GetForPlatform(stream->parent()->GetPlatform())); + ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform( + stream->parent()->GetPlatform())); if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(), result)) { - TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( + RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( transfer_stream ? transfer_stream : stream, result)); if (transfer_stream && transfer_stream != stream) { - TF_RETURN_IF_ERROR(stream->WaitFor(transfer_stream)); + RETURN_IF_ERROR(stream->WaitFor(transfer_stream)); } return absl::OkStatus(); } else { @@ -79,9 +79,9 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( se::Stream* transfer_stream) { auto stream_exec = stream->parent(); auto platform = stream_exec->GetPlatform(); - TF_ASSIGN_OR_RETURN(auto transfer_manager, - TransferManager::GetForPlatform(platform)); - TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform->id())); + ASSIGN_OR_RETURN(auto transfer_manager, + TransferManager::GetForPlatform(platform)); + ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform->id())); auto shape_size_fn = compiler->ShapeSizeBytesFunction(); auto device_ordinal = stream_exec->device_ordinal(); VLOG(3) << "AllocateOutputMemoryWithInputReuse, device = " << device_ordinal @@ -100,7 +100,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( Shape device_shape = shape; xla::ShapeUtil::ForEachMutableSubshape(&device_shape, update_layout); - TF_RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus( + RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus( [&](const ShapeIndex& output_index, std::optional alias) -> absl::Status { @@ -185,7 +185,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( const Shape& on_device_shape = result.Result().on_device_shape(); const Shape& on_device_subshape = ShapeUtil::GetSubshape(on_device_shape, result_index); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto allocated_buffer, allocator->Allocate(device_ordinal, allocation_bytes, /*retry_on_failure=*/true, @@ -201,7 +201,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( << " parameter buffers (total result buffer size: " << total_result_buffer_bytes << ")"; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( PopulateResultTupleBuffers(result.Result(), stream, transfer_stream)); return std::move(result); } @@ -221,7 +221,7 @@ absl::StatusOr TpuExecutableInterface::ExecuteAsyncOnStream( const HloInputOutputAliasConfig& alias_config = !has_module() ? HloInputOutputAliasConfig() : module().input_output_alias_config(); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( ExecutionOutput result, AllocateOutputMemoryWithInputReuse( shape, alias_config, run_options->allocator(), &arguments, stream, @@ -261,7 +261,7 @@ absl::StatusOr TpuExecutableInterface::ExecuteAsyncOnStream( // arguments. MarkToBeReleasedArguments(absl::MakeSpan(arguments), result); - TF_RETURN_IF_ERROR(LoadProgramAndEnqueueToStream( + RETURN_IF_ERROR(LoadProgramAndEnqueueToStream( *run_options, memory_bases, result.Result().root_buffer(), cross_program_prefetch_addrs, cross_program_prefetch_offsets)); return std::move(result); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc index f5b89cf2e09499..2c55a0764ccd95 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/stream_executor/allocator_stats.h" #include "xla/stream_executor/device_address.h" #include "xla/stream_executor/device_description.h" @@ -91,7 +92,7 @@ absl::StatusOr> TpuExecutor::CreateEvent() { StatusHelper status; ExecutorApiFn()->TpuExecutor_AllocateEventFn(executor_, se_event, status.c_status); - TF_RETURN_IF_ERROR(status.status()); + RETURN_IF_ERROR(status.status()); return std::move(tpu_event); } diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_on_demand_compiler.cc b/third_party/xla/xla/stream_executor/tpu/tpu_on_demand_compiler.cc index 410e9a8dd4a7e3..c242761600bd48 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_on_demand_compiler.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_on_demand_compiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" @@ -163,8 +164,8 @@ class TpuCompiler : public Compiler { ExecutorApiFn()->TpuExecutable_HloModuleFn(se_executables[0]); auto cleanup_c_module = absl::MakeCleanup([&c_module]() { ApiConverter::Destroy(&c_module); }); - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ApiConverter::FromC(c_module)); + ASSIGN_OR_RETURN(std::unique_ptr module, + ApiConverter::FromC(c_module)); std::shared_ptr module_shared(module.release()); executables.emplace_back(std::make_unique( se_executables[0], std::move(module_shared))); diff --git a/third_party/xla/xla/tools/hlo_bisect/BUILD b/third_party/xla/xla/tools/hlo_bisect/BUILD index 2e4a3c4f57fdf0..ed190a80c1161d 100644 --- a/third_party/xla/xla/tools/hlo_bisect/BUILD +++ b/third_party/xla/xla/tools/hlo_bisect/BUILD @@ -22,6 +22,7 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/simplifiers:hlo_dce", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -42,6 +43,7 @@ xla_cc_test( "//xla/hlo/testlib:pattern_matcher_gmock", "//xla/service:pattern_matcher", "//xla/tests:xla_internal_test_main", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", ], diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc index 6db6e970d22f7a..31c99f54e3be79 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -123,7 +124,7 @@ absl::Status MorphModuleWithLiterals( HloInstruction::CreateConstant(std::move(literal))); absl::Status replace_status = entry_computation->ReplaceInstruction(instruction, new_instruction); - TF_RETURN_IF_ERROR(replace_status); + RETURN_IF_ERROR(replace_status); } } @@ -156,15 +157,15 @@ absl::StatusOr HloBisectState::TrimEntryComputation() { for (int iter = 0; changed || iter < 2; iter++) { if (iter % 2 == 0) { VLOG(2) << "Trimming by outputs, iteration " << iter; - TF_ASSIGN_OR_RETURN(changed, TrimByOutputs()); + ASSIGN_OR_RETURN(changed, TrimByOutputs()); } else { VLOG(2) << "Trimming by instructions, iteration " << iter; - TF_ASSIGN_OR_RETURN(changed, TrimByInstructions()); + ASSIGN_OR_RETURN(changed, TrimByInstructions()); } changed_in_loop |= changed; } VLOG(2) << "Trimming by replacing instructions with literals"; - TF_ASSIGN_OR_RETURN(changed, TrimByUsingConstants()); + ASSIGN_OR_RETURN(changed, TrimByUsingConstants()); VLOG(2) << "Final module: " << module_->ToString(); return changed || changed_in_loop; } @@ -178,7 +179,7 @@ absl::StatusOr HloBisectState::RunModule(const HloModule& module) { // Run the modified module with the bug checker. absl::StatusOr bug_result = bug_checker_->Run(module); - TF_RETURN_IF_ERROR(bug_result.status()); + RETURN_IF_ERROR(bug_result.status()); VLOG(3) << "Bug checker result: " << bug_result.value(); // Update foldable instructions data. @@ -207,7 +208,7 @@ absl::StatusOr HloBisectState::TrimByOutputs() { std::unique_ptr new_module = module_->Clone(/*suffix=*/""); HloInstruction* const* new_operands = new_module->entry_computation()->root_instruction()->operands().begin(); - TF_RETURN_IF_ERROR(MorphModuleWithOutputs( + RETURN_IF_ERROR(MorphModuleWithOutputs( new_module.get(), absl::MakeSpan(new_operands + start, end - start + 1))); return RunModule(*new_module); @@ -220,11 +221,11 @@ absl::StatusOr HloBisectState::TrimByOutputs() { int64_t cur = bisect_low + (bisect_high - bisect_low) / 2; VLOG(2) << "Number of outputs: " << (cur - bisect_low + 1) << " [" << bisect_low << ".." << cur << "]"; - TF_ASSIGN_OR_RETURN(bool has_bug, run_modified(bisect_low, cur)); + ASSIGN_OR_RETURN(bool has_bug, run_modified(bisect_low, cur)); if (has_bug) { bisect_high = cur; } else { - TF_ASSIGN_OR_RETURN(has_bug, run_modified(cur + 1, bisect_high)); + ASSIGN_OR_RETURN(has_bug, run_modified(cur + 1, bisect_high)); if (has_bug) { bisect_low = cur + 1; } else { @@ -237,11 +238,11 @@ absl::StatusOr HloBisectState::TrimByOutputs() { bool changed = (bisect_high - bisect_low) < (root_instruction->operand_count() - 1); if (changed) { - TF_RETURN_IF_ERROR(MorphModuleWithOutputs( + RETURN_IF_ERROR(MorphModuleWithOutputs( module_.get(), absl::MakeSpan(root_instruction->operands().begin() + bisect_low, bisect_high - bisect_low + 1))); - TF_RETURN_IF_ERROR(ExpectModuleIsBuggy()); + RETURN_IF_ERROR(ExpectModuleIsBuggy()); } return changed; } @@ -261,8 +262,8 @@ absl::StatusOr HloBisectState::TrimByInstructions() { VLOG(2) << "Number of instructions: " << cur << " (of " << computation->instruction_count() << ")"; std::unique_ptr new_module = module_->Clone(/*suffix=*/""); - TF_RETURN_IF_ERROR(MorphModuleWithInstructions(new_module.get(), cur)); - TF_ASSIGN_OR_RETURN(bool has_bug, RunModule(*new_module)); + RETURN_IF_ERROR(MorphModuleWithInstructions(new_module.get(), cur)); + ASSIGN_OR_RETURN(bool has_bug, RunModule(*new_module)); if (has_bug) { bisect_high = cur; } else { @@ -280,8 +281,8 @@ absl::StatusOr HloBisectState::TrimByInstructions() { // Update the current module and verify that the bug is present, if changed. bool changed = bisect_high < upper_bound; if (changed) { - TF_RETURN_IF_ERROR(MorphModuleWithInstructions(module_.get(), bisect_high)); - TF_RETURN_IF_ERROR(ExpectModuleIsBuggy()); + RETURN_IF_ERROR(MorphModuleWithInstructions(module_.get(), bisect_high)); + RETURN_IF_ERROR(ExpectModuleIsBuggy()); } return changed; } @@ -300,7 +301,7 @@ absl::StatusOr HloBisectState::TrimByUsingConstants() { literal_map.insert(std::move(it)); } else if (foldable_instructions_.contains(instr->name())) { absl::StatusOr literal_status = MakeFakeLiteral(instr->shape()); - TF_RETURN_IF_ERROR(literal_status.status()); + RETURN_IF_ERROR(literal_status.status()); literal_map[instr->name()] = std::move(literal_status).value(); ++random_literals_count; } @@ -312,9 +313,9 @@ absl::StatusOr HloBisectState::TrimByUsingConstants() { // It is possible that the random literals will make the bug disappear, in // which case the module will not get reduced. std::unique_ptr new_module = module_->Clone(/*suffix=*/""); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( MorphModuleWithLiterals(new_module.get(), std::move(literal_map))); - TF_ASSIGN_OR_RETURN(bool has_bug, RunModule(*new_module)); + ASSIGN_OR_RETURN(bool has_bug, RunModule(*new_module)); if (has_bug) { std::swap(module_, new_module); } @@ -323,7 +324,7 @@ absl::StatusOr HloBisectState::TrimByUsingConstants() { absl::Status HloBisectState::ExpectModuleIsBuggy() { // Verify that the current module has a bug. - TF_ASSIGN_OR_RETURN(bool has_bug, RunModule(*module_)); + ASSIGN_OR_RETURN(bool has_bug, RunModule(*module_)); if (has_bug) { return absl::OkStatus(); } @@ -332,7 +333,7 @@ absl::Status HloBisectState::ExpectModuleIsBuggy() { const int retry_count = 5; int bug_count = 0; for (int i = 0; i < retry_count; i++) { - TF_ASSIGN_OR_RETURN(has_bug, bug_checker_->Run(*module_)); + ASSIGN_OR_RETURN(has_bug, bug_checker_->Run(*module_)); if (has_bug) { bug_count++; } diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc index e2a33ab1eb65f2..f1bfd85f238cd7 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" @@ -174,7 +175,7 @@ TEST_F(HloBisectStateTest, TrimByOutputsLostBug) { public: CustomBugSearch() : TestBugSearch({HloOpcode::kConstant}) {} absl::StatusOr Run(const HloModule& module) override { - TF_ASSIGN_OR_RETURN(bool has_constants, TestBugSearch::Run(module)); + ASSIGN_OR_RETURN(bool has_constants, TestBugSearch::Run(module)); int program_size = module.entry_computation()->instruction_count(); return program_size == 5 && !has_constants; } diff --git a/third_party/xla/xla/tools/hlo_isolation/BUILD b/third_party/xla/xla/tools/hlo_isolation/BUILD index c145225ab47338..fb9ad744e311f9 100644 --- a/third_party/xla/xla/tools/hlo_isolation/BUILD +++ b/third_party/xla/xla/tools/hlo_isolation/BUILD @@ -49,6 +49,7 @@ cc_library( "//xla/tools:hlo_module_loader", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/tools/hlo_isolation/hlo_isolation_api.cc b/third_party/xla/xla/tools/hlo_isolation/hlo_isolation_api.cc index b6a3517bf1a317..3bd7583c73e3e4 100644 --- a/third_party/xla/xla/tools/hlo_isolation/hlo_isolation_api.cc +++ b/third_party/xla/xla/tools/hlo_isolation/hlo_isolation_api.cc @@ -37,6 +37,7 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "re2/re2.h" #include "xla/comparison_util.h" #include "xla/error_spec.h" @@ -245,8 +246,8 @@ absl::StatusOr RunModule(std::unique_ptr module, HloRunnerInterface* runner, absl::Span input_data, bool run_hlo_passes) { - TF_ASSIGN_OR_RETURN(auto executable, runner->CreateExecutable( - std::move(module), run_hlo_passes)); + ASSIGN_OR_RETURN(auto executable, + runner->CreateExecutable(std::move(module), run_hlo_passes)); return runner->ExecuteWithExecutable(executable.get(), input_data); } @@ -257,11 +258,11 @@ absl::StatusOr RunIsolationTestOnModule( HloIsolationTestResult result; result.set_module_name(module.name()); - TF_RETURN_IF_ERROR(InitIsolatorOptions(options)); + RETURN_IF_ERROR(InitIsolatorOptions(options)); std::vector local_inputs; if (input_data.empty()) { - TF_ASSIGN_OR_RETURN(local_inputs, options.make_fake_arguments_fn(module)); + ASSIGN_OR_RETURN(local_inputs, options.make_fake_arguments_fn(module)); input_data = local_inputs; } @@ -291,7 +292,7 @@ absl::StatusOr RunIsolationTestOnModule( // Run defused test runner. std::unique_ptr defused_module = module.Clone("defused"); - TF_RETURN_IF_ERROR(DefuseModule(defused_module.get())); + RETURN_IF_ERROR(DefuseModule(defused_module.get())); absl::StatusOr defused_output = run_module(std::move(defused_module), test_runner, input_data); if (!defused_output.ok()) { @@ -331,7 +332,7 @@ absl::StatusOr RunIsolationTestOnModule( std::unique_ptr despecialized_module = module.Clone("despecialized"); Despecializer despecializer; - TF_RETURN_IF_ERROR(despecializer.Run(despecialized_module.get()).status()); + RETURN_IF_ERROR(despecializer.Run(despecialized_module.get()).status()); std::string despecialized_module_name = despecialized_module->name(); // Run the reference runner. @@ -364,10 +365,10 @@ absl::StatusOr RunIsolationTestOnModule( absl::StatusOr> RunIsolationPipeline( const HloModule& input_module, HloRunnerInterface* test_runner, HloRunnerInterface* reference_runner, PipelineIsolationOptions options) { - TF_RETURN_IF_ERROR(ValidatePipelineOptions(options)); - TF_RETURN_IF_ERROR(InitIsolatorOptions(options.module_options)); + RETURN_IF_ERROR(ValidatePipelineOptions(options)); + RETURN_IF_ERROR(InitIsolatorOptions(options.module_options)); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector> modules, DecomposeHloModule(input_module, /*deduplicate_modules=*/true)); @@ -465,8 +466,8 @@ absl::StatusOr> RunIsolationPipeline( absl::StatusOr> RunIsolationPipeline( const std::string& input_path, HloRunnerInterface* test_runner, HloRunnerInterface* reference_runner, PipelineIsolationOptions options) { - TF_ASSIGN_OR_RETURN(std::unique_ptr loaded_module, - LoadModuleFromFile(input_path)); + ASSIGN_OR_RETURN(std::unique_ptr loaded_module, + LoadModuleFromFile(input_path)); return RunIsolationPipeline(*loaded_module, test_runner, reference_runner, options); } @@ -603,8 +604,8 @@ absl::StatusOr> ExtractTopMismatches( absl::StatusOr ExtractTopRelativeErrorMismatch( std::string error_message) { - TF_ASSIGN_OR_RETURN(std::vector mismatches, - ExtractTopMismatches(error_message, false)); + ASSIGN_OR_RETURN(std::vector mismatches, + ExtractTopMismatches(error_message, false)); if (mismatches.empty()) { return absl::NotFoundError( "Could not find top relative error mismatch in the error message."); @@ -627,7 +628,7 @@ absl::StatusOr> DetectReducesInModuleOutput( } std::vector reduce_in_output(num_outputs, false); std::unique_ptr defused_module = module->Clone("defused"); - TF_RETURN_IF_ERROR(DefuseModule(defused_module.get())); + RETURN_IF_ERROR(DefuseModule(defused_module.get())); auto bfs = [&reduce_in_output](HloModule* module, int64_t output_index) -> void { @@ -681,10 +682,10 @@ absl::StatusOr> ExtractAndEnrichTopMismatches( int64_t num_outputs = is_tuple ? module->result_shape().tuple_shapes().size() : 1; - TF_ASSIGN_OR_RETURN(std::vector mismatches, - ExtractTopMismatches(error_message, is_tuple)); - TF_ASSIGN_OR_RETURN(std::vector reduce_in_output, - DetectReducesInModuleOutput(module)); + ASSIGN_OR_RETURN(std::vector mismatches, + ExtractTopMismatches(error_message, is_tuple)); + ASSIGN_OR_RETURN(std::vector reduce_in_output, + DetectReducesInModuleOutput(module)); for (NumericMismatch& mismatch : mismatches) { int output_index = mismatch.output_shape_index(); if (output_index >= num_outputs) { diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index 2720679a32e234..90f529b3f32173 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -211,6 +211,7 @@ cc_library( "//xla/stream_executor/host:host_platform", "//xla/stream_executor/platform:initialize", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc index aec10578771426..9d08ee5e4f403c 100644 --- a/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc +++ b/third_party/xla/xla/tools/hlo_opt/compiled_opt_lib.cc @@ -66,12 +66,12 @@ namespace xla { absl::StatusOr CompiledOptProvider::GetExecutor() { DebugOptions debug_opts = GetDebugOptionsFromFlags(); - TF_ASSIGN_OR_RETURN(se::Platform * platform, - PlatformUtil::GetPlatform(GetPlatformName())); + ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(GetPlatformName())); if (debug_opts.xla_gpu_target_config_filename().empty()) { - TF_ASSIGN_OR_RETURN(std::vector stream_executors, - PlatformUtil::GetStreamExecutors( - platform, /*allowed_devices=*/std::nullopt)); + ASSIGN_OR_RETURN(std::vector stream_executors, + PlatformUtil::GetStreamExecutors( + platform, /*allowed_devices=*/std::nullopt)); return stream_executors[0]; } return nullptr; @@ -80,17 +80,17 @@ absl::StatusOr CompiledOptProvider::GetExecutor() { absl::StatusOr> CompiledOptProvider::GenerateStage( std::unique_ptr module, absl::string_view stage) { if (stage == "hlo") { - TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_module, - GetOptimizedHlo(std::move(module))); + ASSIGN_OR_RETURN(std::unique_ptr optimized_module, + GetOptimizedHlo(std::move(module))); return optimized_module->ToString(); } else if (stage == "html") { - TF_ASSIGN_OR_RETURN(std::unique_ptr optimized_module, - GetOptimizedHlo(std::move(module))); - TF_ASSIGN_OR_RETURN(std::string cmps, - RenderAllComputationsToHtml(*optimized_module)); + ASSIGN_OR_RETURN(std::unique_ptr optimized_module, + GetOptimizedHlo(std::move(module))); + ASSIGN_OR_RETURN(std::string cmps, + RenderAllComputationsToHtml(*optimized_module)); return cmps; } else if (stage == "hlo-backend") { - TF_ASSIGN_OR_RETURN(auto executable, GetExecutable(std::move(module))); + ASSIGN_OR_RETURN(auto executable, GetExecutable(std::move(module))); return executable->module().ToString(); } @@ -98,21 +98,21 @@ absl::StatusOr> CompiledOptProvider::GenerateStage( } absl::StatusOr> CompiledOptProvider::GetCompiler() { - TF_ASSIGN_OR_RETURN(se::Platform * platform, - PlatformUtil::GetPlatform(GetPlatformName())); + ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(GetPlatformName())); - TF_ASSIGN_OR_RETURN(std::unique_ptr compiler, - Compiler::GetForPlatform(platform->id())); + ASSIGN_OR_RETURN(std::unique_ptr compiler, + Compiler::GetForPlatform(platform->id())); return compiler; } absl::StatusOr> CompiledOptProvider::GetOptimizedHlo( std::unique_ptr input_module) { - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, GetExecutor()); + ASSIGN_OR_RETURN(se::StreamExecutor * executor, GetExecutor()); DebugOptions debug_opts = GetDebugOptionsFromFlags(); Compiler::CompileOptions opts; - TF_ASSIGN_OR_RETURN(std::unique_ptr compiler, GetCompiler()); + ASSIGN_OR_RETURN(std::unique_ptr compiler, GetCompiler()); DebugOptions d = input_module->config().debug_options(); d.set_xla_embed_ir_in_executable(true); input_module->mutable_config().set_debug_options(d); @@ -122,7 +122,7 @@ absl::StatusOr> CompiledOptProvider::GetOptimizedHlo( } // But run-hlo-passes does not actually run the scheduling. - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::unique_ptr optimized_module, compiler->RunHloPasses(std::move(input_module), executor, opts)); diff --git a/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc index 8057e720862911..965246b1236e41 100644 --- a/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/cpu_opt.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/CodeGen.h" #include "llvm/Target/TargetOptions.h" @@ -84,8 +85,8 @@ class CpuOptProvider : public CompiledOptProvider { absl::StatusOr> GenerateStage( std::unique_ptr module, absl::string_view s) override { if (s == "llvm-before-optimizations") { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - GetExecutable(std::move(module))); + ASSIGN_OR_RETURN(std::unique_ptr executable, + GetExecutable(std::move(module))); return static_cast(executable.get()) ->ir_module_string(); } diff --git a/third_party/xla/xla/tsl/concurrency/BUILD b/third_party/xla/xla/tsl/concurrency/BUILD index 4c0001a66af583..4269717413a7b3 100644 --- a/third_party/xla/xla/tsl/concurrency/BUILD +++ b/third_party/xla/xla/tsl/concurrency/BUILD @@ -169,6 +169,7 @@ tsl_cc_test( ":executor", ":future", "//xla/tsl/platform:env", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", diff --git a/third_party/xla/xla/tsl/concurrency/future_test.cc b/third_party/xla/xla/tsl/concurrency/future_test.cc index 6bdc1996e77439..761368ebb2605a 100644 --- a/third_party/xla/xla/tsl/concurrency/future_test.cc +++ b/third_party/xla/xla/tsl/concurrency/future_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/concurrency/executor.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/statusor.h" @@ -41,6 +42,7 @@ using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::testing::Not; +using ::testing::Pointee; // Check that we correctly detect move-only types. static_assert(internal::IsMoveOnly>::value); @@ -188,7 +190,7 @@ TEST(FutureTest, ValueImplicitConversion) { TEST(FutureTest, StatusMacro) { auto f = [&](absl::StatusOr value) -> tsl::Future { - TF_ASSIGN_OR_RETURN(const int x, value); + ASSIGN_OR_RETURN(const int x, value); return x; }; @@ -211,8 +213,13 @@ TEST(FutureTest, OnReadyRvalueFuture) { promise.Set(42); - std::move(future).OnReady( - [](absl::StatusOr value) { EXPECT_EQ(*value, 42); }); + future.OnReady([](const absl::StatusOr& value) { + EXPECT_THAT(value, IsOkAndHolds(42)); + }); + + std::move(future).OnReady([](absl::StatusOr value) { + EXPECT_THAT(value, IsOkAndHolds(42)); + }); } TEST(FutureTest, OnReadyMoveOnlyFuture) { @@ -220,8 +227,12 @@ TEST(FutureTest, OnReadyMoveOnlyFuture) { promise.Set(std::make_unique(42)); + future.OnReady([](const absl::StatusOr>& value) { + EXPECT_THAT(value, IsOkAndHolds(Pointee(42))); + }); + std::move(future).OnReady([](absl::StatusOr> value) { - EXPECT_EQ(**value, 42); + EXPECT_THAT(value, IsOkAndHolds(Pointee(42))); }); } diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD index 42c9d1920bb5a4..4fa2a68cbde121 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD @@ -60,6 +60,7 @@ cc_library( "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/protobuf:coordination_config_proto_cc", "//xla/tsl/protobuf:coordination_service_proto_cc", "//xla/tsl/util:device_name_utils", @@ -289,6 +290,7 @@ tsl_cc_test( "//xla/tsl/platform:env_impl", # buildcleaner: keep "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:test", "//xla/tsl/protobuf:coordination_config_proto_cc_impl", "//xla/tsl/protobuf:coordination_service_proto_cc_impl", diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc index b78a62726800bb..f915d3f863e45e 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/client_server_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "grpcpp/channel.h" #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" @@ -255,7 +256,7 @@ TEST_F(ClientServerTest, ConnectAndShutdownAreBarriers) { mu.Await(absl::Condition(&my_connect_turn)); ++connect_count; } - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); // Verify that all of the threads have called Connect() by the time we get // here. { @@ -276,7 +277,7 @@ TEST_F(ClientServerTest, ConnectAndShutdownAreBarriers) { mu.Await(absl::Condition(&my_shutdown_turn)); ++shutdown_count; } - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Shutdown()); { absl::MutexLock lock(mu); if (shutdown_count != num_nodes) { @@ -310,7 +311,7 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { /*init_and_shutdown_timeout=*/absl::Seconds(3), /*shutdown_on_destruction=*/node_id != 0); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); if (node_id == 0) { return absl::OkStatus(); @@ -354,7 +355,7 @@ TEST_F(ClientServerTest, ClientsShutdownSuccessfully) { auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); return client->Shutdown(); // The error polling request will be cancelled automatically when the // client is shutting down. @@ -388,7 +389,7 @@ TEST_F(ClientServerTest, MissedHeartbeatCallbackIsExecutedIfAnyClientGoesAway) { shutdown.Notify(); }); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); if (node_id == 0) { return absl::OkStatus(); @@ -473,12 +474,12 @@ TEST_F(ClientServerTest, ClientsTerminateIfServiceGoesAway) { shutdown.Notify(); }); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); barrier.Block(); shutdown.WaitForNotification(); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); }; @@ -510,8 +511,8 @@ TEST_F(ClientServerTest, LateClientsAreOk) { barrier.Block(); absl::SleepFor(absl::Milliseconds(200) * node_id); - TF_RETURN_IF_ERROR(client->Connect()); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); }; @@ -536,8 +537,8 @@ TEST_F(ClientServerTest, ConnectEventuallyTimesOutIfAClientDoesNotShowUp) { auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); - TF_RETURN_IF_ERROR(client->Connect()); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); }; @@ -568,7 +569,7 @@ TEST_F(ClientServerTest, ClientRestart_AfterConnect_Fails) { auto client = GetClient(node_id, /*init_and_shutdown_timeout=*/absl::Seconds(5)); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); // All clients have successfully connected at this point. // Simulate client restart by creating a new client. if (node_id == 2) { @@ -580,7 +581,7 @@ TEST_F(ClientServerTest, ClientRestart_AfterConnect_Fails) { return status; } n.WaitForNotification(); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); }; @@ -627,14 +628,14 @@ TEST_F(ClientServerTest, ClientRestart_DuringConnect_Succeeds) { // 3. Node 1 connects. // 4. All attempts succeed, except the initial node 2 connection attempt. if (node_id == 0) { - TF_RETURN_IF_ERROR(client->Connect()); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); } else if (node_id == 1) { node_2_restarted.WaitForNotification(); absl::SleepFor(absl::Seconds(1)); // Give time for node 2 to connect. - TF_RETURN_IF_ERROR(client->Connect()); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); } else if (node_id == 2 && !restarted_node_2) { previous_node_2_connecting.Notify(); @@ -644,8 +645,8 @@ TEST_F(ClientServerTest, ClientRestart_DuringConnect_Succeeds) { previous_node_2_connecting.WaitForNotification(); absl::SleepFor(absl::Seconds(1)); // Give time for node 2 to connect. node_2_restarted.Notify(); - TF_RETURN_IF_ERROR(client->Connect()); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); } }; @@ -672,12 +673,12 @@ TEST_F(ClientServerTest, WaitAtBarrier_Succeed) { auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); - TF_RETURN_IF_ERROR(client->WaitAtBarrier("barrier_1", kBarrierTimeout, {})); - TF_RETURN_IF_ERROR(client->WaitAtBarrier("barrier_2", kBarrierTimeout, {})); + RETURN_IF_ERROR(client->WaitAtBarrier("barrier_1", kBarrierTimeout, {})); + RETURN_IF_ERROR(client->WaitAtBarrier("barrier_2", kBarrierTimeout, {})); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); }; @@ -701,7 +702,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_Timeout) { auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); // Node 1 waits for barrier to time out before proceeding. if (node_id == 1) { @@ -713,9 +714,9 @@ TEST_F(ClientServerTest, WaitAtBarrier_Timeout) { if (node_id == 0) { n.Notify(); } - TF_RETURN_IF_ERROR(barrier_status); + RETURN_IF_ERROR(barrier_status); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); }; @@ -741,7 +742,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_TimeoutWithDifferentBarrierId) { auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); std::string barrier_id; if (node_id == 0) { @@ -749,9 +750,9 @@ TEST_F(ClientServerTest, WaitAtBarrier_TimeoutWithDifferentBarrierId) { } else if (node_id == 1) { barrier_id = "barrier_1"; } - TF_RETURN_IF_ERROR(client->WaitAtBarrier(barrier_id, kBarrierTimeout, {})); + RETURN_IF_ERROR(client->WaitAtBarrier(barrier_id, kBarrierTimeout, {})); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); }; @@ -775,14 +776,14 @@ TEST_F(ClientServerTest, WaitAtBarrier_ReuseSameId_Succeeds) { auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); - TF_RETURN_IF_ERROR(client->WaitAtBarrier("barrier_1", kBarrierTimeout, {})); - TF_RETURN_IF_ERROR(client->WaitAtBarrier("barrier_2", kBarrierTimeout, {})); - TF_RETURN_IF_ERROR(client->WaitAtBarrier("barrier_1", kBarrierTimeout, {})); - TF_RETURN_IF_ERROR(client->WaitAtBarrier("barrier_2", kBarrierTimeout, {})); + RETURN_IF_ERROR(client->WaitAtBarrier("barrier_1", kBarrierTimeout, {})); + RETURN_IF_ERROR(client->WaitAtBarrier("barrier_2", kBarrierTimeout, {})); + RETURN_IF_ERROR(client->WaitAtBarrier("barrier_1", kBarrierTimeout, {})); + RETURN_IF_ERROR(client->WaitAtBarrier("barrier_2", kBarrierTimeout, {})); - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); }; @@ -929,14 +930,14 @@ TEST_F(ClientServerTest, WaitAtBarrierSubset_Succeeds) { auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); if (node_id != 2) { - TF_RETURN_IF_ERROR(client->WaitAtBarrier("barrier_1", kBarrierTimeout, - {GetTask(0), GetTask(1)})); + RETURN_IF_ERROR(client->WaitAtBarrier("barrier_1", kBarrierTimeout, + {GetTask(0), GetTask(1)})); } - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); }; @@ -1013,7 +1014,7 @@ TEST_F(ClientServerTest, auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); // Node 0 will be notified only after the barrier has failed and will thus // fail too. @@ -1060,13 +1061,13 @@ TEST_F(ClientServerTest, GetAliveTasks_Succeed) { auto thread_fn = [&](int node_id) -> absl::Status { auto client = GetClient(node_id); - TF_RETURN_IF_ERROR(client->Connect()); + RETURN_IF_ERROR(client->Connect()); absl::StatusOr> alive_tasks = client->GetAliveTasks({GetTask(0), GetTask(1)}); if (!alive_tasks.ok()) { return alive_tasks.status(); } - TF_RETURN_IF_ERROR(client->Shutdown()); + RETURN_IF_ERROR(client->Shutdown()); return absl::OkStatus(); }; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index 4a970f83aa0f80..b89e90cffcf1ba 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -43,6 +43,7 @@ limitations under the License. #include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" @@ -1186,7 +1187,7 @@ absl::Status CoordinationService::InitializeBarrier( barrier->result = absl::UnknownError("Invalid barrier result."); barrier->initiating_task = task; barrier->done_callbacks.clear(); - TF_RETURN_IF_ERROR(InitializeTasksAtBarrier(barrier, participating_tasks)); + RETURN_IF_ERROR(InitializeTasksAtBarrier(barrier, participating_tasks)); barrier->num_pending_tasks = barrier->tasks_at_barrier.size(); diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD b/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD index a26f1de6842724..c189214a83cafc 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD @@ -58,6 +58,7 @@ cc_library( "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "//xla/tsl/lib/monitoring:gauge", "//xla/tsl/platform:env", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc index f5a0a0a4f3313a..37e5aef1abf0d3 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/preemption/preemption_notifier.h" #include "xla/tsl/lib/monitoring/gauge.h" @@ -83,7 +84,7 @@ absl::Status PreemptionSyncManager::Initialize( absl::Status PreemptionSyncManager::Initialize( CoordinationServiceAgent* agent, const std::string& preemption_notifier_type) { - TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv()); + ASSIGN_OR_RETURN(Env * env, agent->GetEnv()); return Initialize(agent, PreemptionNotifier::CreatePreemptionNotifier( preemption_notifier_type, env)); } @@ -96,11 +97,11 @@ absl::Status PreemptionSyncManager::Initialize( CHECK(!shut_down_); } - TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv()); + ASSIGN_OR_RETURN(Env * env, agent->GetEnv()); env_ = env; agent_ = agent; preemption_notifier_ = std::move(notifier); - TF_ASSIGN_OR_RETURN(CoordinatedTask own_task, agent->GetOwnTask()); + ASSIGN_OR_RETURN(CoordinatedTask own_task, agent->GetOwnTask()); const std::string task_name = absl::StrCat("/job:", own_task.job_name(), "/task:", own_task.task_id()); current_call_counter_key_ = absl::StrCat(kPreemptionCounterDirKey, task_name); diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD index 249260b4060826..0260caf5198217 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD @@ -95,6 +95,7 @@ cc_library( "//xla/tsl/platform:logging", "//xla/tsl/platform:macros", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "//xla/tsl/protobuf:rpc_options_proto_cc", "//xla/tsl/util:device_name_utils", diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc index becff37233c630..e7ff75b9c0fc6e 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "grpcpp/create_channel.h" #include "xla/tsl/distributed_runtime/rpc/grpc_channel_common.h" #include "xla/tsl/lib/gtl/map_util.h" @@ -147,7 +148,7 @@ absl::Status NewHostPortGrpcChannel(const std::string& target, const RPCOptions* rpc_options, SharedGrpcChannelPtr* channel_pointer) { // Minimally ensure that the target is valid - TF_RETURN_IF_ERROR(ValidateHostPortPair(target)); + RETURN_IF_ERROR(ValidateHostPortPair(target)); ::grpc::ChannelArguments args = GetChannelArguments(rpc_options); *channel_pointer = ::grpc::CreateCustomChannel( @@ -178,7 +179,7 @@ absl::Status GrpcChannelSpec::AddHostPortsJob( absl::StrCat("Duplicate job ID in cluster specification: ", job_id)); } for (const auto& id_host_port : host_ports) { - TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second)); + RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second)); } host_ports_jobs_.emplace_back(job_id, host_ports); return absl::OkStatus(); diff --git a/third_party/xla/xla/tsl/framework/BUILD b/third_party/xla/xla/tsl/framework/BUILD index 7abf5b14e6282c..675de5487fc6eb 100644 --- a/third_party/xla/xla/tsl/framework/BUILD +++ b/third_party/xla/xla/tsl/framework/BUILD @@ -285,6 +285,7 @@ cc_library( ":device_type", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:device_name_utils", "@com_google_absl//absl/algorithm:container", diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.cc b/third_party/xla/xla/tsl/framework/device_id_utils.cc index d26624d05e2982..1b393b5d847528 100644 --- a/third_party/xla/xla/tsl/framework/device_id_utils.cc +++ b/third_party/xla/xla/tsl/framework/device_id_utils.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/framework/device_id.h" #include "xla/tsl/framework/device_id_manager.h" #include "xla/tsl/framework/device_type.h" @@ -93,7 +94,7 @@ absl::Status ParseVisibleDeviceList( tsl::str_util::Split(visible_device_list, ','); // non-absl ok for (const std::string& platform_device_id_str : order_str) { int32_t platform_device_id; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( platform_device_id, ParsePlatformDeviceIdString(platform_device_id_str, device_type)); if (platform_device_id == -1) { @@ -136,16 +137,16 @@ absl::StatusOr GetNumberTfDevicesAndConfigurePlatformDeviceId( return 0; } std::vector visible_device_order; - TF_RETURN_IF_ERROR(ParseVisibleDeviceList( - std::string(visible_device_list), visible_device_count, - &visible_device_order, device_type)); + RETURN_IF_ERROR(ParseVisibleDeviceList(std::string(visible_device_list), + visible_device_count, + &visible_device_order, device_type)); if (num_tf_devices > visible_device_order.size()) { num_tf_devices = visible_device_order.size(); } for (int i = 0; i < num_tf_devices; ++i) { const PlatformDeviceId platform_device_id = visible_device_order[i]; const TfDeviceId tf_device_id(i); - TF_RETURN_IF_ERROR(tsl::DeviceIdManager::InsertTfPlatformDeviceIdPair( + RETURN_IF_ERROR(tsl::DeviceIdManager::InsertTfPlatformDeviceIdPair( DeviceType(device_type), tf_device_id, platform_device_id)); } return num_tf_devices; diff --git a/third_party/xla/xla/tsl/lib/core/BUILD b/third_party/xla/xla/tsl/lib/core/BUILD index 3c304b17bf1ddd..aca26c9dd8199e 100644 --- a/third_party/xla/xla/tsl/lib/core/BUILD +++ b/third_party/xla/xla/tsl/lib/core/BUILD @@ -42,6 +42,7 @@ cc_library( deps = [ "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:test", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status:status_matchers", ], ) diff --git a/third_party/xla/xla/tsl/lib/core/status_test_util.h b/third_party/xla/xla/tsl/lib/core/status_test_util.h index 811a42f4d5fab7..20fcdb5ba05922 100644 --- a/third_party/xla/xla/tsl/lib/core/status_test_util.h +++ b/third_party/xla/xla/tsl/lib/core/status_test_util.h @@ -16,13 +16,27 @@ limitations under the License. #ifndef XLA_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ #define XLA_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ +#include "absl/base/attributes.h" #include "absl/status/status_matchers.h" // IWYU pragma: keep #include "xla/tsl/platform/status_matchers.h" #include "xla/tsl/platform/test.h" +namespace tsl { +ABSL_DEPRECATED("TF_EXPECT_OK is deprecated. Call EXPECT_OK instead") +inline void TfExpectOkDeprecationMarker() {} + +ABSL_DEPRECATED("TF_ASSERT_OK is deprecated. Call ASSERT_OK instead") +inline void TfAssertOkDeprecationMarker() {} +} // namespace tsl + // Macros for testing the results of functions that return tensorflow::Status. -#define TF_EXPECT_OK(statement) EXPECT_THAT((statement), ::absl_testing::IsOk()) -#define TF_ASSERT_OK(statement) ASSERT_THAT((statement), ::absl_testing::IsOk()) +#define TF_EXPECT_OK(statement) \ + EXPECT_THAT((::tsl::TfExpectOkDeprecationMarker(), (statement)), \ + ::absl_testing::IsOk()) + +#define TF_ASSERT_OK(statement) \ + ASSERT_THAT((::tsl::TfAssertOkDeprecationMarker(), (statement)), \ + ::absl_testing::IsOk()) // There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not // provide much value (when they fail, they would just print the OK status diff --git a/third_party/xla/xla/tsl/lib/io/snappy/BUILD b/third_party/xla/xla/tsl/lib/io/snappy/BUILD index d62c834064e295..230eb8ddad3978 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/BUILD +++ b/third_party/xla/xla/tsl/lib/io/snappy/BUILD @@ -39,6 +39,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:macros", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -60,6 +61,7 @@ cc_library( "//xla/tsl/platform:env", "//xla/tsl/platform:macros", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "@tsl//tsl/platform", "@tsl//tsl/platform:platform_port", @@ -74,6 +76,7 @@ cc_library( deps = [ "//xla/tsl/lib/io:inputstream_interface", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -105,6 +108,7 @@ tsl_cc_test( "//xla/tsl/lib/io:random_inputstream", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:test", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.cc index c1b222fe72d4a8..7c99b4245ea947 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/file_system.h" #include "tsl/platform/snappy.h" @@ -62,7 +63,7 @@ absl::Status SnappyInputBuffer::ReadNBytes(int64_t bytes_to_read, DCHECK_EQ(avail_out_, 0); // Now that the cache is empty we need to inflate more data. - TF_RETURN_IF_ERROR(Inflate()); + RETURN_IF_ERROR(Inflate()); bytes_read = ReadBytesFromCache(bytes_to_read, result_ptr); bytes_to_read -= bytes_read; @@ -98,11 +99,11 @@ size_t SnappyInputBuffer::ReadBytesFromCache(size_t bytes_to_read, absl::Status SnappyInputBuffer::Inflate() { // Read length of compressed block. uint32_t compressed_block_length; - TF_RETURN_IF_ERROR(ReadCompressedBlockLength(&compressed_block_length)); + RETURN_IF_ERROR(ReadCompressedBlockLength(&compressed_block_length)); // If the entire block is not in cache do a read from file. if (avail_in_ < compressed_block_length) { - TF_RETURN_IF_ERROR(ReadFromFile()); + RETURN_IF_ERROR(ReadFromFile()); if (avail_in_ < compressed_block_length) { if (compressed_block_length > input_buffer_capacity_) { return absl::ResourceExhaustedError( @@ -146,7 +147,7 @@ absl::Status SnappyInputBuffer::ReadCompressedBlockLength(uint32_t* length) { size_t bytes_to_read = 4; while (bytes_to_read > 0) { if (avail_in_ == 0) { - TF_RETURN_IF_ERROR(ReadFromFile()); + RETURN_IF_ERROR(ReadFromFile()); } size_t readable = std::min(bytes_to_read, avail_in_); diff --git a/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc index f9865d7e7ca1f0..77105cb8cac22a 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/lib/io/inputstream_interface.h" #include "xla/tsl/platform/errors.h" #include "tsl/platform/snappy.h" @@ -69,7 +70,7 @@ absl::Status SnappyInputStream::ReadNBytes(int64_t bytes_to_read, DCHECK_EQ(avail_out_, 0); // Fill the cache with more data. - TF_RETURN_IF_ERROR(Inflate()); + RETURN_IF_ERROR(Inflate()); size_t bytes_read = ReadBytesFromCache(bytes_to_read, result_ptr); bytes_to_read -= bytes_read; @@ -84,7 +85,7 @@ absl::Status SnappyInputStream::ReadNBytes(int64_t bytes_to_read, absl::Cord* result) { // TODO(frankchn): Optimize this instead of bouncing through the buffer. tstring buf; - TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &buf)); + RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &buf)); result->Clear(); result->Append(buf.data()); return absl::OkStatus(); @@ -95,7 +96,7 @@ absl::Status SnappyInputStream::Inflate() { tstring compressed_block_length_ts; uint32_t compressed_block_length; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( input_stream_->ReadNBytes(sizeof(uint32_t), &compressed_block_length_ts)); for (int i = 0; i < sizeof(uint32_t); ++i) { compressed_block_length = @@ -113,7 +114,7 @@ absl::Status SnappyInputStream::Inflate() { absl::StrCat("Failed to read ", compressed_block_length, " bytes from file. Possible data corruption.")); } - TF_RETURN_IF_ERROR(s); + RETURN_IF_ERROR(s); size_t uncompressed_length; if (!port::Snappy_GetUncompressedLength(compressed_block.data(), @@ -155,7 +156,7 @@ size_t SnappyInputStream::ReadBytesFromCache(size_t bytes_to_read, int64_t SnappyInputStream::Tell() const { return bytes_read_; } absl::Status SnappyInputStream::Reset() { - TF_RETURN_IF_ERROR(input_stream_->Reset()); + RETURN_IF_ERROR(input_stream_->Reset()); avail_out_ = 0; bytes_read_ = 0; return absl::OkStatus(); diff --git a/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.cc index 41fde85ca7f7b9..fea84b719bff2f 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "xla/tsl/platform/status_macros.h" + namespace tsl { namespace io { @@ -47,7 +49,7 @@ absl::Status SnappyOutputBuffer::Append(absl::string_view data) { #if defined(TF_CORD_SUPPORT) absl::Status SnappyOutputBuffer::Append(const absl::Cord& cord) { for (absl::string_view fragment : cord.Chunks()) { - TF_RETURN_IF_ERROR(Append(fragment)); + RETURN_IF_ERROR(Append(fragment)); } return absl::OkStatus(); } @@ -63,7 +65,7 @@ absl::Status SnappyOutputBuffer::Name(absl::string_view* result) const { } absl::Status SnappyOutputBuffer::Sync() { - TF_RETURN_IF_ERROR(Flush()); + RETURN_IF_ERROR(Flush()); return file_->Sync(); } @@ -88,7 +90,7 @@ absl::Status SnappyOutputBuffer::Write(absl::string_view data) { // If there isn't enough available space in the input_buffer_ we empty it // by uncompressing its contents. If data now fits in input_buffer_ // we add it there else we directly deflate it. - TF_RETURN_IF_ERROR(DeflateBuffered()); + RETURN_IF_ERROR(DeflateBuffered()); // input_buffer_ should be empty at this point. if (static_cast(bytes_to_write) <= AvailableInputSpace()) { @@ -102,7 +104,7 @@ absl::Status SnappyOutputBuffer::Write(absl::string_view data) { next_in_ = const_cast(data.data()); avail_in_ = bytes_to_write; - TF_RETURN_IF_ERROR(Deflate()); + RETURN_IF_ERROR(Deflate()); DCHECK_EQ(avail_in_, 0); // All input will be used up. @@ -112,8 +114,8 @@ absl::Status SnappyOutputBuffer::Write(absl::string_view data) { } absl::Status SnappyOutputBuffer::Flush() { - TF_RETURN_IF_ERROR(DeflateBuffered()); - TF_RETURN_IF_ERROR(FlushOutputBufferToFile()); + RETURN_IF_ERROR(DeflateBuffered()); + RETURN_IF_ERROR(FlushOutputBufferToFile()); return absl::OkStatus(); } @@ -166,14 +168,14 @@ absl::Status SnappyOutputBuffer::AddToOutputBuffer(const char* data, avail_out_ -= bytes_to_copy; length -= bytes_to_copy; if (avail_out_ == 0) { - TF_RETURN_IF_ERROR(FlushOutputBufferToFile()); + RETURN_IF_ERROR(FlushOutputBufferToFile()); } } return absl::OkStatus(); } absl::Status SnappyOutputBuffer::DeflateBuffered() { - TF_RETURN_IF_ERROR(Deflate()); + RETURN_IF_ERROR(Deflate()); DCHECK_EQ(avail_in_, 0); next_in_ = input_buffer_.get(); return absl::OkStatus(); @@ -209,10 +211,10 @@ absl::Status SnappyOutputBuffer::Deflate() { // Little endian. compressed_length_array[i] = output.size() >> (8 * (3 - i)); } - TF_RETURN_IF_ERROR(AddToOutputBuffer(compressed_length_array, 4)); + RETURN_IF_ERROR(AddToOutputBuffer(compressed_length_array, 4)); // Write compressed output to buffer. - TF_RETURN_IF_ERROR(AddToOutputBuffer(output.data(), output.size())); + RETURN_IF_ERROR(AddToOutputBuffer(output.data(), output.size())); next_in_ += avail_in_; avail_in_ = 0; diff --git a/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc index e548ff7fd945c6..d75d57251351f3 100644 --- a/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/lib/io/random_inputstream.h" #include "xla/tsl/lib/io/snappy/snappy_inputbuffer.h" #include "xla/tsl/lib/io/snappy/snappy_inputstream.h" @@ -84,30 +85,29 @@ absl::Status TestMultipleWritesWriteFile(size_t compress_input_buf_size, data = GenTestString(num_copies); std::unique_ptr file_writer; - TF_RETURN_IF_ERROR(env->NewWritableFile(fname, &file_writer)); + RETURN_IF_ERROR(env->NewWritableFile(fname, &file_writer)); io::SnappyOutputBuffer out(file_writer.get(), compress_input_buf_size, compress_output_buf_size); for (int i = 0; i < num_writes; i++) { - TF_RETURN_IF_ERROR(out.Write(absl::string_view(data))); + RETURN_IF_ERROR(out.Write(absl::string_view(data))); if (with_flush) { - TF_RETURN_IF_ERROR(out.Flush()); + RETURN_IF_ERROR(out.Flush()); } absl::StrAppend(&expected_result, data); } - TF_RETURN_IF_ERROR(out.Flush()); - TF_RETURN_IF_ERROR(file_writer->Flush()); - TF_RETURN_IF_ERROR(file_writer->Close()); + RETURN_IF_ERROR(out.Flush()); + RETURN_IF_ERROR(file_writer->Flush()); + RETURN_IF_ERROR(file_writer->Close()); if (corrupt_compressed_file) { std::string corrupt_fname = testing::TmpDir() + "/snappy_buffers_test_corrupt"; std::unique_ptr corrupt_file_writer; - TF_RETURN_IF_ERROR( - env->NewWritableFile(corrupt_fname, &corrupt_file_writer)); + RETURN_IF_ERROR(env->NewWritableFile(corrupt_fname, &corrupt_file_writer)); std::unique_ptr file_reader; - TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader)); + RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader)); absl::string_view data; size_t file_pos = 0; @@ -152,12 +152,12 @@ absl::Status TestMultipleWrites(size_t compress_input_buf_size, std::string fname; std::string data; - TF_RETURN_IF_ERROR(TestMultipleWritesWriteFile( + RETURN_IF_ERROR(TestMultipleWritesWriteFile( compress_input_buf_size, compress_output_buf_size, num_writes, with_flush, num_copies, corrupt_compressed_file, fname, data, expected_result)); std::unique_ptr file_reader; - TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader)); + RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader)); io::SnappyInputBuffer in(file_reader.get(), uncompress_input_buf_size, uncompress_output_buf_size); @@ -166,14 +166,14 @@ absl::Status TestMultipleWrites(size_t compress_input_buf_size, std::string actual_result; for (int i = 0; i < num_writes; i++) { tstring decompressed_output; - TF_RETURN_IF_ERROR(in.ReadNBytes(data.size(), &decompressed_output)); + RETURN_IF_ERROR(in.ReadNBytes(data.size(), &decompressed_output)); absl::StrAppend(&actual_result, decompressed_output); } if (actual_result != expected_result) { return absl::DataLossError("Actual and expected results don't match."); } - TF_RETURN_IF_ERROR(in.Reset()); + RETURN_IF_ERROR(in.Reset()); } return absl::OkStatus(); @@ -190,12 +190,12 @@ absl::Status TestMultipleWritesInputStream( std::string fname; std::string data; - TF_RETURN_IF_ERROR(TestMultipleWritesWriteFile( + RETURN_IF_ERROR(TestMultipleWritesWriteFile( compress_input_buf_size, compress_output_buf_size, num_writes, with_flush, num_copies, corrupt_compressed_file, fname, data, expected_result)); std::unique_ptr file_reader; - TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader)); + RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader)); io::RandomAccessInputStream random_input_stream(file_reader.get(), false); io::SnappyInputStream snappy_input_stream(&random_input_stream, uncompress_output_buf_size); @@ -204,7 +204,7 @@ absl::Status TestMultipleWritesInputStream( std::string actual_result; for (int i = 0; i < num_writes; ++i) { tstring decompressed_output; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( snappy_input_stream.ReadNBytes(data.size(), &decompressed_output)); absl::StrAppend(&actual_result, decompressed_output); } @@ -212,7 +212,7 @@ absl::Status TestMultipleWritesInputStream( if (actual_result != expected_result) { return absl::DataLossError("Actual and expected results don't match."); } - TF_RETURN_IF_ERROR(snappy_input_stream.Reset()); + RETURN_IF_ERROR(snappy_input_stream.Reset()); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/tsl/lib/monitoring/BUILD b/third_party/xla/xla/tsl/lib/monitoring/BUILD index 8d90b430a45edf..813a402bb63fda 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/BUILD +++ b/third_party/xla/xla/tsl/lib/monitoring/BUILD @@ -213,6 +213,7 @@ cc_library( ":test_utils", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.cc b/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.cc index 4aa0e12b988734..54d80292e41d0c 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.cc +++ b/third_party/xla/xla/tsl/lib/monitoring/cell_reader-inl.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/lib/monitoring/collected_metrics.h" #include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/tsl/lib/monitoring/metric_def.h" @@ -97,8 +98,8 @@ absl::StatusOr> GetPoints( absl::StatusOr GetLatestPoint(const CollectedMetrics& metrics, const std::string& metric_name, const std::vector& labels) { - TF_ASSIGN_OR_RETURN(std::vector points, - GetPoints(metrics, metric_name, labels)); + ASSIGN_OR_RETURN(std::vector points, + GetPoints(metrics, metric_name, labels)); if (points.empty()) { return absl::UnavailableError( absl::StrCat("No data collected for metric ", metric_name, diff --git a/third_party/xla/xla/tsl/platform/cloud/BUILD b/third_party/xla/xla/tsl/platform/cloud/BUILD index a3e01532382ab6..9ffd42df9b9d1d 100644 --- a/third_party/xla/xla/tsl/platform/cloud/BUILD +++ b/third_party/xla/xla/tsl/platform/cloud/BUILD @@ -67,6 +67,7 @@ cc_library( ":file_block_cache", "//xla/tsl/platform:env", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status", @@ -132,6 +133,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:file_statistics", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -180,6 +182,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:file_statistics", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -229,6 +232,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:macros", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "//xla/tsl/util:env_var", "@com_google_absl//absl/log", @@ -279,6 +283,7 @@ cc_library( "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -303,6 +308,7 @@ cc_library( ":curl_http_request", ":http_request", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/strings", "@tsl//tsl/platform:retrying_utils", ], @@ -322,6 +328,7 @@ cc_library( ":compute_engine_metadata_client", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "@tsl//tsl/platform:str_util", ], ) @@ -353,6 +360,7 @@ cc_library( "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "@boringssl//:crypto", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client.cc b/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client.cc index 2590dcd743e0d6..54919ce2a1cb22 100644 --- a/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client.cc +++ b/third_party/xla/xla/tsl/platform/cloud/compute_engine_metadata_client.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/cloud/curl_http_request.h" namespace tsl { @@ -42,7 +43,8 @@ ComputeEngineMetadataClient::ComputeEngineMetadataClient( absl::Status ComputeEngineMetadataClient::GetMetadata( const std::string& path, std::vector* response_buffer) { - const auto get_metadata_from_gce = [path, response_buffer, this]() { + const auto get_metadata_from_gce = [path, response_buffer, + this]() -> absl::Status { std::string metadata_url; const char* metadata_url_override = std::getenv(kGceMetadataHost); if (metadata_url_override) { @@ -55,7 +57,7 @@ absl::Status ComputeEngineMetadataClient::GetMetadata( request->SetUri(metadata_url + path); request->AddHeader("Metadata-Flavor", "Google"); request->SetResultBuffer(response_buffer); - TF_RETURN_IF_ERROR(request->Send()); + RETURN_IF_ERROR(request->Send()); return absl::OkStatus(); }; diff --git a/third_party/xla/xla/tsl/platform/cloud/compute_engine_zone_provider.cc b/third_party/xla/xla/tsl/platform/cloud/compute_engine_zone_provider.cc index 2c1705802ebdb5..4d7a0ee9df56fa 100644 --- a/third_party/xla/xla/tsl/platform/cloud/compute_engine_zone_provider.cc +++ b/third_party/xla/xla/tsl/platform/cloud/compute_engine_zone_provider.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/tsl/platform/status_macros.h" #include "tsl/platform/str_util.h" namespace tsl { @@ -34,8 +35,8 @@ absl::Status ComputeEngineZoneProvider::GetZone(std::string* zone) { return absl::OkStatus(); } std::vector response_buffer; - TF_RETURN_IF_ERROR(google_metadata_client_->GetMetadata(kGceMetadataZonePath, - &response_buffer)); + RETURN_IF_ERROR(google_metadata_client_->GetMetadata(kGceMetadataZonePath, + &response_buffer)); absl::string_view location(&response_buffer[0], response_buffer.size()); std::vector elems = str_util::Split(location, "/"); diff --git a/third_party/xla/xla/tsl/platform/cloud/curl_http_request.cc b/third_party/xla/xla/tsl/platform/cloud/curl_http_request.cc index 8045c4848eca9d..dee1439cc7e416 100644 --- a/third_party/xla/xla/tsl/platform/cloud/curl_http_request.cc +++ b/third_party/xla/xla/tsl/platform/cloud/curl_http_request.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/types.h" @@ -459,7 +460,7 @@ absl::Status CurlHttpRequest::Send() { } const CURLcode curl_result = libcurl_->curl_easy_perform(curl_); - TF_RETURN_IF_ERROR(CURLcodeToStatus(curl_result, error_buffer)); + RETURN_IF_ERROR(CURLcodeToStatus(curl_result, error_buffer)); double written_size = 0; CHECK_CURL_OK(libcurl_->curl_easy_getinfo(curl_, CURLINFO_SIZE_DOWNLOAD, diff --git a/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.cc b/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.cc index 94300feb638e3c..92192bd1649558 100644 --- a/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.cc +++ b/third_party/xla/xla/tsl/platform/cloud/gcs_file_system.cc @@ -73,6 +73,7 @@ limitations under the License. #ifdef _WIN32 #include // for _mktemp #endif +#include "xla/tsl/platform/status_macros.h" #include "json/json.h" #include "xla/tsl/platform/cloud/curl_http_request.h" #include "xla/tsl/platform/cloud/file_block_cache.h" @@ -273,7 +274,7 @@ absl::Status GetValue(const Json::Value& parent, const char* name, absl::Status GetStringValue(const Json::Value& parent, const char* name, std::string* result) { Json::Value result_value; - TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value)); + RETURN_IF_ERROR(GetValue(parent, name, &result_value)); if (!result_value.isString()) { return absl::InternalError( absl::StrCat("The field '", name, @@ -287,7 +288,7 @@ absl::Status GetStringValue(const Json::Value& parent, const char* name, absl::Status GetInt64Value(const Json::Value& parent, const char* name, int64_t* result) { Json::Value result_value; - TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value)); + RETURN_IF_ERROR(GetValue(parent, name, &result_value)); if (result_value.isNumeric()) { *result = result_value.asInt64(); return absl::OkStatus(); @@ -305,7 +306,7 @@ absl::Status GetInt64Value(const Json::Value& parent, const char* name, absl::Status GetBoolValue(const Json::Value& parent, const char* name, bool* result) { Json::Value result_value; - TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value)); + RETURN_IF_ERROR(GetValue(parent, name, &result_value)); if (!result_value.isBool()) { return absl::InternalError( absl::StrCat("The field '", name, @@ -573,7 +574,7 @@ class GcsWritableFile : public WritableFile { } absl::Status Append(absl::string_view data) override { - TF_RETURN_IF_ERROR(CheckWritable()); + RETURN_IF_ERROR(CheckWritable()); VLOG(3) << "Append: " << GetGcsPath() << " size " << data.length(); sync_needed_ = true; outfile_ << data; @@ -608,7 +609,7 @@ class GcsWritableFile : public WritableFile { absl::Status Sync() override { VLOG(3) << "Sync started:" << GetGcsPath(); - TF_RETURN_IF_ERROR(CheckWritable()); + RETURN_IF_ERROR(CheckWritable()); if (!sync_needed_) { return absl::OkStatus(); } @@ -655,16 +656,16 @@ class GcsWritableFile : public WritableFile { io::Basename(object_), ".", start_offset_); } } - TF_RETURN_IF_ERROR(CreateNewUploadSession(start_offset, object_to_upload, - &session_handle)); + RETURN_IF_ERROR(CreateNewUploadSession(start_offset, object_to_upload, + &session_handle)); uint64_t already_uploaded = 0; bool first_attempt = true; const absl::Status upload_status = RetryingUtils::CallWithRetries( [&first_attempt, &already_uploaded, &session_handle, &start_offset, - this]() { + this]() -> absl::Status { if (session_handle.resumable && !first_attempt) { bool completed; - TF_RETURN_IF_ERROR(RequestUploadSessionStatus( + RETURN_IF_ERROR(RequestUploadSessionStatus( session_handle.session_uri, &completed, &already_uploaded)); LOG(INFO) << "### RequestUploadSessionStatus: completed = " << completed @@ -693,9 +694,9 @@ class GcsWritableFile : public WritableFile { } if (upload_status.ok()) { if (should_compose) { - TF_RETURN_IF_ERROR(AppendObject(object_to_upload)); + RETURN_IF_ERROR(AppendObject(object_to_upload)); } - TF_RETURN_IF_ERROR(GetCurrentFileSize(&start_offset_)); + RETURN_IF_ERROR(GetCurrentFileSize(&start_offset_)); } return upload_status; } @@ -723,7 +724,7 @@ class GcsWritableFile : public WritableFile { std::string object_to_upload, UploadSessionHandle* session_handle) { uint64_t file_size; - TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); + RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); return session_creator_(start_offset, object_to_upload, bucket_, file_size, GetGcsPath(), session_handle); } @@ -735,13 +736,13 @@ class GcsWritableFile : public WritableFile { VLOG(3) << "AppendObject: " << append_object_path << " to " << GetGcsPath(); int64_t generation = 0; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( generation_getter_(GetGcsPath(), bucket_, object_, &generation)); - TF_RETURN_IF_ERROR(RetryingUtils::CallWithRetries( - [&append_object, &generation, this]() { + RETURN_IF_ERROR(RetryingUtils::CallWithRetries( + [&append_object, &generation, this]() -> absl::Status { std::unique_ptr request; - TF_RETURN_IF_ERROR(filesystem_->CreateHttpRequest(&request)); + RETURN_IF_ERROR(filesystem_->CreateHttpRequest(&request)); request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket_, "/o/", request->EscapeString(object_), @@ -776,7 +777,7 @@ class GcsWritableFile : public WritableFile { absl::Status RequestUploadSessionStatus(const std::string& session_uri, bool* completed, uint64_t* uploaded) { uint64_t file_size; - TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); + RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); return status_poller_(session_uri, file_size, GetGcsPath(), completed, uploaded); } @@ -786,7 +787,7 @@ class GcsWritableFile : public WritableFile { uint64_t start_offset, uint64_t already_uploaded) { uint64_t file_size; - TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); + RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); absl::Status status = object_uploader_(session_uri, start_offset, already_uploaded, tmp_content_filename_, file_size, GetGcsPath()); @@ -1058,48 +1059,49 @@ GcsFileSystem::GcsFileSystem( absl::Status GcsFileSystem::NewRandomAccessFile( const std::string& fname, std::unique_ptr* result) { std::string bucket, object; - TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); - TF_RETURN_IF_ERROR(CheckBucketLocationConstraint(bucket)); + RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); + RETURN_IF_ERROR(CheckBucketLocationConstraint(bucket)); if (cache_enabled_) { - result->reset(new GcsRandomAccessFile(fname, [this, bucket, object]( - const std::string& fname, - uint64_t offset, size_t n, - absl::string_view* result, - char* scratch) { - absl::ReaderMutexLock l(block_cache_lock_); - GcsFileStat stat; - TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute( - fname, &stat, - [this, bucket, object](absl::string_view fname, GcsFileStat* stat) { - return UncachedStatForObject(fname, bucket, object, stat); - })); - if (!file_block_cache_->ValidateAndUpdateFileSignature( - fname, stat.generation_number)) { - VLOG(1) - << "File signature has been changed. Refreshing the cache. Path: " - << fname; - } - *result = absl::string_view(); - size_t bytes_transferred; - TF_RETURN_IF_ERROR(file_block_cache_->Read(fname, offset, n, scratch, - &bytes_transferred)); - *result = absl::string_view(scratch, bytes_transferred); - if (bytes_transferred < n) { - return absl::OutOfRangeError( - absl::StrCat("EOF reached, ", result->size(), - " bytes were read out of ", n, " bytes requested.")); - } - return absl::OkStatus(); - })); + result->reset(new GcsRandomAccessFile( + fname, + [this, bucket, object](const std::string& fname, uint64_t offset, + size_t n, absl::string_view* result, + char* scratch) -> absl::Status { + absl::ReaderMutexLock l(block_cache_lock_); + GcsFileStat stat; + RETURN_IF_ERROR(stat_cache_->LookupOrCompute( + fname, &stat, + [this, bucket, object](absl::string_view fname, + GcsFileStat* stat) { + return UncachedStatForObject(fname, bucket, object, stat); + })); + if (!file_block_cache_->ValidateAndUpdateFileSignature( + fname, stat.generation_number)) { + VLOG(1) << "File signature has been changed. Refreshing the cache. " + "Path: " + << fname; + } + *result = absl::string_view(); + size_t bytes_transferred; + RETURN_IF_ERROR(file_block_cache_->Read(fname, offset, n, scratch, + &bytes_transferred)); + *result = absl::string_view(scratch, bytes_transferred); + if (bytes_transferred < n) { + return absl::OutOfRangeError(absl::StrCat( + "EOF reached, ", result->size(), " bytes were read out of ", n, + " bytes requested.")); + } + return absl::OkStatus(); + })); } else { result->reset(new BufferedGcsRandomAccessFile( fname, block_size_, [this, bucket, object](const std::string& fname, uint64_t offset, size_t n, absl::string_view* result, - char* scratch) { + char* scratch) -> absl::Status { *result = absl::string_view(); size_t bytes_transferred; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( LoadBufferFromGCS(fname, offset, n, scratch, &bytes_transferred)); *result = absl::string_view(scratch, bytes_transferred); if (bytes_transferred < n) { @@ -1148,7 +1150,7 @@ absl::Status GcsFileSystem::LoadBufferFromGCS(const std::string& fname, *bytes_transferred = 0; std::string bucket, object; - TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); + RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); profiler::TraceMe activity( [fname]() { return absl::StrCat("LoadBufferFromGCS ", fname); }); @@ -1208,7 +1210,7 @@ absl::Status GcsFileSystem::CreateNewUploadSession( UploadSessionHandle* session_handle) { std::vector output_buffer; std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + RETURN_IF_ERROR(CreateHttpRequest(&request)); std::string uri = strings::StrCat( kGcsUploadUriBase, "b/", bucket, @@ -1238,7 +1240,7 @@ absl::Status GcsFileSystem::UploadToSession( uint64_t already_uploaded, const std::string& tmp_content_filename, uint64_t file_size, const std::string& file_path) { std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + RETURN_IF_ERROR(CreateHttpRequest(&request)); request->SetUri(session_uri); if (file_size > 0) { request->AddHeader("Content-Range", @@ -1248,8 +1250,8 @@ absl::Status GcsFileSystem::UploadToSession( } request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.write); - TF_RETURN_IF_ERROR(request->SetPutFromFile(tmp_content_filename, - start_offset + already_uploaded)); + RETURN_IF_ERROR(request->SetPutFromFile(tmp_content_filename, + start_offset + already_uploaded)); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when uploading ", file_path); return absl::OkStatus(); @@ -1263,7 +1265,7 @@ absl::Status GcsFileSystem::RequestUploadSessionStatus( CHECK(uploaded != nullptr) << "RequestUploadSessionStatus() called with out " "param 'uploaded' == nullptr."; // Crash ok std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + RETURN_IF_ERROR(CreateHttpRequest(&request)); request->SetUri(session_uri); request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata); request->AddHeader("Content-Range", absl::StrCat("bytes */", file_size)); @@ -1363,7 +1365,7 @@ void GcsFileSystem::ClearFileCaches(const std::string& fname) { absl::Status GcsFileSystem::NewWritableFile( const std::string& fname, std::unique_ptr* result) { std::string bucket, object; - TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); + RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); auto session_creator = [this](uint64_t start_offset, const std::string& object_to_upload, @@ -1386,12 +1388,11 @@ absl::Status GcsFileSystem::NewWritableFile( completed, uploaded); }; - auto generation_getter = [this](const std::string& fname, - const std::string& bucket, - const std::string& object, - int64_t* generation) { + auto generation_getter = + [this](const std::string& fname, const std::string& bucket, + const std::string& object, int64_t* generation) -> absl::Status { GcsFileStat stat; - TF_RETURN_IF_ERROR(RetryingUtils::CallWithRetries( + RETURN_IF_ERROR(RetryingUtils::CallWithRetries( [&fname, &bucket, &object, &stat, this]() { return UncachedStatForObject(fname, bucket, object, &stat); }, @@ -1413,7 +1414,7 @@ absl::Status GcsFileSystem::NewWritableFile( absl::Status GcsFileSystem::NewAppendableFile( const std::string& fname, std::unique_ptr* result) { std::unique_ptr reader; - TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &reader)); + RETURN_IF_ERROR(NewRandomAccessFile(fname, &reader)); std::unique_ptr buffer(new char[kReadAppendableFileBufferSize]); absl::Status status; uint64_t offset = 0; @@ -1421,7 +1422,7 @@ absl::Status GcsFileSystem::NewAppendableFile( // Read the file from GCS in chunks and save it to a tmp file. std::string old_content_filename; - TF_RETURN_IF_ERROR(GetTmpFilename(&old_content_filename)); + RETURN_IF_ERROR(GetTmpFilename(&old_content_filename)); std::ofstream old_content(old_content_filename, std::ofstream::binary); while (true) { status = reader->Read( @@ -1465,12 +1466,11 @@ absl::Status GcsFileSystem::NewAppendableFile( completed, uploaded); }; - auto generation_getter = [this](const std::string& fname, - const std::string& bucket, - const std::string& object, - int64_t* generation) { + auto generation_getter = + [this](const std::string& fname, const std::string& bucket, + const std::string& object, int64_t* generation) -> absl::Status { GcsFileStat stat; - TF_RETURN_IF_ERROR(RetryingUtils::CallWithRetries( + RETURN_IF_ERROR(RetryingUtils::CallWithRetries( [&fname, &bucket, &object, &stat, this]() { return UncachedStatForObject(fname, bucket, object, &stat); }, @@ -1481,7 +1481,7 @@ absl::Status GcsFileSystem::NewAppendableFile( // Create a writable file and pass the old content to it. std::string bucket, object; - TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); + RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); result->reset(new GcsWritableFile( bucket, object, this, old_content_filename, &timeouts_, [this, fname]() { ClearFileCaches(fname); }, retry_config_, @@ -1493,14 +1493,14 @@ absl::Status GcsFileSystem::NewAppendableFile( absl::Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile( const std::string& fname, std::unique_ptr* result) { uint64_t size; - TF_RETURN_IF_ERROR(GetFileSize(fname, &size)); + RETURN_IF_ERROR(GetFileSize(fname, &size)); std::unique_ptr data(new char[size]); std::unique_ptr file; - TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &file)); + RETURN_IF_ERROR(NewRandomAccessFile(fname, &file)); absl::string_view piece; - TF_RETURN_IF_ERROR(file->Read(0, piece, absl::MakeSpan(data.get(), size))); + RETURN_IF_ERROR(file->Read(0, piece, absl::MakeSpan(data.get(), size))); result->reset(new GcsReadOnlyMemoryRegion(std::move(data), size)); return absl::OkStatus(); @@ -1508,10 +1508,10 @@ absl::Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile( absl::Status GcsFileSystem::FileExists(absl::string_view fname) { std::string bucket, object; - TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object)); + RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object)); if (object.empty()) { bool result; - TF_RETURN_IF_ERROR(BucketExists(bucket, &result)); + RETURN_IF_ERROR(BucketExists(bucket, &result)); if (result) { return absl::OkStatus(); } else { @@ -1529,7 +1529,7 @@ absl::Status GcsFileSystem::FileExists(absl::string_view fname) { // Check if the folder exists. bool result; - TF_RETURN_IF_ERROR(FolderExists(fname, &result)); + RETURN_IF_ERROR(FolderExists(fname, &result)); if (result) { return absl::OkStatus(); } @@ -1579,19 +1579,18 @@ absl::Status GcsFileSystem::UncachedStatForObject(absl::string_view fname, request->Send(), " when reading metadata of gs://", bucket, "/", object); Json::Value root; - TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root)); + RETURN_IF_ERROR(ParseJson(output_buffer, &root)); // Parse file size. - TF_RETURN_IF_ERROR(GetInt64Value(root, "size", &stat->base.length)); + RETURN_IF_ERROR(GetInt64Value(root, "size", &stat->base.length)); // Parse generation number. - TF_RETURN_IF_ERROR( - GetInt64Value(root, "generation", &stat->generation_number)); + RETURN_IF_ERROR(GetInt64Value(root, "generation", &stat->generation_number)); // Parse file modification time. std::string updated; - TF_RETURN_IF_ERROR(GetStringValue(root, "updated", &updated)); - TF_RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->base.mtime_nsec))); + RETURN_IF_ERROR(GetStringValue(root, "updated", &updated)); + RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->base.mtime_nsec))); VLOG(1) << "Stat of: gs://" << bucket << "/" << object << " -- " << " length: " << stat->base.length @@ -1619,7 +1618,7 @@ absl::Status GcsFileSystem::StatForObject(absl::string_view fname, "'object' must be a non-empty string. (File: %s)", fname)); } - TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute( + RETURN_IF_ERROR(stat_cache_->LookupOrCompute( fname, stat, [this, &bucket, &object](absl::string_view fname, GcsFileStat* stat) { return UncachedStatForObject(fname, bucket, object, stat); @@ -1651,12 +1650,12 @@ absl::Status GcsFileSystem::CheckBucketLocationConstraint( // Avoid calling external API's in the constructor if (allowed_locations_.erase(kDetectZoneSentinelValue) == 1) { std::string zone; - TF_RETURN_IF_ERROR(zone_provider_->GetZone(&zone)); + RETURN_IF_ERROR(zone_provider_->GetZone(&zone)); allowed_locations_.insert(ZoneToRegion(&zone)); } std::string location; - TF_RETURN_IF_ERROR(GetBucketLocation(bucket, &location)); + RETURN_IF_ERROR(GetBucketLocation(bucket, &location)); if (allowed_locations_.find(location) != allowed_locations_.end()) { return absl::OkStatus(); } @@ -1669,20 +1668,21 @@ absl::Status GcsFileSystem::CheckBucketLocationConstraint( absl::Status GcsFileSystem::GetBucketLocation(const std::string& bucket, std::string* location) { - auto compute_func = [this](absl::string_view bucket, std::string* location) { + auto compute_func = [this](absl::string_view bucket, + std::string* location) -> absl::Status { std::vector result_buffer; absl::Status status = GetBucketMetadata(bucket, &result_buffer); Json::Value result; - TF_RETURN_IF_ERROR(ParseJson(result_buffer, &result)); + RETURN_IF_ERROR(ParseJson(result_buffer, &result)); std::string bucket_location; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( GetStringValue(result, kBucketMetadataLocationKey, &bucket_location)); // Lowercase the GCS location to be case insensitive for allowed locations. *location = absl::AsciiStrToLower(bucket_location); return absl::OkStatus(); }; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( bucket_location_cache_->LookupOrCompute(bucket, location, compute_func)); return absl::OkStatus(); @@ -1691,7 +1691,7 @@ absl::Status GcsFileSystem::GetBucketLocation(const std::string& bucket, absl::Status GcsFileSystem::GetBucketMetadata( absl::string_view bucket, std::vector* result_buffer) { std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + RETURN_IF_ERROR(CreateHttpRequest(&request)); request->SetUri(absl::StrCat(kGcsUriBase, "b/", bucket)); if (result_buffer != nullptr) { @@ -1705,7 +1705,7 @@ absl::Status GcsFileSystem::GetBucketMetadata( absl::Status GcsFileSystem::GetStorageLayout(absl::string_view bucket, std::vector* result_buffer) { std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + RETURN_IF_ERROR(CreateHttpRequest(&request)); request->SetUri(absl::StrCat(kGcsUriBase, "b/", bucket, "/storageLayout")); @@ -1726,7 +1726,7 @@ absl::Status GcsFileSystem::ParseIsHnsEnabled( if (!hns_node.isNull() && hns_node.isObject()) { bool enabled = false; if (hns_node.isMember("enabled")) { - TF_RETURN_IF_ERROR(GetBoolValue(hns_node, "enabled", &enabled)); + RETURN_IF_ERROR(GetBoolValue(hns_node, "enabled", &enabled)); *is_hns = enabled; } @@ -1749,7 +1749,7 @@ absl::Status GcsFileSystem::IsBucketHnsEnabled(const std::string& bucket, }; // Look up the full JSON object in the new cache. - TF_RETURN_IF_ERROR(storage_layout_cache_->LookupOrCompute( + RETURN_IF_ERROR(storage_layout_cache_->LookupOrCompute( bucket, &storage_layout, compute_func)); return ParseIsHnsEnabled(storage_layout, is_hns); @@ -1757,10 +1757,10 @@ absl::Status GcsFileSystem::IsBucketHnsEnabled(const std::string& bucket, absl::Status GcsFileSystem::FolderExists(absl::string_view dirname, bool* result) { - StatCache::ComputeFunc compute_func = [this](absl::string_view dirname, - GcsFileStat* stat) { + StatCache::ComputeFunc compute_func = + [this](absl::string_view dirname, GcsFileStat* stat) -> absl::Status { std::vector children; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( GetChildrenBounded(dirname, 1, &children, true /* recursively */, true /* include_self_directory_marker */)); if (!children.empty()) { @@ -1794,43 +1794,44 @@ absl::Status GcsFileSystem::GetChildren(const std::string& dirname, absl::Status GcsFileSystem::GetMatchingPaths( const std::string& pattern, std::vector* results) { MatchingPathsCache::ComputeFunc compute_func = - [this](absl::string_view pattern, std::vector* results) { - results->clear(); - // Find the fixed prefix by looking for the first wildcard. - const absl::string_view fixed_prefix = - pattern.substr(0, pattern.find_first_of("*?[\\")); - const absl::string_view dir = this->Dirname(fixed_prefix); - if (dir.empty()) { - return absl::InvalidArgumentError(absl::StrCat( - "A GCS pattern doesn't have a bucket name: ", pattern)); - } - std::vector all_files; - TF_RETURN_IF_ERROR(GetChildrenBounded( - dir, UINT64_MAX, &all_files, true /* recursively */, - false /* include_self_directory_marker */)); - - const auto& files_and_folders = AddAllSubpaths(all_files); - - // To handle `/` in the object names, we need to remove it from `dir` - // and then use `StrCat` to insert it back. - const absl::string_view dir_no_slash = absl::StripSuffix(dir, "/"); - - // Match all obtained paths to the input pattern. - for (const auto& path : files_and_folders) { - // Manually construct the path instead of using `JoinPath` for the - // cases where `path` starts with a `/` (which is a valid character in - // the filenames of GCS objects). `JoinPath` canonicalizes the result, - // removing duplicate slashes. We know that `dir_no_slash` does not - // end in `/`, so we are safe inserting the new `/` here as the path - // separator. - const std::string full_path = absl::StrCat(dir_no_slash, "/", path); - if (this->Match(full_path, pattern)) { - results->push_back(full_path); - } - } - return absl::OkStatus(); - }; - TF_RETURN_IF_ERROR( + [this](absl::string_view pattern, + std::vector* results) -> absl::Status { + results->clear(); + // Find the fixed prefix by looking for the first wildcard. + const absl::string_view fixed_prefix = + pattern.substr(0, pattern.find_first_of("*?[\\")); + const absl::string_view dir = this->Dirname(fixed_prefix); + if (dir.empty()) { + return absl::InvalidArgumentError( + absl::StrCat("A GCS pattern doesn't have a bucket name: ", pattern)); + } + std::vector all_files; + RETURN_IF_ERROR( + GetChildrenBounded(dir, UINT64_MAX, &all_files, true /* recursively */, + false /* include_self_directory_marker */)); + + const auto& files_and_folders = AddAllSubpaths(all_files); + + // To handle `/` in the object names, we need to remove it from `dir` + // and then use `StrCat` to insert it back. + const absl::string_view dir_no_slash = absl::StripSuffix(dir, "/"); + + // Match all obtained paths to the input pattern. + for (const auto& path : files_and_folders) { + // Manually construct the path instead of using `JoinPath` for the + // cases where `path` starts with a `/` (which is a valid character in + // the filenames of GCS objects). `JoinPath` canonicalizes the result, + // removing duplicate slashes. We know that `dir_no_slash` does not + // end in `/`, so we are safe inserting the new `/` here as the path + // separator. + const std::string full_path = absl::StrCat(dir_no_slash, "/", path); + if (this->Match(full_path, pattern)) { + results->push_back(full_path); + } + } + return absl::OkStatus(); + }; + RETURN_IF_ERROR( matching_paths_cache_->LookupOrCompute(pattern, results, compute_func)); return absl::OkStatus(); } @@ -1843,7 +1844,7 @@ absl::Status GcsFileSystem::GetChildrenBounded( return absl::InvalidArgumentError("'result' cannot be null"); } std::string bucket, object_prefix; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ParseGcsPath(MaybeAppendSlash(dirname), true, &bucket, &object_prefix)); std::string nextPageToken; @@ -1851,7 +1852,7 @@ absl::Status GcsFileSystem::GetChildrenBounded( while (true) { // A loop over multiple result pages. std::vector output_buffer; std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + RETURN_IF_ERROR(CreateHttpRequest(&request)); auto uri = absl::StrCat(kGcsUriBase, "b/", bucket, "/o"); if (recursive) { uri = absl::StrCat(uri, "?fields=items%2Fname%2CnextPageToken"); @@ -1878,7 +1879,7 @@ absl::Status GcsFileSystem::GetChildrenBounded( TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading ", dirname); Json::Value root; - TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root)); + RETURN_IF_ERROR(ParseJson(output_buffer, &root)); const auto items = root.get("items", Json::Value::null); if (!items.isNull()) { if (!items.isArray()) { @@ -1892,7 +1893,7 @@ absl::Status GcsFileSystem::GetChildrenBounded( "Unexpected JSON format: 'items' should be a list of objects."); } std::string name; - TF_RETURN_IF_ERROR(GetStringValue(item, "name", &name)); + RETURN_IF_ERROR(GetStringValue(item, "name", &name)); // The names should be relative to the 'dirname'. That means the // 'object_prefix', which is part of 'dirname', should be removed from // the beginning of 'name'. @@ -1955,10 +1956,10 @@ absl::Status GcsFileSystem::Stat(const std::string& fname, return absl::InternalError("'stat' cannot be nullptr."); } std::string bucket, object; - TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object)); + RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object)); if (object.empty()) { bool is_bucket; - TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket)); + RETURN_IF_ERROR(BucketExists(bucket, &is_bucket)); if (is_bucket) { *stat = DIRECTORY_STAT; return absl::OkStatus(); @@ -1977,7 +1978,7 @@ absl::Status GcsFileSystem::Stat(const std::string& fname, return status; } bool is_folder; - TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder)); + RETURN_IF_ERROR(FolderExists(fname, &is_folder)); if (is_folder) { *stat = DIRECTORY_STAT; return absl::OkStatus(); @@ -1988,10 +1989,10 @@ absl::Status GcsFileSystem::Stat(const std::string& fname, absl::Status GcsFileSystem::DeleteFile(const std::string& fname) { std::string bucket, object; - TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); + RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + RETURN_IF_ERROR(CreateHttpRequest(&request)); request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/", request->EscapeString(object))); request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata); @@ -2007,11 +2008,11 @@ absl::Status GcsFileSystem::CreateDir(const std::string& dirname) { VLOG(3) << "CreateDir: creating directory with dirname: " << dirname << " and dirname_with_slash: " << dirname_with_slash; std::string bucket, object; - TF_RETURN_IF_ERROR(ParseGcsPath(dirname_with_slash, /*empty_object_ok=*/true, - &bucket, &object)); + RETURN_IF_ERROR(ParseGcsPath(dirname_with_slash, /*empty_object_ok=*/true, + &bucket, &object)); if (object.empty()) { bool is_bucket; - TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket)); + RETURN_IF_ERROR(BucketExists(bucket, &is_bucket)); return is_bucket ? absl::OkStatus() : absl::NotFoundError(absl::StrCat("The specified bucket ", dirname_with_slash, @@ -2025,7 +2026,7 @@ absl::Status GcsFileSystem::CreateDir(const std::string& dirname) { } std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + RETURN_IF_ERROR(CreateHttpRequest(&request)); request->SetUri(strings::StrCat( kGcsUploadUriBase, "b/", bucket, @@ -2058,9 +2059,9 @@ absl::Status GcsFileSystem::DeleteDir(const std::string& dirname) { // with the corresponding name prefix or if there is exactly one matching // object and it is the directory marker. Therefore we need to retrieve // at most two children for the prefix to detect if a directory is empty. - TF_RETURN_IF_ERROR( - GetChildrenBounded(dirname, 2, &children, true /* recursively */, - true /* include_self_directory_marker */)); + RETURN_IF_ERROR(GetChildrenBounded(dirname, 2, &children, + true /* recursively */, + true /* include_self_directory_marker */)); if (children.size() > 1 || (children.size() == 1 && !children[0].empty())) { return absl::FailedPreconditionError( @@ -2081,10 +2082,10 @@ absl::Status GcsFileSystem::GetFileSize(const std::string& fname, // Only validate the name. std::string bucket, object; - TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); + RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); FileStatistics stat; - TF_RETURN_IF_ERROR(Stat(fname, &stat)); + RETURN_IF_ERROR(Stat(fname, &stat)); *file_size = stat.length; return absl::OkStatus(); } @@ -2097,16 +2098,15 @@ absl::Status GcsFileSystem::RenameFile(const std::string& src, // It's a directory. Parse both source and target to check the buckets. std::string src_bucket, src_object; - TF_RETURN_IF_ERROR(ParseGcsPath(src, true, &src_bucket, &src_object)); + RETURN_IF_ERROR(ParseGcsPath(src, true, &src_bucket, &src_object)); std::string target_bucket, target_object; - TF_RETURN_IF_ERROR( - ParseGcsPath(target, true, &target_bucket, &target_object)); + RETURN_IF_ERROR(ParseGcsPath(target, true, &target_bucket, &target_object)); // If buckets are the same, we can check for HNS and use the fast rename API. if (src_bucket == target_bucket) { bool hns_enabled = false; - TF_RETURN_IF_ERROR(IsBucketHnsEnabled(src_bucket, &hns_enabled)); + RETURN_IF_ERROR(IsBucketHnsEnabled(src_bucket, &hns_enabled)); if (hns_enabled) { return RenameFolderHns(src, target); @@ -2118,11 +2118,11 @@ absl::Status GcsFileSystem::RenameFile(const std::string& src, // 2. The buckets are the same, but HNS is not enabled. VLOG(1) << "Falling back to iterative rename for directory " << src; std::vector children; - TF_RETURN_IF_ERROR( - GetChildrenBounded(src, UINT64_MAX, &children, true /* recursively */, - true /* include_self_directory_marker */)); + RETURN_IF_ERROR(GetChildrenBounded(src, UINT64_MAX, &children, + true /* recursively */, + true /* include_self_directory_marker */)); for (const std::string& subpath : children) { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( RenameObject(JoinGcsPath(src, subpath), JoinGcsPath(target, subpath))); } return absl::OkStatus(); @@ -2133,12 +2133,11 @@ absl::Status GcsFileSystem::RenameObject(const std::string& src, const std::string& target) { VLOG(3) << "RenameObject: started gs://" << src << " to " << target; std::string src_bucket, src_object, target_bucket, target_object; - TF_RETURN_IF_ERROR(ParseGcsPath(src, false, &src_bucket, &src_object)); - TF_RETURN_IF_ERROR( - ParseGcsPath(target, false, &target_bucket, &target_object)); + RETURN_IF_ERROR(ParseGcsPath(src, false, &src_bucket, &src_object)); + RETURN_IF_ERROR(ParseGcsPath(target, false, &target_bucket, &target_object)); std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + RETURN_IF_ERROR(CreateHttpRequest(&request)); request->SetUri(strings::StrCat(kGcsUriBase, "b/", src_bucket, "/o/", request->EscapeString(src_object), "/rewriteTo/b/", target_bucket, "/o/", @@ -2153,9 +2152,9 @@ absl::Status GcsFileSystem::RenameObject(const std::string& src, // DeleteFile call below. ClearFileCaches(target); Json::Value root; - TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root)); + RETURN_IF_ERROR(ParseJson(output_buffer, &root)); bool done; - TF_RETURN_IF_ERROR(GetBoolValue(root, "done", &done)); + RETURN_IF_ERROR(GetBoolValue(root, "done", &done)); if (!done) { // If GCS didn't complete rewrite in one call, this means that a large file // is being copied to a bucket with a different storage class or location, @@ -2177,10 +2176,10 @@ absl::Status GcsFileSystem::RenameObject(const std::string& src, absl::Status GcsFileSystem::IsDirectory(const std::string& fname) { std::string bucket, object; - TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object)); + RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object)); if (object.empty()) { bool is_bucket; - TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket)); + RETURN_IF_ERROR(BucketExists(bucket, &is_bucket)); if (is_bucket) { return absl::OkStatus(); } @@ -2188,12 +2187,12 @@ absl::Status GcsFileSystem::IsDirectory(const std::string& fname) { absl::StrCat("The specified bucket gs://", bucket, " was not found.")); } bool is_folder; - TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder)); + RETURN_IF_ERROR(FolderExists(fname, &is_folder)); if (is_folder) { return absl::OkStatus(); } bool is_object; - TF_RETURN_IF_ERROR(ObjectExists(fname, bucket, object, &is_object)); + RETURN_IF_ERROR(ObjectExists(fname, bucket, object, &is_object)); if (is_object) { return absl::FailedPreconditionError( absl::StrCat("The specified path ", fname, " is not a directory.")); @@ -2219,9 +2218,9 @@ absl::Status GcsFileSystem::DeleteRecursively(const std::string& dirname, } std::vector all_objects; // Get all children in the directory recursively. - TF_RETURN_IF_ERROR(GetChildrenBounded( - dirname, UINT64_MAX, &all_objects, true /* recursively */, - true /* include_self_directory_marker */)); + RETURN_IF_ERROR(GetChildrenBounded(dirname, UINT64_MAX, &all_objects, + true /* recursively */, + true /* include_self_directory_marker */)); for (const std::string& object : all_objects) { const std::string& full_path = JoinGcsPath(dirname, object); // Delete all objects including directory markers for subfolders. @@ -2248,12 +2247,11 @@ absl::Status GcsFileSystem::RenameFolderHns(const std::string& src, << "' to: '" << target << "'"; std::string src_bucket, src_object, target_bucket, target_object; - TF_RETURN_IF_ERROR(ParseGcsPath(src, false, &src_bucket, &src_object)); - TF_RETURN_IF_ERROR( - ParseGcsPath(target, false, &target_bucket, &target_object)); + RETURN_IF_ERROR(ParseGcsPath(src, false, &src_bucket, &src_object)); + RETURN_IF_ERROR(ParseGcsPath(target, false, &target_bucket, &target_object)); std::unique_ptr request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); + RETURN_IF_ERROR(CreateHttpRequest(&request)); const std::string uri_to_send = absl::StrCat(kGcsUriBase, "b/", src_bucket, "/folders/", @@ -2273,11 +2271,11 @@ absl::Status GcsFileSystem::RenameFolderHns(const std::string& src, // Parse the long-running operation object from the response. Json::Value operation_response; - TF_RETURN_IF_ERROR(ParseJson(output_buffer, &operation_response)); + RETURN_IF_ERROR(ParseJson(output_buffer, &operation_response)); bool done = false; if (operation_response.isMember("done")) { - TF_RETURN_IF_ERROR(GetBoolValue(operation_response, "done", &done)); + RETURN_IF_ERROR(GetBoolValue(operation_response, "done", &done)); if (done) { if (operation_response.isMember("error")) { return absl::InternalError( @@ -2291,8 +2289,7 @@ absl::Status GcsFileSystem::RenameFolderHns(const std::string& src, } std::string operation_name; - TF_RETURN_IF_ERROR( - GetStringValue(operation_response, "name", &operation_name)); + RETURN_IF_ERROR(GetStringValue(operation_response, "name", &operation_name)); absl::string_view operation_id = io::Basename(operation_name); @@ -2303,7 +2300,7 @@ absl::Status GcsFileSystem::RenameFolderHns(const std::string& src, while (true) { absl::SleepFor(kPollingInterval); std::unique_ptr poll_request; - TF_RETURN_IF_ERROR(CreateHttpRequest(&poll_request)); + RETURN_IF_ERROR(CreateHttpRequest(&poll_request)); poll_request->SetUri(absl::StrCat(kGcsUriBase, "b/", src_bucket, "/operations/", operation_id)); @@ -2315,7 +2312,7 @@ absl::Status GcsFileSystem::RenameFolderHns(const std::string& src, TF_RETURN_WITH_CONTEXT_IF_ERROR(poll_request->Send(), " when polling operation ", operation_id); - TF_RETURN_IF_ERROR(ParseJson(poll_output_buffer, &operation_response)); + RETURN_IF_ERROR(ParseJson(poll_output_buffer, &operation_response)); if (operation_response.isMember("error")) { return absl::InternalError( @@ -2325,7 +2322,7 @@ absl::Status GcsFileSystem::RenameFolderHns(const std::string& src, if (operation_response.isMember("done")) { bool done = false; - TF_RETURN_IF_ERROR(GetBoolValue(operation_response, "done", &done)); + RETURN_IF_ERROR(GetBoolValue(operation_response, "done", &done)); if (done) { break; } @@ -2387,8 +2384,7 @@ absl::Status GcsFileSystem::CreateHttpRequest( std::string auth_token; { absl::ReaderMutexLock l(mu_); - TF_RETURN_IF_ERROR( - AuthProvider::GetToken(auth_provider_.get(), &auth_token)); + RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token)); } new_request->AddAuthBearerHeader(auth_token); diff --git a/third_party/xla/xla/tsl/platform/cloud/google_auth_provider.cc b/third_party/xla/xla/tsl/platform/cloud/google_auth_provider.cc index 9d4d29bd70340b..c428560c690dfd 100644 --- a/third_party/xla/xla/tsl/platform/cloud/google_auth_provider.cc +++ b/third_party/xla/xla/tsl/platform/cloud/google_auth_provider.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "json/json.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" @@ -207,10 +208,10 @@ absl::Status GoogleAuthProvider::GetToken(std::string* t) { kNoGceCheck, " environment variable.")); } else { int max_requests; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ParseNonNegativeIntEnvVar(kGcsAuthMaxRequests, 1, &max_requests)); int retry_delay_sec; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ParseNonNegativeIntEnvVar(kGcsAuthRetryDelaySec, 5, &retry_delay_sec)); for (int i = 0; i < max_requests; ++i) { @@ -278,10 +279,10 @@ absl::Status GoogleAuthProvider::GetTokenFromFiles() { "Couldn't parse the JSON credentials file."); } if (json.isMember("refresh_token")) { - TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( + RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( json, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); } else if (json.isMember("private_key")) { - TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( + RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( json, kOAuthV4Url, kOAuthScope, ¤t_token_, &expiration_timestamp_sec_)); } else { @@ -295,12 +296,12 @@ absl::Status GoogleAuthProvider::GetTokenFromGce() { std::vector response_buffer; const uint64_t request_timestamp_sec = env_->NowSeconds(); - TF_RETURN_IF_ERROR(compute_engine_metadata_client_->GetMetadata( + RETURN_IF_ERROR(compute_engine_metadata_client_->GetMetadata( kGceTokenPath, &response_buffer)); absl::string_view response = absl::string_view(&response_buffer[0], response_buffer.size()); - TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse( + RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse( response, request_timestamp_sec, ¤t_token_, &expiration_timestamp_sec_)); diff --git a/third_party/xla/xla/tsl/platform/cloud/oauth_client.cc b/third_party/xla/xla/tsl/platform/cloud/oauth_client.cc index 1ef7e0eb1e5299..18e1a6cc6fead3 100644 --- a/third_party/xla/xla/tsl/platform/cloud/oauth_client.cc +++ b/third_party/xla/xla/tsl/platform/cloud/oauth_client.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/platform/status_macros.h" #include #include #include @@ -67,7 +68,7 @@ absl::Status ReadJsonValue(const Json::Value& json, const std::string& name, absl::Status ReadJsonString(const Json::Value& json, const std::string& name, std::string* value) { Json::Value json_value; - TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value)); + RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value)); if (!json_value.isString()) { return absl::FailedPreconditionError( absl::StrCat("JSON value '", name, "' is not string.")); @@ -79,7 +80,7 @@ absl::Status ReadJsonString(const Json::Value& json, const std::string& name, absl::Status ReadJsonInt(const Json::Value& json, const std::string& name, int64_t* value) { Json::Value json_value; - TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value)); + RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value)); if (!json_value.isIntegral()) { return absl::FailedPreconditionError( absl::StrCat("JSON value '", name, "' is not integer.")); @@ -193,11 +194,10 @@ absl::Status OAuthClient::GetTokenFromServiceAccountJson( "'token' and 'expiration_timestamp_sec' cannot be nullptr."); } std::string private_key_serialized, private_key_id, client_id, client_email; - TF_RETURN_IF_ERROR( - ReadJsonString(json, "private_key", &private_key_serialized)); - TF_RETURN_IF_ERROR(ReadJsonString(json, "private_key_id", &private_key_id)); - TF_RETURN_IF_ERROR(ReadJsonString(json, "client_id", &client_id)); - TF_RETURN_IF_ERROR(ReadJsonString(json, "client_email", &client_email)); + RETURN_IF_ERROR(ReadJsonString(json, "private_key", &private_key_serialized)); + RETURN_IF_ERROR(ReadJsonString(json, "private_key_id", &private_key_id)); + RETURN_IF_ERROR(ReadJsonString(json, "client_id", &client_id)); + RETURN_IF_ERROR(ReadJsonString(json, "client_email", &client_email)); std::unique_ptr> bio( BIO_new(BIO_s_mem()), [](BIO* ptr) { BIO_free_all(ptr); }); @@ -215,12 +215,12 @@ absl::Status OAuthClient::GetTokenFromServiceAccountJson( const uint64_t request_timestamp_sec = env_->NowSeconds(); std::string encoded_claim, encoded_header; - TF_RETURN_IF_ERROR(EncodeJwtHeader(private_key_id, &encoded_header)); - TF_RETURN_IF_ERROR(EncodeJwtClaim(client_email, scope, oauth_server_uri, - request_timestamp_sec, &encoded_claim)); + RETURN_IF_ERROR(EncodeJwtHeader(private_key_id, &encoded_header)); + RETURN_IF_ERROR(EncodeJwtClaim(client_email, scope, oauth_server_uri, + request_timestamp_sec, &encoded_claim)); const std::string to_sign = encoded_header + "." + encoded_claim; std::string signature; - TF_RETURN_IF_ERROR(CreateSignature(private_key.get(), to_sign, &signature)); + RETURN_IF_ERROR(CreateSignature(private_key.get(), to_sign, &signature)); const std::string jwt = to_sign + "." + signature; const std::string request_body = absl::StrCat("grant_type=", kGrantType, "&assertion=", jwt); @@ -231,12 +231,12 @@ absl::Status OAuthClient::GetTokenFromServiceAccountJson( request->SetUri(std::string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); - TF_RETURN_IF_ERROR(request->Send()); + RETURN_IF_ERROR(request->Send()); absl::string_view response = absl::string_view(response_buffer.data(), response_buffer.size()); - TF_RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token, - expiration_timestamp_sec)); + RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token, + expiration_timestamp_sec)); return absl::OkStatus(); } @@ -248,9 +248,9 @@ absl::Status OAuthClient::GetTokenFromRefreshTokenJson( "'token' and 'expiration_timestamp_sec' cannot be nullptr."); } std::string client_id, client_secret, refresh_token; - TF_RETURN_IF_ERROR(ReadJsonString(json, "client_id", &client_id)); - TF_RETURN_IF_ERROR(ReadJsonString(json, "client_secret", &client_secret)); - TF_RETURN_IF_ERROR(ReadJsonString(json, "refresh_token", &refresh_token)); + RETURN_IF_ERROR(ReadJsonString(json, "client_id", &client_id)); + RETURN_IF_ERROR(ReadJsonString(json, "client_secret", &client_secret)); + RETURN_IF_ERROR(ReadJsonString(json, "refresh_token", &refresh_token)); const auto request_body = absl::StrCat( "client_id=", client_id, "&client_secret=", client_secret, @@ -263,12 +263,12 @@ absl::Status OAuthClient::GetTokenFromRefreshTokenJson( request->SetUri(std::string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); - TF_RETURN_IF_ERROR(request->Send()); + RETURN_IF_ERROR(request->Send()); absl::string_view response = absl::string_view(response_buffer.data(), response_buffer.size()); - TF_RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token, - expiration_timestamp_sec)); + RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token, + expiration_timestamp_sec)); return absl::OkStatus(); } @@ -287,15 +287,15 @@ absl::Status OAuthClient::ParseOAuthResponse( } std::string token_type; - TF_RETURN_IF_ERROR(ReadJsonString(root, "token_type", &token_type)); + RETURN_IF_ERROR(ReadJsonString(root, "token_type", &token_type)); if (token_type != "Bearer") { return absl::FailedPreconditionError("Unexpected Oauth token type: " + token_type); } int64_t expires_in = 0; - TF_RETURN_IF_ERROR(ReadJsonInt(root, "expires_in", &expires_in)); + RETURN_IF_ERROR(ReadJsonInt(root, "expires_in", &expires_in)); *expiration_timestamp_sec = request_timestamp_sec + expires_in; - TF_RETURN_IF_ERROR(ReadJsonString(root, "access_token", token)); + RETURN_IF_ERROR(ReadJsonString(root, "access_token", token)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.cc b/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.cc index 6aa094e88dc352..62af02b5fd8e8d 100644 --- a/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.cc +++ b/third_party/xla/xla/tsl/platform/cloud/ram_file_block_cache.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/env.h" namespace tsl { @@ -197,8 +198,8 @@ absl::Status RamFileBlockCache::Read(const std::string& filename, size_t offset, // LRU iterator for the key and block. std::shared_ptr block = Lookup(key); DCHECK(block) << "No block for key " << key.first << "@" << key.second; - TF_RETURN_IF_ERROR(MaybeFetch(key, block)); - TF_RETURN_IF_ERROR(UpdateLRU(key, block)); + RETURN_IF_ERROR(MaybeFetch(key, block)); + RETURN_IF_ERROR(UpdateLRU(key, block)); // Copy the relevant portion of the block into the result buffer. const auto& data = block->data; if (offset >= pos + data.size()) { diff --git a/third_party/xla/xla/tsl/platform/default/BUILD b/third_party/xla/xla/tsl/platform/default/BUILD index 81a8d03fdc40ce..88d7ca867f7ce0 100644 --- a/third_party/xla/xla/tsl/platform/default/BUILD +++ b/third_party/xla/xla/tsl/platform/default/BUILD @@ -572,6 +572,7 @@ cc_library( deps = [ "//xla/tsl/platform:macros", "//xla/tsl/platform:status", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status:statusor", ], ) diff --git a/third_party/xla/xla/tsl/platform/default/statusor.h b/third_party/xla/xla/tsl/platform/default/statusor.h index babd52ed96d7b7..947096f37d924b 100644 --- a/third_party/xla/xla/tsl/platform/default/statusor.h +++ b/third_party/xla/xla/tsl/platform/default/statusor.h @@ -15,19 +15,26 @@ limitations under the License. #ifndef XLA_TSL_PLATFORM_DEFAULT_STATUSOR_H_ #define XLA_TSL_PLATFORM_DEFAULT_STATUSOR_H_ +#include "absl/base/attributes.h" #include "absl/status/statusor.h" #include "xla/tsl/platform/macros.h" #include "xla/tsl/platform/status.h" +namespace tsl { +ABSL_DEPRECATED( + "TF_ASSIGN_OR_RETURN is deprecated. Use ASSIGN_OR_RETURN instead") +inline void TfAssignOrReturnDeprecationMarker() {} +} // namespace tsl + #define TF_ASSIGN_OR_RETURN(lhs, rexpr) \ TF_ASSIGN_OR_RETURN_IMPL( \ TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) -#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - if (TF_PREDICT_FALSE(!statusor.ok())) { \ - return statusor.status(); \ - } \ +#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (::tsl::TfAssignOrReturnDeprecationMarker(), (rexpr)); \ + if (TF_PREDICT_FALSE(!statusor.ok())) { \ + return statusor.status(); \ + } \ lhs = std::move(statusor).value() #endif // XLA_TSL_PLATFORM_DEFAULT_STATUSOR_H_ diff --git a/third_party/xla/xla/tsl/platform/errors.h b/third_party/xla/xla/tsl/platform/errors.h index 1dab8ff6469d60..16448b47747d8b 100644 --- a/third_party/xla/xla/tsl/platform/errors.h +++ b/third_party/xla/xla/tsl/platform/errors.h @@ -36,6 +36,7 @@ limitations under the License. #include "tsl/platform/platform.h" namespace tsl { + namespace error { // NOLINTBEGIN(misc-unused-using-decls) // TODO(aminim): figure out the protobuf migration story. @@ -189,14 +190,19 @@ void AppendToMessage(absl::Status* status, Args... args) { *status = std::move(new_status); } +ABSL_DEPRECATED( + "TF_RETURN_IF_ERROR is deprecated. Call RETURN_IF_ERROR instead") +inline void TfReturnIfErrorDeprecationMarker() {} + // For propagating errors when calling a function. -#define TF_RETURN_IF_ERROR(...) \ - do { \ - absl::Status _status = (__VA_ARGS__); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - MAYBE_ADD_SOURCE_LOCATION(_status) \ - return _status; \ - } \ +#define TF_RETURN_IF_ERROR(...) \ + do { \ + ::tsl::errors::TfReturnIfErrorDeprecationMarker(); \ + absl::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + MAYBE_ADD_SOURCE_LOCATION(_status) \ + return _status; \ + } \ } while (0) #define TF_RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \ diff --git a/third_party/xla/xla/tsl/platform/windows/BUILD b/third_party/xla/xla/tsl/platform/windows/BUILD index e873f26d3fdd5f..20bb4d0a1e256a 100644 --- a/third_party/xla/xla/tsl/platform/windows/BUILD +++ b/third_party/xla/xla/tsl/platform/windows/BUILD @@ -52,6 +52,7 @@ cc_library( "//xla/tsl/platform:logging", "//xla/tsl/platform:macros", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:threadpool_interface", "//xla/tsl/platform:types", "//xla/tsl/protobuf:error_codes_proto_impl_cc", diff --git a/third_party/xla/xla/tsl/platform/windows/windows_file_system.cc b/third_party/xla/xla/tsl/platform/windows/windows_file_system.cc index fcd3b40fc53d50..5b36b953f06790 100644 --- a/third_party/xla/xla/tsl/platform/windows/windows_file_system.cc +++ b/third_party/xla/xla/tsl/platform/windows/windows_file_system.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/file_statistics.h" @@ -657,7 +658,7 @@ absl::Status WindowsFileSystem::GetFileSize(const std::string& fname, absl::Status WindowsFileSystem::IsDirectory(const std::string& fname) { std::wstring ws_final_fname = GetUncPathName(TranslateName(fname)); std::string str_final_fname(ws_final_fname.begin(), ws_final_fname.end()); - TF_RETURN_IF_ERROR(FileExists(str_final_fname)); + RETURN_IF_ERROR(FileExists(str_final_fname)); if (PathIsDirectoryW(ws_final_fname.c_str())) { return absl::OkStatus(); } @@ -707,8 +708,8 @@ absl::Status WindowsFileSystem::GetMatchingPaths( // but no code appears to rely on this behavior. std::string converted_pattern(pattern); std::replace(converted_pattern.begin(), converted_pattern.end(), '\\', '/'); - TF_RETURN_IF_ERROR(internal::GetMatchingPaths(this, Env::Default(), - converted_pattern, results)); + RETURN_IF_ERROR(internal::GetMatchingPaths(this, Env::Default(), + converted_pattern, results)); for (std::string& result : *results) { std::replace(result.begin(), result.end(), '/', '\\'); } diff --git a/third_party/xla/xla/tsl/profiler/rpc/BUILD b/third_party/xla/xla/tsl/profiler/rpc/BUILD index c9c02d77749870..376f4e4ba91939 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/BUILD +++ b/third_party/xla/xla/tsl/profiler/rpc/BUILD @@ -35,6 +35,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:macros", + "//xla/tsl/platform:status_macros", "//xla/tsl/profiler/rpc/client:save_profile", "//xla/tsl/profiler/utils:math_utils", "//xla/tsl/profiler/utils:time_utils", diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/BUILD b/third_party/xla/xla/tsl/profiler/rpc/client/BUILD index 40fba67114eed0..05e869e20cd1b7 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/BUILD +++ b/third_party/xla/xla/tsl/profiler/rpc/client/BUILD @@ -37,6 +37,7 @@ cc_library( ":save_profile", "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/profiler/convert:trace_events_to_json", "//xla/tsl/profiler/convert:xplane_to_trace_events", "//xla/tsl/profiler/utils:session_manager", @@ -75,6 +76,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "//xla/tsl/profiler/utils:file_system_utils", "@com_google_absl//absl/strings", @@ -131,6 +133,7 @@ cc_library( deps = [ "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:status_macros", "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/memory", @@ -199,6 +202,7 @@ cc_library( "//xla/tsl/platform:logging", "//xla/tsl/platform:macros", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/platform:types", "@com_google_absl//absl/memory", diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc b/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc index 33fa94d29cf8bc..67f23653580f2a 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/profiler/convert/trace_events_to_json.h" @@ -136,7 +137,7 @@ absl::Status Profile(const std::string& repository_root, ProfileRequest request = PopulateProfileRequest(repository_root, session_id, /*host_name=*/"", opts); auto session = RemoteProfilerSessionManager::Create(opts, request, status); - TF_RETURN_IF_ERROR(status); + RETURN_IF_ERROR(status); // Expect one or more service addresses. DCHECK_GT(opts.service_addresses_size(), 0); std::vector responses = session->WaitForCompletion(); @@ -154,9 +155,9 @@ absl::Status Profile(const std::string& repository_root, // If server side returns tool data in the response, saves that into the // repository. This improves backward compatibility by reducing assumption // of what server side does. - TF_RETURN_IF_ERROR(SaveProfile(repository_root, session_id, - client_response.service_address, response, - &std::cout)); + RETURN_IF_ERROR(SaveProfile(repository_root, session_id, + client_response.service_address, response, + &std::cout)); } if (!client_response.status.ok()) { LOG(WARNING) << client_response.service_address << " returned " @@ -182,7 +183,7 @@ absl::Status NewSession(absl::string_view repository_root, NewProfileSessionRequest request = PopulateNewProfileSessionRequest(repository_root, session_id, opts); NewProfileSessionResponse response; - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( NewSessionGrpc(opts.service_addresses(0), request, &response)); std::cout << "Profile session succeed for host(s):" @@ -205,14 +206,14 @@ absl::Status StartContinuousProfiling( /*repository_root=*/"", session_id, /*host_name=*/service_addr, opts); tensorflow::ContinuousProfilingResponse response; - TF_RETURN_IF_ERROR(ContinuousProfilingGrpc(service_addr, request, &response)); + RETURN_IF_ERROR(ContinuousProfilingGrpc(service_addr, request, &response)); return absl::OkStatus(); } absl::Status StopContinuousProfiling(const char* service_addr) { tensorflow::StopContinuousProfilingRequest request; tensorflow::StopContinuousProfilingResponse response; - TF_RETURN_IF_ERROR(tsl::profiler::StopContinuousProfilingGrpc( + RETURN_IF_ERROR(tsl::profiler::StopContinuousProfilingGrpc( service_addr, request, &response)); return absl::OkStatus(); } @@ -220,7 +221,7 @@ absl::Status StopContinuousProfiling(const char* service_addr) { absl::Status GetSnapshot(const char* service_addr, const char* logdir) { tensorflow::GetSnapshotRequest request; ProfileResponse response; - TF_RETURN_IF_ERROR(GetSnapshotGrpc(service_addr, request, &response)); + RETURN_IF_ERROR(GetSnapshotGrpc(service_addr, request, &response)); if (response.empty_trace()) { return absl::OkStatus(); @@ -228,11 +229,11 @@ absl::Status GetSnapshot(const char* service_addr, const char* logdir) { std::string repository_root = GetTensorBoardProfilePluginDir(logdir); std::string snapshot_session_id = GetCurrentTimeStampAsString(); - TF_RETURN_IF_ERROR(SaveProfile(repository_root, snapshot_session_id, - service_addr, response, &std::cout)); + RETURN_IF_ERROR(SaveProfile(repository_root, snapshot_session_id, + service_addr, response, &std::cout)); if (response.has_xspace()) { - TF_RETURN_IF_ERROR(SaveXSpace(repository_root, snapshot_session_id, - service_addr, response.xspace())); + RETURN_IF_ERROR(SaveXSpace(repository_root, snapshot_session_id, + service_addr, response.xspace())); } return absl::OkStatus(); @@ -294,7 +295,7 @@ absl::Status Monitor(const std::string& service_addr, int duration_ms, MonitorRequest request = PopulateMonitorRequest(duration_ms, monitoring_level, display_timestamp); MonitorResponse response; - TF_RETURN_IF_ERROR(MonitorGrpc(service_addr, request, &response)); + RETURN_IF_ERROR(MonitorGrpc(service_addr, request, &response)); *result = response.data(); return absl::OkStatus(); } @@ -306,7 +307,7 @@ absl::Status ExportToTensorBoard(const XSpace& xspace, std::string repository_root = tsl::profiler::GetTensorBoardProfilePluginDir(logdir); std::string host = tsl::port::Hostname(); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( tsl::profiler::SaveXSpace(repository_root, run, host, xspace)); if (also_export_trace_json) { tsl::profiler::TraceContainer container = @@ -337,11 +338,11 @@ absl::Status CaptureRemoteTrace( GetRemoteSessionManagerOptionsLocked(service_addr, logdir, worker_list, include_dataset_ops, duration_ms, options, &is_cloud_tpu_session); - TF_RETURN_IF_ERROR(ValidateRemoteProfilerSessionManagerOptions(opts)); + RETURN_IF_ERROR(ValidateRemoteProfilerSessionManagerOptions(opts)); { - TF_RETURN_IF_ERROR(CaptureRemoteTrace(logdir, num_tracing_attempts, opts, - is_cloud_tpu_session)); + RETURN_IF_ERROR(CaptureRemoteTrace(logdir, num_tracing_attempts, opts, + is_cloud_tpu_session)); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc index 2ff22822b7f199..38f1ff2432548b 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "grpcpp/client_context.h" #include "grpcpp/create_channel.h" #include "grpcpp/grpcpp.h" // IWYU pragma: keep @@ -80,8 +81,7 @@ absl::Status ProfileGrpc(const std::string& service_address, ::grpc::ClientContext context; std::unique_ptr stub = CreateStub(service_address); - TF_RETURN_IF_ERROR( - FromGrpcStatus(stub->Profile(&context, request, response))); + RETURN_IF_ERROR(FromGrpcStatus(stub->Profile(&context, request, response))); return absl::OkStatus(); } @@ -91,7 +91,7 @@ absl::Status ContinuousProfilingGrpc(const std::string& service_address, ::grpc::ClientContext context; std::unique_ptr stub = CreateStub(service_address); - TF_RETURN_IF_ERROR(FromGrpcStatus( + RETURN_IF_ERROR(FromGrpcStatus( stub->StartContinuousProfiling(&context, request, response))); return absl::OkStatus(); } @@ -103,7 +103,7 @@ absl::Status StopContinuousProfilingGrpc( ::grpc::ClientContext context; std::unique_ptr stub = CreateStub(service_address); - TF_RETURN_IF_ERROR(FromGrpcStatus( + RETURN_IF_ERROR(FromGrpcStatus( stub->StopContinuousProfiling(&context, request, response))); return absl::OkStatus(); } @@ -114,7 +114,7 @@ absl::Status GetSnapshotGrpc(const std::string& service_address, ::grpc::ClientContext context; std::unique_ptr stub = CreateStub(service_address); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( FromGrpcStatus(stub->GetSnapshot(&context, request, response))); return absl::OkStatus(); } @@ -125,7 +125,7 @@ absl::Status NewSessionGrpc(const std::string& service_address, ::grpc::ClientContext context; std::unique_ptr stub = CreateStub(service_address); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( FromGrpcStatus(stub->NewSession(&context, request, response))); return absl::OkStatus(); } @@ -136,8 +136,7 @@ absl::Status MonitorGrpc(const std::string& service_address, ::grpc::ClientContext context; std::unique_ptr stub = CreateStub(service_address); - TF_RETURN_IF_ERROR( - FromGrpcStatus(stub->Monitor(&context, request, response))); + RETURN_IF_ERROR(FromGrpcStatus(stub->Monitor(&context, request, response))); return absl::OkStatus(); } diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc index 5c929d18a77497..f85d7daf8e7897 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/lib/gtl/map_util.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -117,7 +118,7 @@ absl::Status RemoteProfilerSessionManager::Init() { clients_.reserve(options_.service_addresses().size()); ProfileRequest request_template = request_; - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::vector override_hostnames_list, ParseAndValidateOverrideHostnames(options_, request_template)); diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc index 0f13135050ab47..341c45a821bad5 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/strip.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/lib/io/zlib_compression_options.h" #include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "xla/tsl/platform/env.h" @@ -57,7 +58,7 @@ absl::Status DumpToolData(absl::string_view run_dir, absl::string_view host, std::string host_prefix = host.empty() ? "" : absl::StrCat(host, "."); std::string path = ProfilerJoinPath(run_dir, absl::StrCat(host_prefix, tool.name())); - TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data())); + RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data())); if (os) { *os << "Dumped tool data for " << tool.name() << " to " << path << '\n'; } @@ -67,14 +68,14 @@ absl::Status DumpToolData(absl::string_view run_dir, absl::string_view host, absl::Status WriteGzippedDataToFile(const std::string& filepath, const std::string& data) { std::unique_ptr file; - TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(filepath, &file)); + RETURN_IF_ERROR(Env::Default()->NewWritableFile(filepath, &file)); io::ZlibCompressionOptions options = io::ZlibCompressionOptions::GZIP(); io::ZlibOutputBuffer buffer(file.get(), options.input_buffer_size, options.output_buffer_size, options); - TF_RETURN_IF_ERROR(buffer.Init()); - TF_RETURN_IF_ERROR(buffer.Append(data)); - TF_RETURN_IF_ERROR(buffer.Close()); - TF_RETURN_IF_ERROR(file->Close()); + RETURN_IF_ERROR(buffer.Init()); + RETURN_IF_ERROR(buffer.Append(data)); + RETURN_IF_ERROR(buffer.Close()); + RETURN_IF_ERROR(file->Close()); return absl::OkStatus(); } @@ -84,7 +85,7 @@ absl::Status GetOrCreateRunDir(const std::string& repository_root, // Creates a directory to //. *run_dir = ProfilerJoinPath(repository_root, run); *os << "Creating directory: " << *run_dir << '\n'; - TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(*run_dir)); + RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(*run_dir)); return absl::OkStatus(); } } // namespace @@ -103,11 +104,11 @@ absl::Status SaveProfile(const std::string& repository_root, return absl::OkStatus(); } std::string run_dir; - TF_RETURN_IF_ERROR(GetOrCreateRunDir(repository_root, run, &run_dir, os)); + RETURN_IF_ERROR(GetOrCreateRunDir(repository_root, run, &run_dir, os)); // Windows file names do not support colons. std::string hostname = absl::StrReplaceAll(host, {{":", "_"}}); for (const auto& tool_data : response.tool_data()) { - TF_RETURN_IF_ERROR(DumpToolData(run_dir, hostname, tool_data, os)); + RETURN_IF_ERROR(DumpToolData(run_dir, hostname, tool_data, os)); } return absl::OkStatus(); } @@ -121,11 +122,11 @@ absl::Status SaveGzippedToolData(const std::string& repository_root, std::stringstream ss; absl::Status status = GetOrCreateRunDir(repository_root, run, &run_dir, &ss); LOG(INFO) << ss.str(); - TF_RETURN_IF_ERROR(status); + RETURN_IF_ERROR(status); std::string host_prefix = host.empty() ? "" : absl::StrCat(host, "."); std::string path = ProfilerJoinPath(run_dir, absl::StrCat(host_prefix, tool_name)); - TF_RETURN_IF_ERROR(WriteGzippedDataToFile(path, data)); + RETURN_IF_ERROR(WriteGzippedDataToFile(path, data)); LOG(INFO) << "Dumped gzipped tool data for " << tool_name << " to " << path; return absl::OkStatus(); } @@ -140,7 +141,7 @@ absl::Status SaveXSpace(const std::string& repository_root, const tensorflow::profiler::XSpace& xspace) { std::string log_dir = ProfilerJoinPath(repository_root, run); VLOG(1) << "Creating " << log_dir; - TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(log_dir)); + RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(log_dir)); std::string file_name = absl::StrCat(host, ".", kXPlanePb); // Windows file names do not support colons. absl::StrReplaceAll({{":", "_"}}, &file_name); diff --git a/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc b/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc index f8ba0f63d81197..d6696e8a64ae21 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/platform/status_macros.h" #include "grpcpp/server_context.h" #include "grpcpp/support/status.h" #include "xla/tsl/platform/env.h" @@ -73,7 +74,7 @@ absl::Status CollectData(const ProfileRequest& request, tensorflow::profiler::XSpace xspace; tensorflow::profiler::XSpace* xspace_ptr = request.emit_xspace() ? response->mutable_xspace() : &xspace; - TF_RETURN_IF_ERROR(profiler->CollectData(xspace_ptr)); + RETURN_IF_ERROR(profiler->CollectData(xspace_ptr)); VLOG(3) << "Collected XSpace to " << (request.emit_xspace() ? "response" : "repository") << "."; response->set_empty_trace(IsEmpty(*xspace_ptr)); diff --git a/third_party/xla/xla/tsl/profiler/utils/BUILD b/third_party/xla/xla/tsl/profiler/utils/BUILD index c228179e919b06..838145b0aea77b 100644 --- a/third_party/xla/xla/tsl/profiler/utils/BUILD +++ b/third_party/xla/xla/tsl/profiler/utils/BUILD @@ -498,6 +498,7 @@ cc_library( deps = [ "//xla/tsl/platform:errors", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/tsl/profiler/utils/session_manager.cc b/third_party/xla/xla/tsl/profiler/utils/session_manager.cc index abdfba96f0ab12..906a6391a35713 100644 --- a/third_party/xla/xla/tsl/profiler/utils/session_manager.cc +++ b/third_party/xla/xla/tsl/profiler/utils/session_manager.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/errors.h" #include "tsl/profiler/lib/profiler_session.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" @@ -276,7 +277,7 @@ absl::Status ValidateRemoteProfilerSessionManagerOptions( } for (absl::string_view host_port : options.service_addresses()) { - TF_RETURN_IF_ERROR(ValidateHostPortPair(host_port)); + RETURN_IF_ERROR(ValidateHostPortPair(host_port)); } if (options.max_session_duration_ms() < diff --git a/third_party/xla/xla/tsl/testing/BUILD b/third_party/xla/xla/tsl/testing/BUILD index 9d57a6296794e8..c7077514ef46f7 100644 --- a/third_party/xla/xla/tsl/testing/BUILD +++ b/third_party/xla/xla/tsl/testing/BUILD @@ -18,6 +18,7 @@ cc_library( deps = [ "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/tsl/testing/temporary_directory.cc b/third_party/xla/xla/tsl/testing/temporary_directory.cc index a4524fa78fdb9d..8948bd1c8807c8 100644 --- a/third_party/xla/xla/tsl/testing/temporary_directory.cc +++ b/third_party/xla/xla/tsl/testing/temporary_directory.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" #include "tsl/platform/path.h" @@ -35,7 +36,7 @@ absl::StatusOr TemporaryDirectory::CreateForTestcase( std::string path = tsl::io::JoinPath(::testing::TempDir(), "xla_testing_tmp", test_info.test_suite_name(), test_info.name()); - TF_RETURN_IF_ERROR(tsl::Env::Default()->RecursivelyCreateDir(path)); + RETURN_IF_ERROR(tsl::Env::Default()->RecursivelyCreateDir(path)); return TemporaryDirectory(std::move(path)); } diff --git a/third_party/xla/xla/tsl/util/BUILD b/third_party/xla/xla/tsl/util/BUILD index e0e2bd2acfc601..b3b312d030fa59 100644 --- a/third_party/xla/xla/tsl/util/BUILD +++ b/third_party/xla/xla/tsl/util/BUILD @@ -226,6 +226,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:status", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -248,6 +249,7 @@ cc_library( "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:macros", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:types", "//xla/tsl/protobuf:test_log_proto_cc", "@com_google_absl//absl/status", @@ -459,6 +461,7 @@ cc_binary( deps = [ "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", @@ -545,6 +548,7 @@ cc_library( "//xla/tsl/platform:embedded_filesystem", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", diff --git a/third_party/xla/xla/tsl/util/env_var.cc b/third_party/xla/xla/tsl/util/env_var.cc index c64c467e9c84a3..75a48d983f36cb 100644 --- a/third_party/xla/xla/tsl/util/env_var.cc +++ b/third_party/xla/xla/tsl/util/env_var.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" #include "tsl/platform/numbers.h" @@ -99,7 +100,7 @@ absl::Status ReadStringsFromEnvVar(absl::string_view env_var_name, std::vector* value, absl::string_view delimiters) { std::string str_val; - TF_RETURN_IF_ERROR(ReadStringFromEnvVar(env_var_name, default_val, &str_val)); + RETURN_IF_ERROR(ReadStringFromEnvVar(env_var_name, default_val, &str_val)); std::vector parts = absl::StrSplit( str_val, absl::ByAnyChar(delimiters), absl::SkipWhitespace()); value->clear(); diff --git a/third_party/xla/xla/tsl/util/filewrapper.cc b/third_party/xla/xla/tsl/util/filewrapper.cc index d2b0dfbc199bff..c42815eacc48da 100755 --- a/third_party/xla/xla/tsl/util/filewrapper.cc +++ b/third_party/xla/xla/tsl/util/filewrapper.cc @@ -84,6 +84,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/MD5.h" #include "xla/tsl/platform/env.h" @@ -172,7 +173,7 @@ absl::StatusOr> ExpandDirs( return absl::FailedPreconditionError(absl::StrCat( "filewrapper: refusing to process dir '", filename, "'")); } else if (s.ok()) { - TF_RETURN_IF_ERROR(env.GetChildren(filename, &to_process)); + RETURN_IF_ERROR(env.GetChildren(filename, &to_process)); } else if (absl::IsFailedPrecondition(s)) { allfiles.push_back(filename); } @@ -558,9 +559,9 @@ absl::Status WriteCpp(tsl::Env* env, const std::string& cc_filename, // simple cross-platform way to truncate files, so we just read and write // again. std::string contents; - TF_RETURN_IF_ERROR(tsl::ReadFileToString(env, cc_filename, &contents)); + RETURN_IF_ERROR(tsl::ReadFileToString(env, cc_filename, &contents)); contents.resize(end_pos); - TF_RETURN_IF_ERROR(tsl::WriteStringToFile(env, cc_filename, contents)); + RETURN_IF_ERROR(tsl::WriteStringToFile(env, cc_filename, contents)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/tsl/util/memfile_builtin.cc b/third_party/xla/xla/tsl/util/memfile_builtin.cc index b82ef58e59db87..1f35c146afee45 100644 --- a/third_party/xla/xla/tsl/util/memfile_builtin.cc +++ b/third_party/xla/xla/tsl/util/memfile_builtin.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/embedded_filesystem.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" @@ -55,7 +56,7 @@ absl::Status RegisterBuiltInFiles(const char* absl_nonnull name, // conditionally. We're running at global-init time, before flags have been // parsed, so VLOG is out, and any standard log level will result in RAW_LOG // on stderr. - TF_RETURN_IF_ERROR(global_file_system().EmbedFile(path, contents)); + RETURN_IF_ERROR(global_file_system().EmbedFile(path, contents)); } return absl::OkStatus(); diff --git a/third_party/xla/xla/tsl/util/reporter.cc b/third_party/xla/xla/tsl/util/reporter.cc index 0b4eb6107690d2..6fe54fb383f591 100644 --- a/third_party/xla/xla/tsl/util/reporter.cc +++ b/third_party/xla/xla/tsl/util/reporter.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "xla/tsl/platform/status_macros.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/types.h" @@ -57,8 +58,8 @@ absl::Status TestReportFile::Initialize() { return absl::InvalidArgumentError(absl::StrCat( "Cannot create TestReportFile, file exists: ", mangled_fname)); } - TF_RETURN_IF_ERROR(env->NewWritableFile(mangled_fname, &log_file_)); - TF_RETURN_IF_ERROR(log_file_->Flush()); + RETURN_IF_ERROR(env->NewWritableFile(mangled_fname, &log_file_)); + RETURN_IF_ERROR(log_file_->Flush()); closed_ = false; return absl::OkStatus(); @@ -77,7 +78,7 @@ absl::Status TestReporter::Close() { tensorflow::BenchmarkEntries entries; *entries.add_entry() = benchmark_entry_; - TF_RETURN_IF_ERROR(report_file_.Append(entries.SerializeAsString())); + RETURN_IF_ERROR(report_file_.Append(entries.SerializeAsString())); benchmark_entry_.Clear(); return report_file_.Close(); diff --git a/third_party/xla/xla/util/BUILD b/third_party/xla/xla/util/BUILD index 2a3bb9db2f523f..5ff87718290b65 100644 --- a/third_party/xla/xla/util/BUILD +++ b/third_party/xla/xla/util/BUILD @@ -66,6 +66,7 @@ cc_library( "//xla/backends/cpu:alignment", "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/util/embedded_constant_buffers.cc b/third_party/xla/xla/util/embedded_constant_buffers.cc index c9d7502c0f6c76..0e6afe19f73583 100644 --- a/third_party/xla/xla/util/embedded_constant_buffers.cc +++ b/third_party/xla/xla/util/embedded_constant_buffers.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/tsl/platform/status_macros.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Constants.h" @@ -132,8 +133,8 @@ GetTargetMachineFromTriple(absl::string_view target_triple) { absl::StatusOr CreateEmbeddedConstantBuffers( absl::string_view target_triple, absl::Span constants_to_embed) { - TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, - GetTargetMachineFromTriple(target_triple)); + ASSIGN_OR_RETURN(std::unique_ptr target_machine, + GetTargetMachineFromTriple(target_triple)); llvm::LLVMContext llvm_context; auto module_with_serialized_proto = std::make_unique( @@ -144,7 +145,7 @@ absl::StatusOr CreateEmbeddedConstantBuffers( for (const ConstantToEmbed& constant_to_embed : constants_to_embed) { std::string constant_array_symbol_name; - TF_RETURN_IF_ERROR(AddBufferToLlvmModule( + RETURN_IF_ERROR(AddBufferToLlvmModule( module_with_serialized_proto.get(), constant_to_embed, constant_to_embed.symbol_prefix, constant_array_symbol_name)); @@ -171,9 +172,9 @@ absl::StatusOr CreateEmbeddedConstantBuffers( {constant_array_symbol_name, cpp_variable_decl, cpp_access_shim}); } - TF_ASSIGN_OR_RETURN(result.object_file_data, - CodegenModule(target_machine.get(), - std::move(module_with_serialized_proto))); + ASSIGN_OR_RETURN(result.object_file_data, + CodegenModule(target_machine.get(), + std::move(module_with_serialized_proto))); return result; } diff --git a/third_party/xla/xla/util/split_proto/BUILD b/third_party/xla/xla/util/split_proto/BUILD index 874c3bf3a89d40..9a3b71b7c26d70 100644 --- a/third_party/xla/xla/util/split_proto/BUILD +++ b/third_party/xla/xla/util/split_proto/BUILD @@ -33,6 +33,7 @@ cc_library( ":split_proto_cc", "//xla/service/gpu:gpu_executable_proto_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -113,6 +114,7 @@ cc_library( "//xla/service:hlo_proto_util", "//xla/service/gpu:gpu_executable_proto_cc", "//xla/tsl/platform:errors", + "//xla/tsl/platform:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", diff --git a/third_party/xla/xla/util/split_proto/split_gpu_executable_writer.cc b/third_party/xla/xla/util/split_proto/split_gpu_executable_writer.cc index 4a813f88f3f3fe..f9645afc6ad0e0 100644 --- a/third_party/xla/xla/util/split_proto/split_gpu_executable_writer.cc +++ b/third_party/xla/xla/util/split_proto/split_gpu_executable_writer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/tsl/platform/status_macros.h" #include "google/protobuf/message.h" #include "riegeli/bytes/writer.h" #include "riegeli/records/record_writer.h" @@ -77,7 +78,7 @@ absl::Status NormalizeBackendConfig(gpu::GpuExecutableProto& executable) { ->mutable_computations()) { for (HloInstructionProto& instruction : *computation.mutable_instructions()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( std::string backend_config_str, GetBackendConfigString( instruction, &executable.hlo_module_with_config().hlo_module())); @@ -119,33 +120,33 @@ absl::Status WriteSplitGpuExecutable(gpu::GpuExecutableProto executable, riegeli::RecordWriter record_writer(std::move(writer), GetSplitProtoRiegeliOptions()); SplitProtoManifest manifest = BuildManifest(executable.constants_size()); - TF_RETURN_IF_ERROR(WriteRecord(record_writer, manifest)); + RETURN_IF_ERROR(WriteRecord(record_writer, manifest)); - TF_RETURN_IF_ERROR(WriteRecord(record_writer, executable.asm_text())); + RETURN_IF_ERROR(WriteRecord(record_writer, executable.asm_text())); executable.clear_asm_text(); - TF_RETURN_IF_ERROR(WriteRecord(record_writer, executable.binary())); + RETURN_IF_ERROR(WriteRecord(record_writer, executable.binary())); executable.clear_binary(); gpu::GpuExecutableProto dnn_graphs_wrapper; *dnn_graphs_wrapper.mutable_dnn_compiled_graphs() = std::move(executable.dnn_compiled_graphs()); executable.clear_dnn_compiled_graphs(); - TF_RETURN_IF_ERROR(WriteRecord(record_writer, dnn_graphs_wrapper)); + RETURN_IF_ERROR(WriteRecord(record_writer, dnn_graphs_wrapper)); for (gpu::GpuExecutableProto::ConstantInfoProto& constant : *executable.mutable_constants()) { gpu::GpuExecutableProto constant_wrapper; *constant_wrapper.add_constants() = std::move(constant); - TF_RETURN_IF_ERROR(WriteRecord(record_writer, constant_wrapper)); + RETURN_IF_ERROR(WriteRecord(record_writer, constant_wrapper)); } executable.clear_constants(); // The rest of the fields (i.e. the non-offloaded fields) - TF_RETURN_IF_ERROR(NormalizeBackendConfig(executable)); + RETURN_IF_ERROR(NormalizeBackendConfig(executable)); // Module IDs are created via a static counter when deserializing, and they // can cause non-determinism, so we don't preserve them. executable.mutable_hlo_module_with_config()->mutable_hlo_module()->clear_id(); - TF_RETURN_IF_ERROR(WriteRecord(record_writer, executable)); + RETURN_IF_ERROR(WriteRecord(record_writer, executable)); if (!record_writer.Close()) { return record_writer.status(); diff --git a/third_party/xla/xla/util/split_proto/split_proto_reader.cc b/third_party/xla/xla/util/split_proto/split_proto_reader.cc index 166c1913c6e791..b8874ddd97b76f 100644 --- a/third_party/xla/xla/util/split_proto/split_proto_reader.cc +++ b/third_party/xla/xla/util/split_proto/split_proto_reader.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/tsl/platform/status_macros.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/reflection.h" @@ -52,7 +53,7 @@ template absl::Status HandleProtoMergeRecord(riegeli::RecordReader& record_reader, google::protobuf::Message& proto) { absl::string_view record_data; - TF_RETURN_IF_ERROR(ReadRecord(record_reader, record_data)); + RETURN_IF_ERROR(ReadRecord(record_reader, record_data)); if (!proto.MergeFromString(record_data)) { return absl::InternalError("Failed to parse proto merge record"); @@ -106,9 +107,9 @@ absl::Status HandleFieldOverrideRecord( riegeli::RecordReader& record_reader, google::protobuf::Message& proto, const Record& record) { std::string record_data; - TF_RETURN_IF_ERROR(ReadRecord(record_reader, record_data)); + RETURN_IF_ERROR(ReadRecord(record_reader, record_data)); - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( ReadOverrideFieldRecord(proto, std::move(record_data), record)); return absl::OkStatus(); } @@ -139,11 +140,11 @@ absl::Status ReadSplitProto(std::unique_ptr reader, for (const Record& record : manifest.records()) { switch (record.record_type_case()) { case Record::kProtoMergeRecord: { - TF_RETURN_IF_ERROR(HandleProtoMergeRecord<>(record_reader, proto)); + RETURN_IF_ERROR(HandleProtoMergeRecord<>(record_reader, proto)); break; } case Record::kFieldOverrideRecord: { - TF_RETURN_IF_ERROR( + RETURN_IF_ERROR( HandleFieldOverrideRecord<>(record_reader, proto, record)); break; } diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 5a4f71c68086d9..8918274e908ab9 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -742,10 +742,17 @@ message DebugOptions { optional bool xla_gpu_executable_embed_debug_info = 437; - // Timeout to terminate on stuck rendez-vous. + // Number of additional communication streams to allocate for a GPU + // executable. + optional int32 xla_gpu_executable_num_communication_streams = 487; + + // Number of additional compute streams to allocate for a GPU executable. + optional int32 xla_gpu_executable_num_compute_streams = 486; + + // Timeout to terminate on stuck rendezvous. optional int32 xla_gpu_executable_terminate_timeout_seconds = 328; - // Timeout to issue a warning on stuck rendez-vous. + // Timeout to issue a warning on stuck rendezvous. optional int32 xla_gpu_executable_warn_stuck_timeout_seconds = 327; // Number of thunks to track for execution progress reporting. When this @@ -1572,7 +1579,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 485 + // Next id: 489 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.