-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathomniobject_train_pose_2D.py
147 lines (126 loc) · 5.75 KB
/
omniobject_train_pose_2D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import pprint
import random
import numpy as np
import torch
import torch.nn.parallel
import torch.optim
import itertools
import torch.utils.data
import torch.utils.data.distributed
import torch.distributed as dist
import argparse
from config.config import config, update_config
from utils import exp_utils, train_utils
from scripts.kubric_trainer_pose2D import train_epoch, validate
from models.pose_estimator_2d import PoseEstimator2D
from dataset.omniobject3d import Omniobject3D
def parse_args():
parser = argparse.ArgumentParser(description='Train pose estimator using 2D inputs')
parser.add_argument(
'--cfg', help='experiment configure file name', required=True, type=str)
parser.add_argument(
'--local_rank', default=-1, type=int, help='node rank for distributed training')
args, rest = parser.parse_known_args()
update_config(args.cfg)
return args
def main():
# Get args and config
args = parse_args()
logger, output_dir, tb_log_dir = exp_utils.create_logger(config, args.cfg, phase='train')
logger.info(pprint.pformat(args))
logger.info(pprint.pformat(config))
# set random seeds
torch.cuda.manual_seed_all(config.seed)
torch.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)
# set device
gpus = range(torch.cuda.device_count())
device = torch.device('cuda') if len(gpus) > 0 else torch.device('cpu')
if device == torch.device("cuda"):
dist.init_process_group(backend='nccl', init_method='env://')
torch.cuda.set_device(args.local_rank)
# get model
model = PoseEstimator2D().to(device)
optimizer = torch.optim.Adam(model.parameters(),
lr=config.train.lr * config.train.accumulation_step,
weight_decay=config.train.weight_decay)
best_rot, ep_resume = float('inf'), None
if config.train.resume:
model, optimizer, ep_resume, best_psnr, best_rot = exp_utils.resume_training(model, optimizer, output_dir,
cpt_name='cpt_best_rot_11.268242299860368.pth.tar')
# distributed training
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if device == torch.device("cuda"):
torch.backends.cudnn.benchmark = True
device_ids = range(torch.cuda.device_count())
print("using {} cuda".format(len(device_ids)))
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
device_num = len(device_ids)
# get dataset and dataloader
train_data = Omniobject3D(config, split='train')
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=config.train.batch_size,
shuffle=False,
num_workers=int(config.workers),
pin_memory=True,
drop_last=True,
sampler=train_sampler)
val_data = Omniobject3D(config, split='val')
val_loader = torch.utils.data.DataLoader(val_data,
batch_size=config.test.batch_size,
shuffle=False,
num_workers=int(config.workers),
pin_memory=True,
drop_last=False)
start_ep = ep_resume if ep_resume is not None else 0
end_ep = int(config.train.total_iteration / len(train_loader)) + 1
# train
for epoch in range(start_ep, end_ep):
train_sampler.set_epoch(epoch)
train_epoch(config,
loader=train_loader,
dataset=train_data,
model=model,
optimizer=optimizer,
epoch=epoch,
output_dir=output_dir,
device=device,
rank=args.local_rank)
if args.local_rank == 0:
train_utils.save_checkpoint(
{
'epoch': epoch + 1,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
},
checkpoint=output_dir, filename="cpt_last.pth.tar")
if epoch % 8 == 0:
print('Testing..')
cur_rot, cur_trans, return_dict = validate(config,
loader=val_loader,
dataset=val_data,
model=model,
epoch=epoch,
output_dir=output_dir,
device=device,
rank=args.local_rank)
torch.cuda.empty_cache()
if cur_rot < best_rot:
best_rot = cur_rot
if args.local_rank == 0:
train_utils.save_checkpoint(
{
'epoch': epoch + 1,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'best_rot': best_rot,
'eval_dict': return_dict,
},
checkpoint=output_dir, filename="cpt_best_rot_{}_trans_{}.pth.tar".format(best_rot, cur_trans))
if args.local_rank == 0:
logger.info('Best rot error: {} (current {})'.format(best_rot, cur_rot))
if __name__ == '__main__':
main()