diff --git a/xla/hlo/utils/hlo_matchers.h b/xla/hlo/utils/hlo_matchers.h index 2c00ddb7b3edf..1235dcbdd6a0c 100644 --- a/xla/hlo/utils/hlo_matchers.h +++ b/xla/hlo/utils/hlo_matchers.h @@ -284,6 +284,7 @@ HLO_MATCHER(BitcastConvert); HLO_MATCHER(Broadcast); HLO_MATCHER(Call); HLO_MATCHER(Ceil); +HLO_MATCHER(Cholesky); HLO_MATCHER(Clamp); HLO_MATCHER(CollectiveBroadcast); HLO_MATCHER(CollectivePermute); @@ -353,6 +354,7 @@ HLO_MATCHER(Subtract); HLO_MATCHER(Tan); HLO_MATCHER(Tanh); HLO_MATCHER(Transpose); +HLO_MATCHER(TriangularSolve); HLO_MATCHER(Tuple); HLO_MATCHER(While); HLO_MATCHER(Xor); diff --git a/xla/service/spmd/BUILD b/xla/service/spmd/BUILD index d97584970dee6..07781581968e1 100644 --- a/xla/service/spmd/BUILD +++ b/xla/service/spmd/BUILD @@ -112,6 +112,8 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -120,8 +122,6 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index 46a6768bea87c..e43f92497ae61 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -2758,19 +2758,18 @@ absl::Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) { return absl::OkStatus(); } -absl::Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { +absl::Status SpmdPartitioningVisitor::HandleElementwiseWithDimsToReplicate( + HloInstruction* hlo, absl::Span dims_to_replicate) { const HloSharding& sharding = hlo->sharding(); if (sharding.IsTileMaximal()) { return DefaultAction(hlo); } - // 1. Replicate the final sharding along the concatenate dimension to get - // temp_sharding. If the final sharding is already replicated along the - // concatenate dimension, then temp_sharding will be the same as final - // sharding. + // 1. Replicate the final sharding along `dims_to_replicate` to get + // temp_sharding. const HloSharding temp_sharding = hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - sharding, {hlo->concatenate_dimension()}); + sharding, dims_to_replicate); // 2. Reshard the operands to temp_sharding. std::vector new_operands; @@ -2780,18 +2779,36 @@ absl::Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { GetPartitionedHlo(operand).Reshard(temp_sharding).hlo()); } - // 3. Concatenate the operands to get result in temp_sharding. - auto concatenate = b_.AddInstruction(hlo->CloneWithNewOperands( + // 3. Apply the operation to get result in temp_sharding. + auto result_in_temp_sharding = b_.AddInstruction(hlo->CloneWithNewOperands( MakePartitionedShape(hlo->shape(), temp_sharding), new_operands)); - concatenate->set_sharding(temp_sharding); + result_in_temp_sharding->set_sharding(temp_sharding); // 4. Reshard the result from temp_sharding to the final sharding. - SetPartitionedHlo( - hlo, PartitionedHlo(concatenate, hlo->shape(), MakePartitioningState()) - .Reshard(sharding)); + SetPartitionedHlo(hlo, PartitionedHlo(result_in_temp_sharding, hlo->shape(), + MakePartitioningState()) + .Reshard(sharding)); return absl::OkStatus(); } +absl::Status SpmdPartitioningVisitor::HandleCholesky(HloInstruction* hlo) { + CHECK_GE(hlo->shape().rank(), 2); + return HandleElementwiseWithDimsToReplicate( + hlo, {hlo->shape().rank() - 2, hlo->shape().rank() - 1}); +} + +absl::Status SpmdPartitioningVisitor::HandleTriangularSolve( + HloInstruction* hlo) { + CHECK_GE(hlo->shape().rank(), 2); + return HandleElementwiseWithDimsToReplicate( + hlo, {hlo->shape().rank() - 2, hlo->shape().rank() - 1}); +} + +absl::Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { + return HandleElementwiseWithDimsToReplicate(hlo, + {hlo->concatenate_dimension()}); +} + absl::Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) { const HloSharding& sharding = hlo->sharding(); if (sharding.IsTileMaximal()) { diff --git a/xla/service/spmd/spmd_partitioner.h b/xla/service/spmd/spmd_partitioner.h index e771f00d071be..f357ffcd62760 100644 --- a/xla/service/spmd/spmd_partitioner.h +++ b/xla/service/spmd/spmd_partitioner.h @@ -594,6 +594,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { absl::Status HandleBitcastConvert(HloInstruction* hlo) override; absl::Status HandleBroadcast(HloInstruction* hlo) override; absl::Status HandleCall(HloInstruction* hlo) override; + absl::Status HandleCholesky(HloInstruction* hlo) override; absl::Status HandleConcatenate(HloInstruction* hlo) override; absl::Status HandleConditional(HloInstruction* hlo) override; absl::Status HandleConstant(HloInstruction* hlo) override; @@ -622,6 +623,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { absl::Status HandleSlice(HloInstruction* hlo) override; absl::Status HandleSort(HloInstruction* hlo) override; absl::Status HandleTranspose(HloInstruction* hlo) override; + absl::Status HandleTriangularSolve(HloInstruction* hlo) override; absl::Status HandleTuple(HloInstruction* hlo) override; absl::Status HandleWhile(HloInstruction* hlo) override; @@ -637,6 +639,11 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { // Common handle for elementwise HLOs. absl::Status HandleElementwise(HloInstruction* hlo); + // All dimensions in the hlo are element-wise except that we replicate + // `dims_to_replicate`. + absl::Status HandleElementwiseWithDimsToReplicate( + HloInstruction* hlo, absl::Span dims_to_replicate); + // Common handle for HLOs that runs on a single device. absl::Status HandleSingleDevice(const HloInstruction* hlo); diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 723cbd0320b4b..d6fc45702bea5 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -51,10 +51,10 @@ limitations under the License. #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace xla { namespace spmd { @@ -15491,6 +15491,56 @@ ENTRY entry { AllOf(op::DynamicSlice(result, _, _), op::Shape("f32[2,1]"))); } +TEST_P(SpmdPartitioningTest, Cholesky) { + absl::string_view hlo_string = R"( +ENTRY entry { + %p0 = f32[32,32,32] parameter(0), sharding={devices=[2,2,2]<=[8]} + ROOT %cholesky = f32[32,32,32] cholesky(p0), lower=true, sharding={devices=[2,2,2]<=[8]} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + auto param0 = AllOf(op::Parameter(0), op::Shape("f32[16,16,16]")); + auto param0_reshard = + AllOf(op::Shape("f32[16,32,32]"), + op::AllReduce(op::AllReduce( + op::DynamicUpdateSlice(op::Broadcast(), param0, _, _, _)))); + auto cholesky = + AllOf(op::Cholesky(param0_reshard), op::Shape("f32[16,32,32]")); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + AllOf(op::DynamicSlice(cholesky, _, _, _), op::Shape("f32[16,16,16]"))); +} + +TEST_P(SpmdPartitioningTest, TriangularSolve) { + absl::string_view hlo_string = R"( +ENTRY main { + a = f32[10,32,32] parameter(0), sharding={devices=[2,2,2]<=[8]} + b = f32[10,32,48] parameter(1), sharding={devices=[2,2,2]<=[8]} + ROOT triangular-solve = f32[10,32,48] triangular-solve(a, b), left_side=true, unit_diagonal=true, lower=true, transpose_a=NO_TRANSPOSE, sharding={devices=[2,2,2]<=[8]} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + auto param0 = AllOf(op::Parameter(0), op::Shape("f32[5,16,16]")); + auto param0_reshard = + AllOf(op::Shape("f32[5,32,32]"), + op::AllReduce(op::AllReduce( + op::DynamicUpdateSlice(op::Broadcast(), param0, _, _, _)))); + auto param1 = AllOf(op::Parameter(1), op::Shape("f32[5,16,24]")); + auto param1_reshard = + AllOf(op::Shape("f32[5,32,48]"), + op::AllReduce(op::AllReduce( + op::DynamicUpdateSlice(op::Broadcast(), param1, _, _, _)))); + + auto ts = AllOf(op::TriangularSolve(param0_reshard, param1_reshard), + op::Shape("f32[5,32,48]")); + EXPECT_THAT(module->entry_computation()->root_instruction(), + AllOf(op::DynamicSlice(ts, _, _, _), op::Shape("f32[5,16,24]"))); +} + } // namespace } // namespace spmd } // namespace xla