Skip to content

Commit

Permalink
Rollback breaking C API changes (TryGetKeyValue()).
Browse files Browse the repository at this point in the history
Reverts 926ef6a

PiperOrigin-RevId: 707888995
  • Loading branch information
Google-ML-Automation committed Dec 19, 2024
1 parent 03d02c5 commit 313d56f
Show file tree
Hide file tree
Showing 18 changed files with 14 additions and 214 deletions.
6 changes: 0 additions & 6 deletions xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
# PJRT C API changelog

## 0.61
* Added ``PJRT_KeyValueTryGet`` to the KV store interface,
which is non-blocking and immediately returns an error if the
key is not found.

## 0.60
* Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``.

Expand Down
40 changes: 2 additions & 38 deletions xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 61
#define PJRT_API_MINOR 60

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -351,35 +351,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueGetCallback_Args,
typedef PJRT_Error* (*PJRT_KeyValueGetCallback)(
PJRT_KeyValueGetCallback_Args* args);

// Same as KeyValueGet, but returns `NotFoundError` immediately if the key is
// not found.
typedef void (*PJRT_KeyValueTryGetCallback_ValueDeleter)(char* value);

struct PJRT_KeyValueTryGetCallback_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
const char* key;
size_t key_size;
PJRT_CallbackError* callback_error;
void* user_arg;
char* value; // out
size_t value_size; // out
// The caller needs to set a PJRT_KeyValueTryGetCallback_ValueDeleter to
// delete the value returned by PJRT_KeyValueTryGetCallback. The
// implementation is responsible for copying `value` and then calling
// value_deleter_callback.
PJRT_KeyValueTryGetCallback_ValueDeleter value_deleter_callback; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueTryGetCallback_Args,
value_deleter_callback);

// Requirements for PJRT_KeyValueTryGetCallback implementation: (1) Thread-safe.
// (2) The caller that provides the two callbacks is responsible for avoiding
// key collisions between different users of key-value store (i.e. between
// different plugins, but not between different nodes in one plugin).
typedef PJRT_Error* (*PJRT_KeyValueTryGetCallback)(
PJRT_KeyValueTryGetCallback_Args* args);

struct PJRT_KeyValuePutCallback_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
Expand Down Expand Up @@ -418,15 +389,8 @@ struct PJRT_Client_Create_Args {
void* kv_put_user_arg;

PJRT_Client* client; // out

// Key-value try-get callback provided by the caller of PJRT_Client_Create.
// Same as key-value get callback, but returns `NotFoundError` immediately if
// the key is not found.
PJRT_KeyValueTryGetCallback kv_try_get_callback;
// Will be passed to `kv_try_get_callback` as `user_arg` argument.
void* kv_try_get_user_arg;
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, kv_try_get_user_arg);
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, client);

// Creates and initializes a new PJRT_Client and returns in `client`.
typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args);
Expand Down
6 changes: 3 additions & 3 deletions xla/pjrt/c/pjrt_c_api_gpu_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
options.num_nodes = num_nodes;
options.allowed_devices = visible_devices;
options.platform_name = platform_name;
options.kv_store = pjrt::ToCppKeyValueStore(
args->kv_get_callback, args->kv_get_user_arg, args->kv_try_get_callback,
args->kv_try_get_user_arg, args->kv_put_callback, args->kv_put_user_arg);
options.kv_store =
pjrt::ToCppKeyValueStore(args->kv_get_callback, args->kv_get_user_arg,
args->kv_put_callback, args->kv_put_user_arg);
options.enable_mock_nccl = enable_mock_nccl;
options.mock_gpu_topology = mock_gpu_topology;
PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
Expand Down
38 changes: 0 additions & 38 deletions xla/pjrt/c/pjrt_c_api_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -795,25 +795,6 @@ static PJRT_KeyValueGetCFunc ToKVGetCFunc(
};
}

