diff --git a/haiku/_src/layer_stack.py b/haiku/_src/layer_stack.py index cc93129a1..a2cd622b6 100644 --- a/haiku/_src/layer_stack.py +++ b/haiku/_src/layer_stack.py @@ -89,13 +89,8 @@ def _split_params( name_map: LayerStackTransparencyMapping, ) -> base.Params: """Splits the stacked parameters.""" - - def _split(x): - return [jnp.squeeze(s, axis=0) for s in jnp.split(x, x.shape[0], axis=0)] - params = {} for mod_name, mod_params in stacked_params.items(): - split_mod_params = {k: _split(v) for k, v in mod_params.items()} for i in range(num_layers): new_mod_name = name_map.stacked_to_flat(mod_name, i) if new_mod_name in params: @@ -103,8 +98,7 @@ def _split(x): f"Found conflicting unstacked module name for {mod_name} at" f" {new_mod_name}." ) - params[new_mod_name] = {k: v[i] for k, v in split_mod_params.items()} - + params[new_mod_name] = jax.tree_map(lambda x: x[i], mod_params) # pylint:disable=cell-var-from-loop return params @@ -114,32 +108,27 @@ def _stack_params( name_map: LayerStackTransparencyMapping, ) -> base.Params: """Stacks the split parameters.""" - params = {} - make_empty_param_stack = lambda: ([None] * num_layers) - + # Construct a separate tree for each loop iteration, which we will then + # multimap over in a call to jnp.stack. This formulation preserves custom + # pytree node types. + param_trees = [{} for _ in range(num_layers)] for mod_name, mod_params in split_params.items(): stacked_name_idx = name_map.flat_to_stacked(mod_name) + # If the transparency map returns None, this param is not part of the stack. if stacked_name_idx is None: continue stacked_mod_name, idx = stacked_name_idx - if stacked_mod_name not in params: - params[stacked_mod_name] = collections.defaultdict(make_empty_param_stack) - + if stacked_mod_name not in param_trees[idx]: + param_trees[idx][stacked_mod_name] = {} for k, v in mod_params.items(): - if params[stacked_mod_name][k][idx] is not None: + if k in param_trees[idx][stacked_mod_name]: raise ValueError( f"Found conflicting values for param {stacked_mod_name}/{k} at" f" index {idx}." ) - params[stacked_mod_name][k][idx] = v - - for mod_name, mod_params in params.items(): - for k, v in mod_params.items(): - if None in v: - raise ValueError(f"Couldn't find all params for {mod_name}/{k}: {v}") - mod_params[k] = jnp.stack(v, axis=0) + param_trees[idx][stacked_mod_name][k] = v - return params + return jax.tree_map(lambda *args: jnp.stack(args, axis=0), *param_trees) class _LayerStack: