Skip to content

Commit

Permalink
Extract a common helper function `HandleElementwiseWithDimsToReplicat…
Browse files Browse the repository at this point in the history
…e` in `SpmdPartitioningVisitor`.

Based on that, add `HandleCholesky` and `HandleTriangularSolve`. Before this change, we replicate all dimensions in these ops. With this cl, we only replicate the last two dimensions for these two operations.

PiperOrigin-RevId: 713827953
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Jan 10, 2025
1 parent ba29286 commit b1cc7dd
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 16 deletions.
2 changes: 2 additions & 0 deletions xla/hlo/utils/hlo_matchers.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ HLO_MATCHER(BitcastConvert);
HLO_MATCHER(Broadcast);
HLO_MATCHER(Call);
HLO_MATCHER(Ceil);
HLO_MATCHER(Cholesky);
HLO_MATCHER(Clamp);
HLO_MATCHER(CollectiveBroadcast);
HLO_MATCHER(CollectivePermute);
Expand Down Expand Up @@ -353,6 +354,7 @@ HLO_MATCHER(Subtract);
HLO_MATCHER(Tan);
HLO_MATCHER(Tanh);
HLO_MATCHER(Transpose);
HLO_MATCHER(TriangularSolve);
HLO_MATCHER(Tuple);
HLO_MATCHER(While);
HLO_MATCHER(Xor);
Expand Down
4 changes: 2 additions & 2 deletions xla/service/spmd/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ xla_cc_test(
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand All @@ -120,8 +122,6 @@ xla_cc_test(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
41 changes: 29 additions & 12 deletions xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2758,19 +2758,18 @@ absl::Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
return absl::OkStatus();
}

absl::Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
absl::Status SpmdPartitioningVisitor::HandleElementwiseWithDimsToReplicate(
HloInstruction* hlo, absl::Span<const int64_t> dims_to_replicate) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
return DefaultAction(hlo);
}

// 1. Replicate the final sharding along the concatenate dimension to get
// temp_sharding. If the final sharding is already replicated along the
// concatenate dimension, then temp_sharding will be the same as final
// sharding.
// 1. Replicate the final sharding along `dims_to_replicate` to get
// temp_sharding.
const HloSharding temp_sharding =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
sharding, {hlo->concatenate_dimension()});
sharding, dims_to_replicate);

// 2. Reshard the operands to temp_sharding.
std::vector<HloInstruction*> new_operands;
Expand All @@ -2780,18 +2779,36 @@ absl::Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
GetPartitionedHlo(operand).Reshard(temp_sharding).hlo());
}

// 3. Concatenate the operands to get result in temp_sharding.
auto concatenate = b_.AddInstruction(hlo->CloneWithNewOperands(
// 3. Apply the operation to get result in temp_sharding.
auto result_in_temp_sharding = b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), temp_sharding), new_operands));
concatenate->set_sharding(temp_sharding);
result_in_temp_sharding->set_sharding(temp_sharding);

// 4. Reshard the result from temp_sharding to the final sharding.
SetPartitionedHlo(
hlo, PartitionedHlo(concatenate, hlo->shape(), MakePartitioningState())
.Reshard(sharding));
SetPartitionedHlo(hlo, PartitionedHlo(result_in_temp_sharding, hlo->shape(),
MakePartitioningState())
.Reshard(sharding));
return absl::OkStatus();
}

absl::Status SpmdPartitioningVisitor::HandleCholesky(HloInstruction* hlo) {
CHECK_GE(hlo->shape().rank(), 2);
return HandleElementwiseWithDimsToReplicate(
hlo, {hlo->shape().rank() - 2, hlo->shape().rank() - 1});
}

absl::Status SpmdPartitioningVisitor::HandleTriangularSolve(
HloInstruction* hlo) {
CHECK_GE(hlo->shape().rank(), 2);
return HandleElementwiseWithDimsToReplicate(
hlo, {hlo->shape().rank() - 2, hlo->shape().rank() - 1});
}

absl::Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
return HandleElementwiseWithDimsToReplicate(hlo,
{hlo->concatenate_dimension()});
}

