diff --git a/docs/conf.py b/docs/conf.py index e65c142bd4..cd90590d67 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -113,7 +113,7 @@ # href with no underline and white bold text color announcement = """ This site covers the old Flax Linen API. [Explore the new Flax NNX API ✨] diff --git a/docs/index.rst b/docs/index.rst index 2f0cfee614..c286d4f0a0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -326,4 +326,4 @@ Notable examples in Flax include: philosophy contributing api_reference/index - Flax NNX + Flax NNX diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py index 344010ac8b..c3deabe87c 100644 --- a/docs_nnx/conf.py +++ b/docs_nnx/conf.py @@ -144,7 +144,7 @@ # files that will not be executed. myst_enable_extensions = ['dollarmath'] nb_execution_excludepatterns = [ - 'quick_start.ipynb', # <-- times out + 'mnist_tutorial.ipynb', # <-- times out 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 'flax/nnx', # exclude nnx 'guides/demo.ipynb', # TODO(cgarciae): broken, remove or update diff --git a/docs_nnx/glossary.rst b/docs_nnx/glossary.rst index 1ed754a098..0a8a1fa34c 100644 --- a/docs_nnx/glossary.rst +++ b/docs_nnx/glossary.rst @@ -7,7 +7,7 @@ For additional terms, refer to the `JAX glossary ` out of a :term:`Module`. Usually done via calling :meth:`nnx.split ` upon the module. See the `Filter guide `__ to learn more. + A way to extract only certain :term:`Variables` out of a :term:`Module`. Usually done via calling :meth:`nnx.split ` upon the module. See the `Filter guide `__ to learn more. `Folding in `__ Generating a new PRNG key given an input PRNG key and integer. Typically used when you want to @@ -15,13 +15,13 @@ For additional terms, refer to the `JAX glossary `__ but this will effectively create two RNG keys, which is slower. See how Flax generates new PRNG keys automatically in our - `RNG guide `__. + `RNG guide `__. GraphDef :class:`nnx.GraphDef`, a class that represents all the static, stateless, Pythonic part of an :class:`nnx.Module` definition. Lifted transformation - A wrapped version of the `JAX transformations `__ that allows the transformed function to take Flax :term:`Modules` as input or output. For example, a lifted version of `jax.jit `__ will be :meth:`flax.nnx.jit `. See the `lifted transforms guide `__. + A wrapped version of the `JAX transformations `__ that allows the transformed function to take Flax :term:`Modules` as input or output. For example, a lifted version of `jax.jit `__ will be :meth:`flax.nnx.jit `. See the `lifted transforms guide `__. Merge See :term:`Split and merge`. @@ -37,7 +37,7 @@ For additional terms, refer to the `JAX glossary ` can keep a reference of an :class:`RNG state object ` that can generate new JAX `PRNG `__ keys. They keys are used to generate random JAX arrays through `JAX's functional random number generators `__. You can use an RNG state with different seeds to make more fine-grained control on your model (e.g., independent random numbers for parameters and dropout masks). - See the `RNG guide `__ + See the `RNG guide `__ for more details. Split and merge diff --git a/docs_nnx/guides/bridge_guide.ipynb b/docs_nnx/guides/bridge_guide.ipynb index e41836a93e..c24b76a573 100644 --- a/docs_nnx/guides/bridge_guide.ipynb +++ b/docs_nnx/guides/bridge_guide.ipynb @@ -17,9 +17,9 @@ "\n", "**Note**:\n", "\n", - "This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. \n", + "This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. \n", "\n", - "And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)." + "And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)." ] }, { diff --git a/docs_nnx/guides/bridge_guide.md b/docs_nnx/guides/bridge_guide.md index 3f243ae2ab..3e2c9b4aa3 100644 --- a/docs_nnx/guides/bridge_guide.md +++ b/docs_nnx/guides/bridge_guide.md @@ -11,9 +11,9 @@ We hope this allows you to move and try out NNX at your own pace, and leverage t **Note**: -This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. +This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. -And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). +And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). ```python diff --git a/docs_nnx/guides/flax_gspmd.ipynb b/docs_nnx/guides/flax_gspmd.ipynb index 168ae2bdb1..4011c56c22 100644 --- a/docs_nnx/guides/flax_gspmd.ipynb +++ b/docs_nnx/guides/flax_gspmd.ipynb @@ -6,7 +6,7 @@ "source": [ "# Scale up on multiple devices\n", "\n", - "This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)." + "This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)." ] }, { @@ -16,13 +16,13 @@ "source": [ "## Overview\n", "\n", - "Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n", + "Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n", "\n", "JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices.\n", "\n", - "To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.\n", + "To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.\n", "\n", - "> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer.\n", + "> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer.\n", "\n", "If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials:\n", "\n", @@ -131,11 +131,11 @@ "source": [ "## Define a model with specified sharding\n", "\n", - "Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", + "Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", "\n", - "To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n", + "To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n", "\n", - "> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html) to learn more." + "> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more." ] }, { @@ -234,7 +234,7 @@ "source": [ "Here, you should leverage JAX's compilation mechanism, via [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function:\n", "\n", - "1. Use [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables.\n", + "1. Use [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables.\n", "\n", "1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jit` how to shard a variable!\n", "\n", @@ -412,7 +412,7 @@ "\n", "Now you can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading it sharded if a sharding pytree is given.\n", "\n", - "You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree.\n", + "You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree.\n", "\n", "Below is an example demonstration using Orbax's `StandardCheckpointer` API. Check out [Orbax website](https://orbax.readthedocs.io/en/latest/) to learn their latest updates and recommended APIs." ] @@ -591,9 +591,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded.\n", + "Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded.\n", "\n", - "[`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs." + "[`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs." ] }, { @@ -723,7 +723,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding)." + "If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding)." ] }, { diff --git a/docs_nnx/guides/flax_gspmd.md b/docs_nnx/guides/flax_gspmd.md index 342774d945..e13c4f2c4c 100644 --- a/docs_nnx/guides/flax_gspmd.md +++ b/docs_nnx/guides/flax_gspmd.md @@ -10,19 +10,19 @@ jupytext: # Scale up on multiple devices -This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html). +This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html). +++ ## Overview -Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s. +Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s. JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices. -To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information. +To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information. -> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer. +> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer. If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials: @@ -79,11 +79,11 @@ print(mesh) ## Define a model with specified sharding -Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between. +Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between. -To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable). +To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable). -> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html) to learn more. +> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more. ```{code-cell} ipython3 class DotReluDot(nnx.Module): @@ -146,7 +146,7 @@ print(unsharded_model.w2.value.sharding) # SingleDeviceSharding Here, you should leverage JAX's compilation mechanism, via [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function: -1. Use [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables. +1. Use [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables. 1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jit` how to shard a variable! @@ -209,7 +209,7 @@ This brings a question: Why use the Flax NNX Annotation API then? Why not just a Now you can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading it sharded if a sharding pytree is given. -You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree. +You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree. Below is an example demonstration using Orbax's `StandardCheckpointer` API. Check out [Orbax website](https://orbax.readthedocs.io/en/latest/) to learn their latest updates and recommended APIs. @@ -256,9 +256,9 @@ print(output.shape) jax.debug.visualize_array_sharding(output) # Also sharded as ('data', None) ``` -Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded. +Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded. -[`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs. +[`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs. ```{code-cell} ipython3 optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3)) # reference sharing @@ -337,7 +337,7 @@ class LogicalDotReluDot(nnx.Module): return z ``` -If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). +If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). ```{code-cell} ipython3 def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState: diff --git a/docs_nnx/guides/linen_to_nnx.rst b/docs_nnx/guides/linen_to_nnx.rst index d0c20fd09e..b4fa6b3736 100644 --- a/docs_nnx/guides/linen_to_nnx.rst +++ b/docs_nnx/guides/linen_to_nnx.rst @@ -4,9 +4,9 @@ Evolution from Linen to NNX This guide will walk you through the differences between Flax Linen and Flax NNX models, and side-by-side comparisions to help you migrate your code from the Linen API to NNX. -Before this guide, it's highly recommended to read through `The Basics of Flax NNX `__ to learn about the core concepts and code examples of Flax NNX. +Before this guide, it's highly recommended to read through `The Basics of Flax NNX `__ to learn about the core concepts and code examples of Flax NNX. -This guide mainly covers converting arbitratry Linen code to NNX. If you want to play it safe and convert your codebase iteratively, check out the guide that allows you to `use NNX and Linen code together `__ +This guide mainly covers converting arbitratry Linen code to NNX. If you want to play it safe and convert your codebase iteratively, check out the guide that allows you to `use NNX and Linen code together `__ .. testsetup:: Linen, NNX @@ -90,7 +90,7 @@ To generate the model parameters for a Linen model, you call the ``init`` method In NNX, the model parameters are automatically initialized when the user instantiates the model, and the variables are stored inside the module (or its submodule) as attributes. You still need to give it an RNG key, but the key will be wrapped inside a ``nnx.Rngs`` class and will be stored inside, generating more RNG keys when needed. -If you want to access NNX model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the `NNX split/merge API `__. +If you want to access NNX model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the `NNX split/merge API `__. .. codediff:: :title: Linen, NNX @@ -122,15 +122,15 @@ Now we write a training step and compile it using JAX just-in-time compilation. * Linen uses ``@jax.jit`` to compile the training step, whereas NNX uses ``@nnx.jit``. ``jax.jit`` only accepts pure stateless arguments, but ``nnx.jit`` allows the arguments to be stateful NNX modules. This greatly reduced the number of lines needed for a train step. -* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return the gradients of Modules as NNX ``State`` dictionaries. To use regular ``jax.grad`` with NNX you need to use the `NNX split/merge API `__. +* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return the gradients of Modules as NNX ``State`` dictionaries. To use regular ``jax.grad`` with NNX you need to use the `NNX split/merge API `__. - * If you are already using Optax optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here), check out `nnx.Optimizer example `__ for a much more concise way of training and updating your model. + * If you are already using Optax optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here), check out `nnx.Optimizer example `__ for a much more concise way of training and updating your model. * The Linen train step needs to return a tree of parameters, as the input of the next step. On the other hand, NNX's step doesn't need to return anything, because the ``model`` was already in-place-updated within ``nnx.jit``. -* NNX modules are stateful and automatically tracks a few things within, such as RNG keys and BatchNorm stats. That's why you don't need to explicitly pass an RNG key in on every step. Note that you can use `nnx.reseed `__ to reset its underlying RNG state. +* NNX modules are stateful and automatically tracks a few things within, such as RNG keys and BatchNorm stats. That's why you don't need to explicitly pass an RNG key in on every step. Note that you can use `nnx.reseed `__ to reset its underlying RNG state. -* In Linen, you need to explicitly define and pass in an argument ``training`` to control the behavior of ``nn.Dropout`` (namely, its ``deterministic`` flag, which means random dropout only happens if ``training=True``). In NNX, you can call ``model.train()`` to automatically switch ``nnx.Dropout`` to training mode. Conversely, call ``model.eval()`` to turn off training mode. You can learn more about what this API does at its `API reference `__. +* In Linen, you need to explicitly define and pass in an argument ``training`` to control the behavior of ``nn.Dropout`` (namely, its ``deterministic`` flag, which means random dropout only happens if ``training=True``). In NNX, you can call ``model.train()`` to automatically switch ``nnx.Dropout`` to training mode. Conversely, call ``model.eval()`` to turn off training mode. You can learn more about what this API does at its `API reference `__. .. codediff:: @@ -257,7 +257,7 @@ For all the built-in Flax Linen layers and collections, NNX already created the model.batchnorm.mean # BatchStat(value=...) model.count # Counter(value=...) -If you want to extract certain arrays from the tree of variables, you can access the specific dictionary path in Linen, or use ``nnx.split`` to distinguish the types apart in NNX. The code below is an easier example, and check out `Filter API Guide `__ for more sophisticated filtering expressions. +If you want to extract certain arrays from the tree of variables, you can access the specific dictionary path in Linen, or use ``nnx.split`` to distinguish the types apart in NNX. The code below is an easier example, and check out `Filter API Guide `__ for more sophisticated filtering expressions. .. codediff:: :title: Linen, NNX @@ -525,7 +525,7 @@ But if you think closely, there actually isn't any need for ``jax.lax.scan`` ope In NNX we take advantage of the fact that model initialization and running code are completely decoupled, and instead use ``nnx.vmap`` to initialize the underlying blocks, and ``nnx.scan`` to run the model input through them. -For more information on NNX transforms, check out the `Transforms Guide `__. +For more information on NNX transforms, check out the `Transforms Guide `__. .. codediff:: :title: Linen, NNX diff --git a/docs_nnx/guides/randomness.ipynb b/docs_nnx/guides/randomness.ipynb index 517e3f22b4..500095530d 100644 --- a/docs_nnx/guides/randomness.ipynb +++ b/docs_nnx/guides/randomness.ipynb @@ -234,7 +234,7 @@ "source": [ "## Filtering random state\n", "\n", - "Random state can be manipulated using [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`RngState`, `RngKey`, `RngCount`) or using strings corresponding to the stream names (see [The Filter DSL](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:" + "Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`RngState`, `RngKey`, `RngCount`) or using strings corresponding to the stream names (see [The Filter DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:" ] }, { diff --git a/docs_nnx/guides/randomness.md b/docs_nnx/guides/randomness.md index 815567b424..b8d8426dd4 100644 --- a/docs_nnx/guides/randomness.md +++ b/docs_nnx/guides/randomness.md @@ -99,7 +99,7 @@ As shown above, a key from the `default` stream can also be generated by calling ## Filtering random state -Random state can be manipulated using [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`RngState`, `RngKey`, `RngCount`) or using strings corresponding to the stream names (see [The Filter DSL](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`: +Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`RngState`, `RngKey`, `RngCount`) or using strings corresponding to the stream names (see [The Filter DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`: ```{code-cell} ipython3 model = Model(nnx.Rngs(params=0, dropout=1)) diff --git a/docs_nnx/guides/transforms.ipynb b/docs_nnx/guides/transforms.ipynb index 35ed4d6f52..28287fe7ec 100644 --- a/docs_nnx/guides/transforms.ipynb +++ b/docs_nnx/guides/transforms.ipynb @@ -411,7 +411,7 @@ "\n", "> **Note:** * Flax NNX `shard_map` has not been implemented yet at the time of writing this version of the document.\n", "\n", - "To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix.\n", + "To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix.\n", "\n", "Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.\n", "To do this we will define a `nnx.StateAxes` with a filter that matches the `nnx.Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `nnx.StateAxes` to `in_axes` for the `Weights` object." diff --git a/docs_nnx/guides/transforms.md b/docs_nnx/guides/transforms.md index 4b12869b60..9df2c9abbf 100644 --- a/docs_nnx/guides/transforms.md +++ b/docs_nnx/guides/transforms.md @@ -205,7 +205,7 @@ Certain JAX transforms allow the use of pytree prefixes to specify how different > **Note:** * Flax NNX `shard_map` has not been implemented yet at the time of writing this version of the document. -To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix. +To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix. Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements. To do this we will define a `nnx.StateAxes` with a filter that matches the `nnx.Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `nnx.StateAxes` to `in_axes` for the `Weights` object. diff --git a/docs_nnx/why.rst b/docs_nnx/why.rst index cfc1d9a941..6e0e5c9c2d 100644 --- a/docs_nnx/why.rst +++ b/docs_nnx/why.rst @@ -357,7 +357,7 @@ however Linen transforms have the following drawbacks: * They accepts other Modules as arguments but not as return values. 3. They can only be used inside ``apply``. -`Flax NNX transforms `_ on the other hand +`Flax NNX transforms `_ on the other hand are intented to be equivalent to JAX transforms with the exception that they can be used on Modules. This means they have the same API as JAX transforms, can accepts Modules on any argument and Modules can be returned from them, and they can be used anywhere including the training loop.