Skip to content

Commit

Permalink
[JAX] Update users of jax.tree.map() to be more careful about how the…
Browse files Browse the repository at this point in the history
…y handle Nones.

Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.

Fix user code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.

PiperOrigin-RevId: 641681067
  • Loading branch information
hawkinsp authored and copybara-github committed Jun 9, 2024
1 parent e9d6270 commit 3292138
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,9 @@ def _keep_in_algo(k, v):
masked_grads = grads
else:
masked_grads = {k: _keep_in_algo(k, v) for k, v in grads.items()}
flat_grads, treedef = jax.tree_util.tree_flatten(masked_grads)
flat_grads, treedef = jax.tree_util.tree_flatten(
masked_grads, is_leaf=lambda x: x is None
)
flat_opt_state = jax.tree_util.tree_map(
lambda _, x: x # pylint:disable=g-long-lambda
if isinstance(x, (np.ndarray, jax.Array))
Expand Down

0 comments on commit 3292138

Please sign in to comment.