Skip to content

Commit

Permalink
[Minor update] Use Pytorch DDP in ppo_atari_multigpu
Browse files Browse the repository at this point in the history
[Minor update] Use Pytorch DDP in ppo_atari_multigpu
  • Loading branch information
realAsma committed Jan 14, 2025
1 parent e648ee2 commit e757c8d
Showing 1 changed file with 13 additions and 20 deletions.
33 changes: 13 additions & 20 deletions cleanrl/ppo_atari_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tyro
from rich.pretty import pprint
from torch.distributions.categorical import Categorical
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

from stable_baselines3.common.atari_wrappers import ( # isort:skip
Expand Down Expand Up @@ -159,6 +160,10 @@ def get_action_and_value(self, x, action=None):
return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)


def unwrap_ddp(model) -> Agent:
return model.module if isinstance(model, DDP) else model


if __name__ == "__main__":
# torchrun --standalone --nnodes=1 --nproc_per_node=2 ppo_atari_multigpu.py
# taken from https://pytorch.org/docs/stable/elastic/run.html
Expand Down Expand Up @@ -228,6 +233,8 @@ def get_action_and_value(self, x, action=None):
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

agent = Agent(envs).to(device)
if args.world_size > 1:
agent = DDP(agent)
torch.manual_seed(args.seed)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

Expand Down Expand Up @@ -260,7 +267,7 @@ def get_action_and_value(self, x, action=None):

# ALGO LOGIC: action logic
with torch.no_grad():
action, logprob, _, value = agent.get_action_and_value(next_obs)
action, logprob, _, value = unwrap_ddp(agent).get_action_and_value(next_obs)
values[step] = value.flatten()
actions[step] = action
logprobs[step] = logprob
Expand All @@ -282,11 +289,11 @@ def get_action_and_value(self, x, action=None):
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

print(
f"local_rank: {local_rank}, action.sum(): {action.sum()}, iteration: {iteration}, agent.actor.weight.sum(): {agent.actor.weight.sum()}"
f"local_rank: {local_rank}, action.sum(): {action.sum()}, iteration: {iteration}, agent.actor.weight.sum(): {unwrap_ddp(agent).actor.weight.sum()}"
)
# bootstrap value if not done
with torch.no_grad():
next_value = agent.get_value(next_obs).reshape(1, -1)
next_value = unwrap_ddp(agent).get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
Expand Down Expand Up @@ -317,7 +324,9 @@ def get_action_and_value(self, x, action=None):
end = start + args.local_minibatch_size
mb_inds = b_inds[start:end]

_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
_, newlogprob, entropy, newvalue = unwrap_ddp(agent).get_action_and_value(
b_obs[mb_inds], b_actions.long()[mb_inds]
)
logratio = newlogprob - b_logprobs[mb_inds]
ratio = logratio.exp()

Expand Down Expand Up @@ -357,22 +366,6 @@ def get_action_and_value(self, x, action=None):
optimizer.zero_grad()
loss.backward()

if args.world_size > 1:
# batch allreduce ops: see https://github.com/entity-neural-network/incubator/pull/220
all_grads_list = []
for param in agent.parameters():
if param.grad is not None:
all_grads_list.append(param.grad.view(-1))
all_grads = torch.cat(all_grads_list)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
offset = 0
for param in agent.parameters():
if param.grad is not None:
param.grad.data.copy_(
all_grads[offset : offset + param.numel()].view_as(param.grad.data) / args.world_size
)
offset += param.numel()

nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()

Expand Down

0 comments on commit e757c8d

Please sign in to comment.