-
Notifications
You must be signed in to change notification settings - Fork 34
/
lr_scheduler.py
executable file
·142 lines (124 loc) · 5.48 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import math
from collections import Counter
from collections import defaultdict
import torch
from torch.optim.lr_scheduler import _LRScheduler
class MultiStepLR_Restart(_LRScheduler):
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
clear_state=False, last_epoch=-1):
self.milestones = Counter(milestones)
self.gamma = gamma
self.clear_state = clear_state
self.restarts = restarts if restarts else [0]
self.restart_weights = weights if weights else [1]
assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.'
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch in self.restarts:
if self.clear_state:
self.optimizer.state = defaultdict(dict)
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [
group['lr'] * self.gamma**self.milestones[self.last_epoch]
for group in self.optimizer.param_groups
]
class CosineAnnealingLR_Restart(_LRScheduler):
def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
self.T_period = T_period
self.T_max = self.T_period[0] # current T period
self.eta_min = eta_min
self.restarts = restarts if restarts else [0]
self.restart_weights = weights if weights else [1]
self.last_restart = 0
assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.'
super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch == 0:
return self.base_lrs
elif self.last_epoch in self.restarts:
self.last_restart = self.last_epoch
self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
return [
group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
(1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
(group['lr'] - self.eta_min) + self.eta_min
for group in self.optimizer.param_groups]
if __name__ == "__main__":
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
betas=(0.9, 0.99))
##############################
# MultiStepLR_Restart
##############################
## Original
lr_steps = [200000, 400000, 600000, 800000]
restarts = None
restart_weights = None
## two
lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
restarts = [500000]
restart_weights = [1]
## four
lr_steps = [
50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
]
restarts = [250000, 500000, 750000]
restart_weights = [1, 1, 1]
scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
clear_state=False)
##############################
# Cosine Annealing Restart
##############################
## two
T_period = [500000, 500000]
restarts = [500000]
restart_weights = [1]
## four
T_period = [250000, 250000, 250000, 250000]
restarts = [250000, 500000, 750000]
restart_weights = [1, 1, 1]
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
weights=restart_weights)
##############################
# Draw figure
##############################
N_iter = 1000000
lr_l = list(range(N_iter))
for i in range(N_iter):
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
lr_l[i] = current_lr
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.ticker as mtick
mpl.style.use('default')
import seaborn
seaborn.set(style='whitegrid')
seaborn.set_context('paper')
plt.figure(1)
plt.subplot(111)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
plt.title('Title', fontsize=16, color='k')
plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
legend = plt.legend(loc='upper right', shadow=False)
ax = plt.gca()
labels = ax.get_xticks().tolist()
for k, v in enumerate(labels):
labels[k] = str(int(v / 1000)) + 'K'
ax.set_xticklabels(labels)
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
ax.set_ylabel('Learning rate')
ax.set_xlabel('Iteration')
fig = plt.gcf()
plt.show()