Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds AdeMAMix Optimizer to contrib #1104

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
daed085
getting ademamix docs and notebook running
mathDR Oct 13, 2024
c47261c
fixed imports
mathDR Oct 13, 2024
bf2d4a8
fixed linting errors
mathDR Oct 14, 2024
7cb270a
ran notebook in order
mathDR Oct 14, 2024
1c8f9de
Merge branch 'google-deepmind:main' into main
mathDR Oct 14, 2024
c252b51
implementing pr feedback
mathDR Oct 18, 2024
f9b6559
updated ademamix with author docstrings
mathDR Oct 21, 2024
4b621aa
added docstrings and matched adamw api
mathDR Oct 21, 2024
1a51332
removed unneeded alpha scheduler
mathDR Oct 21, 2024
4eb6065
added alpha as a scheduler
mathDR Oct 21, 2024
4eb618a
removed b3_scheduler
mathDR Oct 21, 2024
9442085
removed b3_scheduler
mathDR Oct 21, 2024
420771f
fixing tests with new docstrings
mathDR Oct 21, 2024
9690311
fixed docstring
mathDR Oct 21, 2024
f892e29
updated notebook
mathDR Oct 22, 2024
81da0e4
fixed import ordering
mathDR Oct 22, 2024
94e3f0a
updated references using rst format
mathDR Oct 24, 2024
40c0e6e
updated docstrings
mathDR Oct 24, 2024
fb095e1
fixed linting
mathDR Oct 24, 2024
af8f22d
synced ademamix api to adamw
mathDR Oct 24, 2024
47e4248
added defaults to scale_by_ademamix
mathDR Oct 24, 2024
934761a
fixed syntaxerror
mathDR Oct 24, 2024
1b72988
updated docstrings
mathDR Oct 24, 2024
3e0699b
fixed typo
mathDR Oct 24, 2024
3ff8aba
reformatting note
mathDR Oct 24, 2024
c933d17
fixing formatting issues
mathDR Oct 24, 2024
bdeb3e2
reformatting note
mathDR Oct 24, 2024
8d38058
reformatting note
mathDR Oct 24, 2024
6f4ec8a
fixed ademamix docstring
mathDR Oct 24, 2024
e8ba763
fixed notebook ordering and line lengths
mathDR Oct 24, 2024
c84ce49
added docs image for ademamix
mathDR Oct 30, 2024
dac4260
added ademamix example to gallery
mathDR Oct 30, 2024
7c57abd
reran notebook with colab link
mathDR Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Experimental features and algorithms that don't meet the

.. autosummary::
acprop
ademamix
cocob
COCOBState
dadapt_adamw
Expand Down Expand Up @@ -37,6 +38,12 @@ Experimental features and algorithms that don't meet the
split_real_and_imaginary
SplitRealAndImaginaryState

AdEMAMix
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ademamix
.. autofunction:: scale_by_ademamix
.. autoclass:: ScaleByAdemamixState

Asynchronous-centering-Prop
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: acprop
Expand Down
16 changes: 16 additions & 0 deletions docs/gallery.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,22 @@ Examples that make use of the :doc:`api/contrib` module.
<div class="sphx-glr-thumbnail-title">Sharpness-Aware Minimization (SAM).</div>
</div>

.. raw:: html

<div class="sphx-glr-thumbcontainer" tooltip="AdEMAMix.">

.. only:: html

.. image:: /images/examples/contrib/ademamix_rosenbrock.png
:alt:

:doc:`_collections/examples/contrib/ademamix_rosenbrock`

.. raw:: html

<div class="sphx-glr-thumbnail-title">AdEMAMix.</div>
</div>


.. raw:: html

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
435 changes: 435 additions & 0 deletions examples/contrib/rosenbrock_ademamix.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

from optax.contrib._acprop import acprop
from optax.contrib._acprop import scale_by_acprop
from optax.contrib._ademamix import ScaleByAdemamixState
from optax.contrib._ademamix import scale_by_ademamix
from optax.contrib._ademamix import ademamix
from optax.contrib._cocob import cocob
from optax.contrib._cocob import COCOBState
from optax.contrib._cocob import scale_by_cocob
Expand Down
259 changes: 259 additions & 0 deletions optax/contrib/_ademamix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
"""AdEMAMix.

Implementation of
"THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER"
(https://arxiv.org/pdf/2409.03137) by Matteo Pagliardini,
Pierre Ablin and David Grangier.
"""

