Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable state with layer stack #761

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 28 additions & 34 deletions haiku/_src/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@
import jax.numpy as jnp


class LayerStackStateError(Exception):
"""Raise if trying to use layer_stack with Haiku state."""

LayerStackCarry = collections.namedtuple("LayerStackCarry", ["x"])
LayerStackScanned = collections.namedtuple("LayerStackScanned",
["params", "rng", "args_ys"])
LayerStackScanned = collections.namedtuple(
"LayerStackScanned", ["params", "rng", "state", "args_ys"])

# WrappedFn should take in arbitrarily nested `jax.Array`, and return the
# exact same type. We cannot express this with `typing`. So we just use it
Expand Down Expand Up @@ -151,42 +148,37 @@ def __init__(

def __call__(self, x, *args_ys, reverse=False):
count = self._count
init_fn, apply_fn = transform.transform(self._call_wrapped)
init_fn, apply_fn = transform.transform_with_state(self._call_wrapped)

def per_layer_init_fn(c, a):
c, rng = c
if rng is not None:
rng, next_rng, apply_rng = jax.random.split(rng, 3)
else:
rng, next_rng, apply_rng = None, None, None
params = init_fn(rng, c, *a)
c, _ = apply_fn(params, apply_rng, c, *a)
return (c, next_rng), params
params, state = init_fn(rng, c, *a)
(c, _), state = apply_fn(params, state, apply_rng, c, *a)
return (c, next_rng), (params, state)

def scanned_init_fn(x, rng):
_, params = jax.lax.scan(per_layer_init_fn, (x, rng), args_ys,
length=self._count)
_, (params, state) = jax.lax.scan(per_layer_init_fn, (x, rng), args_ys,
length=self._count)
if self._transparency_map is not None:
return _split_params(params, self._count, self._transparency_map)
else:
return params
return (_split_params(params, self._count, self._transparency_map),
_split_params(state, self._count, self._transparency_map))
return params, state

rng = base.maybe_next_rng_key()

try:
if self._transparency_map is not None:
lifted_init_fn = lift.transparent_lift(
scanned_init_fn, allow_reuse=True
)
else:
lifted_init_fn = lift.lift(
scanned_init_fn, allow_reuse=True, name=self._name
)
params = lifted_init_fn(x, rng)
except base.NonEmptyStateError as e:
raise LayerStackStateError("LayerStack can only be used on Haiku "
"functions which do not make use of Haiku "
"state.") from e
if self._transparency_map is not None:
params_and_state_fn, updater = lift.transparent_lift_with_state(
scanned_init_fn, allow_reuse=True
)
else:
params_and_state_fn, updater = lift.lift_with_state(
scanned_init_fn, allow_reuse=True, name=self._name
)
params, state = params_and_state_fn(x, rng)

# Use scan during apply, threading through random seed so that it's
# unique for each layer.
Expand All @@ -195,26 +187,31 @@ def layer(
) -> tuple[LayerStackCarry, Any]:
rng = scanned.rng
params = scanned.params
state = scanned.state

kwargs = {}
if self._pass_reverse_to_layer_fn:
kwargs["reverse"] = reverse
out_x, z = apply_fn(params, rng, carry.x, *scanned.args_ys, **kwargs)
return LayerStackCarry(x=out_x), z
(out_x, z), state = apply_fn(
params, state, rng, carry.x, *scanned.args_ys, **kwargs)
return LayerStackCarry(x=out_x), (z, state)

rng = _get_rng_stack(count)

if self._transparency_map is not None:
params = _stack_params(params, self._count, self._transparency_map)
state = _stack_params(state, self._count, self._transparency_map)

carry = LayerStackCarry(x=x)
scanned = LayerStackScanned(params=params,
state=state,
rng=rng,
args_ys=args_ys)

carry, zs = jax.lax.scan(
carry, (zs, states) = jax.lax.scan(
layer, carry, scanned, length=count, unroll=self._unroll,
reverse=reverse)
updater.update(states)
return carry.x, zs

def _call_wrapped(
Expand Down Expand Up @@ -302,9 +299,6 @@ def layer_stack(
that kwargs are not supported, neither are functions with variable number
of parameters (specified by ``*args``).

Note that `layer_stack` cannot at the moment be used with functions that build
Haiku modules with state.

If ``with_per_layer_inputs=False`` then the new, wrapped function can be
understood as performing the following:

Expand Down
16 changes: 11 additions & 5 deletions haiku/_src/layer_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,24 @@ def stack_fn(x):
ValueError, "The function `f` should not have any `varargs`"):
build_and_init_stack(VarArgsModule)

def test_layer_stack_no_state_error(self):
def test_layer_stack_with_state(self):
def outer_fn_layer_stack(x):
stack = layer_stack.layer_stack(1)(lambda x: base.set_state("hi", x))
def simple_stateful_layer(x):
base.set_state("hi", x)
return x + 1
stack = layer_stack.layer_stack(
5, name="with_state")(simple_stateful_layer)
return stack(x)

layer_stack_fn = transform.transform_with_state(outer_fn_layer_stack)

x = jnp.ones((1,))

with self.assertRaisesRegex(layer_stack.LayerStackStateError,
"LayerStack.*state"):
layer_stack_fn.init(None, x)
params, state = layer_stack_fn.init(None, x)
_, state = layer_stack_fn.apply(params, state, None, x)

np.testing.assert_allclose(state["with_state/~"]["hi"],
np.array([[1.0, 2.0, 3.0, 4.0, 5.0]]).T)

@parameterized.parameters([1, 2, 4])
def test_layer_stack_grads(self, unroll):
Expand Down
Loading