Skip to content

Commit

Permalink
documentation update
Browse files Browse the repository at this point in the history
  • Loading branch information
realAsma committed Jan 23, 2025
1 parent 3456c05 commit 57c0a1d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 41 deletions.
2 changes: 2 additions & 0 deletions cleanrl/ppo_atari_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def unwrap_ddp(model) -> Agent:

agent = Agent(envs).to(device)
if args.world_size > 1:
# DDP syncs gradients (after each backward step), weights are sync'd at DDP initialization
agent = DDP(agent)

optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

# ALGO Logic: Storage setup
Expand Down
52 changes: 11 additions & 41 deletions docs/rl-algorithms/ppo.md
Original file line number Diff line number Diff line change
Expand Up @@ -909,66 +909,36 @@ See [related docs](/rl-algorithms/ppo/#explanation-of-the-logged-metrics) for `p

[ppo_atari_multigpu.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_multigpu.py) is based on `ppo_atari.py` (see its [related docs](/rl-algorithms/ppo/#implementation-details_1)).

We use [Pytorch's distributed API](https://pytorch.org/tutorials/intermediate/dist_tuto.html) to implement the data parallelism paradigm. The basic idea is that the user can spawn $N$ processes each running a copy of `ppo_atari.py`, holding a copy of the model, stepping the environments, and averaging their gradients together for the backward pass. Here are a few note-worthy implementation details.
We use Pytorch [distributed API](https://pytorch.org/tutorials/intermediate/dist_tuto.html) and [DistributedDataParallel module](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) to implement data parallelism. The basic idea is that the user can spawn $N$ processes each running a copy of `ppo_atari.py`, holding a copy of the model, stepping the environments and learn a better policy. Here are a few note-worthy implementation details.

1. **Local versus global parameters**: All of the parameters in `ppo_atari.py` are global (such as batch size), but in `ppo_atari_multigpu.py` we have local parameters as well. Say we run `torchrun --standalone --nnodes=1 --nproc_per_node=2 cleanrl/ppo_atari_multigpu.py --env-id BreakoutNoFrameskip-v4 --local-num-envs=4`; here are how all multi-gpu related parameters are adjusted:
* **number of environments**: `num_envs = local_num_envs * world_size = 4 * 2 = 8`
* **batch size**: `local_batch_size = local_num_envs * num_steps = 4 * 128 = 512`, `batch_size = num_envs * num_steps) = 8 * 128 = 1024`
* **minibatch size**: `local_minibatch_size = int(args.local_batch_size // args.num_minibatches) = 512 // 4 = 128`, `minibatch_size = int(args.batch_size // args.num_minibatches) = 1024 // 4 = 256`
* **number of updates**: `num_iterations = args.total_timesteps // args.batch_size = 10000000 // 1024 = 9765`
1. **Adjust seed per process**: we need be very careful with seeding: we could have used the exact same seed for each subprocess. To ensure this does not happen, we do the following
1. **Adjust seed per process**: we need to be very careful with seeding: we could have used the exact same seed for each subprocess. To ensure this does not happen, we do the following

```python hl_lines="2 5 16"
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args.seed += local_rank
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed - local_rank)
torch.backends.cudnn.deterministic = args.torch_deterministic

# ...

envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

agent = Agent(envs).to(device)
torch.manual_seed(args.seed)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
```

Notice that we adjust the seed with `args.seed += local_rank` (line 2), where `local_rank` is the index of the subprocesses. This ensures we seed packages and envs with uncorrealted seeds. However, we do need to use the same `torch` seed for all process to initialize same weights for the `agent` (line 5), after which we can use a different seed for `torch` (line 16).
1. **Efficient gradient averaging**: PyTorch recommends to average the gradient across the whole world via the following (see [docs](https://pytorch.org/tutorials/intermediate/dist_tuto.html#distributed-training))

```python
for param in agent.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= world_size
torch.backends.cudnn.deterministic = args.torch_deterministic
```

However, [@cswinter](https://github.com/cswinter) introduces a more efficient gradient averaging scheme with proper batching (see :material-github: [entity-neural-network/incubator#220](https://github.com/entity-neural-network/incubator/pull/220)), which looks like:
1. **Pytorch DDP for weight and gradient synchronization**: We wrap the agent in Pytorch [DistributedDataParallel (DDP)](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) module as shown below:

```python
all_grads_list = []
for param in agent.parameters():
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0
for param in agent.parameters():
if param.grad is not None:
param.grad.data.copy_(
all_grads[offset : offset + param.numel()].view_as(param.grad.data) / world_size
)
offset += param.numel()
from torch.nn.parallel import DistributedDataParallel as DDP
...
agent = Agent(envs).to(device)
if args.world_size > 1:
# DDP syncs gradients (after each backward step), weights are sync'd at DDP initialization
agent = DDP(agent)
```

In our previous empirical testing (see :material-github: [vwxyzjn/cleanrl#162](https://github.com/vwxyzjn/cleanrl/pull/162#issuecomment-1107909696)), we have found [@cswinter](https://github.com/cswinter)'s implementation to be faster, hence we adopt it in our implementation.



`DDP` uses collective communications from the torch.distributed package to synchronize gradients across all processes after each backward pass. This means that each process will have its own copy of the model, but they’ll all work together to train the model as if it were on a single machine.

### Experiment results

Expand Down

0 comments on commit 57c0a1d

Please sign in to comment.