Skip to content

Commit

Permalink
Make tree_map_params and masked work together
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549924283
  • Loading branch information
rosshemsley authored and OptaxDev committed Jul 21, 2023
1 parent 2a8a517 commit 2b5ff1b
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 7 deletions.
6 changes: 5 additions & 1 deletion optax/_src/state_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def tree_map_params(
/,
*rest: Any,
transform_non_params: Optional[Callable[..., Any]] = None,
is_leaf: Optional[Callable[[base.Params], bool]] = None,
) -> base.OptState:
"""Apply a callable over all params in the given optimizer state.
Expand Down Expand Up @@ -71,6 +72,9 @@ def tree_map_params(
that will be passed to f.
transform_non_params: An optional function that will be called on all
non-parameter fields within the optimizer state.
is_leaf: Passed through to `jax.tree_map`. This makes it possible to ignore
parts of the parameter tree e.g. when the gradient transformations modify
the shape of the original pytree, such as for ``optax.masked``.
Returns:
The result of applying the function f on all trees in the optimizer's state
Expand All @@ -89,7 +93,7 @@ def tree_map_params(

def map_params(maybe_placeholder_value, value):
if isinstance(maybe_placeholder_value, _ParamsPlaceholder):
return jax.tree_map(f, value, *rest)
return jax.tree_map(f, value, *rest, is_leaf=is_leaf)
elif transform_non_params is not None:
return transform_non_params(value)
else:
Expand Down
22 changes: 22 additions & 0 deletions optax/_src/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import numpy as np
from optax._src import base
from optax._src import numerics
from optax._src import state_utils


Array = jnp.ndarray

Expand Down Expand Up @@ -473,6 +475,11 @@ def masked(
For the ``inner`` transform, state will only be stored for the parameters that
have a mask value of ``True``.
Note that, when using ``tree_map_params``, it may be required to pass the
argument `is_leaf=lambda v: isinstance(v, optax.MaskedNode)`, if the tree
map needs to take additional arguments with the same shape as the original
input tree.
Args:
inner: Inner transformation to mask.
mask: a PyTree with same structure as (or a prefix of) the params PyTree, or
Expand All @@ -490,6 +497,21 @@ def mask_pytree(pytree, mask_tree):
return tree_map(lambda m, p: p if m else MaskedNode(), mask_tree, pytree)

def init_fn(params):
# This is a workaround to make tree_map_params work with masking.
# The API of `masked` takes a mask on construction, instead of at init.
# This means that this gradient transformation can only work for parameter
# trees that match the shape of the mask. Technically this breaks the API
# of optax, and this causes tree_map_params to break. This is because
# tree_map_params calls init with a placeholder in order to detect copies
# of the parameter tree. As a (slightly ugly) workaround, we detect when
# the init is being called by tree_map_params, and pass the placeholder
# down without masking. This is safe, since tree_map_params does not impose
# any particular constraints on the shape of the parameter tree, as long
# as tree_map_params is being called on a tree with the correct structure.
# See wrappers_test for proof that this works!
if isinstance(params, state_utils._ParamsPlaceholder): # pylint:disable=protected-access
return MaskedState(inner_state=inner.init(params))

mask_tree = mask(params) if callable(mask) else mask
masked_params = mask_pytree(params, mask_tree)
return MaskedState(inner_state=inner.init(masked_params))
Expand Down
112 changes: 106 additions & 6 deletions optax/_src/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
"""Tests for `wrappers.py`."""

import copy
from typing import cast

from absl.testing import absltest
from absl.testing import parameterized

import chex
import haiku as hk
import jax
Expand Down Expand Up @@ -390,6 +390,107 @@ def test_multi_steps_skip_not_finite(self):
class MaskedTest(chex.TestCase):
"""Tests for the masked wrapper."""

def test_tree_map_params(self):
params = {
'a': {
'b': (jnp.zeros((1, 2)), jnp.zeros((2, 2))),
},
'c': {
'd': jnp.zeros((1, 2)),
'e': (jnp.zeros((1, 2)), jnp.zeros((1, 2))),
},
}

sharding_axes = {
'a': {
'b': (1, 2),
},
'c': {
'd': 1,
'e': (1, 2),
},
}

mask = {
'a': {
'b': (True, False),
},
'c': {
'd': True,
'e': (False, True),
},
}

expected = {
'a': {
'b': (jnp.ones((1, 2)), jnp.zeros((2, 2))),
},
'c': {
'd': jnp.ones((1, 2)),
'e': (jnp.ones((1, 2)), jnp.ones((1, 2))),
},
}

def init_fn(params):
return {'count': 1, 'params': params, 'params_copy': params}

def update_fn(updates, state, params=None):
del params
return updates, state

inner = base.GradientTransformation(init_fn, update_fn)
masked = wrappers.masked(inner, mask)

def increment_dim_1(v):
return v + 1 if v.shape[0] == 1 else v

# For this optimizer, tree_map_params should have the same effect on a
# masked optimizer state as it does on an unmasked optimizer state.
with self.subTest('inner'):
state = inner.init(params)
result = state_utils.tree_map_params(inner, increment_dim_1, state)
chex.assert_trees_all_equal(result, inner.init(expected))

with self.subTest('masked'):
state = masked.init(params)
result = state_utils.tree_map_params(masked, increment_dim_1, state)
chex.assert_trees_all_equal(result, masked.init(expected))

with self.subTest('masked_with_extra_args'):
# Users wishing to pass additional arguments with the same tree structure
# as the original params pytree will need to add the additional `is_leaf`
# callable. This makes it possible to ignore the masked parts of the
# pytree.

# Replace all non-masked parameters in the opt-state tree with the
# sharding axis values given in the tree above. Everything else is set to
# None.
new_state = state_utils.tree_map_params(
masked,
lambda p, axis: None if isinstance(p, wrappers.MaskedNode) else axis,
state,
sharding_axes,
is_leaf=lambda v: isinstance(v, wrappers.MaskedNode),
transform_non_params=lambda v: None,
)

sharded_params = {
'a': {
'b': (1, None),
},
'c': {
'd': 1,
'e': (None, 2),
},
}

# Required to make pytype happy
new_state = cast(wrappers.MaskedState, new_state)

chex.assert_equal(None, new_state.inner_state['count'])
chex.assert_equal(sharded_params, new_state.inner_state['params'])
chex.assert_equal(sharded_params, new_state.inner_state['params_copy'])

@chex.all_variants
@parameterized.named_parameters(
('sgd', _build_sgd, False),
Expand All @@ -416,10 +517,9 @@ def masked_negate(updates):
update_fn = self.variant(update_fn)
state = self.variant(init_fn)(params)

# Known issue: masked does not work with arbitrary parameter trees, and
# so does not work with tree_map_params.
with self.assertRaises(ValueError):
state_utils.tree_map_params(init_fn, lambda v: v, state)
with self.subTest('tree_map_params'):
result = state_utils.tree_map_params(init_fn, lambda v: v, state)
chex.assert_tree_all_equal_structs(result, state)

updates, state = update_fn(input_updates, state, params)
chex.assert_trees_all_close(updates, correct_updates)
Expand Down Expand Up @@ -456,7 +556,7 @@ def _masked_sgd_on_updates(m, upd):
# Check repeated application, this time with no params.
correct_updates = jax.tree_util.tree_map(
_masked_sgd_on_updates, mask, correct_updates)
updates, state = update_fn(updates, state)
updates, _ = update_fn(updates, state)
chex.assert_trees_all_close(updates, correct_updates)

@chex.all_variants
Expand Down

0 comments on commit 2b5ff1b

Please sign in to comment.