-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase_postprocessor.py
46 lines (36 loc) · 1.4 KB
/
base_postprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import Any
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import openood.utils.comm as comm
class BasePostprocessor:
def __init__(self, config):
self.config = config
def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict, id_loader_split=None):
pass
@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
output = net(data)
score = torch.softmax(output, dim=1)
conf, pred = torch.max(score, dim=1)
return pred, conf
def inference(self,
net: nn.Module,
data_loader: DataLoader,
progress: bool = True):
pred_list, conf_list, label_list = [], [], []
for batch in tqdm(data_loader,
disable=not progress or not comm.is_main_process()):
#print(batch)
data = batch['data'].cuda()
label = batch['label'].cuda()
pred, conf = self.postprocess(net, data)
pred_list.append(pred.cpu())
conf_list.append(conf.cpu())
label_list.append(label.cpu())
# convert values into numpy array
pred_list = torch.cat(pred_list).numpy().astype(int)
conf_list = torch.cat(conf_list).numpy()
label_list = torch.cat(label_list).numpy().astype(int)
return pred_list, conf_list, label_list