diff --git a/xla/hlo/transforms/host_offloader.cc b/xla/hlo/transforms/host_offloader.cc index 9255f0a6d8870..c05659eb5bb21 100644 --- a/xla/hlo/transforms/host_offloader.cc +++ b/xla/hlo/transforms/host_offloader.cc @@ -255,7 +255,7 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( instruction_and_shape_index.shape_index); CHECK(output_shape.has_layout()) << "Expecting output shape of entry computation to have a layout."; - if (output_shape.layout().memory_space() == kHostMemorySpaceColor) { + if (output_shape.layout().memory_space() == Layout::kHostMemorySpace) { VLOG(2) << absl::StreamFormat( "Memory offloaded starting from %s is output streamed", starting_instruction_and_index.ToString()); @@ -280,7 +280,7 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( // Finished walking all host memory paths. Now we'll make all the necessary // changes. const bool set_buffers_changed = SetBuffersToMemorySpaceColor( - buffers_to_set_to_host_memory, kHostMemorySpaceColor); + buffers_to_set_to_host_memory, Layout::kHostMemorySpace); changed = changed || set_buffers_changed; for (HloInstruction* dus : dynamic_update_slices) { @@ -349,7 +349,7 @@ absl::StatusOr HostOffloader::HandleInputStreaming( entry_computation_layout.parameter_shape(i), [&](const Shape& subshape, const ShapeIndex& index) { if (subshape.has_layout() && - subshape.layout().memory_space() == kHostMemorySpaceColor) { + subshape.layout().memory_space() == Layout::kHostMemorySpace) { HloInstruction* parameter_instruction = entry_computation->parameter_instruction(i); VLOG(1) << "Host parameter streamed into program with shape: " @@ -395,7 +395,7 @@ absl::StatusOr HostOffloader::HandleMoveToHostCustomCall( HloInstruction* copy_to_host = data_to_copy->parent()->AddInstruction(HloInstruction::CreateUnary( data_to_copy->shape(), HloOpcode::kCopy, data_to_copy)); - SetMemorySpace(copy_to_host->mutable_shape(), kHostMemorySpaceColor); + SetMemorySpace(copy_to_host->mutable_shape(), Layout::kHostMemorySpace); TF_RETURN_IF_ERROR( custom_call_instruction->ReplaceAllUsesWith(copy_to_host)); VLOG(2) << absl::StreamFormat( @@ -487,7 +487,7 @@ absl::StatusOr HostOffloader::InsertCopyBetween( copy_to_host = data_to_copy->parent()->AddInstruction(HloInstruction::CreateUnary( data_to_copy->shape(), HloOpcode::kCopy, data_to_copy)); - SetMemorySpace(copy_to_host->mutable_shape(), kHostMemorySpaceColor); + SetMemorySpace(copy_to_host->mutable_shape(), Layout::kHostMemorySpace); copies_created_after_[data_to_copy] = copy_to_host; } else { // We already have a copy which feeds into this instruction. @@ -619,7 +619,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( SetMemorySpace(ShapeUtil::GetMutableSubshape( instruction_and_shape.instruction->mutable_shape(), instruction_and_shape.shape_index), - kHostMemorySpaceColor); + Layout::kHostMemorySpace); HloInstruction* instruction = instruction_and_shape.instruction; if (instruction->opcode() == HloOpcode::kParameter) { // If this is a parameter of a while_body, we also need to find the @@ -645,7 +645,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( SetMemorySpace(ShapeUtil::GetMutableSubshape( while_condition_parameter->mutable_shape(), instruction_and_shape.shape_index), - kHostMemorySpaceColor); + Layout::kHostMemorySpace); // Walk further down the graph and set the memory spaces of all uses // too. This includes verifying that no compute is done on the buffer. // Another, better way, to do this, is to walk down the graph starting @@ -669,7 +669,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( ShapeUtil::GetMutableSubshape( nested_instruction_and_shape.instruction->mutable_shape(), nested_instruction_and_shape.shape_index), - kHostMemorySpaceColor); + Layout::kHostMemorySpace); TF_ASSIGN_OR_RETURN( const std::vector successors, host_offload_utils::GetSuccessors( @@ -711,7 +711,8 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( VLOG(1) << absl::StreamFormat( "Created new AllocateBuffer instruction \"%s\"", allocate_buffer->ToString()); - SetMemorySpace(allocate_buffer->mutable_shape(), kHostMemorySpaceColor); + SetMemorySpace(allocate_buffer->mutable_shape(), + Layout::kHostMemorySpace); for (int64_t index : operand_indices) { TF_RETURN_IF_ERROR( broadcast_user->ReplaceOperandWith(index, allocate_buffer)); @@ -793,7 +794,7 @@ absl::StatusOr HostOffloader::ApplySchedulingFix( continue; } if (instruction->shape().layout().memory_space() != - kHostMemorySpaceColor) { + Layout::kHostMemorySpace) { continue; } // Replace DynamicUpdateSlice's 1st operand with a copy in case it diff --git a/xla/hlo/transforms/host_offloader.h b/xla/hlo/transforms/host_offloader.h index 8e79a44926178..5055aa15f10a8 100644 --- a/xla/hlo/transforms/host_offloader.h +++ b/xla/hlo/transforms/host_offloader.h @@ -59,8 +59,7 @@ class HloCostAnalysis; // pass. class HostOffloader : public HloModulePass { public: - explicit HostOffloader(int64_t host_memory_space_color) - : kHostMemorySpaceColor(host_memory_space_color) {} + HostOffloader() = default; ~HostOffloader() override = default; absl::string_view name() const override { return "host-offloader"; } @@ -77,7 +76,6 @@ class HostOffloader : public HloModulePass { // instruction chain) are ignored. absl::StatusOr ProcessNextMoveToHostInstr(HloComputation* computation); - const int64_t kHostMemorySpaceColor; absl::flat_hash_set already_visited_move_to_host_custom_calls_; absl::flat_hash_set dynamic_update_slices_already_allocated_; diff --git a/xla/hlo/transforms/host_offloader_test.cc b/xla/hlo/transforms/host_offloader_test.cc index d38526e93178a..9eff4508838fd 100644 --- a/xla/hlo/transforms/host_offloader_test.cc +++ b/xla/hlo/transforms/host_offloader_test.cc @@ -63,7 +63,7 @@ class HostOffloaderTest : public HloHardwareIndependentTestBase { after_layout); TF_ASSIGN_OR_RETURN(bool legal_changed, host_offload_legalize.Run(module)); changed |= legal_changed; - HostOffloader host_offloader(Layout::kHostMemorySpace); + HostOffloader host_offloader; TF_ASSIGN_OR_RETURN(bool offload_changed, host_offloader.Run(module)); changed |= offload_changed; return changed; diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 5c6a5ab6ca172..bb720566717bf 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1649,8 +1649,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // also have unsorted update_window_dims. pipeline.AddPass(); - pipeline.AddPass( - static_cast(stream_executor::MemoryType::kHost)); + pipeline.AddPass(); TF_RETURN_IF_ERROR( AddConvAndGemmAutotuningPasses(&pipeline, gpu_version, options,