diff --git a/megatron/arguments.py b/megatron/arguments.py index b3ed06353e..2a0ac606ce 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -780,6 +780,15 @@ def _add_regularization_args(parser): help='Weight decay increment function.') group.add_argument('--clip-grad', type=float, default=1.0, help='Gradient clipping based on global L2 norm.') + group.add_argument('--sophiag-beta1', type=float, default=0.9, + help='First coefficient for computing running averages ' + 'of gradient and its hessian') + group.add_argument('--sophiag-beta2', type=float, default=0.95, + help='Second coefficient for computing running averages ' + 'of gradient and its hessian') + group.add_argument('--sophiag-rho', type=float, default=0.01, + help='SophiaG clipping threshhold') + group.add_argument('--adam-beta1', type=float, default=0.9, help='First coefficient for computing running averages ' 'of gradient and its square') @@ -946,6 +955,7 @@ def _add_training_args(parser): choices=[ 'adam', 'adamw', + 'sophiag', 'sgd', 'ds.fusedlamb', 'ipex.lamb', diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 48f2737a06..99145ff4f4 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -315,6 +315,15 @@ def optimizer_hook(p): weight_decay=args.weight_decay, momentum=args.sgd_momentum ) + elif str(args.optimizer).lower() == 'sophiag': + from .sophia import SophiaG + optimizer = SophiaG( + param_groups, + lr=args.lr, + betas=(args.sophiag_beta1, args.sophiag_beta2), + rho = args.sophiag_rho, + weight_decay=args.weight_decay + ) else: raise TypeError(f'{args.optimizer} optimizer is not supported.') if args.deepspeed: diff --git a/megatron/optimizer/sophia.py b/megatron/optimizer/sophia.py new file mode 100644 index 0000000000..4c4e074790 --- /dev/null +++ b/megatron/optimizer/sophia.py @@ -0,0 +1,202 @@ +import math +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer +from typing import List, Optional + + +#SOphiaG implementation from https://github.com/Liuhong99/Sophia/blob/main/sophia.py, copy pasted here because no pip and not sure about submodules + +class SophiaG(Optimizer): + def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho = 0.04, + weight_decay=1e-1, *, maximize: bool = False, + capturable: bool = False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= rho: + raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, rho=rho, + weight_decay=weight_decay, + maximize=maximize, capturable=capturable) + super(SophiaG, self).__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('maximize', False) + group.setdefault('capturable', False) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def update_hessian(self): + for group in self.param_groups: + beta1, beta2 = group['betas'] + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) + + + @torch.no_grad() + def step(self, closure=None, bs=5120): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + state_steps = [] + hessian = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError('Hero does not support sparse gradients') + grads.append(p.grad) + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + state_steps.append(state['step']) + hessian.append(state['hessian']) + + if self.defaults['capturable']: + bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs + + sophiag(params_with_grad, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=group['rho'], + lr=group['lr'], + weight_decay=group['weight_decay'], + maximize=group['maximize'], + capturable=group['capturable']) + + return loss + +def sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + capturable: bool = False, + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool): + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + + + func = _single_tensor_sophiag + + func(params, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=rho, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable) + +def _single_tensor_sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool, + capturable: bool): + + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + hess = hessian[i] + step_t = state_steps[i] + + if capturable: + assert param.is_cuda and step_t.is_cuda and bs.is_cuda + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + hess = torch.view_as_real(hess) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + if capturable: + step_size = lr + step_size_neg = step_size.neg() + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) + else: + step_size_neg = - lr + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)