From adc45f3048696e6a4dab229a1c60d2ac6bb4823a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 15 May 2026 10:39:14 -0700 Subject: [PATCH] Added logic to securely initialize `Indvar` at the start of a while loop sequence specifically on the parallel test worker threading context. `Indvar` previously experienced race conditions as it evaluated on uninitialized thread-local memory on the parallel execution streams of `ExecuteOnStream`. PiperOrigin-RevId: 916076636 --- .../gpu/codegen/dynamic_slice_fusion_test.cc | 110 +++++++++++++++++- .../xla/xla/backends/gpu/runtime/BUILD | 1 + .../gpu/runtime/dynamic_slice_thunk.cc | 40 +++++-- 3 files changed, 140 insertions(+), 11 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc b/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc index cc7338f2b1ce44..8364f785d69701 100644 --- a/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc @@ -3535,9 +3535,26 @@ TEST_F(DynamicSliceFusionTest, MultipleOffsetsAsFunctionOfInductionVariable) { ROOT while = (s32[], s32[16,32,32], s32[32,32]) while(tuple), body=body, condition=condition } )"; + (void)hlo_fused; + + HloModuleConfig config; + config.mutable_debug_options().set_xla_gpu_enable_dynamic_slice_fusion(true); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr unfused_module, + ParseAndReturnVerifiedModule(hlo_unfused)); + std::unique_ptr fused_module = unfused_module->Clone(); + fused_module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_dynamic_slice_fusion(true); + TF_ASSERT_OK_AND_ASSIGN(fused_module, + GetOptimizedModule(std::move(fused_module))); + unfused_module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_dynamic_slice_fusion(false); + TF_ASSERT_OK_AND_ASSIGN(unfused_module, + GetOptimizedModule(std::move(unfused_module))); EXPECT_TRUE(RunAndCompareTwoModulesReplicated( - /*module_0_str=*/hlo_unfused, /*module_1_str=*/hlo_fused, + std::move(unfused_module), std::move(fused_module), /*run_hlo_passes=*/false, /*use_threads=*/true, std::nullopt)); } @@ -3672,6 +3689,97 @@ TEST_F(DynamicSliceFusionTest, WhileLoopSliceWithNoInductionVariable) { /*use_threads=*/true, error_spec)); } +TEST_F(DynamicSliceFusionTest, TriggerMemoryCorruption) { + std::string hlo = R"( + HloModule test + + dynamic-slice-fusion { + p0 = s32[100] parameter(0) + p1 = s32[100] parameter(1) + p2 = s32[100] parameter(2) + p3 = s32[100] parameter(3) + p4 = s32[100] parameter(4) + p5 = s32[100] parameter(5) + p6 = s32[100] parameter(6) + p7 = s32[100] parameter(7) + p8 = s32[100] parameter(8) + p9 = s32[100] parameter(9) + o0 = s32[] parameter(10) + o1 = s32[] parameter(11) + o2 = s32[] parameter(12) + o3 = s32[] parameter(13) + o4 = s32[] parameter(14) + o5 = s32[] parameter(15) + o6 = s32[] parameter(16) + o7 = s32[] parameter(17) + o8 = s32[] parameter(18) + o9 = s32[] parameter(19) + + ds0 = s32[1] dynamic-slice(p0, o0), dynamic_slice_sizes={1} + ds1 = s32[1] dynamic-slice(p1, o1), dynamic_slice_sizes={1} + ds2 = s32[1] dynamic-slice(p2, o2), dynamic_slice_sizes={1} + ds3 = s32[1] dynamic-slice(p3, o3), dynamic_slice_sizes={1} + ds4 = s32[1] dynamic-slice(p4, o4), dynamic_slice_sizes={1} + ds5 = s32[1] dynamic-slice(p5, o5), dynamic_slice_sizes={1} + ds6 = s32[1] dynamic-slice(p6, o6), dynamic_slice_sizes={1} + ds7 = s32[1] dynamic-slice(p7, o7), dynamic_slice_sizes={1} + ds8 = s32[1] dynamic-slice(p8, o8), dynamic_slice_sizes={1} + ds9 = s32[1] dynamic-slice(p9, o9), dynamic_slice_sizes={1} + + add0 = s32[1] add(ds0, ds1) + add1 = s32[1] add(add0, ds2) + add2 = s32[1] add(add1, ds3) + add3 = s32[1] add(add2, ds4) + add4 = s32[1] add(add3, ds5) + add5 = s32[1] add(add4, ds6) + add6 = s32[1] add(add5, ds7) + add7 = s32[1] add(add6, ds8) + ROOT root = s32[1] add(add7, ds9) + } + + ENTRY main { + p0 = s32[100] parameter(0) + p1 = s32[100] parameter(1) + p2 = s32[100] parameter(2) + p3 = s32[100] parameter(3) + p4 = s32[100] parameter(4) + p5 = s32[100] parameter(5) + p6 = s32[100] parameter(6) + p7 = s32[100] parameter(7) + p8 = s32[100] parameter(8) + p9 = s32[100] parameter(9) + o0 = s32[] constant(0) + o1 = s32[] constant(0) + o2 = s32[] constant(0) + o3 = s32[] constant(0) + o4 = s32[] constant(0) + o5 = s32[] constant(0) + o6 = s32[] constant(0) + o7 = s32[] constant(0) + o8 = s32[] constant(0) + o9 = s32[] constant(0) + + ROOT fusion = s32[1] fusion(p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, o0, o1, o2, o3, o4, o5, o6, o7, o8, o9), kind=kCustom, calls=dynamic-slice-fusion, backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"} }} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + std::unique_ptr m_fused = m->Clone(); + m_fused->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_dynamic_slice_fusion(true); + TF_ASSERT_OK_AND_ASSIGN(m_fused, GetOptimizedModule(std::move(m_fused))); + + m->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_dynamic_slice_fusion(false); + TF_ASSERT_OK_AND_ASSIGN(m, GetOptimizedModule(std::move(m))); + + EXPECT_TRUE(RunAndCompareTwoModulesReplicated( + std::move(m), std::move(m_fused), /*run_hlo_passes=*/false, + /*use_threads=*/true, std::nullopt)); +} } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 6283fb6cf828de..95978582b4d036 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -536,6 +536,7 @@ cc_library( ":thunk", ":thunk_executor", ":thunk_proto_cc", + ":while_loop", "//xla:literal", "//xla:literal_util", "//xla:shape_util", diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc index bed0d425f956dd..cf32e8a8bc7eff 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -41,6 +42,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/dynamic_slice_thunk.pb.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" +#include "xla/backends/gpu/runtime/while_loop.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" @@ -209,8 +211,7 @@ DynamicSliceThunk::DynamicSliceThunk( for (SliceDef& slice : slices_) { offsets_allocs_base_.push_back(offsets_allocs_size_); if (slice.sliced_shape.has_value()) { - offsets_allocs_size_ += - slice.sliced_shape->dimensions().size() * sizeof(int64_t); + offsets_allocs_size_ += slice.sliced_shape->dimensions().size(); } } } @@ -237,12 +238,11 @@ absl::Status DynamicSliceThunk::Prepare(const PrepareParams& params) { RETURN_IF_ERROR(embedded_executor_.Prepare(params)); if (offset_as_function_of_indvar_metadata_.has_value()) { - Indvar(this) = - HloEvaluator() - .Evaluate( - /*module=*/*offset_as_function_of_indvar_metadata_->indvar_init, - /*args=*/{}) - .value(); + ASSIGN_OR_RETURN( + Indvar(this), + HloEvaluator().Evaluate( + /*module=*/*offset_as_function_of_indvar_metadata_->indvar_init, + /*args=*/{})); VLOG(2) << "Indvar init module: " << offset_as_function_of_indvar_metadata_->indvar_init->ToString(); VLOG(2) @@ -260,11 +260,21 @@ absl::Status DynamicSliceThunk::Initialize(const InitializeParams& params) { if (offsets_allocs_.contains(params.executor)) { return absl::OkStatus(); } + if (offset_as_function_of_indvar_metadata_.has_value()) { + ASSIGN_OR_RETURN( + Indvar(this), + HloEvaluator().Evaluate( + *offset_as_function_of_indvar_metadata_->indvar_init, {})); + VLOG(2) << "Initialize Indvar on worker thread = " + << Indvar(this).ToString(); + } - VLOG(2) << "Allocate " << offsets_allocs_size_ + VLOG(2) << "Allocate " << offsets_allocs_size_ * sizeof(int64_t) << " bytes for transferring offsets on executor: " << params.executor; ASSIGN_OR_RETURN(std::unique_ptr allocation, - params.executor->HostMemoryAllocate(offsets_allocs_size_)); + params.executor->HostMemoryAllocate(offsets_allocs_size_ * + sizeof(int64_t))); + memset(allocation->opaque(), 0, offsets_allocs_size_ * sizeof(int64_t)); offsets_allocs_.emplace(params.executor, std::move(allocation)); return absl::OkStatus(); @@ -284,6 +294,16 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) { offsets_allocs_.at(stream.parent())->address().opaque()); }(); + if (offset_as_function_of_indvar_metadata_.has_value()) { + if (const WhileLoopState* state = IsInsideWhileLoop(); + state && state->loop_iteration == 0) { + ASSIGN_OR_RETURN( + Indvar(this), + HloEvaluator().Evaluate( + *offset_as_function_of_indvar_metadata_->indvar_init, {})); + } + } + auto offset_value = [&](int64_t arg_idx, int64_t offset_idx) -> int64_t& { return offsets_alloc[offsets_allocs_base_.at(arg_idx) + offset_idx]; };