static PJRT_KeyValueTryGetCFunc ToKVTryGetCFunc(
xla::KeyValueStoreInterface* kv_store) {
return [kv_store](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* {
absl::StatusOr<std::string> output =
kv_store->TryGet(absl::string_view(args->key, args->key_size));
if (!output.ok()) {
absl::string_view message = output.status().message();
return (*args->callback_error)(
StatusCodeToPjrtErrorCode(output.status().code()), message.data(),
message.size());
}
args->value = new char[output->size()];
std::copy(output->begin(), output->end(), args->value);
args->value_size = output->size();
args->value_deleter_callback = &PjRtValueDeleterCallback;
return nullptr;
};
}

static PJRT_KeyValuePutCFunc ToKVPutCFunc(
xla::KeyValueStoreInterface* kv_store) {
return [kv_store](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* {
Expand Down Expand Up @@ -845,22 +826,6 @@ static PJRT_KeyValueGetCallback ToCKVGetCallback(
};
}

static PJRT_KeyValueTryGetCallback ToCKVTryGetCallback(
PJRT_KeyValueTryGetCFunc* kv_try_get_c_func) {
return [](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* {
PJRT_KeyValueTryGetCFunc* kv_try_get_c_func =
reinterpret_cast<PJRT_KeyValueTryGetCFunc*>(args->user_arg);
if (kv_try_get_c_func == nullptr) {
absl::Status status = xla::InvalidArgument(
"got nullptr for PJRT_KeyValueTryGet_Args.user_arg");
return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()),
status.message().data(),
status.message().size());
}
return (*kv_try_get_c_func)(args);
};
}

static PJRT_KeyValuePutCallback ToCKVPutCallback(
PJRT_KeyValuePutCFunc* kv_put_c_func) {
return [](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* {
Expand All @@ -881,12 +846,9 @@ std::unique_ptr<PJRT_KeyValueCallbackData> ConvertToCKeyValueCallbacks(
std::shared_ptr<xla::KeyValueStoreInterface> kv_store) {
auto kv_callback_data = std::make_unique<PJRT_KeyValueCallbackData>();
kv_callback_data->kv_get_c_func = ToKVGetCFunc(kv_store.get());
kv_callback_data->kv_try_get_c_func = ToKVTryGetCFunc(kv_store.get());
kv_callback_data->kv_put_c_func = ToKVPutCFunc(kv_store.get());
kv_callback_data->c_kv_get =
ToCKVGetCallback(&kv_callback_data->kv_get_c_func);
kv_callback_data->c_kv_try_get =
ToCKVTryGetCallback(&kv_callback_data->kv_try_get_c_func);
kv_callback_data->c_kv_put =
ToCKVPutCallback(&kv_callback_data->kv_put_c_func);
kv_callback_data->kv_store = std::move(kv_store);
Expand Down
17 changes: 5 additions & 12 deletions xla/pjrt/c/pjrt_c_api_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,6 @@ int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc);
using PJRT_KeyValueGetCFunc =
std::function<PJRT_Error*(PJRT_KeyValueGetCallback_Args* args)>;

using PJRT_KeyValueTryGetCFunc =
std::function<PJRT_Error*(PJRT_KeyValueTryGetCallback_Args* args)>;

using PJRT_KeyValuePutCFunc =
std::function<PJRT_Error*(PJRT_KeyValuePutCallback_Args* args)>;

Expand All @@ -231,21 +228,17 @@ struct PJRT_KeyValueCallbackData {

std::shared_ptr<xla::KeyValueStoreInterface> kv_store;

// kv_get_c_func, kv_try_get_c_func and kv_put_c_func are holding pointers to
// kv_store.
// kv_get_c_func and kv_put_c_func are holding pointers to kv_store.
pjrt::PJRT_KeyValueGetCFunc kv_get_c_func;
pjrt::PJRT_KeyValuePutCFunc kv_put_c_func;
// c_kv_get, c_kv_try_get and c_kv_put are holding pointers to kv_get_c_func,
// kv_try_get_c_func and kv_put_c_func.
// c_kv_get and c_kv_put are holding pointers to kv_get_c_func and
// kv_put_c_func.
PJRT_KeyValueGetCallback c_kv_get;
PJRT_KeyValuePutCallback c_kv_put;
pjrt::PJRT_KeyValueTryGetCFunc kv_try_get_c_func;
PJRT_KeyValueTryGetCallback c_kv_try_get;
};

// The returned &kv_get_c_func, &kv_try_get_c_func and &kv_put_c_func must be
// set as PJRT_Client_Create_Args.kv_get_user_arg,
// PJRT_Client_Create_Args.kv_try_get_user_arg and
// The returned &kv_get_c_func and &kv_put_c_func must be set as
// PJRT_Client_Create_Args.kv_get_user_arg and
// PJRT_Client_Create_Args.kv_put_user_arg, respectively. The entire
// PJRT_KeyValueCallbackData must be kept alive as long as c_kv_get and c_kv_put
// may be called.
Expand Down
8 changes: 0 additions & 8 deletions xla/pjrt/c/pjrt_c_api_helpers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,14 @@ TEST(PjRtCApiHelperTest, Callback) {
auto kv_callback_data = ConvertToCKeyValueCallbacks(kv_store);
auto converted_kv_store = ToCppKeyValueStore(
kv_callback_data->c_kv_get, &kv_callback_data->kv_get_c_func,
kv_callback_data->c_kv_try_get, &kv_callback_data->kv_try_get_c_func,
kv_callback_data->c_kv_put, &kv_callback_data->kv_put_c_func);

auto v_not_found = converted_kv_store->Get("key", absl::Seconds(1));
EXPECT_TRUE(absl::IsNotFound(v_not_found.status())) << v_not_found.status();

auto s = converted_kv_store->Set("key", "value");
TF_EXPECT_OK(s);

auto v = converted_kv_store->Get("key", absl::Seconds(1));
TF_EXPECT_OK(v.status());
EXPECT_EQ(*v, "value");

auto v_2 = converted_kv_store->TryGet("key");
TF_EXPECT_OK(v.status());
EXPECT_EQ(*v, "value");
}

TEST(PjRtCApiHelperTest, ConvertToCLayoutFromStrides) {
Expand Down
4 changes: 1 addition & 3 deletions xla/pjrt/c/pjrt_c_api_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,9 @@ PJRT_Client* CreateClient(const PJRT_Api* api) {
create_args.create_options = nullptr;
create_args.num_options = 0;
create_args.kv_get_callback = nullptr;
create_args.kv_get_user_arg = nullptr;
create_args.kv_put_callback = nullptr;
create_args.kv_put_user_arg = nullptr;
create_args.kv_try_get_callback = nullptr;
create_args.kv_try_get_user_arg = nullptr;
create_args.kv_get_user_arg = nullptr;
PJRT_Error* error = api->PJRT_Client_Create(&create_args);
CHECK_EQ(error, nullptr);
CHECK_NE(create_args.client, nullptr);
Expand Down
36 changes: 3 additions & 33 deletions xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,9 @@ static absl::Status PopulateExecutableOutputMemoryKinds(
class CApiKeyValueStore : public xla::KeyValueStoreInterface {
public:
CApiKeyValueStore(PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg,
PJRT_KeyValueTryGetCallback c_try_get_callback,
void* try_get_user_arg,
PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg)
: c_get_callback_(c_get_callback),
get_user_arg_(get_user_arg),
c_try_get_callback_(c_try_get_callback),
try_get_user_arg_(try_get_user_arg),
c_put_callback_(c_put_callback),
put_user_arg_(put_user_arg) {}

Expand All @@ -268,27 +264,6 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface {
return result;
}

absl::StatusOr<std::string> TryGet(absl::string_view key) override {
PJRT_CallbackError callback_error = [](PJRT_Error_Code code,
const char* message,
size_t message_size) {
return new PJRT_Error{absl::Status(static_cast<absl::StatusCode>(code),
std::string(message, message_size))};
};
PJRT_KeyValueTryGetCallback_Args args;
args.key = key.data();
args.key_size = key.size();
args.callback_error = &callback_error;
args.user_arg = try_get_user_arg_;
std::unique_ptr<PJRT_Error> error(c_try_get_callback_(&args));
if (error != nullptr) {
return error->status;
}
auto result = std::string(args.value, args.value_size);
args.value_deleter_callback(args.value);
return result;
}

absl::Status Set(absl::string_view key, absl::string_view value) override {
PJRT_CallbackError callback_error = [](PJRT_Error_Code code,
const char* message,
Expand All @@ -313,23 +288,18 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface {
private:
PJRT_KeyValueGetCallback c_get_callback_;
void* get_user_arg_;
PJRT_KeyValueTryGetCallback c_try_get_callback_;
void* try_get_user_arg_;
PJRT_KeyValuePutCallback c_put_callback_;
void* put_user_arg_;
};

std::shared_ptr<xla::KeyValueStoreInterface> ToCppKeyValueStore(
PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg,
PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg,
PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) {
if (c_get_callback == nullptr || c_try_get_callback == nullptr ||
c_put_callback == nullptr) {
if (c_get_callback == nullptr || c_put_callback == nullptr) {
return nullptr;
}
return std::make_shared<CApiKeyValueStore>(
c_get_callback, get_user_arg, c_try_get_callback, try_get_user_arg,
c_put_callback, put_user_arg);
return std::make_shared<CApiKeyValueStore>(c_get_callback, get_user_arg,
c_put_callback, put_user_arg);
}

// ---------------------------------- Errors -----------------------------------
Expand Down
1 change: 0 additions & 1 deletion xla/pjrt/c/pjrt_c_api_wrapper_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,6 @@ PJRT_Client* CreateWrapperClient(std::unique_ptr<xla::PjRtClient> cpp_client);
// Helper functions for converting C key-value store callbacks to C++ callbacks.
std::shared_ptr<xla::KeyValueStoreInterface> ToCppKeyValueStore(
PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg,
PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg,
PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg);

// A method that does not nothing other than returning a nullptr. Can be used as
Expand Down
12 changes: 0 additions & 12 deletions xla/pjrt/distributed/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "grpcpp/channel.h"
Expand Down Expand Up @@ -54,7 +53,6 @@ class DistributedRuntimeCoordinationServiceClient
absl::Status Shutdown() override;
absl::StatusOr<std::string> BlockingKeyValueGet(
absl::string_view key, absl::Duration timeout) override;
absl::StatusOr<std::string> KeyValueTryGet(absl::string_view key) override;
absl::StatusOr<std::vector<std::pair<std::string, std::string>>>
KeyValueDirGet(absl::string_view key) override;
absl::Status KeyValueSet(absl::string_view key,
Expand Down Expand Up @@ -146,12 +144,6 @@ DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet(
return coord_agent_->GetKeyValue(key, timeout);
}

absl::StatusOr<std::string>
DistributedRuntimeCoordinationServiceClient::KeyValueTryGet(
absl::string_view key) {
return coord_agent_->TryGetKeyValue(key);
}

absl::StatusOr<std::vector<std::pair<std::string, std::string>>>
DistributedRuntimeCoordinationServiceClient::KeyValueDirGet(
absl::string_view key) {
Expand Down Expand Up @@ -224,10 +216,6 @@ class DistributedKeyValueStore : public KeyValueStoreInterface {
return client_->BlockingKeyValueGet(absl::StrCat(prefix_, key), timeout);
}

absl::StatusOr<std::string> TryGet(absl::string_view key) override {
return client_->KeyValueTryGet(absl::StrCat(prefix_, key));
}

absl::Status Set(absl::string_view key, absl::string_view value) override {
return client_->KeyValueSet(absl::StrCat(prefix_, key), value);
}
Expand Down
4 changes: 0 additions & 4 deletions xla/pjrt/distributed/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "grpcpp/channel.h"
Expand Down Expand Up @@ -117,9 +116,6 @@ class DistributedRuntimeClient {
virtual absl::StatusOr<std::string> BlockingKeyValueGet(
absl::string_view key, absl::Duration timeout) = 0;

// Returns `NotFoundError` immediately if the key is not found.
virtual absl::StatusOr<std::string> KeyValueTryGet(absl::string_view key) = 0;

// Get all key-value pairs under a directory (key).
// A value is considered to be in the directory if its key is prefixed with
// the directory.
Expand Down
14 changes: 0 additions & 14 deletions xla/pjrt/distributed/client_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1029,20 +1029,6 @@ TEST_F(ClientServerTest, KeyValueSet_Duplicate_Overwrites) {
EXPECT_EQ(result.value(), "overwritten_value");
}

TEST_F(ClientServerTest, KeyValueTryGet) {
StartService(/*num_nodes=*/1);
auto client = GetClient(/*node_id=*/0);
TF_ASSERT_OK(client->Connect());

ASSERT_THAT(client->KeyValueTryGet("test_key").status(),
StatusIs(absl::StatusCode::kNotFound));

TF_ASSERT_OK(client->KeyValueSet("test_key", "value"));
auto result = client->KeyValueTryGet("test_key");
TF_ASSERT_OK(result.status());
EXPECT_EQ(result.value(), "value");
}

TEST_F(ClientServerTest, KeyValueDelete) {
StartService(/*num_nodes=*/1);
auto client = GetClient(/*node_id=*/0);
Expand Down
Loading

0 comments on commit 313d56f

Please sign in to comment.