Skip to content

Commit

Permalink
Merge pull request #4267 from google:remove-flax-nnx-urls
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683629110
  • Loading branch information
Flax Authors committed Oct 8, 2024
2 parents 601a915 + b12c4b4 commit 2a2187b
Show file tree
Hide file tree
Showing 14 changed files with 49 additions and 49 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
# href with no underline and white bold text color
announcement = """
<a
href="https://flax-nnx.readthedocs.io/en/latest/index.html"
href="https://flax.readthedocs.io/en/latest/index.html"
style="text-decoration: none; color: white;"
>
This site covers the old Flax Linen API. <span style="color: lightgray;">[Explore the new <b>Flax NNX</b> API ✨]</span>
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,4 @@ Notable examples in Flax include:
philosophy
contributing
api_reference/index
Flax NNX <https://flax-nnx.readthedocs.io/en/latest/index.html>
Flax NNX <https://flax.readthedocs.io/en/latest/index.html>
2 changes: 1 addition & 1 deletion docs_nnx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
# files that will not be executed.
myst_enable_extensions = ['dollarmath']
nb_execution_excludepatterns = [
'quick_start.ipynb', # <-- times out
'mnist_tutorial.ipynb', # <-- times out
'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0
'flax/nnx', # exclude nnx
'guides/demo.ipynb', # TODO(cgarciae): broken, remove or update
Expand Down
8 changes: 4 additions & 4 deletions docs_nnx/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ For additional terms, refer to the `JAX glossary <https://jax.readthedocs.io/en/
.. glossary::

Filter
A way to extract only certain :term:`Variables<Variable>` out of a :term:`Module<Module>`. Usually done via calling :meth:`nnx.split <flax.nnx.split>` upon the module. See the `Filter guide <https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html>`__ to learn more.
A way to extract only certain :term:`Variables<Variable>` out of a :term:`Module<Module>`. Usually done via calling :meth:`nnx.split <flax.nnx.split>` upon the module. See the `Filter guide <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ to learn more.

`Folding in <https://jax.readthedocs.io/en/latest/_autosummary/jax.random.fold_in.html>`__
Generating a new PRNG key given an input PRNG key and integer. Typically used when you want to
generate a new key but still be able to use the original rng key afterwards. You can also do this with
`jax.random.split <https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html>`__
but this will effectively create two RNG keys, which is slower. See how Flax generates new PRNG keys
automatically in our
`RNG guide <https://flax-nnx.readthedocs.io/en/latest/guides/randomness.html>`__.
`RNG guide <https://flax.readthedocs.io/en/latest/guides/randomness.html>`__.

GraphDef
:class:`nnx.GraphDef<flax.nnx.GraphDef>`, a class that represents all the static, stateless, Pythonic part of an :class:`nnx.Module<flax.nnx.Module>` definition.

Lifted transformation
A wrapped version of the `JAX transformations <https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html>`__ that allows the transformed function to take Flax :term:`Modules<Module>` as input or output. For example, a lifted version of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit>`__ will be :meth:`flax.nnx.jit <flax.nnx.jit>`. See the `lifted transforms guide <https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html>`__.
A wrapped version of the `JAX transformations <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__ that allows the transformed function to take Flax :term:`Modules<Module>` as input or output. For example, a lifted version of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit>`__ will be :meth:`flax.nnx.jit <flax.nnx.jit>`. See the `lifted transforms guide <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__.

Merge
See :term:`Split and merge<Split and merge>`.
Expand All @@ -37,7 +37,7 @@ For additional terms, refer to the `JAX glossary <https://jax.readthedocs.io/en/
RNG states
A Flax :class:`module <flax.nnx.Module>` can keep a reference of an :class:`RNG state object <flax.nnx.Rngs>` that can generate new JAX `PRNG <https://en.wikipedia.org/wiki/Pseudorandom_number_generator>`__ keys. They keys are used to generate random JAX arrays through `JAX's functional random number generators <https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html>`__.
You can use an RNG state with different seeds to make more fine-grained control on your model (e.g., independent random numbers for parameters and dropout masks).
See the `RNG guide <https://flax-nnx.readthedocs.io/en/latest/guides/randomness.html>`__
See the `RNG guide <https://flax.readthedocs.io/en/latest/guides/randomness.html>`__
for more details.

Split and merge
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/bridge_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
"\n",
"**Note**:\n",
"\n",
"This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. \n",
"This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. \n",
"\n",
"And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)."
"And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)."
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/bridge_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ We hope this allows you to move and try out NNX at your own pace, and leverage t

**Note**:

This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide.
This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide.

And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html).
And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html).


```python
Expand Down
24 changes: 12 additions & 12 deletions docs_nnx/guides/flax_gspmd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"source": [
"# Scale up on multiple devices\n",
"\n",
"This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)."
"This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)."
]
},
{
Expand All @@ -16,13 +16,13 @@
"source": [
"## Overview\n",
"\n",
"Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n",
"Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n",
"\n",
"JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices.\n",
"\n",
"To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.\n",
"To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.\n",
"\n",
"> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer.\n",
"> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer.\n",
"\n",
"If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials:\n",
"\n",
Expand Down Expand Up @@ -131,11 +131,11 @@
"source": [
"## Define a model with specified sharding\n",
"\n",
"Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n",
"Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n",
"\n",
"To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n",
"To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n",
"\n",
"> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html) to learn more."
"> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more."
]
},
{
Expand Down Expand Up @@ -234,7 +234,7 @@
"source": [
"Here, you should leverage JAX's compilation mechanism, via [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function:\n",
"\n",
"1. Use [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables.\n",
"1. Use [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables.\n",
"\n",
"1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jit` how to shard a variable!\n",
"\n",
Expand Down Expand Up @@ -412,7 +412,7 @@
"\n",
"Now you can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading it sharded if a sharding pytree is given.\n",
"\n",
"You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree.\n",
"You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree.\n",
"\n",
"Below is an example demonstration using Orbax's `StandardCheckpointer` API. Check out [Orbax website](https://orbax.readthedocs.io/en/latest/) to learn their latest updates and recommended APIs."
]
Expand Down Expand Up @@ -591,9 +591,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded.\n",
"Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded.\n",
"\n",
"[`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs."
"[`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs."
]
},
{
Expand Down Expand Up @@ -723,7 +723,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding)."
"If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding)."
]
},
{
Expand Down
Loading

0 comments on commit 2a2187b

Please sign in to comment.