forked from Lin-Yijie/Graph-Matching-Networks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
59 lines (43 loc) · 1.76 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from sklearn import metrics
from loss import *
def exact_hamming_similarity(x, y):
"""Compute the binary Hamming similarity."""
match = ((x > 0) * (y > 0)).float()
return torch.mean(match, dim=1)
def compute_similarity(config, x, y):
"""Compute the distance between x and y vectors.
The distance will be computed based on the training loss type.
Args:
config: a config dict.
x: [n_examples, feature_dim] float tensor.
y: [n_examples, feature_dim] float tensor.
Returns:
dist: [n_examples] float tensor.
Raises:
ValueError: if loss type is not supported.
"""
if config['training']['loss'] == 'margin':
# similarity is negative distance
return -euclidean_distance(x, y)
elif config['training']['loss'] == 'hamming':
return exact_hamming_similarity(x, y)
else:
raise ValueError('Unknown loss type %s' % config['training']['loss'])
def auc(scores, labels, **auc_args):
"""Compute the AUC for pair classification.
See `tf.metrics.auc` for more details about this metric.
Args:
scores: [n_examples] float. Higher scores mean higher preference of being
assigned the label of +1.
labels: [n_examples] int. Labels are either +1 or -1.
**auc_args: other arguments that can be used by `tf.metrics.auc`.
Returns:
auc: the area under the ROC curve.
"""
scores_max = torch.max(scores)
scores_min = torch.min(scores)
# normalize scores to [0, 1] and add a small epislon for safety
scores = (scores - scores_min) / (scores_max - scores_min + 1e-8)
labels = (labels + 1) / 2
fpr, tpr, thresholds = metrics.roc_curve(labels.cpu().detach().numpy(), scores.cpu().detach().numpy())
return metrics.auc(fpr, tpr)