-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransition_model.py
98 lines (74 loc) · 3.39 KB
/
transition_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import random
import torch
import torch.nn as nn
class DeterministicTransitionModel(nn.Module):
def __init__(self, encoder_feature_dim, action_shape, layer_width):
super().__init__()
self.fc = nn. Linear(encoder_feature_dim + action_shape[0], layer_width)
self.ln = nn.LayerNorm(layer_width)
self.fc_mu = nn.Linear(layer_width, encoder_feature_dim)
print("Deterministic transition model chosen.")
def forward(self, x):
x = self.fc(x)
x = self.ln(x)
x = torch.relu(x)
mu = self.fc_mu(x)
sigma = None
return mu, sigma
def sample_prediction(self, x):
mu, sigma = self(x)
return mu
class ProbabilisticTransitionModel(nn.Module):
def __init__(self, encoder_feature_dim, action_shape, layer_width, announce=True, max_sigma=1e1, min_sigma=1e-4):
super().__init__()
self.fc = nn. Linear(encoder_feature_dim + action_shape[0], layer_width)
self.ln = nn.LayerNorm(layer_width)
self.fc_mu = nn.Linear(layer_width, encoder_feature_dim)
self.fc_sigma = nn.Linear(layer_width, encoder_feature_dim)
self.max_sigma = max_sigma
self.min_sigma = min_sigma
assert(self.max_sigma >= self.min_sigma)
if announce:
print("Probabilistic transition model chosen.")
def forward(self, x):
x = self.fc(x)
x = self.ln(x)
x = torch.relu(x)
mu = self.fc_mu(x)
sigma = torch.sigmoid(self.fc_sigma(x)) # range (0, 1.)
sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * sigma # scaled range (min_sigma, max_sigma)
return mu, sigma
def sample_prediction(self, x):
mu, sigma = self(x)
eps = torch.randn_like(sigma)
return mu + sigma * eps
class EnsembleOfProbabilisticTransitionModels(object):
def __init__(self, encoder_feature_dim, action_shape, layer_width, ensemble_size=5):
self.models = [ProbabilisticTransitionModel(encoder_feature_dim, action_shape, layer_width, announce=False)
for _ in range(ensemble_size)]
print("Ensemble of probabilistic transition models chosen.")
def __call__(self, x):
mu_sigma_list = [model.forward(x) for model in self.models]
mus, sigmas = zip(*mu_sigma_list)
mus, sigmas = torch.stack(mus), torch.stack(sigmas)
return mus, sigmas
def sample_prediction(self, x):
model = random.choice(self.models)
return model.sample_prediction(x)
def to(self, device):
for model in self.models:
model.to(device)
return self
def parameters(self):
list_of_parameters = [list(model.parameters()) for model in self.models]
parameters = [p for ps in list_of_parameters for p in ps]
return parameters
_AVAILABLE_TRANSITION_MODELS = {'': DeterministicTransitionModel,
'deterministic': DeterministicTransitionModel,
'probabilistic': ProbabilisticTransitionModel,
'ensemble': EnsembleOfProbabilisticTransitionModels}
def make_transition_model(transition_model_type, encoder_feature_dim, action_shape, layer_width=512):
assert transition_model_type in _AVAILABLE_TRANSITION_MODELS
return _AVAILABLE_TRANSITION_MODELS[transition_model_type](
encoder_feature_dim, action_shape, layer_width
)