Skip to content

Commit

Permalink
modify layer_stack transparency map
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 598005552
  • Loading branch information
Haiku Contributor authored and copybara-github committed Jan 26, 2024
1 parent 0898b7b commit 1273b9c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 23 deletions.
33 changes: 11 additions & 22 deletions haiku/_src/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,16 @@ def _split_params(
name_map: LayerStackTransparencyMapping,
) -> base.Params:
"""Splits the stacked parameters."""

def _split(x):
return [jnp.squeeze(s, axis=0) for s in jnp.split(x, x.shape[0], axis=0)]

params = {}
for mod_name, mod_params in stacked_params.items():
split_mod_params = {k: _split(v) for k, v in mod_params.items()}
for i in range(num_layers):
new_mod_name = name_map.stacked_to_flat(mod_name, i)
if new_mod_name in params:
raise ValueError(
f"Found conflicting unstacked module name for {mod_name} at"
f" {new_mod_name}."
)
params[new_mod_name] = {k: v[i] for k, v in split_mod_params.items()}

params[new_mod_name] = jax.tree_map(lambda x: x[i], mod_params) # pylint:disable=cell-var-from-loop
return params


Expand All @@ -114,32 +108,27 @@ def _stack_params(
name_map: LayerStackTransparencyMapping,
) -> base.Params:
"""Stacks the split parameters."""
params = {}
make_empty_param_stack = lambda: ([None] * num_layers)

# Construct a separate tree for each loop iteration, which we will then
# multimap over in a call to jnp.stack. This formulation preserves custom
# pytree node types.
param_trees = [{} for _ in range(num_layers)]
for mod_name, mod_params in split_params.items():
stacked_name_idx = name_map.flat_to_stacked(mod_name)
# If the transparency map returns None, this param is not part of the stack.
if stacked_name_idx is None:
continue
stacked_mod_name, idx = stacked_name_idx
if stacked_mod_name not in params:
params[stacked_mod_name] = collections.defaultdict(make_empty_param_stack)

if stacked_mod_name not in param_trees[idx]:
param_trees[idx][stacked_mod_name] = {}
for k, v in mod_params.items():
if params[stacked_mod_name][k][idx] is not None:
if k in param_trees[idx][stacked_mod_name]:
raise ValueError(
f"Found conflicting values for param {stacked_mod_name}/{k} at"
f" index {idx}."
)
params[stacked_mod_name][k][idx] = v

for mod_name, mod_params in params.items():
for k, v in mod_params.items():
if None in v:
raise ValueError(f"Couldn't find all params for {mod_name}/{k}: {v}")
mod_params[k] = jnp.stack(v, axis=0)
param_trees[idx][stacked_mod_name][k] = v

return params
return jax.tree_map(lambda *args: jnp.stack(args, axis=0), *param_trees)


class _LayerStack:
Expand Down
78 changes: 77 additions & 1 deletion haiku/_src/layer_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import functools
import re
from typing import Optional

from absl.testing import absltest
from absl.testing import parameterized
from haiku._src import base
Expand Down Expand Up @@ -598,6 +597,83 @@ def stacked(x: jax.Array) -> jax.Array:
rtol=1e-6,
)

def test_layer_stack_transparent_with_custom_pytrees(self):
class TransparencyMap(layer_stack.LayerStackTransparencyMapping):

def stacked_to_flat(self, stacked_module_name: str, scan_idx: int) -> str:
return stacked_module_name.replace("0", str(scan_idx))

def flat_to_stacked(
self, unstacked_module_name: str
) -> Optional[tuple[str, int]]:
idx = int(re.findall(r"\d+", unstacked_module_name)[0])
return unstacked_module_name.replace(str(idx), "0"), idx

@jax.tree_util.register_pytree_node_class
class CustomParam:

def __init__(self, param, name):
self.param = param
self.multiplier = name

def tree_flatten(self):
return ((self.param, self.multiplier), None)

@classmethod
def tree_unflatten(cls, aux, values):
del aux
return cls(*values)

@property
def shape(self) -> list[int]:
return self.param.shape

class CustomLinear:

def __init__(self, *args, **kwargs):
self.linear = basic.Linear(*args, **kwargs)

def __call__(self, x: CustomParam) -> CustomParam:
# Unwrap from CustomParam before invoking linear
return CustomParam(
self.linear(x.param * x.multiplier),
x.multiplier,
)

def block(x: CustomParam, i: int) -> CustomParam:
return CustomLinear(output_size=x.shape[-1], name=f"linear_{i}")(x)

def looped(x: CustomParam, num_layers: int = 1) -> CustomParam:
for i in range(num_layers):
x = block(x, i)
return x

def stacked(x: CustomParam) -> CustomParam:
return layer_stack.layer_stack(
num_layers=1, transparent=True, transparency_map=TransparencyMap()
)(lambda y: block(y, 0))(x)

looped = transform.transform(looped)
stacked = transform.transform(stacked)

x = CustomParam(jnp.ones((2, 2)), 0.3)
rng = jax.random.PRNGKey(0)
looped_params = looped.init(rng, x)
stacked_params = stacked.init(rng, x)

self.assertEqual(
jax.tree_util.tree_structure(looped_params),
jax.tree_util.tree_structure(stacked_params),
)

# Use same set of params for both calls since stacked_params have different
# value than looped params because differences in RNG splitting.
np.testing.assert_allclose(
looped.apply(looped_params, rng, x).param,
stacked.apply(looped_params, rng, x).param,
rtol=1e-6,
)


if __name__ == "__main__":
jax.config.update("jax_check_tracer_leaks", True)
Expand Down

0 comments on commit 1273b9c

Please sign in to comment.