From 84c5eaa6d9d2c8bcc3a58f22dc499578253e60d1 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 6 Jun 2024 16:54:09 -0700 Subject: [PATCH] make haiku work with upcoming JAX change to tree_map (being more careful about Nones) PiperOrigin-RevId: 641064728 --- haiku/_src/stateful.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/haiku/_src/stateful.py b/haiku/_src/stateful.py index d2231c4c4..681736c9d 100644 --- a/haiku/_src/stateful.py +++ b/haiku/_src/stateful.py @@ -699,8 +699,9 @@ def pure_body_fun(i, val): return val -def maybe_get_axis(axis: int, arrays: Any) -> Optional[int]: +def maybe_get_axis(axis: Optional[int], arrays: Any) -> Optional[int]: """Returns `array.shape[axis]` for one of the arrays in the input.""" + if axis is None: return None shapes = [a.shape for a in jax.tree_util.tree_leaves(arrays)] sizes = {s[axis] for s in shapes} if len(sizes) != 1: @@ -715,7 +716,8 @@ def maybe_get_axis(axis: int, arrays: Any) -> Optional[int]: def get_mapped_axis_size(args: tuple[Any], in_axes: Any) -> int: sizes = uniq(jax.tree_util.tree_leaves( - jax.tree_util.tree_map(maybe_get_axis, in_axes, args))) + jax.tree_util.tree_map(maybe_get_axis, in_axes, args, + is_leaf=lambda x: x is None))) assert sizes, "hk.vmap should guarantee non-empty in_axes" # NOTE: We use the first in_axes regardless of how many non-unique values # there are to allow JAX to handle multiple conflicting sizes.