From e7fa101f19a4cb57baf28394328b3eb4e46dff87 Mon Sep 17 00:00:00 2001 From: xla authors Date: Fri, 10 Jan 2025 21:01:21 -0800 Subject: [PATCH] fix a bug when partitioning scatter instruction with same operands PiperOrigin-RevId: 714328364 --- xla/service/spmd/gather_scatter_handler.cc | 32 ++++++++++++---------- xla/service/spmd/spmd_partitioner_test.cc | 25 +++++++++++++++++ 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/xla/service/spmd/gather_scatter_handler.cc b/xla/service/spmd/gather_scatter_handler.cc index 57f13ca7d1c5f..5d1d93ab0b8f5 100644 --- a/xla/service/spmd/gather_scatter_handler.cc +++ b/xla/service/spmd/gather_scatter_handler.cc @@ -77,10 +77,11 @@ PartitionedHlo PerGroupPartitionedHlo( // Helper to get multiple per-group partitioned hlos. std::vector PerGroupPartitionedHlos( std::vector& phlos, const GroupedSharding& grouped_sharding, - SpmdBuilder* b, absl::InlinedVector, 3>& clean_ups) { + SpmdBuilder* b, absl::InlinedVector, 3>& clean_ups, + absl::flat_hash_map& + cached_per_group_hlos) { // Cache per-group partitioned hlos to avoid group-partitioning it more than // once. - absl::flat_hash_map cached_per_group_hlos; std::vector hlos; absl::c_transform(phlos, std::back_inserter(hlos), [&](PartitionedHlo phlo) { return phlo.hlo(); }); @@ -1230,10 +1231,11 @@ absl::StatusOr PartitionScatterParallelDimensions( updates[0].sharding(), update_parallel_dims), new_indices_grouped); const GroupedSharding& output_grouped = operand_grouped; - std::vector per_group_operands = - PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); - std::vector per_group_updates = - PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); + absl::flat_hash_map cached_per_group_hlos; + std::vector per_group_operands = PerGroupPartitionedHlos( + operands, operand_grouped, b, clean_ups, cached_per_group_hlos); + std::vector per_group_updates = PerGroupPartitionedHlos( + updates, update_grouped, b, clean_ups, cached_per_group_hlos); PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo(new_indices, new_indices_grouped, b, clean_ups); auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape); @@ -1367,10 +1369,11 @@ absl::StatusOr PartitionScatterOperandPassthroughDimensions( ScatterIndexDimsByPriority(scatter)), update_grouped); const GroupedSharding& output_grouped = operand_grouped; - std::vector per_group_operands = - PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); - std::vector per_group_updates = - PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); + absl::flat_hash_map cached_per_group_hlos; + std::vector per_group_operands = PerGroupPartitionedHlos( + operands, operand_grouped, b, clean_ups, cached_per_group_hlos); + std::vector per_group_updates = PerGroupPartitionedHlos( + updates, update_grouped, b, clean_ups, cached_per_group_hlos); PartitionedHlo per_group_indices = PerGroupPartitionedHlo(indices, indices_grouped, b, clean_ups); auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape); @@ -1623,10 +1626,11 @@ absl::StatusOr PartitionScatterTrivialSlicedOperandDimensions( indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), indices_min)); PartitionedHlo new_indices = indices.CloneWithNewHlo(adjusted_indices); - std::vector per_group_operands = - PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); - std::vector per_group_updates = - PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); + absl::flat_hash_map cached_per_group_hlos; + std::vector per_group_operands = PerGroupPartitionedHlos( + operands, operand_grouped, b, clean_ups, cached_per_group_hlos); + std::vector per_group_updates = PerGroupPartitionedHlos( + updates, update_grouped, b, clean_ups, cached_per_group_hlos); PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo(new_indices, indices_grouped, b, clean_ups); auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape); diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index d6fc45702bea5..445fae47dc636 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -14871,6 +14871,31 @@ ENTRY %main.21 { EXPECT_THAT(updates, op::Shape("bf16[4096,64]")); } +TEST_P(SpmdPartitioningTest, ScatterSameInputSharding) { + const char* const hlo_string = R"( +HloModule pjit + +%region_3.1507 { + %Arg_0.4576 = s32[] parameter(0) + %Arg_1.4577 = s32[] parameter(1) + ROOT %add.4578 = s32[] add(%Arg_0.4576, %Arg_1.4577) +} + +ENTRY %main.21 { + broadcast.4498 = s32[8,3072]{1,0} parameter(0), sharding={devices=[4,1]<=[4]} + concatenate.103 = s32[8,3072,2]{2,1,0} parameter(1), sharding={devices=[4,1,1]<=[4]} + ROOT scatter.36 = s32[8,3072]{1,0} scatter(broadcast.4498, concatenate.103, broadcast.4498), update_window_dims={}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=2, to_apply=region_3.1507, sharding={devices=[4,1]<=[4]}, metadata={op_name="pjit(_train_step)/jit(main)/root[Learner]/jvp(while)/body/jvp(model[Model])/decoder[SelectivePadDecoder]/padder[PadRandomFromList]/scatter" source_file="/combined-code/ajax-code/ajax/experiments/tkncomp/selpad.py" source_line=106} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + + XLA_VLOG_LINES(0, module->ToString()); + auto* scatter = FindInstruction(module.get(), HloOpcode::kScatter); + EXPECT_NE(scatter, nullptr); +} + TEST_P(SpmdPartitioningTest, ComplexReshardUnmerge) { const char* const hlo_string = R"( HloModule Test