-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdilo_utils.py
159 lines (128 loc) · 4.83 KB
/
dilo_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import csv
from datetime import datetime
import json
from pathlib import Path
import random
import string
import sys
import numpy as np
import torch
import torch.nn as nn
DEFAULT_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Squeeze(nn.Module):
def __init__(self, dim=None):
super().__init__()
self.dim = dim
def forward(self, x):
return x.squeeze(dim=self.dim)
def mlp(dims, activation=nn.ReLU, output_activation=None, squeeze_output=False,layer_norm=False):
n_dims = len(dims)
assert n_dims >= 2, 'MLP requires at least two dims (input and output)'
layers = []
for i in range(n_dims - 2):
layers.append(nn.Linear(dims[i], dims[i+1]))
if layer_norm:
layers.append(nn.LayerNorm(dims[i+1]))
layers.append(activation())
layers.append(nn.Linear(dims[-2], dims[-1]))
if output_activation is not None:
layers.append(output_activation())
if squeeze_output:
assert dims[-1] == 1
layers.append(Squeeze(-1))
net = nn.Sequential(*layers)
net.to(dtype=torch.float32)
return net
def compute_batched(f, xs):
return f(torch.cat(xs, dim=0)).split([len(x) for x in xs])
def update_exponential_moving_average(target, source, alpha):
for target_param, source_param in zip(target.parameters(), source.parameters()):
target_param.data.mul_(1. - alpha).add_(source_param.data, alpha=alpha)
def torchify(x):
x = torch.from_numpy(x)
if x.dtype is torch.float64:
x = x.float()
x = x.to(device=DEFAULT_DEVICE)
return x
def return_range(dataset, max_episode_steps):
returns, lengths = [], []
ep_ret, ep_len = 0., 0
for r, d in zip(dataset['rewards'], dataset['terminals']):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0., 0
# returns.append(ep_ret) # incomplete trajectory
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset['rewards'])
return min(returns), max(returns)
# dataset is a dict, values of which are tensors of same first dimension
def sample_batch(dataset, batch_size):
k = list(dataset.keys())[0]
n, device = len(dataset[k]), dataset[k].device
for v in dataset.values():
assert len(v) == n, 'Dataset values must have same length'
indices = torch.randint(low=0, high=n, size=(batch_size,), device=device)
return {k: v[indices] for k, v in dataset.items()}
def evaluate_policy(env, policy, max_episode_steps, deterministic=True):
obs = env.reset()
total_reward = 0.
for _ in range(max_episode_steps):
with torch.no_grad():
action = policy.act(torchify(obs), deterministic=deterministic).cpu().numpy()
next_obs, reward, done, info = env.step(action)
total_reward += reward
if done:
break
else:
obs = next_obs
return total_reward
def set_seed(seed, env=None):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
if env is not None:
env.seed(seed)
def _gen_dir_name():
now_str = datetime.now().strftime('%m-%d-%y_%H.%M.%S')
rand_str = ''.join(random.choices(string.ascii_lowercase, k=4))
return f'{now_str}_{rand_str}'
class Log:
def __init__(self, root_log_dir, cfg_dict,
txt_filename='log.txt',
csv_filename='progress.csv',
cfg_filename='config.json',
flush=True):
self.dir = Path(root_log_dir)/_gen_dir_name()
self.dir.mkdir(parents=True)
self.txt_file = open(self.dir/txt_filename, 'w')
self.csv_file = None
(self.dir/cfg_filename).write_text(json.dumps(cfg_dict))
self.txt_filename = txt_filename
self.csv_filename = csv_filename
self.cfg_filename = cfg_filename
self.flush = flush
def write(self, message, end='\n'):
now_str = datetime.now().strftime('%H:%M:%S')
message = f'[{now_str}] ' + message
for f in [sys.stdout, self.txt_file]:
print(message, end=end, file=f, flush=self.flush)
def __call__(self, *args, **kwargs):
self.write(*args, **kwargs)
def row(self, dict):
if self.csv_file is None:
self.csv_file = open(self.dir/self.csv_filename, 'w', newline='')
self.csv_writer = csv.DictWriter(self.csv_file, list(dict.keys()))
self.csv_writer.writeheader()
self(str(dict))
self.csv_writer.writerow(dict)
if self.flush:
self.csv_file.flush()
def close(self):
self.txt_file.close()
if self.csv_file is not None:
self.csv_file.close()