From a5ba283bc14503f9c308950deb5730053e7cf945 Mon Sep 17 00:00:00 2001 From: Isha Arkatkar Date: Fri, 10 Jan 2025 16:31:38 -0800 Subject: [PATCH] [Coordination Service]Allow restartable tasks to connect back to cluster, as long as they have the same local topology as before. PiperOrigin-RevId: 714267850 --- xla/pjrt/distributed/BUILD | 3 +- xla/pjrt/distributed/topology_util.cc | 62 ++++++++++++++- xla/pjrt/distributed/topology_util_test.cc | 91 +++++++++++++++++++++- 3 files changed, 150 insertions(+), 6 deletions(-) diff --git a/xla/pjrt/distributed/BUILD b/xla/pjrt/distributed/BUILD index 963d5c697c58c..21e5c7bcfc6fb 100644 --- a/xla/pjrt/distributed/BUILD +++ b/xla/pjrt/distributed/BUILD @@ -49,8 +49,8 @@ xla_cc_test( ":in_memory_key_value_store", ":protocol_proto_cc", ":topology_util", - "//xla:test_helpers", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:env", @@ -115,7 +115,6 @@ cc_library( ":key_value_store_interface", ":protocol_proto_cc", "//xla:util", - "//xla/pjrt:pjrt_client", "//xla/pjrt:utils", "//xla/pjrt/gpu:gpu_topology_proto_cc", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/pjrt/distributed/topology_util.cc b/xla/pjrt/distributed/topology_util.cc index d22446a663184..ca08bbb530f2c 100644 --- a/xla/pjrt/distributed/topology_util.cc +++ b/xla/pjrt/distributed/topology_util.cc @@ -16,6 +16,8 @@ limitations under the License. #include "xla/pjrt/distributed/topology_util.h" #include +#include +#include #include #include #include @@ -28,13 +30,13 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" #include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/distributed/protocol.pb.h" -#include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/utils.h" #include "xla/util.h" #include "tsl/platform/env.h" @@ -45,6 +47,34 @@ limitations under the License. namespace xla { +namespace { +bool SameDevice(const DeviceProto& a, const DeviceProto& b) { + return (a.name() == b.name() && a.vendor() == b.vendor() && + a.local_device_ordinal() == b.local_device_ordinal() && + a.core_count() == b.core_count() && + a.device_kind() == b.device_kind() && + a.slice_index() == b.slice_index() && + // Global device ID Might not be set for LocalTopologyProto, still + // check it for default value. + a.global_device_id() == b.global_device_id() && + a.compute_capability() == b.compute_capability()); +} + +bool SameLocalTopology(const LocalTopologyProto& a, + const LocalTopologyProto& b) { + if (a.node_id() != b.node_id() || a.devices_size() != b.devices_size()) { + return false; + } + for (int i = 0; i < a.devices_size(); ++i) { + if (!SameDevice(a.devices(i), b.devices(i))) { + return false; + } + } + return true; +} + +} // namespace + // Exists on Linux systems. Unique per OS kernel restart. static constexpr char kBootIdPath[] = "/proc/sys/kernel/random/boot_id"; @@ -179,8 +209,34 @@ absl::Status ExchangeTopologies(absl::string_view platform, int node_id, return absl::OkStatus(); } CHECK(kv_store != nullptr); - TF_RETURN_IF_ERROR(kv_store->Set(GetLocalTopologyKey(platform, node_id), - local_topology.SerializeAsString())); + const std::string local_topology_key = GetLocalTopologyKey(platform, node_id); + const std::string serialized_local_topology = + local_topology.SerializeAsString(); + + absl::StatusOr existing_local_topology = + kv_store->TryGet(local_topology_key); + printf("existing_local_topology status: %s\n", + existing_local_topology.status().ToString().c_str()); + + if (existing_local_topology.ok()) { + printf("existing topology found"); + // Local topology has been set previously from the same node before + // restart. + LocalTopologyProto existing_local_topology_proto; + existing_local_topology_proto.ParseFromString(*existing_local_topology); + if (!SameLocalTopology(existing_local_topology_proto, local_topology)) { + return absl::InternalError(absl::Substitute( + "Different local topology for node $0 has been set previously, " + "possibly before a restart.\nBefore: $1\nAfter: $2", + node_id, existing_local_topology_proto.DebugString(), + local_topology.DebugString())); + } + } else if (absl::IsNotFound(existing_local_topology.status())) { + TF_RETURN_IF_ERROR(kv_store->Set(GetLocalTopologyKey(platform, node_id), + serialized_local_topology)); + } else { + return existing_local_topology.status(); + } // The lead node gets all local topologies, builds the global topology and // puts it to the key-value store. diff --git a/xla/pjrt/distributed/topology_util_test.cc b/xla/pjrt/distributed/topology_util_test.cc index 1ad4dda2c01cd..06464dc9b1b1b 100644 --- a/xla/pjrt/distributed/topology_util_test.cc +++ b/xla/pjrt/distributed/topology_util_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/distributed/protocol.pb.h" -#include "xla/test_helpers.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" @@ -31,6 +31,7 @@ limitations under the License. namespace xla { namespace { +using tsl::testing::StatusIs; TEST(TopologyTest, BuildGlobalTopology) { std::vector locals(2); @@ -86,6 +87,94 @@ TEST(TopologyTest, ExchangeTopology) { } } +TEST(TopologyTest, ExchangeTopology_Twice_Succeeds) { + int num_nodes = 2; + std::vector locals(num_nodes); + DeviceProto* d0 = locals[0].add_devices(); + d0->set_local_device_ordinal(0); + DeviceProto* d1 = locals[0].add_devices(); + d1->set_local_device_ordinal(0); + DeviceProto* d2 = locals[1].add_devices(); + d2->set_local_device_ordinal(0); + DeviceProto* d3 = locals[1].add_devices(); + d3->set_local_device_ordinal(1); + + InMemoryKeyValueStore kv_store; + std::vector globals(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "TestPool", + num_nodes); + for (int i = 0; i < num_nodes; i++) { + thread_pool.Schedule([&, i] { + TF_ASSERT_OK(ExchangeTopologies( + /*platform=*/"cuda", /*node_id=*/i, num_nodes, + /*get_local_topology_timeout=*/ + absl::Seconds(10), /*get_global_topology_timeout=*/ + absl::Seconds(10), &kv_store, locals[i], &globals[i], + /*assign_global_device_ids=*/true)); + // Simulate node 1 restarting and exchanging topologies again. + if (i == 1) { + TF_ASSERT_OK(ExchangeTopologies( + /*platform=*/"cuda", /*node_id=*/i, num_nodes, + /*get_local_topology_timeout=*/ + absl::Seconds(10), /*get_global_topology_timeout=*/ + absl::Seconds(10), &kv_store, locals[i], &globals[i], + /*assign_global_device_ids=*/true)); + } + }); + } + } + for (const GlobalTopologyProto& global : globals) { + EXPECT_EQ(global.nodes_size(), 2); + EXPECT_EQ(global.nodes()[0].devices_size(), 2); + EXPECT_EQ(global.nodes()[1].devices_size(), 2); + } +} + +TEST(TopologyTest, ExchangeTopology_TwiceWithDifferentLocalTopology_Fails) { + int num_nodes = 2; + std::vector locals(num_nodes); + DeviceProto* d0 = locals[0].add_devices(); + d0->set_local_device_ordinal(0); + DeviceProto* d1 = locals[0].add_devices(); + d1->set_local_device_ordinal(0); + DeviceProto* d2 = locals[1].add_devices(); + d2->set_local_device_ordinal(0); + DeviceProto* d3 = locals[1].add_devices(); + d3->set_local_device_ordinal(1); + + InMemoryKeyValueStore kv_store; + std::vector globals(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "TestPool", + num_nodes); + for (int i = 0; i < num_nodes; i++) { + thread_pool.Schedule([&, i] { + TF_ASSERT_OK(ExchangeTopologies( + /*platform=*/"cuda", /*node_id=*/i, num_nodes, + /*get_local_topology_timeout=*/ + absl::Seconds(10), /*get_global_topology_timeout=*/ + absl::Seconds(10), &kv_store, locals[i], &globals[i], + /*assign_global_device_ids=*/true)); + // Simulate node 1 restarting with different devices. + if (i == 1) { + DeviceProto* d4 = locals[1].add_devices(); + d4->set_local_device_ordinal(2); + // This should fail because the local topology is unexpectedly + // different. + EXPECT_THAT(ExchangeTopologies( + /*platform=*/"cuda", /*node_id=*/i, num_nodes, + /*get_local_topology_timeout=*/ + absl::Seconds(10), /*get_global_topology_timeout=*/ + absl::Seconds(10), &kv_store, locals[i], &globals[i], + /*assign_global_device_ids=*/true), + StatusIs(absl::StatusCode::kInternal)); + } + }); + } + } +} + TEST(TopologyTest, BuildGpuTopology) { std::string slice_0_boot_id = "foo"; std::string slice_1_boot_id = "bar";