Skip to content

Commit

Permalink
Avoid multiplication of boolean arrays
Browse files Browse the repository at this point in the history
Instead, use a jnp.where statement. That avoids type promotion of boolean arrays to integer arrays.

Fixes issue #828

Also made cosmetic changes to the docstring

PiperOrigin-RevId: 612388782
  • Loading branch information
fabianp authored and OptaxDev committed Mar 4, 2024
1 parent d1a958c commit 6de95bf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
10 changes: 5 additions & 5 deletions optax/_src/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def skip_large_updates(updates: base.Updates,
class MultiSteps:
"""An optimizer wrapper to accumulate gradients over multiple steps.
This wrapper collects together the updates passed to its `update` function
This wrapper collects together the updates passed to its ``update`` function
over consecutive steps until a given number of scheduled steps is reached.
In each of these intermediate steps, the returned value from the optimizer is
a tree of zeros of the same shape of the updates passed as input.
Expand All @@ -305,15 +305,15 @@ def __init__(
Args:
opt: the wrapped optimizer.
every_k_schedule: an int or f a function.
every_k_schedule: an int or a function.
* As a function, it returns how many mini-steps should be accumulated
in a single gradient step. Its only argument is the current
gradient step count. By varying the returned value, users can vary the
overall training batch size.
* If an `int`, this is the constant number of mini-steps per gradient
* If an ``int``, this is the constant number of mini-steps per gradient
update.
use_grad_mean: if `True` (the default), gradients accumulated over
use_grad_mean: if ``True`` (the default), gradients accumulated over
multiple mini-steps are averaged. Otherwise, they are summed.
should_skip_update_fn: if provided, this function is used to decide when
to accept or reject the updates from a mini-step. When a mini-step is
Expand Down Expand Up @@ -405,7 +405,7 @@ def _do_update(updates, state, params):
* numerics.safe_int32_increment(state.gradient_step)
+ (1 - emit) * state.gradient_step,
inner_opt_state=jax.tree_util.tree_map(
lambda st, nst: (1 - emit) * st + emit * nst,
lambda st, nst: jnp.where(emit, nst, st),
state.inner_opt_state,
new_inner_state,
),
Expand Down
13 changes: 13 additions & 0 deletions optax/_src/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,19 @@ def test_multi_steps_every_k_schedule(self):
_, opt_state = opt_update(grad, opt_state, params)
self.assertTrue(ms_opt.has_updated(opt_state))

def test_multi_steps_zero_nans(self):
# Test that MultiStep is compatible with zero_nans
# https://github.com/google-deepmind/optax/issues/828
ms_opt = wrappers.MultiSteps(
combine.chain(constrain.zero_nans(), alias.sgd(1e-4)),
every_k_schedule=2
)
opt_init, opt_update = ms_opt.gradient_transformation()
params = dict(a=jnp.zeros([]))
opt_state = opt_init(params)
grad = dict(a=jnp.zeros([]))
opt_update(grad, opt_state, params)

def test_multi_steps_computes_mean(self):
k_steps = 4
ms_opt = wrappers.MultiSteps(
Expand Down

0 comments on commit 6de95bf

Please sign in to comment.