-
Notifications
You must be signed in to change notification settings - Fork 38
/
train.py
186 lines (148 loc) · 7.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import shutil
import torch.nn as nn
import torch.optim
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.optim.rmsprop import RMSprop
from tqdm import tqdm
from utils import AverageTracker
class Train:
def __init__(self, model, trainloader, valloader, args):
self.model = model
self.trainloader = trainloader
self.valloader = valloader
self.args = args
self.start_epoch = 0
self.best_top1 = 0.0
# Loss function and Optimizer
self.loss = None
self.optimizer = None
self.create_optimization()
# Model Loading
self.load_pretrained_model()
self.load_checkpoint(self.args.resume_from)
# Tensorboard Writer
self.summary_writer = SummaryWriter(log_dir=args.summary_dir)
def train(self):
for cur_epoch in range(self.start_epoch, self.args.num_epochs):
# Initialize tqdm
tqdm_batch = tqdm(self.trainloader,
desc="Epoch-" + str(cur_epoch) + "-")
# Learning rate adjustment
self.adjust_learning_rate(self.optimizer, cur_epoch)
# Meters for tracking the average values
loss, top1, top5 = AverageTracker(), AverageTracker(), AverageTracker()
# Set the model to be in training mode (for dropout and batchnorm)
self.model.train()
for data, target in tqdm_batch:
if self.args.cuda:
data, target = data.cuda(async=self.args.async_loading), target.cuda(
async=self.args.async_loading)
data_var, target_var = Variable(data), Variable(target)
# Forward pass
output = self.model(data_var)
cur_loss = self.loss(output, target_var)
# Optimization step
self.optimizer.zero_grad()
cur_loss.backward()
self.optimizer.step()
# Top-1 and Top-5 Accuracy Calculation
cur_acc1, cur_acc5 = self.compute_accuracy(output.data, target, topk=(1, 5))
loss.update(cur_loss.data[0])
top1.update(cur_acc1[0])
top5.update(cur_acc5[0])
# Summary Writing
self.summary_writer.add_scalar("epoch-loss", loss.avg, cur_epoch)
self.summary_writer.add_scalar("epoch-top-1-acc", top1.avg, cur_epoch)
self.summary_writer.add_scalar("epoch-top-5-acc", top5.avg, cur_epoch)
# Print in console
tqdm_batch.close()
print("Epoch-" + str(cur_epoch) + " | " + "loss: " + str(
loss.avg) + " - acc-top1: " + str(
top1.avg)[:7] + "- acc-top5: " + str(top5.avg)[:7])
# Evaluate on Validation Set
if cur_epoch % self.args.test_every == 0 and self.valloader:
self.test(self.valloader, cur_epoch)
# Checkpointing
is_best = top1.avg > self.best_top1
self.best_top1 = max(top1.avg, self.best_top1)
self.save_checkpoint({
'epoch': cur_epoch + 1,
'state_dict': self.model.state_dict(),
'best_top1': self.best_top1,
'optimizer': self.optimizer.state_dict(),
}, is_best)
def test(self, testloader, cur_epoch=-1):
loss, top1, top5 = AverageTracker(), AverageTracker(), AverageTracker()
# Set the model to be in testing mode (for dropout and batchnorm)
self.model.eval()
for data, target in testloader:
if self.args.cuda:
data, target = data.cuda(async=self.args.async_loading), target.cuda(
async=self.args.async_loading)
data_var, target_var = Variable(data, volatile=True), Variable(target, volatile=True)
# Forward pass
output = self.model(data_var)
cur_loss = self.loss(output, target_var)
# Top-1 and Top-5 Accuracy Calculation
cur_acc1, cur_acc5 = self.compute_accuracy(output.data, target, topk=(1, 5))
loss.update(cur_loss.data[0])
top1.update(cur_acc1[0])
top5.update(cur_acc5[0])
if cur_epoch != -1:
# Summary Writing
self.summary_writer.add_scalar("test-loss", loss.avg, cur_epoch)
self.summary_writer.add_scalar("test-top-1-acc", top1.avg, cur_epoch)
self.summary_writer.add_scalar("test-top-5-acc", top5.avg, cur_epoch)
print("Test Results" + " | " + "loss: " + str(loss.avg) + " - acc-top1: " + str(
top1.avg)[:7] + "- acc-top5: " + str(top5.avg)[:7])
def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, self.args.checkpoint_dir + filename)
if is_best:
shutil.copyfile(self.args.checkpoint_dir + filename,
self.args.checkpoint_dir + 'model_best.pth.tar')
def compute_accuracy(self, output, target, topk=(1,)):
"""Computes the accuracy@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, idx = output.topk(maxk, 1, True, True)
idx = idx.t()
correct = idx.eq(target.view(1, -1).expand_as(idx))
acc_arr = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
acc_arr.append(correct_k.mul_(1.0 / batch_size))
return acc_arr
def adjust_learning_rate(self, optimizer, epoch):
"""Sets the learning rate to the initial LR multiplied by 0.98 every epoch"""
learning_rate = self.args.learning_rate * (self.args.learning_rate_decay ** epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = learning_rate
def create_optimization(self):
self.loss = nn.CrossEntropyLoss()
if self.args.cuda:
self.loss.cuda()
self.optimizer = RMSprop(self.model.parameters(), self.args.learning_rate,
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
def load_pretrained_model(self):
try:
print("Loading ImageNet pretrained weights...")
pretrained_dict = torch.load(self.args.pretrained_path)
self.model.load_state_dict(pretrained_dict)
print("ImageNet pretrained weights loaded successfully.\n")
except:
print("No ImageNet pretrained weights exist. Skipping...\n")
def load_checkpoint(self, filename):
filename = self.args.checkpoint_dir + filename
try:
print("Loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
self.start_epoch = checkpoint['epoch']
self.best_top1 = checkpoint['best_top1']
self.model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
print("Checkpoint loaded successfully from '{}' at (epoch {})\n"
.format(self.args.checkpoint_dir, checkpoint['epoch']))
except:
print("No checkpoint exists from '{}'. Skipping...\n".format(self.args.checkpoint_dir))