from typing import Any, Callable, NamedTuple, Optional, Tuple, Union
import chex
import jax.numpy as jnp
import jax.tree_util as jtu
from optax._src import base
from optax._src import combine
from optax._src import numerics
from optax._src import transform
from optax._src import utils
import optax.tree_utils as otu

class ScaleByAdemamixState(NamedTuple):
"""State for the Ademamix algorithm.

Attributes:
count: iteration of the algorithm used to update the fast EMA and
second moment.
count_m2: iteration of the algorithm used to update the slow EMA and alpha.
m1: fast EMA of the first moment
m2: slow EMA of the first moment
nu: estimate of the second moment
"""

count: chex.Array # shape=(), dtype=jnp.int32.
count_m2: chex.Array # shape=(), dtype=jnp.int32.
m1: base.Updates
m2: base.Updates
nu: base.Updates


def scale_by_ademamix(
b1: float = 0.9,
b2: float = 0.999,
b3: base.ScalarOrSchedule = 0.9999,
alpha: base.ScalarOrSchedule = 6.0,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Scale updates according to the Ademamix algorithm.

See :func:`optax.contrib.ademamix.` for a full description of the algorithm.

References:
mathDR marked this conversation as resolved.
Show resolved Hide resolved
Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older
<https://arxiv.org/abs/2409.03137>`_, 2024

Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the fast EMA.
b2: Exponential decay rate to track the second moment of past gradients.
b3: Exponential decay rate to track the slow EMA.
alpha: Mixing coefficient in the linear combination fo the fast and
slow EMAs.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.

Returns:
The corresponding `GradientTransformation`.
"""

mu_dtype = utils.canonicalize_dtype(mu_dtype)

def init_fn(params):
m1 = otu.tree_zeros_like(params) # fast EMA
m2 = otu.tree_zeros_like(params) # slow EMA
nu = otu.tree_zeros_like(params) # second moment estimate
return ScaleByAdemamixState(
count=jnp.zeros([], jnp.int32),
count_m2=jnp.zeros([], jnp.int32),
m1=m1,
m2=m2,
nu=nu,
)

def update_fn(
updates: jtu.tree_map, state, params=None
) -> Tuple[jtu.tree_map, ScaleByAdemamixState]:
del params
c_b3 = b3(state.count_m2) if callable(b3) else b3
c_alpha = (
alpha(state.count_m2) if callable(alpha) else alpha
)
m1 = otu.tree_update_moment(
updates, state.m1, b1, order=1
) # m1 = b1 * m1 + (1-b1) * updates
m2 = otu.tree_update_moment(updates, state.m2, c_b3, order=1)
nu = otu.tree_update_moment_per_elem_norm(updates, state.nu, b2, order=2)
count_inc = numerics.safe_int32_increment(state.count)
count_m2_inc = numerics.safe_int32_increment(state.count_m2)
m1_hat = otu.tree_bias_correction(m1, b1, count_inc)
vroulet marked this conversation as resolved.
Show resolved Hide resolved
# NOTE: AdEMAMix does not perform bias correction on b2 to let
# the slow EMA momentum buffer fill itself slowly.
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
updates = jtu.tree_map(
lambda m1_, m2_, v_: ((m1_ + c_alpha * m2_) / (jnp.sqrt(v_+eps_root)
+ eps)),
m1_hat,
m2,
nu_hat,
)
return updates, ScaleByAdemamixState(
count=count_inc, count_m2=count_m2_inc, m1=m1, m2=m2, nu=nu
)

return base.GradientTransformation(init_fn, update_fn)


def ademamix(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
b3: base.ScalarOrSchedule = 0.9999,
alpha: base.ScalarOrSchedule = 5.0,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[Any] = None,
weight_decay: float = 0.0,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
r"""AdEMAMix.

AdEMAMix (Adaptive EMA Mixture) is AdamW with a mixture of two momentum
terms to better take advantage of historical gradients.

Both SGD with momemtum (SGD+M) and Adam incorporate momentum using
Exponential Moving Averages (EMAs) of past gradients

Let :math:`\eta` represent the learning rate and :math:`\beta_1, \beta_2`,
:math:`\beta_3, \alpha, \varepsilon, \bar{\varepsilon}`, represent the
arguments ``b1``, ``b2``, ``b3``, ``alpha``, ``eps`` and ``eps_root``
respectively. Let :math:`\lambda` be the weight decay and :math:`\theta_t`
the parameter vector at time :math:`t`.

The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m^{(1)}_0, m^{(2)}_0, \nu_0) = (0, 0, 0)`, representing initial
estimates for the fast and slow EMAs of the first moment along with the second
moment estimate. In practice, these values are stored as pytrees containing
all zeros, with the same shape as the model updates. At step :math:`t`,
the ``update`` function of this optimizer takes as arguments the incoming
gradients :math:`g^t`, the optimizer state :math:`S^t` and the parameters
:math:`\theta^{(t)}`. It then computes updates :math:`\theta^{(t+1)}` and the
new state :math:`S^{(t+1)}`. Thus, for :math:`t > 0`, we have,

.. math::

\begin{align*}
m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + (1-\beta_1)
\cdot g^{(t)} \\
m_2^{(t)} &\leftarrow \beta_3 \cdot m_2^{(t-1)} + (1-\beta_3) \cdot
g^{(t)} \\
\nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot
{g^{(t)}}^2 \\
\hat{m_1}^{(t)} &\leftarrow m_1^{(t)} / {(1-\beta_1^{(t)})} \\
\hat{\nu}^{(t)} &\leftarrow \nu^{(t)} / {(1-\beta_2^{(t)})} \\
\theta^{(t)} &\leftarrow \theta^{(t-1)} - \eta \cdot \left(
\frac{(\hat{m_1}^{(t)} + \alpha m_2^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)}
+ \bar{\varepsilon}} + \varepsilon\right)} + \lambda \theta^{(t-1)}
\right).\\
S^{(t)} &\leftarrow (m_1^{(t)}, m_2^{(t)}, v^{(t)}).
\end{align*}

.. note::

AdEMAMix consists in leveraging very old gradients. Therefore,
the method is best suited to settings where the number of iterations is
important. The paper reports on this effect in Appendix C.1.5, showing how
smaller values of ``b3`` (e.g. ``b3 = 0.999``) can be better for low
iterations scenarios. Moreover, retaining gradient information over many
thousands of steps can pose a problem in domains requiring fast adaptation
to a sudden distribution shift, or general cases in which the distribution
is non-stationary.

Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(jnp.square(x)) # simple quadratic function
>>> solver = optax.contrib.ademamix(learning_rate=0.01)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.36E+01
Objective function: 1.35E+01
Objective function: 1.34E+01

References:
Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older
<https://arxiv.org/abs/2409.03137>`_, 2024

