Skip to content

Commit

Permalink
[PJRT:C] Implement PJRT_AsyncHostToDeviceTransferManager class. Intro…
Browse files Browse the repository at this point in the history
…duce more of its member function to C Api.

PiperOrigin-RevId: 711450094
  • Loading branch information
sizhit2 authored and Google-ML-Automation committed Jan 11, 2025
1 parent a5ba283 commit 7c9cc77
Show file tree
Hide file tree
Showing 12 changed files with 546 additions and 123 deletions.
3 changes: 3 additions & 0 deletions xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -371,6 +372,7 @@ cc_library(
"//xla/pjrt:pjrt_executable",
"//xla/pjrt:pjrt_future",
"//xla/service:computation_placer_hdr",
"//xla/tsl/platform:status",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -416,6 +418,7 @@ xla_test(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:platform_port",
"@tsl//tsl/platform:status",
Expand Down
2 changes: 2 additions & 0 deletions xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# PJRT C API changelog
## 0.62
* Added more member functions for ``PJRT_AsyncHostToDeviceTransferManager``.

## 0.61
* Added ``PJRT_KeyValueTryGet`` to the KV store interface,
Expand Down
85 changes: 82 additions & 3 deletions xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 61
#define PJRT_API_MINOR 62

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -661,6 +661,79 @@ PJRT_DEFINE_STRUCT_TRAITS(
typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_TransferData(
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args* args);

struct PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_AsyncHostToDeviceTransferManager* transfer_manager;
int buffer_index;
PJRT_Buffer* buffer_out; // out
};
PJRT_DEFINE_STRUCT_TRAITS(
PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args, buffer_out);
typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer(
PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args* args);

struct PJRT_AsyncHostToDeviceTransferManager_Device_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_AsyncHostToDeviceTransferManager* transfer_manager;
PJRT_Device* device_out; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_AsyncHostToDeviceTransferManager_Device_Args,
device_out);
typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_Device(
PJRT_AsyncHostToDeviceTransferManager_Device_Args* args);

struct PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_AsyncHostToDeviceTransferManager* transfer_manager;
size_t buffer_count; // out
};
PJRT_DEFINE_STRUCT_TRAITS(
PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args, buffer_count);
typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_BufferCount(
PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args* args);

struct PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_AsyncHostToDeviceTransferManager* transfer_manager;
int buffer_index;
size_t buffer_size; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args,
buffer_size);
typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_BufferSize(
PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args* args);

struct PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_AsyncHostToDeviceTransferManager* transfer_manager;
int buffer_index;
PJRT_Error_Code error_code;
const char* error_message;
size_t error_message_size;
};
PJRT_DEFINE_STRUCT_TRAITS(
PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args,
error_message_size);
typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_SetBufferError(
PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args* args);

struct PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_AsyncHostToDeviceTransferManager* transfer_manager;
const PJRT_NamedValue* transfer_metadata;
size_t num_metadata;
};
PJRT_DEFINE_STRUCT_TRAITS(
PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args, num_metadata);
typedef PJRT_Error* PJRT_AsyncHostToDeviceTransferManager_AddMetadata(
PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args* args);

typedef enum {
// Invalid primitive type to serve as default.
PJRT_Buffer_Type_INVALID,
Expand Down Expand Up @@ -2362,11 +2435,17 @@ typedef struct PJRT_Api {
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_Destroy);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_TransferData);
_PJRT_API_STRUCT_FIELD(PJRT_Client_CreateBuffersForAsyncHostToDevice);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_Device);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_BufferCount);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_BufferSize);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_SetBufferError);
_PJRT_API_STRUCT_FIELD(PJRT_AsyncHostToDeviceTransferManager_AddMetadata);
} PJRT_Api;

enum {
PJRT_Api_STRUCT_SIZE =
PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_CreateBuffersForAsyncHostToDevice)
PJRT_Api_STRUCT_SIZE = PJRT_STRUCT_SIZE(
PJRT_Api, PJRT_AsyncHostToDeviceTransferManager_AddMetadata)
};

#undef _PJRT_API_STRUCT_FIELD
Expand Down
195 changes: 145 additions & 50 deletions xla/pjrt/c/pjrt_c_api_gpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/client/client_library.h"
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/execution_context.h"
Expand All @@ -52,21 +54,23 @@ limitations under the License.
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/gpu/gpu_init.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/framework/allocator.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/status.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/platform/mem.h"
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"

namespace pjrt {
namespace {

using ::testing::ElementsAreArray;
using ::testing::HasSubstr;
using ::testing::IsNull;
using ::tsl::testing::StatusIs;
Expand Down Expand Up @@ -166,6 +170,19 @@ TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(
xla::LiteralUtil::CreateR1<float>(float_data), *literal));
}
class PjrtCApiGpuTransferManagerTest : public PjrtCApiGpuTest {
public:
~PjrtCApiGpuTransferManagerTest() override {
transfer_manager_.reset(nullptr);
}
void CreateTransferManager(const xla::Shape& host_shape) {
transfer_manager_ = create_transfer_manager(host_shape);
}

std::unique_ptr<PJRT_AsyncHostToDeviceTransferManager,
PJRT_AsyncHostToDeviceTransferManagerDeleter>
transfer_manager_;
};

