From a7791f09ca75a01b8fd355382f624c98d85dc8df Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 8 Dec 2024 23:47:40 +0100 Subject: [PATCH] fix(chainable): correct nesterov momentum --- heavyball/chainable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heavyball/chainable.py b/heavyball/chainable.py index f110c41..9caba15 100644 --- a/heavyball/chainable.py +++ b/heavyball/chainable.py @@ -277,13 +277,13 @@ def orthogonalize_update(group, update, grad, param): @zero_guard("momentum") @no_state def nesterov_momentum(group, updates, grads, params, momentum): - utils.nesterov_momentum(momentum, updates, utils.get_beta1(group)) + return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group)) @zero_guard("momentum") @no_state def heavyball_momentum(group, updates, grads, params, momentum): - utils.heavyball_momentum(momentum, updates, utils.get_beta1(group)) + return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group)) @zero_guard("exp_avg", "exp_avg_sq")