-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_trained_agent.py
executable file
·127 lines (111 loc) · 3.86 KB
/
run_trained_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python
import os
import gymnasium as gym
from gymnasium.wrappers import TimeLimit, FlattenObservation
from stable_baselines3.common.monitor import Monitor
from stable_baselines3 import PPO, DQN, A2C
from sb3_contrib import QRDQN, RecurrentPPO, MaskablePPO, TRPO
import rlp
if __name__ == "__main__":
parser = rlp.puzzle.make_puzzle_parser()
parser.add_argument(
"-t", "--timesteps", type=int, help="Number of timesteps during training"
)
parser.add_argument("-r", "--numrun", type=int, help="Training Run Number")
parser.add_argument(
"-alg",
"--algorithm",
type=str,
help="Choice of RL Algorithm used for training",
choices=[
"PPO",
"DQN",
"A2C",
"HER",
"ARS",
"QRDQN",
"RecurrentPPO",
"TRPO",
"MaskablePPO",
],
)
parser.add_argument(
"-ot",
"--obs-type",
type=str,
help="Type of observation",
choices=["rgb", "puzzle_state"],
default="puzzle_state",
)
args = parser.parse_args()
data_dir = f"/tmp/rlp/monitor/"
if args.allowundo:
undo_prefix = "undo"
else:
undo_prefix = "noundo"
log_dir = f"{data_dir}trained_runs/{args.algorithm}_{args.timesteps}/{args.puzzle}_{args.arg}_{undo_prefix}_{args.obs_type}/"
os.makedirs(log_dir, exist_ok=True)
render_mode = "human" if not args.headless else "rgb_array"
allow_undo = True if args.allowundo else False
print(f"log_dir = {log_dir}")
env_kwargs = dict(
puzzle=args.puzzle,
params=args.arg,
render_mode=render_mode,
window_width=128,
window_height=128,
allow_undo=args.allowundo,
max_state_repeats=10000,
include_cursor_in_state_info=True,
obs_type=args.obs_type,
)
env = gym.make("rlp/Puzzle-v0", None, None, None, **env_kwargs)
if args.obs_type == "puzzle_state":
env = FlattenObservation(env)
model_prefix = "best"
model_suffix = f"_{args.puzzle}"
env = Monitor(env, log_dir, override_existing=True)
max_timesteps = 10000
env = TimeLimit(env, max_timesteps)
model_file = f"{data_dir}{args.algorithm}_{args.timesteps}/{args.puzzle}_{args.arg}_{undo_prefix}_{args.obs_type}/best_model{model_suffix}"
print(f"Loading model {model_file}")
buffer_size = 1000000
model: PPO | DQN | A2C | QRDQN | RecurrentPPO | TRPO | MaskablePPO
if args.algorithm == "PPO":
model = PPO.load(model_file, env=env)
elif args.algorithm == "DQN":
model = DQN.load(model_file, env=env)
elif args.algorithm == "A2C":
model = A2C.load(model_file, env=env)
elif args.algorithm == "QRDQN":
model = QRDQN.load(model_file, env=env)
elif args.algorithm == "RecurrentPPO":
model = RecurrentPPO.load(model_file, env=env)
elif args.algorithm == "TRPO":
model = TRPO.load(model_file, env=env)
elif args.algorithm == "MaskablePPO":
model = MaskablePPO.load(model_file, env=env)
else:
raise Exception(f"{args.algorithm} is not supported")
episodes = 1000
timesteps = 0
obs, info = env.reset()
while episodes > 0:
action, _ = model.predict(obs, deterministic=False)
obs, reward, terminated, truncated, info = env.step(int(action))
timesteps += 1
if terminated or truncated:
episodes -= 1
obs, info = env.reset()
if terminated:
print(
f"episode {1000-episodes} terminated after {timesteps} steps",
flush=True,
)
if truncated:
print(
f"episode {1000-episodes} truncated after {timesteps} steps",
flush=True,
)
timesteps = 0
env.close()