Args:
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
b1: Exponential decay rate to track the fast EMA.
b2: Exponential decay rate to track the second moment of past gradients.
b3: Exponential decay rate to track the slow EMA.
alpha: Mixing coefficient in the linear combination fo the fast and
slow EMAs.
eps: A small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
instance when computing (meta-)gradients through Adam.
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.
weight_decay: Strength of the weight decay regularization. Note that this
weight decay is multiplied with the learning rate. This is consistent
with other frameworks such as PyTorch, but different from
(Loshchilov et al, 2019) where the weight decay is only multiplied with
the "schedule multiplier", but not the base learning rate.
mask: A tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.

Returns:
The corresponding `GradientTransformation`.

.. seealso::
See the related functions :func:`optax.adam`, :func:`optax.nadamw`, as well
as the example :doc:`../_collections/examples/contrib/rosenbrock_ademamix`
for a use case.
"""
return combine.chain(
scale_by_ademamix(
b1=b1,
b2=b2,
b3=b3,
alpha=alpha,
eps=eps,
eps_root=eps_root,
mu_dtype=mu_dtype
),
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate),
)
1 change: 1 addition & 0 deletions optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
# Testing contributions coded as GradientTransformations
_MAIN_OPTIMIZERS_UNDER_TEST = [
dict(opt_name='acprop', opt_kwargs=dict(learning_rate=1e-3)),
dict(opt_name='ademamix', opt_kwargs=dict(learning_rate=1e-3)),
dict(opt_name='cocob', opt_kwargs={}),
dict(opt_name='cocob', opt_kwargs=dict(weight_decay=1e-2)),
dict(opt_name='dadapt_adamw', opt_kwargs=dict(learning_rate=1e-1)),
Expand Down
Loading