From 7c9cc77af80b620401d92bcd60bd6e75ff413cd8 Mon Sep 17 00:00:00 2001 From: Sizhi Tan Date: Thu, 2 Jan 2025 09:47:48 -0800 Subject: [PATCH] [PJRT:C] Implement PJRT_AsyncHostToDeviceTransferManager class. Introduce more of its member function to C Api. PiperOrigin-RevId: 711450094 --- xla/pjrt/c/BUILD | 3 + xla/pjrt/c/CHANGELOG.md | 2 + xla/pjrt/c/pjrt_c_api.h | 85 ++++++++++- xla/pjrt/c/pjrt_c_api_gpu_test.cc | 195 +++++++++++++++++++------- xla/pjrt/c/pjrt_c_api_helpers.cc | 27 ++++ xla/pjrt/c/pjrt_c_api_helpers.h | 4 + xla/pjrt/c/pjrt_c_api_test.cc | 42 +++++- xla/pjrt/c/pjrt_c_api_test_base.cc | 50 +++++++ xla/pjrt/c/pjrt_c_api_test_base.h | 10 ++ xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 95 +++++++++++++ xla/pjrt/c/pjrt_c_api_wrapper_impl.h | 12 ++ xla/pjrt/pjrt_c_api_client.cc | 144 ++++++++++--------- 12 files changed, 546 insertions(+), 123 deletions(-) diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD index 01c1b03e3cb57..ef209e6453247 100644 --- a/xla/pjrt/c/BUILD +++ b/xla/pjrt/c/BUILD @@ -203,6 +203,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -371,6 +372,7 @@ cc_library( "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/service:computation_placer_hdr", + "//xla/tsl/platform:status", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", @@ -416,6 +418,7 @@ xla_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@tsl//tsl/platform:platform_port", "@tsl//tsl/platform:status", diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index d56741eb3500b..a5509ac25ebd4 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,4 +1,6 @@ # PJRT C API changelog +## 0.62 +* Added more member functions for ``PJRT_AsyncHostToDeviceTransferManager``. ## 0.61 * Added ``PJRT_KeyValueTryGet`` to the KV store interface, diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index f2fc3b1c507a3..0df289b0aecf4 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 61 +#define PJRT_API_MINOR 62 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -661,6 +661,79 @@ PJRT_DEFINE_STRUCT_TRAITS( typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( PJRT_AsyncHostToDeviceTransferManager_TransferData_Args* args); +struct PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; + int buffer_index; + PJRT_Buffer* buffer_out; // out +}; +PJRT_DEFINE_STRUCT_TRAITS( + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args, buffer_out); +typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer( + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args* args); + +struct PJRT_AsyncHostToDeviceTransferManager_Device_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; + PJRT_Device* device_out; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_AsyncHostToDeviceTransferManager_Device_Args, + device_out); +typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Device( + PJRT_AsyncHostToDeviceTransferManager_Device_Args* args); + +struct PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; + size_t buffer_count; // out +}; +PJRT_DEFINE_STRUCT_TRAITS( + PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args, buffer_count); +typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_BufferCount( + PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args* args); + +struct PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; + int buffer_index; + size_t buffer_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args, + buffer_size); +typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_BufferSize( + PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args* args); + +struct PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; + int buffer_index; + PJRT_Error_Code error_code; + const char* error_message; + size_t error_message_size; +}; +PJRT_DEFINE_STRUCT_TRAITS( + PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args, + error_message_size); +typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_SetBufferError( + PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args* args); + +struct PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + PJRT_AsyncHostToDeviceTransferManager* transfer_manager; + const PJRT_NamedValue* transfer_metadata; + size_t num_metadata; +}; +PJRT_DEFINE_STRUCT_TRAITS( + PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args, num_metadata); +typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_AddMetadata( + PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args* args); + typedef enum { // Invalid primitive type to serve as default. PJRT_Buffer_Type_INVALID, @@ -2362,11 +2435,17 @@ typedef struct PJRT_Api { _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_Destroy); _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_TransferData); _PJRT_API_STRUCT_FIELD(PJRT_Client_CreateBuffersForAsyncHostToDevice); + _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer); + _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_Device); + _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_BufferCount); + _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_BufferSize); + _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_SetBufferError); + _PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_AddMetadata); } PJRT_Api; enum { - PJRT_Api_STRUCT_SIZE = - PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_CreateBuffersForAsyncHostToDevice) + PJRT_Api_STRUCT_SIZE = PJRT_STRUCT_SIZE( + PJRT_Api, PJRT_AsyncHostToDeviceTransferManager_AddMetadata) }; #undef _PJRT_API_STRUCT_FIELD diff --git a/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/xla/pjrt/c/pjrt_c_api_gpu_test.cc index ae12a1684c23e..1d9e5ef95241c 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -35,6 +35,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/client/client_library.h" #include "xla/ffi/api/ffi.h" #include "xla/ffi/execution_context.h" @@ -52,21 +54,23 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_future.h" #include "xla/service/custom_call_target_registry.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/tests/literal_test_util.h" +#include "xla/tsl/framework/allocator.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/platform/mem.h" -#include "tsl/platform/status.h" -#include "tsl/platform/status_matchers.h" -#include "tsl/platform/statusor.h" namespace pjrt { namespace { +using ::testing::ElementsAreArray; using ::testing::HasSubstr; using ::testing::IsNull; using ::tsl::testing::StatusIs; @@ -166,6 +170,19 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR1(float_data), *literal)); } +class PjrtCApiGpuTransferManagerTest : public PjrtCApiGpuTest { + public: + ~PjrtCApiGpuTransferManagerTest() override { + transfer_manager_.reset(nullptr); + } + void CreateTransferManager(const xla::Shape& host_shape) { + transfer_manager_ = create_transfer_manager(host_shape); + } + + std::unique_ptr + transfer_manager_; +}; class PjrtCApiGpuBufferTest : public PjrtCApiGpuTest { public: @@ -274,69 +291,147 @@ TEST_F(PjrtCApiGpuTest, CreateAndDestroyExecuteContext) { api_->PJRT_ExecuteContext_Destroy(&destroy_args); } -TEST_F(PjrtCApiGpuTest, CreateBuffersWithMemorytForH2DAndTransfer) { - xla::Shape host_shape = xla::ShapeUtil::MakeShapeWithDenseLayout( - xla::F32, /*dimensions=*/{2, 2, 2}, /*minor_to_major=*/{1, 0, 2}); +TEST_F(PjrtCApiGpuTransferManagerTest, SetBufferError) { + xla::Shape host_shape = + xla::ShapeUtil::MakeShape(xla::F32, /*dimensions=*/{8}); std::vector float_data = {1, 2, 3, 4, 5, 6, 7, 8}; - PJRT_Client_CreateBuffersForAsyncHostToDevice_Args args; - args.struct_size = - PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE; - args.extension_start = nullptr; - args.client = client_; - PJRT_ShapeSpec c_shape_spec; - c_shape_spec.element_type = - pjrt::ConvertToPjRtBufferType(xla::PrimitiveType::F32); - c_shape_spec.dims = host_shape.dimensions().data(); - c_shape_spec.num_dims = host_shape.dimensions().size(); - args.shape_specs = &c_shape_spec; - args.num_shape_specs = 1; - TF_ASSERT_OK_AND_ASSIGN(pjrt::BufferMemoryLayoutData c_layout_data, - ConvertToBufferMemoryLayoutData(host_shape.layout())); - std::vector device_layout_list(1); - device_layout_list[0] = &(c_layout_data.c_layout); - args.device_layouts = device_layout_list.data(); - args.num_device_layouts = device_layout_list.size(); - PJRT_Client_AddressableMemories_Args memory_args; - memory_args.struct_size = PJRT_Client_AddressableMemories_Args_STRUCT_SIZE; - memory_args.extension_start = nullptr; - memory_args.client = client_; - - PJRT_Error* memory_error = - api_->PJRT_Client_AddressableMemories(&memory_args); - ASSERT_EQ(memory_error, nullptr); - ASSERT_NE(memory_args.addressable_memories, nullptr); - ASSERT_GT(memory_args.num_addressable_memories, 0); - args.memory = memory_args.addressable_memories[0]; - PJRT_Error* error = - api_->PJRT_Client_CreateBuffersForAsyncHostToDevice(&args); - ASSERT_EQ(error, nullptr); + CreateTransferManager(host_shape); + + PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args add_metadata_args; + add_metadata_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args_STRUCT_SIZE; + add_metadata_args.extension_start = nullptr; + add_metadata_args.transfer_manager = transfer_manager_.get(); + std::vector transfer_metadata; + transfer_metadata.reserve(1); + std::string test_key = "test_key"; + std::string test_value = "test_value"; + PJRT_NamedValue test_named_value; + test_named_value.name = test_key.c_str(); + test_named_value.name_size = test_key.size(); + test_named_value.type = PJRT_NamedValue_Type::PJRT_NamedValue_kString; + test_named_value.string_value = test_value.c_str(); + test_named_value.value_size = test_value.size(); + transfer_metadata.push_back(test_named_value); + add_metadata_args.transfer_metadata = transfer_metadata.data(); + add_metadata_args.num_metadata = transfer_metadata.size(); + PJRT_Error* add_metadata_error = + PJRT_AsyncHostToDeviceTransferManager_AddMetadata(&add_metadata_args); + ASSERT_EQ(add_metadata_error, nullptr); + + PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args buffer_count_args; + buffer_count_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args_STRUCT_SIZE; + buffer_count_args.extension_start = nullptr; + buffer_count_args.transfer_manager = transfer_manager_.get(); + PJRT_Error* buffer_count_error = + PJRT_AsyncHostToDeviceTransferManager_BufferCount(&buffer_count_args); + ASSERT_EQ(buffer_count_error, nullptr); + EXPECT_EQ(buffer_count_args.buffer_count, 1); + + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args retrieve_args; + retrieve_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args_STRUCT_SIZE; + retrieve_args.extension_start = nullptr; + retrieve_args.transfer_manager = transfer_manager_.get(); + retrieve_args.buffer_index = 0; + PJRT_Error* retrieve_error = + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer(&retrieve_args); + ASSERT_EQ(retrieve_error, nullptr); + PJRT_Buffer* buffer_out = retrieve_args.buffer_out; + + PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args + set_buffer_error_args; + set_buffer_error_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args_STRUCT_SIZE; + set_buffer_error_args.extension_start = nullptr; + set_buffer_error_args.transfer_manager = transfer_manager_.get(); + set_buffer_error_args.buffer_index = 0; + set_buffer_error_args.error_code = PJRT_Error_Code_INTERNAL; + std::string error_message = "test error"; + set_buffer_error_args.error_message = error_message.data(); + set_buffer_error_args.error_message_size = error_message.size(); + PJRT_Error* set_buffer_error_error = + PJRT_AsyncHostToDeviceTransferManager_SetBufferError( + &set_buffer_error_args); + ASSERT_EQ(set_buffer_error_error, nullptr); + + EXPECT_THAT(buffer_out->buffer->ToLiteralSync(), + StatusIs(absl::StatusCode::kInternal, HasSubstr(error_message))); + + PJRT_BufferDeleter buffer_deleter = MakeBufferDeleter(api_); + buffer_deleter(buffer_out); +} + +TEST_F(PjrtCApiGpuTransferManagerTest, TransferRawDataToBufferIsSuccessful) { + xla::Shape host_shape = + xla::ShapeUtil::MakeShape(xla::U32, /*dimensions=*/{8}); + std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; + absl::Span raw_data_view = GetRawView(data); + CreateTransferManager(host_shape); + + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args retrieve_args; + retrieve_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args_STRUCT_SIZE; + retrieve_args.extension_start = nullptr; + retrieve_args.transfer_manager = transfer_manager_.get(); + retrieve_args.buffer_index = 0; + PJRT_Error* retrieve_error = + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer(&retrieve_args); + ASSERT_EQ(retrieve_error, nullptr); + PJRT_Buffer* buffer_out = retrieve_args.buffer_out; + EXPECT_FALSE(buffer_out->buffer->GetReadyFuture().IsReady()); + + TF_ASSERT_OK_AND_ASSIGN(xla::Shape result_shape, + buffer_out->buffer->HostShape()); + EXPECT_EQ(result_shape, host_shape); + + PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args buffer_size_args; + buffer_size_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args_STRUCT_SIZE; + buffer_size_args.extension_start = nullptr; + buffer_size_args.transfer_manager = transfer_manager_.get(); + buffer_size_args.buffer_index = 0; + PJRT_Error* buffer_size_error = + PJRT_AsyncHostToDeviceTransferManager_BufferSize(&buffer_size_args); + ASSERT_EQ(buffer_size_error, nullptr); + EXPECT_EQ(buffer_size_args.buffer_size, + buffer_out->buffer->GetOnDeviceSizeInBytes().value()); + + PJRT_AsyncHostToDeviceTransferManager_Device_Args device_args; + device_args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_Device_Args_STRUCT_SIZE; + device_args.extension_start = nullptr; + device_args.transfer_manager = transfer_manager_.get(); + PJRT_Error* device_error = + PJRT_AsyncHostToDeviceTransferManager_Device(&device_args); + ASSERT_EQ(device_error, nullptr); + EXPECT_EQ(device_args.device_out, GetClientDevices()[0]); PJRT_AsyncHostToDeviceTransferManager_TransferData_Args transfer_args; transfer_args.struct_size = PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE; transfer_args.extension_start = nullptr; - transfer_args.transfer_manager = args.transfer_manager; + transfer_args.transfer_manager = transfer_manager_.get(); transfer_args.buffer_index = 0; - transfer_args.data = float_data.data(); + transfer_args.data = raw_data_view.data(); transfer_args.offset = 0; - transfer_args.transfer_size = float_data.size(); + transfer_args.transfer_size = raw_data_view.size(); transfer_args.is_last_transfer = true; - PJRT_Error* transfer_error = PJRT_AsyncHostToDeviceTransferManager_TransferData(&transfer_args); ASSERT_EQ(transfer_error, nullptr); std::unique_ptr done_with_h2d_transfer_event( transfer_args.done_with_h2d_transfer, MakeEventDeleter(api_)); - // Destroy the transfer manager. - PJRT_AsyncHostToDeviceTransferManager_Destroy_Args destroy_args; - destroy_args.struct_size = - PJRT_AsyncHostToDeviceTransferManager_Destroy_Args_STRUCT_SIZE; - destroy_args.extension_start = nullptr; - destroy_args.transfer_manager = args.transfer_manager; - LogFatalIfPjrtError( - api_->PJRT_AsyncHostToDeviceTransferManager_Destroy(&destroy_args), api_); + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr literal, + buffer_out->buffer->ToLiteralSync()); + EXPECT_EQ(literal->element_count(), 8); + EXPECT_THAT(literal->data(), ElementsAreArray(data)); + + PJRT_BufferDeleter buffer_deleter = MakeBufferDeleter(api_); + buffer_deleter(buffer_out); } absl::StatusOr BuildCreateArg( diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index c5113d1766ef6..f83ea69b522d0 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -1178,5 +1179,31 @@ std::vector GetMemorySpaceDescriptions( } return memory_space_descriptions; } +PJRT_Error* InvokePjRtEventWhenReady( + const PJRT_Api* api, PJRT_Event* event, + absl::AnyInvocable on_done_with_event) { + if (on_done_with_event) { + PJRT_Event_OnReady_Args event_args; + event_args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; + event_args.extension_start = nullptr; + event_args.event = event; + event_args.user_arg = new absl::AnyInvocable( + [on_done_with_event = std::move(on_done_with_event), + c_api = api](PJRT_Error* error) mutable { + if (error) { + ::pjrt::MakeErrorDeleter(c_api)(error); + } + std::move(on_done_with_event)(); + }); + event_args.callback = [](PJRT_Error* error, void* args) { + auto* on_done_with_event = + reinterpret_cast*>(args); + (*on_done_with_event)(error); + delete on_done_with_event; + }; + return api->PJRT_Event_OnReady(&event_args); + } + return nullptr; +} } // namespace pjrt diff --git a/xla/pjrt/c/pjrt_c_api_helpers.h b/xla/pjrt/c/pjrt_c_api_helpers.h index 44b56cc1b7f4f..00f22132bcf25 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/xla/pjrt/c/pjrt_c_api_helpers.h @@ -361,6 +361,10 @@ std::vector GetMemorySpaceDescriptions( PJRT_DeviceDescription* device_description, const PJRT_Api* c_api, absl::StatusOr* default_memory); +PJRT_Error* InvokePjRtEventWhenReady( + const PJRT_Api* api, PJRT_Event* event, + absl::AnyInvocable on_done_with_event); + } // namespace pjrt #endif // XLA_PJRT_C_PJRT_C_API_HELPERS_H_ diff --git a/xla/pjrt/c/pjrt_c_api_test.cc b/xla/pjrt/c/pjrt_c_api_test.cc index 0d9030380f35b..1d8a127a85456 100644 --- a/xla/pjrt/c/pjrt_c_api_test.cc +++ b/xla/pjrt/c/pjrt_c_api_test.cc @@ -915,12 +915,24 @@ FieldOffsetsAndSizesForVersion(int major_version, int minor_version) { if (minor_version >= 57) { add_field("PJRT_Buffer_CopyRawToHost", kFnPtrSize); } - if (minor_version >= 58) { + if (minor_version >= 60) { add_field("PJRT_AsyncHostToDeviceTransferManager_Destroy", kFnPtrSize); add_field("PJRT_AsyncHostToDeviceTransferManager_TransferData", kFnPtrSize); add_field("PJRT_Client_CreateBuffersForAsyncHostToDevice", kFnPtrSize); } + if (minor_version >= 62) { + add_field("PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer", + kFnPtrSize); + add_field("PJRT_AsyncHostToDeviceTransferManager_Device", kFnPtrSize); + add_field("PJRT_AsyncHostToDeviceTransferManager_BufferCount", + kFnPtrSize); + add_field("PJRT_AsyncHostToDeviceTransferManager_BufferSize", kFnPtrSize); + add_field("PJRT_AsyncHostToDeviceTransferManager_SetBufferError", + kFnPtrSize); + add_field("PJRT_AsyncHostToDeviceTransferManager_AddMetadata", + kFnPtrSize); + } return version_offsets_and_sizes; } LOG(FATAL) << "Unsupported API version: " << major_version << "." @@ -1260,6 +1272,34 @@ TEST_F(PjrtCAbiTestBase, FieldOffsetsAndSizes) { {"PJRT_Client_CreateBuffersForAsyncHostToDevice", {offsetof(PJRT_Api, PJRT_Client_CreateBuffersForAsyncHostToDevice), sizeof(PJRT_Api::PJRT_Client_CreateBuffersForAsyncHostToDevice)}}, + {"PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer", + {offsetof(PJRT_Api, + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer), + sizeof(PJRT_Api:: + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer)}}, + {"PJRT_AsyncHostToDeviceTransferManager_Device", + {offsetof(PJRT_Api, PJRT_AsyncHostToDeviceTransferManager_Device), + sizeof(PJRT_Api::PJRT_AsyncHostToDeviceTransferManager_Device)}}, + {"PJRT_AsyncHostToDeviceTransferManager_BufferCount", + {offsetof(PJRT_Api, + PJRT_AsyncHostToDeviceTransferManager_BufferCount), + sizeof( + PJRT_Api::PJRT_AsyncHostToDeviceTransferManager_BufferCount)}}, + {"PJRT_AsyncHostToDeviceTransferManager_BufferSize", + {offsetof(PJRT_Api, + PJRT_AsyncHostToDeviceTransferManager_BufferSize), + sizeof( + PJRT_Api::PJRT_AsyncHostToDeviceTransferManager_BufferSize)}}, + {"PJRT_AsyncHostToDeviceTransferManager_SetBufferError", + {offsetof(PJRT_Api, + PJRT_AsyncHostToDeviceTransferManager_SetBufferError), + sizeof(PJRT_Api:: + PJRT_AsyncHostToDeviceTransferManager_SetBufferError)}}, + {"PJRT_AsyncHostToDeviceTransferManager_AddMetadata", + {offsetof(PJRT_Api, + PJRT_AsyncHostToDeviceTransferManager_AddMetadata), + sizeof( + PJRT_Api::PJRT_AsyncHostToDeviceTransferManager_AddMetadata)}}, }; ASSERT_EQ(api_->pjrt_api_version.major_version, PJRT_API_MAJOR); ASSERT_EQ(api_->pjrt_api_version.minor_version, PJRT_API_MINOR); diff --git a/xla/pjrt/c/pjrt_c_api_test_base.cc b/xla/pjrt/c/pjrt_c_api_test_base.cc index f867846ebcbd5..42d4484d9f09d 100644 --- a/xla/pjrt/c/pjrt_c_api_test_base.cc +++ b/xla/pjrt/c/pjrt_c_api_test_base.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/service/computation_placer.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/status.h" #include "tsl/platform/status.h" namespace pjrt { @@ -215,4 +216,53 @@ PjrtCApiTestBase::ToUniquePtr(PJRT_Error* error) { error, ::pjrt::MakeErrorDeleter(api_)}; } +std::unique_ptr +PjrtCApiTestBase::create_transfer_manager(const xla::Shape& host_shape) { + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args args; + args.struct_size = + PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.client = client_; + + PJRT_ShapeSpec c_shape_spec; + c_shape_spec.element_type = + pjrt::ConvertToPjRtBufferType(host_shape.element_type()); + c_shape_spec.dims = host_shape.dimensions().data(); + c_shape_spec.num_dims = host_shape.dimensions().size(); + + args.shape_specs = &c_shape_spec; + args.num_shape_specs = 1; + absl::StatusOr result = + ConvertToBufferMemoryLayoutData(host_shape.layout()); + CHECK_OK(result); + BufferMemoryLayoutData c_layout_data = result.value(); + std::vector device_layout_list(1); + device_layout_list[0] = &(c_layout_data.c_layout); + args.device_layouts = device_layout_list.data(); + args.num_device_layouts = device_layout_list.size(); + + PJRT_Client_AddressableMemories_Args memory_args; + memory_args.struct_size = PJRT_Client_AddressableMemories_Args_STRUCT_SIZE; + memory_args.extension_start = nullptr; + memory_args.client = client_; + + PJRT_Error* memory_error = + api_->PJRT_Client_AddressableMemories(&memory_args); + CHECK_EQ(memory_error, nullptr); + CHECK_NE(memory_args.addressable_memories, nullptr); + CHECK_GT(memory_args.num_addressable_memories, 0); + args.memory = memory_args.addressable_memories[0]; + + PJRT_Error* error = + api_->PJRT_Client_CreateBuffersForAsyncHostToDevice(&args); + CHECK_EQ(error, nullptr); + std::unique_ptr + transfer_manager_out( + args.transfer_manager, + ::pjrt::MakeAsyncHostToDeviceTransferManagerDeleter(api_)); + return transfer_manager_out; +} + } // namespace pjrt diff --git a/xla/pjrt/c/pjrt_c_api_test_base.h b/xla/pjrt/c/pjrt_c_api_test_base.h index f6b7c97fa48f2..f75069dec1f20 100644 --- a/xla/pjrt/c/pjrt_c_api_test_base.h +++ b/xla/pjrt/c/pjrt_c_api_test_base.h @@ -32,6 +32,12 @@ limitations under the License. namespace pjrt { +template +absl::Span GetRawView(const std::vector& v) { + return absl::Span(reinterpret_cast(v.data()), + v.size() * sizeof(T)); +} + class PjrtCApiTestBase : public ::testing::Test { public: explicit PjrtCApiTestBase(const PJRT_Api* api); @@ -70,6 +76,10 @@ class PjrtCApiTestBase : public ::testing::Test { std::unique_ptr ToUniquePtr( PJRT_Error* error); + std::unique_ptr + create_transfer_manager(const xla::Shape& host_shape); + private: PjrtCApiTestBase(const PjrtCApiTestBase&) = delete; void operator=(const PjrtCApiTestBase&) = delete; diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 906223b315931..cf172f6986adf 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -652,6 +652,89 @@ PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( return nullptr; } +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer( + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args", + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args_STRUCT_SIZE, + args->struct_size)); + std::unique_ptr buffer_out = + args->transfer_manager->transfer_manager->RetrieveBuffer( + args->buffer_index); + args->buffer_out = + new PJRT_Buffer{std::move(buffer_out), args->transfer_manager->client}; + return nullptr; +} + +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Device( + PJRT_AsyncHostToDeviceTransferManager_Device_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_AsyncHostToDeviceTransferManager_Device_Args", + PJRT_AsyncHostToDeviceTransferManager_Device_Args_STRUCT_SIZE, + args->struct_size)); + args->device_out = + FindDeviceWrapper(args->transfer_manager->transfer_manager->device(), + args->transfer_manager->client->addressable_devices); + CHECK(args->device_out != nullptr) + << "No PJRT_Device* found in the client's `addressable_devices` that " + "wraps this " + << args->transfer_manager->transfer_manager->device()->DebugString(); + return nullptr; +} + +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_BufferCount( + PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args", + PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args_STRUCT_SIZE, + args->struct_size)); + args->buffer_count = args->transfer_manager->transfer_manager->buffer_count(); + return nullptr; +} + +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_BufferSize( + PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args", + PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args_STRUCT_SIZE, + args->struct_size)); + args->buffer_size = + args->transfer_manager->transfer_manager->buffer_size(args->buffer_index); + return nullptr; +} + +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_SetBufferError( + PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args", + PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args_STRUCT_SIZE, + args->struct_size)); + auto error_message = + absl::string_view(args->error_message, args->error_message_size); + auto error = absl::Status(pjrt::PjrtErrorCodeToStatusCode(args->error_code), + error_message); + args->transfer_manager->transfer_manager->SetBufferError(args->buffer_index, + error); + return nullptr; +} + +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_AddMetadata( + PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args", + PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args_STRUCT_SIZE, + args->struct_size)); + + auto pjrt_metadata = ConvertFromPjRtNamedValueList(args->transfer_metadata, + args->num_metadata); + absl::flat_hash_map metadata; + for (const auto& [key, value] : pjrt_metadata) { + metadata[key] = std::get(value); + } + args->transfer_manager->transfer_manager->AddTransferMetadata(metadata); + return nullptr; +} + namespace { absl::StatusOr ParseCompileOptions( @@ -2696,6 +2779,18 @@ PJRT_Api CreatePjrtApi(PJRT_Client_Create* create_fn, pjrt::PJRT_AsyncHostToDeviceTransferManager_TransferData, /*PJRT_Client_CreateBuffersForAsyncHostToDevice=*/ pjrt::PJRT_Client_CreateBuffersForAsyncHostToDevice, + /*PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer=*/ + pjrt::PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer, + /*PJRT_AsyncHostToDeviceTransferManager_Device=*/ + pjrt::PJRT_AsyncHostToDeviceTransferManager_Device, + /*PJRT_AsyncHostToDeviceTransferManager_BufferCount=*/ + pjrt::PJRT_AsyncHostToDeviceTransferManager_BufferCount, + /*PJRT_AsyncHostToDeviceTransferManager_BufferSize=*/ + pjrt::PJRT_AsyncHostToDeviceTransferManager_BufferSize, + /*PJRT_AsyncHostToDeviceTransferManager_SetBufferError=*/ + pjrt::PJRT_AsyncHostToDeviceTransferManager_SetBufferError, + /*PJRT_AsyncHostToDeviceTransferManager_AddMetadata=*/ + pjrt::PJRT_AsyncHostToDeviceTransferManager_AddMetadata, }; } diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 27b1cac051dbd..584754061491e 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -268,6 +268,18 @@ PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Destroy( PJRT_AsyncHostToDeviceTransferManager_Destroy_Args* args); PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData( PJRT_AsyncHostToDeviceTransferManager_TransferData_Args* args); +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer( + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args* args); +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Device( + PJRT_AsyncHostToDeviceTransferManager_Device_Args* args); +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_BufferCount( + PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args* args); +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_BufferSize( + PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args* args); +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_SetBufferError( + PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args* args); +PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_AddMetadata( + PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args* args); PJRT_Error* PJRT_DeviceDescription_Id(PJRT_DeviceDescription_Id_Args* args); PJRT_Error* PJRT_DeviceDescription_ProcessIndex( PJRT_DeviceDescription_ProcessIndex_Args* args); diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index 1ef07811cf53d..f22654fbde224 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -536,29 +536,10 @@ PjRtCApiClient::BufferFromHostBufferInternalImpl( std::unique_ptr event( args.done_with_host_buffer, ::pjrt::MakeEventDeleter(c_api_)); - if (on_done_with_host_buffer) { - PJRT_Event_OnReady_Args event_args; - event_args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; - event_args.extension_start = nullptr; - event_args.event = event.get(); - event_args.user_arg = new absl::AnyInvocable( - [on_done_with_host_buffer = std::move(on_done_with_host_buffer), - c_api = c_api_](PJRT_Error* error) mutable { - if (error) { - ::pjrt::MakeErrorDeleter(c_api)(error); - } - std::move(on_done_with_host_buffer)(); - }); - event_args.callback = [](PJRT_Error* error, void* args) { - auto* on_done_with_host_buffer = - reinterpret_cast*>(args); - (*on_done_with_host_buffer)(error); - delete on_done_with_host_buffer; - }; - - RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Event_OnReady(&event_args), - c_api_); - } + RETURN_STATUS_IF_PJRT_ERROR( + pjrt::InvokePjRtEventWhenReady(c_api_, event.get(), + std::move(on_done_with_host_buffer)), + c_api_); return buffer; } @@ -704,24 +685,40 @@ class PjRtCApiAsyncHostToDeviceTransferManager : c_client_(client), c_transfer_manager_(std::move(c_transfer_manager)) {} size_t buffer_count() const override { - LOG(FATAL) << "PJRT C API does not support buffer_count. Please " - "report an issue at https://github.com/google/jax/issues if " - "you need " - "this feature."; + PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args args; + args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.transfer_manager = c_transfer_manager_.get(); + const PJRT_Api* api = c_client_->pjrt_c_api(); + pjrt::LogFatalIfPjrtError( + api->PJRT_AsyncHostToDeviceTransferManager_BufferCount(&args), api); + return args.buffer_count; } PjRtDevice* device() const override { - LOG(FATAL) << "PJRT C API does not support device. Please " - "report an issue at https://github.com/google/jax/issues if " - "you need " - "this feature."; + PJRT_AsyncHostToDeviceTransferManager_Device_Args args; + args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_Device_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.transfer_manager = c_transfer_manager_.get(); + const PJRT_Api* api = c_client_->pjrt_c_api(); + pjrt::LogFatalIfPjrtError( + api->PJRT_AsyncHostToDeviceTransferManager_Device(&args), api); + return c_client_->GetCppDevice(args.device_out); } std::unique_ptr RetrieveBuffer(int buffer_index) override { - LOG(FATAL) << "PJRT C API does not support RetrieveBuffer. Please " - "report an issue at https://github.com/google/jax/issues if " - "you need " - "this feature."; + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args args; + args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.transfer_manager = c_transfer_manager_.get(); + args.buffer_index = buffer_index; + const PJRT_Api* api = c_client_->pjrt_c_api(); + pjrt::LogFatalIfPjrtError( + api->PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer(&args), api); + return std::make_unique(c_client_, args.buffer_out); } absl::Status TransferLiteralToBuffer( @@ -734,10 +731,16 @@ class PjRtCApiAsyncHostToDeviceTransferManager } size_t buffer_size(int buffer_index) const override { - LOG(FATAL) - << "PJRT C API does not support buffer_size. Please report an " - "issue at https://github.com/google/jax/issues if you need this " - "feature."; + PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args args; + args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.transfer_manager = c_transfer_manager_.get(); + args.buffer_index = buffer_index; + const PJRT_Api* api = c_client_->pjrt_c_api(); + pjrt::LogFatalIfPjrtError( + api->PJRT_AsyncHostToDeviceTransferManager_BufferSize(&args), api); + return args.buffer_size; } absl::Status TransferRawDataToBuffer( @@ -766,44 +769,47 @@ class PjRtCApiAsyncHostToDeviceTransferManager api->PJRT_AsyncHostToDeviceTransferManager_TransferData(&args), api); std::unique_ptr event( args.done_with_h2d_transfer, ::pjrt::MakeEventDeleter(api)); - if (on_done) { - PJRT_Event_OnReady_Args event_args; - event_args.struct_size = PJRT_Event_OnReady_Args_STRUCT_SIZE; - event_args.extension_start = nullptr; - event_args.event = event.get(); - event_args.user_arg = new absl::AnyInvocable( - [on_done = std::move(on_done), - c_api = api](PJRT_Error* error) mutable { - if (error) { - ::pjrt::MakeErrorDeleter(c_api)(error); - } - std::move(on_done)(); - }); - event_args.callback = [](PJRT_Error* error, void* args) { - auto* on_done_with_d2h_transfer = - reinterpret_cast*>(args); - (*on_done_with_d2h_transfer)(error); - delete on_done_with_d2h_transfer; - }; - - RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Event_OnReady(&event_args), api); - } + RETURN_STATUS_IF_PJRT_ERROR( + pjrt::InvokePjRtEventWhenReady(api, event.get(), std::move(on_done)), + api); return absl::OkStatus(); } void SetBufferError(int buffer_index, absl::Status error) override { - LOG(FATAL) << "PJRT C API does not support SetBufferError. Please " - "report an issue at https://github.com/google/jax/issues if " - "you need " - "this feature."; + PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args args; + args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.transfer_manager = c_transfer_manager_.get(); + args.buffer_index = buffer_index; + args.error_code = pjrt::StatusCodeToPjrtErrorCode(error.code()); + args.error_message = error.message().data(); + args.error_message_size = error.message().size(); + const PJRT_Api* api = c_client_->pjrt_c_api(); + pjrt::LogFatalIfPjrtError( + api->PJRT_AsyncHostToDeviceTransferManager_SetBufferError(&args), api); } using TransferMetadata = absl::flat_hash_map; void AddTransferMetadata(const TransferMetadata& metadata) override { - LOG(FATAL) << "PJRT C API does not support AddTransferMetadata. Please " - "report an issue at https://github.com/google/jax/issues if " - "you need " - "this feature."; + PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args args; + args.struct_size = + PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.transfer_manager = c_transfer_manager_.get(); + absl::flat_hash_map pjrt_metadata; + for (const auto& [key, value] : metadata) { + pjrt_metadata[key] = PjRtValueType(value); + }; + absl::StatusOr> result = + pjrt::ConvertToPjRtNamedValueList(pjrt_metadata); + TF_CHECK_OK(result.status()); + std::vector c_metadata = result.value(); + args.transfer_metadata = c_metadata.data(); + args.num_metadata = c_metadata.size(); + const PJRT_Api* api = c_client_->pjrt_c_api(); + pjrt::LogFatalIfPjrtError( + api->PJRT_AsyncHostToDeviceTransferManager_AddMetadata(&args), api); } private: