forked from 4uiiurz1/keras-arcface
-
Notifications
You must be signed in to change notification settings - Fork 1
/
scheduler.py
28 lines (23 loc) · 1 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import math
from keras.callbacks import Callback
from keras import backend as K
class CosineAnnealingScheduler(Callback):
"""Cosine annealing scheduler.
"""
def __init__(self, T_max, eta_max, eta_min=0, verbose=0):
super(CosineAnnealingScheduler, self).__init__()
self.T_max = T_max
self.eta_max = eta_max
self.eta_min = eta_min
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
lr = self.eta_min + (self.eta_max - self.eta_min) * (1 + math.cos(math.pi * epoch / self.T_max)) / 2
K.set_value(self.model.optimizer.lr, lr)
if self.verbose > 0:
print('\nEpoch %05d: CosineAnnealingScheduler setting learning '
'rate to %s.' % (epoch + 1, lr))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = K.get_value(self.model.optimizer.lr)