-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_agent1.py
89 lines (76 loc) · 3.52 KB
/
train_agent1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import torch
import gymnasium as gym
import gym_maze
from common.storage import Storage
from common.model import ImpalaModel
from common.policy import CategoricalPolicy
from agents.ppo import PPO as AGENT
from stable_baselines3.common.vec_env import DummyVecEnv
def make_env(rank):
def init():
env = gym.make("maze-random-10x10-v0", seed=42+rank, target=1) # Set target to random location
return env
return init
if __name__=='__main__':
MAX_T = 3000000
N_ENVS = 2
N_STEPS = 500
RENDER_MAZE = True # Make False
CHECKPOINT_PATH = 'saved_models/agent1/'
LOAD_FROM_TIMESTEP = str(1000)
LOG_FILE = "log_agent1.csv"
PERFORMANCE_FILE = "performance_agent1.csv"
N_CHECKPOINTS = 5
# Set Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# Make Vectorized Env
env = DummyVecEnv([make_env(i) for i in range(N_ENVS)])
# Set Env Spaces
observation_space = env.observation_space
observation_shape = observation_space.shape
in_channels = observation_shape[2]
action_space = env.action_space
# Model & Policy
model = ImpalaModel(in_channels=in_channels).to(device)
if isinstance(action_space, gym.spaces.Discrete):
recurrent = False
action_size = action_space.n
policy = CategoricalPolicy(model, recurrent, action_size)
else:
raise NotImplementedError
policy.to(device)
# Initialize Storage
print('Initializing Storage...')
hidden_state_dim = model.output_dim
storage = Storage(observation_shape, hidden_state_dim, N_STEPS, N_ENVS, device=device)
storage_valid = Storage(observation_shape, hidden_state_dim, N_STEPS, N_ENVS, device=device)
print(f'Storage Valid...:{storage_valid}')
agent = AGENT(env, policy, logger=None, storage=storage, device=device, n_checkpoints=N_CHECKPOINTS, n_steps=N_STEPS, n_envs=N_ENVS)
# To load the most recent checkpoint
if os.path.exists(CHECKPOINT_PATH):
all_files = os.listdir(CHECKPOINT_PATH)
checkpoint_files = [file for file in all_files if file.endswith('.pt')]
if checkpoint_files:
latest_checkpoint = str(max(int(checkpoint_file[:-3]) for checkpoint_file in checkpoint_files)) + '.pt'
checkpoint_path = os.path.join(CHECKPOINT_PATH, latest_checkpoint)
checkpoint = torch.load(checkpoint_path)
print(f'Loading checkpoint from {checkpoint_path}')
agent.policy.load_state_dict(checkpoint['model_state_dict']) #, map_location=torch.device(device))
print("Model policy loaded.")
agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) #, map_location=torch.device(device))
print("Model optimizer loaded.")
else:
print("No checkpoint files found.")
# # To load a specific checkpoint
# if os.path.exists(CHECKPOINT_PATH + "model_" + LOAD_FROM_TIMESTEP + '.pt'):
# print("Loading file...")
# checkpoint = torch.load(CHECKPOINT_PATH + "model_" + LOAD_FROM_TIMESTEP + '.pt')
# print(f'Loading checkpoint from {checkpoint_path}')
# agent.policy.load_state_dict(checkpoint['model_state_dict']) #, map_location=torch.device(device))
# print("Model policy loaded.")
# agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) #, map_location=torch.device(device))
# print("Model optimizer loaded.")
print('START TRAINING...')
agent.train(MAX_T, CHECKPOINT_PATH, LOG_FILE, PERFORMANCE_FILE)