Skip to content

Commit

Permalink
manager: added FIXED_WITH_SPARES mode (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k authored Dec 6, 2024
1 parent 1d5464d commit ab66c7c
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 33 deletions.
4 changes: 3 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// LICENSE file in the root directory of this source tree.

fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("proto/torchft.proto")?;
tonic_build::configure()
.protoc_arg("--experimental_allow_proto3_optional")
.compile_protos(&["proto/torchft.proto"], &["proto"])?;
Ok(())
}
11 changes: 7 additions & 4 deletions proto/torchft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,14 @@ message ManagerQuorumResponse {
int64 quorum_id = 1;
string address = 2;
string store_address = 3;
// These are information for the replicas which are at the max step.
int64 max_step = 4;
int64 num_max = 5;
int64 replica_rank = 6;
int64 replica_world = 7;
bool heal = 8;
optional int64 max_rank = 5;
int64 max_world_size = 6;
// These are information for all replicas including behind replicas.
int64 replica_rank = 7;
int64 replica_world_size = 8;
bool heal = 9;
}

message CheckpointAddressRequest {
Expand Down
7 changes: 4 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl ManagerClient {
rank: i64,
step: i64,
checkpoint_server_addr: String,
) -> PyResult<(i64, i64, i64, String, String, i64, i64, bool)> {
) -> PyResult<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool)> {
py.allow_threads(move || {
let mut request = tonic::Request::new(ManagerQuorumRequest {
rank: rank,
Expand All @@ -121,11 +121,12 @@ impl ManagerClient {
Ok((
resp.quorum_id,
resp.replica_rank,
resp.replica_world,
resp.replica_world_size,
resp.address,
resp.store_address,
resp.max_step,
resp.num_max,
resp.max_rank,
resp.max_world_size,
resp.heal,
))
})
Expand Down
13 changes: 11 additions & 2 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,14 @@ impl ManagerService for Arc<Manager> {

let primary = max_participants[rank as usize % max_participants.len()];

let mut max_rank = None;
for (i, p) in max_participants.iter().enumerate() {
if p.replica_id == self.replica_id {
max_rank = Some(i as i64);
break;
}
}

// Decide whether we should be healing:
// 1. if we're not at the max step
// 2. if everyone is at the first step and we're not the primary
Expand All @@ -251,9 +259,10 @@ impl ManagerService for Arc<Manager> {
address: primary.address.clone(),
store_address: primary.store_address.clone(),
max_step: max_step,
num_max: max_participants.len() as i64,
max_rank: max_rank,
max_world_size: max_participants.len() as i64,
replica_rank: replica_rank as i64,
replica_world: participants.len() as i64,
replica_world_size: participants.len() as i64,
heal: heal,
};

Expand Down
64 changes: 54 additions & 10 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
import logging
import os
import socket
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast

import torch
Expand All @@ -54,6 +54,24 @@
T = TypeVar("T")


class WorldSizeMode(Enum):
"""
This controls the numerics for the job when doing allreduces across replicas
when the world size is larger than ``min_replica_size``. The world size will
never be smaller than ``min_replica_size``.
DYNAMIC:
The world size will dynamical increase to use all available
replicas and normalize the gradient by the world size.
FIXED_WITH_SPARES:
The number of active replicas is ``min_replica_size`` and any spares
will contribute zero gradients.
"""

DYNAMIC = 0
FIXED_WITH_SPARES = 1


class Manager:
"""
Manager manages the full fault tolerant training loop.
Expand All @@ -73,6 +91,7 @@ def __init__(
timeout: timedelta = timedelta(seconds=60),
rank: Optional[int] = None,
world_size: Optional[int] = None,
world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
store_addr: Optional[str] = None,
store_port: Optional[int] = None,
lighthouse_addr: Optional[str] = None,
Expand All @@ -98,6 +117,7 @@ def __init__(
self._pending_state_dict: Optional[Dict[str, object]] = None
self._use_async_quorum = use_async_quorum
self._timeout = timeout
self._world_size_mode = world_size_mode

store_addr = store_addr or os.environ["MASTER_ADDR"]
store_port = store_port or int(os.environ["MASTER_PORT"])
Expand Down Expand Up @@ -150,12 +170,13 @@ def __init__(
self._quorum_id = -1
self._errored = False
self._healing = False
self._participating_replicas = 0
self._pending_work: List[torch.futures.Future[object]] = []
self._batches_committed = 0

# first step is 1
self._should_step = True
self._participating_rank: Optional[int] = None
self._participating_world_size: int = 0

def shutdown(self) -> None:
"""
Expand Down Expand Up @@ -287,7 +308,7 @@ def step(self) -> None:

if self._should_step:
self._step += 1
self._batches_committed += self._participating_replicas
self._batches_committed += self.num_participants()

self._errored = False
self._healing = False
Expand All @@ -311,25 +332,45 @@ def _async_quorum(self) -> None:
(
quorum_id,
replica_rank,
replica_world,
replica_world_size,
address,
store_address,
max_step,
num_max,
max_rank,
max_world_size,
heal,
) = self._client.quorum(
rank=self._rank,
step=self._step,
checkpoint_server_addr=self._ckpt_server.address(),
)
self._participating_replicas = (
num_max if self._use_async_quorum else replica_world

# When using async quorum we need to take the recovered workers.
# When not using async quorum we need to take the max world size as all
# workers will be healthy.
self._participating_rank, self._participating_world_size = (
(max_rank, max_world_size)
if self._use_async_quorum
else (replica_rank, replica_world_size)
)

# For fixed with spares we need to ensure that we don't have more
# participating replicas than the min replica size.
if self._world_size_mode == WorldSizeMode.FIXED_WITH_SPARES:
self._participating_world_size = min(
self._participating_world_size, self._min_replica_size
)
if (
self._participating_rank is not None
and self._participating_rank >= self._min_replica_size
):
self._participating_rank = None

if quorum_id != self._quorum_id:
logger.info(f"reconfiguring for quorum_id {quorum_id}")
store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}"
self._pg.configure(store_prefixed_addr, replica_rank, replica_world)
# We use the replica rank and world as we want all replicas in the PG.
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
self._quorum_id = quorum_id

# See manager.rs for healing conditions
Expand Down Expand Up @@ -396,7 +437,7 @@ def should_commit(self) -> bool:
if self._healing:
self._apply_pending_state_dict()

enough_replicas = self._participating_replicas >= self._min_replica_size
enough_replicas = self.num_participants() >= self._min_replica_size
local_should_commit = enough_replicas and not self._errored
should_commit = self._client.should_commit(
self._rank, self._step, local_should_commit
Expand Down Expand Up @@ -469,7 +510,8 @@ def num_participants(self) -> int:
Returns:
the number of participants in the current quorum
"""
return self._participating_replicas
assert self._participating_world_size >= 0, "internal error"
return self._participating_world_size

def is_participating(self) -> bool:
"""
Expand All @@ -478,6 +520,8 @@ def is_participating(self) -> bool:
Returns:
whether this replica is participating in the current quorum
"""
if self._participating_rank is None:
return False
if self._healing:
assert self._use_async_quorum
return False
Expand Down
Loading

0 comments on commit ab66c7c

Please sign in to comment.