forked from TinyZeaMays/CircleLoss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_example.py
81 lines (67 loc) · 2.5 KB
/
mnist_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import torch
from torch import nn, Tensor
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from tqdm import tqdm
from circle_loss import convert_label_to_similarity, CircleLoss
def get_loader(is_train: bool, batch_size: int) -> DataLoader:
return DataLoader(
dataset=MNIST(root="./data", train=is_train, transform=ToTensor(), download=True),
batch_size=batch_size,
shuffle=is_train,
)
class Model(nn.Module):
def __init__(self) -> None:
super(Model, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
nn.MaxPool2d(kernel_size=2),
nn.ReLU(),
)
def forward(self, inp: Tensor) -> Tensor:
feature = self.feature_extractor(inp).mean(dim=[2, 3])
return nn.functional.normalize(feature)
def main(resume: bool = True) -> None:
model = Model()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
train_loader = get_loader(is_train=True, batch_size=64)
val_loader = get_loader(is_train=False, batch_size=2)
criterion = CircleLoss(m=0.25, gamma=80)
if resume and os.path.exists("resume.state"):
model.load_state_dict(torch.load("resume.state"))
else:
for epoch in range(20):
for img, label in tqdm(train_loader):
model.zero_grad()
pred = model(img)
loss = criterion(*convert_label_to_similarity(pred, label))
loss.backward()
optimizer.step()
torch.save(model.state_dict(), "resume.state")
tp = 0
fn = 0
fp = 0
thresh = 0.75
for img, label in val_loader:
pred = model(img)
gt_label = label[0] == label[1]
pred_label = torch.sum(pred[0] * pred[1]) > thresh
if gt_label and pred_label:
tp += 1
elif gt_label and not pred_label:
fn += 1
elif not gt_label and pred_label:
fp += 1
print("Recall: {:.4f}".format(tp / (tp + fn)))
print("Precision: {:.4f}".format(tp / (tp + fp)))
if __name__ == "__main__":
main()