Skip to content

Commit

Permalink
Prioritized Replay Buffer option is added
Browse files Browse the repository at this point in the history
  • Loading branch information
gliese876b committed Jan 10, 2025
1 parent e910a83 commit 8725a34
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
10 changes: 8 additions & 2 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement, PrioritizedSampler
from torchrl.envs import Compose, EnvBase, Transform
from torchrl.objectives import LossModule
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater
Expand Down Expand Up @@ -158,7 +158,13 @@ def get_replay_buffer(
memory_size = -(-memory_size // sequence_length)
sampling_size = -(-sampling_size // sequence_length)

sampler = SamplerWithoutReplacement() if self.on_policy else RandomSampler()
if self.on_policy:
sampler = SamplerWithoutReplacement()
elif self.experiment_config.off_policy_use_prioritized_replay_buffer:
sampler = PrioritizedSampler(memory_size, self.experiment_config.off_policy_prb_alpha, self.experiment_config.off_policy_prb_beta)
else:
sampler = RandomSampler()

return TensorDictReplayBuffer(
storage=LazyTensorStorage(
memory_size,
Expand Down
19 changes: 13 additions & 6 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ share_policy_params: True
prefer_continuous_actions: True
# If False collection is done using a collector (under no grad). If True, collection is done with gradients.
collect_with_grad: False
# In case of non-vectorized environments, weather to run collection of multiple processes
# In case of non-vectorized environments, whether to run collection of multiple processes
# If this is used, there will be n_envs_per_worker processes, collecting frames_per_batch/n_envs_per_worker frames each
parallel_collection: False

Expand All @@ -34,7 +34,7 @@ clip_grad_val: 5
soft_target_update: True
# If soft_target_update is True, this is its polyak_tau
polyak_tau: 0.005
# If soft_target_update is False, this is the frequency of the hard trarget updates in terms of n_optimizer_steps
# If soft_target_update is False, this is the frequency of the hard target updates in terms of n_optimizer_steps
hard_target_update_frequency: 5

# When an exploration wrapper is used. This is its initial epsilon for annealing
Expand All @@ -54,7 +54,7 @@ max_n_frames: 3_000_000
on_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection.
# Otherwise batching will be simulated and each env will be run sequentially or parallel depending on parallel_collection.
on_policy_n_envs_per_worker: 10
# This is the number of times collected_frames_per_batch will be split into minibatches and trained
on_policy_n_minibatch_iters: 45
Expand All @@ -66,7 +66,7 @@ on_policy_minibatch_size: 400
off_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection.
# Otherwise batching will be simulated and each env will be run sequentially or parallel depending on parallel_collection.
off_policy_n_envs_per_worker: 10
# This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over.
off_policy_n_optimizer_steps: 1000
Expand All @@ -76,11 +76,18 @@ off_policy_train_batch_size: 128
off_policy_memory_size: 1_000_000
# Number of random action frames to prefill the replay buffer with
off_policy_init_random_frames: 0
# whether to use priorities while sampling from the replay buffer
off_policy_use_prioritized_replay_buffer: False
# exponent that determines how much prioritization is used when off_policy_use_prioritized_replay_buffer = True
# PRB reduces to random sampling when alpha=0
off_policy_prb_alpha: 0.6
# importance sampling negative exponent when off_policy_use_prioritized_replay_buffer = True
off_policy_prb_beta: 0.4


evaluation: True
# Whether to render the evaluation (if rendering is available)
render: True
render: False
# Frequency of evaluation in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch)
evaluation_interval: 120_000
# Number of episodes that evaluation is run on
Expand Down Expand Up @@ -108,7 +115,7 @@ restore_map_location: null
# Interval for experiment saving in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch).
# Set it to 0 to disable checkpointing
checkpoint_interval: 0
# Wether to checkpoint when the experiment is done
# Whether to checkpoint when the experiment is done
checkpoint_at_end: False
# How many checkpoints to keep. As new checkpoints are taken, temporally older checkpoints are deleted to keep this number of
# checkpoints. The checkpoint at the end is included in this number. Set to `null` to keep all checkpoints.
Expand Down
3 changes: 3 additions & 0 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class ExperimentConfig:
off_policy_train_batch_size: int = MISSING
off_policy_memory_size: int = MISSING
off_policy_init_random_frames: int = MISSING
off_policy_use_prioritized_replay_buffer: bool = MISSING
off_policy_prb_alpha: float = MISSING
off_policy_prb_beta: float = MISSING

evaluation: bool = MISSING
render: bool = MISSING
Expand Down

0 comments on commit 8725a34

Please sign in to comment.