-
Notifications
You must be signed in to change notification settings - Fork 434
/
train_amp.py
183 lines (171 loc) · 7.97 KB
/
train_amp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
@author: Jun Wang
@date: 20210124
@contact: [email protected]
"""
import os
import sys
import shutil
import argparse
import logging as logger
import torch
from torch import optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from apex import amp
sys.path.append('../../')
from utils.AverageMeter import AverageMeter
from data_processor.train_dataset import ImageDataset
from backbone.backbone_def import BackboneFactory
from head.head_def import HeadFactory
logger.basicConfig(level=logger.INFO,
format='%(levelname)s %(asctime)s %(filename)s: %(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
class FaceModel(torch.nn.Module):
"""Define a traditional face model which contains a backbone and a head.
Attributes:
backbone(object): the backbone of face model.
head(object): the head of face model.
"""
def __init__(self, backbone_factory, head_factory):
"""Init face model by backbone factorcy and head factory.
Args:
backbone_factory(object): produce a backbone according to config files.
head_factory(object): produce a head according to config files.
"""
super(FaceModel, self).__init__()
self.backbone = backbone_factory.get_backbone()
self.head = head_factory.get_head()
def forward(self, data, label):
feat = self.backbone.forward(data)
pred = self.head.forward(feat, label)
return pred
def get_lr(optimizer):
"""Get the current learning rate from optimizer.
"""
for param_group in optimizer.param_groups:
return param_group['lr']
def train_one_epoch(data_loader, model, optimizer, criterion, cur_epoch, loss_meter, conf):
"""Tain one epoch by traditional training.
"""
for batch_idx, (images, labels) in enumerate(data_loader):
images = images.to(conf.device)
labels = labels.to(conf.device)
labels = labels.squeeze()
if conf.head_type == 'AdaM-Softmax':
outputs, lamda_lm = model.forward(images, labels)
lamda_lm = torch.mean(lamda_lm)
loss = criterion(outputs, labels) + lamda_lm
else:
outputs = model.forward(images, labels)
loss = criterion(outputs, labels)
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
loss_meter.update(loss.item(), images.shape[0])
if batch_idx % conf.print_freq == 0:
loss_avg = loss_meter.avg
lr = get_lr(optimizer)
logger.info('Epoch %d, iter %d/%d, lr %f, loss %f' %
(cur_epoch, batch_idx, len(data_loader), lr, loss_avg))
global_batch_idx = cur_epoch * len(data_loader) + batch_idx
conf.writer.add_scalar('Train_loss', loss_avg, global_batch_idx)
conf.writer.add_scalar('Train_lr', lr, global_batch_idx)
loss_meter.reset()
if (batch_idx + 1) % conf.save_freq == 0:
saved_name = 'Epoch_%d_batch_%d.pt' % (cur_epoch, batch_idx)
state = {
'state_dict': model.module.state_dict(),
'epoch': cur_epoch,
'batch_id': batch_idx
}
torch.save(state, os.path.join(conf.out_dir, saved_name))
logger.info('Save checkpoint %s to disk.' % saved_name)
saved_name = 'Epoch_%d.pt' % cur_epoch
state = {'state_dict': model.module.state_dict(),
'epoch': cur_epoch, 'batch_id': batch_idx}
torch.save(state, os.path.join(conf.out_dir, saved_name))
logger.info('Save checkpoint %s to disk...' % saved_name)
def train(conf):
"""Total training procedure.
"""
data_loader = DataLoader(ImageDataset(conf.data_root, conf.train_file),
conf.batch_size, True, num_workers = 4)
conf.device = torch.device('cuda:0')
criterion = torch.nn.CrossEntropyLoss().cuda(conf.device)
backbone_factory = BackboneFactory(conf.backbone_type, conf.backbone_conf_file)
head_factory = HeadFactory(conf.head_type, conf.head_conf_file)
model = FaceModel(backbone_factory, head_factory)
ori_epoch = 0
if conf.resume:
ori_epoch = torch.load(args.pretrain_model)['epoch'] + 1
state_dict = torch.load(args.pretrain_model)['state_dict']
model.load_state_dict(state_dict)
model = model.cuda()
parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(parameters, lr = conf.lr,
momentum = conf.momentum, weight_decay = 1e-4)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
model = torch.nn.DataParallel(model).cuda()
lr_schedule = optim.lr_scheduler.MultiStepLR(
optimizer, milestones = conf.milestones, gamma = 0.1)
loss_meter = AverageMeter()
model.train()
for epoch in range(ori_epoch, conf.epoches):
train_one_epoch(data_loader, model, optimizer,
criterion, epoch, loss_meter, conf)
lr_schedule.step()
if __name__ == '__main__':
conf = argparse.ArgumentParser(description='traditional_training for face recognition.')
conf.add_argument("--data_root", type = str,
help = "The root folder of training set.")
conf.add_argument("--train_file", type = str,
help = "The training file path.")
conf.add_argument("--backbone_type", type = str,
help = "Mobilefacenets, Resnet.")
conf.add_argument("--backbone_conf_file", type = str,
help = "the path of backbone_conf.yaml.")
conf.add_argument("--head_type", type = str,
help = "mv-softmax, arcface, npc-face.")
conf.add_argument("--head_conf_file", type = str,
help = "the path of head_conf.yaml.")
conf.add_argument('--lr', type = float, default = 0.1,
help='The initial learning rate.')
conf.add_argument("--out_dir", type = str,
help = "The folder to save models.")
conf.add_argument('--epoches', type = int, default = 9,
help = 'The training epoches.')
conf.add_argument('--step', type = str, default = '2,5,7',
help = 'Step for lr.')
conf.add_argument('--print_freq', type = int, default = 10,
help = 'The print frequency for training state.')
conf.add_argument('--save_freq', type = int, default = 10,
help = 'The save frequency for training state.')
conf.add_argument('--batch_size', type = int, default = 128,
help='The training batch size over all gpus.')
conf.add_argument('--momentum', type = float, default = 0.9,
help = 'The momentum for sgd.')
conf.add_argument('--log_dir', type = str, default = 'log',
help = 'The directory to save log.log')
conf.add_argument('--tensorboardx_logdir', type = str,
help = 'The directory to save tensorboardx logs')
conf.add_argument('--pretrain_model', type = str, default = 'mv_epoch_8.pt',
help = 'The path of pretrained model')
conf.add_argument('--resume', '-r', action = 'store_true', default = False,
help = 'Whether to resume from a checkpoint.')
args = conf.parse_args()
args.milestones = [int(num) for num in args.step.split(',')]
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
tensorboardx_logdir = os.path.join(args.log_dir, args.tensorboardx_logdir)
if os.path.exists(tensorboardx_logdir):
shutil.rmtree(tensorboardx_logdir)
writer = SummaryWriter(log_dir=tensorboardx_logdir)
args.writer = writer
logger.info('Start optimization.')
logger.info(args)
train(args)
logger.info('Optimization done!')