absl::Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
const HloSharding& sharding = hlo->sharding();
if (sharding.IsTileMaximal()) {
Expand Down
7 changes: 7 additions & 0 deletions xla/service/spmd/spmd_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
absl::Status HandleBitcastConvert(HloInstruction* hlo) override;
absl::Status HandleBroadcast(HloInstruction* hlo) override;
absl::Status HandleCall(HloInstruction* hlo) override;
absl::Status HandleCholesky(HloInstruction* hlo) override;
absl::Status HandleConcatenate(HloInstruction* hlo) override;
absl::Status HandleConditional(HloInstruction* hlo) override;
absl::Status HandleConstant(HloInstruction* hlo) override;
Expand Down Expand Up @@ -622,6 +623,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
absl::Status HandleSlice(HloInstruction* hlo) override;
absl::Status HandleSort(HloInstruction* hlo) override;
absl::Status HandleTranspose(HloInstruction* hlo) override;
absl::Status HandleTriangularSolve(HloInstruction* hlo) override;
absl::Status HandleTuple(HloInstruction* hlo) override;
absl::Status HandleWhile(HloInstruction* hlo) override;

Expand All @@ -637,6 +639,11 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
// Common handle for elementwise HLOs.
absl::Status HandleElementwise(HloInstruction* hlo);

// All dimensions in the hlo are element-wise except that we replicate
// `dims_to_replicate`.
absl::Status HandleElementwiseWithDimsToReplicate(
HloInstruction* hlo, absl::Span<const int64_t> dims_to_replicate);

// Common handle for HLOs that runs on a single device.
absl::Status HandleSingleDevice(const HloInstruction* hlo);

Expand Down
54 changes: 52 additions & 2 deletions xla/service/spmd/spmd_partitioner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace spmd {
Expand Down Expand Up @@ -15491,6 +15491,56 @@ ENTRY entry {
AllOf(op::DynamicSlice(result, _, _), op::Shape("f32[2,1]")));
}

TEST_P(SpmdPartitioningTest, Cholesky) {
absl::string_view hlo_string = R"(
ENTRY entry {
%p0 = f32[32,32,32] parameter(0), sharding={devices=[2,2,2]<=[8]}
ROOT %cholesky = f32[32,32,32] cholesky(p0), lower=true, sharding={devices=[2,2,2]<=[8]}
})";

TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));

auto param0 = AllOf(op::Parameter(0), op::Shape("f32[16,16,16]"));
auto param0_reshard =
AllOf(op::Shape("f32[16,32,32]"),
op::AllReduce(op::AllReduce(
op::DynamicUpdateSlice(op::Broadcast(), param0, _, _, _))));
auto cholesky =
AllOf(op::Cholesky(param0_reshard), op::Shape("f32[16,32,32]"));
EXPECT_THAT(
module->entry_computation()->root_instruction(),
AllOf(op::DynamicSlice(cholesky, _, _, _), op::Shape("f32[16,16,16]")));
}

TEST_P(SpmdPartitioningTest, TriangularSolve) {
absl::string_view hlo_string = R"(
ENTRY main {
a = f32[10,32,32] parameter(0), sharding={devices=[2,2,2]<=[8]}
b = f32[10,32,48] parameter(1), sharding={devices=[2,2,2]<=[8]}
ROOT triangular-solve = f32[10,32,48] triangular-solve(a, b), left_side=true, unit_diagonal=true, lower=true, transpose_a=NO_TRANSPOSE, sharding={devices=[2,2,2]<=[8]}
})";

TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));

auto param0 = AllOf(op::Parameter(0), op::Shape("f32[5,16,16]"));
auto param0_reshard =
AllOf(op::Shape("f32[5,32,32]"),
op::AllReduce(op::AllReduce(
op::DynamicUpdateSlice(op::Broadcast(), param0, _, _, _))));
auto param1 = AllOf(op::Parameter(1), op::Shape("f32[5,16,24]"));
auto param1_reshard =
AllOf(op::Shape("f32[5,32,48]"),
op::AllReduce(op::AllReduce(
op::DynamicUpdateSlice(op::Broadcast(), param1, _, _, _))));

auto ts = AllOf(op::TriangularSolve(param0_reshard, param1_reshard),
op::Shape("f32[5,32,48]"));
EXPECT_THAT(module->entry_computation()->root_instruction(),
AllOf(op::DynamicSlice(ts, _, _, _), op::Shape("f32[5,16,24]")));
}

} // namespace
} // namespace spmd
} // namespace xla

0 comments on commit b1cc7dd

Please sign in to comment.