Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Eve Optimizer #475

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8b3f732
Added Eve Optimizer
wglao Jan 21, 2023
1374b6e
renamed for testing
wglao Jan 21, 2023
5e63aee
reverted rename
wglao Jan 21, 2023
3f3b2a0
reverted rename
wglao Jan 21, 2023
8a0cccd
added eve to build
wglao Jan 21, 2023
b2db79a
reverted accidental deletion
wglao Jan 21, 2023
5787426
typo
wglao Jan 21, 2023
5c46d4e
typo
wglao Jan 21, 2023
b61e47c
typo
wglao Jan 21, 2023
923ae62
alphabetized format
wglao Jan 21, 2023
06cb64d
typo
wglao Jan 21, 2023
23211e1
conform to optax api
wglao Jan 21, 2023
01ad74f
tests for eve
wglao Jan 21, 2023
6f1b3f9
added custom update function for eve state
wglao Jan 21, 2023
895d20f
update init
wglao Jan 21, 2023
349516f
documentation
wglao Jan 21, 2023
118d669
clearer documentation
wglao Jan 21, 2023
d6dc2f0
clearer documentation
wglao Jan 21, 2023
46a089f
clearer documentation
wglao Jan 21, 2023
a3485a4
eve passes all tests
wglao Jan 23, 2023
b19077d
typo
wglao Jan 23, 2023
3099b6c
remove unnecessary import
wglao Jan 23, 2023
252be75
update documentation
wglao Jan 23, 2023
12ef13e
update documentation
wglao Jan 23, 2023
ea0d0ca
update doc strings
wglao Jan 23, 2023
c23520f
update doc string
wglao Jan 23, 2023
74e4a12
formatting
wglao Jan 23, 2023
58d2659
formatting and typo
wglao Jan 23, 2023
77df2f3
formatting
wglao Jan 23, 2023
13975db
test
wglao Jan 23, 2023
af275fb
docs
wglao Jan 23, 2023
608e03b
remove test artifacts
wglao Jan 23, 2023
296094a
correct version name
wglao Jan 25, 2023
fe00057
limit the injectable hyperparams
wglao Apr 25, 2023
aba8dd4
bug fix with None
wglao Apr 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Common Optimizers
adamax
adamaxw
amsgrad
eve
fromage
lamb
lars
Expand Down Expand Up @@ -67,6 +68,11 @@ AMSGrad

.. autofunction:: amsgrad

Eve
~~~

.. autofunction:: eve

Fromage
~~~~~~~

Expand Down Expand Up @@ -289,6 +295,7 @@ Optax Transforms and States
.. autofunction:: scale_by_adamax
.. autofunction:: scale_by_amsgrad
.. autofunction:: scale_by_belief
.. autofunction:: scale_by_eve
.. autofunction:: scale_by_factored_rms
.. autofunction:: scale_by_novograd
.. autofunction:: scale_by_param_block_norm
Expand All @@ -310,6 +317,9 @@ Optax Transforms and States
.. autoclass:: ScaleByNovogradState
:members:

.. autoclass:: ScaleByEveState
:members:

.. autoclass:: ScaleByRmsState
:members:

Expand Down
8 changes: 7 additions & 1 deletion optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from optax._src.alias import adamw
from optax._src.alias import amsgrad
from optax._src.alias import dpsgd
from optax._src.alias import eve
from optax._src.alias import fromage
from optax._src.alias import lamb
from optax._src.alias import lars
Expand Down Expand Up @@ -130,6 +131,7 @@
from optax._src.transform import scale_by_adamax
from optax._src.transform import scale_by_amsgrad
from optax._src.transform import scale_by_belief
from optax._src.transform import scale_by_eve
from optax._src.transform import scale_by_novograd
from optax._src.transform import scale_by_optimistic_gradient
from optax._src.transform import scale_by_param_block_norm
Expand All @@ -145,6 +147,7 @@
from optax._src.transform import ScaleByAdamState
from optax._src.transform import ScaleByAmsgradState
from optax._src.transform import ScaleByBeliefState
from optax._src.transform import ScaleByEveState
from optax._src.transform import ScaleByNovogradState
from optax._src.transform import ScaleByRmsState
from optax._src.transform import ScaleByRssState
Expand Down Expand Up @@ -177,7 +180,7 @@
from optax._src.wrappers import skip_large_updates
from optax._src.wrappers import skip_not_finite

__version__ = "0.1.5.dev"
__version__ = "0.1.5.dev0"

