Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:pjrt] Add support for forwarding FFI context to C API client #21317

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions xla/ffi/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@ cc_library(
hdrs = ["execution_context.h"],
deps = [
":type_id_registry",
"//xla:util",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
30 changes: 23 additions & 7 deletions xla/ffi/execution_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ limitations under the License.
#include <utility>

#include "absl/container/flat_hash_map.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "xla/tsl/platform/errors.h"
#include "xla/util.h"

namespace xla::ffi {

Expand All @@ -44,9 +46,9 @@ absl::Status ExecutionContext::InsertUserData(TypeId type_id,

auto emplaced = user_data_.emplace(type_id, std::move(data));
if (!emplaced.second) {
return absl::AlreadyExistsError(
absl::StrCat("User data with type id ", type_id.value(),
" already exists in execution context"));
return Internal(
"User data with type id %d already exists in execution context",
type_id.value());
}
return absl::OkStatus();
}
Expand All @@ -55,11 +57,25 @@ absl::StatusOr<ExecutionContext::UserData*> ExecutionContext::LookupUserData(
TypeId type_id) const {
auto it = user_data_.find(type_id);
if (it == user_data_.end()) {
return absl::NotFoundError(absl::StrCat("User data with type id ",
type_id.value(),
" not found in execution context"));
return NotFound("User data with type id %d not found in execution context",
type_id.value());
}
return it->second.get();
}

void ExecutionContext::ForEach(
absl::FunctionRef<void(TypeId type_id, void* data)> fn) const {
for (auto& [type_id, user_data] : user_data_) {
fn(type_id, user_data->data());
}
}

absl::Status ExecutionContext::ForEachWithStatus(
absl::FunctionRef<absl::Status(TypeId type_id, void* data)> fn) const {
for (auto& [type_id, user_data] : user_data_) {
TF_RETURN_IF_ERROR(fn(type_id, user_data->data()));
}
return absl::OkStatus();
}

} // namespace xla::ffi
10 changes: 8 additions & 2 deletions xla/ffi/execution_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ limitations under the License.
#include <utility>

#include "absl/container/flat_hash_map.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/ffi/type_id_registry.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"

namespace xla::ffi {

Expand Down Expand Up @@ -76,6 +77,11 @@ class ExecutionContext {
return user_data->data();
}

// Visit all user data in the execution context.
void ForEach(absl::FunctionRef<void(TypeId type_id, void* data)> fn) const;
absl::Status ForEachWithStatus(
absl::FunctionRef<absl::Status(TypeId type_id, void* data)> fn) const;

private:
// An RAII wrapper for opaque user data. Optional deleter will be called when
// UserData is destroyed together with the execution context. If deleter is
Expand Down
14 changes: 11 additions & 3 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -800,10 +800,12 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/ffi:execution_context",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"//xla/pjrt/c:pjrt_c_api_hdrs",
"//xla/pjrt/c:pjrt_c_api_helpers",
"//xla/pjrt/c:pjrt_c_api_layouts_extension_hdrs",
Expand All @@ -815,7 +817,9 @@ cc_library(
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_proto_cc",
"//xla/tsl/framework:allocator",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:status",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
Expand All @@ -833,10 +837,7 @@ cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:fingerprint",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
],
)

Expand All @@ -855,13 +856,20 @@ xla_cc_test(
":pjrt_device_description",
":pjrt_executable",
"//xla:cpu_function_runtime",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/parser:hlo_parser",
"//xla/pjrt/c:pjrt_c_api_cpu_internal",
"//xla/pjrt/c:pjrt_c_api_hdrs",
"//xla/tests:literal_test_util",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
Expand Down
4 changes: 4 additions & 0 deletions xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,15 @@ cc_library(
hdrs = ["pjrt_c_api_cpu_internal.h"],
visibility = ["//visibility:public"],
deps = [
":pjrt_c_api_ffi_extension_hdrs",
":pjrt_c_api_ffi_internal",
":pjrt_c_api_hdrs",
":pjrt_c_api_helpers",
":pjrt_c_api_layouts_extension_hdrs",
":pjrt_c_api_memory_descriptions_extension_hdrs",
":pjrt_c_api_wrapper_impl",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt/plugin/xla_cpu:cpu_client_options",
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"@com_google_absl//absl/status",
Expand Down
4 changes: 4 additions & 0 deletions xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# PJRT C API changelog

## 0.64
* Added ``context`` field of type ``PJRT_ExecuteContext *`` in ``PJRT_ExecuteOptions``.

## 0.63
* Added types F4E2M1FN and F8E8M0FNU.

Expand Down
5 changes: 3 additions & 2 deletions xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 63
#define PJRT_API_MINOR 64

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -1535,8 +1535,9 @@ struct PJRT_ExecuteOptions {
// during the call.
const int64_t* non_donatable_input_indices;
size_t num_non_donatable_input_indices;
PJRT_ExecuteContext* context;
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_ExecuteOptions, num_non_donatable_input_indices);
PJRT_DEFINE_STRUCT_TRAITS(PJRT_ExecuteOptions, context);

struct PJRT_LoadedExecutable_Execute_Args {
size_t struct_size;
Expand Down
17 changes: 14 additions & 3 deletions xla/pjrt/c/pjrt_c_api_cpu_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ limitations under the License.

#include "absl/status/status.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/c/pjrt_c_api_layouts_extension.h"
#include "xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h"
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h"
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"

Expand All @@ -46,8 +50,12 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
}

PJRT_Error* PJRT_ExecuteContext_Create(PJRT_ExecuteContext_Create_Args* args) {
return new PJRT_Error{absl::UnimplementedError(
"ExecuteContext not supported for CPU execution.")};
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"PJRT_ExecuteContext_Create_Args",
PJRT_ExecuteContext_Create_Args_STRUCT_SIZE, args->struct_size));
auto execute_context = std::make_unique<xla::ExecuteContext>();
args->context = pjrt::CreateWrapperExecuteContext(std::move(execute_context));
return nullptr;
}

PJRT_Error* PJRT_CpuDeviceTopology_Create(
Expand All @@ -64,12 +72,15 @@ const PJRT_Api* GetCpuPjrtApi() {
pjrt::CreateMemoryDescriptionsExtension(
reinterpret_cast<PJRT_Extension_Base*>(&layouts_extension));

static PJRT_FFI_Extension ffi_extension = pjrt::CreateFfiExtension(
reinterpret_cast<PJRT_Extension_Base*>(&memory_descriptions_extension));

static const PJRT_Api pjrt_api = pjrt::CreatePjrtApi(
pjrt::cpu_plugin::PJRT_Client_Create,
pjrt::cpu_plugin::PJRT_ExecuteContext_Create,
pjrt::cpu_plugin::PJRT_CpuDeviceTopology_Create,
pjrt::PJRT_Plugin_Initialize_NoOp,
reinterpret_cast<PJRT_Extension_Base*>(&memory_descriptions_extension),
reinterpret_cast<PJRT_Extension_Base*>(&ffi_extension),
pjrt::PJRT_Plugin_Attributes_Xla);

return &pjrt_api;
Expand Down
4 changes: 3 additions & 1 deletion xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1640,7 +1640,9 @@ PJRT_Error* PJRT_LoadedExecutable_Execute(
options.strict_shape_checking = true;
options.arguments_are_tupled = false;
options.untuple_result = true;
options.context = nullptr;
options.context = args->options->context
? args->options->context->execute_context.get()
: nullptr;
options.multi_slice_config = nullptr;
options.use_major_to_minor_data_layout_for_callbacks = true;
if (args->options->num_non_donatable_input_indices > 0) {
Expand Down
67 changes: 58 additions & 9 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "xla/ffi/execution_context.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
Expand All @@ -50,6 +51,7 @@ limitations under the License.
#include "xla/literal.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/c/pjrt_c_api_layouts_extension.h"
#include "xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h"
Expand All @@ -71,15 +73,14 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/framework/allocator.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/status.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/fingerprint.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down Expand Up @@ -1826,6 +1827,42 @@ PjRtCApiLoadedExecutable::GetCommonExecuteArgs(
return args;
}

static absl::StatusOr<PJRT_ExecuteContext*> ForwardExecuteContext(
const PJRT_Api* c_api, const ExecuteContext* context) {
// If the execute context is null, we don't have anything to forward.
if (context == nullptr) return nullptr;

// If we can't find the FFI extension, we can't forward anything from the
// execute context to the C API.
PJRT_FFI_Extension* ffi_extension = pjrt::FindExtension<PJRT_FFI_Extension>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_FFI);
if (ffi_extension == nullptr) return nullptr;

// Create a new instance of the PJRT_ExecuteContext.
PJRT_ExecuteContext_Create_Args create_args = {
PJRT_ExecuteContext_Create_Args_STRUCT_SIZE, nullptr, nullptr};
RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_ExecuteContext_Create(&create_args),
c_api);

// Forward FFI user data to the C API execute context.
using TypeId = ffi::ExecutionContext::TypeId;
auto forward_user_data = [&](TypeId type_id, void* data) -> absl::Status {
PJRT_FFI_UserData_Add_Args add_args{
PJRT_FFI_UserData_Add_Args_STRUCT_SIZE,
nullptr,
create_args.context,
PJRT_FFI_UserData{type_id.value(), data, /*deleter=*/nullptr},
};
RETURN_STATUS_IF_PJRT_ERROR(ffi_extension->user_data_add(&add_args), c_api);
return absl::OkStatus();
};

TF_RETURN_IF_ERROR(
context->ffi_context().ForEachWithStatus(forward_user_data));

return create_args.context;
}

absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
PjRtCApiLoadedExecutable::Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
Expand All @@ -1835,15 +1872,28 @@ PjRtCApiLoadedExecutable::Execute(
std::vector<std::vector<PJRT_Buffer*>> c_output_lists_storage;
std::vector<PJRT_Buffer**> c_output_lists;
std::vector<int64_t> non_donatable_input_indices_storage;
PJRT_ExecuteOptions c_options;
c_options.num_send_ops = 0;
c_options.num_recv_ops = 0;
std::vector<PJRT_Buffer**> c_arguments;
std::optional<std::vector<PJRT_Event*>> device_complete_events;
if (returned_futures.has_value()) {
device_complete_events.emplace();
}

PJRT_ExecuteOptions c_options = {PJRT_ExecuteOptions_STRUCT_SIZE, nullptr};
TF_ASSIGN_OR_RETURN(c_options.context,
ForwardExecuteContext(pjrt_c_api(), options.context));

// Don't forget to destroy execute context if we created it.
auto destroy_context = absl::MakeCleanup([&]() {
if (c_options.context != nullptr) {
PJRT_ExecuteContext_Destroy_Args destroy_args = {
PJRT_ExecuteContext_Destroy_Args_STRUCT_SIZE, nullptr,
c_options.context};
pjrt::LogFatalIfPjrtError(
pjrt_c_api()->PJRT_ExecuteContext_Destroy(&destroy_args),
pjrt_c_api());
}
});

auto callback_data = std::make_shared<SendRecvCallbackData>();
TF_ASSIGN_OR_RETURN(
PJRT_LoadedExecutable_Execute_Args args,
Expand Down Expand Up @@ -1907,16 +1957,15 @@ PjRtCApiLoadedExecutable::ExecuteWithSingleDevice(
std::vector<std::vector<PJRT_Buffer*>> c_output_lists_storage;
std::vector<PJRT_Buffer**> c_output_lists;
std::vector<int64_t> non_donatable_input_indices_storage;
PJRT_ExecuteOptions c_options;
c_options.num_send_ops = 0;
c_options.num_recv_ops = 0;
std::vector<PJRT_Buffer**> c_arguments;
std::optional<std::vector<PJRT_Event*>> device_complete_events;
if (fill_future) {
device_complete_events.emplace();
}

auto callback_data = std::make_shared<SendRecvCallbackData>();

PJRT_ExecuteOptions c_options = {PJRT_ExecuteOptions_STRUCT_SIZE, nullptr};
TF_ASSIGN_OR_RETURN(
PJRT_LoadedExecutable_Execute_Args args,
GetCommonExecuteArgs(argument_handles_vec, options, c_options,
Expand Down
Loading
Loading