[JAX] Update users of jax.tree.map() to be more careful about how they handle Nones. #784
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[JAX] Update users of jax.tree.map() to be more careful about how they handle Nones.
Due to a bug in JAX, JAX previously permitted
jax.tree.map(f, None, x)
wherex
is notNone
, effectively treatingNone
as if it were pytree-prefix of any value. ButNone
is a pytree container, and it is only a prefix ofNone
itself.Fix user code that was relying on this bug. Most commonly, the fix is to write
jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)
.