-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathpolicies.py
146 lines (123 loc) · 5.77 KB
/
policies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import random
import numpy as np
import torch
from torchvision import transforms
import networks
from envs import VectorEnv
class DQNPolicy:
def __init__(self, cfg, train=False, random_seed=None):
self.cfg = cfg
self.robot_group_types = [next(iter(g.keys())) for g in self.cfg.robot_config]
self.train = train
if random_seed is not None:
random.seed(random_seed)
self.num_robot_groups = len(self.robot_group_types)
self.transform = transforms.ToTensor()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.policy_nets = self.build_policy_nets()
# Resume if applicable
if self.cfg.checkpoint_path is not None:
self.policy_checkpoint = torch.load(self.cfg.policy_path, map_location=self.device)
for i in range(self.num_robot_groups):
self.policy_nets[i].load_state_dict(self.policy_checkpoint['state_dicts'][i])
if self.train:
self.policy_nets[i].train()
else:
self.policy_nets[i].eval()
print("=> loaded policy '{}'".format(self.cfg.policy_path))
def build_policy_nets(self):
policy_nets = []
for robot_type in self.robot_group_types:
num_output_channels = VectorEnv.get_num_output_channels(robot_type)
policy_nets.append(torch.nn.DataParallel(
networks.FCN(num_input_channels=self.cfg.num_input_channels, num_output_channels=num_output_channels)
).to(self.device))
return policy_nets
def apply_transform(self, s):
return self.transform(s).unsqueeze(0)
def step(self, state, exploration_eps=None, debug=False):
if exploration_eps is None:
exploration_eps = self.cfg.final_exploration
action = [[None for _ in g] for g in state]
output = [[None for _ in g] for g in state]
with torch.no_grad():
for i, g in enumerate(state):
robot_type = self.robot_group_types[i]
self.policy_nets[i].eval()
for j, s in enumerate(g):
if s is not None:
s = self.apply_transform(s).to(self.device)
o = self.policy_nets[i](s).squeeze(0)
if random.random() < exploration_eps:
a = random.randrange(VectorEnv.get_action_space(robot_type))
else:
a = o.view(1, -1).max(1)[1].item()
action[i][j] = a
output[i][j] = o.cpu().numpy()
if self.train:
self.policy_nets[i].train()
if debug:
info = {'output': output}
return action, info
return action
class DQNIntentionPolicy(DQNPolicy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.intention_nets = self.build_intention_nets()
if self.cfg.checkpoint_path is not None:
for i in range(self.num_robot_groups):
self.intention_nets[i].load_state_dict(self.policy_checkpoint['state_dicts_intention'][i])
if self.train:
self.intention_nets[i].train()
else:
self.intention_nets[i].eval()
print("=> loaded intention network '{}'".format(self.cfg.policy_path))
def build_intention_nets(self):
intention_nets = []
for _ in range(self.num_robot_groups):
intention_nets.append(torch.nn.DataParallel(
networks.FCN(num_input_channels=(self.cfg.num_input_channels - 1), num_output_channels=1)
).to(self.device))
return intention_nets
def step_intention(self, state, debug=False):
state_intention = [[None for _ in g] for g in state]
output_intention = [[None for _ in g] for g in state]
with torch.no_grad():
for i, g in enumerate(state):
self.intention_nets[i].eval()
for j, s in enumerate(g):
if s is not None:
s_copy = s.copy()
s = self.apply_transform(s).to(self.device)
o = torch.sigmoid(self.intention_nets[i](s)).squeeze(0).squeeze(0).cpu().numpy()
state_intention[i][j] = np.concatenate((s_copy, np.expand_dims(o, 2)), axis=2)
output_intention[i][j] = o
if self.train:
self.intention_nets[i].train()
if debug:
info = {'output_intention': output_intention}
return state_intention, info
return state_intention
def step(self, state, exploration_eps=None, debug=False, use_ground_truth_intention=False):
if self.train and use_ground_truth_intention:
# Use the ground truth intention map
return super().step(state, exploration_eps=exploration_eps, debug=debug)
if self.train:
# Remove ground truth intention map
state_copy = [[None for _ in g] for g in state]
for i, g in enumerate(state):
for j, s in enumerate(g):
if s is not None:
state_copy[i][j] = s[:, :, :-1]
state = state_copy
# Add predicted intention map to state
state = self.step_intention(state, debug=debug)
if debug:
state, info_intention = state
action = super().step(state, exploration_eps=exploration_eps, debug=debug)
if debug:
action, info = action
info['state_intention'] = state
info['output_intention'] = info_intention['output_intention']
return action, info
return action