-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGMVAE.py
388 lines (312 loc) · 13 KB
/
GMVAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data.sampler import SubsetRandomSampler
from model.GMVAENet import GMVAENet
from loss_function import GMVAELossFunctions
from metric import Metrics
import matplotlib.pyplot as plt
class GMVAE:
def __init__(self, args):
self.num_epochs = args.epochs
self.cuda = args.cuda
self.verbose = args.verbose
self.batch_size = args.batch_size
self.batch_size_val = args.batch_size_val
self.learning_rate = args.learning_rate
self.decay_epoch = args.decay_epoch
self.lr_decay = args.lr_decay
self.w_cat = args.w_categ
self.w_gauss = args.w_gauss
self.w_rec = args.w_rec
self.rec_type = args.rec_type
self.num_classes = args.num_classes
self.gaussian_size = args.gaussian_size
self.input_size = args.input_size
# gumbel
self.init_temp = args.init_temp
self.decay_temp = args.decay_temp
self.hard_gumbel = args.hard_gumbel
self.min_temp = args.min_temp
self.decay_temp_rate = args.decay_temp_rate
self.gumbel_temp = self.init_temp
self.network = GMVAENet(self.input_size, self.gaussian_size, self.num_classes)
self.losses = GMVAELossFunctions()
self.metrics = Metrics()
if self.cuda:
self.network = self.network.cuda()
def unlabeled_loss(self, data, out_net):
"""Method defining the loss functions derived from the variational lower bound
Args:
data: (array) corresponding array containing the input data
out_net: (dict) contains the graph operations or nodes of the network output
Returns:
loss_dic: (dict) contains the values of each loss function and predictions
"""
# obtain network variables
z, data_recon = out_net['gaussian'], out_net['x_rec']
logits, prob_cat = out_net['logits'], out_net['prob_cat']
y_mu, y_var = out_net['y_mean'], out_net['y_var']
mu, var = out_net['mean'], out_net['var']
# reconstruction loss
loss_rec = self.losses.reconstruction_loss(data, data_recon, self.rec_type)
# gaussian loss
loss_gauss = self.losses.gaussian_loss(z, mu, var, y_mu, y_var)
# categorical loss
loss_cat = -self.losses.entropy(logits, prob_cat) - np.log(1.0/self.num_classes)
# total loss
loss_total = self.w_rec * loss_rec + self.w_gauss * loss_gauss + self.w_cat * loss_cat
# obtain predictions
_, predicted_labels = torch.max(logits, dim=1)
loss_dic = {'total': loss_total,
'predicted_labels': predicted_labels,
'reconstruction': loss_rec,
'gaussian': loss_gauss,
'categorical': loss_cat}
return loss_dic
def train_epoch(self, optimizer, data_loader):
"""Train the model for one epoch
Args:
optimizer: (Optim) optimizer to use in backpropagation
data_loader: (DataLoader) corresponding loader containing the training data
Returns:
average of all loss values, accuracy, nmi
"""
self.network.train()
total_loss = 0.
recon_loss = 0.
cat_loss = 0.
gauss_loss = 0.
accuracy = 0.
nmi = 0.
num_batches = 0.
true_labels_list = []
predicted_labels_list = []
# iterate over the dataset
for (data, labels) in data_loader:
if self.cuda == 1:
data = data.cuda()
labels= labels.long()
optimizer.zero_grad()
# flatten data
# data = data.view(data.size(0), -1)
# forward call
out_net = self.network(data, self.gumbel_temp, self.hard_gumbel)
unlab_loss_dic = self.unlabeled_loss(data, out_net)
total = unlab_loss_dic['total']
# accumulate values
total_loss += total.item()
recon_loss += unlab_loss_dic['reconstruction'].item()
gauss_loss += unlab_loss_dic['gaussian'].item()
cat_loss += unlab_loss_dic['categorical'].item()
# perform backpropagation
total.backward()
optimizer.step()
# save predicted and true labels
predicted = unlab_loss_dic['predicted_labels']
true_labels_list.append(labels)
predicted_labels_list.append(predicted)
num_batches += 1.
# average per batch
total_loss /= num_batches
recon_loss /= num_batches
gauss_loss /= num_batches
cat_loss /= num_batches
# concat all true and predicted labels
true_labels = torch.cat(true_labels_list, dim=0).cpu().numpy()
predicted_labels = torch.cat(predicted_labels_list, dim=0).cpu().numpy()
# compute metrics
accuracy = 100.0 * self.metrics.cluster_acc(predicted_labels, true_labels)
nmi = 100.0 * self.metrics.nmi(predicted_labels, true_labels)
return total_loss, recon_loss, gauss_loss, cat_loss, accuracy, nmi
def test(self, data_loader, return_loss=False):
"""Test the model with new data
Args:
data_loader: (DataLoader) corresponding loader containing the test/validation data
return_loss: (boolean) whether to return the average loss values
Return:
accuracy and nmi for the given test data
"""
self.network.eval()
total_loss = 0.
recon_loss = 0.
cat_loss = 0.
gauss_loss = 0.
accuracy = 0.
nmi = 0.
num_batches = 0.
true_labels_list = []
predicted_labels_list = []
with torch.no_grad():
for data, labels in data_loader:
if self.cuda == 1:
data = data.cuda()
labels= labels.long()
# flatten data
# data = data.view(data.size(0), -1)
# forward call
out_net = self.network(data, self.gumbel_temp, self.hard_gumbel)
unlab_loss_dic = self.unlabeled_loss(data, out_net)
# accumulate values
total_loss += unlab_loss_dic['total'].item()
recon_loss += unlab_loss_dic['reconstruction'].item()
gauss_loss += unlab_loss_dic['gaussian'].item()
cat_loss += unlab_loss_dic['categorical'].item()
# save predicted and true labels
predicted = unlab_loss_dic['predicted_labels']
true_labels_list.append(labels)
predicted_labels_list.append(predicted)
num_batches += 1.
# average per batch
if return_loss:
total_loss /= num_batches
recon_loss /= num_batches
gauss_loss /= num_batches
cat_loss /= num_batches
# concat all true and predicted labels
true_labels = torch.cat(true_labels_list, dim=0).cpu().numpy()
predicted_labels = torch.cat(predicted_labels_list, dim=0).cpu().numpy()
# compute metrics
#accuracy = 100.0 * self.metrics.cluster_acc(predicted_labels, true_labels)
accuracy = 100.0 * (predicted_labels == true_labels).sum() / predicted_labels.size
nmi = 100.0 * self.metrics.nmi(predicted_labels, true_labels)
if return_loss:
return total_loss, recon_loss, gauss_loss, cat_loss, accuracy, nmi
else:
return accuracy, nmi
def train(self, train_loader, val_loader):
"""Train the model
Args:
train_loader: (DataLoader) corresponding loader containing the training data
val_loader: (DataLoader) corresponding loader containing the validation data
Returns:
output: (dict) contains the history of train/val loss
"""
optimizer = optim.Adam(self.network.parameters(), lr=self.learning_rate)
train_history_acc, val_history_acc = [], []
train_history_nmi, val_history_nmi = [], []
for epoch in range(1, self.num_epochs + 1):
train_loss, train_rec, train_gauss, train_cat, train_acc, train_nmi = self.train_epoch(optimizer, train_loader)
val_loss, val_rec, val_gauss, val_cat, val_acc, val_nmi = self.test(val_loader, True)
# if verbose then print specific information about training
if self.verbose == 1:
print("(Epoch %d / %d)" % (epoch, self.num_epochs) )
print("Train - REC: %.5lf; Gauss: %.5lf; Cat: %.5lf;" % \
(train_rec, train_gauss, train_cat))
print("Valid - REC: %.5lf; Gauss: %.5lf; Cat: %.5lf;" % \
(val_rec, val_gauss, val_cat))
print("Accuracy=Train: %.5lf; Val: %.5lf NMI=Train: %.5lf; Val: %.5lf Total Loss=Train: %.5lf; Val: %.5lf" % \
(train_acc, val_acc, train_nmi, val_nmi, train_loss, val_loss))
else:
print('(Epoch %d / %d) Train_Loss: %.3lf; Val_Loss: %.3lf Train_ACC: %.3lf; Val_ACC: %.3lf Train_NMI: %.3lf; Val_NMI: %.3lf' % \
(epoch, self.num_epochs, train_loss, val_loss, train_acc, val_acc, train_nmi, val_nmi))
# decay gumbel temperature
if self.decay_temp == 1:
self.gumbel_temp = np.maximum(self.init_temp * np.exp(-self.decay_temp_rate * epoch), self.min_temp)
if self.verbose == 1:
print("Gumbel Temperature: %.3lf" % self.gumbel_temp)
train_history_acc.append(train_acc)
val_history_acc.append(val_acc)
train_history_nmi.append(train_nmi)
val_history_nmi.append(val_nmi)
return {'train_history_nmi' : train_history_nmi, 'val_history_nmi': val_history_nmi,
'train_history_acc': train_history_acc, 'val_history_acc': val_history_acc}
def latent_features(self, data_loader, return_labels=False):
"""Obtain latent features learnt by the model
Args:
data_loader: (DataLoader) loader containing the data
return_labels: (boolean) whether to return true labels or not
Returns:
features: (array) array containing the features from the data
"""
self.network.eval()
N = len(data_loader.dataset)
features = np.zeros((N, self.gaussian_size))
if return_labels:
true_labels = np.zeros(N, dtype=np.int64)
start_ind = 0
with torch.no_grad():
for (data, labels) in data_loader:
if self.cuda == 1:
data = data.cuda()
# flatten data
data = data.view(data.size(0), -1)
out = self.network.inference(data, self.gumbel_temp, self.hard_gumbel)
latent_feat = out['mean']
end_ind = min(start_ind + data.size(0), N+1)
# return true labels
if return_labels:
true_labels[start_ind:end_ind] = labels.cpu().numpy()
features[start_ind:end_ind] = latent_feat.cpu().detach().numpy()
start_ind += data.size(0)
if return_labels:
return features, true_labels
return features
def reconstruct_data(self, data_loader, sample_size=-1):
"""Reconstruct Data
Args:
data_loader: (DataLoader) loader containing the data
sample_size: (int) size of random data to consider from data_loader
Returns:
reconstructed: (array) array containing the reconstructed data
"""
self.network.eval()
# sample random data from loader
indices = np.random.randint(0, len(data_loader.dataset), size=sample_size)
test_random_loader = torch.utils.data.DataLoader(data_loader.dataset, batch_size=sample_size, sampler=SubsetRandomSampler(indices))
# obtain values
it = iter(test_random_loader)
test_batch_data, _ = it.next()
original = test_batch_data.data.numpy()
if self.cuda:
test_batch_data = test_batch_data.cuda()
# obtain reconstructed data
out = self.network(test_batch_data, self.gumbel_temp, self.hard_gumbel)
reconstructed = out['x_rec']
return original, reconstructed.data.cpu().numpy()
def plot_latent_space(self, data_loader, labels, save=True, suffix=None):
"""Plot the latent space learnt by the model
Args:
data: (array) corresponding array containing the data
labels: (array) corresponding array containing the labels
save: (bool) whether to save the latent space plot
Returns:
fig: (figure) plot of the latent space
"""
# obtain the latent features
features = self.latent_features(data_loader)
# plot only the first 2 dimensions
fig = plt.figure(figsize=(8, 6))
plt.scatter(features[:, 0], features[:, 1], c=labels, marker='o',
edgecolor='none', cmap=plt.cm.get_cmap('jet', 10), s = 10)
plt.colorbar()
if(save):
if suffix is not None:
fig.savefig('results/latent_space_'+ str(suffix) + '.png')
else:
fig.savefig('latent_space.png')
return fig
def random_generation(self, num_elements=1):
"""Random generation for each category
Args:
num_elements: (int) number of elements to generate
Returns:
generated data according to num_elements
"""
# categories for each element
arr = np.array([])
for i in range(self.num_classes):
arr = np.hstack([arr,np.ones(num_elements) * i] )
indices = arr.astype(int).tolist()
categorical = F.one_hot(torch.tensor(indices), self.num_classes).float()
if self.cuda:
categorical = categorical.cuda()
# infer the gaussian distribution according to the category
mean, var = self.network.generative.pzy(categorical)
# gaussian random sample by using the mean and variance
noise = torch.randn_like(var)
std = torch.sqrt(var)
gaussian = mean + noise * std
# generate new samples with the given gaussian
generated = self.network.generative.pxz(gaussian)
return generated.cpu().detach().numpy()