-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathNpair_loss.py
93 lines (72 loc) · 3.3 KB
/
Npair_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.contrib.losses.python.metric_learning import metric_loss_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
def cross_entropy(logits, target, size_average=True):
if size_average:
return torch.mean(torch.sum(- target * F.log_softmax(logits, -1), -1))
else:
return torch.sum(torch.sum(- target * F.log_softmax(logits, -1), -1))
class NpairLoss(nn.Module):
"""the multi-class n-pair loss"""
def __init__(self, l2_reg=0.02):
super(NpairLoss, self).__init__()
self.l2_reg = l2_reg
def forward(self, anchor, positive, target):
batch_size = anchor.size(0)
target = target.view(target.size(0), 1)
target = (target == torch.transpose(target, 0, 1)).float()
target = target / torch.sum(target, dim=1, keepdim=True).float()
logit = torch.matmul(anchor, torch.transpose(positive, 0, 1))
loss_ce = cross_entropy(logit, target)
l2_loss = torch.sum(anchor**2) / batch_size + torch.sum(positive**2) / batch_size
loss = loss_ce + self.l2_reg*l2_loss*0.25
return loss
class NpairsLossTest(test.TestCase):
def testNpairs(self):
with self.test_session():
num_data = 16
feat_dim = 5
num_classes = 3
reg_lambda = 0.02
embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32)
embeddings_positive = np.random.rand(num_data, feat_dim).astype(np.float32)
labels = np.random.randint(0, num_classes, size=(num_data)).astype(np.float32)
# Reshape labels to compute adjacency matrix.
labels_reshaped = np.reshape(labels, (labels.shape[0], 1))
# Compute the loss in NP
reg_term = np.mean(np.sum(np.square(embeddings_anchor), 1))
reg_term += np.mean(np.sum(np.square(embeddings_positive), 1))
reg_term *= 0.25 * reg_lambda
similarity_matrix = np.matmul(embeddings_anchor, embeddings_positive.T)
labels_remapped = np.equal(labels_reshaped, labels_reshaped.T).astype(np.float32)
labels_remapped /= np.sum(labels_remapped, axis=1, keepdims=True)
xent_loss = math_ops.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
logits=ops.convert_to_tensor(similarity_matrix),
labels=ops.convert_to_tensor(labels_remapped))).eval()
loss_np = xent_loss + reg_term
# Compute the loss in pytorch
npairloss = NpairLoss()
loss_tc = npairloss(
anchor=torch.tensor(embeddings_anchor),
positive=torch.tensor(embeddings_positive),
target=torch.from_numpy(labels)
)
# Compute the loss in TF
loss_tf = metric_loss_ops.npairs_loss(
labels=ops.convert_to_tensor(labels),
embeddings_anchor=ops.convert_to_tensor(embeddings_anchor),
embeddings_positive=ops.convert_to_tensor(embeddings_positive),
reg_lambda=reg_lambda)
loss_tf = loss_tf.eval()
print('pytorch version: ', loss_tc.numpy())
print('numpy version: ',loss_np)
print('tensorflow version: ',loss_tf)
# self.assertAllClose(loss_np, loss_tf)
if __name__ == '__main__':
NpairsLossTest().testNpairs()