-
Notifications
You must be signed in to change notification settings - Fork 432
/
Copy pathtrain.py
208 lines (195 loc) · 8.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""
@author: Jun Wang
@date: 20201019
@contact: [email protected]
"""
import os
import sys
import shutil
import argparse
import logging as logger
import torch
import torch.distributed as dist
import torch.utils.data.distributed
from torch import optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from apex import amp
from optimizer import build_optimizer
from lr_scheduler import build_scheduler
sys.path.append('../../')
from utils.AverageMeter import AverageMeter
from data_processor.train_dataset import ImageDataset
from backbone.backbone_def import BackboneFactory
from head.head_def import HeadFactory
logger.basicConfig(level=logger.INFO,
format='%(levelname)s %(asctime)s %(filename)s: %(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
torch.backends.cudnn.benchmark = True
class FaceModel(torch.nn.Module):
"""Define a traditional face model which contains a backbone and a head.
Attributes:
backbone(object): the backbone of face model.
head(object): the head of face model.
"""
def __init__(self, backbone_factory, head_factory):
"""Init face model by backbone factorcy and head factory.
Args:
backbone_factory(object): produce a backbone according to config files.
head_factory(object): produce a head according to config files.
"""
super(FaceModel, self).__init__()
self.backbone = backbone_factory.get_backbone()
self.head = head_factory.get_head()
def forward(self, data, label):
feat = self.backbone.forward(data)
pred = self.head.forward(feat, label)
return pred
def get_lr(optimizer):
"""Get the current learning rate from optimizer.
"""
for param_group in optimizer.param_groups:
return param_group['lr']
def train_one_epoch(data_loader, model, optimizer, lr_schedule, criterion, cur_epoch, loss_meter, args):
"""Tain one epoch by traditional training.
"""
for batch_idx, (images, labels) in enumerate(data_loader):
images = images.to(args.local_rank)
labels = labels.to(args.local_rank)
labels = labels.squeeze()
if args.head_type == 'AdaM-Softmax':
outputs, lamda_lm = model.forward(images, labels)
lamda_lm = torch.mean(lamda_lm)
loss = criterion(outputs, labels) + lamda_lm
else:
outputs = model.forward(images, labels)
loss = criterion(outputs, labels)
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
global_batch_idx = cur_epoch * len(data_loader) + batch_idx
lr_schedule.step_update(global_batch_idx)
torch.cuda.synchronize()
loss_meter.update(loss.item(), images.shape[0])
if args.local_rank == 0 and batch_idx % args.print_freq == 0:
loss_avg = loss_meter.avg
lr = get_lr(optimizer)
logger.info('Epoch %d, iter %d/%d, lr %f, loss %f' %
(cur_epoch, batch_idx, len(data_loader), lr, loss_avg))
args.writer.add_scalar('Train_loss', loss_avg, global_batch_idx)
args.writer.add_scalar('Train_lr', lr, global_batch_idx)
loss_meter.reset()
if args.local_rank == 0 and (batch_idx + 1) % args.save_freq == 0:
saved_name = 'Epoch_%d_batch_%d.pt' % (cur_epoch, batch_idx)
state = {
'state_dict': model.module.state_dict(),
'epoch': cur_epoch,
'batch_id': batch_idx
}
torch.save(state, os.path.join(args.out_dir, saved_name))
logger.info('Save checkpoint %s to disk.' % saved_name)
torch.cuda.empty_cache()
if args.local_rank == 0:
saved_name = 'Epoch_%d.pt' % cur_epoch
state = {'state_dict': model.module.state_dict(),
'epoch': cur_epoch, 'batch_id': batch_idx}
torch.save(state, os.path.join(args.out_dir, saved_name))
logger.info('Save checkpoint %s to disk...' % saved_name)
def train(args):
"""Total training procedure.
"""
print("Use GPU: {} for training".format(args.local_rank))
if args.local_rank == 0:
writer = SummaryWriter(log_dir=args.tensorboardx_logdir)
args.writer = writer
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
args.rank = dist.get_rank()
#print('args.rank: ', dist.get_rank())
#print('args.get_world_size: ', dist.get_world_size())
#print('is_nccl_available: ', dist.is_nccl_available())
args.world_size = dist.get_world_size()
trainset = ImageDataset(args.data_root, args.train_file)
train_sampler = torch.utils.data.distributed.DistributedSampler(
trainset, shuffle=True)
train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=0,
pin_memory=True,
drop_last=False)
backbone_factory = BackboneFactory(args.backbone_type, args.backbone_conf_file)
head_factory = HeadFactory(args.head_type, args.head_conf_file)
model = FaceModel(backbone_factory, head_factory)
model = model.to(args.local_rank)
model.train()
for ps in model.parameters():
dist.broadcast(ps, 0)
optimizer = build_optimizer(model, args.lr)
lr_schedule = build_scheduler(optimizer, len(train_loader), args.epoches, args.warm_up_epoches)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
# DDP
model = torch.nn.parallel.DistributedDataParallel(
module=model,
broadcast_buffers=False,
device_ids=[args.local_rank]
)
criterion = torch.nn.CrossEntropyLoss().to(args.local_rank)
loss_meter = AverageMeter()
model.train()
ori_epoch = 0
for epoch in range(ori_epoch, args.epoches):
train_one_epoch(train_loader, model, optimizer, lr_schedule,
criterion, epoch, loss_meter, args)
dist.destroy_process_group()
if __name__ == '__main__':
conf = argparse.ArgumentParser(description='traditional_training for face recognition.')
conf.add_argument('--local_rank', type=int, default=0, help='local_rank')
conf.add_argument("--data_root", type = str,
help = "The root folder of training set.")
conf.add_argument("--train_file", type = str,
help = "The training file path.")
conf.add_argument("--backbone_type", type = str,
help = "Mobilefacenets, Resnet.")
conf.add_argument("--backbone_conf_file", type = str,
help = "the path of backbone_conf.yaml.")
conf.add_argument("--head_type", type = str,
help = "mv-softmax, arcface, npc-face.")
conf.add_argument("--head_conf_file", type = str,
help = "the path of head_conf.yaml.")
conf.add_argument('--lr', type = float, default = 0.1,
help='The initial learning rate.')
conf.add_argument("--out_dir", type = str,
help = "The folder to save models.")
conf.add_argument('--epoches', type = int, default = 9,
help = 'The training epoches.')
conf.add_argument('--warm_up_epoches', type = int, default = 9,
help = 'The training epoches.')
'''
conf.add_argument('--step', type = str, default = '2,5,7',
help = 'Step for lr.')
'''
conf.add_argument('--print_freq', type = int, default = 10,
help = 'The print frequency for training state.')
conf.add_argument('--save_freq', type = int, default = 10,
help = 'The save frequency for training state.')
conf.add_argument('--batch_size', type = int, default = 128,
help='The training batch size over all gpus.')
'''
conf.add_argument('--momentum', type = float, default = 0.9,
help = 'The momentum for sgd.')
'''
conf.add_argument('--log_dir', type = str, default = 'log',
help = 'The directory to save log.log')
conf.add_argument('--tensorboardx_logdir', type = str,
help = 'The directory to save tensorboardx logs')
conf.add_argument('--pretrain_model', type = str, default = 'mv_epoch_8.pt',
help = 'The path of pretrained model')
conf.add_argument('--resume', '-r', action = 'store_true', default = False,
help = 'Whether to resume from a checkpoint.')
args = conf.parse_args()
#args.milestones = [int(num) for num in args.step.split(',')]
train(args)