diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index f395ed32c..339274b9f 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -14,16 +14,12 @@ # ============================================================================== """Transformation wrappers.""" -import functools from typing import Any, Callable, NamedTuple, Optional, Protocol, Tuple, Union import chex import jax from jax import lax import jax.numpy as jnp -from jax.tree_util import tree_flatten -from jax.tree_util import tree_map -from jax.tree_util import tree_unflatten import numpy as np from optax._src import base from optax._src import numerics @@ -52,12 +48,12 @@ def flatten( def _flatten(params): """Flattens and concatenates all tensors in params to a single vector.""" - params, _ = tree_flatten(params) + params, _ = jax.tree_util.tree_flatten(params) return jnp.concatenate([jnp.reshape(param, [-1]) for param in params]) def _unflatten(updates, flat): """Extracts tensors from flat, using the structure and shapes of params.""" - updates_flat, treedef = tree_flatten(updates) + updates_flat, treedef = jax.tree_util.tree_flatten(updates) offsets = [] for update in updates_flat: size = np.prod(update.shape) @@ -71,7 +67,7 @@ def _unflatten(updates, flat): jnp.reshape(flat_update, update.shape) for flat_update, update in zip(flat_split, updates_flat) ] - return tree_unflatten(treedef, reshaped) + return jax.tree_util.tree_unflatten(treedef, reshaped) def init_fn(params): flat = _flatten(params) @@ -144,7 +140,7 @@ def init(params): def update(updates, state, params=None, **extra_args): inner_state = state.inner_state - flat_updates = tree_flatten(updates)[0] + flat_updates = jax.tree_util.tree_flatten(updates)[0] isfinite = jnp.all( jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) notfinite_count = jnp.where( @@ -154,7 +150,7 @@ def update(updates, state, params=None, **extra_args): def do_update(_): return inner.update(updates, inner_state, params, **extra_args) def reject_update(_): - return (tree_map(jnp.zeros_like, updates), inner_state) + return (jax.tree_util.tree_map(jnp.zeros_like, updates), inner_state) updates, new_inner_state = lax.cond( jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors), @@ -171,7 +167,7 @@ def reject_update(_): return base.GradientTransformationExtraArgs(init=init, update=update) -def _zeros_tree_like(inp_tree): +def _zeros_tree_like(inp_tree: chex.ArrayTree) -> chex.ArrayTree: return jax.tree_util.tree_map(jnp.zeros_like, inp_tree) @@ -376,15 +372,18 @@ def update(self, ) -> Tuple[base.Updates, MultiStepsState]: """Accumulates gradients and proposes non-zero updates every `k_steps`.""" k_steps = self._every_k_schedule(state.gradient_step) - acc_grads = jax.tree_util.tree_map( - functools.partial(self._acc_update, n_acc=state.mini_step), - updates, state.acc_grads) - should_skip_update, skip_state = self._should_skip_update_fn( updates, state.gradient_step, params) + if (should_skip_update.dtype, should_skip_update.shape) != (jnp.bool_, ()): + raise ValueError( + 'The `should_skip_update_fn` function should return a boolean scalar ' + f'array, but it returned an array of dtype {should_skip_update.dtype}' + f' and shape {should_skip_update.shape}' + ) - def final_step(args): - del args + # Note: we do not enclose variables to allow JAX to re-use memory buffers. + + def _final_step(state, params, acc_grads): final_updates, new_inner_state = self._opt.update( acc_grads, state.inner_opt_state, params=params, **extra_args) new_state = MultiStepsState( @@ -395,8 +394,7 @@ def final_step(args): skip_state=skip_state) return final_updates, new_state - def mid_step(args): - del args + def _mid_step(state, params, acc_grads): updates_shape_dtype, _ = jax.eval_shape( self._opt.update, acc_grads, state.inner_opt_state, params=params) mid_updates = jax.tree_util.tree_map( @@ -409,27 +407,29 @@ def mid_step(args): skip_state=skip_state) return mid_updates, new_state - new_updates, new_state = jax.lax.cond( - state.mini_step < k_steps - 1, (), mid_step, (), final_step) - - if (should_skip_update.dtype, should_skip_update.shape) != (jnp.bool_, ()): - raise ValueError( - 'The `should_skip_update_fn` function should return a boolean scalar ' - f'array, but it returned an array of dtype {should_skip_update.dtype}' - f' and shape {should_skip_update.shape}') + def _do_update(updates, state, params): + acc_grads = jax.tree_util.tree_map( + lambda upd, acc: self._acc_update(upd, acc, n_acc=state.mini_step), + updates, state.acc_grads) + new_updates, new_state = jax.lax.cond( + state.mini_step < k_steps - 1, + _mid_step, _final_step, *(state, params, acc_grads)) + return new_updates, new_state + + def _skip_update(updates, state, params): + del updates, params + multi_state_when_skip = MultiStepsState( + mini_step=state.mini_step, + gradient_step=state.gradient_step, + inner_opt_state=state.inner_opt_state, + acc_grads=state.acc_grads, + skip_state=skip_state, + ) + zero_updates = _zeros_tree_like(state.acc_grads) + return zero_updates, multi_state_when_skip - multi_state_when_skip = MultiStepsState( - mini_step=state.mini_step, - gradient_step=state.gradient_step, - inner_opt_state=state.inner_opt_state, - acc_grads=state.acc_grads, - skip_state=skip_state) - zero_updates = jax.tree_util.tree_map(jnp.zeros_like, updates) new_updates, new_state = jax.lax.cond( - should_skip_update, - (), lambda args: (zero_updates, multi_state_when_skip), - (), lambda args: (new_updates, new_state)) - + should_skip_update, _skip_update, _do_update, *(updates, state, params)) return new_updates, new_state def has_updated(self, state: Union[MultiStepsState, chex.ArrayTree]) -> Array: @@ -497,7 +497,9 @@ def masked( inner = base.with_extra_args_support(inner) def mask_pytree(pytree, mask_tree): - return tree_map(lambda m, p: p if m else MaskedNode(), mask_tree, pytree) + return jax.tree_util.tree_map( + lambda m, p: p if m else MaskedNode(), mask_tree, pytree + ) def init_fn(params): # This is a workaround to make tree_map_params work with masking. @@ -527,7 +529,7 @@ def update_fn(updates, state, params=None, **extra_args): new_masked_updates, new_inner_state = inner.update( masked_updates, state.inner_state, masked_params, **extra_args) - new_updates = tree_map( + new_updates = jax.tree_util.tree_map( lambda m, new_u, old_u: new_u if m else old_u, mask_tree, new_masked_updates, updates) return new_updates, MaskedState(inner_state=new_inner_state)