-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
100 lines (94 loc) · 3.86 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import time
import torch
import torch.nn as nn
import os
import visdom
import random
from tqdm import tqdm as tqdm
from cannet import CANNet
from my_dataset import CrowdDataset
if __name__=="__main__":
# configuration
train_image_root='./data_preparation/part_A/train_data/images'
train_dmap_root='./data_preparation/part_A/train_data/ground_truth'
test_image_root='./data_preparation/part_A/test_data/images'
test_dmap_root='./data_preparation/part_A/test_data/ground_truth'
gpu_or_cpu='cuda' # use cuda or cpu
lr = 1e-7
batch_size = 1
momentum = 0.95
epochs = 20000
steps = [-1,1,100,150]
scales = [1,1,1,1]
workers = 4
seed = time.time()
print_freq = 30
vis=visdom.Visdom()
device=torch.device(gpu_or_cpu)
torch.cuda.manual_seed(seed)
model=CANNet().to(device)
criterion=nn.MSELoss(size_average=False).to(device)
optimizer=torch.optim.SGD(model.parameters(),lr,
momentum=momentum,
weight_decay=0)
# optimizer=torch.optim.Adam(model.parameters(),lr)
train_dataset=CrowdDataset(train_image_root,train_dmap_root,gt_downsample=8,phase='train')
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=1,shuffle=True)
test_dataset=CrowdDataset(test_image_root,test_dmap_root,gt_downsample=8,phase='test')
test_loader=torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False)
if not os.path.exists('./checkpoints'):
os.mkdir('./checkpoints')
min_mae=10000
min_epoch=0
train_loss_list=[]
epoch_list=[]
test_error_list=[]
for epoch in range(0,epochs):
# training phase
model.train()
epoch_loss=0
for i,(img,gt_dmap) in enumerate(tqdm(train_loader)):
img=img.to(device)
gt_dmap=gt_dmap.to(device)
# forward propagation
et_dmap=model(img)
# calculate loss
loss=criterion(et_dmap,gt_dmap)
epoch_loss+=loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print("epoch:",epoch,"loss:",epoch_loss/len(dataloader))
epoch_list.append(epoch)
train_loss_list.append(epoch_loss/len(train_loader))
torch.save(model.state_dict(),'./checkpoints/epoch_'+str(epoch)+".pth")
# testing phase
model.eval()
mae=0
for i,(img,gt_dmap) in enumerate(tqdm(test_loader)):
img=img.to(device)
gt_dmap=gt_dmap.to(device)
# forward propagation
et_dmap=model(img)
mae+=abs(et_dmap.data.sum()-gt_dmap.data.sum()).item()
del img,gt_dmap,et_dmap
if mae/len(test_loader)<min_mae:
min_mae=mae/len(test_loader)
min_epoch=epoch
test_error_list.append(mae/len(test_loader))
print("epoch:"+str(epoch)+" error:"+str(mae/len(test_loader))+" min_mae:"+str(min_mae)+" min_epoch:"+str(min_epoch))
vis.line(win=1,X=epoch_list, Y=train_loss_list, opts=dict(title='train_loss'))
vis.line(win=2,X=epoch_list, Y=test_error_list, opts=dict(title='test_error'))
# show an image
index=random.randint(0,len(test_loader)-1)
img,gt_dmap=test_dataset[index]
vis.image(win=3,img=img,opts=dict(title='img'))
vis.image(win=4,img=gt_dmap/(gt_dmap.max())*255,opts=dict(title='gt_dmap('+str(gt_dmap.sum())+')'))
img=img.unsqueeze(0).to(device)
gt_dmap=gt_dmap.unsqueeze(0)
et_dmap=model(img)
et_dmap=et_dmap.squeeze(0).detach().cpu().numpy()
vis.image(win=5,img=et_dmap/(et_dmap.max())*255,opts=dict(title='et_dmap('+str(et_dmap.sum())+')'))
import time
print(time.strftime('%Y.%m.%d %H:%M:%S',time.localtime(time.time())))