class PjrtCApiGpuBufferTest : public PjrtCApiGpuTest {
public:
Expand Down Expand Up @@ -274,69 +291,147 @@ TEST_F(PjrtCApiGpuTest, CreateAndDestroyExecuteContext) {
api_->PJRT_ExecuteContext_Destroy(&destroy_args);
}

TEST_F(PjrtCApiGpuTest, CreateBuffersWithMemorytForH2DAndTransfer) {
xla::Shape host_shape = xla::ShapeUtil::MakeShapeWithDenseLayout(
xla::F32, /*dimensions=*/{2, 2, 2}, /*minor_to_major=*/{1, 0, 2});
TEST_F(PjrtCApiGpuTransferManagerTest, SetBufferError) {
xla::Shape host_shape =
xla::ShapeUtil::MakeShape(xla::F32, /*dimensions=*/{8});
std::vector<float> float_data = {1, 2, 3, 4, 5, 6, 7, 8};

PJRT_Client_CreateBuffersForAsyncHostToDevice_Args args;
args.struct_size =
PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.client = client_;
PJRT_ShapeSpec c_shape_spec;
c_shape_spec.element_type =
pjrt::ConvertToPjRtBufferType(xla::PrimitiveType::F32);
c_shape_spec.dims = host_shape.dimensions().data();
c_shape_spec.num_dims = host_shape.dimensions().size();
args.shape_specs = &c_shape_spec;
args.num_shape_specs = 1;
TF_ASSERT_OK_AND_ASSIGN(pjrt::BufferMemoryLayoutData c_layout_data,
ConvertToBufferMemoryLayoutData(host_shape.layout()));
std::vector<PJRT_Buffer_MemoryLayout*> device_layout_list(1);
device_layout_list[0] = &(c_layout_data.c_layout);
args.device_layouts = device_layout_list.data();
args.num_device_layouts = device_layout_list.size();
PJRT_Client_AddressableMemories_Args memory_args;
memory_args.struct_size = PJRT_Client_AddressableMemories_Args_STRUCT_SIZE;
memory_args.extension_start = nullptr;
memory_args.client = client_;

PJRT_Error* memory_error =
api_->PJRT_Client_AddressableMemories(&memory_args);
ASSERT_EQ(memory_error, nullptr);
ASSERT_NE(memory_args.addressable_memories, nullptr);
ASSERT_GT(memory_args.num_addressable_memories, 0);
args.memory = memory_args.addressable_memories[0];
PJRT_Error* error =
api_->PJRT_Client_CreateBuffersForAsyncHostToDevice(&args);
ASSERT_EQ(error, nullptr);
CreateTransferManager(host_shape);

PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args add_metadata_args;
add_metadata_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_AddMetadata_Args_STRUCT_SIZE;
add_metadata_args.extension_start = nullptr;
add_metadata_args.transfer_manager = transfer_manager_.get();
std::vector<PJRT_NamedValue> transfer_metadata;
transfer_metadata.reserve(1);
std::string test_key = "test_key";
std::string test_value = "test_value";
PJRT_NamedValue test_named_value;
test_named_value.name = test_key.c_str();
test_named_value.name_size = test_key.size();
test_named_value.type = PJRT_NamedValue_Type::PJRT_NamedValue_kString;
test_named_value.string_value = test_value.c_str();
test_named_value.value_size = test_value.size();
transfer_metadata.push_back(test_named_value);
add_metadata_args.transfer_metadata = transfer_metadata.data();
add_metadata_args.num_metadata = transfer_metadata.size();
PJRT_Error* add_metadata_error =
PJRT_AsyncHostToDeviceTransferManager_AddMetadata(&add_metadata_args);
ASSERT_EQ(add_metadata_error, nullptr);

PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args buffer_count_args;
buffer_count_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_BufferCount_Args_STRUCT_SIZE;
buffer_count_args.extension_start = nullptr;
buffer_count_args.transfer_manager = transfer_manager_.get();
PJRT_Error* buffer_count_error =
PJRT_AsyncHostToDeviceTransferManager_BufferCount(&buffer_count_args);
ASSERT_EQ(buffer_count_error, nullptr);
EXPECT_EQ(buffer_count_args.buffer_count, 1);

PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args retrieve_args;
retrieve_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args_STRUCT_SIZE;
retrieve_args.extension_start = nullptr;
retrieve_args.transfer_manager = transfer_manager_.get();
retrieve_args.buffer_index = 0;
PJRT_Error* retrieve_error =
PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer(&retrieve_args);
ASSERT_EQ(retrieve_error, nullptr);
PJRT_Buffer* buffer_out = retrieve_args.buffer_out;

PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args
set_buffer_error_args;
set_buffer_error_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_SetBufferError_Args_STRUCT_SIZE;
set_buffer_error_args.extension_start = nullptr;
set_buffer_error_args.transfer_manager = transfer_manager_.get();
set_buffer_error_args.buffer_index = 0;
set_buffer_error_args.error_code = PJRT_Error_Code_INTERNAL;
std::string error_message = "test error";
set_buffer_error_args.error_message = error_message.data();
set_buffer_error_args.error_message_size = error_message.size();
PJRT_Error* set_buffer_error_error =
PJRT_AsyncHostToDeviceTransferManager_SetBufferError(
&set_buffer_error_args);
ASSERT_EQ(set_buffer_error_error, nullptr);

EXPECT_THAT(buffer_out->buffer->ToLiteralSync(),
StatusIs(absl::StatusCode::kInternal, HasSubstr(error_message)));

PJRT_BufferDeleter buffer_deleter = MakeBufferDeleter(api_);
buffer_deleter(buffer_out);
}

TEST_F(PjrtCApiGpuTransferManagerTest, TransferRawDataToBufferIsSuccessful) {
xla::Shape host_shape =
xla::ShapeUtil::MakeShape(xla::U32, /*dimensions=*/{8});
std::vector<uint32_t> data = {1, 2, 3, 4, 5, 6, 7, 8};
absl::Span<const char> raw_data_view = GetRawView(data);
CreateTransferManager(host_shape);

PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args retrieve_args;
retrieve_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer_Args_STRUCT_SIZE;
retrieve_args.extension_start = nullptr;
retrieve_args.transfer_manager = transfer_manager_.get();
retrieve_args.buffer_index = 0;
PJRT_Error* retrieve_error =
PJRT_AsyncHostToDeviceTransferManager_RetrieveBuffer(&retrieve_args);
ASSERT_EQ(retrieve_error, nullptr);
PJRT_Buffer* buffer_out = retrieve_args.buffer_out;
EXPECT_FALSE(buffer_out->buffer->GetReadyFuture().IsReady());

TF_ASSERT_OK_AND_ASSIGN(xla::Shape result_shape,
buffer_out->buffer->HostShape());
EXPECT_EQ(result_shape, host_shape);

PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args buffer_size_args;
buffer_size_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_BufferSize_Args_STRUCT_SIZE;
buffer_size_args.extension_start = nullptr;
buffer_size_args.transfer_manager = transfer_manager_.get();
buffer_size_args.buffer_index = 0;
PJRT_Error* buffer_size_error =
PJRT_AsyncHostToDeviceTransferManager_BufferSize(&buffer_size_args);
ASSERT_EQ(buffer_size_error, nullptr);
EXPECT_EQ(buffer_size_args.buffer_size,
buffer_out->buffer->GetOnDeviceSizeInBytes().value());

PJRT_AsyncHostToDeviceTransferManager_Device_Args device_args;
device_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_Device_Args_STRUCT_SIZE;
device_args.extension_start = nullptr;
device_args.transfer_manager = transfer_manager_.get();
PJRT_Error* device_error =
PJRT_AsyncHostToDeviceTransferManager_Device(&device_args);
ASSERT_EQ(device_error, nullptr);
EXPECT_EQ(device_args.device_out, GetClientDevices()[0]);

PJRT_AsyncHostToDeviceTransferManager_TransferData_Args transfer_args;
transfer_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_TransferData_Args_STRUCT_SIZE;
transfer_args.extension_start = nullptr;
transfer_args.transfer_manager = args.transfer_manager;
transfer_args.transfer_manager = transfer_manager_.get();
transfer_args.buffer_index = 0;
transfer_args.data = float_data.data();
transfer_args.data = raw_data_view.data();
transfer_args.offset = 0;
transfer_args.transfer_size = float_data.size();
transfer_args.transfer_size = raw_data_view.size();
transfer_args.is_last_transfer = true;

PJRT_Error* transfer_error =
PJRT_AsyncHostToDeviceTransferManager_TransferData(&transfer_args);
ASSERT_EQ(transfer_error, nullptr);
std::unique_ptr<PJRT_Event, PJRT_EventDeleter> done_with_h2d_transfer_event(
transfer_args.done_with_h2d_transfer, MakeEventDeleter(api_));

// Destroy the transfer manager.
PJRT_AsyncHostToDeviceTransferManager_Destroy_Args destroy_args;
destroy_args.struct_size =
PJRT_AsyncHostToDeviceTransferManager_Destroy_Args_STRUCT_SIZE;
destroy_args.extension_start = nullptr;
destroy_args.transfer_manager = args.transfer_manager;
LogFatalIfPjrtError(
api_->PJRT_AsyncHostToDeviceTransferManager_Destroy(&destroy_args), api_);
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<xla::Literal> literal,
buffer_out->buffer->ToLiteralSync());
EXPECT_EQ(literal->element_count(), 8);
EXPECT_THAT(literal->data<uint32_t>(), ElementsAreArray(data));

PJRT_BufferDeleter buffer_deleter = MakeBufferDeleter(api_);
buffer_deleter(buffer_out);
}

absl::StatusOr<PJRT_Client_Create_Args> BuildCreateArg(
Expand Down
Loading

0 comments on commit 7c9cc77

Please sign in to comment.