-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
51 lines (42 loc) · 1.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torch.utils.tensorboard.writer import SummaryWriter
from torchvision.utils import make_grid
CHECKPOINT_PATH = "checkpoint.tar"
TENSORBOARD_LOG_DIR = "tensorboard_logs"
def accuracy(outs, truth):
preds = outs >= 0.5
targets = truth >= 0.5
# Make both predictions and ground truths have the same shape
preds = torch.reshape(preds, targets.shape)
return preds.eq(targets).sum().item() / targets.numel()
def save_checkpoint(model, optimizer, epoch, loss):
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch,
"loss": loss
}
torch.save(checkpoint, CHECKPOINT_PATH)
def load_checkpoint(model):
checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer_state_dict = checkpoint['optimizer_state_dict']
epoch = checkpoint['epoch']
loss = checkpoint['loss']
return epoch, optimizer_state_dict, loss
class TBManager():
"""
A wrapper around Tensorboard.
"""
def __init__(self):
self.writer = SummaryWriter(log_dir=TENSORBOARD_LOG_DIR, comment='', purge_step=None)
def add_scalar(self, name, scalar, epoch):
self.writer.add_scalar(name, scalar, epoch)
def add_images(self, name, model, arg_images):
images = arg_images * 255
grid = make_grid(images)
self.writer.add_image(name, grid, 0)
# if model:
# self.writer.add_graph(model, images)
def close(self):
self.writer.close()