From bf46427a0e0802e540832922a6063d22d119f779 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 12 Oct 2024 13:07:17 +0200 Subject: [PATCH] Improves docs for `tree_at` and `tree_check`. Supersedes #874. Fixes #872. --- equinox/_tree.py | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/equinox/_tree.py b/equinox/_tree.py index a1e48c4c..1b959601 100644 --- a/equinox/_tree.py +++ b/equinox/_tree.py @@ -116,6 +116,18 @@ def tree_at( new_mlp = eqx.tree_at(get_last_layer, mlp, new_linear) ``` See also the [Tricks](../../tricks) page. + + !!! Info + + Constructing analogous PyTrees, with the same structure but different leaves, is + very common in JAX: for example when constructing the `in_axes` argument to + `jax.vmap`. + + To support this use-case, the returned PyTree is constructed without calling + `__init__`, `__post_init__`, or + [`__check_init__`](./module/advanced_fields.md#checking-invariants). This allows + for modifying leaves to be anything, regardless of the use of any custom + constructor or custom checks in the original PyTree. """ # noqa: E501 # We need to specify a particular node in a PyTree. @@ -406,15 +418,15 @@ def is_leaf(node): def tree_check(pytree: Any) -> None: - """Checks if the PyTree is well-formed: does it have no self-references, and does - it have no duplicate layers. - - Precisely, a "duplicate layer" is any PyTree node with at least one child node. - - !!! info + """Checks if the PyTree has no self-references, and if all non-leaf nodes unique + Python objects. (For example something like `x = [1]; y = [x, x]` would fail as `x` + appears twice in the PyTree.) - This is automatically called when creating an `eqx.Module` instance, to help - avoid bugs from duplicating layers. + Having unique non-leaf nodes isn't actually a requirement that JAX imposes, but it + will become true after passing through an operation like `jax.{jit, grad, ...}` (as + JAX copies the PyTree without trying to preserve identity), so some users like to + assert that this invariant was already true prior to the transform, as a way to + avoid surprises. !!! Example @@ -423,18 +435,18 @@ def tree_check(pytree: Any) -> None: eqx.tree_check([a, a]) # passes, duplicate is a leaf b = eqx.nn.Linear(...) - eqx.tree_check([b, b]) # fails, duplicate is nontrivial! + eqx.tree_check([b, b]) # fails, duplicate is non-leaf! c = [] # empty list - eqx.tree_check([c, c]) # passes, duplicate is trivial + eqx.tree_check([c, c]) # passes, duplicate is leaf d = eqx.Module() - eqx.tree_check([d, d]) # passes, duplicate is trivial + eqx.tree_check([d, d]) # passes, duplicate is leaf - eqx.tree_check([None, None]) # passes, duplicate is trivial + eqx.tree_check([None, None]) # passes, duplicate is leaf e = [1] - eqx.tree_check([e, e]) # fails, duplicate is nontrivial! + eqx.tree_check([e, e]) # fails, duplicate is non-leaf! eqx.tree_check([[1], [1]]) # passes, not actually a duplicate: each `[1]` # has the same structure, but they're different.