From c5ddd4e80bb9adc89b4874b5750c981e22376759 Mon Sep 17 00:00:00 2001 From: xla authors Date: Thu, 9 Jan 2025 13:12:59 -0800 Subject: [PATCH] Allows suboptimal solutions for partial mesh shapes when given a *hard* memory budget constraint. PiperOrigin-RevId: 713772425 --- .../auto_sharding/auto_sharding.cc | 8 +++- .../auto_sharding/auto_sharding_test.cc | 38 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 57a3533a4d950..1d19ce6757cbc 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3521,6 +3521,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( bool module_is_changed = false; bool set_to_memory_lower_bound = (option_.memory_budget_per_device == 0); + bool hard_memory_constraint = (option_.memory_budget_ratio < 0); // Remove CustomCalls with custom_call_target="Sharding" and move their // shardings to their input ops. @@ -3684,7 +3685,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( option_.memory_budget_per_device = memory_lower_bound * std::abs(option_.memory_budget_ratio); // TODO(b/341299984): Document this flag syntax, or automate the behavior. - if (option_.memory_budget_ratio < 0) { + if (hard_memory_constraint) { option_.memory_overbudget_coeff = -1.0; // Disables the soft constraint } } else if (option_.memory_budget_per_device > 0) { @@ -3807,7 +3808,12 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( option_, request_name, sharding_propagation_solution)); if (mesh_idx == partial_mesh_shapes.size() - 1) { this->solver_optimal_objective_value_ = output.cost; + } else if (hard_memory_constraint) { + // If the memory budget constraint is *hard*, we're already guaranteed + // that this intermediate solution honors the maximum value. } else { + // If the memory budget constraint is *soft*, we require the intermediate + // solution to be optimal (since otherwise, it's probably degenerate). TF_RET_CHECK(output.is_optimal) << "The solver did not find an optimal solution for a partial mesh " << "shape."; diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index c4065bf05066f..660b344b3bdb7 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -3124,6 +3124,44 @@ ENTRY %entry { op::Sharding("{devices=[8,16]<=[128] last_tile_dim_replicate}")); } +TEST_F(AutoShardingTest, NegativeMemoryBudgetRatioTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +region { + Arg_0 = s32[] parameter(0) + ROOT Arg_1 = s32[] parameter(1) +} + +ENTRY %Scatter { + call = s32[4,128]{1,0} parameter(0) + clamp = s32[4,2]{1,0} parameter(1) + broadcast = s32[4,8]{1,0} parameter(2) + ROOT scatter = s32[4,128]{1,0} scatter(call, clamp, broadcast), update_window_dims={1}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=region +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.enable = true; + option.device_mesh_shape = {2, 2}; + option.device_mesh_ids = {0, 1, 2, 3}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + // Memory budget a tad higher than what would be required if the largest + // tensors are sharded 4-ways + option.memory_budget_per_device = 0; + option.memory_budget_ratio = -1.1; // Disables the soft memory constraint. + + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* scatter = FindInstruction(module.get(), "scatter"); + ASSERT_NE(scatter, nullptr); + EXPECT_EQ(scatter->sharding().NumTiles(), 4); + TF_EXPECT_OK(scatter->sharding().Validate(scatter->shape(), 4)); +} + TEST(NormalizeTest, NormalizeHandlesNegativeCosts) { EdgeReshardingCostMatrix edge_cost(2, 2); edge_cost(0, 0).communication_cost = -100;