Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug when we have 2 valid live AllocationValues for an HloValue. This comes up with asychronous operations. #21833

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/service/memory_space_assignment/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ cc_library(
"//xla/service:hlo_value",
"//xla/tests:hlo_test_base",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
Expand Down
89 changes: 82 additions & 7 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include <map>
#include <memory>
#include <optional>
#include <ostream>
#include <set>
#include <string>
#include <utility>
Expand Down Expand Up @@ -627,6 +628,17 @@ void MsaAlgorithm::FindAliases(
}
}

std::string MsaAlgorithm::RequiredMemoryAssignment::ToString() const {
std::string memory_space_str =
memory_space == MemorySpace::kDefault ? "def" : "alt";
std::string offset_str =
offset == nullptr ? "null" : absl::StrCat(offset->offset);

return absl::StrCat(
"RequiredMemoryAssignment(memory_space=", memory_space_str,
", time=", time, ", offset=", offset_str, ")");
}

std::vector<const MsaBufferInterval*> MsaAlgorithm::GetSortedColocatedIntervals(
const MsaBufferInterval& interval) const {
std::vector<const MsaBufferInterval*> colocated_intervals;
Expand Down Expand Up @@ -2577,7 +2589,8 @@ absl::StatusOr<AllocationResult> MsaAlgorithm::AllocateAllocationValues(
preferred_offset_for_allocation_value.at(&allocation_value_to_update),
definition_time_for_allocation_value.at(&allocation_value_to_update),
RequiresNoCopyAlternateMemAllocation(allocation_value_to_update),
all_use_times, entry.only_extend_existing_allocation);
all_use_times, entry.only_extend_existing_allocation,
allocation_values.subspan(0, alloc_value_idx));
if (options_.allocation_request_modifier_testing_fn) {
options_.allocation_request_modifier_testing_fn(request);
}
Expand Down Expand Up @@ -2750,7 +2763,8 @@ AllocationRequest MsaAlgorithm::CreateAllocationRequest(
AliasedOffset* preferred_offset, int64_t definition_time,
bool require_no_copy_alternate_mem_allocation,
const std::vector<int64_t>& all_use_times,
bool only_extend_existing_allocation) {
bool only_extend_existing_allocation,
absl::Span<AllocationValue> processed_allocation_values) {
const HloUse& hlo_use = use.hlo_use;
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
bool require_copy_allocation = false;
Expand Down Expand Up @@ -2984,6 +2998,7 @@ AllocationRequest MsaAlgorithm::CreateAllocationRequest(

request.end_time = use_time;
request.only_extend_existing_allocation = only_extend_existing_allocation;
request.processed_allocation_values = processed_allocation_values;

return request;
}
Expand Down Expand Up @@ -4483,6 +4498,58 @@ std::string MsaAlgorithm::ResultToString(const AllocationResult& result) {
return result_str;
}

void MsaAlgorithm::CheckAndUpdateForDualLiveAllocationValues(
const std::optional<RequiredMemoryAssignment>&
required_memory_assignment_at_start,
AllocationRequest& request) {
if (!request.allocation_value->requires_contiguous_allocation()) {
return;
}
if (!required_memory_assignment_at_start.has_value()) {
return;
}
if (required_memory_assignment_at_start->memory_space !=
MemorySpace::kAlternate) {
return;
}
// Go through previous allocations, for the same HloValue, and check if they
// have already allocated alternate memory at the beginning of the current
// AllocationValue, such that we are required to use the same heap offset.
std::vector<Allocation*> overlapping_allocations;
Chunk required_chunk = Chunk::FromOffsetSize(
required_memory_assignment_at_start->offset->offset, request.size);
for (const AllocationValue& processed_allocation_value :
request.processed_allocation_values) {
for (const std::unique_ptr<Allocation>& allocation :
*processed_allocation_value.allocation_sequence()) {
if (allocation->is_in_alternate_mem() &&
allocation->start_time() <= request.inclusive_start_time &&
request.inclusive_start_time <= allocation->end_time() &&
allocation->chunk() == required_chunk) {
overlapping_allocations.push_back(allocation.get());
}
}
}
absl::c_sort(overlapping_allocations,
[](const Allocation* a, const Allocation* b) {
return a->start_time() < b->start_time();
});
int64_t chunk_start_time = request.inclusive_start_time;
for (const Allocation* allocation : overlapping_allocations) {
chunk_start_time = std::max(chunk_start_time, allocation->end_time() + 1);
}

// Note, we don't have to set request.preferred_offset, or do anything special
// to handle aliasing. This is done for us. Specifically, before calling
// CheckAndUpdateForDualLiveAllocationValues(), AllocateSegment() inserts a
// PinnedAllocation with no associated heap chunk, at the beginning of
// request.allocation_value. It aliases that PinnedAllocation with any
// overlapping allocations calculated above. In
// AllocateInAlternateMemoryNoCopy(), we will find that PinnedAllocation and
// realize we need to use the same alternate memory offset.
request.no_copy_chunk_inclusive_start_time = chunk_start_time;
}

AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
auto allocation_sequence =
request.allocation_value->mutable_allocation_sequence();
Expand Down Expand Up @@ -4535,17 +4602,19 @@ AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
// we're allowed to prefetch. If the use expects the output to be in default
// memory, we cannot prefetch it because if we did, it would be in alternate
// memory instead.
auto required_assignment_at_start = RequiredMemoryAssignmentAt(
request.allocation_value->value(), request.inclusive_start_time);
std::optional<RequiredMemoryAssignment> required_assignment_at_start =
RequiredMemoryAssignmentAt(request.allocation_value->value(),
request.inclusive_start_time);
std::optional<MemorySpace> required_memory_space_at_start;
if (required_assignment_at_start) {
required_memory_space_at_start = required_assignment_at_start->memory_space;
}
// Find required assignment both for the use and its aliases. If they are both
// non-nullopt, then make sure they require the same assignment.
auto required_assignment_at_end = RequiredMemoryAssignmentAt(
request.allocation_value_to_update->value(), request.end_time);
auto aliased_required_assignment_at_end =
std::optional<RequiredMemoryAssignment> required_assignment_at_end =
RequiredMemoryAssignmentAt(request.allocation_value_to_update->value(),
request.end_time);
std::optional<RequiredMemoryAssignment> aliased_required_assignment_at_end =
AliasedRequiredAssignmentForUse(*request.use);
if (required_assignment_at_end != aliased_required_assignment_at_end) {
if (required_assignment_at_end == std::nullopt) {
Expand Down Expand Up @@ -4600,6 +4669,8 @@ AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
required_memory_space_at_end != MemorySpace::kDefault &&
request.allow_no_copy_alternate_mem_allocation &&
!request.require_copy_allocation) {
CheckAndUpdateForDualLiveAllocationValues(required_assignment_at_start,
request);
allocation_result = AllocateInAlternateMemoryNoCopy(request);
if (allocation_result == AllocationResult::kSuccess) {
return AllocationResult::kSuccess;
Expand Down Expand Up @@ -4996,6 +5067,10 @@ AllocationResult MsaAlgorithm::AllocateInAlternateMemoryNoCopy(
// If there is a previous allocation, set the start time one after the end
// of the previous allocation's end.
alternate_mem_interval.start = prev_allocation->end_time() + 1;
if (request.no_copy_chunk_inclusive_start_time.has_value()) {
alternate_mem_interval.start =
*request.no_copy_chunk_inclusive_start_time;
}
}

if (request.preferred_offset) {
Expand Down
23 changes: 21 additions & 2 deletions xla/service/memory_space_assignment/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <map>
#include <memory>
#include <optional>
#include <ostream>
#include <set>
#include <string>
#include <tuple>
Expand Down Expand Up @@ -340,7 +341,7 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap<HloValue> {
struct RequiredMemoryAssignment {
MemorySpace memory_space;
int64_t time;
AliasedOffset* offset;
AliasedOffset* offset = nullptr;

bool equals_ignoring_time(const RequiredMemoryAssignment& other) const {
return memory_space == other.memory_space && offset == other.offset;
Expand All @@ -354,6 +355,8 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap<HloValue> {
bool operator!=(const RequiredMemoryAssignment& other) const {
return !(*this == other);
}

std::string ToString() const;
};

// A struct that contains a pointer to loop-optimized allocation along with
Expand Down Expand Up @@ -622,14 +625,18 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap<HloValue> {
// only_extend_existing_allocation is true, no new Allocations will be created
// while processing the resulting AllocationRequest, and we only need to
// extend an existing Allocation's end_time.
//
// * processed_allocation_values: The AllocationValues that have already been
// processed for the same parent HloValue as is used in the request.
AllocationRequest CreateAllocationRequest(
AllocationValue& allocation_value,
AllocationValue& allocation_value_to_update,
const AllocationValue::Use& use, const AllocationValue::Use* previous_use,
AliasedOffset* preferred_offset, int64_t definition_time,
bool require_no_copy_alternate_mem_allocation,
const std::vector<int64_t>& all_use_times,
bool only_extend_existing_allocation);
bool only_extend_existing_allocation,
absl::Span<AllocationValue> processed_allocation_values);

// Returns true, if the allocation value requires a pinned allocation in the
// alternate memory space.
Expand Down Expand Up @@ -658,6 +665,18 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap<HloValue> {
absl::StatusOr<AllocationResult> AllocateAllocationValues(
absl::Span<AllocationValue> allocation_values);

// Checks for a situation in which an HloValue has more than one live
// AllocationValue at the same time, and the already processed AllocationValue
// has been given alternate memory at the start of the second AllocationValue.
// If such a case is detected, we set
// request.no_copy_chunk_inclusive_start_time with the time where the first
// AllocationValue left off. AllocateInAlternateMemoryNoCopy() takes advantage
// of that information.
void CheckAndUpdateForDualLiveAllocationValues(
const std::optional<RequiredMemoryAssignment>&
required_memory_assignment_at_start,
AllocationRequest& request);

// Finds an allocation for an allocation request for a segment (see the
// documentation for AllocationRequest above how a segment is defined).
//
Expand Down
11 changes: 11 additions & 0 deletions xla/service/memory_space_assignment/allocation_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,17 @@ struct AllocationRequest {
// Data structure that contains the options for making window prefetched
// allocations.
const WindowPrefetchedAllocation::Options* window_prefetch_options = nullptr;
// Previously processed AllocationValues, with the same parent HloValue as the
// request.
absl::Span<AllocationValue> processed_allocation_values;
// An optional override starting time for the placement of a chunk on the MSA
// heap, for a no-copy allocation (see
// MsaAlgorithm::AllocateInAlternateMemoryNoCopy() for more details).
//
// Note, this override is used when an aliased AllocationValue has already
// done some of the heap allocation for us. So this request picks up where it
// left off.
std::optional<int64_t> no_copy_chunk_inclusive_start_time;
};

// Result of an allocation, prefetch, eviction etc. request. The result is
Expand Down
Loading