-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
37 lines (30 loc) · 1.24 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import torch as tr
from sklearn.metrics import accuracy_score
from dataset import ProtDataset, pad_batch
from torch.utils.data import DataLoader
from tlprotcnn import TLProtCNN
TEST_PATH = "data/Clustered_data/test/"
CACHE_PATH = "data/"
BATCH_SIZE = 32
DEVICE = "cuda"
categories = [item.strip() for item in open("data/categories.txt")]
# trained model weights
models = [f"{d}/weights.pk" for d in os.listdir("./") if "results_" in d]
test_data = ProtDataset(TEST_PATH, categories, CACHE_PATH)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, collate_fn=pad_batch)
# run predictions for each model on the folder, and get the ensembled prediction
pred_avg = tr.zeros((len(test_data), len(categories)))
for k, model in enumerate(models):
print("load weights from", model)
net = TLProtCNN(len(categories), device=DEVICE)
net.load_state_dict(tr.load(model))
_, test_errate, pred, ref, _ = net.pred(test_loader)
# k-ensemble score
pred_avg += pred
pred_avg_bin = tr.argmax(pred_avg, dim=1)
ensemble_errate = 1 - accuracy_score(ref, pred_avg_bin)
msg = f"Model-{k+1:02d} error: {test_errate:.2f}"
if k>0:
msg += f", {k+1}-ensemble error: {ensemble_errate:.2f}"
print(msg)