Skip to content

Commit

Permalink
test(chainable): capturable
Browse files Browse the repository at this point in the history
  • Loading branch information
ClashLuke committed Dec 15, 2024
1 parent 384863c commit df96188
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions heavyball/chainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,23 +421,34 @@ def __init__(self, params, defaults, foreach: bool, *fns):
def _step(self, group):
if 'base_lr' not in group:
group['base_lr'] = group['lr']
step = group['step'] = group.get('step', 0) + 1
if group['warmup_steps'] and step < group['warmup_steps']:
group['lr'] = group['base_lr'] * step / group['warmup_steps']
else:
group['lr'] = group['base_lr']

vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
if not vals:
return
p, g = zip(*vals)

for param in p:
state = self.state_(param)
if 'step' not in state:
state['step'] = utils.scalar_guard(0, param)
step = state['step'].add_(1)
break

group['step'] = step

if group['warmup_steps'] and step < group['warmup_steps']:
group['lr'] = group['base_lr'] * step / group['warmup_steps']
else:
group['lr'] = group['base_lr']

if not group['foreach'] or len(p) == 1:
for param, grad in zip(p, g):
chain(self.state_, group, [grad], [param], *self.fns)
return
else:
chain(self.state_, group, g, p, *self.fns)

chain(self.state_, group, g, p, *self.fns)
group['lr'] = None
group['step'] = None


use_default = object()
Expand Down

0 comments on commit df96188

Please sign in to comment.