diff --git a/xla/hlo/transforms/BUILD b/xla/hlo/transforms/BUILD index d895b3cc917306..84cf00b65702b1 100644 --- a/xla/hlo/transforms/BUILD +++ b/xla/hlo/transforms/BUILD @@ -1335,6 +1335,9 @@ cc_library( "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) @@ -1346,6 +1349,10 @@ xla_cc_test( "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", "@tsl//tsl/platform:test_main", ], ) @@ -1838,6 +1845,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:errors", @@ -1879,6 +1887,7 @@ cc_library( "//xla:side_effect_util", "//xla:status_macros", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -2274,10 +2283,12 @@ xla_cc_test( deps = [ ":operand_upcaster", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_matchers", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:test_main", ], diff --git a/xla/hlo/transforms/host_offload_legalize.cc b/xla/hlo/transforms/host_offload_legalize.cc index 639e37874ceb4b..5e70dbb26c7d21 100644 --- a/xla/hlo/transforms/host_offload_legalize.cc +++ b/xla/hlo/transforms/host_offload_legalize.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" diff --git a/xla/hlo/transforms/host_offload_legalize.h b/xla/hlo/transforms/host_offload_legalize.h index a5d85fa40a8a5c..e08c842ee0bc68 100644 --- a/xla/hlo/transforms/host_offload_legalize.h +++ b/xla/hlo/transforms/host_offload_legalize.h @@ -17,8 +17,10 @@ #include #include +#include #include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/xla/hlo/transforms/host_offload_legalize_test.cc b/xla/hlo/transforms/host_offload_legalize_test.cc index 4aedc40b8ca2be..a37a73fc149f9f 100644 --- a/xla/hlo/transforms/host_offload_legalize_test.cc +++ b/xla/hlo/transforms/host_offload_legalize_test.cc @@ -16,12 +16,9 @@ limitations under the License. #include "xla/hlo/transforms/host_offload_legalize.h" #include -#include #include -#include #include -#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_computation.h" diff --git a/xla/hlo/transforms/host_offloader.cc b/xla/hlo/transforms/host_offloader.cc index 7b798fe38eef7b..833fa176b78b00 100644 --- a/xla/hlo/transforms/host_offloader.cc +++ b/xla/hlo/transforms/host_offloader.cc @@ -15,15 +15,10 @@ limitations under the License. #include "xla/hlo/transforms/host_offloader.h" -#include -#include #include #include #include -#include #include -#include -#include #include #include "absl/algorithm/container.h" @@ -35,7 +30,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -56,6 +51,7 @@ limitations under the License. #include "xla/side_effect_util.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/xla/hlo/transforms/host_offloader.h b/xla/hlo/transforms/host_offloader.h index 765b3c2709856e..8e79a449261783 100644 --- a/xla/hlo/transforms/host_offloader.h +++ b/xla/hlo/transforms/host_offloader.h @@ -18,8 +18,11 @@ #include #include #include +#include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_alias_analysis.h" diff --git a/xla/hlo/transforms/host_offloader_test.cc b/xla/hlo/transforms/host_offloader_test.cc index 1452815127f1a7..d38526e93178af 100644 --- a/xla/hlo/transforms/host_offloader_test.cc +++ b/xla/hlo/transforms/host_offloader_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include diff --git a/xla/hlo/transforms/memory_space_propagation.cc b/xla/hlo/transforms/memory_space_propagation.cc index d0704df0e88af9..3dc14572dc408b 100644 --- a/xla/hlo/transforms/memory_space_propagation.cc +++ b/xla/hlo/transforms/memory_space_propagation.cc @@ -16,7 +16,11 @@ limitations under the License. #include "xla/hlo/transforms/memory_space_propagation.h" #include +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/xla/hlo/transforms/memory_space_propagation.h b/xla/hlo/transforms/memory_space_propagation.h index bb0da70bf1a7fc..b3998f542d39f5 100644 --- a/xla/hlo/transforms/memory_space_propagation.h +++ b/xla/hlo/transforms/memory_space_propagation.h @@ -16,6 +16,12 @@ limitations under the License. #ifndef XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_ #define XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_ +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/xla/hlo/transforms/memory_space_propagation_test.cc b/xla/hlo/transforms/memory_space_propagation_test.cc index 15cd6c4cd4cbff..a1252d596ee281 100644 --- a/xla/hlo/transforms/memory_space_propagation_test.cc +++ b/xla/hlo/transforms/memory_space_propagation_test.cc @@ -15,6 +15,10 @@ limitations under the License. #include "xla/hlo/transforms/memory_space_propagation.h" +#include +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/xla/hlo/transforms/operand_upcaster_test.cc b/xla/hlo/transforms/operand_upcaster_test.cc index 8a143b365af618..ed61bb63d2dad6 100644 --- a/xla/hlo/transforms/operand_upcaster_test.cc +++ b/xla/hlo/transforms/operand_upcaster_test.cc @@ -18,12 +18,15 @@ limitations under the License. #include #include +#include +#include #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/primitive_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index 2d44c7dfac6fd2..c76806ef920d04 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,12 +1,7 @@ # PJRT C API changelog -## 0.62 -* Added ``context`` field of type ``PJRT_ExecuteContext *`` in ``PJRT_ExecuteOptions``. - ## 0.61 -* Added ``PJRT_KeyValueTryGet`` to the KV store interface, - which is non-blocking and immediately returns an error if the - key is not found. +* Added ``context`` field of type ``PJRT_ExecuteContext *`` in ``PJRT_ExecuteOptions``. ## 0.60 * Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 359c51a2e09676..f267db05fff1b7 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 62 +#define PJRT_API_MINOR 61 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -351,35 +351,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueGetCallback_Args, typedef PJRT_Error* (*PJRT_KeyValueGetCallback)( PJRT_KeyValueGetCallback_Args* args); -// Same as KeyValueGet, but returns `NotFoundError` immediately if the key is -// not found. -typedef void (*PJRT_KeyValueTryGetCallback_ValueDeleter)(char* value); - -struct PJRT_KeyValueTryGetCallback_Args { - size_t struct_size; - PJRT_Extension_Base* extension_start; - const char* key; - size_t key_size; - PJRT_CallbackError* callback_error; - void* user_arg; - char* value; // out - size_t value_size; // out - // The caller needs to set a PJRT_KeyValueTryGetCallback_ValueDeleter to - // delete the value returned by PJRT_KeyValueTryGetCallback. The - // implementation is responsible for copying `value` and then calling - // value_deleter_callback. - PJRT_KeyValueTryGetCallback_ValueDeleter value_deleter_callback; // out -}; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueTryGetCallback_Args, - value_deleter_callback); - -// Requirements for PJRT_KeyValueTryGetCallback implementation: (1) Thread-safe. -// (2) The caller that provides the two callbacks is responsible for avoiding -// key collisions between different users of key-value store (i.e. between -// different plugins, but not between different nodes in one plugin). -typedef PJRT_Error* (*PJRT_KeyValueTryGetCallback)( - PJRT_KeyValueTryGetCallback_Args* args); - struct PJRT_KeyValuePutCallback_Args { size_t struct_size; PJRT_Extension_Base* extension_start; @@ -418,15 +389,8 @@ struct PJRT_Client_Create_Args { void* kv_put_user_arg; PJRT_Client* client; // out - - // Key-value try-get callback provided by the caller of PJRT_Client_Create. - // Same as key-value get callback, but returns `NotFoundError` immediately if - // the key is not found. - PJRT_KeyValueTryGetCallback kv_try_get_callback; - // Will be passed to `kv_try_get_callback` as `user_arg` argument. - void* kv_try_get_user_arg; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, kv_try_get_user_arg); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, client); // Creates and initializes a new PJRT_Client and returns in `client`. typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args); diff --git a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 68d36fdb7f5c86..4f53c640a6a3dc 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -154,9 +154,9 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { options.num_nodes = num_nodes; options.allowed_devices = visible_devices; options.platform_name = platform_name; - options.kv_store = pjrt::ToCppKeyValueStore( - args->kv_get_callback, args->kv_get_user_arg, args->kv_try_get_callback, - args->kv_try_get_user_arg, args->kv_put_callback, args->kv_put_user_arg); + options.kv_store = + pjrt::ToCppKeyValueStore(args->kv_get_callback, args->kv_get_user_arg, + args->kv_put_callback, args->kv_put_user_arg); options.enable_mock_nccl = enable_mock_nccl; options.mock_gpu_topology = mock_gpu_topology; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index ca094063c412aa..cf92041af497d5 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -795,25 +795,6 @@ static PJRT_KeyValueGetCFunc ToKVGetCFunc( }; } -static PJRT_KeyValueTryGetCFunc ToKVTryGetCFunc( - xla::KeyValueStoreInterface* kv_store) { - return [kv_store](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { - absl::StatusOr output = - kv_store->TryGet(absl::string_view(args->key, args->key_size)); - if (!output.ok()) { - absl::string_view message = output.status().message(); - return (*args->callback_error)( - StatusCodeToPjrtErrorCode(output.status().code()), message.data(), - message.size()); - } - args->value = new char[output->size()]; - std::copy(output->begin(), output->end(), args->value); - args->value_size = output->size(); - args->value_deleter_callback = &PjRtValueDeleterCallback; - return nullptr; - }; -} - static PJRT_KeyValuePutCFunc ToKVPutCFunc( xla::KeyValueStoreInterface* kv_store) { return [kv_store](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -845,22 +826,6 @@ static PJRT_KeyValueGetCallback ToCKVGetCallback( }; } -static PJRT_KeyValueTryGetCallback ToCKVTryGetCallback( - PJRT_KeyValueTryGetCFunc* kv_try_get_c_func) { - return [](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { - PJRT_KeyValueTryGetCFunc* kv_try_get_c_func = - reinterpret_cast(args->user_arg); - if (kv_try_get_c_func == nullptr) { - absl::Status status = xla::InvalidArgument( - "got nullptr for PJRT_KeyValueTryGet_Args.user_arg"); - return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()), - status.message().data(), - status.message().size()); - } - return (*kv_try_get_c_func)(args); - }; -} - static PJRT_KeyValuePutCallback ToCKVPutCallback( PJRT_KeyValuePutCFunc* kv_put_c_func) { return [](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -881,12 +846,9 @@ std::unique_ptr ConvertToCKeyValueCallbacks( std::shared_ptr kv_store) { auto kv_callback_data = std::make_unique(); kv_callback_data->kv_get_c_func = ToKVGetCFunc(kv_store.get()); - kv_callback_data->kv_try_get_c_func = ToKVTryGetCFunc(kv_store.get()); kv_callback_data->kv_put_c_func = ToKVPutCFunc(kv_store.get()); kv_callback_data->c_kv_get = ToCKVGetCallback(&kv_callback_data->kv_get_c_func); - kv_callback_data->c_kv_try_get = - ToCKVTryGetCallback(&kv_callback_data->kv_try_get_c_func); kv_callback_data->c_kv_put = ToCKVPutCallback(&kv_callback_data->kv_put_c_func); kv_callback_data->kv_store = std::move(kv_store); diff --git a/xla/pjrt/c/pjrt_c_api_helpers.h b/xla/pjrt/c/pjrt_c_api_helpers.h index baae41fbeca28d..f530b82f423573 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/xla/pjrt/c/pjrt_c_api_helpers.h @@ -218,9 +218,6 @@ int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc); using PJRT_KeyValueGetCFunc = std::function; -using PJRT_KeyValueTryGetCFunc = - std::function; - using PJRT_KeyValuePutCFunc = std::function; @@ -231,21 +228,17 @@ struct PJRT_KeyValueCallbackData { std::shared_ptr kv_store; - // kv_get_c_func, kv_try_get_c_func and kv_put_c_func are holding pointers to - // kv_store. + // kv_get_c_func and kv_put_c_func are holding pointers to kv_store. pjrt::PJRT_KeyValueGetCFunc kv_get_c_func; pjrt::PJRT_KeyValuePutCFunc kv_put_c_func; - // c_kv_get, c_kv_try_get and c_kv_put are holding pointers to kv_get_c_func, - // kv_try_get_c_func and kv_put_c_func. + // c_kv_get and c_kv_put are holding pointers to kv_get_c_func and + // kv_put_c_func. PJRT_KeyValueGetCallback c_kv_get; PJRT_KeyValuePutCallback c_kv_put; - pjrt::PJRT_KeyValueTryGetCFunc kv_try_get_c_func; - PJRT_KeyValueTryGetCallback c_kv_try_get; }; -// The returned &kv_get_c_func, &kv_try_get_c_func and &kv_put_c_func must be -// set as PJRT_Client_Create_Args.kv_get_user_arg, -// PJRT_Client_Create_Args.kv_try_get_user_arg and +// The returned &kv_get_c_func and &kv_put_c_func must be set as +// PJRT_Client_Create_Args.kv_get_user_arg and // PJRT_Client_Create_Args.kv_put_user_arg, respectively. The entire // PJRT_KeyValueCallbackData must be kept alive as long as c_kv_get and c_kv_put // may be called. diff --git a/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/xla/pjrt/c/pjrt_c_api_helpers_test.cc index 6dfce81a1e4514..4b8a59287589ed 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -108,22 +108,14 @@ TEST(PjRtCApiHelperTest, Callback) { auto kv_callback_data = ConvertToCKeyValueCallbacks(kv_store); auto converted_kv_store = ToCppKeyValueStore( kv_callback_data->c_kv_get, &kv_callback_data->kv_get_c_func, - kv_callback_data->c_kv_try_get, &kv_callback_data->kv_try_get_c_func, kv_callback_data->c_kv_put, &kv_callback_data->kv_put_c_func); - auto v_not_found = converted_kv_store->Get("key", absl::Seconds(1)); - EXPECT_TRUE(absl::IsNotFound(v_not_found.status())) << v_not_found.status(); - auto s = converted_kv_store->Set("key", "value"); TF_EXPECT_OK(s); auto v = converted_kv_store->Get("key", absl::Seconds(1)); TF_EXPECT_OK(v.status()); EXPECT_EQ(*v, "value"); - - auto v_2 = converted_kv_store->TryGet("key"); - TF_EXPECT_OK(v.status()); - EXPECT_EQ(*v, "value"); } TEST(PjRtCApiHelperTest, ConvertToCLayoutFromStrides) { diff --git a/xla/pjrt/c/pjrt_c_api_test_base.cc b/xla/pjrt/c/pjrt_c_api_test_base.cc index f867846ebcbd54..9602813c573c52 100644 --- a/xla/pjrt/c/pjrt_c_api_test_base.cc +++ b/xla/pjrt/c/pjrt_c_api_test_base.cc @@ -47,11 +47,9 @@ PJRT_Client* CreateClient(const PJRT_Api* api) { create_args.create_options = nullptr; create_args.num_options = 0; create_args.kv_get_callback = nullptr; - create_args.kv_get_user_arg = nullptr; create_args.kv_put_callback = nullptr; create_args.kv_put_user_arg = nullptr; - create_args.kv_try_get_callback = nullptr; - create_args.kv_try_get_user_arg = nullptr; + create_args.kv_get_user_arg = nullptr; PJRT_Error* error = api->PJRT_Client_Create(&create_args); CHECK_EQ(error, nullptr); CHECK_NE(create_args.client, nullptr); diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 395d72ffbdb02e..932ccfd349cc90 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -235,13 +235,9 @@ static absl::Status PopulateExecutableOutputMemoryKinds( class CApiKeyValueStore : public xla::KeyValueStoreInterface { public: CApiKeyValueStore(PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, - PJRT_KeyValueTryGetCallback c_try_get_callback, - void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) : c_get_callback_(c_get_callback), get_user_arg_(get_user_arg), - c_try_get_callback_(c_try_get_callback), - try_get_user_arg_(try_get_user_arg), c_put_callback_(c_put_callback), put_user_arg_(put_user_arg) {} @@ -268,27 +264,6 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { return result; } - absl::StatusOr TryGet(absl::string_view key) override { - PJRT_CallbackError callback_error = [](PJRT_Error_Code code, - const char* message, - size_t message_size) { - return new PJRT_Error{absl::Status(static_cast(code), - std::string(message, message_size))}; - }; - PJRT_KeyValueTryGetCallback_Args args; - args.key = key.data(); - args.key_size = key.size(); - args.callback_error = &callback_error; - args.user_arg = try_get_user_arg_; - std::unique_ptr error(c_try_get_callback_(&args)); - if (error != nullptr) { - return error->status; - } - auto result = std::string(args.value, args.value_size); - args.value_deleter_callback(args.value); - return result; - } - absl::Status Set(absl::string_view key, absl::string_view value) override { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, @@ -313,23 +288,18 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { private: PJRT_KeyValueGetCallback c_get_callback_; void* get_user_arg_; - PJRT_KeyValueTryGetCallback c_try_get_callback_; - void* try_get_user_arg_; PJRT_KeyValuePutCallback c_put_callback_; void* put_user_arg_; }; std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, - PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) { - if (c_get_callback == nullptr || c_try_get_callback == nullptr || - c_put_callback == nullptr) { + if (c_get_callback == nullptr || c_put_callback == nullptr) { return nullptr; } - return std::make_shared( - c_get_callback, get_user_arg, c_try_get_callback, try_get_user_arg, - c_put_callback, put_user_arg); + return std::make_shared(c_get_callback, get_user_arg, + c_put_callback, put_user_arg); } // ---------------------------------- Errors ----------------------------------- diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 873845d3ac815f..0ebecc0c251734 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -464,7 +464,6 @@ PJRT_Client* CreateWrapperClient(std::unique_ptr cpp_client); // Helper functions for converting C key-value store callbacks to C++ callbacks. std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, - PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg); // A method that does not nothing other than returning a nullptr. Can be used as diff --git a/xla/pjrt/distributed/client.cc b/xla/pjrt/distributed/client.cc index 305afe7ae4c6d4..280c60873e9d07 100644 --- a/xla/pjrt/distributed/client.cc +++ b/xla/pjrt/distributed/client.cc @@ -26,7 +26,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -54,7 +53,6 @@ class DistributedRuntimeCoordinationServiceClient absl::Status Shutdown() override; absl::StatusOr BlockingKeyValueGet( absl::string_view key, absl::Duration timeout) override; - absl::StatusOr KeyValueTryGet(absl::string_view key) override; absl::StatusOr>> KeyValueDirGet(absl::string_view key) override; absl::Status KeyValueSet(absl::string_view key, @@ -146,12 +144,6 @@ DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet( return coord_agent_->GetKeyValue(key, timeout); } -absl::StatusOr -DistributedRuntimeCoordinationServiceClient::KeyValueTryGet( - absl::string_view key) { - return coord_agent_->TryGetKeyValue(key); -} - absl::StatusOr>> DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( absl::string_view key) { @@ -224,10 +216,6 @@ class DistributedKeyValueStore : public KeyValueStoreInterface { return client_->BlockingKeyValueGet(absl::StrCat(prefix_, key), timeout); } - absl::StatusOr TryGet(absl::string_view key) override { - return client_->KeyValueTryGet(absl::StrCat(prefix_, key)); - } - absl::Status Set(absl::string_view key, absl::string_view value) override { return client_->KeyValueSet(absl::StrCat(prefix_, key), value); } diff --git a/xla/pjrt/distributed/client.h b/xla/pjrt/distributed/client.h index 58f4fe367681d2..e597ff158cc674 100644 --- a/xla/pjrt/distributed/client.h +++ b/xla/pjrt/distributed/client.h @@ -27,7 +27,6 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -117,9 +116,6 @@ class DistributedRuntimeClient { virtual absl::StatusOr BlockingKeyValueGet( absl::string_view key, absl::Duration timeout) = 0; - // Returns `NotFoundError` immediately if the key is not found. - virtual absl::StatusOr KeyValueTryGet(absl::string_view key) = 0; - // Get all key-value pairs under a directory (key). // A value is considered to be in the directory if its key is prefixed with // the directory. diff --git a/xla/pjrt/distributed/client_server_test.cc b/xla/pjrt/distributed/client_server_test.cc index baec103eced933..f5b7e656fe69a2 100644 --- a/xla/pjrt/distributed/client_server_test.cc +++ b/xla/pjrt/distributed/client_server_test.cc @@ -1029,20 +1029,6 @@ TEST_F(ClientServerTest, KeyValueSet_Duplicate_Overwrites) { EXPECT_EQ(result.value(), "overwritten_value"); } -TEST_F(ClientServerTest, KeyValueTryGet) { - StartService(/*num_nodes=*/1); - auto client = GetClient(/*node_id=*/0); - TF_ASSERT_OK(client->Connect()); - - ASSERT_THAT(client->KeyValueTryGet("test_key").status(), - StatusIs(absl::StatusCode::kNotFound)); - - TF_ASSERT_OK(client->KeyValueSet("test_key", "value")); - auto result = client->KeyValueTryGet("test_key"); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(result.value(), "value"); -} - TEST_F(ClientServerTest, KeyValueDelete) { StartService(/*num_nodes=*/1); auto client = GetClient(/*node_id=*/0); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.cc b/xla/pjrt/distributed/in_memory_key_value_store.cc index 49fc73ec87f163..70cc5360ecf7b3 100644 --- a/xla/pjrt/distributed/in_memory_key_value_store.cc +++ b/xla/pjrt/distributed/in_memory_key_value_store.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -41,17 +40,6 @@ absl::StatusOr InMemoryKeyValueStore::Get(absl::string_view key, return kv_store_.find(key)->second; } -absl::StatusOr InMemoryKeyValueStore::TryGet( - absl::string_view key) { - absl::MutexLock lock(&mu_); - auto it = kv_store_.find(key); - if (it == kv_store_.end()) { - return absl::NotFoundError( - absl::StrCat(key, " is not found in the kv store.")); - } - return it->second; -} - absl::Status InMemoryKeyValueStore::Set(absl::string_view key, absl::string_view value) { absl::MutexLock lock(&mu_); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.h b/xla/pjrt/distributed/in_memory_key_value_store.h index 13f50c722bd125..1530633a98b754 100644 --- a/xla/pjrt/distributed/in_memory_key_value_store.h +++ b/xla/pjrt/distributed/in_memory_key_value_store.h @@ -21,9 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "absl/time/time.h" #include "xla/pjrt/distributed/key_value_store_interface.h" namespace xla { @@ -33,8 +31,6 @@ class InMemoryKeyValueStore : public KeyValueStoreInterface { absl::StatusOr Get(absl::string_view key, absl::Duration timeout) override; - absl::StatusOr TryGet(absl::string_view key) override; - absl::Status Set(absl::string_view key, absl::string_view value) override; private: diff --git a/xla/pjrt/distributed/key_value_store_interface.h b/xla/pjrt/distributed/key_value_store_interface.h index 312ebb8abb6463..29580fb86847b1 100644 --- a/xla/pjrt/distributed/key_value_store_interface.h +++ b/xla/pjrt/distributed/key_value_store_interface.h @@ -38,18 +38,11 @@ class KeyValueStoreInterface { virtual ~KeyValueStoreInterface() = default; // Blocking Get(). - // Useful for listening for a key-value pair that may be set later on. // There are no concurrency guarantees. To avoid a race / impose an ordering // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). virtual absl::StatusOr Get(absl::string_view key, absl::Duration timeout) = 0; - // Returns `NotFoundError` immediately if the key is not found. - // Useful for checking key existence. - // There are no concurrency guarantees. To avoid a race / impose an ordering - // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). - virtual absl::StatusOr TryGet(absl::string_view key) = 0; - virtual absl::Status Set(absl::string_view key, absl::string_view value) = 0; }; diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index 1f65b13109afc6..8855ef33620e5f 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -2578,8 +2578,6 @@ absl::StatusOr> WrapClientAroundCApi( kv_callback_data = pjrt::ConvertToCKeyValueCallbacks(kv_store); init_args.kv_get_callback = kv_callback_data->c_kv_get; init_args.kv_get_user_arg = &kv_callback_data->kv_get_c_func; - init_args.kv_try_get_callback = kv_callback_data->c_kv_try_get; - init_args.kv_try_get_user_arg = &kv_callback_data->kv_try_get_c_func; init_args.kv_put_callback = kv_callback_data->c_kv_put; init_args.kv_put_user_arg = &kv_callback_data->kv_put_c_func; } diff --git a/xla/python/xla.cc b/xla/python/xla.cc index e30af5d4e5e43d..51c96229493e4c 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -672,21 +672,6 @@ NB_MODULE(xla_extension, m) { return nb::bytes(result.data(), result.size()); }, nb::arg("key"), nb::arg("timeout_in_ms")) - .def( - "key_value_try_get", - [](DistributedRuntimeClient& client, std::string key) { - nb::gil_scoped_release gil_release; - return xla::ValueOrThrow(client.KeyValueTryGet(key)); - }, - nb::arg("key")) - .def( - "key_value_try_get_bytes", - [](DistributedRuntimeClient& client, std::string key) -> nb::bytes { - nb::gil_scoped_release gil_release; - std::string result = xla::ValueOrThrow(client.KeyValueTryGet(key)); - return nb::bytes(result.data(), result.size()); - }, - nb::arg("key")) .def( "wait_at_barrier", [](DistributedRuntimeClient& client, std::string barrier_id, diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 5fa885f9f92255..2e3862285898f2 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -830,8 +830,6 @@ class DistributedRuntimeClient: def blocking_key_value_get_bytes( self, key: str, timeout_in_ms: int ) -> _Status: ... - def key_value_try_get(self, key: str) -> _Status: ... - def key_value_try_get_bytes(self, key: str) -> _Status: ... def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str, diff --git a/xla/service/cpu/ir_emitter.cc b/xla/service/cpu/ir_emitter.cc index a2498bb8b6e63a..bfafea513a3d69 100644 --- a/xla/service/cpu/ir_emitter.cc +++ b/xla/service/cpu/ir_emitter.cc @@ -224,7 +224,13 @@ absl::StatusOr IrEmitter::EmitComputation( std::string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; is_top_level_computation_ = is_top_level_computation; + + auto cleanup = absl::MakeCleanup( + [saved_allow_reassociation = allow_reassociation_, this]() { + allow_reassociation_ = saved_allow_reassociation; + }); allow_reassociation_ = allow_reassociation; + num_dynamic_loop_bounds_ = 0; auto backend_config_or = computation->root_instruction()->backend_config(); diff --git a/xla/service/cpu/ir_emitter.h b/xla/service/cpu/ir_emitter.h index e56a57ff97789f..926f6b6461ba37 100644 --- a/xla/service/cpu/ir_emitter.h +++ b/xla/service/cpu/ir_emitter.h @@ -637,7 +637,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::IRBuilderBase* current_builder_; std::stack compute_function_; mlir::MLIRContext* mlir_context_; - bool allow_reassociation_; + // The state of allow_reassociation_ is required so that that it is + // transitive to all nested computations. + bool allow_reassociation_ = false; // The buffer allocation slice for the root of the computation being compiled. // Only relevant for thread local computations. diff --git a/xla/service/cpu/ir_emitter2.cc b/xla/service/cpu/ir_emitter2.cc index ea63cb5a44a045..621fffbdfa3329 100644 --- a/xla/service/cpu/ir_emitter2.cc +++ b/xla/service/cpu/ir_emitter2.cc @@ -99,10 +99,8 @@ KernelApiIrBuilder::Options KernelApiIrBuilderOptionsFromHloModuleConfig( class IrEmitter2::ElementalIrEmitter : public CpuElementalIrEmitter { public: ElementalIrEmitter(llvm::Module* module, llvm::IRBuilderBase* b, - const HloModule* hlo_module, IrEmitter* nested_ir_emitter, - bool fast_min_max) + IrEmitter* nested_ir_emitter, bool fast_min_max) : CpuElementalIrEmitter(module, b, true, fast_min_max), - hlo_module_(hlo_module), nested_ir_emitter_(nested_ir_emitter), fast_min_max_(fast_min_max) {} @@ -110,43 +108,8 @@ class IrEmitter2::ElementalIrEmitter : public CpuElementalIrEmitter { absl::StatusOr> EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, absl::string_view name, bool is_reducer) override { - // Module must be scheduled to emit thread local computation. - if (!hlo_module_ || !hlo_module_->has_schedule()) { - return absl::InternalError( - "HLO module must be scheduled to emit thread local computation."); - } - - // Create a nested function for thread local computation(s) if it is not - // already created. Nested functions are created with internal linkage. - auto emit_computation = [&](const HloComputation* computation) { - if (!nested_ir_emitter_->is_computation_emitted(*computation, - is_reducer)) { - VLOG(2) << "Emit nested computation: " << computation->name(); - TF_RETURN_IF_ERROR( - nested_ir_emitter_ - ->EmitComputation( - const_cast(computation), name, false, - hlo_module_->schedule() - .sequence(computation) - .instructions(), - /*allow_reassociation=*/is_reducer, - /*function_attributes=*/{llvm::Attribute::AlwaysInline}) - .status()); - } - return absl::OkStatus(); - }; - - // We emit all embedded computations reachable through the `callee` to - // support nested thread local call, i.e., nested map computations. - for (HloComputation* embedded : callee.MakeEmbeddedComputationsList()) { - if (embedded->IsFusionComputation()) continue; - TF_RETURN_IF_ERROR(emit_computation(embedded)); - } - TF_RETURN_IF_ERROR(emit_computation(&callee)); - // Add a thread local call to the nested computation. VLOG(2) << "Emit thread local call to: " << callee.name(); - nested_ir_emitter_->b()->SetInsertPoint(b()->GetInsertPoint()); auto values = nested_ir_emitter_->EmitThreadLocalCall( callee, parameters, name, is_reducer, /*in_compute_function=*/false); @@ -156,7 +119,6 @@ class IrEmitter2::ElementalIrEmitter : public CpuElementalIrEmitter { bool fast_min_max() override { return fast_min_max_; } private: - const HloModule* hlo_module_; IrEmitter* nested_ir_emitter_; bool fast_min_max_; }; @@ -195,6 +157,8 @@ absl::StatusOr IrEmitter2::EmitElementalHostKernel( llvm::IRBuilder<> b(module_->getContext()); b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); + IrEmitter::IRBuilderGuard builder_guard = nested_ir_emitter_->WithBuilder(b); + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (int64_t i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); @@ -203,8 +167,16 @@ absl::StatusOr IrEmitter2::EmitElementalHostKernel( }; } - ElementalIrEmitter elemental_emitter(module_, &b, &hlo_module_, - nested_ir_emitter_, fast_min_max()); + if (instr->has_to_apply()) { + HloComputation* nested_computation = instr->to_apply(); + bool is_reducer = instr->opcode() == HloOpcode::kReduce || + instr->opcode() == HloOpcode::kReduceWindow; + TF_RETURN_IF_ERROR(EmitNestedComputation( + *nested_computation, llvm_ir::IrName(instr), is_reducer)); + } + + ElementalIrEmitter elemental_emitter(module_, &b, nested_ir_emitter_, + fast_min_max()); llvm_ir::ElementGenerator element_generator = elemental_emitter.MakeElementGenerator(instr, operand_to_generator); @@ -266,8 +238,14 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( llvm::IRBuilder<> b(module_->getContext()); b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - ElementalIrEmitter elemental_emitter(module_, &b, &hlo_module_, - nested_ir_emitter_, fast_min_max()); + IrEmitter::IRBuilderGuard builder_guard = nested_ir_emitter_->WithBuilder(b); + + HloComputation* nested_computation = fusion->fused_instructions_computation(); + TF_RETURN_IF_ERROR(EmitNestedComputation(*nested_computation, + llvm_ir::IrName(fusion), false)); + + ElementalIrEmitter elemental_emitter(module_, &b, nested_ir_emitter_, + fast_min_max()); FusedIrEmitter fused_emitter(elemental_emitter); for (int i = 0; i < fusion->operand_count(); i++) { @@ -911,6 +889,43 @@ absl::StatusOr IrEmitter2::EmitElementalLoops( return se::ThreadDim(); } +absl::Status IrEmitter2::EmitNestedComputation(const HloComputation& callee, + absl::string_view name, + bool is_reducer) { + // Module must be scheduled to emit thread local computation. + if (!hlo_module_.has_schedule()) { + return absl::InternalError( + "HLO module must be scheduled to emit thread local computation."); + } + + if (nested_ir_emitter_->is_computation_emitted(callee, is_reducer)) { + return absl::OkStatus(); + } + + for (HloInstruction* instr : callee.instructions()) { + bool nested_is_reducer = instr->opcode() == HloOpcode::kReduce || + instr->opcode() == HloOpcode::kReduceWindow; + for (HloComputation* called_computation : instr->called_computations()) { + // reassociation is transitive so we "or" the caller and the callee. + TF_RETURN_IF_ERROR( + EmitNestedComputation(*called_computation, llvm_ir::IrName(instr), + is_reducer || nested_is_reducer)); + } + } + + if (callee.IsFusionComputation()) { + return absl::OkStatus(); + } + + VLOG(2) << "Emit nested computation: " << callee.name(); + return nested_ir_emitter_ + ->EmitComputation(const_cast(&callee), name, false, + hlo_module_.schedule().sequence(&callee).instructions(), + /*allow_reassociation=*/is_reducer, + /*function_attributes=*/{llvm::Attribute::AlwaysInline}) + .status(); +} + // This is a convenience function taken from IrEmitter, it uses module_ class // field. If there will be more functions that use module_, we should consider // refactoring (like we did for compute_function_ and builder_). diff --git a/xla/service/cpu/ir_emitter2.h b/xla/service/cpu/ir_emitter2.h index eafaa99e123006..be7048414de2b0 100644 --- a/xla/service/cpu/ir_emitter2.h +++ b/xla/service/cpu/ir_emitter2.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" @@ -228,6 +229,9 @@ class IrEmitter2 { const KernelPrototype& kernel_prototype, const llvm_ir::ElementGenerator& element_generator); + absl::Status EmitNestedComputation(const HloComputation& callee, + absl::string_view name, bool is_reducer); + bool fast_min_max() const; // Returns the number of bytes within the shape. diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index 8ed220d5d3e192..a0ae574269ca37 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -374,6 +374,7 @@ gentbl_cc_library( cc_library( name = "xla_triton_passes", srcs = [ + "xla_triton_int4_passes.cc", "xla_triton_prevent_mmav3_loop_unrolling_pass.cc", "xla_triton_sparse_passes.cc", ], @@ -383,9 +384,12 @@ cc_library( deps = [ ":xla_triton", ":xla_triton_passes_inc_gen", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToNVVMTransforms", "@llvm-project//mlir:IR", @@ -393,6 +397,7 @@ cc_library( "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@triton//:TritonAnalysis", diff --git a/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc index 6bd49df697a7d9..2ce0a8039309b4 100644 --- a/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc +++ b/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc @@ -48,6 +48,8 @@ absl::Status CreateTritonPipeline( const int ccAsInt = cc.major * 10 + cc.minor; const int threadsPerWarp = 32; + pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass()); + // Based on make_ttir() in // @triton//:third_party/nvidia/backend/compiler.py pm->addPass(mlir::createInlinerPass()); diff --git a/xla/service/gpu/fusions/triton/xla_triton_int4_passes.cc b/xla/service/gpu/fusions/triton/xla_triton_int4_passes.cc new file mode 100644 index 00000000000000..091970f645ee5d --- /dev/null +++ b/xla/service/gpu/fusions/triton/xla_triton_int4_passes.cc @@ -0,0 +1,324 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir::triton::xla { + +using ::xla::llvm_ir::DumpToString; + +namespace mt = ::mlir::triton; +namespace ma = ::mlir::arith; + +#define GEN_PASS_DEF_LOADINT4REWRITEPASS +#include "xla/service/gpu/fusions/triton/xla_triton_passes.h.inc" + +class I4ToI8Converter : public TypeConverter { + public: + static Type convertIntegerType(IntegerType type) { + VLOG(10) << "I4ToI8Converter: converting IntegerType for " + << DumpToString(type); + if (type.getWidth() == 4) { + auto new_type = IntegerType::get(type.getContext(), 8); + VLOG(10) << " -> I4ToI8Converter: IntegerType converted to " + << DumpToString(new_type); + return new_type; + } + return type; + } + static Type convertRankedTensorType(RankedTensorType type) { + VLOG(10) << "I4ToI8Converter: RankedTensorType for " << DumpToString(type); + if (!type.getElementType().isInteger(4)) return type; + + auto shape = type.getShape(); + if (shape[0] == ShapedType::kDynamic) + return type; // Only handle static shapes for simplicity + + std::vector newShape(shape.begin(), shape.end()); + newShape[0] /= 2; + auto new_type = + RankedTensorType::get(newShape, IntegerType::get(type.getContext(), 8)); + VLOG(10) << " -> I4ToI8Converter: RankedTensorType converted to " + << DumpToString(new_type); + return new_type; + } + + PointerType convertPointerType(PointerType ptr_type) { + VLOG(10) << "I4ToI8Converter: converting PointerType for " + << DumpToString(ptr_type); + auto pointee_type = ptr_type.getPointeeType(); + auto new_pointee_type = convertType(pointee_type); + auto new_ptr_type = + PointerType::get(new_pointee_type, ptr_type.getAddressSpace()); + VLOG(10) << " -> I4ToI8Converter: converted PointerType to " + << DumpToString(new_ptr_type); + return new_ptr_type; + } + Type convertFunctionType(FunctionType func_type) { + VLOG(10) << "I4ToI8Converter: converting FunctionType " + << DumpToString(func_type); + + SmallVector inputs; + if (failed(convertTypes(func_type.getInputs(), inputs))) return func_type; + + SmallVector results; + if (failed(convertTypes(func_type.getResults(), results))) return func_type; + + auto new_func_type = + FunctionType::get(func_type.getContext(), inputs, results); + VLOG(10) << " -> I4ToI8Converter: converted FunctionType to " + << DumpToString(new_func_type); + return new_func_type; + } + + I4ToI8Converter() { + // Passthrough for other types. + addConversion([](Type type) { + VLOG(10) << "I4ToI8Converter: passthrough for " << DumpToString(type); + return type; + }); + + // Convert i4 to i8 + addConversion( + [this](IntegerType type) { return this->convertIntegerType(type); }); + + // Convert tensor to tensor + addConversion([this](RankedTensorType type) { + return this->convertRankedTensorType(type); + }); + + // Convert !tt.ptr> to !tt.ptr> + addConversion( + [this](PointerType type) { return this->convertPointerType(type); }); + + // Convert function type to function type + addConversion( + [this](FunctionType type) { return this->convertFunctionType(type); }); + } +}; + +class MakeTensorPtrOpConversionPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + MakeTensorPtrOp op, + OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + // Convert the tensor type using the TypeConverter + auto new_type = getTypeConverter()->convertType(op.getType()); + if (op.getType() == new_type) { + return r.notifyMatchFailure(op, "no conversion needed"); + } + + auto loc = op.getLoc(); + Value c2 = + r.create(loc, r.getIntegerAttr(r.getI64Type(), 2)); + SmallVector shape{adaptor.getShape().begin(), + adaptor.getShape().end()}; + // The packing dim is major and it should twice smaller. + shape[0] = r.create(loc, shape[0], c2); + + // The packing dim is major and the other stride should be half of the + // original one. + SmallVector new_strides = adaptor.getStrides(); + new_strides[1] = r.create(loc, new_strides[1], c2); + + r.replaceOpWithNewOp( + op, new_type, adaptor.getBase(), shape, new_strides, + adaptor.getOffsets(), adaptor.getOrderAttr()); + + return success(); + } +}; + +class AddPtrOpConversionPattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + AddPtrOp op, OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + // Convert the tensor type using the TypeConverter + auto new_type = getTypeConverter()->convertType(op.getType()); + if (op.getType() == new_type) { + return r.notifyMatchFailure(op, "no conversion needed"); + } + + // The increment for the next stripe of tiles along K dimension should be + // twice smaller. + auto ptr = adaptor.getOperands()[0]; + auto offset = adaptor.getOperands()[1]; + auto offset_type = offset.getType(); + Value c2 = + r.create(op.getLoc(), r.getIntegerAttr(offset_type, 2)); + auto new_offset = + r.create(op.getLoc(), offset_type, offset, c2); + + r.replaceOpWithNewOp(op, new_type, ptr, new_offset); + + return success(); + } +}; + +template +class OpTypeConversionPattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + OpType op, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + VLOG(10) << "OpTypeConversionPattern: matching\n" + << DumpToString(static_cast(op.getOperation())); + // Convert the tensor type using the TypeConverter + auto new_type = + OpConversionPattern::getTypeConverter()->convertType( + op.getType()); + if (op.getType() == new_type) { + VLOG(10) << "OpTypeConversionPattern: no conversion needed for " + << DumpToString(op.getType()); + return r.notifyMatchFailure(op, "no conversion needed"); + } + + r.replaceOpWithNewOp(op, new_type, adaptor.getOperands(), + op->getAttrs()); + return success(); + } +}; + +// The pattern converts the ExtSIOp that converts i4 tensor to i8 tensor to the +// unpack sequence with ShLIOp, ShRSIOp, JoinOp, TransOp and ReshapeOp that does +// the same thing. +class ExtSIInt4ToInt8Pattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ma::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &r) const override { + auto i4_tensor = cast(op.getType()); + const auto &operand_type = cast(op.getIn().getType()); + + auto i4_type = r.getI4Type(); + auto i8_type = r.getI8Type(); + + if (operand_type.getElementType() != i4_type) { + return r.notifyMatchFailure(op, "not i4 operand"); + } + + // Make a new i8 tensor with the shape that is half of the int4 tensor. + SmallVector result_shape(i4_tensor.getShape()); + result_shape[0] /= 2; + auto i8_tensor = RankedTensorType::get(result_shape, i8_type); + + auto loc = op.getLoc(); + + Value shift4_const = + r.create(loc, r.getIntegerAttr(i8_type, 4)); + Value shift4 = r.create(loc, i8_tensor, shift4_const); + Value shifted_lo = + r.create(loc, i8_tensor, adaptor.getIn(), shift4); + Value lo = r.create(loc, i8_tensor, shifted_lo, shift4); + Value hi = r.create(loc, i8_tensor, adaptor.getIn(), shift4); + Value hi_lo = r.create(loc, hi, lo); + auto trans_attr = r.getDenseI32ArrayAttr({0, 2, 1}); + + Value trans_hi_lo = r.create(loc, hi_lo, trans_attr); + + r.replaceOpWithNewOp(op, i4_tensor, trans_hi_lo, + /*allow_reorder=*/false); + return success(); + } +}; + +struct PlainInt4ToPackedInt4RewritePass + : public impl::LoadInt4RewritePassBase { + void runOnOperation() override { + auto *ctx = &getContext(); + auto module = getOperation(); + + ConversionTarget target(*ctx); + + VLOG(10) << "before TypeRewrite rewrite"; + { + I4ToI8Converter converter; + ConversionTarget target(*ctx); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + if (auto func_op = dyn_cast(op)) { + VLOG(10) << "check funcOp: " << DumpToString(func_op); + if (func_op.getFunctionType() != + converter.convertType(func_op.getFunctionType())) { + VLOG(10) << "funcOp not legal: " << DumpToString(func_op); + return false; + } + } + bool is_legal = converter.isLegal(op); + VLOG(10) << "is_legal: " << is_legal << " for " << DumpToString(op); + return is_legal; + }); + RewritePatternSet patterns(ctx); + scf::populateSCFStructuralTypeConversions(converter, patterns); + patterns.add(ctx); + patterns.add>(converter, ctx); + patterns.add>(converter, ctx); + patterns.add(converter, ctx); + patterns.add(converter, ctx); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + VLOG(10) << "failed to apply partial conversion"; + signalPassFailure(); + } + } + VLOG(10) << "after TypeRewrite Module: " << DumpToString(module); + } +}; + +// The pass converts the types like tensor to tensor in the +// Triton dialect and replaces the ExtSIOp with the unpack sequence that accepts +// twice smaller i8 tensor and convert it to the twice bigger i8 tensor where +// every i4 element uses i8 space. At the end the module accepts the tt.ptr +// to the packed i4 tensor, and unpacks it to the i8 tensor for the further +// processing. It expects that the i4 tensor is packed along the major +// dimension. +std::unique_ptr CreateInt4ToPackedInt4RewritePass() { + return std::make_unique(); +} + +} // namespace mlir::triton::xla diff --git a/xla/service/gpu/fusions/triton/xla_triton_passes.h b/xla/service/gpu/fusions/triton/xla_triton_passes.h index 10f5e684cb5516..67034fe1df1897 100644 --- a/xla/service/gpu/fusions/triton/xla_triton_passes.h +++ b/xla/service/gpu/fusions/triton/xla_triton_passes.h @@ -36,6 +36,7 @@ std::unique_ptr CreateSparseLocalLoadToLLVMPass(); std::unique_ptr CreateSparseDotOpToLLVMPass(); std::unique_ptr CreateSparseWGMMAOpToLLVMPass(); std::unique_ptr CreatePreventMmaV3LoopUnrollingPass(); +std::unique_ptr CreateInt4ToPackedInt4RewritePass(); // Returns true if the `op` contains an operation in it's regions that satisfies // the `fn`. diff --git a/xla/service/gpu/fusions/triton/xla_triton_passes.td b/xla/service/gpu/fusions/triton/xla_triton_passes.td index 49e003e392ed15..21db540475b390 100644 --- a/xla/service/gpu/fusions/triton/xla_triton_passes.td +++ b/xla/service/gpu/fusions/triton/xla_triton_passes.td @@ -95,4 +95,15 @@ def PreventMmaV3LoopUnrollingPass let constructor = "CreatePreventMmaV3LoopUnrollingPass()"; } +def LoadInt4RewritePass + : Pass<"int4-to-packed-int4-rewrite", "mlir::ModuleOp"> { + let summary = "Converts ops with int4 tensors to the ops with int4 packed to int8 tensors."; + let description = [{ + This pass replaces the int4 tensors with the int4 packed to int8 tensor of + the twice smaller size. It also replaces the plain ExtSIOp upcast to the + int8 tensor with the unpack sequence. + }]; + let constructor = "CreateInt4ToPackedInt4RewritePass()"; +} + #endif // XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_PASSES_TD_ diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index f4e87d417eec5a..faeaa7a6c46679 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1584,9 +1584,6 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( if ((cuda_cc != nullptr && cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || rocm_cc != nullptr) { - // Triton compilation needs normalized operations on bf16 (i.e. converted - // to f32). - add_float_normalization(pipeline); pipeline.AddPass>(simplifier_options, gpu_version); pipeline.AddPass(/*is_layout_sensitive=*/true); diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index e8d8a04f1a93ec..33214758e230fd 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1471,8 +1471,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, XlaBuilder builder(TestName()); std::string ref_bnth = R"( custom-call.4.0 = ( - bf16[4,4,16,16]{3,1,2,0}, - u8[0]{0} + bf16[4,4,16,16]{3,1,2,0} ) custom-call( convert.19, convert.31, @@ -1546,8 +1545,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, custom-call.21.0 = ( f8e4m3fn[4,4,16,16]{3,1,2,0}, f32[1,1,1,1]{3,2,1,0}, - f32[1,1,1,1]{3,2,1,0}, - u8[16]{0} + f32[1,1,1,1]{3,2,1,0} ) custom-call( convert.18, convert.30, @@ -1652,8 +1650,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, std::string ref_btnh = R"( custom-call.4.0 = ( - bf16[4,16,4,16]{3,2,1,0}, - u8[0]{0} + bf16[4,16,4,16]{3,2,1,0} ) custom-call( convert.19, convert.31, @@ -1726,8 +1723,7 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, custom-call.21.0 = ( f8e4m3fn[4,16,4,16]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, - f32[1,1,1,1]{3,2,1,0}, - u8[16]{0} + f32[1,1,1,1]{3,2,1,0} ) custom-call( convert.18, convert.30, diff --git a/xla/service/gpu/tests/int4_to_packed_int4.mlir b/xla/service/gpu/tests/int4_to_packed_int4.mlir new file mode 100644 index 00000000000000..29cdd45524d57c --- /dev/null +++ b/xla/service/gpu/tests/int4_to_packed_int4.mlir @@ -0,0 +1,110 @@ +// RUN: xla-opt --int4-to-packed-int4-rewrite %s --mlir-print-ir-after-all + +module { + tt.func @gemm_fusion_dot_2_impl(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + %0 = tt.get_program_id x : i32 + %c16_i32 = arith.constant 16 : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %c8_i32 = arith.constant 8 : i32 + %2 = arith.muli %1, %c8_i32 : i32 + %c1_i32 = arith.constant 1 : i32 + %3 = arith.subi %c1_i32, %2 : i32 + %4 = arith.cmpi slt, %3, %c8_i32 : i32 + %5 = arith.select %4, %3, %c8_i32 : i32 + %6 = arith.remsi %0, %5 : i32 + %7 = arith.addi %2, %6 : i32 + %c16_i32_0 = arith.constant 16 : i32 + %8 = arith.remsi %0, %c16_i32_0 : i32 + %9 = arith.divsi %8, %5 : i32 + %c128_i32 = arith.constant 128 : i32 + %10 = arith.muli %7, %c128_i32 : i32 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %11 = arith.addi %10, %c0_i32 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i32_1 = arith.constant 0 : i32 + %c128_i64_2 = arith.constant 128 : i64 + %c0_i32_3 = arith.constant 0 : i32 + %c128_i64_4 = arith.constant 128 : i64 + %c0_i32_5 = arith.constant 0 : i32 + %12 = arith.addi %c0_i32_3, %c0_i32_5 : i32 + %c64_i64 = arith.constant 64 : i64 + %c0_i32_6 = arith.constant 0 : i32 + %c64_i64_7 = arith.constant 64 : i64 + %c8192_i32 = arith.constant 8192 : i32 + %13 = tt.get_program_id y : i32 + %c0_i32_8 = arith.constant 0 : i32 + %14 = arith.addi %c0_i32_8, %13 : i32 + %15 = arith.muli %14, %c8192_i32 : i32 + %16 = tt.addptr %arg0, %15 : !tt.ptr, i32 + %17 = tt.make_tensor_ptr %16, [%c128_i64_2, %c64_i64_7], [%c1_i64, %c128_i64_4], [%c0_i32_1, %c0_i32_6] {order = array} : > + %18 = tt.advance %17, [%10, %c0_i32_3] : > + %c0_i32_9 = arith.constant 0 : i32 + %c256_i64 = arith.constant 256 : i64 + %c0_i32_10 = arith.constant 0 : i32 + %19 = arith.addi %c0_i32_9, %c0_i32_10 : i32 + %c64_i64_11 = arith.constant 64 : i64 + %c0_i32_12 = arith.constant 0 : i32 + %c64_i64_13 = arith.constant 64 : i64 + %c128_i32_14 = arith.constant 128 : i32 + %20 = arith.muli %9, %c128_i32_14 : i32 + %c1_i64_15 = arith.constant 1 : i64 + %c0_i32_16 = arith.constant 0 : i32 + %21 = arith.addi %20, %c0_i32_16 : i32 + %c256_i64_17 = arith.constant 256 : i64 + %c0_i32_18 = arith.constant 0 : i32 + %c256_i64_19 = arith.constant 256 : i64 + %c16384_i32 = arith.constant 16384 : i32 + %22 = tt.get_program_id y : i32 + %c0_i32_20 = arith.constant 0 : i32 + %23 = arith.addi %c0_i32_20, %22 : i32 + %24 = arith.muli %23, %c16384_i32 : i32 + %25 = tt.addptr %arg1, %24 : !tt.ptr, i32 + %26 = tt.make_tensor_ptr %25, [%c64_i64_13, %c256_i64_19], [%c256_i64, %c1_i64_15], [%c0_i32_12, %c0_i32_18] {order = array} : > + %27 = tt.advance %26, [%c0_i32_9, %20] : > + %c0_i32_21 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c32_i32 = arith.constant 32 : i32 + %28:3 = scf.for %arg3 = %c0_i32_21 to %c64_i32 step %c32_i32 iter_args(%arg4 = %18, %arg5 = %27, %arg6 = %cst) -> (!tt.ptr>, !tt.ptr>, tensor<128x128xf32>) : i32 { + %39 = tt.load %arg4 : !tt.ptr> + %c0_i32_35 = arith.constant 0 : i32 + %c32_i32_36 = arith.constant 32 : i32 + %40 = tt.advance %arg4, [%c0_i32_35, %c32_i32_36] : > + %41 = tt.load %arg5 : !tt.ptr> + %c32_i32_37 = arith.constant 32 : i32 + %c0_i32_38 = arith.constant 0 : i32 + %42 = tt.advance %arg5, [%c32_i32_37, %c0_i32_38] : > + %43 = arith.extsi %39 : tensor<128x32xi4> to tensor<128x32xi8> + %44 = arith.sitofp %43 : tensor<128x32xi8> to tensor<128x32xf32> + %45 = tt.dot %44, %41, %arg6 : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> + scf.yield %40, %42, %45 : !tt.ptr>, !tt.ptr>, tensor<128x128xf32> + } + %c128_i32_22 = arith.constant 128 : i32 + %29 = arith.muli %7, %c128_i32_22 : i32 + %c256_i64_23 = arith.constant 256 : i64 + %c0_i32_24 = arith.constant 0 : i32 + %30 = arith.addi %29, %c0_i32_24 : i32 + %c128_i64_25 = arith.constant 128 : i64 + %c0_i32_26 = arith.constant 0 : i32 + %c128_i64_27 = arith.constant 128 : i64 + %c128_i32_28 = arith.constant 128 : i32 + %31 = arith.muli %9, %c128_i32_28 : i32 + %c1_i64_29 = arith.constant 1 : i64 + %c0_i32_30 = arith.constant 0 : i32 + %32 = arith.addi %31, %c0_i32_30 : i32 + %c256_i64_31 = arith.constant 256 : i64 + %c0_i32_32 = arith.constant 0 : i32 + %c256_i64_33 = arith.constant 256 : i64 + %c32768_i32 = arith.constant 32768 : i32 + %33 = tt.get_program_id y : i32 + %c0_i32_34 = arith.constant 0 : i32 + %34 = arith.addi %c0_i32_34, %33 : i32 + %35 = arith.muli %34, %c32768_i32 : i32 + %36 = tt.addptr %arg2, %35 : !tt.ptr, i32 + %37 = tt.make_tensor_ptr %36, [%c128_i64_27, %c256_i64_33], [%c256_i64_23, %c1_i64_29], [%c0_i32_26, %c0_i32_32] {order = array} : > + %38 = tt.advance %37, [%29, %31] : > + tt.store %38, %28#2 : !tt.ptr> + tt.return + } +} diff --git a/xla/service/gpu/tests/int4_to_packed_int4_small.mlir b/xla/service/gpu/tests/int4_to_packed_int4_small.mlir new file mode 100644 index 00000000000000..a7323a4afaed8b --- /dev/null +++ b/xla/service/gpu/tests/int4_to_packed_int4_small.mlir @@ -0,0 +1,12 @@ +// RUN: xla-opt --int4-to-packed-int4-rewrite %s + +module { + tt.func @dot_test(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<16x16xi8> { + %c0 = arith.constant 0 : i32 + %c16 = arith.constant 16: i64 + %0 = tt.make_tensor_ptr %arg0, [%c16, %c16], [%c16, %c16], [%c0, %c0] {order = array} : > + %1 = tt.load %0 : !tt.ptr> + %2 = arith.extsi %1 : tensor<16x16xi4> to tensor<16x16xi8> + tt.return %2 : tensor<16x16xi8> + } +} diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index b711f3142f3328..0dc92c47d2cb55 100644 --- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -393,6 +393,9 @@ class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { : dnn_support_(dnn_support), compilation_results_(compilation_results) {} void AddWorkspace(HloInstruction &hlo, int64_t workspace_size) { + if (workspace_size == 0) { + return; + } VLOG(4) << "Applying workspace size " << workspace_size << " to " << hlo.ToString(); Shape *shape = hlo.mutable_shape();