-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
188 lines (148 loc) · 6.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""
contains various utility functions for pytorch model training and saving
"""
import torch
from pathlib import Path
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
from torch.utils.tensorboard.writer import SummaryWriter
def save_model(model: torch.nn.Module,
target_dir: str,
model_name: str):
"""Saves a pytorch model to a target directory
Args:
model: target pytorch model
target_dir: string of target directory path to store the saved models
model_name: a filename for the saved model. Should be included either ".pth" or ".pt" as
the file extension.
"""
# create target directory
target_dir_path = Path(target_dir)
target_dir_path.mkdir(parents=True, exist_ok=True)
# create model save path
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model name should end with .pt or .pth"
model_save_path = target_dir_path / model_name
# save the model state_dict()
print(f"[INFO] Saving model to: {model_save_path}")
torch.save(obj=model.state_dict(), f=model_save_path)
def pred_and_plot_image(
model: torch.nn.Module,
image_path: str,
class_names: list[str] = None,
transform=None,
device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
):
"""Makes a prediction on a target image with a trained model and plots the image.
Args:
model (torch.nn.Module): trained PyTorch image classification model.
image_path (str): filepath to target image.
class_names (List[str], optional): different class names for target image. Defaults to None.
transform (_type_, optional): transform of target image. Defaults to None.
device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
Returns:
Matplotlib plot of target image and model prediction as title.
Example usage:
pred_and_plot_image(model=model,
image="some_image.jpeg",
class_names=["class_1", "class_2", "class_3"],
transform=torchvision.transforms.ToTensor(),
device=device)
"""
# 1. Load in image and convert the tensor values to float32
img_list = Image.open(image_path)
# 2. Divide the image pixel values by 255 to get them between [0, 1]
# target_image = target_image / 255.0
# 3. Transform if necessary
if transform:
target_image = transform(img_list)
# 4. Make sure the model is on the target device
model.to(device)
# 5. Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
# Add an extra dimension to the image
target_image = target_image.unsqueeze(dim=0)
# Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(target_image.to(device))
# 6. Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# 7. Convert prediction probabilities -> prediction labels
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
# 8. Plot the image alongside the prediction and prediction probability
plt.imshow(
target_image.squeeze().permute(1, 2, 0)
) # make sure it's the right size for matplotlib
if class_names:
title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
else:
title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
plt.title(title)
plt.axis(False)
def set_seeds(seed: int=42):
"""Sets random sets for torch operations.
Args:
seed (int, optional): Random seed to set. Defaults to 42.
"""
# Set the seed for general torch operations
torch.manual_seed(seed)
# Set the seed for CUDA torch operations (ones that happen on the GPU)
torch.cuda.manual_seed(seed)
def create_writer(experiment_name: str, model_name: str, extra: str=None) -> torch.utils.tensorboard.writer.SummaryWriter(): # type: ignore
"""
creates a torch.utils.tensorboard.writer.SummaryWriter() instance saving to a
specific log_dir.
log_dir is a combination of runs/timestamp/experiment_name/model_name/extra.
where timestamp is the current date in YYYY-MM-DD format.
Args:
experiment_name (str): Name of experiment
model_name (str): model name
extra (str, optional): anything extra to add to the directory. Defaults is None
Returns:
torch.utils.tensorboard.writer.SummaryWriter(): Instance of a writer saving to log_dir
Examples usage:
this is gonna create writer saving to "runs/2022-06-04/data_10_percent/effnetb2/5_epochs"
writer = create_writer(experiment_name="data_10_percent", model_name="effnetb2", extra="5_epochs")
This is the same as:
writer = SummaryWriter(log_dir="runs/2022-06-04/data_10_percent/effnetb2/5_epochs")
"""
from datetime import datetime
import os
# get the timestamp
timestamp = datetime.now().strftime("%Y-%m-%d")
if extra:
# create log directory path
log_dir = os.path.join("runs", timestamp, experiment_name, model_name, extra)
else:
log_dir = os.path.join("runs", timestamp, experiment_name, model_name)
print(f"[INFO] Created SummaryWriter(), saving to: {log_dir}")
return SummaryWriter(log_dir=log_dir)
def plot_loss_curves(results):
"""Plots training curves of a results dictionary.
Args:
results (dict): dictionary containing list of values, e.g.
{"train_loss": [...],
"train_acc": [...],
"test_loss": [...],
"test_acc": [...]}
"""
loss = results["train_loss"]
test_loss = results["test_loss"]
accuracy = results["train_acc"]
test_accuracy = results["test_acc"]
epochs = range(len(results["train_loss"]))
plt.figure(figsize=(15, 7))
# Plot loss
plt.subplot(1, 2, 1)
plt.plot(epochs, loss, label="train_loss")
plt.plot(epochs, test_loss, label="test_loss")
plt.title("Loss")
plt.xlabel("Epochs")
plt.legend()
# Plot accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, accuracy, label="train_accuracy")
plt.plot(epochs, test_accuracy, label="test_accuracy")
plt.title("Accuracy")
plt.xlabel("Epochs")
plt.legend()