-
Notifications
You must be signed in to change notification settings - Fork 8
/
optimizer.py
27 lines (23 loc) · 881 Bytes
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
def build_optimizer(args, model):
params_with_decay = []
params_without_decay = []
for name, param in model.named_parameters():
if param.requires_grad is False:
continue
if args.filter_biases_wd and (len(param.shape) == 1 or name.endswith("bias")):
params_without_decay.append(param)
else:
params_with_decay.append(param)
if args.filter_biases_wd:
param_groups = [
{"params": params_without_decay, "weight_decay": 0.0},
{"params": params_with_decay, "weight_decay": args.weight_decay},
]
else:
param_groups = [
{"params": params_with_decay, "weight_decay": args.weight_decay},
]
optimizer = torch.optim.AdamW(param_groups, lr=args.base_lr)
return optimizer