forked from nanowell/AdEMAMix-Optimizer-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AdEMAMix.py
120 lines (100 loc) · 4.72 KB
/
AdEMAMix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import math
import torch
from torch.optim import Optimizer
class AdEMAMix(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9999), eps=1e-8,
weight_decay=0, alpha=5.0, T_alpha_beta3=None):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
assert len(betas) == 3, f"Invalid beta parameters: {betas}, expected 3"
assert all(0.0 <= beta < 1.0 for beta in betas), f"Invalid beta parameters: {betas}"
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
alpha=alpha, T_alpha_beta3=T_alpha_beta3)
super(AdEMAMix, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdEMAMix, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
exp_avg_slow = []
state_steps = []
for p in group['params']:
if p.grad is not None:
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdEMAMix does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Slow exponential moving average
state['exp_avg_slow'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
exp_avg_slow.append(state['exp_avg_slow'])
state['step'] += 1
state_steps.append(state['step'])
beta1, beta2, beta3 = group['betas']
alpha = group['alpha']
T_alpha_beta3 = group['T_alpha_beta3']
self._update_adamemix(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
exp_avg_slow,
state_steps,
beta1=beta1,
beta2=beta2,
beta3=beta3,
alpha=alpha,
T_alpha_beta3=T_alpha_beta3,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
)
return loss
def _update_adamemix(self, params, grads, exp_avgs, exp_avg_sqs, exp_avg_slow, state_steps,
beta1, beta2, beta3, alpha, T_alpha_beta3, lr, weight_decay, eps):
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
exp_avg_slow_i = exp_avg_slow[i]
step = state_steps[i]
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
if T_alpha_beta3 is not None:
alpha_t = min(step * alpha / T_alpha_beta3, alpha)
beta3_t = min(math.exp(math.log(beta1) * math.log(beta3) /
((1 - step / T_alpha_beta3) * math.log(beta3) +
(step / T_alpha_beta3) * math.log(beta1))), beta3)
else:
alpha_t = alpha
beta3_t = beta3
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg_slow_i.mul_(beta3_t).add_(grad, alpha=1 - beta3_t)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
if weight_decay != 0:
param.add_(param, alpha=-weight_decay * lr)
param.addcdiv_(exp_avg + alpha_t * exp_avg_slow_i, denom, value=-step_size)