From 448316c7ae83ba181d416f262d060e5c8ea55e43 Mon Sep 17 00:00:00 2001 From: "Ryan M. Lefever" Date: Tue, 14 Jan 2025 15:40:22 -0800 Subject: [PATCH] Fix a bug when we have 2 valid live AllocationValues for an HloValue. This comes up with asychronous operations. PiperOrigin-RevId: 715556884 --- xla/service/memory_space_assignment/BUILD | 2 + .../memory_space_assignment/algorithm.cc | 89 +++++++- .../memory_space_assignment/algorithm.h | 23 +- .../allocation_value.h | 11 + .../memory_space_assignment_test.cc | 209 ++++++++++++++++++ .../memory_space_assignment_test_base.h | 38 ++++ 6 files changed, 363 insertions(+), 9 deletions(-) diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index 40fcfd5140515..17c47eac99d3f 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -187,7 +187,9 @@ cc_library( "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test", diff --git a/xla/service/memory_space_assignment/algorithm.cc b/xla/service/memory_space_assignment/algorithm.cc index 952a064b72620..b27bdad2b522b 100644 --- a/xla/service/memory_space_assignment/algorithm.cc +++ b/xla/service/memory_space_assignment/algorithm.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -627,6 +628,17 @@ void MsaAlgorithm::FindAliases( } } +std::string MsaAlgorithm::RequiredMemoryAssignment::ToString() const { + std::string memory_space_str = + memory_space == MemorySpace::kDefault ? "def" : "alt"; + std::string offset_str = + offset == nullptr ? "null" : absl::StrCat(offset->offset); + + return absl::StrCat( + "RequiredMemoryAssignment(memory_space=", memory_space_str, + ", time=", time, ", offset=", offset_str, ")"); +} + std::vector MsaAlgorithm::GetSortedColocatedIntervals( const MsaBufferInterval& interval) const { std::vector colocated_intervals; @@ -2577,7 +2589,8 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( preferred_offset_for_allocation_value.at(&allocation_value_to_update), definition_time_for_allocation_value.at(&allocation_value_to_update), RequiresNoCopyAlternateMemAllocation(allocation_value_to_update), - all_use_times, entry.only_extend_existing_allocation); + all_use_times, entry.only_extend_existing_allocation, + allocation_values.subspan(0, alloc_value_idx)); if (options_.allocation_request_modifier_testing_fn) { options_.allocation_request_modifier_testing_fn(request); } @@ -2750,7 +2763,8 @@ AllocationRequest MsaAlgorithm::CreateAllocationRequest( AliasedOffset* preferred_offset, int64_t definition_time, bool require_no_copy_alternate_mem_allocation, const std::vector& all_use_times, - bool only_extend_existing_allocation) { + bool only_extend_existing_allocation, + absl::Span processed_allocation_values) { const HloUse& hlo_use = use.hlo_use; const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); bool require_copy_allocation = false; @@ -2984,6 +2998,7 @@ AllocationRequest MsaAlgorithm::CreateAllocationRequest( request.end_time = use_time; request.only_extend_existing_allocation = only_extend_existing_allocation; + request.processed_allocation_values = processed_allocation_values; return request; } @@ -4483,6 +4498,58 @@ std::string MsaAlgorithm::ResultToString(const AllocationResult& result) { return result_str; } +void MsaAlgorithm::CheckAndUpdateForDualLiveAllocationValues( + const std::optional& + required_memory_assignment_at_start, + AllocationRequest& request) { + if (!request.allocation_value->requires_contiguous_allocation()) { + return; + } + if (!required_memory_assignment_at_start.has_value()) { + return; + } + if (required_memory_assignment_at_start->memory_space != + MemorySpace::kAlternate) { + return; + } + // Go through previous allocations, for the same HloValue, and check if they + // have already allocated alternate memory at the beginning of the current + // AllocationValue, such that we are required to use the same heap offset. + std::vector overlapping_allocations; + Chunk required_chunk = Chunk::FromOffsetSize( + required_memory_assignment_at_start->offset->offset, request.size); + for (const AllocationValue& processed_allocation_value : + request.processed_allocation_values) { + for (const std::unique_ptr& allocation : + *processed_allocation_value.allocation_sequence()) { + if (allocation->is_in_alternate_mem() && + allocation->start_time() <= request.inclusive_start_time && + request.inclusive_start_time <= allocation->end_time() && + allocation->chunk() == required_chunk) { + overlapping_allocations.push_back(allocation.get()); + } + } + } + absl::c_sort(overlapping_allocations, + [](const Allocation* a, const Allocation* b) { + return a->start_time() < b->start_time(); + }); + int64_t chunk_start_time = request.inclusive_start_time; + for (const Allocation* allocation : overlapping_allocations) { + chunk_start_time = std::max(chunk_start_time, allocation->end_time() + 1); + } + + // Note, we don't have to set request.preferred_offset, or do anything special + // to handle aliasing. This is done for us. Specifically, before calling + // CheckAndUpdateForDualLiveAllocationValues(), AllocateSegment() inserts a + // PinnedAllocation with no associated heap chunk, at the beginning of + // request.allocation_value. It aliases that PinnedAllocation with any + // overlapping allocations calculated above. In + // AllocateInAlternateMemoryNoCopy(), we will find that PinnedAllocation and + // realize we need to use the same alternate memory offset. + request.no_copy_chunk_inclusive_start_time = chunk_start_time; +} + AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) { auto allocation_sequence = request.allocation_value->mutable_allocation_sequence(); @@ -4535,17 +4602,19 @@ AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) { // we're allowed to prefetch. If the use expects the output to be in default // memory, we cannot prefetch it because if we did, it would be in alternate // memory instead. - auto required_assignment_at_start = RequiredMemoryAssignmentAt( - request.allocation_value->value(), request.inclusive_start_time); + std::optional required_assignment_at_start = + RequiredMemoryAssignmentAt(request.allocation_value->value(), + request.inclusive_start_time); std::optional required_memory_space_at_start; if (required_assignment_at_start) { required_memory_space_at_start = required_assignment_at_start->memory_space; } // Find required assignment both for the use and its aliases. If they are both // non-nullopt, then make sure they require the same assignment. - auto required_assignment_at_end = RequiredMemoryAssignmentAt( - request.allocation_value_to_update->value(), request.end_time); - auto aliased_required_assignment_at_end = + std::optional required_assignment_at_end = + RequiredMemoryAssignmentAt(request.allocation_value_to_update->value(), + request.end_time); + std::optional aliased_required_assignment_at_end = AliasedRequiredAssignmentForUse(*request.use); if (required_assignment_at_end != aliased_required_assignment_at_end) { if (required_assignment_at_end == std::nullopt) { @@ -4600,6 +4669,8 @@ AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) { required_memory_space_at_end != MemorySpace::kDefault && request.allow_no_copy_alternate_mem_allocation && !request.require_copy_allocation) { + CheckAndUpdateForDualLiveAllocationValues(required_assignment_at_start, + request); allocation_result = AllocateInAlternateMemoryNoCopy(request); if (allocation_result == AllocationResult::kSuccess) { return AllocationResult::kSuccess; @@ -4996,6 +5067,10 @@ AllocationResult MsaAlgorithm::AllocateInAlternateMemoryNoCopy( // If there is a previous allocation, set the start time one after the end // of the previous allocation's end. alternate_mem_interval.start = prev_allocation->end_time() + 1; + if (request.no_copy_chunk_inclusive_start_time.has_value()) { + alternate_mem_interval.start = + *request.no_copy_chunk_inclusive_start_time; + } } if (request.preferred_offset) { diff --git a/xla/service/memory_space_assignment/algorithm.h b/xla/service/memory_space_assignment/algorithm.h index b74439d25f42b..4cc41e8e004ec 100644 --- a/xla/service/memory_space_assignment/algorithm.h +++ b/xla/service/memory_space_assignment/algorithm.h @@ -23,6 +23,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -340,7 +341,7 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { struct RequiredMemoryAssignment { MemorySpace memory_space; int64_t time; - AliasedOffset* offset; + AliasedOffset* offset = nullptr; bool equals_ignoring_time(const RequiredMemoryAssignment& other) const { return memory_space == other.memory_space && offset == other.offset; @@ -354,6 +355,8 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { bool operator!=(const RequiredMemoryAssignment& other) const { return !(*this == other); } + + std::string ToString() const; }; // A struct that contains a pointer to loop-optimized allocation along with @@ -622,6 +625,9 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { // only_extend_existing_allocation is true, no new Allocations will be created // while processing the resulting AllocationRequest, and we only need to // extend an existing Allocation's end_time. + // + // * processed_allocation_values: The AllocationValues that have already been + // processed for the same parent HloValue as is used in the request. AllocationRequest CreateAllocationRequest( AllocationValue& allocation_value, AllocationValue& allocation_value_to_update, @@ -629,7 +635,8 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { AliasedOffset* preferred_offset, int64_t definition_time, bool require_no_copy_alternate_mem_allocation, const std::vector& all_use_times, - bool only_extend_existing_allocation); + bool only_extend_existing_allocation, + absl::Span processed_allocation_values); // Returns true, if the allocation value requires a pinned allocation in the // alternate memory space. @@ -658,6 +665,18 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { absl::StatusOr AllocateAllocationValues( absl::Span allocation_values); + // Checks for a situation in which an HloValue has more than one live + // AllocationValue at the same time, and the already processed AllocationValue + // has been given alternate memory at the start of the second AllocationValue. + // If such a case is detected, we set + // request.no_copy_chunk_inclusive_start_time with the time where the first + // AllocationValue left off. AllocateInAlternateMemoryNoCopy() takes advantage + // of that information. + void CheckAndUpdateForDualLiveAllocationValues( + const std::optional& + required_memory_assignment_at_start, + AllocationRequest& request); + // Finds an allocation for an allocation request for a segment (see the // documentation for AllocationRequest above how a segment is defined). // diff --git a/xla/service/memory_space_assignment/allocation_value.h b/xla/service/memory_space_assignment/allocation_value.h index e59733adabbbc..9712534be6e36 100644 --- a/xla/service/memory_space_assignment/allocation_value.h +++ b/xla/service/memory_space_assignment/allocation_value.h @@ -264,6 +264,17 @@ struct AllocationRequest { // Data structure that contains the options for making window prefetched // allocations. const WindowPrefetchedAllocation::Options* window_prefetch_options = nullptr; + // Previously processed AllocationValues, with the same parent HloValue as the + // request. + absl::Span processed_allocation_values; + // An optional override starting time for the placement of a chunk on the MSA + // heap, for a no-copy allocation (see + // MsaAlgorithm::AllocateInAlternateMemoryNoCopy() for more details). + // + // Note, this override is used when an aliased AllocationValue has already + // done some of the heap allocation for us. So this request picks up where it + // left off. + std::optional no_copy_chunk_inclusive_start_time; }; // Result of an allocation, prefetch, eviction etc. request. The result is diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 15694b878bc39..1cd00af978967 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -6219,6 +6219,215 @@ TEST_F(MemorySpaceAssignmentTest, DisallowedUseBugInWhile) { AssignMemorySpace(module.get(), options); } +TEST_F(MemorySpaceAssignmentTest, TwoLiveAllocationValuesBase) { + // In this example, we have enough space to give negate.0 alternate memory, + // and we put put negate.0 at the top of MSA's sort order. So, we expect that + // it will get alternate memory. + // + // We are testing a fix for dual live AllocationsValues, with the following + // setup: + // - HloValue H containing the following positions: negate.0, cp-start.0{0} + // - AllocationValue A0 defined at negate.0 + // - Segment A0.S0 define during [negate.0, cp-start.0] + //. - Segment A0.S1 defined during [cp-start.0, add.0] + // - AllocationValue A1 defined at cp-start.0{0} + // - Segment A1.S0 defined during [cp-start.0, cp-done.0] + // + // A0 and A1 are both live for more than 1 instruction. + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + /*00:*/ p.0 = f32[10,10,10,10] parameter(0) + /*01:*/ p.1 = f32[10,10,10,10] parameter(1) + /*02:*/ v.0 = f32[10,10,10,10] add(p.1, p.1) + /*03:*/ negate.0 = f32[10,10,10,10] negate(p.0) + /*04:*/ cp-start.0 = (f32[10,10,10,10], f32[10,10,10,10], u32[], u32[]) collective-permute-start(negate.0), source_target_pairs={{0,1},{2,3}} + /*05:*/ v.1 = f32[10,10,10,10] add(v.0, v.0) + /*06:*/ add.0 = f32[10,10,10,10] add(negate.0, negate.0) + /*07:*/ v.2 = f32[10,10,10,10] add(v.1, v.1) + /*08:*/ cp-done.0 = f32[10,10,10,10] collective-permute-done(cp-start.0) + /*09:*/ v.3 = f32[10,10,10,10] add(v.2, v.2) + /*10:*/ ROOT tuple.0 = (f32[10,10,10,10], f32[10,10,10,10], f32[10,10,10,10]) tuple(add.0, cp-done.0, v.3) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 4 * 10 * 10 * 10 * 10; + MsaBufferIntervalCompare buffer_interval_compare = + CreateBufferIntervalCompareFnFromInstructionNames({"negate.0"}); + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(1, 10); + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); + VLOG(1) << "Module after MSA:\n" << module->ToString(); + + HloInstruction* copy0 = FindInstruction(module.get(), "negate.0"); + ASSERT_NE(copy0, nullptr); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); +} + +TEST_F(MemorySpaceAssignmentTest, + TwoLiveAllocationValuesTwoInstructionOverlap) { + // In this example, we have enough space to give negate.0 alternate memory, + // and we put put negate.0 at the top of MSA's sort order. So, we expect that + // it will get alternate memory. + // + // We are testing a fix for dual live AllocationValues, with the following + // setup: + // - HloValue H containing the following positions: negate.0, cp-start.0{0} + // - AllocationValue A0 defined at negate.0 + // - Segment A0.S0 define during [negate.0, cp-start.0] + //. - Segment A0.S1 defined during [cp-start.0, add.0] + // - AllocationValue A1 defined at cp-start.0{0} + // - Segment A1.S0 defined during [cp-start.0, cp-done.0] + // + // A0 and A1 are both live for 2 instructions + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + /*00:*/ p.0 = f32[10,10,10,10] parameter(0) + /*01:*/ p.1 = f32[10,10,10,10] parameter(1) + /*02:*/ v.0 = f32[10,10,10,10] add(p.1, p.1) + /*03:*/ negate.0 = f32[10,10,10,10] negate(p.0) + /*04:*/ cp-start.0 = (f32[10,10,10,10], f32[10,10,10,10], u32[], u32[]) collective-permute-start(negate.0), source_target_pairs={{0,1},{2,3}} + /*05:*/ add.0 = f32[10,10,10,10] add(negate.0, negate.0) + /*06:*/ v.1 = f32[10,10,10,10] add(v.0, v.0) + /*07:*/ v.2 = f32[10,10,10,10] add(v.1, v.1) + /*08:*/ cp-done.0 = f32[10,10,10,10] collective-permute-done(cp-start.0) + /*09:*/ v.3 = f32[10,10,10,10] add(v.2, v.2) + /*10:*/ ROOT tuple.0 = (f32[10,10,10,10], f32[10,10,10,10], f32[10,10,10,10]) tuple(add.0, cp-done.0, v.3) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 4 * 10 * 10 * 10 * 10; + MsaBufferIntervalCompare buffer_interval_compare = + CreateBufferIntervalCompareFnFromInstructionNames({"negate.0"}); + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(1, 10); + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); + VLOG(1) << "Module after MSA:\n" << module->ToString(); + + HloInstruction* copy0 = FindInstruction(module.get(), "negate.0"); + ASSERT_NE(copy0, nullptr); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); +} + +TEST_F(MemorySpaceAssignmentTest, + TwoLiveAllocationValuesFirstLiveAllocationValueOutlastsSecond) { + // In this example, we have enough space to give negate.0 alternate memory, + // and we put put negate.0 at the top of MSA's sort order. So, we expect that + // it will get alternate memory. + // + // We are testing a fix for dual live AllocationValues, with the following + // setup: + // - HloValue H containing the following positions: negate.0, cp-start.0{0} + // - AllocationValue A0 defined at negate.0 + // - Segment A0.S0 define during [negate.0, cp-start.0] + //. - Segment A0.S1 defined during [cp-start.0, add.0] + // - Segment A0.S2 defined during [add.0, add.1] + // - AllocationValue A1 defined at cp-start.0{0} + // - Segment A1.S0 defined during [cp-start.0, cp-done.0] + // + // A0 and A1 are both live for more than 1 instruction. A0 is live beyond the + // end of A1. + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + /*00:*/ p.0 = f32[10,10,10,10] parameter(0) + /*01:*/ p.1 = f32[10,10,10,10] parameter(1) + /*02:*/ v.0 = f32[10,10,10,10] add(p.1, p.1) + /*03:*/ negate.0 = f32[10,10,10,10] negate(p.0) + /*04:*/ cp-start.0 = (f32[10,10,10,10], f32[10,10,10,10], u32[], u32[]) collective-permute-start(negate.0), source_target_pairs={{0,1},{2,3}} + /*05:*/ v.1 = f32[10,10,10,10] add(v.0, v.0) + /*06:*/ add.0 = f32[10,10,10,10] add(negate.0, negate.0) + /*07:*/ v.2 = f32[10,10,10,10] add(v.1, v.1) + /*08:*/ cp-done.0 = f32[10,10,10,10] collective-permute-done(cp-start.0) + /*09:*/ v.3 = f32[10,10,10,10] add(v.2, v.2) + /*10:*/ add.1 = f32[10,10,10,10] add(add.0, negate.0) + /*11:*/ ROOT tuple.0 = (f32[10,10,10,10], f32[10,10,10,10], f32[10,10,10,10]) tuple(add.1, cp-done.0, v.3) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 4 * 10 * 10 * 10 * 10; + MsaBufferIntervalCompare buffer_interval_compare = + CreateBufferIntervalCompareFnFromInstructionNames({"negate.0"}); + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(1, 10); + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); + VLOG(1) << "Module after MSA:\n" << module->ToString(); + + HloInstruction* copy0 = FindInstruction(module.get(), "negate.0"); + ASSERT_NE(copy0, nullptr); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); +} + +TEST_F(MemorySpaceAssignmentTest, + TwoLiveAllocationValuesUnableToAllocateContiguousAltMem) { + // In this example, we have enough space to give v.2 alternate memory, + // and we put v.2 at the top of MSA's sort order. So, we expect that + // it will get alternate memory. Second, we try to give negate.0 alternate + // memory, but we can't. In order to give negate.0 alternate memory, we need + // to give it contiguous alternate memory during cp-start.0 to cp-done.0. + // (negate.0 and cp-start.0 {0} alias.) However, v.2 is taking too much + // alternate memory to accomodate that request. + // + // We are testing a fix for dual live AllocationValues, with the following + // setup: + // - HloValue H containing the following positions: negate.0, cp-start.0{0} + // - AllocationValue A0 defined at negate.0 + // - Segment A0.S0 define during [negate.0, cp-start.0] + //. - Segment A0.S1 defined during [cp-start.0, add.0] + // - AllocationValue A1 defined at cp-start.0{0} + // - Segment A1.S0 defined during [cp-start.0, cp-done.0] + // + // A0 and A1 are both live for more than 1 instruction. + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + /*00:*/ p.0 = f32[10,10,10,10] parameter(0) + /*01:*/ p.1 = f32[10,10,10,10] parameter(1) + /*02:*/ v.0 = f32[10,10,10,10] add(p.1, p.1) + /*03:*/ negate.0 = f32[10,10,10,10] negate(p.0) + /*04:*/ cp-start.0 = (f32[10,10,10,10], f32[10,10,10,10], u32[], u32[]) collective-permute-start(negate.0), source_target_pairs={{0,1},{2,3}} + /*05:*/ v.1 = f32[10,10,10,10] add(v.0, v.0) + /*06:*/ add.0 = f32[10,10,10,10] add(negate.0, negate.0) + /*07:*/ v.2 = f32[10,10,10,10] add(v.1, v.1) + /*08:*/ cp-done.0 = f32[10,10,10,10] collective-permute-done(cp-start.0) + /*09:*/ v.3 = f32[10,10,10,10] add(v.2, v.2) + /*10:*/ ROOT tuple.0 = (f32[10,10,10,10], f32[10,10,10,10], f32[10,10,10,10]) tuple(add.0, cp-done.0, v.3) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 4 * 10 * 10 * 10 * 10; + MsaBufferIntervalCompare buffer_interval_compare = + CreateBufferIntervalCompareFnFromInstructionNames({"v.2", "negate.0"}); + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(1, 10); + std::unique_ptr preset_assignments = + AssignMemorySpace(module.get(), options, buffer_interval_compare, + &prefetch_interval_picker); + VLOG(1) << "Module after MSA:\n" << module->ToString(); + + HloInstruction* v2 = FindInstruction(module.get(), "v.2"); + ASSERT_NE(v2, nullptr); + EXPECT_EQ(v2->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* copy0 = FindInstruction(module.get(), "negate.0"); + ASSERT_NE(copy0, nullptr); + EXPECT_NE(copy0->shape().layout().memory_space(), kAlternateMemorySpace); +} + TEST_F(MemorySpaceAssignmentTest, AvoidRedundantEvictionInWhile) { absl::string_view hlo_string = R"( HloModule module, is_scheduled=true diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test_base.h b/xla/service/memory_space_assignment/memory_space_assignment_test_base.h index 66515e8b4c875..ff9d8dc4053ca 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test_base.h +++ b/xla/service/memory_space_assignment/memory_space_assignment_test_base.h @@ -17,14 +17,19 @@ limitations under the License. #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_SPACE_ASSIGNMENT_TEST_BASE_H_ #include +#include #include #include #include #include #include +#include #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -129,6 +134,39 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { return options; } + // Creates an MsaBufferIntervalCompare function that prioritizes the + // instructions named in prioritized_instruction_names, in the order + // specified. We default to alphabetical instruction name order for the + // remaining instructions. + static MsaBufferIntervalCompare + CreateBufferIntervalCompareFnFromInstructionNames( + std::vector prioritized_instruction_names) { + absl::flat_hash_map instruction_name_to_priority; + // A lower priority value means its higher on the Msa sort list. + for (size_t i = 0; i < prioritized_instruction_names.size(); ++i) { + instruction_name_to_priority[prioritized_instruction_names[i]] = i; + } + return [instruction_name_to_priority = + std::move(instruction_name_to_priority)]( + const MsaBufferInterval& a, const MsaBufferInterval& b) { + auto get_sort_tuple = [&instruction_name_to_priority]( + const MsaBufferInterval& buffer_interval) { + auto it = instruction_name_to_priority.find( + buffer_interval.buffer->defining_instruction()->name()); + if (it != instruction_name_to_priority.end()) { + return std::make_tuple( + it->second, + buffer_interval.buffer->defining_instruction()->name()); + } + return std::make_tuple( + instruction_name_to_priority.size(), + buffer_interval.buffer->defining_instruction()->name()); + }; + + return get_sort_tuple(a) < get_sort_tuple(b); + }; + } + std::unique_ptr AssignMemorySpaceUsingCostAnalysis( HloModule* module, std::optional memory_space_options_override = std::nullopt,