__all__ = (
"adabelief",
Expand Down Expand Up @@ -223,6 +226,7 @@
"ema",
"EmaState",
"EmptyState",
"eve",
"exponential_decay",
"FactoredState",
"fisher_diag",
Expand Down Expand Up @@ -284,6 +288,7 @@
"scale_by_adamax",
"scale_by_amsgrad",
"scale_by_belief",
"scale_by_eve",
"scale_by_factored_rms",
"scale_by_novograd",
"scale_by_param_block_norm",
Expand All @@ -301,6 +306,7 @@
"ScaleByAdamState",
"ScaleByAmsgradState",
"ScaleByBeliefState",
"ScaleByEveState",
"ScaleByNovogradState",
"ScaleByRmsState",
"ScaleByRssState",
Expand Down
121 changes: 121 additions & 0 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from optax._src import combine
from optax._src import factorized
from optax._src import privacy
from optax._src import schedule
from optax._src import transform
from optax._src import wrappers

Expand Down Expand Up @@ -339,6 +340,126 @@ def amsgrad(
_scale_by_learning_rate(learning_rate),
)

def _eve(
a1: float = 1e-3,
b1: float = 0.9,
b2: float = 0.999,
b3: float = 0.999,
c: float = 10.,
eps: float = 1e-8,
f: float = 1.,
f_star: float = 0.,
mu_dtype: Optional[Any] = None,
) -> base.GradientTransformation:
"""The Eve optimizer (uninjectable, see `eve()`).

Eve is an SGD variant with adaptive global and local learning rates.
The local learning rate used for each weight is computed from estimates of
first- and second-order moments of the gradients (using suitable exponential
moving averages) as in ADAM. These are then scaled by the global learning
rate `a1`, which is adaptively modified by some notion of sub-optimality `d`:
increasing the global rate when far from optimal and decreasing it when
approaching optimality. This is also computed with exponential moving
averages, similar to the first and second moments.

References:
Hayashi et al, 2018: https://arXiv.org/abs/1611.01505

Args:
a1: this is the initial global scaling factor.
b1: the exponential decay rate to track the first moment of past gradients.
b2: the exponential decay rate to track the second moment of past gradients.
b3: the exponential decay rate to track the sub-optimality.
c: the clipping limit to prevent extreme global learning rate changes
eps: a small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
f: the current loss value. (needs to be injected before update is called)
f_star: estimation of the global minimum
mu_dtype: optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.

Returns:
the corresponding `GradientTransformation`

Note:
Eve requires an additional parameter: the loss for the current iteration::

f := f_t

ScaleByEveState also holds the loss from the previous iteration::

state.f_prev := f_{t-1}

Since it is up to the user to inject the current loss before calling the
update function, the `eve` alias returns an injectable state by default by
wrapping `_eve` in `inject_hyperparams`.
"""
return combine.chain(
transform.scale_by_eve(
b1=b1, b2=b2, b3=b3, c=c, eps=eps, f=f, f_star=f_star, mu_dtype=mu_dtype),
_scale_by_learning_rate(a1)
)


def eve(
a1: float = 1e-3,
b1: float = 0.9,
b2: float = 0.999,
b3: float = 0.999,
c: float = 10.,
eps: float = 1e-8,
f: float = 1.,
f_star: float = 0.,
mu_dtype: Optional[Any] = None,
) -> base.GradientTransformation:
"""Injectable Eve optimizer.

Eve requires an additional parameter: the loss for the current iteration::

f := f_t

ScaleByEveState also holds the loss from the previous iteration::

state.f_prev := f_{t-1}

Since it is up to the user to inject the current loss before calling the
update function, the `eve` alias returns an injectable state by default by
wrapping `_eve` in `inject_hyperparams`.

Args:
a1: this is the initial global scaling factor.
b1: the exponential decay rate to track the first moment of past gradients.
b2: the exponential decay rate to track the second moment of past gradients.
b3: the exponential decay rate to track the sub-optimality.
c: the clipping limit to prevent extreme global learning rate changes
eps: a small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
f: the current loss value. (needs to be injected before update is called)
f_star: estimation of the global minimum
mu_dtype: optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.

Returns:
the corresponding `GradientTransformation` wrapped in inject_hyperparams

Inject the current loss as follows:
-----------------------------------

Initialize::

optimizer = optax.eve()
opt_state = optimizer.init(params)

Train::

while training:
loss, grads = jax.value_and_grad(loss_fn)(params, data)
opt_state.hyperparams['f'] = loss # <-- Update state here
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
"""
return schedule.inject_hyperparams(_eve)(f=f)


def fromage(
learning_rate: float,
Expand Down
8 changes: 8 additions & 0 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='eve', opt_kwargs=dict(f=10)),
dict(opt_name='lars', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1e-3)),
dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, eta=1e-4)),
Expand Down Expand Up @@ -116,6 +117,9 @@ def step(params, state):
updates = get_updates(params)
if opt_name == 'dpsgd':
updates = updates[None]
elif opt_name == 'eve':
f = jnp.mean(jnp.square(params-final_params))
state.hyperparams['f'] = f
# Complex gradients need to be conjugated before being added to parameters
# https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
updates = jax.tree_util.tree_map(lambda x: x.conj(), updates)
Expand Down Expand Up @@ -144,6 +148,10 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
# https://github.com/deepmind/optax/issues/412.
opt_inject = schedule.inject_hyperparams(
opt_factory, static_args=('min_dim_size_to_factor',))(**opt_kwargs)
elif opt_name == 'eve':
# Eve is injectable by default. Reassign opt to uninjectable _eve alias
opt = alias._eve(**opt_kwargs)
opt_inject = opt_factory(**opt_kwargs)
else:
opt_inject = schedule.inject_hyperparams(opt_factory)(**opt_kwargs)

