From 03d1a1c2fe4fef5b13911f05cbdda0c7f8850a49 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:07:17 +0100 Subject: [PATCH] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 0084a1c3..f4eeac2c 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -278,12 +278,13 @@ def get_action_and_value(self, x, action=None): b_values = values.reshape(-1) # Optimizing the policy and value network - b_inds = np.arange(args.batch_size) + b_inds = np.arange(args.local_batch_size) clipfracs = [] for epoch in range(args.update_epochs): np.random.shuffle(b_inds) - for start in range(0, args.batch_size, args.minibatch_size): - end = start + args.minibatch_size + for start in range(0, args.local_batch_size, args.local_minibatch_size): + end = start + args.local_minibatch_size + mb_inds = b_inds[start:end] mb_inds = b_inds[start:end] _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])