Skip to content

Commit

Permalink
make haiku work with upcoming JAX change to tree_map (being more care…
Browse files Browse the repository at this point in the history
…ful about

Nones)

PiperOrigin-RevId: 641064728
  • Loading branch information
mattjj authored and copybara-github committed Jun 7, 2024
1 parent 0451621 commit 84c5eaa
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions haiku/_src/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,9 @@ def pure_body_fun(i, val):
return val


def maybe_get_axis(axis: int, arrays: Any) -> Optional[int]:
def maybe_get_axis(axis: Optional[int], arrays: Any) -> Optional[int]:
"""Returns `array.shape[axis]` for one of the arrays in the input."""
if axis is None: return None
shapes = [a.shape for a in jax.tree_util.tree_leaves(arrays)]
sizes = {s[axis] for s in shapes}
if len(sizes) != 1:
Expand All @@ -715,7 +716,8 @@ def maybe_get_axis(axis: int, arrays: Any) -> Optional[int]:

def get_mapped_axis_size(args: tuple[Any], in_axes: Any) -> int:
sizes = uniq(jax.tree_util.tree_leaves(
jax.tree_util.tree_map(maybe_get_axis, in_axes, args)))
jax.tree_util.tree_map(maybe_get_axis, in_axes, args,
is_leaf=lambda x: x is None)))
assert sizes, "hk.vmap should guarantee non-empty in_axes"
# NOTE: We use the first in_axes regardless of how many non-unique values
# there are to allow JAX to handle multiple conflicting sizes.
Expand Down

0 comments on commit 84c5eaa

Please sign in to comment.