-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshe_postprocessor.py
77 lines (64 loc) · 2.93 KB
/
she_postprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from typing import Any
from copy import deepcopy
import torch
import torch.nn as nn
from tqdm import tqdm
from .base_postprocessor import BasePostprocessor
from .info import get_num_classes
def distance(penultimate, target, metric='inner_product'):
if metric == 'inner_product':
return torch.sum(torch.mul(penultimate, target), dim=1)
elif metric == 'euclidean':
return -torch.sqrt(torch.sum((penultimate - target)**2, dim=1))
elif metric == 'cosine':
return torch.cosine_similarity(penultimate, target, dim=1)
else:
raise ValueError('Unknown metric: {}'.format(metric))
class SHEPostprocessor(BasePostprocessor):
def __init__(self, config):
super(SHEPostprocessor, self).__init__(config)
self.args = self.config.postprocessor.postprocessor_args
self.num_classes = get_num_classes(self.config.dataset.name)
self.activation_log = None
self.setup_flag = False
self.has_data_based_setup = True
def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict, id_loader_split="train"):
print(f"Setup on ID data - {id_loader_split} split")
if not self.setup_flag:
net.eval()
all_activation_log = []
all_labels = []
all_preds = []
with torch.no_grad():
for batch in tqdm(id_loader_dict[id_loader_split],
desc='Eval: ',
position=0,
leave=True):
data = batch['data'].cuda()
labels = batch['label']
all_labels.append(deepcopy(labels))
logits, features = net(data, return_feature=True)
all_activation_log.append(features.cpu())
all_preds.append(logits.argmax(1).cpu())
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)
all_activation_log = torch.cat(all_activation_log)
self.activation_log = []
for i in range(self.num_classes):
mask = torch.logical_and(all_labels == i, all_preds == i)
if not mask.any():
print(f"WARNING - Correct predictions for class {i} not found in ID data")
mask = torch.logical_or(all_labels == i, all_preds == i)
class_correct_activations = all_activation_log[mask]
self.activation_log.append(
class_correct_activations.mean(0, keepdim=True))
self.activation_log = torch.cat(self.activation_log).cuda()
self.setup_flag = True
else:
pass
@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
output, feature = net(data, return_feature=True)
pred = output.argmax(1)
conf = distance(feature, self.activation_log[pred], self.args.metric)
return pred, conf