generated from SparkJiao/pytorch-transformers-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVEM.txt
80 lines (62 loc) · 3.36 KB
/
VEM.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
q = F.normalize(z1, dim=-1) #N *K
k = F.normalize(z2, dim=-1) #N * K
N = q.size(0)
M = 4096#*2 #* 2
use_M = True
if not hasattr(cls,"B"):
cls.B = torch.zeros(M, q.size(1), device=q.device, requires_grad=False) #M*K
#cls.B = torch.randn(M, q.size(1), device=q.device, requires_grad=False) #M*K
cls.B = F.normalize(cls.B, dim=-1)
else:
M = cls.B.size(0)
if use_M:
alpha = 1.0
eta = 10 # sampling steps
with torch.no_grad():
for i in range(eta):
# for q
L = [email protected](0,1)#/cls.model_args.temp #M*N
##L_k = [email protected](0,1)#/cls.model_args.temp #M*N
L_norm = (L/cls.model_args.temp).softmax(dim=-1) #M*N
#L_norm = F.gumbel_softmax(L, tau = cls.model_args.temp, dim=-1, hard=False) #M*N
#L_norm = F.gumbel_softmax(L, tau = 5.0, dim=-1, hard=False) #M*N
##L_norm_k = (L_k/cls.model_args.temp).softmax(dim=-1) #M*N
delta_B = L_norm @ q / N - (L_norm * L).mean(dim=1, keepdim=True) * cls.B #M*K
##delta_B_k = L_norm_k @ k / N - (L_norm_k * L_k).mean(dim=1, keepdim=True) * cls.B #M*K
delta_B1 = delta_B
##delta_B1_k = delta_B_k
delta_B = cls.B @ cls.B.transpose(0,1) @ delta_B / M + cls.B
##delta_B_k = cls.B @ cls.B.transpose(0,1) @ delta_B_k / M + cls.B
#B = cls.B + alpha / (i+1) * delta_B
#print("delta_B:", delta_B)
Q = torch.randn(M, q.size(1), device=q.device, requires_grad=False)
#B = cls.B + alpha / (i+1) * delta_B + math.sqrt(2 * alpha / (i+1)) * Q
B = cls.B + alpha *0.5 / (i+1) * delta_B + alpha * 0.5 / (i+1) * delta_B1 + math.sqrt(2 * alpha / (i+1)) * Q
##B = cls.B + alpha *0.25 / (i+1) * delta_B + alpha * 0.25 / (i+1) * delta_B1 + math.sqrt(2 * alpha / (i+1)) * Q
##B = B + alpha *0.25 / (i+1) * delta_B_k + alpha * 0.25 / (i+1) * delta_B1_k
#C = alpha *0.5 / (i+1) * delta_B
#D = alpha * 0.5 / (i+1) * delta_B1
#E = math.sqrt(2 * alpha / (i+1)) * Q
#G = torch.cat([C.unsqueeze(dim=-1), D.unsqueeze(dim=-1), E.unsqueeze(dim=-1)], dim=-1)
#W = G.softmax(dim=-1)
#G = (G * W).sum(dim=-1)
#C = alpha *0.5 / (i+1) * delta_B
#D = alpha * 0.5 / (i+1) * delta_B1
#E = math.sqrt(2 * alpha / (i+1)) * Q
#G = torch.cat([C.unsqueeze(dim=-1), D.unsqueeze(dim=-1), E.unsqueeze(dim=-1)], dim=-1)
#W = G.softmax(dim=-1)
#G = (G * W).sum(dim=-1)
#B = cls.B + G
cls.B = F.normalize(B, dim=-1)
#print("cls.B:", cls.B)
logit_neg = (q @ cls.B.transpose(0, 1) /cls.model_args.temp) # N, * M
loss_fct = nn.CrossEntropyLoss()
#loss = loss_fct(logits, zeros)
if use_M:
#cos_sim = torch.cat([cos_sim, logit_neg, logit_neg2], dim=-1) #N, M+N
cos_sim = torch.cat([cos_sim, logit_neg], dim=-1) #N, M+N
#cos_sim = torch.cat([cos_sim, logit_neg], dim=-1) #N, M+N
loss = loss_fct(cos_sim, labels)
#loss = loss + loss1
else:
loss = loss_fct(cos_sim, labels)