Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train_ddp, process_group: fixes so CUDA works e2e #5

Merged
merged 1 commit into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def step(self) -> None:
if not self._use_async_quorum:
self._quorum_future.result()

# eagerly apply pending state_dict so we can run the forwards pass
self._apply_pending_state_dict()

# we are forcing healing at the beginning so we're in a good state
# and don't need to zero_grad
self._healing = False
Expand Down Expand Up @@ -236,14 +239,27 @@ def _async_quorum(self) -> None:
primary_client = ManagerClient(address, timeout=self._timeout)
checkpoint_server_address = primary_client.checkpoint_address(self._rank)

state_dict = CheckpointServer.load_from_address(checkpoint_server_address)
self._load_state_dict(state_dict["user"])
self.load_state_dict(state_dict["torchft"])
self._state_dict = CheckpointServer.load_from_address(
checkpoint_server_address
)
self.load_state_dict(self._state_dict["torchft"])
# we apply the user state dict only when safe from the main thread

# This isn't strictly needed as loading the state_dict above should
# restore the correct step but it makes writing tests simpler.
self._step = max_step

def _apply_pending_state_dict(self) -> None:
assert self._healing, "must be in healing state"

# synchronize on future
self._quorum_future.result()

assert self._state_dict is not None, "checkpoint was not staged"

self._load_state_dict(self._state_dict["user"])
self._state_dict = None

def should_commit(self) -> bool:
for work in self._pending_work:
# check at the beginning of since .wait() may trigger errors
Expand All @@ -256,6 +272,10 @@ def should_commit(self) -> bool:

self._pending_work = []

# apply state_dict if healing
if self._healing:
self._apply_pending_state_dict()

