-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmi_INCE.py
executable file
·55 lines (43 loc) · 1.7 KB
/
mi_INCE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import math
import torch.distributions as tdis
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torch import optim
from torch.utils.data import TensorDataset, RandomSampler, BatchSampler, DataLoader
def InfoNCE(X, Y, batch_size=512, num_epochs=50, dev=torch.device("cpu"), model=None):
A = torch.tensor(batch_size).float().log()
if not model:
model = nn.Sequential(
nn.Linear(X.shape[1]+Y.shape[1], 16),
nn.ReLU(),
nn.Linear(16, 8),
nn.ReLU(),
nn.Linear(8, 1),
)
# Move data to device
X = X.to(dev)
Y = Y.to(dev) + torch.randn_like(Y) * 1e-4
model = model.to(dev)
opt = optim.Adam(model.parameters(), lr=0.01)
td = TensorDataset(X, Y)
result = []
for epoch in range(num_epochs):
for x, y in DataLoader(td, batch_size, shuffle=True, drop_last=True):
opt.zero_grad()
top = model(torch.cat([x, y], 1)).flatten()
xiyj = torch.cat([x.repeat_interleave(batch_size,dim=0),y.repeat(batch_size,1)], 1)
bottom = torch.logsumexp(model(xiyj).reshape(batch_size,batch_size), 1) - A
loss = -(top - bottom).mean()
result.append(-loss.item())
loss.backward(retain_graph=True)
opt.step()
r = torch.mean(torch.tensor(result[-50:]))
plt.plot(result,label="Ince")
plt.title('Ince')
plt.xlabel('Number of Epochs')
plt.ylabel('Mutual Infomation')
plt.legend(loc='lower right')
print(r)
return r