Skip to content

Commit

Permalink
Improves docs for tree_at and tree_check.
Browse files Browse the repository at this point in the history
Supersedes #874.
Fixes #872.
  • Loading branch information
patrick-kidger committed Oct 12, 2024
1 parent 1c33c85 commit bf46427
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions equinox/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ def tree_at(
new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear)
```
See also the [Tricks](../../tricks) page.
!!! Info
Constructing analogous PyTrees, with the same structure but different leaves, is
very common in JAX: for example when constructing the `in_axes` argument to
`jax.vmap`.
To support this use-case, the returned PyTree is constructed without calling
`__init__`, `__post_init__`, or
[`__check_init__`](./module/advanced_fields.md#checking-invariants). This allows
for modifying leaves to be anything, regardless of the use of any custom
constructor or custom checks in the original PyTree.
""" # noqa: E501

# We need to specify a particular node in a PyTree.
Expand Down Expand Up @@ -406,15 +418,15 @@ def is_leaf(node):


def tree_check(pytree: Any) -> None:
"""Checks if the PyTree is well-formed: does it have no self-references, and does
it have no duplicate layers.
Precisely, a "duplicate layer" is any PyTree node with at least one child node.
!!! info
"""Checks if the PyTree has no self-references, and if all non-leaf nodes unique
Python objects. (For example something like `x = [1]; y = [x, x]` would fail as `x`
appears twice in the PyTree.)
This is automatically called when creating an `eqx.Module` instance, to help
avoid bugs from duplicating layers.
Having unique non-leaf nodes isn't actually a requirement that JAX imposes, but it
will become true after passing through an operation like `jax.{jit, grad, ...}` (as
JAX copies the PyTree without trying to preserve identity), so some users like to
assert that this invariant was already true prior to the transform, as a way to
avoid surprises.
!!! Example
Expand All @@ -423,18 +435,18 @@ def tree_check(pytree: Any) -> None:
eqx.tree_check([a, a]) # passes, duplicate is a leaf
b = eqx.nn.Linear(...)
eqx.tree_check([b, b]) # fails, duplicate is nontrivial!
eqx.tree_check([b, b]) # fails, duplicate is non-leaf!
c = [] # empty list
eqx.tree_check([c, c]) # passes, duplicate is trivial
eqx.tree_check([c, c]) # passes, duplicate is leaf
d = eqx.Module()
eqx.tree_check([d, d]) # passes, duplicate is trivial
eqx.tree_check([d, d]) # passes, duplicate is leaf
eqx.tree_check([None, None]) # passes, duplicate is trivial
eqx.tree_check([None, None]) # passes, duplicate is leaf
e = [1]
eqx.tree_check([e, e]) # fails, duplicate is nontrivial!
eqx.tree_check([e, e]) # fails, duplicate is non-leaf!
eqx.tree_check([[1], [1]]) # passes, not actually a duplicate: each `[1]`
# has the same structure, but they're different.
Expand Down

0 comments on commit bf46427

Please sign in to comment.