diff --git a/haiku/_src/layer_stack.py b/haiku/_src/layer_stack.py index 43c34b8ec..ad62feec9 100644 --- a/haiku/_src/layer_stack.py +++ b/haiku/_src/layer_stack.py @@ -89,13 +89,8 @@ def _split_params( name_map: LayerStackTransparencyMapping, ) -> base.Params: """Splits the stacked parameters.""" - - def _split(x): - return [jnp.squeeze(s, axis=0) for s in jnp.split(x, x.shape[0], axis=0)] - params = {} for mod_name, mod_params in stacked_params.items(): - split_mod_params = {k: _split(v) for k, v in mod_params.items()} for i in range(num_layers): new_mod_name = name_map.stacked_to_flat(mod_name, i) if new_mod_name in params: @@ -103,8 +98,7 @@ def _split(x): f"Found conflicting unstacked module name for {mod_name} at" f" {new_mod_name}." ) - params[new_mod_name] = {k: v[i] for k, v in split_mod_params.items()} - + params[new_mod_name] = jax.tree_map(lambda x: x[i], mod_params) # pylint:disable=cell-var-from-loop return params @@ -114,32 +108,27 @@ def _stack_params( name_map: LayerStackTransparencyMapping, ) -> base.Params: """Stacks the split parameters.""" - params = {} - make_empty_param_stack = lambda: ([None] * num_layers) - + # Construct a separate tree for each loop iteration, which we will then + # multimap over in a call to jnp.stack. This formulation preserves custom + # pytree node types. + param_trees = [{} for _ in range(num_layers)] for mod_name, mod_params in split_params.items(): stacked_name_idx = name_map.flat_to_stacked(mod_name) + # If the transparency map returns None, this param is not part of the stack. if stacked_name_idx is None: continue stacked_mod_name, idx = stacked_name_idx - if stacked_mod_name not in params: - params[stacked_mod_name] = collections.defaultdict(make_empty_param_stack) - + if stacked_mod_name not in param_trees[idx]: + param_trees[idx][stacked_mod_name] = {} for k, v in mod_params.items(): - if params[stacked_mod_name][k][idx] is not None: + if k in param_trees[idx][stacked_mod_name]: raise ValueError( f"Found conflicting values for param {stacked_mod_name}/{k} at" f" index {idx}." ) - params[stacked_mod_name][k][idx] = v - - for mod_name, mod_params in params.items(): - for k, v in mod_params.items(): - if None in v: - raise ValueError(f"Couldn't find all params for {mod_name}/{k}: {v}") - mod_params[k] = jnp.stack(v, axis=0) + param_trees[idx][stacked_mod_name][k] = v - return params + return jax.tree_map(lambda *args: jnp.stack(args, axis=0), *param_trees) class _LayerStack: diff --git a/haiku/_src/layer_stack_test.py b/haiku/_src/layer_stack_test.py index 1fb9fc3b5..20a534e8c 100644 --- a/haiku/_src/layer_stack_test.py +++ b/haiku/_src/layer_stack_test.py @@ -17,7 +17,6 @@ import functools import re from typing import Optional - from absl.testing import absltest from absl.testing import parameterized from haiku._src import base @@ -598,6 +597,83 @@ def stacked(x: jax.Array) -> jax.Array: rtol=1e-6, ) + def test_layer_stack_transparent_with_custom_pytrees(self): + class TransparencyMap(layer_stack.LayerStackTransparencyMapping): + + def stacked_to_flat(self, stacked_module_name: str, scan_idx: int) -> str: + return stacked_module_name.replace("0", str(scan_idx)) + + def flat_to_stacked( + self, unstacked_module_name: str + ) -> Optional[tuple[str, int]]: + idx = int(re.findall(r"\d+", unstacked_module_name)[0]) + return unstacked_module_name.replace(str(idx), "0"), idx + + @jax.tree_util.register_pytree_node_class + class CustomParam: + + def __init__(self, param, name): + self.param = param + self.multiplier = name + + def tree_flatten(self): + return ((self.param, self.multiplier), None) + + @classmethod + def tree_unflatten(cls, aux, values): + del aux + return cls(*values) + + @property + def shape(self) -> list[int]: + return self.param.shape + + class CustomLinear: + + def __init__(self, *args, **kwargs): + self.linear = basic.Linear(*args, **kwargs) + + def __call__(self, x: CustomParam) -> CustomParam: + # Unwrap from CustomParam before invoking linear + return CustomParam( + self.linear(x.param * x.multiplier), + x.multiplier, + ) + + def block(x: CustomParam, i: int) -> CustomParam: + return CustomLinear(output_size=x.shape[-1], name=f"linear_{i}")(x) + + def looped(x: CustomParam, num_layers: int = 1) -> CustomParam: + for i in range(num_layers): + x = block(x, i) + return x + + def stacked(x: CustomParam) -> CustomParam: + return layer_stack.layer_stack( + num_layers=1, transparent=True, transparency_map=TransparencyMap() + )(lambda y: block(y, 0))(x) + + looped = transform.transform(looped) + stacked = transform.transform(stacked) + + x = CustomParam(jnp.ones((2, 2)), 0.3) + rng = jax.random.PRNGKey(0) + looped_params = looped.init(rng, x) + stacked_params = stacked.init(rng, x) + + self.assertEqual( + jax.tree_util.tree_structure(looped_params), + jax.tree_util.tree_structure(stacked_params), + ) + + # Use same set of params for both calls since stacked_params have different + # value than looped params because differences in RNG splitting. + np.testing.assert_allclose( + looped.apply(looped_params, rng, x).param, + stacked.apply(looped_params, rng, x).param, + rtol=1e-6, + ) + if __name__ == "__main__": jax.config.update("jax_check_tracer_leaks", True)