-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrmds_postprocessor.py
94 lines (80 loc) · 3.76 KB
/
rmds_postprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from copy import deepcopy
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import sklearn.covariance
from tqdm import tqdm
from .base_postprocessor import BasePostprocessor
from .info import get_num_classes
class RMDSPostprocessor(BasePostprocessor):
def __init__(self, config):
self.config = config
self.num_classes = get_num_classes(self.config.dataset.name)
self.setup_flag = False
self.has_data_based_setup = True
def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict, id_loader_split="train"):
print(f"Setup on ID data - {id_loader_split} split")
if not self.setup_flag:
# estimate mean and variance from training set
print('\n Estimating mean and variance from training set...')
all_feats = []
all_labels = []
all_preds = []
with torch.no_grad():
for batch in tqdm(id_loader_dict[id_loader_split],
desc='Setup: ',
position=0,
leave=True):
data, labels = batch['data'].cuda(), batch['label']
logits, features = net(data, return_feature=True)
all_feats.append(features.cpu())
all_labels.append(deepcopy(labels))
all_preds.append(logits.argmax(1).cpu())
all_feats = torch.cat(all_feats)
all_labels = torch.cat(all_labels)
all_preds = torch.cat(all_preds)
# sanity check on train acc
train_acc = all_preds.eq(all_labels).float().mean()
print(f' Train acc: {train_acc:.2%}')
# compute class-conditional statistics
self.class_mean = []
centered_data = []
for c in range(self.num_classes):
class_samples = all_feats[all_labels.eq(c)].data
self.class_mean.append(class_samples.mean(0))
centered_data.append(class_samples -
self.class_mean[c].view(1, -1))
self.class_mean = torch.stack(
self.class_mean) # shape [#classes, feature dim]
group_lasso = sklearn.covariance.EmpiricalCovariance(
assume_centered=False)
group_lasso.fit(
torch.cat(centered_data).cpu().numpy().astype(np.float32))
# inverse of covariance
self.precision = torch.from_numpy(group_lasso.precision_).float()
self.whole_mean = all_feats.mean(0)
centered_data = all_feats - self.whole_mean.view(1, -1)
group_lasso = sklearn.covariance.EmpiricalCovariance(
assume_centered=False)
group_lasso.fit(centered_data.cpu().numpy().astype(np.float32))
self.whole_precision = torch.from_numpy(
group_lasso.precision_).float()
self.setup_flag = True
else:
pass
@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
logits, features = net(data, return_feature=True)
pred = logits.argmax(1)
tensor1 = features.cpu() - self.whole_mean.view(1, -1)
background_scores = -torch.matmul(
torch.matmul(tensor1, self.whole_precision), tensor1.t()).diag()
class_scores = torch.zeros((logits.shape[0], self.num_classes))
for c in range(self.num_classes):
tensor = features.cpu() - self.class_mean[c].view(1, -1)
class_scores[:, c] = -torch.matmul(
torch.matmul(tensor, self.precision), tensor.t()).diag()
class_scores[:, c] = class_scores[:, c] - background_scores
conf = torch.max(class_scores, dim=1)[0]
return pred, conf