Skip to content

Commit

Permalink
[Coordination Service]Allow restartable tasks to connect back to clus…
Browse files Browse the repository at this point in the history
…ter, as long as they have the same local topology as before.

PiperOrigin-RevId: 714267850
  • Loading branch information
ishark authored and Google-ML-Automation committed Jan 11, 2025
1 parent 1ad2d32 commit a5ba283
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 6 deletions.
3 changes: 1 addition & 2 deletions xla/pjrt/distributed/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ xla_cc_test(
":in_memory_key_value_store",
":protocol_proto_cc",
":topology_util",
"//xla:test_helpers",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:env",
Expand Down Expand Up @@ -115,7 +115,6 @@ cc_library(
":key_value_store_interface",
":protocol_proto_cc",
"//xla:util",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:utils",
"//xla/pjrt/gpu:gpu_topology_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
62 changes: 59 additions & 3 deletions xla/pjrt/distributed/topology_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#include "xla/pjrt/distributed/topology_util.h"

#include <algorithm>
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <map>
#include <set>
Expand All @@ -28,13 +30,13 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/substitute.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/distributed/protocol.pb.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/utils.h"
#include "xla/util.h"
#include "tsl/platform/env.h"
Expand All @@ -45,6 +47,34 @@ limitations under the License.

namespace xla {

namespace {
bool SameDevice(const DeviceProto& a, const DeviceProto& b) {
return (a.name() == b.name() && a.vendor() == b.vendor() &&
a.local_device_ordinal() == b.local_device_ordinal() &&
a.core_count() == b.core_count() &&
a.device_kind() == b.device_kind() &&
a.slice_index() == b.slice_index() &&
// Global device ID Might not be set for LocalTopologyProto, still
// check it for default value.
a.global_device_id() == b.global_device_id() &&
a.compute_capability() == b.compute_capability());
}

bool SameLocalTopology(const LocalTopologyProto& a,
const LocalTopologyProto& b) {
if (a.node_id() != b.node_id() || a.devices_size() != b.devices_size()) {
return false;
}
for (int i = 0; i < a.devices_size(); ++i) {
if (!SameDevice(a.devices(i), b.devices(i))) {
return false;
}
}
return true;
}

} // namespace

// Exists on Linux systems. Unique per OS kernel restart.
static constexpr char kBootIdPath[] = "/proc/sys/kernel/random/boot_id";

Expand Down Expand Up @@ -179,8 +209,34 @@ absl::Status ExchangeTopologies(absl::string_view platform, int node_id,
return absl::OkStatus();
}
CHECK(kv_store != nullptr);
TF_RETURN_IF_ERROR(kv_store->Set(GetLocalTopologyKey(platform, node_id),
local_topology.SerializeAsString()));
const std::string local_topology_key = GetLocalTopologyKey(platform, node_id);
const std::string serialized_local_topology =
local_topology.SerializeAsString();

absl::StatusOr<std::string> existing_local_topology =
kv_store->TryGet(local_topology_key);
printf("existing_local_topology status: %s\n",
existing_local_topology.status().ToString().c_str());

if (existing_local_topology.ok()) {
printf("existing topology found");
// Local topology has been set previously from the same node before
// restart.
LocalTopologyProto existing_local_topology_proto;
existing_local_topology_proto.ParseFromString(*existing_local_topology);
if (!SameLocalTopology(existing_local_topology_proto, local_topology)) {
return absl::InternalError(absl::Substitute(
"Different local topology for node $0 has been set previously, "
"possibly before a restart.\nBefore: $1\nAfter: $2",
node_id, existing_local_topology_proto.DebugString(),
local_topology.DebugString()));
}
} else if (absl::IsNotFound(existing_local_topology.status())) {
TF_RETURN_IF_ERROR(kv_store->Set(GetLocalTopologyKey(platform, node_id),
serialized_local_topology));
} else {
return existing_local_topology.status();
}

// The lead node gets all local topologies, builds the global topology and
// puts it to the key-value store.
Expand Down
91 changes: 90 additions & 1 deletion xla/pjrt/distributed/topology_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ limitations under the License.
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/distributed/protocol.pb.h"
#include "xla/test_helpers.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/statusor.h"
Expand All @@ -31,6 +31,7 @@ limitations under the License.

namespace xla {
namespace {
using tsl::testing::StatusIs;

TEST(TopologyTest, BuildGlobalTopology) {
std::vector<LocalTopologyProto> locals(2);
Expand Down Expand Up @@ -86,6 +87,94 @@ TEST(TopologyTest, ExchangeTopology) {
}
}

TEST(TopologyTest, ExchangeTopology_Twice_Succeeds) {
int num_nodes = 2;
std::vector<LocalTopologyProto> locals(num_nodes);
DeviceProto* d0 = locals[0].add_devices();
d0->set_local_device_ordinal(0);
DeviceProto* d1 = locals[0].add_devices();
d1->set_local_device_ordinal(0);
DeviceProto* d2 = locals[1].add_devices();
d2->set_local_device_ordinal(0);
DeviceProto* d3 = locals[1].add_devices();
d3->set_local_device_ordinal(1);

InMemoryKeyValueStore kv_store;
std::vector<GlobalTopologyProto> globals(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "TestPool",
num_nodes);
for (int i = 0; i < num_nodes; i++) {
thread_pool.Schedule([&, i] {
TF_ASSERT_OK(ExchangeTopologies(
/*platform=*/"cuda", /*node_id=*/i, num_nodes,
/*get_local_topology_timeout=*/
absl::Seconds(10), /*get_global_topology_timeout=*/
absl::Seconds(10), &kv_store, locals[i], &globals[i],
/*assign_global_device_ids=*/true));
// Simulate node 1 restarting and exchanging topologies again.
if (i == 1) {
TF_ASSERT_OK(ExchangeTopologies(
/*platform=*/"cuda", /*node_id=*/i, num_nodes,
/*get_local_topology_timeout=*/
absl::Seconds(10), /*get_global_topology_timeout=*/
absl::Seconds(10), &kv_store, locals[i], &globals[i],
/*assign_global_device_ids=*/true));
}
});
}
}
for (const GlobalTopologyProto& global : globals) {
EXPECT_EQ(global.nodes_size(), 2);
EXPECT_EQ(global.nodes()[0].devices_size(), 2);
EXPECT_EQ(global.nodes()[1].devices_size(), 2);
}
}

TEST(TopologyTest, ExchangeTopology_TwiceWithDifferentLocalTopology_Fails) {
int num_nodes = 2;
std::vector<LocalTopologyProto> locals(num_nodes);
DeviceProto* d0 = locals[0].add_devices();
d0->set_local_device_ordinal(0);
DeviceProto* d1 = locals[0].add_devices();
d1->set_local_device_ordinal(0);
DeviceProto* d2 = locals[1].add_devices();
d2->set_local_device_ordinal(0);
DeviceProto* d3 = locals[1].add_devices();
d3->set_local_device_ordinal(1);

InMemoryKeyValueStore kv_store;
std::vector<GlobalTopologyProto> globals(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "TestPool",
num_nodes);
for (int i = 0; i < num_nodes; i++) {
thread_pool.Schedule([&, i] {
TF_ASSERT_OK(ExchangeTopologies(
/*platform=*/"cuda", /*node_id=*/i, num_nodes,
/*get_local_topology_timeout=*/
absl::Seconds(10), /*get_global_topology_timeout=*/
absl::Seconds(10), &kv_store, locals[i], &globals[i],
/*assign_global_device_ids=*/true));
// Simulate node 1 restarting with different devices.
if (i == 1) {
DeviceProto* d4 = locals[1].add_devices();
d4->set_local_device_ordinal(2);
// This should fail because the local topology is unexpectedly
// different.
EXPECT_THAT(ExchangeTopologies(
/*platform=*/"cuda", /*node_id=*/i, num_nodes,
/*get_local_topology_timeout=*/
absl::Seconds(10), /*get_global_topology_timeout=*/
absl::Seconds(10), &kv_store, locals[i], &globals[i],
/*assign_global_device_ids=*/true),
StatusIs(absl::StatusCode::kInternal));
}
});
}
}
}

TEST(TopologyTest, BuildGpuTopology) {
std::string slice_0_boot_id = "foo";
std::string slice_1_boot_id = "bar";
Expand Down

0 comments on commit a5ba283

Please sign in to comment.