Skip to content

Commit

Permalink
Allows suboptimal solutions for partial mesh shapes when given a *har…
Browse files Browse the repository at this point in the history
…d* memory budget constraint.

PiperOrigin-RevId: 713772425
  • Loading branch information
Google-ML-Automation committed Jan 11, 2025
1 parent 72069fd commit c5ddd4e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
8 changes: 7 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3521,6 +3521,7 @@ absl::StatusOr<bool> AutoShardingImplementation::RunAutoSharding(
bool module_is_changed = false;

bool set_to_memory_lower_bound = (option_.memory_budget_per_device == 0);
bool hard_memory_constraint = (option_.memory_budget_ratio < 0);

// Remove CustomCalls with custom_call_target="Sharding" and move their
// shardings to their input ops.
Expand Down Expand Up @@ -3684,7 +3685,7 @@ absl::StatusOr<bool> AutoShardingImplementation::RunAutoSharding(
option_.memory_budget_per_device =
memory_lower_bound * std::abs(option_.memory_budget_ratio);
// TODO(b/341299984): Document this flag syntax, or automate the behavior.
if (option_.memory_budget_ratio < 0) {
if (hard_memory_constraint) {
option_.memory_overbudget_coeff = -1.0; // Disables the soft constraint
}
} else if (option_.memory_budget_per_device > 0) {
Expand Down Expand Up @@ -3807,7 +3808,12 @@ absl::StatusOr<bool> AutoShardingImplementation::RunAutoSharding(
option_, request_name, sharding_propagation_solution));
if (mesh_idx == partial_mesh_shapes.size() - 1) {
this->solver_optimal_objective_value_ = output.cost;
} else if (hard_memory_constraint) {
// If the memory budget constraint is *hard*, we're already guaranteed
// that this intermediate solution honors the maximum value.
} else {
// If the memory budget constraint is *soft*, we require the intermediate
// solution to be optimal (since otherwise, it's probably degenerate).
TF_RET_CHECK(output.is_optimal)
<< "The solver did not find an optimal solution for a partial mesh "
<< "shape.";
Expand Down
38 changes: 38 additions & 0 deletions xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3124,6 +3124,44 @@ ENTRY %entry {
op::Sharding("{devices=[8,16]<=[128] last_tile_dim_replicate}"));
}

TEST_F(AutoShardingTest, NegativeMemoryBudgetRatioTest) {
constexpr absl::string_view kHloString = R"(
HloModule module
region {
Arg_0 = s32[] parameter(0)
ROOT Arg_1 = s32[] parameter(1)
}
ENTRY %Scatter {
call = s32[4,128]{1,0} parameter(0)
clamp = s32[4,2]{1,0} parameter(1)
broadcast = s32[4,8]{1,0} parameter(2)
ROOT scatter = s32[4,128]{1,0} scatter(call, clamp, broadcast), update_window_dims={1}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=region
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(kHloString));
AutoShardingOption option;
option.enable = true;
option.device_mesh_shape = {2, 2};
option.device_mesh_ids = {0, 1, 2, 3};
option.device_mesh_alpha = {1.0, 1.0};
option.device_mesh_beta = {0.01, 1.0};
// Memory budget a tad higher than what would be required if the largest
// tensors are sharded 4-ways
option.memory_budget_per_device = 0;
option.memory_budget_ratio = -1.1; // Disables the soft memory constraint.

TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
VLOG(10) << module->ToString();
EXPECT_TRUE(changed);
const HloInstruction* scatter = FindInstruction(module.get(), "scatter");
ASSERT_NE(scatter, nullptr);
EXPECT_EQ(scatter->sharding().NumTiles(), 4);
TF_EXPECT_OK(scatter->sharding().Validate(scatter->shape(), 4));
}

TEST(NormalizeTest, NormalizeHandlesNegativeCosts) {
EdgeReshardingCostMatrix edge_cost(2, 2);
edge_cost(0, 0).communication_cost = -100;
Expand Down

0 comments on commit c5ddd4e

Please sign in to comment.