diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index a3c156c4..7bc5123b 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -278,7 +278,7 @@ def skip_large_updates(updates: base.Updates, class MultiSteps: """An optimizer wrapper to accumulate gradients over multiple steps. - This wrapper collects together the updates passed to its `update` function + This wrapper collects together the updates passed to its ``update`` function over consecutive steps until a given number of scheduled steps is reached. In each of these intermediate steps, the returned value from the optimizer is a tree of zeros of the same shape of the updates passed as input. @@ -305,15 +305,15 @@ def __init__( Args: opt: the wrapped optimizer. - every_k_schedule: an int or f a function. + every_k_schedule: an int or a function. * As a function, it returns how many mini-steps should be accumulated in a single gradient step. Its only argument is the current gradient step count. By varying the returned value, users can vary the overall training batch size. - * If an `int`, this is the constant number of mini-steps per gradient + * If an ``int``, this is the constant number of mini-steps per gradient update. - use_grad_mean: if `True` (the default), gradients accumulated over + use_grad_mean: if ``True`` (the default), gradients accumulated over multiple mini-steps are averaged. Otherwise, they are summed. should_skip_update_fn: if provided, this function is used to decide when to accept or reject the updates from a mini-step. When a mini-step is @@ -405,7 +405,7 @@ def _do_update(updates, state, params): * numerics.safe_int32_increment(state.gradient_step) + (1 - emit) * state.gradient_step, inner_opt_state=jax.tree_util.tree_map( - lambda st, nst: (1 - emit) * st + emit * nst, + lambda st, nst: jnp.where(emit, nst, st), state.inner_opt_state, new_inner_state, ), diff --git a/optax/_src/wrappers_test.py b/optax/_src/wrappers_test.py index 1b7d3146..79e049b8 100644 --- a/optax/_src/wrappers_test.py +++ b/optax/_src/wrappers_test.py @@ -270,6 +270,19 @@ def test_multi_steps_every_k_schedule(self): _, opt_state = opt_update(grad, opt_state, params) self.assertTrue(ms_opt.has_updated(opt_state)) + def test_multi_steps_zero_nans(self): + # Test that MultiStep is compatible with zero_nans + # https://github.com/google-deepmind/optax/issues/828 + ms_opt = wrappers.MultiSteps( + combine.chain(constrain.zero_nans(), alias.sgd(1e-4)), + every_k_schedule=2 + ) + opt_init, opt_update = ms_opt.gradient_transformation() + params = dict(a=jnp.zeros([])) + opt_state = opt_init(params) + grad = dict(a=jnp.zeros([])) + opt_update(grad, opt_state, params) + def test_multi_steps_computes_mean(self): k_steps = 4 ms_opt = wrappers.MultiSteps(