From 8725a340e3484fc07daec13e0da76ab052c0eb64 Mon Sep 17 00:00:00 2001 From: gliese876b Date: Fri, 10 Jan 2025 10:55:36 +0000 Subject: [PATCH] Prioritized Replay Buffer option is added --- benchmarl/algorithms/common.py | 10 ++++++++-- .../conf/experiment/base_experiment.yaml | 19 +++++++++++++------ benchmarl/experiment/experiment.py | 3 +++ 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index 96fba1ab..134302cf 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -18,7 +18,7 @@ ReplayBuffer, TensorDictReplayBuffer, ) -from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement +from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement, PrioritizedSampler from torchrl.envs import Compose, EnvBase, Transform from torchrl.objectives import LossModule from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater @@ -158,7 +158,13 @@ def get_replay_buffer( memory_size = -(-memory_size // sequence_length) sampling_size = -(-sampling_size // sequence_length) - sampler = SamplerWithoutReplacement() if self.on_policy else RandomSampler() + if self.on_policy: + sampler = SamplerWithoutReplacement() + elif self.experiment_config.off_policy_use_prioritized_replay_buffer: + sampler = PrioritizedSampler(memory_size, self.experiment_config.off_policy_prb_alpha, self.experiment_config.off_policy_prb_beta) + else: + sampler = RandomSampler() + return TensorDictReplayBuffer( storage=LazyTensorStorage( memory_size, diff --git a/benchmarl/conf/experiment/base_experiment.yaml b/benchmarl/conf/experiment/base_experiment.yaml index 014aae5f..648767fa 100644 --- a/benchmarl/conf/experiment/base_experiment.yaml +++ b/benchmarl/conf/experiment/base_experiment.yaml @@ -15,7 +15,7 @@ share_policy_params: True prefer_continuous_actions: True # If False collection is done using a collector (under no grad). If True, collection is done with gradients. collect_with_grad: False -# In case of non-vectorized environments, weather to run collection of multiple processes +# In case of non-vectorized environments, whether to run collection of multiple processes # If this is used, there will be n_envs_per_worker processes, collecting frames_per_batch/n_envs_per_worker frames each parallel_collection: False @@ -34,7 +34,7 @@ clip_grad_val: 5 soft_target_update: True # If soft_target_update is True, this is its polyak_tau polyak_tau: 0.005 -# If soft_target_update is False, this is the frequency of the hard trarget updates in terms of n_optimizer_steps +# If soft_target_update is False, this is the frequency of the hard target updates in terms of n_optimizer_steps hard_target_update_frequency: 5 # When an exploration wrapper is used. This is its initial epsilon for annealing @@ -54,7 +54,7 @@ max_n_frames: 3_000_000 on_policy_collected_frames_per_batch: 6000 # Number of environments used for collection # If the environment is vectorized, this will be the number of batched environments. -# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection. +# Otherwise batching will be simulated and each env will be run sequentially or parallel depending on parallel_collection. on_policy_n_envs_per_worker: 10 # This is the number of times collected_frames_per_batch will be split into minibatches and trained on_policy_n_minibatch_iters: 45 @@ -66,7 +66,7 @@ on_policy_minibatch_size: 400 off_policy_collected_frames_per_batch: 6000 # Number of environments used for collection # If the environment is vectorized, this will be the number of batched environments. -# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection. +# Otherwise batching will be simulated and each env will be run sequentially or parallel depending on parallel_collection. off_policy_n_envs_per_worker: 10 # This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over. off_policy_n_optimizer_steps: 1000 @@ -76,11 +76,18 @@ off_policy_train_batch_size: 128 off_policy_memory_size: 1_000_000 # Number of random action frames to prefill the replay buffer with off_policy_init_random_frames: 0 +# whether to use priorities while sampling from the replay buffer +off_policy_use_prioritized_replay_buffer: False +# exponent that determines how much prioritization is used when off_policy_use_prioritized_replay_buffer = True +# PRB reduces to random sampling when alpha=0 +off_policy_prb_alpha: 0.6 +# importance sampling negative exponent when off_policy_use_prioritized_replay_buffer = True +off_policy_prb_beta: 0.4 evaluation: True # Whether to render the evaluation (if rendering is available) -render: True +render: False # Frequency of evaluation in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch) evaluation_interval: 120_000 # Number of episodes that evaluation is run on @@ -108,7 +115,7 @@ restore_map_location: null # Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch). # Set it to 0 to disable checkpointing checkpoint_interval: 0 -# Wether to checkpoint when the experiment is done +# Whether to checkpoint when the experiment is done checkpoint_at_end: False # How many checkpoints to keep. As new checkpoints are taken, temporally older checkpoints are deleted to keep this number of # checkpoints. The checkpoint at the end is included in this number. Set to `null` to keep all checkpoints. diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 2ea0f92c..5b483553 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -90,6 +90,9 @@ class ExperimentConfig: off_policy_train_batch_size: int = MISSING off_policy_memory_size: int = MISSING off_policy_init_random_frames: int = MISSING + off_policy_use_prioritized_replay_buffer: bool = MISSING + off_policy_prb_alpha: float = MISSING + off_policy_prb_beta: float = MISSING evaluation: bool = MISSING render: bool = MISSING