From 313d56fc66638fc32abdba49f2614b54df51f900 Mon Sep 17 00:00:00 2001 From: xla authors Date: Thu, 19 Dec 2024 06:08:08 -0800 Subject: [PATCH] Rollback breaking C API changes (TryGetKeyValue()). Reverts 926ef6acf5ef98b363127a115c4614ec817d4804 PiperOrigin-RevId: 707888995 --- xla/pjrt/c/CHANGELOG.md | 6 --- xla/pjrt/c/pjrt_c_api.h | 40 +------------------ xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 6 +-- xla/pjrt/c/pjrt_c_api_helpers.cc | 38 ------------------ xla/pjrt/c/pjrt_c_api_helpers.h | 17 +++----- xla/pjrt/c/pjrt_c_api_helpers_test.cc | 8 ---- xla/pjrt/c/pjrt_c_api_test_base.cc | 4 +- xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 36 ++--------------- xla/pjrt/c/pjrt_c_api_wrapper_impl.h | 1 - xla/pjrt/distributed/client.cc | 12 ------ xla/pjrt/distributed/client.h | 4 -- xla/pjrt/distributed/client_server_test.cc | 14 ------- .../distributed/in_memory_key_value_store.cc | 12 ------ .../distributed/in_memory_key_value_store.h | 4 -- .../distributed/key_value_store_interface.h | 7 ---- xla/pjrt/pjrt_c_api_client.cc | 2 - xla/python/xla.cc | 15 ------- xla/python/xla_extension/__init__.pyi | 2 - 18 files changed, 14 insertions(+), 214 deletions(-) diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index d56741eb3500b..5852c9a54dcc0 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,10 +1,4 @@ # PJRT C API changelog - -## 0.61 -* Added ``PJRT_KeyValueTryGet`` to the KV store interface, - which is non-blocking and immediately returns an error if the - key is not found. - ## 0.60 * Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index f2fc3b1c507a3..36d82b0787ba4 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 61 +#define PJRT_API_MINOR 60 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -351,35 +351,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueGetCallback_Args, typedef PJRT_Error* (*PJRT_KeyValueGetCallback)( PJRT_KeyValueGetCallback_Args* args); -// Same as KeyValueGet, but returns `NotFoundError` immediately if the key is -// not found. -typedef void (*PJRT_KeyValueTryGetCallback_ValueDeleter)(char* value); - -struct PJRT_KeyValueTryGetCallback_Args { - size_t struct_size; - PJRT_Extension_Base* extension_start; - const char* key; - size_t key_size; - PJRT_CallbackError* callback_error; - void* user_arg; - char* value; // out - size_t value_size; // out - // The caller needs to set a PJRT_KeyValueTryGetCallback_ValueDeleter to - // delete the value returned by PJRT_KeyValueTryGetCallback. The - // implementation is responsible for copying `value` and then calling - // value_deleter_callback. - PJRT_KeyValueTryGetCallback_ValueDeleter value_deleter_callback; // out -}; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueTryGetCallback_Args, - value_deleter_callback); - -// Requirements for PJRT_KeyValueTryGetCallback implementation: (1) Thread-safe. -// (2) The caller that provides the two callbacks is responsible for avoiding -// key collisions between different users of key-value store (i.e. between -// different plugins, but not between different nodes in one plugin). -typedef PJRT_Error* (*PJRT_KeyValueTryGetCallback)( - PJRT_KeyValueTryGetCallback_Args* args); - struct PJRT_KeyValuePutCallback_Args { size_t struct_size; PJRT_Extension_Base* extension_start; @@ -418,15 +389,8 @@ struct PJRT_Client_Create_Args { void* kv_put_user_arg; PJRT_Client* client; // out - - // Key-value try-get callback provided by the caller of PJRT_Client_Create. - // Same as key-value get callback, but returns `NotFoundError` immediately if - // the key is not found. - PJRT_KeyValueTryGetCallback kv_try_get_callback; - // Will be passed to `kv_try_get_callback` as `user_arg` argument. - void* kv_try_get_user_arg; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, kv_try_get_user_arg); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, client); // Creates and initializes a new PJRT_Client and returns in `client`. typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args); diff --git a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 68d36fdb7f5c8..4f53c640a6a3d 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -154,9 +154,9 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { options.num_nodes = num_nodes; options.allowed_devices = visible_devices; options.platform_name = platform_name; - options.kv_store = pjrt::ToCppKeyValueStore( - args->kv_get_callback, args->kv_get_user_arg, args->kv_try_get_callback, - args->kv_try_get_user_arg, args->kv_put_callback, args->kv_put_user_arg); + options.kv_store = + pjrt::ToCppKeyValueStore(args->kv_get_callback, args->kv_get_user_arg, + args->kv_put_callback, args->kv_put_user_arg); options.enable_mock_nccl = enable_mock_nccl; options.mock_gpu_topology = mock_gpu_topology; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index ca094063c412a..cf92041af497d 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -795,25 +795,6 @@ static PJRT_KeyValueGetCFunc ToKVGetCFunc( }; } -static PJRT_KeyValueTryGetCFunc ToKVTryGetCFunc( - xla::KeyValueStoreInterface* kv_store) { - return [kv_store](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { - absl::StatusOr output = - kv_store->TryGet(absl::string_view(args->key, args->key_size)); - if (!output.ok()) { - absl::string_view message = output.status().message(); - return (*args->callback_error)( - StatusCodeToPjrtErrorCode(output.status().code()), message.data(), - message.size()); - } - args->value = new char[output->size()]; - std::copy(output->begin(), output->end(), args->value); - args->value_size = output->size(); - args->value_deleter_callback = &PjRtValueDeleterCallback; - return nullptr; - }; -} - static PJRT_KeyValuePutCFunc ToKVPutCFunc( xla::KeyValueStoreInterface* kv_store) { return [kv_store](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -845,22 +826,6 @@ static PJRT_KeyValueGetCallback ToCKVGetCallback( }; } -static PJRT_KeyValueTryGetCallback ToCKVTryGetCallback( - PJRT_KeyValueTryGetCFunc* kv_try_get_c_func) { - return [](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { - PJRT_KeyValueTryGetCFunc* kv_try_get_c_func = - reinterpret_cast(args->user_arg); - if (kv_try_get_c_func == nullptr) { - absl::Status status = xla::InvalidArgument( - "got nullptr for PJRT_KeyValueTryGet_Args.user_arg"); - return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()), - status.message().data(), - status.message().size()); - } - return (*kv_try_get_c_func)(args); - }; -} - static PJRT_KeyValuePutCallback ToCKVPutCallback( PJRT_KeyValuePutCFunc* kv_put_c_func) { return [](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -881,12 +846,9 @@ std::unique_ptr ConvertToCKeyValueCallbacks( std::shared_ptr kv_store) { auto kv_callback_data = std::make_unique(); kv_callback_data->kv_get_c_func = ToKVGetCFunc(kv_store.get()); - kv_callback_data->kv_try_get_c_func = ToKVTryGetCFunc(kv_store.get()); kv_callback_data->kv_put_c_func = ToKVPutCFunc(kv_store.get()); kv_callback_data->c_kv_get = ToCKVGetCallback(&kv_callback_data->kv_get_c_func); - kv_callback_data->c_kv_try_get = - ToCKVTryGetCallback(&kv_callback_data->kv_try_get_c_func); kv_callback_data->c_kv_put = ToCKVPutCallback(&kv_callback_data->kv_put_c_func); kv_callback_data->kv_store = std::move(kv_store); diff --git a/xla/pjrt/c/pjrt_c_api_helpers.h b/xla/pjrt/c/pjrt_c_api_helpers.h index baae41fbeca28..f530b82f42357 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/xla/pjrt/c/pjrt_c_api_helpers.h @@ -218,9 +218,6 @@ int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc); using PJRT_KeyValueGetCFunc = std::function; -using PJRT_KeyValueTryGetCFunc = - std::function; - using PJRT_KeyValuePutCFunc = std::function; @@ -231,21 +228,17 @@ struct PJRT_KeyValueCallbackData { std::shared_ptr kv_store; - // kv_get_c_func, kv_try_get_c_func and kv_put_c_func are holding pointers to - // kv_store. + // kv_get_c_func and kv_put_c_func are holding pointers to kv_store. pjrt::PJRT_KeyValueGetCFunc kv_get_c_func; pjrt::PJRT_KeyValuePutCFunc kv_put_c_func; - // c_kv_get, c_kv_try_get and c_kv_put are holding pointers to kv_get_c_func, - // kv_try_get_c_func and kv_put_c_func. + // c_kv_get and c_kv_put are holding pointers to kv_get_c_func and + // kv_put_c_func. PJRT_KeyValueGetCallback c_kv_get; PJRT_KeyValuePutCallback c_kv_put; - pjrt::PJRT_KeyValueTryGetCFunc kv_try_get_c_func; - PJRT_KeyValueTryGetCallback c_kv_try_get; }; -// The returned &kv_get_c_func, &kv_try_get_c_func and &kv_put_c_func must be -// set as PJRT_Client_Create_Args.kv_get_user_arg, -// PJRT_Client_Create_Args.kv_try_get_user_arg and +// The returned &kv_get_c_func and &kv_put_c_func must be set as +// PJRT_Client_Create_Args.kv_get_user_arg and // PJRT_Client_Create_Args.kv_put_user_arg, respectively. The entire // PJRT_KeyValueCallbackData must be kept alive as long as c_kv_get and c_kv_put // may be called. diff --git a/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/xla/pjrt/c/pjrt_c_api_helpers_test.cc index 6dfce81a1e451..4b8a59287589e 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -108,22 +108,14 @@ TEST(PjRtCApiHelperTest, Callback) { auto kv_callback_data = ConvertToCKeyValueCallbacks(kv_store); auto converted_kv_store = ToCppKeyValueStore( kv_callback_data->c_kv_get, &kv_callback_data->kv_get_c_func, - kv_callback_data->c_kv_try_get, &kv_callback_data->kv_try_get_c_func, kv_callback_data->c_kv_put, &kv_callback_data->kv_put_c_func); - auto v_not_found = converted_kv_store->Get("key", absl::Seconds(1)); - EXPECT_TRUE(absl::IsNotFound(v_not_found.status())) << v_not_found.status(); - auto s = converted_kv_store->Set("key", "value"); TF_EXPECT_OK(s); auto v = converted_kv_store->Get("key", absl::Seconds(1)); TF_EXPECT_OK(v.status()); EXPECT_EQ(*v, "value"); - - auto v_2 = converted_kv_store->TryGet("key"); - TF_EXPECT_OK(v.status()); - EXPECT_EQ(*v, "value"); } TEST(PjRtCApiHelperTest, ConvertToCLayoutFromStrides) { diff --git a/xla/pjrt/c/pjrt_c_api_test_base.cc b/xla/pjrt/c/pjrt_c_api_test_base.cc index f867846ebcbd5..9602813c573c5 100644 --- a/xla/pjrt/c/pjrt_c_api_test_base.cc +++ b/xla/pjrt/c/pjrt_c_api_test_base.cc @@ -47,11 +47,9 @@ PJRT_Client* CreateClient(const PJRT_Api* api) { create_args.create_options = nullptr; create_args.num_options = 0; create_args.kv_get_callback = nullptr; - create_args.kv_get_user_arg = nullptr; create_args.kv_put_callback = nullptr; create_args.kv_put_user_arg = nullptr; - create_args.kv_try_get_callback = nullptr; - create_args.kv_try_get_user_arg = nullptr; + create_args.kv_get_user_arg = nullptr; PJRT_Error* error = api->PJRT_Client_Create(&create_args); CHECK_EQ(error, nullptr); CHECK_NE(create_args.client, nullptr); diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 222d689b3b68e..ec697b08af784 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -235,13 +235,9 @@ static absl::Status PopulateExecutableOutputMemoryKinds( class CApiKeyValueStore : public xla::KeyValueStoreInterface { public: CApiKeyValueStore(PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, - PJRT_KeyValueTryGetCallback c_try_get_callback, - void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) : c_get_callback_(c_get_callback), get_user_arg_(get_user_arg), - c_try_get_callback_(c_try_get_callback), - try_get_user_arg_(try_get_user_arg), c_put_callback_(c_put_callback), put_user_arg_(put_user_arg) {} @@ -268,27 +264,6 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { return result; } - absl::StatusOr TryGet(absl::string_view key) override { - PJRT_CallbackError callback_error = [](PJRT_Error_Code code, - const char* message, - size_t message_size) { - return new PJRT_Error{absl::Status(static_cast(code), - std::string(message, message_size))}; - }; - PJRT_KeyValueTryGetCallback_Args args; - args.key = key.data(); - args.key_size = key.size(); - args.callback_error = &callback_error; - args.user_arg = try_get_user_arg_; - std::unique_ptr error(c_try_get_callback_(&args)); - if (error != nullptr) { - return error->status; - } - auto result = std::string(args.value, args.value_size); - args.value_deleter_callback(args.value); - return result; - } - absl::Status Set(absl::string_view key, absl::string_view value) override { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, @@ -313,23 +288,18 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { private: PJRT_KeyValueGetCallback c_get_callback_; void* get_user_arg_; - PJRT_KeyValueTryGetCallback c_try_get_callback_; - void* try_get_user_arg_; PJRT_KeyValuePutCallback c_put_callback_; void* put_user_arg_; }; std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, - PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) { - if (c_get_callback == nullptr || c_try_get_callback == nullptr || - c_put_callback == nullptr) { + if (c_get_callback == nullptr || c_put_callback == nullptr) { return nullptr; } - return std::make_shared( - c_get_callback, get_user_arg, c_try_get_callback, try_get_user_arg, - c_put_callback, put_user_arg); + return std::make_shared(c_get_callback, get_user_arg, + c_put_callback, put_user_arg); } // ---------------------------------- Errors ----------------------------------- diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 873845d3ac815..0ebecc0c25173 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -464,7 +464,6 @@ PJRT_Client* CreateWrapperClient(std::unique_ptr cpp_client); // Helper functions for converting C key-value store callbacks to C++ callbacks. std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, - PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg); // A method that does not nothing other than returning a nullptr. Can be used as diff --git a/xla/pjrt/distributed/client.cc b/xla/pjrt/distributed/client.cc index 305afe7ae4c6d..280c60873e9d0 100644 --- a/xla/pjrt/distributed/client.cc +++ b/xla/pjrt/distributed/client.cc @@ -26,7 +26,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -54,7 +53,6 @@ class DistributedRuntimeCoordinationServiceClient absl::Status Shutdown() override; absl::StatusOr BlockingKeyValueGet( absl::string_view key, absl::Duration timeout) override; - absl::StatusOr KeyValueTryGet(absl::string_view key) override; absl::StatusOr>> KeyValueDirGet(absl::string_view key) override; absl::Status KeyValueSet(absl::string_view key, @@ -146,12 +144,6 @@ DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet( return coord_agent_->GetKeyValue(key, timeout); } -absl::StatusOr -DistributedRuntimeCoordinationServiceClient::KeyValueTryGet( - absl::string_view key) { - return coord_agent_->TryGetKeyValue(key); -} - absl::StatusOr>> DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( absl::string_view key) { @@ -224,10 +216,6 @@ class DistributedKeyValueStore : public KeyValueStoreInterface { return client_->BlockingKeyValueGet(absl::StrCat(prefix_, key), timeout); } - absl::StatusOr TryGet(absl::string_view key) override { - return client_->KeyValueTryGet(absl::StrCat(prefix_, key)); - } - absl::Status Set(absl::string_view key, absl::string_view value) override { return client_->KeyValueSet(absl::StrCat(prefix_, key), value); } diff --git a/xla/pjrt/distributed/client.h b/xla/pjrt/distributed/client.h index 58f4fe367681d..e597ff158cc67 100644 --- a/xla/pjrt/distributed/client.h +++ b/xla/pjrt/distributed/client.h @@ -27,7 +27,6 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -117,9 +116,6 @@ class DistributedRuntimeClient { virtual absl::StatusOr BlockingKeyValueGet( absl::string_view key, absl::Duration timeout) = 0; - // Returns `NotFoundError` immediately if the key is not found. - virtual absl::StatusOr KeyValueTryGet(absl::string_view key) = 0; - // Get all key-value pairs under a directory (key). // A value is considered to be in the directory if its key is prefixed with // the directory. diff --git a/xla/pjrt/distributed/client_server_test.cc b/xla/pjrt/distributed/client_server_test.cc index baec103eced93..f5b7e656fe69a 100644 --- a/xla/pjrt/distributed/client_server_test.cc +++ b/xla/pjrt/distributed/client_server_test.cc @@ -1029,20 +1029,6 @@ TEST_F(ClientServerTest, KeyValueSet_Duplicate_Overwrites) { EXPECT_EQ(result.value(), "overwritten_value"); } -TEST_F(ClientServerTest, KeyValueTryGet) { - StartService(/*num_nodes=*/1); - auto client = GetClient(/*node_id=*/0); - TF_ASSERT_OK(client->Connect()); - - ASSERT_THAT(client->KeyValueTryGet("test_key").status(), - StatusIs(absl::StatusCode::kNotFound)); - - TF_ASSERT_OK(client->KeyValueSet("test_key", "value")); - auto result = client->KeyValueTryGet("test_key"); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(result.value(), "value"); -} - TEST_F(ClientServerTest, KeyValueDelete) { StartService(/*num_nodes=*/1); auto client = GetClient(/*node_id=*/0); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.cc b/xla/pjrt/distributed/in_memory_key_value_store.cc index 49fc73ec87f16..70cc5360ecf7b 100644 --- a/xla/pjrt/distributed/in_memory_key_value_store.cc +++ b/xla/pjrt/distributed/in_memory_key_value_store.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -41,17 +40,6 @@ absl::StatusOr InMemoryKeyValueStore::Get(absl::string_view key, return kv_store_.find(key)->second; } -absl::StatusOr InMemoryKeyValueStore::TryGet( - absl::string_view key) { - absl::MutexLock lock(&mu_); - auto it = kv_store_.find(key); - if (it == kv_store_.end()) { - return absl::NotFoundError( - absl::StrCat(key, " is not found in the kv store.")); - } - return it->second; -} - absl::Status InMemoryKeyValueStore::Set(absl::string_view key, absl::string_view value) { absl::MutexLock lock(&mu_); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.h b/xla/pjrt/distributed/in_memory_key_value_store.h index 13f50c722bd12..1530633a98b75 100644 --- a/xla/pjrt/distributed/in_memory_key_value_store.h +++ b/xla/pjrt/distributed/in_memory_key_value_store.h @@ -21,9 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "absl/time/time.h" #include "xla/pjrt/distributed/key_value_store_interface.h" namespace xla { @@ -33,8 +31,6 @@ class InMemoryKeyValueStore : public KeyValueStoreInterface { absl::StatusOr Get(absl::string_view key, absl::Duration timeout) override; - absl::StatusOr TryGet(absl::string_view key) override; - absl::Status Set(absl::string_view key, absl::string_view value) override; private: diff --git a/xla/pjrt/distributed/key_value_store_interface.h b/xla/pjrt/distributed/key_value_store_interface.h index 312ebb8abb646..29580fb86847b 100644 --- a/xla/pjrt/distributed/key_value_store_interface.h +++ b/xla/pjrt/distributed/key_value_store_interface.h @@ -38,18 +38,11 @@ class KeyValueStoreInterface { virtual ~KeyValueStoreInterface() = default; // Blocking Get(). - // Useful for listening for a key-value pair that may be set later on. // There are no concurrency guarantees. To avoid a race / impose an ordering // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). virtual absl::StatusOr Get(absl::string_view key, absl::Duration timeout) = 0; - // Returns `NotFoundError` immediately if the key is not found. - // Useful for checking key existence. - // There are no concurrency guarantees. To avoid a race / impose an ordering - // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). - virtual absl::StatusOr TryGet(absl::string_view key) = 0; - virtual absl::Status Set(absl::string_view key, absl::string_view value) = 0; }; diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index 1f65b13109afc..8855ef33620e5 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -2578,8 +2578,6 @@ absl::StatusOr> WrapClientAroundCApi( kv_callback_data = pjrt::ConvertToCKeyValueCallbacks(kv_store); init_args.kv_get_callback = kv_callback_data->c_kv_get; init_args.kv_get_user_arg = &kv_callback_data->kv_get_c_func; - init_args.kv_try_get_callback = kv_callback_data->c_kv_try_get; - init_args.kv_try_get_user_arg = &kv_callback_data->kv_try_get_c_func; init_args.kv_put_callback = kv_callback_data->c_kv_put; init_args.kv_put_user_arg = &kv_callback_data->kv_put_c_func; } diff --git a/xla/python/xla.cc b/xla/python/xla.cc index e30af5d4e5e43..51c96229493e4 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -672,21 +672,6 @@ NB_MODULE(xla_extension, m) { return nb::bytes(result.data(), result.size()); }, nb::arg("key"), nb::arg("timeout_in_ms")) - .def( - "key_value_try_get", - [](DistributedRuntimeClient& client, std::string key) { - nb::gil_scoped_release gil_release; - return xla::ValueOrThrow(client.KeyValueTryGet(key)); - }, - nb::arg("key")) - .def( - "key_value_try_get_bytes", - [](DistributedRuntimeClient& client, std::string key) -> nb::bytes { - nb::gil_scoped_release gil_release; - std::string result = xla::ValueOrThrow(client.KeyValueTryGet(key)); - return nb::bytes(result.data(), result.size()); - }, - nb::arg("key")) .def( "wait_at_barrier", [](DistributedRuntimeClient& client, std::string barrier_id, diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 5fa885f9f9225..2e3862285898f 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -830,8 +830,6 @@ class DistributedRuntimeClient: def blocking_key_value_get_bytes( self, key: str, timeout_in_ms: int ) -> _Status: ... - def key_value_try_get(self, key: str) -> _Status: ... - def key_value_try_get_bytes(self, key: str) -> _Status: ... def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str,