Skip to content

Commit

Permalink
Merge pull request #293 from huangshiyu13/main
Browse files Browse the repository at this point in the history
update atari
  • Loading branch information
huangshiyu13 authored Jan 10, 2024
2 parents 7ccecd2 + abf88d4 commit 8c89a8c
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 11 deletions.
11 changes: 9 additions & 2 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,19 @@ Then install auto-rom via:
or:
```shell
pip install autorom

AutoROM --accept-license
```

or, if you can not download the ROMs, you can download them manually from [Google Drive](https://drive.google.com/file/d/1agerLX3fP2YqUCcAkMF7v_ZtABAOhlA7/view?usp=sharing).
Then, you can install the ROMs via:
```shell
pip install autorom
AutoROM --source-file <path-to-Roms.tar.gz>
````


## Usage

```shell
python train_ppo.py --config atari_ppo.yaml
python train_ppo.py
```
13 changes: 9 additions & 4 deletions examples/atari/atari_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@ seed: 0
lr: 2.5e-4
critic_lr: 2.5e-4
episode_length: 128
ppo_epoch: 4
gamma: 0.99
ppo_epoch: 3
gain: 0.01
use_linear_lr_decay: true
use_share_model: true
entropy_coef: 0.01
hidden_size: 512
num_mini_batch: 4
clip_param: 0.1
num_mini_batch: 8
clip_param: 0.2
value_loss_coef: 0.5
max_grad_norm: 10

run_dir: ./run_results/
experiment_name: atari_ppo

log_interval: 1
use_recurrent_policy: false
use_valuenorm: true
use_adv_normalize: true

wandb_entity: openrl-lab
experiment_name: atari_ppo

vec_info_class:
id: "EPS_RewardInfo"
6 changes: 3 additions & 3 deletions examples/atari/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@

def train():
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args()
cfg = cfg_parser.parse_args(["--config", "atari_ppo.yaml"])

# create environment, set environment parallelism to 9
env = make(
"ALE/Pong-v5", env_num=9, cfg=cfg, asynchronous=True, env_wrappers=env_wrappers
"ALE/Pong-v5", env_num=16, cfg=cfg, asynchronous=True, env_wrappers=env_wrappers
)

# create the neural network
Expand All @@ -56,7 +56,7 @@ def train():
env, cfg=cfg, device="cuda" if "macOS" not in get_system_info()["OS"] else "cpu"
)
# initialize the trainer
agent = Agent(net, use_wandb=True)
agent = Agent(net, use_wandb=True, project_name="Pong-v5")
# start training, set total number of training steps to 20000

agent.train(total_time_steps=5000000)
Expand Down
4 changes: 2 additions & 2 deletions openrl/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,8 @@ def create_config_parser():
parser.add_argument(
"--max_grad_norm",
type=float,
default=10.0,
help="max norm of gradients (default: 0.5)",
default=10,
help="max norm of gradients (default: 10)",
)
parser.add_argument(
"--use_gae",
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_install_requires() -> list:
"mujoco",
"tqdm",
"Jinja2",
"pettingzoo",
]


Expand Down

0 comments on commit 8c89a8c

Please sign in to comment.