Skip to content

Commit

Permalink
Merge branch 'main' into corendos/pjrt-execution-context
Browse files Browse the repository at this point in the history
  • Loading branch information
Corendos committed Dec 19, 2024
2 parents d0b15ef + 313d56f commit 034e6cc
Show file tree
Hide file tree
Showing 43 changed files with 593 additions and 280 deletions.
11 changes: 11 additions & 0 deletions xla/hlo/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,9 @@ cc_library(
"//xla/hlo/analysis:hlo_dataflow_analysis",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
)

Expand All @@ -1346,6 +1349,10 @@ xla_cc_test(
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:test_main",
],
)
Expand Down Expand Up @@ -1838,6 +1845,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:errors",
Expand Down Expand Up @@ -1879,6 +1887,7 @@ cc_library(
"//xla:side_effect_util",
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/analysis:hlo_alias_analysis",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
Expand Down Expand Up @@ -2274,10 +2283,12 @@ xla_cc_test(
deps = [
":operand_upcaster",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/utils:hlo_matchers",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_main",
],
Expand Down
1 change: 0 additions & 1 deletion xla/hlo/transforms/host_offload_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/transforms/host_offload_legalize.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

#include <cstdint>
#include <memory>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/analysis/hlo_alias_analysis.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
Expand Down
3 changes: 0 additions & 3 deletions xla/hlo/transforms/host_offload_legalize_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@ limitations under the License.
#include "xla/hlo/transforms/host_offload_legalize.h"

#include <cstdint>
#include <stack>
#include <string>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_computation.h"
Expand Down
8 changes: 2 additions & 6 deletions xla/hlo/transforms/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,10 @@ limitations under the License.

#include "xla/hlo/transforms/host_offloader.h"

#include <array>
#include <cstddef>
#include <cstdint>
#include <iomanip>
#include <memory>
#include <optional>
#include <queue>
#include <string>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
Expand All @@ -35,7 +30,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/analysis/hlo_alias_analysis.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
Expand All @@ -56,6 +51,7 @@ limitations under the License.
#include "xla/side_effect_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
Expand Down
3 changes: 3 additions & 0 deletions xla/hlo/transforms/host_offloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
#include <cstdint>
#include <memory>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/analysis/hlo_alias_analysis.h"
Expand Down
1 change: 0 additions & 1 deletion xla/hlo/transforms/host_offloader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <stack>
#include <string>
#include <vector>

Expand Down
4 changes: 4 additions & 0 deletions xla/hlo/transforms/memory_space_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ limitations under the License.
#include "xla/hlo/transforms/memory_space_propagation.h"

#include <cstdint>
#include <utility>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/shape.h"
#include "xla/shape_util.h"

Expand Down
6 changes: 6 additions & 0 deletions xla/hlo/transforms/memory_space_propagation.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ limitations under the License.
#ifndef XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_
#define XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_

#include <cstdint>
#include <memory>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/analysis/hlo_dataflow_analysis.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
Expand Down
4 changes: 4 additions & 0 deletions xla/hlo/transforms/memory_space_propagation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ limitations under the License.

#include "xla/hlo/transforms/memory_space_propagation.h"

#include <gtest/gtest.h>
#include "absl/hash/hash.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"
Expand Down
3 changes: 3 additions & 0 deletions xla/hlo/transforms/operand_upcaster_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ limitations under the License.
#include <memory>
#include <tuple>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/primitive_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/statusor.h"

namespace xla {
Expand Down
7 changes: 1 addition & 6 deletions xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
# PJRT C API changelog

## 0.62
* Added ``context`` field of type ``PJRT_ExecuteContext *`` in ``PJRT_ExecuteOptions``.

## 0.61
* Added ``PJRT_KeyValueTryGet`` to the KV store interface,
which is non-blocking and immediately returns an error if the
key is not found.
* Added ``context`` field of type ``PJRT_ExecuteContext *`` in ``PJRT_ExecuteOptions``.

## 0.60
* Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``.
Expand Down
40 changes: 2 additions & 38 deletions xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 62
#define PJRT_API_MINOR 61

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -351,35 +351,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueGetCallback_Args,
typedef PJRT_Error* (*PJRT_KeyValueGetCallback)(
PJRT_KeyValueGetCallback_Args* args);

// Same as KeyValueGet, but returns `NotFoundError` immediately if the key is
// not found.
typedef void (*PJRT_KeyValueTryGetCallback_ValueDeleter)(char* value);

struct PJRT_KeyValueTryGetCallback_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
const char* key;
size_t key_size;
PJRT_CallbackError* callback_error;
void* user_arg;
char* value; // out
size_t value_size; // out
// The caller needs to set a PJRT_KeyValueTryGetCallback_ValueDeleter to
// delete the value returned by PJRT_KeyValueTryGetCallback. The
// implementation is responsible for copying `value` and then calling
// value_deleter_callback.
PJRT_KeyValueTryGetCallback_ValueDeleter value_deleter_callback; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueTryGetCallback_Args,
value_deleter_callback);

// Requirements for PJRT_KeyValueTryGetCallback implementation: (1) Thread-safe.
// (2) The caller that provides the two callbacks is responsible for avoiding
// key collisions between different users of key-value store (i.e. between
// different plugins, but not between different nodes in one plugin).
typedef PJRT_Error* (*PJRT_KeyValueTryGetCallback)(
PJRT_KeyValueTryGetCallback_Args* args);

struct PJRT_KeyValuePutCallback_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
Expand Down Expand Up @@ -418,15 +389,8 @@ struct PJRT_Client_Create_Args {
void* kv_put_user_arg;

PJRT_Client* client; // out

// Key-value try-get callback provided by the caller of PJRT_Client_Create.
// Same as key-value get callback, but returns `NotFoundError` immediately if
// the key is not found.
PJRT_KeyValueTryGetCallback kv_try_get_callback;
// Will be passed to `kv_try_get_callback` as `user_arg` argument.
void* kv_try_get_user_arg;
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, kv_try_get_user_arg);
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, client);

// Creates and initializes a new PJRT_Client and returns in `client`.
typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args);
Expand Down
6 changes: 3 additions & 3 deletions xla/pjrt/c/pjrt_c_api_gpu_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
options.num_nodes = num_nodes;
options.allowed_devices = visible_devices;
options.platform_name = platform_name;
options.kv_store = pjrt::ToCppKeyValueStore(
args->kv_get_callback, args->kv_get_user_arg, args->kv_try_get_callback,
args->kv_try_get_user_arg, args->kv_put_callback, args->kv_put_user_arg);
options.kv_store =
pjrt::ToCppKeyValueStore(args->kv_get_callback, args->kv_get_user_arg,
args->kv_put_callback, args->kv_put_user_arg);
options.enable_mock_nccl = enable_mock_nccl;
options.mock_gpu_topology = mock_gpu_topology;
PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
Expand Down
38 changes: 0 additions & 38 deletions xla/pjrt/c/pjrt_c_api_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -795,25 +795,6 @@ static PJRT_KeyValueGetCFunc ToKVGetCFunc(
};
}

