Skip to content

Commit

Permalink
fix(chainable): correct nesterov momentum
Browse files Browse the repository at this point in the history
  • Loading branch information
ClashLuke committed Dec 8, 2024
1 parent dfc3299 commit a7791f0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions heavyball/chainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ def orthogonalize_update(group, update, grad, param):
@zero_guard("momentum")
@no_state
def nesterov_momentum(group, updates, grads, params, momentum):
utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))


@zero_guard("momentum")
@no_state
def heavyball_momentum(group, updates, grads, params, momentum):
utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))


@zero_guard("exp_avg", "exp_avg_sq")
Expand Down

0 comments on commit a7791f0

Please sign in to comment.