enough_replicas = self._participating_replicas >= self._min_replica_size
local_should_commit = enough_replicas and not self._errored
should_commit = self._client.should_commit(
Expand Down
101 changes: 94 additions & 7 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
from typing import Type, List, Optional, Callable, Tuple
from datetime import timedelta
import threading

from torch.futures import Future
from torch.distributed import (
Expand All @@ -26,6 +27,11 @@

logger = logging.getLogger(__name__)

# TODO: use non strings which are cheaper
_QUEUE_CLOSE = "queue_close"
_FUTURE_RESULT = "fut_result"
_FUTURE_EXCEPTION = "fut_exception"


def _get(queue: mp.Queue, timeout) -> object:
v = queue.get(timeout=timeout)
Expand Down Expand Up @@ -208,9 +214,17 @@ def getBackendName(self):


class BabyWork(Work):
def __init__(self, tx: mp.Queue, rx: mp.Queue, op_id: int, timeout: float):
def __init__(
self,
pg: "ProcessGroupBaby",
tx: mp.Queue,
rx: mp.Queue,
op_id: int,
timeout: float,
):
super().__init__()

self._pg = pg
self._tx = tx
self._rx = rx
self._op_id = op_id
Expand All @@ -221,6 +235,9 @@ def wait(self) -> bool:
assert _get(self._rx, self._timeout) == self._op_id
return True

def get_future(self) -> Future:
return self._pg._get_future(self._op_id)


class BabyWorkNCCL(BabyWork):
def wait(self) -> bool:
Expand Down Expand Up @@ -255,6 +272,8 @@ def __init__(self, timeout: float = 60.0) -> None:
self._p = None
self._tx = None
self._rx = None
self._future_queue = None
self._future_thread = None

self._timeout = timeout

Expand All @@ -264,20 +283,46 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:

self._world_size = world_size

if self._tx is not None:
self._tx.close()
if self._rx is not None:
self._rx.close()
if self._future_queue is not None:
self._future_queue.put(_QUEUE_CLOSE)
self._future_queue.close()

ctx = mp.get_context("spawn")
self._tx = ctx.Queue()
self._rx = ctx.Queue()

# futures need thread to fire callbacks
self._future_queue = ctx.Queue()
# this lock needs to be held when manipulating _futures
self._futures_lock = threading.Lock()
self._futures = {}
self._future_thread = threading.Thread(
target=self._future_handler,
args=(self._future_queue,),
daemon=True,
)
self._future_thread.start()

self._p = ctx.Process(
target=self._worker,
args=(store_addr, rank, world_size, self._tx, self._rx),
args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue),
daemon=True,
)
self._p.start()

@classmethod
def _worker(
cls, store_addr: str, rank: int, world_size: int, rx: mp.Queue, tx: mp.Queue
cls,
store_addr: str,
rank: int,
world_size: int,
rx: mp.Queue,
tx: mp.Queue,
future_queue: mp.Queue,
) -> None:
try:
store = create_store(store_addr)
Expand All @@ -291,15 +336,28 @@ def _worker(
op = rx.get()
cmd = op[0]
if cmd == "func":
func, args, kwargs = op[1:]
work[next_op_id] = getattr(pg, func)(*args, **kwargs)
func_name, args, kwargs = op[1:]
fn = getattr(pg, func_name)
work[next_op_id] = fn(*args, **kwargs)
tx.put(next_op_id)
next_op_id += 1
elif cmd == "wait":
op_id = op[1]
work[op_id].wait()
del work[op_id]
tx.put(op_id)
elif cmd == "future":
op_id = op[1]

def callback(fut: Future):
try:
fut.wait()
future_queue.put((op_id, _FUTURE_RESULT, None))
except Exception as e:
future_queue.put((op_id, _FUTURE_EXCEPTION, e))

work[op_id].get_future().add_done_callback(callback)
tx.put(op_id)
elif cmd == "synchronize":
# CUDA only, use events instead of waiting on CPU
op_id = op[1]
Expand All @@ -322,12 +380,41 @@ def _worker(
logger.exception("worker errored")
tx.put(e)

def _future_handler(self, future_queue: mp.Queue) -> None:
try:
while True:
cmd = future_queue.get()
if cmd == _QUEUE_CLOSE:
break
op_id, mode, data = cmd
with self._futures_lock:
fut = self._futures[op_id]
del self._futures[op_id]
if mode == _FUTURE_RESULT:
fut.set_result(data)
elif mode == _FUTURE_EXCEPTION:
fut.set_exception(data)
else:
raise ValueError(f"unknown mode {mode}")
except Exception as e:
logger.exception(f"got unexpected error in future handler: {e}")

def _get_future(self, op_id: int) -> Future:
with self._futures_lock:
fut = Future()
self._futures[op_id] = fut
self._tx.put(("future", op_id), timeout=self._timeout)

assert _get(self._rx, self._timeout) == op_id
# TODO: return correct tensor instead of None
return fut

def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
self._tx.put(("func", func, args, kwargs), timeout=self._timeout)
op_id = _get(self._rx, self._timeout)
assert isinstance(op_id, int), f"invalid return {op_id}"
return self.WORK_CLASS(
tx=self._tx, rx=self._rx, op_id=op_id, timeout=self._timeout
pg=self, tx=self._tx, rx=self._rx, op_id=op_id, timeout=self._timeout
)

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
Expand Down Expand Up @@ -366,7 +453,7 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
tensors may leak in the current PyTorch implementation. TODO fix
"""

PG_CLASS = BaseProcessGroupGloo
PG_CLASS = BaseProcessGroupNCCL
WORK_CLASS = BabyWorkNCCL

def getBackendName(self):
Expand Down
33 changes: 20 additions & 13 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from unittest import TestCase, skipUnless
from concurrent.futures import ThreadPoolExecutor

import torch
from torch.distributed import TCPStore, ReduceOp
Expand Down Expand Up @@ -37,6 +38,7 @@ def test_gloo(self) -> None:

a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()
a_work.get_future().wait()

m = nn.Linear(3, 4)
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
Expand All @@ -58,6 +60,7 @@ def test_nccl(self) -> None:
at = torch.tensor([2], device=device)
a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()
a_work.get_future().wait()

m = nn.Linear(3, 4).to(device)
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
Expand Down Expand Up @@ -95,7 +98,9 @@ def test_baby_gloo(self) -> None:
b_work = b.allreduce([bt], ReduceOp.SUM)

a_work.wait()
b_work.wait()
fut = b_work.get_future()

fut.wait()

torch.testing.assert_close(at, bt)

Expand All @@ -113,23 +118,25 @@ def test_baby_nccl(self) -> None:

store_addr = f"localhost:{store.port}/prefix"

device = "cuda"
def run(rank: int) -> None:
a = ProcessGroupBabyNCCL()
a.configure(store_addr, rank, 2)

a = ProcessGroupBabyNCCL()
b = ProcessGroupBabyNCCL()
self.assertEqual(a.size(), 2)

a.configure(store_addr, 0, 2)
b.configure(store_addr, 1, 2)
at = torch.tensor([rank + 1], device=f"cuda:{rank}")

self.assertEqual(a.size(), 2)
a_work = a.allreduce([at], ReduceOp.SUM)
return at, a_work

at = torch.tensor([1], device=device)
bt = torch.tensor([2], device=device)
with ThreadPoolExecutor(max_workers=2) as executor:
a_fut = executor.submit(run, 0)
b_fut = executor.submit(run, 1)

a_work = a.allreduce([at], ReduceOp.SUM)
b_work = b.allreduce([bt], ReduceOp.SUM)
at, a_work = a_fut.result()
bt, b_work = b_fut.result()

a_work.wait()
b_work.wait()
b_work.get_future().wait()

torch.testing.assert_close(at, bt)
torch.testing.assert_close(at.cpu(), bt.cpu())
Loading
Loading