From df96188dd23f71eef13f1422017959f1dd41f458 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 15 Dec 2024 19:12:24 +0100 Subject: [PATCH] test(chainable): capturable --- heavyball/chainable.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/heavyball/chainable.py b/heavyball/chainable.py index 7c8155d..2c1c5f0 100644 --- a/heavyball/chainable.py +++ b/heavyball/chainable.py @@ -421,23 +421,34 @@ def __init__(self, params, defaults, foreach: bool, *fns): def _step(self, group): if 'base_lr' not in group: group['base_lr'] = group['lr'] - step = group['step'] = group.get('step', 0) + 1 - if group['warmup_steps'] and step < group['warmup_steps']: - group['lr'] = group['base_lr'] * step / group['warmup_steps'] - else: - group['lr'] = group['base_lr'] vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group))) if not vals: return p, g = zip(*vals) + for param in p: + state = self.state_(param) + if 'step' not in state: + state['step'] = utils.scalar_guard(0, param) + step = state['step'].add_(1) + break + + group['step'] = step + + if group['warmup_steps'] and step < group['warmup_steps']: + group['lr'] = group['base_lr'] * step / group['warmup_steps'] + else: + group['lr'] = group['base_lr'] + if not group['foreach'] or len(p) == 1: for param, grad in zip(p, g): chain(self.state_, group, [grad], [param], *self.fns) - return + else: + chain(self.state_, group, g, p, *self.fns) - chain(self.state_, group, g, p, *self.fns) + group['lr'] = None + group['step'] = None use_default = object()