diff --git a/cleanrl/ppo_atari_multigpu.py b/cleanrl/ppo_atari_multigpu.py index 7a37b4ee..683a89f3 100644 --- a/cleanrl/ppo_atari_multigpu.py +++ b/cleanrl/ppo_atari_multigpu.py @@ -15,6 +15,7 @@ import tyro from rich.pretty import pprint from torch.distributions.categorical import Categorical +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from stable_baselines3.common.atari_wrappers import ( # isort:skip @@ -159,6 +160,10 @@ def get_action_and_value(self, x, action=None): return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) +def unwrap_ddp(model) -> Agent: + return model.module if isinstance(model, DDP) else model + + if __name__ == "__main__": # torchrun --standalone --nnodes=1 --nproc_per_node=2 ppo_atari_multigpu.py # taken from https://pytorch.org/docs/stable/elastic/run.html @@ -228,6 +233,8 @@ def get_action_and_value(self, x, action=None): assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" agent = Agent(envs).to(device) + if args.world_size > 1: + agent = DDP(agent) torch.manual_seed(args.seed) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) @@ -260,7 +267,7 @@ def get_action_and_value(self, x, action=None): # ALGO LOGIC: action logic with torch.no_grad(): - action, logprob, _, value = agent.get_action_and_value(next_obs) + action, logprob, _, value = unwrap_ddp(agent).get_action_and_value(next_obs) values[step] = value.flatten() actions[step] = action logprobs[step] = logprob @@ -282,11 +289,11 @@ def get_action_and_value(self, x, action=None): writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) print( - f"local_rank: {local_rank}, action.sum(): {action.sum()}, iteration: {iteration}, agent.actor.weight.sum(): {agent.actor.weight.sum()}" + f"local_rank: {local_rank}, action.sum(): {action.sum()}, iteration: {iteration}, agent.actor.weight.sum(): {unwrap_ddp(agent).actor.weight.sum()}" ) # bootstrap value if not done with torch.no_grad(): - next_value = agent.get_value(next_obs).reshape(1, -1) + next_value = unwrap_ddp(agent).get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards).to(device) lastgaelam = 0 for t in reversed(range(args.num_steps)): @@ -317,7 +324,9 @@ def get_action_and_value(self, x, action=None): end = start + args.local_minibatch_size mb_inds = b_inds[start:end] - _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + _, newlogprob, entropy, newvalue = unwrap_ddp(agent).get_action_and_value( + b_obs[mb_inds], b_actions.long()[mb_inds] + ) logratio = newlogprob - b_logprobs[mb_inds] ratio = logratio.exp() @@ -357,22 +366,6 @@ def get_action_and_value(self, x, action=None): optimizer.zero_grad() loss.backward() - if args.world_size > 1: - # batch allreduce ops: see https://github.com/entity-neural-network/incubator/pull/220 - all_grads_list = [] - for param in agent.parameters(): - if param.grad is not None: - all_grads_list.append(param.grad.view(-1)) - all_grads = torch.cat(all_grads_list) - dist.all_reduce(all_grads, op=dist.ReduceOp.SUM) - offset = 0 - for param in agent.parameters(): - if param.grad is not None: - param.grad.data.copy_( - all_grads[offset : offset + param.numel()].view_as(param.grad.data) / args.world_size - ) - offset += param.numel() - nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step()