-
Notifications
You must be signed in to change notification settings - Fork 14
/
optimizer.py
31 lines (23 loc) · 1.03 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
def build_optimizer(model, optim, lr, wd, momentum):
def _no_bias_decay(model):
has_decay = []
no_decay = []
skip_list = ['relative_position_bias_table', 'pe']
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list):
no_decay.append(param)
else:
has_decay.append(param)
assert len(list(model.parameters())) == len(has_decay) + len(no_decay), '{} vs. {}'.format(
len(list(model.parameters())), len(has_decay) + len(no_decay))
return [{'params': has_decay},
{'params': no_decay, 'weight_decay': 0.}]
parameters = _no_bias_decay(model)
kwargs = dict(lr=lr, weight_decay=wd)
if optim.lower() == 'SGD':
kwargs['momentum'] = momentum
optimizer = getattr(torch.optim, optim)(params=parameters, **kwargs)
return optimizer