static PJRT_KeyValueTryGetCFunc ToKVTryGetCFunc(
xla::KeyValueStoreInterface* kv_store) {
return [kv_store](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* {
absl::StatusOr<std::string> output =
kv_store->TryGet(absl::string_view(args->key, args->key_size));
if (!output.ok()) {
absl::string_view message = output.status().message();
return (*args->callback_error)(
StatusCodeToPjrtErrorCode(output.status().code()), message.data(),
message.size());
}
args->value = new char[output->size()];
std::copy(output->begin(), output->end(), args->value);
args->value_size = output->size();
args->value_deleter_callback = &PjRtValueDeleterCallback;
return nullptr;
};
}

static PJRT_KeyValuePutCFunc ToKVPutCFunc(
xla::KeyValueStoreInterface* kv_store) {
return [kv_store](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* {
Expand Down Expand Up @@ -845,22 +826,6 @@ static PJRT_KeyValueGetCallback ToCKVGetCallback(
};
}

static PJRT_KeyValueTryGetCallback ToCKVTryGetCallback(
PJRT_KeyValueTryGetCFunc* kv_try_get_c_func) {
return [](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* {
PJRT_KeyValueTryGetCFunc* kv_try_get_c_func =
reinterpret_cast<PJRT_KeyValueTryGetCFunc*>(args->user_arg);
if (kv_try_get_c_func == nullptr) {
absl::Status status = xla::InvalidArgument(
"got nullptr for PJRT_KeyValueTryGet_Args.user_arg");
return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()),
status.message().data(),
status.message().size());
}
return (*kv_try_get_c_func)(args);
};
}

static PJRT_KeyValuePutCallback ToCKVPutCallback(
PJRT_KeyValuePutCFunc* kv_put_c_func) {
return [](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* {
Expand All @@ -881,12 +846,9 @@ std::unique_ptr<PJRT_KeyValueCallbackData> ConvertToCKeyValueCallbacks(
std::shared_ptr<xla::KeyValueStoreInterface> kv_store) {
auto kv_callback_data = std::make_unique<PJRT_KeyValueCallbackData>();
kv_callback_data->kv_get_c_func = ToKVGetCFunc(kv_store.get());
kv_callback_data->kv_try_get_c_func = ToKVTryGetCFunc(kv_store.get());
kv_callback_data->kv_put_c_func = ToKVPutCFunc(kv_store.get());
kv_callback_data->c_kv_get =
ToCKVGetCallback(&kv_callback_data->kv_get_c_func);
kv_callback_data->c_kv_try_get =
ToCKVTryGetCallback(&kv_callback_data->kv_try_get_c_func);
kv_callback_data->c_kv_put =
ToCKVPutCallback(&kv_callback_data->kv_put_c_func);
kv_callback_data->kv_store = std::move(kv_store);
Expand Down
17 changes: 5 additions & 12 deletions xla/pjrt/c/pjrt_c_api_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,6 @@ int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc);
using PJRT_KeyValueGetCFunc =
std::function<PJRT_Error*(PJRT_KeyValueGetCallback_Args* args)>;

using PJRT_KeyValueTryGetCFunc =
std::function<PJRT_Error*(PJRT_KeyValueTryGetCallback_Args* args)>;

using PJRT_KeyValuePutCFunc =
std::function<PJRT_Error*(PJRT_KeyValuePutCallback_Args* args)>;

Expand All @@ -231,21 +228,17 @@ struct PJRT_KeyValueCallbackData {

std::shared_ptr<xla::KeyValueStoreInterface> kv_store;

// kv_get_c_func, kv_try_get_c_func and kv_put_c_func are holding pointers to
// kv_store.
// kv_get_c_func and kv_put_c_func are holding pointers to kv_store.
pjrt::PJRT_KeyValueGetCFunc kv_get_c_func;
pjrt::PJRT_KeyValuePutCFunc kv_put_c_func;
// c_kv_get, c_kv_try_get and c_kv_put are holding pointers to kv_get_c_func,
// kv_try_get_c_func and kv_put_c_func.
// c_kv_get and c_kv_put are holding pointers to kv_get_c_func and
// kv_put_c_func.
PJRT_KeyValueGetCallback c_kv_get;
PJRT_KeyValuePutCallback c_kv_put;
pjrt::PJRT_KeyValueTryGetCFunc kv_try_get_c_func;
PJRT_KeyValueTryGetCallback c_kv_try_get;
};

// The returned &kv_get_c_func, &kv_try_get_c_func and &kv_put_c_func must be
// set as PJRT_Client_Create_Args.kv_get_user_arg,
// PJRT_Client_Create_Args.kv_try_get_user_arg and
// The returned &kv_get_c_func and &kv_put_c_func must be set as
// PJRT_Client_Create_Args.kv_get_user_arg and
// PJRT_Client_Create_Args.kv_put_user_arg, respectively. The entire
// PJRT_KeyValueCallbackData must be kept alive as long as c_kv_get and c_kv_put
// may be called.
Expand Down
8 changes: 0 additions & 8 deletions xla/pjrt/c/pjrt_c_api_helpers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,14 @@ TEST(PjRtCApiHelperTest, Callback) {
auto kv_callback_data = ConvertToCKeyValueCallbacks(kv_store);
auto converted_kv_store = ToCppKeyValueStore(
kv_callback_data->c_kv_get, &kv_callback_data->kv_get_c_func,
kv_callback_data->c_kv_try_get, &kv_callback_data->kv_try_get_c_func,
kv_callback_data->c_kv_put, &kv_callback_data->kv_put_c_func);

auto v_not_found = converted_kv_store->Get("key", absl::Seconds(1));
EXPECT_TRUE(absl::IsNotFound(v_not_found.status())) << v_not_found.status();

auto s = converted_kv_store->Set("key", "value");
TF_EXPECT_OK(s);

auto v = converted_kv_store->Get("key", absl::Seconds(1));
TF_EXPECT_OK(v.status());
EXPECT_EQ(*v, "value");

auto v_2 = converted_kv_store->TryGet("key");
TF_EXPECT_OK(v.status());
EXPECT_EQ(*v, "value");
}

TEST(PjRtCApiHelperTest, ConvertToCLayoutFromStrides) {
Expand Down
Loading

0 comments on commit 034e6cc

Please sign in to comment.