forked from gpauloski/kfac-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimizers.py
74 lines (65 loc) · 2.58 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import sys
import kfac
import torch.optim as optim
sys.path.append('..')
from utils import create_lr_schedule
def get_optimizer(model, args, batch_first=True):
use_kfac = True if args.kfac_update_freq > 0 else False
optimizer = optim.SGD(
model.parameters(),
lr=args.base_lr,
momentum=args.momentum,
weight_decay=args.weight_decay
)
if args.kfac_comm_method == 'comm-opt':
comm_method=kfac.CommMethod.COMM_OPT
elif args.kfac_comm_method == 'mem-opt':
comm_method=kfac.CommMethod.MEM_OPT
elif args.kfac_comm_method == 'hybrid-opt':
comm_method=kfac.CommMethod.HYBRID_OPT
else:
raise ValueError('Unknwon KFAC Comm Method: {}'.format(
args.kfac_comm_method))
if use_kfac:
preconditioner = kfac.KFAC(
model,
damping=args.damping,
factor_decay=args.stat_decay,
factor_update_freq=args.kfac_cov_update_freq,
inv_update_freq=args.kfac_update_freq,
kl_clip=args.kl_clip,
lr=args.base_lr,
batch_first=batch_first,
comm_method=comm_method,
distribute_layer_factors=not args.coallocate_layer_factors,
grad_scaler=args.grad_scaler if 'grad_scaler' in args else None,
grad_worker_fraction = args.kfac_grad_worker_fraction,
skip_layers=args.skip_layers,
use_eigen_decomp=not args.use_inv_kfac,
)
kfac_param_scheduler = kfac.KFACParamScheduler(
preconditioner,
damping_alpha=args.damping_alpha,
damping_schedule=args.damping_decay,
update_freq_alpha=args.kfac_update_freq_alpha,
update_freq_schedule=args.kfac_update_freq_decay
)
else:
preconditioner = None
if args.horovod:
import horovod.torch as hvd
optimizer = hvd.DistributedOptimizer(
optimizer,
named_parameters=model.named_parameters(),
compression=hvd.Compression.none,
op=hvd.Average,
backward_passes_per_step=args.batches_per_allreduce
)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
lrs = create_lr_schedule(args.backend.size(), args.warmup_epochs, args.lr_decay)
lr_scheduler = [optim.lr_scheduler.LambdaLR(optimizer, lrs)]
if use_kfac:
lr_scheduler.append(optim.lr_scheduler.LambdaLR(preconditioner, lrs))
lr_scheduler.append(kfac_param_scheduler)
return optimizer, preconditioner, lr_scheduler