-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_baseline.py
95 lines (74 loc) · 3.52 KB
/
run_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import print_function, absolute_import, division
import datetime
import os
import os.path as path
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from utils.loss import mpjpe
from function_baseline.config import get_parse_args
from function_baseline.data_preparation import data_preparation
from function_baseline.model_pos_preparation import model_pos_preparation
from function_baseline.model_pos_train import train
from function_adaptpose.model_pos_eval import evaluate
from utils.log import Logger, savefig
from utils.utils import save_ckpt
def main(args):
print('==> Using settings {}'.format(args))
device = torch.device("cuda")
print('==> Loading dataset...')
data_dict = data_preparation(args)
print("==> Creating PoseNet model...")
model_pos = model_pos_preparation(args, data_dict['dataset'], device)
print("==> Prepare optimizer...")
criterion = nn.MSELoss(reduction='mean').to(device)
# criterion = mpjpe
optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr)
ckpt_dir_path = path.join(args.checkpoint, args.posenet_name, args.keypoints,
datetime.datetime.now().strftime('%m%d%H%M%S') + '_' + args.note)
os.makedirs(ckpt_dir_path, exist_ok=True)
print('==> Making checkpoint dir: {}'.format(ckpt_dir_path))
logger = Logger(os.path.join(ckpt_dir_path, 'log.txt'), args)
logger.set_names(['epoch', 'lr', 'loss_train', 'error_h36m_p1', 'error_h36m_p2', 'error_3dhp_p1', 'error_3dhp_p2'])
#################################################
# ########## start training here
#################################################
start_epoch = 0
error_best = None
glob_step = 0
lr_now = args.lr
for epoch in range(start_epoch, args.epochs):
print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr_now))
# Train for one epoch
epoch_loss, lr_now, glob_step = train(data_dict['train_loader'], model_pos, criterion, optimizer, device, args.lr, lr_now,
glob_step, args.lr_decay, args.lr_gamma, max_norm=args.max_norm,pad=args.pad)
# Evaluate
error_h36m_p1, error_h36m_p2 = evaluate(data_dict['H36M_test'], model_pos, device,pad=args.pad,flipaug='_flip',scale=False)
error_3dhp_p1, error_3dhp_p2 = evaluate(data_dict['3DHP_test'], model_pos, device,pad=args.pad,flipaug='_flip',tag='3dhp') #,
# Update log file
logger.append([epoch + 1, lr_now, epoch_loss, error_h36m_p1, error_h36m_p2, error_3dhp_p1, error_3dhp_p2])
# Update checkpoint
if error_best is None or error_best > error_h36m_p1:
error_best = error_h36m_p1
save_ckpt({'state_dict': model_pos.state_dict(), 'epoch': epoch + 1}, ckpt_dir_path, suffix='best')
if (epoch + 1) % args.snapshot == 0:
save_ckpt({'state_dict': model_pos.state_dict(), 'epoch': epoch + 1}, ckpt_dir_path)
logger.close()
logger.plot(['loss_train', 'error_h36m_p1'])
savefig(path.join(ckpt_dir_path, 'log.eps'))
return
if __name__ == '__main__':
args = get_parse_args()
# fix random
random_seed = args.random_seed
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
os.environ['PYTHONHASHSEED'] = str(random_seed)
# copy from #https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.deterministic = True
cudnn.benchmark = True
main(args)