Expand Down
5 changes: 3 additions & 2 deletions optax/_src/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,9 @@ def wrapped_transform(*args, **kwargs) -> base.GradientTransformation:
other_hps[name] = value
elif callable(value):
sched_hps[name] = value
elif isinstance(value, (int, float, chex.Array)):
numeric_hps[name] = value
elif value is not None:
if isinstance(value, (int, float, chex.Array)):
numeric_hps[name] = value
else:
other_hps[name] = value

Expand Down
71 changes: 71 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,77 @@ def update_fn(updates, state, params=None):
return base.GradientTransformation(init_fn, update_fn)


class ScaleByEveState(NamedTuple):
"""State for the Eve algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: base.Updates
nu: base.Updates
d: float
f_prev: float


def scale_by_eve(
b1: float = 0.9,
b2: float = 0.999,
b3: float = 0.999,
c: float = 10.,
eps: float = 1e-8,
f: float = 1.,
f_star: float = 0.,
mu_dtype: Optional[Any] = None,
) -> base.GradientTransformation:
"""Rescale updates according to the Eve algorithm.

References:
[Hayashi et al, 2018](https://arxiv.org/abs/1611.01505)

Args:
b1: the exponential decay rate to track the first moment of past gradients.
b2: the exponential decay rate to track the second moment of past gradients.
b3: the exponential decay rate to track the sub-optimality.
c: the clipping limit to prevent extreme global learning rate changes
eps: a small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
f: the current loss value. (needs to be injected before update is called)
f_star: estimation of the global minimum
mu_dtype: optional `dtype` to be used for the first order accumulator; if
`None` then the `dtype` is inferred from `params` and `updates`.

Returns:
An (init_fn, update_fn) tuple.
"""
mu_dtype = utils.canonicalize_dtype(mu_dtype)

def init_fn(params):
mu = jax.tree_util.tree_map( # First moment
lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment
return ScaleByEveState(
count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f_prev=10.
)


def update_fn(updates: base.Updates, state: ScaleByEveState, params=None):
del params
mu = update_moment(updates, state.mu, b1, 1)
nu = update_moment_per_elem_norm(updates, state.nu, b2, 2)
count_inc = utils.numerics.safe_int32_increment(state.count)
mu_hat = jax.tree_util.tree_map(lambda m: m / (1-b1), mu)
nu_hat = jax.tree_util.tree_map(lambda v: v / (1-b2), nu)
d_new = jnp.abs(f-state.f_prev) /\
(jnp.min(jnp.array([f,state.f_prev]))-f_star)
d_tilde = jnp.clip(d_new,1/c,c)
d = jnp.where(count_inc > 1, b3*state.d + (1-b3)*d_tilde, 1.)
updates = jax.tree_util.tree_map(
lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat)
mu = utils.cast_tree(mu, mu_dtype)
return updates, ScaleByEveState(
count=count_inc, mu=mu, nu=nu, d=d, f_prev=f
)

return base.GradientTransformation(init_fn, update_fn)


ScaleState = base.EmptyState


Expand Down
1 change: 1 addition & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def setUp(self):
@parameterized.named_parameters([
('adam', transform.scale_by_adam),
('adamax', transform.scale_by_adamax),
('eve', transform.scale_by_eve),
('rmsprop', transform.scale_by_rms),
('stddev', transform.scale_by_stddev),
('trust_ratio', transform.scale_by_trust_ratio),
Expand Down