Skip to content

Commit

Permalink
fix a bug when partitioning scatter instruction with same operands
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714328364
  • Loading branch information
Google-ML-Automation committed Jan 11, 2025
1 parent 0cb8b51 commit a0acdf5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 14 deletions.
32 changes: 18 additions & 14 deletions xla/service/spmd/gather_scatter_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ PartitionedHlo PerGroupPartitionedHlo(
// Helper to get multiple per-group partitioned hlos.
std::vector<PartitionedHlo> PerGroupPartitionedHlos(
std::vector<PartitionedHlo>& phlos, const GroupedSharding& grouped_sharding,
SpmdBuilder* b, absl::InlinedVector<std::function<void()>, 3>& clean_ups) {
SpmdBuilder* b, absl::InlinedVector<std::function<void()>, 3>& clean_ups,
absl::flat_hash_map<HloInstruction*, PartitionedHlo>&
cached_per_group_hlos) {
// Cache per-group partitioned hlos to avoid group-partitioning it more than
// once.
absl::flat_hash_map<HloInstruction*, PartitionedHlo> cached_per_group_hlos;
std::vector<HloInstruction*> hlos;
absl::c_transform(phlos, std::back_inserter(hlos),
[&](PartitionedHlo phlo) { return phlo.hlo(); });
Expand Down Expand Up @@ -1230,10 +1231,11 @@ absl::StatusOr<HloInstruction*> PartitionScatterParallelDimensions(
updates[0].sharding(), update_parallel_dims),
new_indices_grouped);
const GroupedSharding& output_grouped = operand_grouped;
std::vector<PartitionedHlo> per_group_operands =
PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups);
std::vector<PartitionedHlo> per_group_updates =
PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups);
absl::flat_hash_map<HloInstruction*, PartitionedHlo> cached_per_group_hlos;
std::vector<PartitionedHlo> per_group_operands = PerGroupPartitionedHlos(
operands, operand_grouped, b, clean_ups, cached_per_group_hlos);
std::vector<PartitionedHlo> per_group_updates = PerGroupPartitionedHlos(
updates, update_grouped, b, clean_ups, cached_per_group_hlos);
PartitionedHlo per_group_new_indices =
PerGroupPartitionedHlo(new_indices, new_indices_grouped, b, clean_ups);
auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape);
Expand Down Expand Up @@ -1367,10 +1369,11 @@ absl::StatusOr<HloInstruction*> PartitionScatterOperandPassthroughDimensions(
ScatterIndexDimsByPriority(scatter)),
update_grouped);
const GroupedSharding& output_grouped = operand_grouped;
std::vector<PartitionedHlo> per_group_operands =
PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups);
std::vector<PartitionedHlo> per_group_updates =
PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups);
absl::flat_hash_map<HloInstruction*, PartitionedHlo> cached_per_group_hlos;
std::vector<PartitionedHlo> per_group_operands = PerGroupPartitionedHlos(
operands, operand_grouped, b, clean_ups, cached_per_group_hlos);
std::vector<PartitionedHlo> per_group_updates = PerGroupPartitionedHlos(
updates, update_grouped, b, clean_ups, cached_per_group_hlos);
PartitionedHlo per_group_indices =
PerGroupPartitionedHlo(indices, indices_grouped, b, clean_ups);
auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape);
Expand Down Expand Up @@ -1623,10 +1626,11 @@ absl::StatusOr<HloInstruction*> PartitionScatterTrivialSlicedOperandDimensions(
indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(),
indices_min));
PartitionedHlo new_indices = indices.CloneWithNewHlo(adjusted_indices);
std::vector<PartitionedHlo> per_group_operands =
PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups);
std::vector<PartitionedHlo> per_group_updates =
PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups);
absl::flat_hash_map<HloInstruction*, PartitionedHlo> cached_per_group_hlos;
std::vector<PartitionedHlo> per_group_operands = PerGroupPartitionedHlos(
operands, operand_grouped, b, clean_ups, cached_per_group_hlos);
std::vector<PartitionedHlo> per_group_updates = PerGroupPartitionedHlos(
updates, update_grouped, b, clean_ups, cached_per_group_hlos);
PartitionedHlo per_group_new_indices =
PerGroupPartitionedHlo(new_indices, indices_grouped, b, clean_ups);
auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape);
Expand Down
25 changes: 25 additions & 0 deletions xla/service/spmd/spmd_partitioner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14871,6 +14871,31 @@ ENTRY %main.21 {
EXPECT_THAT(updates, op::Shape("bf16[4096,64]"));
}

TEST_P(SpmdPartitioningTest, ScatterSameInputSharding) {
const char* const hlo_string = R"(
HloModule pjit

%region_3.1507 {
%Arg_0.4576 = s32[] parameter(0)
%Arg_1.4577 = s32[] parameter(1)
ROOT %add.4578 = s32[] add(%Arg_0.4576, %Arg_1.4577)
}

ENTRY %main.21 {
broadcast.4498 = s32[8,3072]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
concatenate.103 = s32[8,3072,2]{2,1,0} parameter(1), sharding={devices=[4,1,1]<=[4]}
ROOT scatter.36 = s32[8,3072]{1,0} scatter(broadcast.4498, concatenate.103, broadcast.4498), update_window_dims={}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=2, to_apply=region_3.1507, sharding={devices=[4,1]<=[4]}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/jvp(while)/body/jvp(model[Model])/decoder[SelectivePadDecoder]/padder[PadRandomFromList]/scatter" source_file="/combined-code/ajax-code/ajax/experiments/tkncomp/selpad.py" source_line=106}
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));

XLA_VLOG_LINES(0, module->ToString());
auto* scatter = FindInstruction(module.get(), HloOpcode::kScatter);
EXPECT_NE(scatter, nullptr);
}

TEST_P(SpmdPartitioningTest, ComplexReshardUnmerge) {
const char* const hlo_string = R"(
HloModule Test
Expand Down

0 comments on commit a0acdf5

Please sign in to comment.