diff --git a/docs/tutorials/jax-export.ipynb b/docs/tutorials/jax-export.ipynb index 627873b422a..01943534502 100644 --- a/docs/tutorials/jax-export.ipynb +++ b/docs/tutorials/jax-export.ipynb @@ -1,445 +1,450 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "_TuAgGNKt5HO" + }, + "source": [ + "# Tutorial: Exporting StableHLO from JAX\n", + "\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][jax-tutorial-colab]\n", + "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][jax-tutorial-kaggle]\n", + "\n", + "JAX is a Python library for high-performance numerical computing. This tutorial shows how to export JAX and Flax (JAX-powered neural network library) models to StableHLO, and directly to TensorFlow SavedModel.\n", + "\n", + "## Tutorial Setup\n", + "\n", + "### Install required dependencies\n", + "\n", + "We use `jax` and `jaxlib` (JAX's support library with compiled binaries), along with `flax` and `transformers` for some models to export.\n", + "We also need to install `tensorflow` to work with SavedModel, and recommend using `tensorflow-cpu` or `tf-nightly` for this tutorial.\n", + "\n", + "[jax-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb\n", + "[jax-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "ENUcO6aML-Nq", + "jupyter": { + "outputs_hidden": true } + }, + "outputs": [], + "source": [ + "!pip install -U jax jaxlib flax transformers tensorflow-cpu" + ] }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Tutorial: Exporting StableHLO from JAX\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][jax-tutorial-colab]\n", - "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][jax-tutorial-kaggle]\n", - "\n", - "## Tutorial Setup\n", - "\n", - "### Install required dependencies\n", - "\n", - "We'll be using `jax` and `jaxlib` (JAX's XLA package), along with `flax` and `transformers` for some models to export.\n", - "\n", - "[jax-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb\n", - "[jax-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb" - ], - "metadata": { - "id": "_TuAgGNKt5HO" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ENUcO6aML-Nq", - "collapsed": true - }, - "outputs": [], - "source": [ - "!pip install -U jax jaxlib flax transformers" - ] - }, - { - "cell_type": "code", - "source": [ - "#@title Define `get_stablehlo_asm` to help with MLIR printing\n", - "from jax._src.interpreters import mlir as jax_mlir\n", - "from jax._src.lib.mlir import ir\n", - "\n", - "# Returns prettyprint of StableHLO module without large constants\n", - "def get_stablehlo_asm(module_str):\n", - " with jax_mlir.make_ir_context():\n", - " stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())\n", - " return stablehlo_module.operation.get_asm(large_elements_limit=20)\n", - "\n", - "# Disable logging for better tutorial rendering\n", - "import logging\n", - "logging.disable(logging.WARNING)" - ], - "metadata": { - "id": "HqjeC_QSugYj", - "cellView": "form" - }, - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "_Note: This helper uses a JAX internal API that may break at any time, but it serves no functional purpose in the tutorial aside from readability._" - ], - "metadata": { - "id": "LNlFj80UwX0D" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Export JAX model to StableHLO using `jax.export`\n", - "\n", - "In this section we'll export a very basic JAX function to StableHLO.\n", - "\n", - "The preferred API for export is `jax.experimental.export`, which uses `jax.jit` under the hood. As a rule-of-thumb, a function must be `jit`-able to be exported to StableHLO." - ], - "metadata": { - "id": "TEfsW69IBp_V" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Export basic JAX model to StableHLO\n", - "\n", - "Let's start by exporting a basic `plus` function to StableHLO, using `np.int32` argument types to trace the function.\n", - "\n", - "Export requires specifying shapes using `jax.ShapeDtypeStruct`, which are trivial to construct from numpy values." - ], - "metadata": { - "id": "3MzMjjf2iIk2" - } - }, - { - "cell_type": "code", - "source": [ - "import jax\n", - "from jax.experimental import export\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "\n", - "def plus(x,y):\n", - " return jnp.add(x,y)\n", - "\n", - "# Create abstract input shapes:\n", - "inputs = (np.int32(1), np.int32(1),)\n", - "input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]\n", - "stablehlo_add = export.export(plus)(*input_shapes).mlir_module()\n", - "print(get_stablehlo_asm(stablehlo_add))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "v-GN3vPbvFoa", - "outputId": "b128c08f-3591-4142-fd67-561812bb3d4e" - }, - "execution_count": 3, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", - " func.func public @main(%arg0: tensor {mhlo.layout_mode = \"default\"}, %arg1: tensor {mhlo.layout_mode = \"default\"}) -> (tensor {jax.result_info = \"\", mhlo.layout_mode = \"default\"}) {\n", - " %0 = stablehlo.add %arg0, %arg1 : tensor\n", - " return %0 : tensor\n", - " }\n", - "}\n", - "\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Export Huggingface FlaxResNet18 to StableHLO\n", - "\n", - "Now let's look at a simple model that appears in the wild, `resnet18`.\n", - "\n", - "This example will export a `flax` model from the huggingface `transformers` ResNet page, [FlaxResNetModel](https://huggingface.co/docs/transformers/en/model_doc/resnet#transformers.FlaxResNetModel). Much of this steps setup was copied from the huggingface documentation.\n", - "\n", - "The documentation also states: _\"Finally, this model supports inherent JAX features such as: **Just-In-Time (JIT) compilation** ...\"_ which means it is perfect for export.\n", - "\n", - "Similar to our very basic example, our steps for export are:\n", - "1. Instantiate a callable (model/function) that can be JIT'ed.\n", - "2. Specify shapes for export using `jax.ShapeDtypeStruct` on numpy values\n", - "3. Use the JAX `export` API to get a StableHLO module" - ], - "metadata": { - "id": "xZWVAHQzBsEM" - } - }, - { - "cell_type": "code", - "source": [ - "from transformers import AutoImageProcessor, FlaxResNetModel\n", - "import jax\n", - "import numpy as np\n", - "\n", - "# Construct flax model with sample inputs\n", - "\n", - "resnet18 = FlaxResNetModel.from_pretrained(\"microsoft/resnet-18\", return_dict=False)\n", - "sample_input = np.random.randn(1, 3, 224, 224)\n", - "input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)\n", - "\n", - "# Export to StableHLO\n", - "stablehlo_resnet18_export = export.export(resnet18)(input_shape)\n", - "resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())\n", - "print(resnet18_stablehlo[:600], \"\\n...\\n\", resnet18_stablehlo[-345:])" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "53T7jO-v_6PC", - "outputId": "0536386d-09d2-49a3-b51c-951c20f6e49b" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", - " func.func public @main(%arg0: tensor<1x3x224x224xf32> {mhlo.layout_mode = \"default\"}) -> (tensor<1x512x7x7xf32> {mhlo.layout_mode = \"default\"}, tensor<1x512x1x1xf32> {mhlo.layout_mode = \"default\"}) {\n", - " %0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32>\n", - " %1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>\n", - " %2 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>\n", - " %3 = stablehlo.constan \n", - "...\n", - " }\n", - " func.func private @relu_3(%arg0: tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32> {\n", - " %0 = stablehlo.constant dense<0.000000e+00> : tensor\n", - " %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<1x7x7x512xf32>\n", - " %2 = stablehlo.maximum %arg0, %1 : tensor<1x7x7x512xf32>\n", - " return %2 : tensor<1x7x7x512xf32>\n", - " }\n", - "}\n", - "\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Export with dynamic batch size\n", - "\n", - "Not let's export that same model with a dynamic batch size!\n", - "\n", - "In the first example we used an input shape of `tensor<1x3x224x224xf32>`, specifying strict constraints on the input shape. If we wanted to defer the concerete shapes used in compilation until a later point, we can specify a `symbolic_shape`, in this case we'll export using `tensor`.\n", - "\n", - "Symbolic shapes are specified using `export.symbolic_shape`, with letters representing symint dimensions. For example, a valid 2-d matrix multiplication could use symbolic constraints of: `2,a * a,5` to ensure the refined program will have valid shapes. Symbolic integer names are kept track of by an `export.SymbolicScope` to avoid unintentional name clashes." - ], - "metadata": { - "id": "X2MC9F7Zlx6E" - } - }, - { - "cell_type": "code", - "source": [ - "dyn_scope = export.SymbolicScope()\n", - "dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape(\"a,3,224,224\", scope=dyn_scope), np.float32)\n", - "dyn_resnet18_export = export.export(resnet18)(dyn_input_shape)\n", - "dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())\n", - "print(dyn_resnet18_stablehlo[:1900], \"\\n...\\n\", dyn_resnet18_stablehlo[-1000:])" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "sIkbtViEMJ3T", - "outputId": "165019c7-e771-43d6-c324-6bb4222798ec" - }, - "execution_count": 6, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", - " func.func public @main(%arg0: tensor {mhlo.layout_mode = \"default\"}) -> (tensor {mhlo.layout_mode = \"default\"}, tensor {mhlo.layout_mode = \"default\"}) {\n", - " %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor\n", - " %1 = stablehlo.constant dense<1> : tensor\n", - " %2 = stablehlo.compare GE, %0, %1, SIGNED : (tensor, tensor) -> tensor\n", - " stablehlo.custom_call @shape_assertion(%2, %0) {api_version = 2 : i32, error_message = \"Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.\", has_side_effect = true} : (tensor, tensor) -> ()\n", - " %3:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor, tensor) -> (tensor, tensor)\n", - " return %3#0, %3#1 : tensor, tensor\n", - " }\n", - " func.func private @_wrapped_jax_export_main(%arg0: tensor {jax.global_constant = \"a\"}, %arg1: tensor {mhlo.layout_mode = \"default\"}) -> (tensor {mhlo.layout_mode = \"default\"}, tensor {mhlo.layout_mode = \"default\"}) {\n", - " %0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32>\n", - " %1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>\n", - " %2 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>\n", - " %3 = stablehl \n", - "...\n", - " \n", - " return %10 : tensor\n", - " }\n", - " func.func private @relu_3(%arg0: tensor {jax.global_constant = \"a\"}, %arg1: tensor) -> tensor {\n", - " %0 = stablehlo.constant dense<0.000000e+00> : tensor\n", - " %1 = stablehlo.constant dense<7> : tensor\n", - " %2 = stablehlo.constant dense<7> : tensor\n", - " %3 = stablehlo.constant dense<512> : tensor\n", - " %4 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32>\n", - " %5 = stablehlo.constant dense<7> : tensor<1xi32>\n", - " %6 = stablehlo.constant dense<7> : tensor<1xi32>\n", - " %7 = stablehlo.constant dense<512> : tensor<1xi32>\n", - " %8 = stablehlo.concatenate %4, %5, %6, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>\n", - " %9 = stablehlo.dynamic_broadcast_in_dim %0, %8, dims = [] : (tensor, tensor<4xi32>) -> tensor\n", - " %10 = stablehlo.maximum %arg1, %9 : tensor\n", - " return %10 : tensor\n", - " }\n", - "}\n", - "\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "A few things to note in the exported StableHLO:\n", - "\n", - "1. The exported program now has `tensor`. These input types can be refined in many ways: SavedModel execution takes care of refinement which we'll see in the next example, but StableHLO also has APIs to [refine shapes](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L69-L99) and [canonicalize dynamic programs](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L18-L28) to static programs.\n", - "2. JAX will generate guards to ensure the values of `a` are valid, in this case `a > 1` is checked. These can be washed away at compile time once refined." - ], - "metadata": { - "id": "dRIu7xlSoDUK" - } - }, - { - "cell_type": "markdown", - "source": [ - "## Export to SavedModel\n", - "\n", - "It is common to want to export a StableHLO model to SavedModel for interop with existing compilation pipelines, existing TF tooling, or serving via [TF Serving](https://github.com/tensorflow/serving).\n", - "\n", - "JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section we'll be using our dynamic model from the previous section." - ], - "metadata": { - "id": "xFU5M6Xm1U8_" - } - }, - { - "cell_type": "markdown", - "source": [ - "### Install latest TF\n", - "\n", - "SavedModel definition lives in TF, so we need to install the dependency. We recommend using `tensorflow-cpu` or `tf-nightly`." - ], - "metadata": { - "id": "Gt_bIJaSpsYf" - } - }, - { - "cell_type": "code", - "source": [ - "!pip install tensorflow-cpu" - ], - "metadata": { - "id": "KZEd7NavBoem", - "collapsed": true - }, - "execution_count": null, - "outputs": [] + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "cellView": "form", + "id": "HqjeC_QSugYj" + }, + "outputs": [], + "source": [ + "#@title Define `get_stablehlo_asm` to help with MLIR printing\n", + "from jax._src.interpreters import mlir as jax_mlir\n", + "from jax._src.lib.mlir import ir\n", + "\n", + "# Returns prettyprint of StableHLO module without large constants\n", + "def get_stablehlo_asm(module_str):\n", + " with jax_mlir.make_ir_context():\n", + " stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())\n", + " return stablehlo_module.operation.get_asm(large_elements_limit=20)\n", + "\n", + "# Disable logging for better tutorial rendering\n", + "import logging\n", + "logging.disable(logging.WARNING)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LNlFj80UwX0D" + }, + "source": [ + "_Note: This helper uses a JAX internal API that may break at any time, but it serves no functional purpose in the tutorial aside from readability._" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TEfsW69IBp_V" + }, + "source": [ + "## Export JAX model to StableHLO using `jax.export`\n", + "\n", + "In this section we'll export a basic JAX function and a Flax model to StableHLO.\n", + "\n", + "The preferred API for export is [`jax.export`](https://jax.readthedocs.io/en/latest/jax.export.html#module-jax.export). The function to export must be JIT transformed, specifically a result of `jax.jit`, to be exported to StableHLO." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3MzMjjf2iIk2" + }, + "source": [ + "### Export basic JAX model to StableHLO\n", + "\n", + "Let's start by exporting a basic `plus` function to StableHLO, using `np.int32` argument types to trace the function.\n", + "\n", + "Export requires specifying shapes using `jax.ShapeDtypeStruct`, which can be constructed from NumPy values." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "v-GN3vPbvFoa", + "outputId": "b128c08f-3591-4142-fd67-561812bb3d4e" + }, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "### Export to SavedModel using `jax2tf`\n", - "\n", - "JAX provides a simple API for exporting StableHLO into a format that can be packaged in SavedModel in `jax.experimental.jax2tf`. This uses the `export` function under the hood, so the same `jit` requirements apply.\n", - "\n", - "Full details on `jax2tf` can be found in the [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#jax-and-tensorflow-interoperation-jax2tfcall_tf), for this example we'll only need to know the `polymorphic_shapes` option to specify our dynamic batch dimension." - ], - "metadata": { - "id": "lf7Fnsrop7BD" - } + "name": "stdout", + "output_type": "stream", + "text": [ + "module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", + " func.func public @main(%arg0: tensor, %arg1: tensor) -> (tensor {jax.result_info = \"\"}) {\n", + " %0 = stablehlo.add %arg0, %arg1 : tensor\n", + " return %0 : tensor\n", + " }\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "import jax\n", + "from jax import export\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "# Create a JIT-transformed function\n", + "@jax.jit\n", + "def plus(x,y):\n", + " return jnp.add(x,y)\n", + "\n", + "# Create abstract input shapes\n", + "inputs = (np.int32(1), np.int32(1),)\n", + "input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]\n", + "\n", + "# Export the function to StableHLO\n", + "stablehlo_add = export.export(plus)(*input_shapes).mlir_module()\n", + "print(get_stablehlo_asm(stablehlo_add))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xZWVAHQzBsEM" + }, + "source": [ + "### Export Hugging Face FlaxResNet18 to StableHLO\n", + "\n", + "Now let's look at a simple model that appears in the wild, `resnet18`.\n", + "\n", + "We'll export a `flax` model from the Hugging Face `transformers` ResNet page, [FlaxResNetModel](https://huggingface.co/docs/transformers/en/model_doc/resnet#transformers.FlaxResNetModel). This steps setup was copied from the Hugging Face documentation.\n", + "\n", + "The documentation also states: _\"Finally, this model supports inherent JAX features such as: **Just-In-Time (JIT) compilation** ...\"_ which means it is perfect for export.\n", + "\n", + "Similar to our very basic example, our steps for export are:\n", + "\n", + "1. Instantiate a callable (model/function) \n", + "2. JIT-transform it with `jax.jit`\n", + "3. Specify shapes for export using `jax.ShapeDtypeStruct` on NumPy values\n", + "4. Use the JAX `export` API to get a StableHLO module" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "53T7jO-v_6PC", + "outputId": "0536386d-09d2-49a3-b51c-951c20f6e49b" + }, + "outputs": [ { - "cell_type": "code", - "source": [ - "from jax.experimental import jax2tf\n", - "import tensorflow as tf\n", - "\n", - "exported_f = jax2tf.convert(resnet18, polymorphic_shapes=[\"(a,3,224,224)\"])\n", - "\n", - "# Copied from the jax2tf README.md > Usage: saved model\n", - "my_model = tf.Module()\n", - "my_model.f = tf.function(exported_f, autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 224, 224], tf.float32))\n", - "tf.saved_model.save(my_model, '/tmp/resnet18_tf', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))\n", - "\n", - "!ls /tmp/resnet18_tf" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "GXkgtX7QEiWa", - "outputId": "9034836e-4ba1-4210-c2c1-316777d4ad89", - "collapsed": true - }, - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "assets\tfingerprint.pb\tsaved_model.pb\tvariables\n" - ] - } - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", + " func.func public @main(%arg0: tensor<1x3x224x224xf32>) -> (tensor<1x512x7x7xf32> {jax.result_info = \"[0]\"}, tensor<1x512x1x1xf32> {jax.result_info = \"[1]\"}) {\n", + " %c = stablehlo.constant dense<49> : tensor\n", + " %cst = stablehlo.constant dense<0.000000e+00> : tensor\n", + " %cst_0 = stablehlo.constant dense<0xFF800000> : tensor\n", + " %cst_1 = stablehlo.constant dense<9.99999974E-6> : tensor\n", + " %cst_2 = stablehlo.constant dense_reso \n", + "...\n", + " func.func private @relu_3(%arg0: tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32> {\n", + " %cst = stablehlo.constant dense<0.000000e+00> : tensor\n", + " %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<1x7x7x512xf32>\n", + " %1 = stablehlo.maximum %arg0, %0 : tensor<1x7x7x512xf32>\n", + " return %1 : tensor<1x7x7x512xf32>\n", + " }\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "from transformers import AutoImageProcessor, FlaxResNetModel\n", + "import jax\n", + "import numpy as np\n", + "\n", + "# Construct jit-transformed flax model with sample inputs\n", + "resnet18 = FlaxResNetModel.from_pretrained(\"microsoft/resnet-18\", return_dict=False)\n", + "resnet18_jit = jax.jit(resnet18)\n", + "sample_input = np.random.randn(1, 3, 224, 224)\n", + "input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)\n", + "\n", + "# Export to StableHLO\n", + "stablehlo_resnet18_export = export.export(resnet18_jit)(input_shape)\n", + "resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())\n", + "print(resnet18_stablehlo[:600], \"\\n...\\n\", resnet18_stablehlo[-345:])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X2MC9F7Zlx6E" + }, + "source": [ + "### Export with dynamic batch size\n", + "\n", + "Now let's export that same model with a dynamic batch size!\n", + "\n", + "In the first example, we used an input shape of `tensor<1x3x224x224xf32>`, specifying strict constraints on the input shape. If we want to defer the concrete shapes used in compilation until a later point, we can specify a `symbolic_shape`. In this example, we'll export using `tensor`.\n", + "\n", + "Symbolic shapes are specified using `export.symbolic_shape`, with letters representing symint dimensions. For example, a valid 2-d matrix multiplication could use symbolic constraints of: `2,a * a,5` to ensure the refined program will have valid shapes. Symbolic integer names are kept track of by an `export.SymbolicScope` to avoid unintentional name clashes." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "sIkbtViEMJ3T", + "outputId": "165019c7-e771-43d6-c324-6bb4222798ec" + }, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "### Reload and call the SavedModel\n", - "\n", - "Now we can load that SavedModel and compile using our `sample_input` from a previous example.\n", - "\n", - "_Note: the restored model does *not* require JAX to run, just XLA._" - ], - "metadata": { - "id": "CmKABmLdrS3C" - } + "name": "stdout", + "output_type": "stream", + "text": [ + "module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", + " func.func public @main(%arg0: tensor) -> (tensor {jax.result_info = \"[0]\"}, tensor {jax.result_info = \"[1]\"}) {\n", + " %c = stablehlo.constant dense<1> : tensor\n", + " %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor\n", + " %1 = stablehlo.compare GE, %0, %c, SIGNED : (tensor, tensor) -> tensor\n", + " stablehlo.custom_call @shape_assertion(%1, %0) {api_version = 2 : i32, error_message = \"Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.\", has_side_effect = true} : (tensor, tensor) -> ()\n", + " %2:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor, tensor) -> (tensor, tensor)\n", + " return %2#0, %2#1 : tensor, tensor\n", + " }\n", + " func.func private @_wrapped_jax_export_main(%arg0: tensor {jax.global_constant = \"a\"}, %arg1: tensor) -> (tensor {jax.result_info = \"[0]\"}, tensor {jax.result_info = \"[1]\"}) {\n", + " %c = stablehlo.constant dense<1> : tensor<1xi32>\n", + " %c_0 = stablehlo.constant dense<49> : tensor\n", + " %cst = stablehlo.constant dense<0.000000e+00> : tensor\n", + " %c_1 = stablehlo.constant dense<512> : tensor<1xi32>\n", + " %c_2 = stablehlo.constant dense<7> : tensor<1xi32>\n", + " %c_3 = stablehlo.constant dense<256> : tensor<1x \n", + "...\n", + " , tensor<1xi32>) -> tensor<4xi32>\n", + " %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor, tensor<4xi32>) -> tensor\n", + " %3 = stablehlo.maximum %arg1, %2 : tensor\n", + " return %3 : tensor\n", + " }\n", + " func.func private @relu_3(%arg0: tensor {jax.global_constant = \"a\"}, %arg1: tensor) -> tensor {\n", + " %c = stablehlo.constant dense<512> : tensor<1xi32>\n", + " %c_0 = stablehlo.constant dense<7> : tensor<1xi32>\n", + " %cst = stablehlo.constant dense<0.000000e+00> : tensor\n", + " %0 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32>\n", + " %1 = stablehlo.concatenate %0, %c_0, %c_0, %c, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>\n", + " %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor, tensor<4xi32>) -> tensor\n", + " %3 = stablehlo.maximum %arg1, %2 : tensor\n", + " return %3 : tensor\n", + " }\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "# Construct dynamic sample inputs\n", + "dyn_scope = export.SymbolicScope()\n", + "dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape(\"a,3,224,224\", scope=dyn_scope), np.float32)\n", + "\n", + "# Export to StableHLO\n", + "dyn_resnet18_export = export.export(resnet18_jit)(dyn_input_shape)\n", + "dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())\n", + "print(dyn_resnet18_stablehlo[:1900], \"\\n...\\n\", dyn_resnet18_stablehlo[-1000:])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dRIu7xlSoDUK" + }, + "source": [ + "A few things to note in the exported StableHLO:\n", + "\n", + "1. The exported program now has `tensor`. These input types can be refined in many ways: StableHLO has APIs to [refine shapes](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L69-L99) and [canonicalize dynamic programs](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L18-L28) to static programs. TensorFlow SavedModel execution also takes care of refinement which we'll see in the next example.\n", + "2. JAX will generate guards to ensure the values of `a` are valid, in this case `a > 1` is checked. These can be washed away at compile time once refined." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xFU5M6Xm1U8_" + }, + "source": [ + "## Export to TensorFlow SavedModel\n", + "\n", + "It is common to export a StableHLO model to SavedModel for interoperability with existing compilation pipelines, existing TensorFlow tooling, or serving via [TensorFlow Serving](https://github.com/tensorflow/serving).\n", + "\n", + "JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section, we'll be using our dynamic model from the previous section." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lf7Fnsrop7BD" + }, + "source": [ + "### Export to SavedModel using `jax2tf`\n", + "\n", + "JAX provides a simple API for exporting StableHLO into a format that can be packaged in SavedModel in `jax.experimental.jax2tf`. This uses the `export` function under the hood, so the same `jit` requirements apply.\n", + "\n", + "Full details on `jax2tf` can be found in the [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#jax-and-tensorflow-interoperation-jax2tfcall_tf). For this example, we'll only need to know the `polymorphic_shapes` option to specify our dynamic batch dimension." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "GXkgtX7QEiWa", + "outputId": "9034836e-4ba1-4210-c2c1-316777d4ad89" + }, + "outputs": [ { - "cell_type": "code", - "source": [ - "restored_model = tf.saved_model.load('/tmp/resnet18_tf')\n", - "restored_result = restored_model.f(tf.constant(sample_input, tf.float32))\n", - "print(\"Result shape:\", restored_result[0].shape)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9Az3dXXWrVDM", - "outputId": "cac6d3b1-126e-4e66-dc4d-f2a798ede463" - }, - "execution_count": 8, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Result shape: (1, 512, 7, 7)\n" - ] - } - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34massets\u001b[m\u001b[m fingerprint.pb saved_model.pb \u001b[34mvariables\u001b[m\u001b[m\n" + ] + } + ], + "source": [ + "from jax.experimental import jax2tf\n", + "import tensorflow as tf\n", + "\n", + "exported_f = jax2tf.convert(resnet18, polymorphic_shapes=[\"(a,3,224,224)\"])\n", + "\n", + "# Copied from the jax2tf README.md > Usage: saved model\n", + "my_model = tf.Module()\n", + "my_model.f = tf.function(exported_f, autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 224, 224], tf.float32))\n", + "tf.saved_model.save(my_model, '/tmp/resnet18_tf', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))\n", + "\n", + "!ls /tmp/resnet18_tf" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CmKABmLdrS3C" + }, + "source": [ + "### Reload and call the SavedModel\n", + "\n", + "Now we can load that SavedModel and compile using our `sample_input` from a previous example.\n", + "\n", + "_Note: The restored model does *not* require JAX to run, just XLA._" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "9Az3dXXWrVDM", + "outputId": "cac6d3b1-126e-4e66-dc4d-f2a798ede463" + }, + "outputs": [ { - "cell_type": "markdown", - "source": [ - "## Common Troubleshooting\n", - "\n", - "If the function can be JIT'ed, then it can be exported. Focus on making `jax.jit` work first, or look in desired project for uses of JIT already (ex: [AlphaFold's `apply`](https://github.com/google-deepmind/alphafold/blob/dbe2a438ebfc6289f960292f15dbf421a05e563d/alphafold/model/model.py#L89) can be exported easily).\n", - "\n", - "See [JAX JIT Examples](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#Examples) for troubleshooting. The most common issue is control flow, which can often be resolved with `static_argnums` / `static_argnames` as in the linked example.\n", - "\n", - "For opening a ticket for help, include a repo using one of the above APIs, this will help get the issue resolved much quicker!" - ], - "metadata": { - "id": "a2Dsm2oF5jn4" - } + "name": "stdout", + "output_type": "stream", + "text": [ + "Result shape: (1, 512, 7, 7)\n" + ] } - ] + ], + "source": [ + "restored_model = tf.saved_model.load('/tmp/resnet18_tf')\n", + "restored_result = restored_model.f(tf.constant(sample_input, tf.float32))\n", + "print(\"Result shape:\", restored_result[0].shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a2Dsm2oF5jn4" + }, + "source": [ + "## Troubleshooting\n", + "\n", + "### `jax.jit` issues\n", + "\n", + "If the function can be JIT'ed, then it can be exported. Ensure `jax.jit` works first, or look in desired project for uses of JIT already (for example, [AlphaFold's `apply`](https://github.com/google-deepmind/alphafold/blob/dbe2a438ebfc6289f960292f15dbf421a05e563d/alphafold/model/model.py#L89) can be exported easily). \n", + "\n", + "See [JAX's JIT compilation documentation](https://jax.readthedocs.io/en/latest/jit-compilation.html) and [`jax.jit` API reference and examples](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) for troubleshooting JIT transformations. The most common issue is control flow, which can often be resolved with `static_argnums` / `static_argnames` as in the linked example.\n", + "\n", + "### Support tickets\n", + "\n", + "You can open an issue on GitHub for further help. Include a reproducible example using one of the above APIs in your issue report, this will help get the issue resolved much quicker!" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 }