Skip to content

Commit

Permalink
Merge pull request #4373 from 8bitmp3:add-nnx-remat-docstring
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696676726
  • Loading branch information
Flax Authors committed Nov 14, 2024
2 parents 72f4971 + 8fbe99c commit 856b6b5
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,3 +874,18 @@ def remat(
),
)
)
"""A 'lifted' version of the
`jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
(a.k.a. ``jax.remat``).
``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for
example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus
how they are recomputed during the backward pass, trading off memory and FLOPs.
Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
To learn about ``jax.remat``, go to JAX's
`fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
"""

0 comments on commit 856b6b5

Please sign in to comment.