Skip to content

Commit

Permalink
[JAX] Update users of jax.tree.map() to be more careful about how the…
Browse files Browse the repository at this point in the history
…y handle Nones.

Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.

Fix user code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.

PiperOrigin-RevId: 641832913
  • Loading branch information
hawkinsp authored and copybara-github committed Jun 10, 2024
1 parent 9ccaae1 commit a7b7e73
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions haiku/_src/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,10 +745,16 @@ def __call__(self, inputs, state):
next_states = []
outputs = []
state_idx = 0
concat = lambda *args: jnp.concatenate(args, axis=-1)
for idx, layer in enumerate(self.layers):
if self.skip_connections and idx > 0:
current_inputs = jax.tree_util.tree_map(concat, inputs, current_inputs)
current_inputs = jax.tree_util.tree_map(
lambda x, *args: (
None if x is None else jnp.concatenate((x,) + args, axis=-1)
),
inputs,
current_inputs,
is_leaf=lambda x: x is None,
)

if isinstance(layer, RNNCore):
current_inputs, next_state = layer(current_inputs, state[state_idx])
Expand All @@ -759,7 +765,8 @@ def __call__(self, inputs, state):
current_inputs = layer(current_inputs)

if self.skip_connections:
out = jax.tree_util.tree_map(concat, *outputs)
out = jax.tree_util.tree_map(lambda *args: jnp.concatenate(args, axis=-1),
*outputs)
else:
out = current_inputs

Expand Down

0 comments on commit a7b7e73

Please sign in to comment.