Skip to content

Commit

Permalink
PR #20954: [XLA:GPU] migrate command buffer to use buffer_use.h
Browse files Browse the repository at this point in the history
Imported from GitHub PR #20954

This PR does not introduce new functionality, it's a refactoring, and is covered by existing command buffer tests.

Copybara import of the project:

--
b1a0efc by Shawn Wang <[email protected]>:

migrate command buffer to use buffer_use.h

Merging this change closes #20954

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20954 from shawnwang18:shawnw/migrate_buffer_use b1a0efc
PiperOrigin-RevId: 711762971
  • Loading branch information
shawnwang18 authored and Google-ML-Automation committed Jan 11, 2025
1 parent 8cfcd2e commit 8e69fb5
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 155 deletions.
5 changes: 4 additions & 1 deletion xla/python/pjit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ void PjitFunction::InitExecutables() {
}
}

PjitFunction::~PjitFunction() = default;
PjitFunction::~PjitFunction() {
nb::ft_object_guard lock(cache_);
executables_ = nullptr;
}

void CallShardArgFallback(
nb::handle arg, nb::handle sharding, nb::handle layout,
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ cc_library(
"//xla/ffi:ffi_api",
"//xla/ffi/api:c_api",
"//xla/hlo/ir:hlo",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
"//xla/service:computation_placer",
Expand Down Expand Up @@ -138,6 +139,7 @@ cc_library(
":wait_for_streams_thunk",
":while_thunk",
"//xla:util",
"//xla/runtime:buffer_use",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand All @@ -155,6 +157,7 @@ xla_test(
":command_buffer_cmd",
":thunk",
"//xla:types",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:executable",
"//xla/service:platform_util",
Expand Down Expand Up @@ -348,6 +351,7 @@ xla_test(
"//xla:shape_util",
"//xla:types",
"//xla:xla_data_proto_cc",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:executable",
"//xla/service:platform_util",
Expand Down
119 changes: 61 additions & 58 deletions xla/service/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ limitations under the License.
namespace xla::gpu {

using ExecutionScopeId = se::CommandBuffer::ExecutionScopeId;
using MemoryAccess = CommandBufferCmd::MemoryAccess;
using MemoryAccess = BufferUse::MemoryAccess;

std::string CommandBufferCmdString(CommandBufferCmdType type) {
switch (type) {
Expand Down Expand Up @@ -195,13 +195,13 @@ CommandBufferCmdSequence::CommandBufferCmdSequence(
: synchronization_mode_(synchronization_mode) {}

void CommandBufferCmdSequence::Append(std::unique_ptr<CommandBufferCmd> cmd) {
for (const CommandBufferCmd::BufferUsage& buffer : cmd->buffers()) {
for (const BufferUse& buffer : cmd->buffers()) {
buffers_.insert(buffer);
allocs_indices_.insert(buffer.slice.index());
allocs_indices_.insert(buffer.slice().index());
}

ExecutionStreamId execution_stream_id = cmd->execution_stream_id();
CommandBufferCmd::BufferUsageVector buffers = cmd->buffers();
CommandBufferCmd::BufferUseVector buffers = cmd->buffers();
bool requires_barrier = HasConflicts(execution_stream_id, buffers);

// Always add barriers between commands if we want to serialize execution.
Expand Down Expand Up @@ -254,24 +254,26 @@ bool Overlaps(const BufferAllocation::Slice& slice,

bool CommandBufferCmdSequence::HasConflicts(
ExecutionStreamId execution_stream_id,
const CommandBufferCmd::BufferUsageVector& buffers) {
const CommandBufferCmd::BufferUseVector& buffers) {
auto& rwset = read_write_sets_[execution_stream_id];

return absl::c_any_of(buffers, [&](const auto& buffer) {
return buffer.access == MemoryAccess::kWrite
? Overlaps(buffer.slice, rwset.write) ||
Overlaps(buffer.slice, rwset.read)
: Overlaps(buffer.slice, rwset.write);
return buffer.access() == MemoryAccess::kWrite
? Overlaps(buffer.slice(), rwset.write) ||
Overlaps(buffer.slice(), rwset.read)
: Overlaps(buffer.slice(), rwset.write);
});
}

void CommandBufferCmdSequence::TrackBuffers(
ExecutionStreamId execution_stream_id,
const CommandBufferCmd::BufferUsageVector& buffers) {
const CommandBufferCmd::BufferUseVector& buffers) {
auto& rwset = read_write_sets_[execution_stream_id];
for (const CommandBufferCmd::BufferUsage& buffer : buffers) {
if (buffer.access == MemoryAccess::kWrite) rwset.write.insert(buffer.slice);
if (buffer.access == MemoryAccess::kRead) rwset.read.insert(buffer.slice);
for (const BufferUse& buffer : buffers) {
if (buffer.access() == MemoryAccess::kWrite)
rwset.write.insert(buffer.slice());
if (buffer.access() == MemoryAccess::kRead)
rwset.read.insert(buffer.slice());
}
}

Expand Down Expand Up @@ -346,8 +348,8 @@ absl::Status CommandBufferCmdSequence::Record(
return absl::OkStatus();
}

const absl::flat_hash_set<CommandBufferCmd::BufferUsage>&
CommandBufferCmdSequence::buffers() const {
const absl::flat_hash_set<BufferUse>& CommandBufferCmdSequence::buffers()
const {
return buffers_;
}

Expand All @@ -369,13 +371,13 @@ std::vector<bool> CommandBufferCmdSequence::barriers() const {

TracedCommandBuffer::TracedCommandBuffer(
const CommandBufferCmd* trace_cmd,
CommandBufferCmd::BufferUsageVector buffers, int64_t capacity)
CommandBufferCmd::BufferUseVector buffers, int64_t capacity)
: trace_cmd_(trace_cmd), capacity_(capacity), entries_(capacity) {
CHECK_GT(capacity, 0) << "capacity must be larger than 0"; // NOLINT
// Collect unique buffer allocation indices in a set first and convert to
// vector as flat hash set iteration has measurable overheads.
absl::flat_hash_set<BufferAllocation::Index> allocs_indices;
for (auto& buffer : buffers) allocs_indices.insert(buffer.slice.index());
for (auto& buffer : buffers) allocs_indices.insert(buffer.slice().index());
allocs_indices_.assign(allocs_indices.begin(), allocs_indices.end());
}

Expand Down Expand Up @@ -535,7 +537,7 @@ ComputationIdCmd::ComputationIdCmd(ExecutionStreamId execution_stream_id,
dest_(dest),
kind_(kind) {}

CommandBufferCmd::BufferUsageVector ComputationIdCmd::buffers() {
CommandBufferCmd::BufferUseVector ComputationIdCmd::buffers() {
return {{dest_, MemoryAccess::kWrite}};
}

Expand Down Expand Up @@ -674,8 +676,8 @@ absl::Status LaunchCmd::Record(const Thunk::ExecuteParams& execute_params,
dims_.block_counts(), *kernel, *kernel_args);
}

CommandBufferCmd::BufferUsageVector LaunchCmd::buffers() {
BufferUsageVector buffers;
CommandBufferCmd::BufferUseVector LaunchCmd::buffers() {
BufferUseVector buffers;
for (int32_t i = 0; i < args_.size(); ++i) {
buffers.emplace_back(args_[i], args_access_[i]);
}
Expand Down Expand Up @@ -746,8 +748,8 @@ absl::Status CustomKernelLaunchCmd::Record(
custom_kernel_.block_dims(), *kernel, kernel_args);
}

CommandBufferCmd::BufferUsageVector CustomKernelLaunchCmd::buffers() {
BufferUsageVector buffers;
CommandBufferCmd::BufferUseVector CustomKernelLaunchCmd::buffers() {
BufferUseVector buffers;
for (int32_t i = 0; i < args_.size(); ++i) {
buffers.emplace_back(args_[i], args_access_[i]);
}
Expand Down Expand Up @@ -790,7 +792,7 @@ absl::Status MemcpyDeviceToDeviceCmd::Record(
num_bytes_);
}

CommandBufferCmd::BufferUsageVector MemcpyDeviceToDeviceCmd::buffers() {
CommandBufferCmd::BufferUseVector MemcpyDeviceToDeviceCmd::buffers() {
return {{dst_, MemoryAccess::kWrite}, {src_, MemoryAccess::kRead}};
}

Expand Down Expand Up @@ -822,7 +824,7 @@ absl::Status MemzeroCmd::Record(const Thunk::ExecuteParams& execute_params,
/*num_elements=*/dst_.size());
}

CommandBufferCmd::BufferUsageVector MemzeroCmd::buffers() {
CommandBufferCmd::BufferUseVector MemzeroCmd::buffers() {
return {{dst_, MemoryAccess::kWrite}};
}

Expand Down Expand Up @@ -857,7 +859,7 @@ absl::Status Memset32Cmd::Record(const Thunk::ExecuteParams& execute_params,
/*num_elements=*/dst_.size() / sizeof(uint32_t));
}

CommandBufferCmd::BufferUsageVector Memset32Cmd::buffers() {
CommandBufferCmd::BufferUseVector Memset32Cmd::buffers() {
return {{dst_, MemoryAccess::kWrite}};
}

Expand Down Expand Up @@ -894,8 +896,8 @@ absl::Status IfCmd::Record(const Thunk::ExecuteParams& execute_params,

bool IfCmd::force_update() { return then_commands_.force_update(); }

CommandBufferCmd::BufferUsageVector IfCmd::buffers() {
absl::flat_hash_set<CommandBufferCmd::BufferUsage> buffers;
CommandBufferCmd::BufferUseVector IfCmd::buffers() {
absl::flat_hash_set<BufferUse> buffers;
buffers.emplace(pred_, MemoryAccess::kRead);
buffers.insert(then_commands_.buffers().begin(),
then_commands_.buffers().end());
Expand Down Expand Up @@ -942,8 +944,8 @@ bool IfElseCmd::force_update() {
return (then_commands_.force_update() || else_commands_.force_update());
}

CommandBufferCmd::BufferUsageVector IfElseCmd::buffers() {
absl::flat_hash_set<CommandBufferCmd::BufferUsage> buffers;
CommandBufferCmd::BufferUseVector IfElseCmd::buffers() {
absl::flat_hash_set<BufferUse> buffers;
buffers.emplace(pred_, MemoryAccess::kRead);
buffers.insert(then_commands_.buffers().begin(),
then_commands_.buffers().end());
Expand Down Expand Up @@ -992,8 +994,8 @@ bool CaseCmd::force_update() {
[](const auto& seq) { return seq.force_update(); });
}

CommandBufferCmd::BufferUsageVector CaseCmd::buffers() {
absl::flat_hash_set<CommandBufferCmd::BufferUsage> buffers;
CommandBufferCmd::BufferUseVector CaseCmd::buffers() {
absl::flat_hash_set<BufferUse> buffers;
buffers.emplace(index_, MemoryAccess::kRead);
for (auto& branch : branches_commands_) {
buffers.insert(branch.buffers().begin(), branch.buffers().end());
Expand Down Expand Up @@ -1039,8 +1041,8 @@ absl::Status ForCmd::Record(const Thunk::ExecuteParams& execute_params,

bool ForCmd::force_update() { return body_commands_.force_update(); }

CommandBufferCmd::BufferUsageVector ForCmd::buffers() {
absl::flat_hash_set<CommandBufferCmd::BufferUsage> buffers;
CommandBufferCmd::BufferUseVector ForCmd::buffers() {
absl::flat_hash_set<BufferUse> buffers;
buffers.emplace(loop_counter_, MemoryAccess::kWrite);
buffers.insert(body_commands_.buffers().begin(),
body_commands_.buffers().end());
Expand Down Expand Up @@ -1089,8 +1091,8 @@ bool WhileCmd::force_update() {
return (cond_commands_.force_update() || body_commands_.force_update());
}

CommandBufferCmd::BufferUsageVector WhileCmd::buffers() {
absl::flat_hash_set<CommandBufferCmd::BufferUsage> buffers;
CommandBufferCmd::BufferUseVector WhileCmd::buffers() {
absl::flat_hash_set<BufferUse> buffers;
buffers.emplace(pred_, MemoryAccess::kWrite);
buffers.insert(cond_commands_.buffers().begin(),
cond_commands_.buffers().end());
Expand Down Expand Up @@ -1152,7 +1154,7 @@ absl::Status GemmCmd::Record(const Thunk::ExecuteParams& execute_params,
});
}

CommandBufferCmd::BufferUsageVector GemmCmd::buffers() {
CommandBufferCmd::BufferUseVector GemmCmd::buffers() {
return {{lhs_buffer_, MemoryAccess::kRead},
{rhs_buffer_, MemoryAccess::kRead},
{output_buffer_, MemoryAccess::kWrite},
Expand Down Expand Up @@ -1292,8 +1294,8 @@ absl::Status CublasLtCmd::Record(const Thunk::ExecuteParams& execute_params,
});
}

CommandBufferCmd::BufferUsageVector CublasLtCmd::buffers() {
BufferUsageVector buffer_usage;
CommandBufferCmd::BufferUseVector CublasLtCmd::buffers() {
BufferUseVector buffer_usage;
buffer_usage.reserve(13);
buffer_usage.push_back({a_buffer_, MemoryAccess::kRead});
buffer_usage.push_back({b_buffer_, MemoryAccess::kRead});
Expand Down Expand Up @@ -1366,8 +1368,8 @@ absl::Status CuDnnCmd::Record(const Thunk::ExecuteParams& execute_params,
});
}

CommandBufferCmd::BufferUsageVector CuDnnCmd::buffers() {
CommandBufferCmd::BufferUsageVector buffer_usage;
CommandBufferCmd::BufferUseVector CuDnnCmd::buffers() {
CommandBufferCmd::BufferUseVector buffer_usage;
buffer_usage.reserve(args_.size());
for (int i = 0; i < args_.size() - 1; ++i) {
buffer_usage.push_back({args_[i], MemoryAccess::kRead});
Expand Down Expand Up @@ -1524,8 +1526,8 @@ absl::Status CustomCallCmd::RecordXlaFfiCall(
*nested_cmd);
}

CommandBufferCmd::BufferUsageVector CustomCallCmd::buffers() {
CommandBufferCmd::BufferUsageVector buffer_usage;
CommandBufferCmd::BufferUseVector CustomCallCmd::buffers() {
CommandBufferCmd::BufferUseVector buffer_usage;
for (auto& slices : {operands_, results_}) {
for (const std::optional<Slice>& slice : slices) {
if (!slice.has_value()) continue;
Expand Down Expand Up @@ -1558,7 +1560,7 @@ absl::Status BarrierCmd::Record(const Thunk::ExecuteParams& execute_params,
return absl::OkStatus();
}

BarrierCmd::BufferUsageVector BarrierCmd::buffers() { return {}; }
BarrierCmd::BufferUseVector BarrierCmd::buffers() { return {}; }

//===----------------------------------------------------------------------===//
// CollectiveCmd
Expand Down Expand Up @@ -1676,8 +1678,8 @@ absl::Status AllReduceCmd::Record(const Thunk::ExecuteParams& execute_params,
});
}

CommandBufferCmd::BufferUsageVector AllReduceCmd::buffers() {
BufferUsageVector buffer_usage;
CommandBufferCmd::BufferUseVector AllReduceCmd::buffers() {
BufferUseVector buffer_usage;
for (auto& buffer : buffers_) {
buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead);
buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite);
Expand Down Expand Up @@ -1743,8 +1745,8 @@ absl::Status ReduceScatterCmd::Record(
});
}

CommandBufferCmd::BufferUsageVector ReduceScatterCmd::buffers() {
BufferUsageVector buffer_usage;
CommandBufferCmd::BufferUseVector ReduceScatterCmd::buffers() {
BufferUseVector buffer_usage;
for (auto& buffer : buffers_) {
buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead);
buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite);
Expand Down Expand Up @@ -1807,8 +1809,8 @@ absl::Status AllToAllCmd::Record(const Thunk::ExecuteParams& execute_params,
});
}

CommandBufferCmd::BufferUsageVector AllToAllCmd::buffers() {
BufferUsageVector buffer_usage;
CommandBufferCmd::BufferUseVector AllToAllCmd::buffers() {
BufferUseVector buffer_usage;
for (auto& buffer : buffers_) {
buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead);
buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite);
Expand Down Expand Up @@ -1870,8 +1872,8 @@ absl::Status AllGatherCmd::Record(const Thunk::ExecuteParams& execute_params,
});
}

CommandBufferCmd::BufferUsageVector AllGatherCmd::buffers() {
BufferUsageVector buffer_usage;
CommandBufferCmd::BufferUseVector AllGatherCmd::buffers() {
BufferUseVector buffer_usage;
for (auto& buffer : buffers_) {
buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead);
buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite);
Expand Down Expand Up @@ -1935,8 +1937,8 @@ absl::Status CollectiveBroadcastCmd::Record(
});
}

CommandBufferCmd::BufferUsageVector CollectiveBroadcastCmd::buffers() {
BufferUsageVector buffer_usage;
CommandBufferCmd::BufferUseVector CollectiveBroadcastCmd::buffers() {
BufferUseVector buffer_usage;
for (auto& buffer : buffers_) {
buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead);
buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite);
Expand Down Expand Up @@ -2176,14 +2178,15 @@ absl::Status DynamicSliceFusionCmd::Record(
*nested_command_buffer);
}

CommandBufferCmd::BufferUsageVector DynamicSliceFusionCmd::buffers() {
CommandBufferCmd::BufferUsageVector buffers;
CommandBufferCmd::BufferUseVector DynamicSliceFusionCmd::buffers() {
CommandBufferCmd::BufferUseVector buffers;
auto embed_buffers = embedded_commands_->buffers();
for (auto buffer_usage : embed_buffers) {
CHECK(embeded_to_origin_slice_map_[buffer_usage.slice.index()].has_value());
CHECK(
embeded_to_origin_slice_map_[buffer_usage.slice().index()].has_value());
buffers.emplace_back(
embeded_to_origin_slice_map_[buffer_usage.slice.index()].value(),
buffer_usage.access);
embeded_to_origin_slice_map_[buffer_usage.slice().index()].value(),
buffer_usage.access());
}
return buffers;
}
Expand Down
Loading

0 comments on commit 8e69fb5

Please sign in to comment.