forked from cgumbsch/goal_anticipations_via_event-inference
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgaussian_networks.py
69 lines (61 loc) · 2.22 KB
/
gaussian_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Implementation of a single-layered Mixture Density Network
producing a Gaussian distribution
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy
from torch.autograd import Variable
import math
import torch.optim as optim
def weights_init(m):
"""
Initialize weights normal distributed with sd = 0.01
:param m: weight matrix
:return: normal distributed weights
"""
m.weight.data.normal_(0.0, 0.01)
class Multivariate_Gaussian_Network(nn.Module):
def __init__(self, input_dim, output_dim):
"""
Initialization
:param input_dim: dimensionality of input
:param output_dim: dimensionality of output
"""
super(Multivariate_Gaussian_Network, self).__init__()
self.fcMu = nn.Linear(input_dim, output_dim)
weights_init(self.fcMu)
self.fcSigma = nn.Linear(input_dim, output_dim)
weights_init(self.fcSigma)
def forward(self, x):
"""
Forward pass of input
:param x: input
:return: mu, Sigma of resulting output distribution
"""
mu = self.fcMu(x)
# Sigma determined with ELUs + 1 + p to ensure values > 0
# small p > 0 avoids that Sigma == 0
sigma = F.elu(self.fcSigma(x)) + 1.00000000001
return mu, sigma
def get_optimizer(self, learning_rate, momentum_term):
"""
:param learning_rate: learning rate of SGD
:param momentum_term: momentum term used for SGD
:return: optimizer of the network
"""
return optim.SGD(self.parameters(), lr=learning_rate, momentum=momentum_term)
def loss_criterion(self, output, label):
"""
Loss function, i.e., negative log likelihood
:param output: output (mu, Sigma) of the network
:param label: nominal output
:return: negative log likelihood of nominal label under output distribution
"""
mu = output[0]
sigma = torch.diag(output[1])
distr = torch.distributions.MultivariateNormal(mu, sigma)
# negative log likelihood is squashed by tanh * 100 to avoid loss > 100
# multiplied by constant factor c = 100
return torch.tanh( -1 * distr.log_prob(label) *(1.0/100)) * 100