diff --git a/onmt/opts.py b/onmt/opts.py index f763e4f254..b73527a3ec 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -429,7 +429,7 @@ def train_opts(parser): nargs="*", default=None, help='Criteria to use for early stopping.') group.add('--optim', '-optim', default='sgd', - choices=['sgd', 'adagrad', 'adadelta', 'adam', + choices=['sgd', 'adagrad', 'adadelta', 'adam', 'lamb', 'sparseadam', 'adafactor', 'fusedadam'], help="Optimization method.") group.add('--adagrad_accumulator_init', '-adagrad_accumulator_init', @@ -466,6 +466,14 @@ def train_opts(parser): 'suggested a value of 0.98 for beta2, this parameter may ' 'not work well for normal models / default ' 'baselines.') + group.add('--lamb_beta1', '-lamb_beta1', type=float, default=0.9, + help="The beta1 parameter used by Lamb.") + group.add('--lamb_beta2', '-lamb_beta2', type=float, default=0.999, + help="The beta2 parameter used by Lamb.") + group.add('--lamb_eps', '-lamb_eps', type=float, default=1e-8, + help="The epsilon parameter used by Lamb.") + group.add('--lamb_wd', '-lamb_wd', type=float, default=0.0, + help="The weight decay parameter used by Lamb.") group.add('--label_smoothing', '-label_smoothing', type=float, default=0.0, help="Label smoothing value epsilon. " "Probabilities of all non-true labels " diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index 36039c33ee..603f6212de 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -6,6 +6,7 @@ import functools from copy import copy from math import sqrt +import math from onmt.utils.misc import fn_args @@ -82,6 +83,13 @@ def build_torch_optimizer(model, opt): params, lr=opt.learning_rate, betas=betas) + elif opt.optim == 'lamb': + optimizer = Lamb( + params, + lr=opt.learning_rate, + betas=(opt.lamb_beta1, opt.lamb_beta2), + eps=opt.lamb_eps, + weight_decay=opt.lamb_wd) else: raise ValueError('Invalid optimizer type: ' + opt.optim) @@ -517,3 +525,111 @@ def step(self, closure=None): p.data.add_(-group['weight_decay'] * lr_t, p.data) return loss + + +class Lamb(torch.optim.Optimizer): + """Implements Lamb algorithm. + Based on https://github.com/cybertronai/pytorch-lamb + which is itself based on `torch.optimizers.Adam`. + It has been proposed in `Reducing BERT Pre-Training Time + from 3 Days to 76 Minutes`_. + Arguments: + params (iterable): iterable of parameters to optimize or + dicts defining parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used + for computing running averages of gradient and + its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, + which turns this into Adam. Useful for comparison purposes. + .. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, adam=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}". + format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}". + format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + self.adam = adam + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Lamb does not support sparse gradients," + "consider SparseAdam instead.") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + # in the paper, exp_avg is m_t and exp_avg_sq is v_t + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + + # m = beta1 * m + (1 - beta1) * grad + exp_avg.mul_(beta1).add_(1 - beta1, grad) + # v = beta2 * m + (1 - beta2) * grad**2 + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + # Apply bias to lr to avoid broadcast. + step_size = group['lr'] * \ + math.sqrt(bias_correction2) / bias_correction1 + + adam_step = exp_avg / denom + # L2 norm uses sum, but here since we're dividing, + # use mean to avoid overflow. + r1 = p.data.pow(2).mean().sqrt() + r2 = adam_step.pow(2).mean().sqrt() + r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10) + state['r1'] = r1 + state['r2'] = r2 + state['r'] = r + if self.adam: + r = 1 + + p.data.add_(-step_size * r, adam_step) + + return loss