diff --git a/src/lib.rs b/src/lib.rs index e4d84a0..34c7cc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -225,7 +225,9 @@ struct Lighthouse { #[pymethods] impl Lighthouse { #[new] - fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult { + fn new(py: Python<'_>, bind: String, min_replicas: u64, join_timeout_ms: Option) -> PyResult { + let join_timeout_ms = join_timeout_ms.unwrap_or(100); + py.allow_threads(move || { let rt = Runtime::new()?; @@ -233,7 +235,7 @@ impl Lighthouse { .block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt { bind: bind, min_replicas: min_replicas, - join_timeout_ms: 100, + join_timeout_ms: join_timeout_ms, quorum_tick_ms: 100, })) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index 2bb961e..4aefc3b 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -20,6 +20,7 @@ from torchft.optim import OptimizerWrapper from torchft.process_group import ProcessGroupGloo from torchft.torchft import Lighthouse +import time logger: logging.Logger = logging.getLogger(__name__) @@ -252,6 +253,76 @@ def state_dict() -> Dict[str, Dict[str, object]]: return state_dict() +def manual_manager_management_train_loop( + rank: int, + store_port: int, + runner: Runner, +) -> Dict[str, Dict[str, object]]: + with ExitStack() as stack: + + def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict() -> Dict[str, Dict[str, object]]: + return { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + + print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") + + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=2, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=str(runner.replica_id), + store_addr="localhost", + store_port=store_port, + rank=rank, + world_size=runner.world_size, + lighthouse_addr=runner.lighthouse_address, + port=19530 + runner.replica_id, + # pyre-fixme[6]: Incompatible parameter type + **runner.manager_args, + ) + stack.callback(manager.shutdown) + + m: nn.Module = MyModel() + optimizer: optim.Optimizer = optim.Adam(m.parameters()) + criterion = nn.CrossEntropyLoss() + + model_states = [] + + while True: + manager.start_quorum("start", allow_heal=True) + model_states.append({k: v.detach().clone() for k, v in state_dict()["model"].items()}) + inputs = torch.rand(2, 3) + labels = torch.randint(4, (2,)) + + optimizer.zero_grad() + out = m(inputs) + time.sleep(1) + loss = criterion(out, labels) + + loss.backward() + + optimizer.step() + manager.start_quorum("end", allow_heal=False) + if manager.should_commit(): + for p in m.parameters(): + manager.allreduce(p).wait() + + if manager.current_step() >= 5: + break + + runner.failure_injector.check(rank, manager.current_step()) + + # return state_dict so we can check consistency + return model_states + class ManagerIntegTest(TestCase): def test_ddp_healthy(self) -> None: lighthouse = Lighthouse( @@ -423,3 +494,56 @@ def test_local_sgd_recovery(self) -> None: ) self.assertEqual(failure_injectors[1].count, 1) + + def test_join(self) -> None: + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=2, + join_timeout_ms=100, + ) + num_replicas = 2 + futures = [] + + with ThreadPoolExecutor(max_workers=3) as executor: + for replica_id in range(num_replicas): + runner = Runner( + replica_id=replica_id, + lighthouse_address=lighthouse.address(), + failure_injector=FailureInjector(), + train_loop=manual_manager_management_train_loop, + manager_args={ + "use_async_quorum": False, + }, + ) + futures.append(executor.submit(runner.run_replica)) + + time.sleep(3) + runner = Runner( + replica_id=2, + lighthouse_address=lighthouse.address(), + failure_injector=FailureInjector(), + train_loop=manual_manager_management_train_loop, + manager_args={ + "use_async_quorum": False, + }, + ) + futures.append(executor.submit(runner.run_replica)) + + state_dicts = [] + + for fut in as_completed(futures): + try: + state_dicts.append(fut.result()) + print(state_dicts[-1]) + except Exception as e: + print(e) + raise + + lighthouse.shutdown() + + #for state_dict in state_dicts: + # # LocalSGD only guarantees that the model is consistent across + # # replicas but uses separate optimizer states. + # torch.testing.assert_close( + # state_dict[0]["model"], state_dicts[0][0]["model"] + # ) diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index aee2947..4249ac3 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -23,6 +23,6 @@ class Manager: def shutdown(self) -> None: ... class Lighthouse: - def __init__(self, bind: str, min_replicas: int) -> None: ... + def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int]) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ...