-
Notifications
You must be signed in to change notification settings - Fork 695
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'vwxyzjn:master' into master
- Loading branch information
Showing
21 changed files
with
4,156 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# export WANDB_ENTITY=openrlbenchmark | ||
|
||
cd cleanrl/ppo_trxl | ||
poetry install | ||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids MortarMayhem-Grid-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --norm_adv --trxl_memory_length 119 --total_timesteps 100000000" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids MortarMayhem-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 275" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids MysteryPath-Grid-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 96 --total_timesteps 100000000" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids MysteryPath-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 256" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids SearingSpotlights-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 256" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids Endless-SearingSpotlights-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --reconstruction_coef 0.1 --trxl_memory_length 256 --total_timesteps 350000000" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
OMP_NUM_THREADS=4 poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids Endless-MortarMayhem-v0 Endless-MysteryPath-v0 \ | ||
--command "python ./cleanrl/ppo_trxl/ppo_trxl.py --track --trxl_memory_length 256 --total_timesteps 350000000" \ | ||
--num-seeds 3 \ | ||
--workers 32 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
poetry install | ||
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ | ||
--command "poetry run python cleanrl/pqn.py --no_cuda --track" \ | ||
--num-seeds 3 \ | ||
--workers 9 \ | ||
--slurm-gpus-per-task 1 \ | ||
--slurm-ntasks 1 \ | ||
--slurm-total-cpus 10 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
poetry install -E envpool | ||
poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \ | ||
--command "poetry run python cleanrl/pqn_atari_envpool.py --track" \ | ||
--num-seeds 3 \ | ||
--workers 9 \ | ||
--slurm-gpus-per-task 1 \ | ||
--slurm-ntasks 1 \ | ||
--slurm-total-cpus 10 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template | ||
|
||
poetry install -E envpool | ||
poetry run python -m cleanrl_utils.benchmark \ | ||
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \ | ||
--command "poetry run python cleanrl/pqn_atari_envpool_lstm.py --track" \ | ||
--num-seeds 3 \ | ||
--workers 9 \ | ||
--slurm-gpus-per-task 1 \ | ||
--slurm-ntasks 1 \ | ||
--slurm-total-cpus 10 \ | ||
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
|
||
python -m openrlbenchmark.rlops \ | ||
--filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ | ||
'pqn?tag=pr-472&cl=CleanRL PQN' \ | ||
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ | ||
--no-check-empty-runs \ | ||
--pc.ncols 3 \ | ||
--pc.ncols-legend 2 \ | ||
--output-filename benchmark/cleanrl/pqn \ | ||
--scan-history | ||
|
||
python -m openrlbenchmark.rlops \ | ||
--filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ | ||
'pqn_atari_envpool?tag=pr-472&cl=CleanRL PQN' \ | ||
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 Pong-v5 MsPacman-v5 \ | ||
--no-check-empty-runs \ | ||
--pc.ncols 3 \ | ||
--pc.ncols-legend 3 \ | ||
--rliable \ | ||
--rc.score_normalization_method maxmin \ | ||
--rc.normalized_score_threshold 1.0 \ | ||
--rc.sample_efficiency_plots \ | ||
--rc.sample_efficiency_and_walltime_efficiency_method Median \ | ||
--rc.performance_profile_plots \ | ||
--rc.aggregate_metrics_plots \ | ||
--rc.sample_efficiency_num_bootstrap_reps 10 \ | ||
--rc.performance_profile_num_bootstrap_reps 10 \ | ||
--rc.interval_estimates_num_bootstrap_reps 10 \ | ||
--output-filename static/0compare \ | ||
--scan-history | ||
|
||
python -m openrlbenchmark.rlops \ | ||
--filters '?we=rogercreus&wpn=cleanRL&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ | ||
'pqn_atari_envpool_lstm?tag=pr-472&cl=CleanRL PQN' \ | ||
--env-ids Breakout-v5 SpaceInvaders-v5 BeamRider-v5 MsPacman-v5 \ | ||
--no-check-empty-runs \ | ||
--pc.ncols 3 \ | ||
--pc.ncols-legend 3 \ | ||
--rliable \ | ||
--rc.score_normalization_method maxmin \ | ||
--rc.normalized_score_threshold 1.0 \ | ||
--rc.sample_efficiency_plots \ | ||
--rc.sample_efficiency_and_walltime_efficiency_method Median \ | ||
--rc.performance_profile_plots \ | ||
--rc.aggregate_metrics_plots \ | ||
--rc.sample_efficiency_num_bootstrap_reps 10 \ | ||
--rc.performance_profile_num_bootstrap_reps 10 \ | ||
--rc.interval_estimates_num_bootstrap_reps 10 \ | ||
--output-filename static/0compare \ | ||
--scan-history |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from dataclasses import dataclass | ||
|
||
import gymnasium as gym | ||
import torch | ||
import tyro | ||
from ppo_trxl import Agent, make_env | ||
|
||
|
||
@dataclass | ||
class Args: | ||
hub: bool = False | ||
"""whether to load the model from the huggingface hub or from the local disk""" | ||
name: str = "Endless-MortarMayhem-v0_12.nn" | ||
"""path to the model file""" | ||
|
||
|
||
if __name__ == "__main__": | ||
# Parse command line arguments and retrieve model path | ||
cli_args = tyro.cli(Args) | ||
if cli_args.hub: | ||
try: | ||
from huggingface_hub import hf_hub_download | ||
|
||
path = hf_hub_download(repo_id="LilHairdy/cleanrl_memory_gym", filename=cli_args.name) | ||
except: | ||
raise RuntimeError( | ||
"Cannot load model from the huggingface hub. Please install the huggingface_hub pypi package and verify the model name. You can also download the model from the hub manually and load it from disk." | ||
) | ||
else: | ||
path = cli_args.name | ||
|
||
# Load the pre-trained model and the original args used to train it | ||
checkpoint = torch.load(path) | ||
args = checkpoint["args"] | ||
args = type("Args", (), args) | ||
|
||
# Init environment and reset | ||
env = make_env(args.env_id, 0, False, "", "human")() | ||
obs, _ = env.reset() | ||
env.render() | ||
|
||
# Determine maximum episode steps | ||
max_episode_steps = env.spec.max_episode_steps | ||
if not max_episode_steps: | ||
max_episode_steps = env.max_episode_steps | ||
if max_episode_steps <= 0: | ||
max_episode_steps = 1024 # Memory Gym envs have max_episode_steps set to -1 | ||
# May episode impacts positional encoding, so make sure to set this accordingly | ||
|
||
# Setup agent and load its model parameters | ||
action_space_shape = ( | ||
(env.action_space.n,) if isinstance(env.action_space, gym.spaces.Discrete) else tuple(env.action_space.nvec) | ||
) | ||
agent = Agent(args, env.observation_space, action_space_shape, max_episode_steps) | ||
agent.load_state_dict(checkpoint["model_weights"]) | ||
|
||
# Setup Transformer-XL memory, mask and indices | ||
memory = torch.zeros((1, max_episode_steps, args.trxl_num_layers, args.trxl_dim), dtype=torch.float32) | ||
memory_mask = torch.tril(torch.ones((args.trxl_memory_length, args.trxl_memory_length)), diagonal=-1) | ||
repetitions = torch.repeat_interleave( | ||
torch.arange(0, args.trxl_memory_length).unsqueeze(0), args.trxl_memory_length - 1, dim=0 | ||
).long() | ||
memory_indices = torch.stack( | ||
[torch.arange(i, i + args.trxl_memory_length) for i in range(max_episode_steps - args.trxl_memory_length + 1)] | ||
).long() | ||
memory_indices = torch.cat((repetitions, memory_indices)) | ||
|
||
# Run episode | ||
done = False | ||
t = 0 | ||
while not done: | ||
# Prepare observation and memory | ||
obs = torch.Tensor(obs).unsqueeze(0) | ||
memory_window = memory[0, memory_indices[t].unsqueeze(0)] | ||
t_ = max(0, min(t, args.trxl_memory_length - 1)) | ||
mask = memory_mask[t_].unsqueeze(0) | ||
indices = memory_indices[t].unsqueeze(0) | ||
# Forward agent | ||
action, _, _, _, new_memory = agent.get_action_and_value(obs, memory_window, mask, indices) | ||
memory[:, t] = new_memory | ||
# Step | ||
obs, reward, termination, truncation, info = env.step(action.cpu().squeeze().numpy()) | ||
env.render() | ||
done = termination or truncation | ||
t += 1 | ||
|
||
if "r" in info["episode"].keys(): | ||
print(f"Episode return: {info['episode']['r'][0]}, Episode length: {info['episode']['l'][0]}") | ||
else: | ||
print(f"Episode return: {info['reward']}, Episode length: {info['length']}") | ||
env.close() |
Oops, something went wrong.