Skip to content

Commit

Permalink
Remove host memory space as input to HostOffloader constructor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714261524
  • Loading branch information
SandSnip3r authored and Google-ML-Automation committed Jan 11, 2025
1 parent 0947e8e commit 03077ca
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 16 deletions.
21 changes: 11 additions & 10 deletions xla/hlo/transforms/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
instruction_and_shape_index.shape_index);
CHECK(output_shape.has_layout())
<< "Expecting output shape of entry computation to have a layout.";
if (output_shape.layout().memory_space() == kHostMemorySpaceColor) {
if (output_shape.layout().memory_space() == Layout::kHostMemorySpace) {
VLOG(2) << absl::StreamFormat(
"Memory offloaded starting from %s is output streamed",
starting_instruction_and_index.ToString());
Expand All @@ -280,7 +280,7 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
// Finished walking all host memory paths. Now we'll make all the necessary
// changes.
const bool set_buffers_changed = SetBuffersToMemorySpaceColor(
buffers_to_set_to_host_memory, kHostMemorySpaceColor);
buffers_to_set_to_host_memory, Layout::kHostMemorySpace);
changed = changed || set_buffers_changed;

for (HloInstruction* dus : dynamic_update_slices) {
Expand Down Expand Up @@ -349,7 +349,7 @@ absl::StatusOr<bool> HostOffloader::HandleInputStreaming(
entry_computation_layout.parameter_shape(i),
[&](const Shape& subshape, const ShapeIndex& index) {
if (subshape.has_layout() &&
subshape.layout().memory_space() == kHostMemorySpaceColor) {
subshape.layout().memory_space() == Layout::kHostMemorySpace) {
HloInstruction* parameter_instruction =
entry_computation->parameter_instruction(i);
VLOG(1) << "Host parameter streamed into program with shape: "
Expand Down Expand Up @@ -395,7 +395,7 @@ absl::StatusOr<bool> HostOffloader::HandleMoveToHostCustomCall(
HloInstruction* copy_to_host =
data_to_copy->parent()->AddInstruction(HloInstruction::CreateUnary(
data_to_copy->shape(), HloOpcode::kCopy, data_to_copy));
SetMemorySpace(copy_to_host->mutable_shape(), kHostMemorySpaceColor);
SetMemorySpace(copy_to_host->mutable_shape(), Layout::kHostMemorySpace);
TF_RETURN_IF_ERROR(
custom_call_instruction->ReplaceAllUsesWith(copy_to_host));
VLOG(2) << absl::StreamFormat(
Expand Down Expand Up @@ -487,7 +487,7 @@ absl::StatusOr<bool> HostOffloader::InsertCopyBetween(
copy_to_host =
data_to_copy->parent()->AddInstruction(HloInstruction::CreateUnary(
data_to_copy->shape(), HloOpcode::kCopy, data_to_copy));
SetMemorySpace(copy_to_host->mutable_shape(), kHostMemorySpaceColor);
SetMemorySpace(copy_to_host->mutable_shape(), Layout::kHostMemorySpace);
copies_created_after_[data_to_copy] = copy_to_host;
} else {
// We already have a copy which feeds into this instruction.
Expand Down Expand Up @@ -619,7 +619,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice(
SetMemorySpace(ShapeUtil::GetMutableSubshape(
instruction_and_shape.instruction->mutable_shape(),
instruction_and_shape.shape_index),
kHostMemorySpaceColor);
Layout::kHostMemorySpace);
HloInstruction* instruction = instruction_and_shape.instruction;
if (instruction->opcode() == HloOpcode::kParameter) {
// If this is a parameter of a while_body, we also need to find the
Expand All @@ -645,7 +645,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice(
SetMemorySpace(ShapeUtil::GetMutableSubshape(
while_condition_parameter->mutable_shape(),
instruction_and_shape.shape_index),
kHostMemorySpaceColor);
Layout::kHostMemorySpace);
// Walk further down the graph and set the memory spaces of all uses
// too. This includes verifying that no compute is done on the buffer.
// Another, better way, to do this, is to walk down the graph starting
Expand All @@ -669,7 +669,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice(
ShapeUtil::GetMutableSubshape(
nested_instruction_and_shape.instruction->mutable_shape(),
nested_instruction_and_shape.shape_index),
kHostMemorySpaceColor);
Layout::kHostMemorySpace);
TF_ASSIGN_OR_RETURN(
const std::vector<InstructionAndShapeIndex> successors,
host_offload_utils::GetSuccessors(
Expand Down Expand Up @@ -711,7 +711,8 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice(
VLOG(1) << absl::StreamFormat(
"Created new AllocateBuffer instruction \"%s\"",
allocate_buffer->ToString());
SetMemorySpace(allocate_buffer->mutable_shape(), kHostMemorySpaceColor);
SetMemorySpace(allocate_buffer->mutable_shape(),
Layout::kHostMemorySpace);
for (int64_t index : operand_indices) {
TF_RETURN_IF_ERROR(
broadcast_user->ReplaceOperandWith(index, allocate_buffer));
Expand Down Expand Up @@ -793,7 +794,7 @@ absl::StatusOr<bool> HostOffloader::ApplySchedulingFix(
continue;
}
if (instruction->shape().layout().memory_space() !=
kHostMemorySpaceColor) {
Layout::kHostMemorySpace) {
continue;
}
// Replace DynamicUpdateSlice's 1st operand with a copy in case it
Expand Down
4 changes: 1 addition & 3 deletions xla/hlo/transforms/host_offloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class HloCostAnalysis;
// pass.
class HostOffloader : public HloModulePass {
public:
explicit HostOffloader(int64_t host_memory_space_color)
: kHostMemorySpaceColor(host_memory_space_color) {}
HostOffloader() = default;
~HostOffloader() override = default;

absl::string_view name() const override { return "host-offloader"; }
Expand All @@ -77,7 +76,6 @@ class HostOffloader : public HloModulePass {
// instruction chain) are ignored.
absl::StatusOr<bool> ProcessNextMoveToHostInstr(HloComputation* computation);

const int64_t kHostMemorySpaceColor;
absl::flat_hash_set<HloInstruction*>
already_visited_move_to_host_custom_calls_;
absl::flat_hash_set<HloInstruction*> dynamic_update_slices_already_allocated_;
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/transforms/host_offloader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class HostOffloaderTest : public HloHardwareIndependentTestBase {
after_layout);
TF_ASSIGN_OR_RETURN(bool legal_changed, host_offload_legalize.Run(module));
changed |= legal_changed;
HostOffloader host_offloader(Layout::kHostMemorySpace);
HostOffloader host_offloader;
TF_ASSIGN_OR_RETURN(bool offload_changed, host_offloader.Run(module));
changed |= offload_changed;
return changed;
Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1649,8 +1649,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
// also have unsorted update_window_dims.
pipeline.AddPass<ScatterSimplifier>();

pipeline.AddPass<HostOffloader>(
static_cast<int64_t>(stream_executor::MemoryType::kHost));
pipeline.AddPass<HostOffloader>();

TF_RETURN_IF_ERROR(
AddConvAndGemmAutotuningPasses(&pipeline, gpu_version, options,
Expand Down

0 comments on commit 03077ca

Please sign in to comment.