From 57c0a1d7f39ab05e7ffde7b058dab0f2377e2fe1 Mon Sep 17 00:00:00 2001 From: akuriparambi Date: Thu, 23 Jan 2025 15:37:38 -0800 Subject: [PATCH] documentation update --- cleanrl/ppo_atari_multigpu.py | 2 ++ docs/rl-algorithms/ppo.md | 52 ++++++++--------------------------- 2 files changed, 13 insertions(+), 41 deletions(-) diff --git a/cleanrl/ppo_atari_multigpu.py b/cleanrl/ppo_atari_multigpu.py index f32279fe5..6a3885f44 100644 --- a/cleanrl/ppo_atari_multigpu.py +++ b/cleanrl/ppo_atari_multigpu.py @@ -234,7 +234,9 @@ def unwrap_ddp(model) -> Agent: agent = Agent(envs).to(device) if args.world_size > 1: + # DDP syncs gradients (after each backward step), weights are sync'd at DDP initialization agent = DDP(agent) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) # ALGO Logic: Storage setup diff --git a/docs/rl-algorithms/ppo.md b/docs/rl-algorithms/ppo.md index e83b38e63..3f908392e 100644 --- a/docs/rl-algorithms/ppo.md +++ b/docs/rl-algorithms/ppo.md @@ -909,66 +909,36 @@ See [related docs](/rl-algorithms/ppo/#explanation-of-the-logged-metrics) for `p [ppo_atari_multigpu.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_multigpu.py) is based on `ppo_atari.py` (see its [related docs](/rl-algorithms/ppo/#implementation-details_1)). -We use [Pytorch's distributed API](https://pytorch.org/tutorials/intermediate/dist_tuto.html) to implement the data parallelism paradigm. The basic idea is that the user can spawn $N$ processes each running a copy of `ppo_atari.py`, holding a copy of the model, stepping the environments, and averaging their gradients together for the backward pass. Here are a few note-worthy implementation details. +We use Pytorch [distributed API](https://pytorch.org/tutorials/intermediate/dist_tuto.html) and [DistributedDataParallel module](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) to implement data parallelism. The basic idea is that the user can spawn $N$ processes each running a copy of `ppo_atari.py`, holding a copy of the model, stepping the environments and learn a better policy. Here are a few note-worthy implementation details. 1. **Local versus global parameters**: All of the parameters in `ppo_atari.py` are global (such as batch size), but in `ppo_atari_multigpu.py` we have local parameters as well. Say we run `torchrun --standalone --nnodes=1 --nproc_per_node=2 cleanrl/ppo_atari_multigpu.py --env-id BreakoutNoFrameskip-v4 --local-num-envs=4`; here are how all multi-gpu related parameters are adjusted: * **number of environments**: `num_envs = local_num_envs * world_size = 4 * 2 = 8` * **batch size**: `local_batch_size = local_num_envs * num_steps = 4 * 128 = 512`, `batch_size = num_envs * num_steps) = 8 * 128 = 1024` * **minibatch size**: `local_minibatch_size = int(args.local_batch_size // args.num_minibatches) = 512 // 4 = 128`, `minibatch_size = int(args.batch_size // args.num_minibatches) = 1024 // 4 = 256` * **number of updates**: `num_iterations = args.total_timesteps // args.batch_size = 10000000 // 1024 = 9765` -1. **Adjust seed per process**: we need be very careful with seeding: we could have used the exact same seed for each subprocess. To ensure this does not happen, we do the following +1. **Adjust seed per process**: we need to be very careful with seeding: we could have used the exact same seed for each subprocess. To ensure this does not happen, we do the following ```python hl_lines="2 5 16" # CRUCIAL: note that we needed to pass a different seed for each data parallelism worker args.seed += local_rank random.seed(args.seed) np.random.seed(args.seed) - torch.manual_seed(args.seed - local_rank) - torch.backends.cudnn.deterministic = args.torch_deterministic - - # ... - - envs = gym.vector.SyncVectorEnv( - [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] - ) - assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" - - agent = Agent(envs).to(device) torch.manual_seed(args.seed) - optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) - ``` - - Notice that we adjust the seed with `args.seed += local_rank` (line 2), where `local_rank` is the index of the subprocesses. This ensures we seed packages and envs with uncorrealted seeds. However, we do need to use the same `torch` seed for all process to initialize same weights for the `agent` (line 5), after which we can use a different seed for `torch` (line 16). -1. **Efficient gradient averaging**: PyTorch recommends to average the gradient across the whole world via the following (see [docs](https://pytorch.org/tutorials/intermediate/dist_tuto.html#distributed-training)) - - ```python - for param in agent.parameters(): - dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) - param.grad.data /= world_size + torch.backends.cudnn.deterministic = args.torch_deterministic ``` - However, [@cswinter](https://github.com/cswinter) introduces a more efficient gradient averaging scheme with proper batching (see :material-github: [entity-neural-network/incubator#220](https://github.com/entity-neural-network/incubator/pull/220)), which looks like: +1. **Pytorch DDP for weight and gradient synchronization**: We wrap the agent in Pytorch [DistributedDataParallel (DDP)](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) module as shown below: ```python - all_grads_list = [] - for param in agent.parameters(): - if param.grad is not None: - all_grads_list.append(param.grad.view(-1)) - all_grads = torch.cat(all_grads_list) - dist.all_reduce(all_grads, op=dist.ReduceOp.SUM) - offset = 0 - for param in agent.parameters(): - if param.grad is not None: - param.grad.data.copy_( - all_grads[offset : offset + param.numel()].view_as(param.grad.data) / world_size - ) - offset += param.numel() + from torch.nn.parallel import DistributedDataParallel as DDP + ... + agent = Agent(envs).to(device) + if args.world_size > 1: + # DDP syncs gradients (after each backward step), weights are sync'd at DDP initialization + agent = DDP(agent) ``` - In our previous empirical testing (see :material-github: [vwxyzjn/cleanrl#162](https://github.com/vwxyzjn/cleanrl/pull/162#issuecomment-1107909696)), we have found [@cswinter](https://github.com/cswinter)'s implementation to be faster, hence we adopt it in our implementation. - - - +`DDP` uses collective communications from the torch.distributed package to synchronize gradients across all processes after each backward pass. This means that each process will have its own copy of the model, but they’ll all work together to train the model as if it were on a single machine. ### Experiment results