-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizer.py
120 lines (97 loc) · 4.52 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import catalyst
from collections import defaultdict
import torch
from torch.optim.optimizer import Optimizer
class Lookahead(Optimizer):
r"""PyTorch implementation of the lookahead wrapper.
Lookahead Optimizer: https://arxiv.org/abs/1907.08610
"""
def __init__(self, optimizer, la_steps=5, la_alpha=0.8, pullback_momentum="none"):
"""optimizer: inner optimizer
la_steps (int): number of lookahead steps
la_alpha (float): linear interpolation factor. 1.0 recovers the inner optimizer.
pullback_momentum (str): change to inner optimizer momentum on interpolation update
"""
self.optimizer = optimizer
self._la_step = 0 # counter for inner optimizer
self.la_alpha = la_alpha
self._total_la_steps = la_steps
pullback_momentum = pullback_momentum.lower()
assert pullback_momentum in ["reset", "pullback", "none"]
self.pullback_momentum = pullback_momentum
self.state = defaultdict(dict)
# Cache the current optimizer parameters
for group in optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['cached_params'] = torch.zeros_like(p.data)
param_state['cached_params'].copy_(p.data)
if self.pullback_momentum == "pullback":
param_state['cached_mom'] = torch.zeros_like(p.data)
def __getstate__(self):
return {
'state': self.state,
'optimizer': self.optimizer,
'la_alpha': self.la_alpha,
'_la_step': self._la_step,
'_total_la_steps': self._total_la_steps,
'pullback_momentum': self.pullback_momentum
}
def zero_grad(self):
self.optimizer.zero_grad()
def get_la_step(self):
return self._la_step
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
def _backup_and_load_cache(self):
"""Useful for performing evaluation on the slow weights (which typically generalize better)
"""
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['backup_params'] = torch.zeros_like(p.data)
param_state['backup_params'].copy_(p.data)
p.data.copy_(param_state['cached_params'])
def _clear_and_load_backup(self):
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.copy_(param_state['backup_params'])
del param_state['backup_params']
@property
def param_groups(self):
return self.optimizer.param_groups
def step(self, closure=None):
"""Performs a single Lookahead optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = self.optimizer.step(closure)
self._la_step += 1
if self._la_step >= self._total_la_steps:
self._la_step = 0
# Lookahead and cache the current optimizer parameters
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.mul_(self.la_alpha).add_(1.0 - self.la_alpha, param_state['cached_params']) # crucial line
param_state['cached_params'].copy_(p.data)
if self.pullback_momentum == "pullback":
internal_momentum = self.optimizer.state[p]["momentum_buffer"]
self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.la_alpha).add_(
1.0 - self.la_alpha, param_state["cached_mom"])
param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
elif self.pullback_momentum == "reset":
self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)
return loss
def get_optim(optim_name, model, lr):
d = {
'adam': torch.optim.Adam(model.parameters(), lr),
'radam': catalyst.contrib.nn.optimizers.radam.RAdam(model.parameters(), lr),
'lookahead': Lookahead(catalyst.contrib.nn.optimizers.radam.RAdam(model.parameters(), lr))
}
return d[optim_name]