Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test manager join #62

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,17 @@ struct Lighthouse {
#[pymethods]
impl Lighthouse {
#[new]
fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult<Self> {
fn new(py: Python<'_>, bind: String, min_replicas: u64, join_timeout_ms: Option<u64>) -> PyResult<Self> {
let join_timeout_ms = join_timeout_ms.unwrap_or(100);

py.allow_threads(move || {
let rt = Runtime::new()?;

let lighthouse = rt
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
bind: bind,
min_replicas: min_replicas,
join_timeout_ms: 100,
join_timeout_ms: join_timeout_ms,
quorum_tick_ms: 100,
}))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Expand Down
124 changes: 124 additions & 0 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchft.optim import OptimizerWrapper
from torchft.process_group import ProcessGroupGloo
from torchft.torchft import Lighthouse
import time

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -252,6 +253,76 @@ def state_dict() -> Dict[str, Dict[str, object]]:
return state_dict()


def manual_manager_management_train_loop(
rank: int,
store_port: int,
runner: Runner,
) -> Dict[str, Dict[str, object]]:
with ExitStack() as stack:

def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
m.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optim"])

def state_dict() -> Dict[str, Dict[str, object]]:
return {
"model": m.state_dict(),
"optim": optimizer.state_dict(),
}

print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")

pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=str(runner.replica_id),
store_addr="localhost",
store_port=store_port,
rank=rank,
world_size=runner.world_size,
lighthouse_addr=runner.lighthouse_address,
port=19530 + runner.replica_id,
# pyre-fixme[6]: Incompatible parameter type
**runner.manager_args,
)
stack.callback(manager.shutdown)

m: nn.Module = MyModel()
optimizer: optim.Optimizer = optim.Adam(m.parameters())
criterion = nn.CrossEntropyLoss()

model_states = []

while True:
manager.start_quorum("start", allow_heal=True)
model_states.append({k: v.detach().clone() for k, v in state_dict()["model"].items()})
inputs = torch.rand(2, 3)
labels = torch.randint(4, (2,))

optimizer.zero_grad()
out = m(inputs)
time.sleep(1)
loss = criterion(out, labels)

loss.backward()

optimizer.step()
manager.start_quorum("end", allow_heal=False)
if manager.should_commit():
for p in m.parameters():
manager.allreduce(p).wait()

if manager.current_step() >= 5:
break

runner.failure_injector.check(rank, manager.current_step())

# return state_dict so we can check consistency
return model_states

class ManagerIntegTest(TestCase):
def test_ddp_healthy(self) -> None:
lighthouse = Lighthouse(
Expand Down Expand Up @@ -423,3 +494,56 @@ def test_local_sgd_recovery(self) -> None:
)

self.assertEqual(failure_injectors[1].count, 1)

def test_join(self) -> None:
lighthouse = Lighthouse(
bind="[::]:0",
min_replicas=2,
join_timeout_ms=100,
)
num_replicas = 2
futures = []

with ThreadPoolExecutor(max_workers=3) as executor:
for replica_id in range(num_replicas):
runner = Runner(
replica_id=replica_id,
lighthouse_address=lighthouse.address(),
failure_injector=FailureInjector(),
train_loop=manual_manager_management_train_loop,
manager_args={
"use_async_quorum": False,
},
)
futures.append(executor.submit(runner.run_replica))

time.sleep(3)
runner = Runner(
replica_id=2,
lighthouse_address=lighthouse.address(),
failure_injector=FailureInjector(),
train_loop=manual_manager_management_train_loop,
manager_args={
"use_async_quorum": False,
},
)
futures.append(executor.submit(runner.run_replica))

state_dicts = []

for fut in as_completed(futures):
try:
state_dicts.append(fut.result())
print(state_dicts[-1])
except Exception as e:
print(e)
raise

lighthouse.shutdown()

#for state_dict in state_dicts:
# # LocalSGD only guarantees that the model is consistent across
# # replicas but uses separate optimizer states.
# torch.testing.assert_close(
# state_dict[0]["model"], state_dicts[0][0]["model"]
# )
2 changes: 1 addition & 1 deletion torchft/torchft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ class Manager:
def shutdown(self) -> None: ...

class Lighthouse:
def __init__(self, bind: str, min_replicas: int) -> None: ...
def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int]) -> None: ...
def address(self) -> str: ...
def shutdown(self) -> None: ...
Loading