-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkl_matching_postprocessor.py
69 lines (58 loc) · 2.38 KB
/
kl_matching_postprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances_argmin_min
import scipy
from tqdm import tqdm
from .base_postprocessor import BasePostprocessor
from .info import get_num_classes
class KLMatchingPostprocessor(BasePostprocessor):
def __init__(self, config):
super().__init__(config)
self.num_classes = get_num_classes(self.config.dataset.name)
self.setup_flag = False
self.has_data_based_setup = True
def kl(self, p, q):
return scipy.stats.entropy(p, q)
def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict, id_loader_split="val"):
print(f"Setup on ID data - {id_loader_split} split")
if not self.setup_flag:
net.eval()
print('Extracting id validation softmax posterior distributions')
all_softmax = []
preds = []
with torch.no_grad():
for batch in tqdm(id_loader_dict[id_loader_split],
desc='Setup: ',
position=0,
leave=True):
data = batch['data'].cuda()
logits = net(data)
all_softmax.append(F.softmax(logits, 1).cpu())
preds.append(logits.argmax(1).cpu())
all_softmax = torch.cat(all_softmax)
preds = torch.cat(preds)
self.mean_softmax_val = []
for i in tqdm(range(self.num_classes)):
# if there are no validation samples
# for this category
if torch.sum(preds.eq(i).float()) == 0:
temp = np.zeros((self.num_classes, ))
temp[i] = 1
self.mean_softmax_val.append(temp)
else:
self.mean_softmax_val.append(
all_softmax[preds.eq(i)].mean(0).numpy())
self.setup_flag = True
else:
pass
@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
logits = net(data)
preds = logits.argmax(1)
softmax = F.softmax(logits, 1).cpu().numpy()
scores = -pairwise_distances_argmin_min(
softmax, np.array(self.mean_softmax_val), metric=self.kl)[1]
return preds, torch.from_numpy(scores)