From 9676ea25a850530a411454ab4c6b8e6ceea88788 Mon Sep 17 00:00:00 2001 From: Sudeep Agarwal Date: Wed, 10 May 2023 23:33:35 -0400 Subject: [PATCH 1/3] Update usage of call_tir --- 4_Build_End_to_End_Model.ipynb | 391 ++++++++++++++++++++------------- 1 file changed, 241 insertions(+), 150 deletions(-) diff --git a/4_Build_End_to_End_Model.ipynb b/4_Build_End_to_End_Model.ipynb index feade76..aefd3ac 100644 --- a/4_Build_End_to_End_Model.ipynb +++ b/4_Build_End_to_End_Model.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text", @@ -11,6 +12,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Mpn1ti5Urdsv" @@ -20,6 +22,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "qXysoqn-vZuF" @@ -65,6 +68,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "i-14C4skxIrJ" @@ -75,6 +79,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "EpXZsatqxnyz" @@ -84,6 +89,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "8k7AtYC0x0jD" @@ -97,6 +103,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "BBIuE2jc1DaU" @@ -110,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 11, "metadata": { "id": "BVp0fHyRkYj6" }, @@ -126,6 +133,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "wIIsUXGqpRqV" @@ -139,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 55, "metadata": { "id": "NdWS5Jabq-DN" }, @@ -162,6 +170,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "acOCCvmSPaR0" @@ -172,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 56, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -183,18 +192,20 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Class: Sandal\n" - ] + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "/tmp/ipykernel_30774/2167287842.py:7: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.\n", - " plt.show()\n" + "Class: Sandal\n" ] } ], @@ -211,6 +222,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "dPR_WrTZglbh" @@ -222,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 58, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -235,21 +247,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "--2022-09-19 18:26:12-- https://github.com/mlc-ai/web-data/raw/main/models/fasionmnist_mlp_params.pkl\n", - "Resolving github.com (github.com)... 192.30.255.112\n", - "Connecting to github.com (github.com)|192.30.255.112|:443... connected.\n", + "--2023-05-10 23:27:23-- https://github.com/mlc-ai/web-data/raw/main/models/fasionmnist_mlp_params.pkl\n", + "Resolving github.com (github.com)... 140.82.113.4\n", + "Connecting to github.com (github.com)|140.82.113.4|:443... connected.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://raw.githubusercontent.com/mlc-ai/web-data/main/models/fasionmnist_mlp_params.pkl [following]\n", - "--2022-09-19 18:26:13-- https://raw.githubusercontent.com/mlc-ai/web-data/main/models/fasionmnist_mlp_params.pkl\n", - "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.110.133, ...\n", - "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n", + "--2023-05-10 23:27:24-- https://raw.githubusercontent.com/mlc-ai/web-data/main/models/fasionmnist_mlp_params.pkl\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8000::154, 2606:50c0:8001::154, 2606:50c0:8002::154, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8000::154|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 407396 (398K) [application/octet-stream]\n", - "Saving to: ‘fasionmnist_mlp_params.pkl.1’\n", + "Saving to: ‘fasionmnist_mlp_params.pkl’\n", "\n", "fasionmnist_mlp_par 100%[===================>] 397.85K --.-KB/s in 0.05s \n", "\n", - "2022-09-19 18:26:13 (7.18 MB/s) - ‘fasionmnist_mlp_params.pkl.1’ saved [407396/407396]\n", + "2023-05-10 23:27:24 (7.70 MB/s) - ‘fasionmnist_mlp_params.pkl’ saved [407396/407396]\n", "\n" ] } @@ -259,6 +271,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Rk5jwmmDzddJ" @@ -271,6 +284,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "hk64a3UllGIV" @@ -280,6 +294,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "PUcCRU2IQPm-" @@ -290,7 +305,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 60, "metadata": { "id": "vvfOgcu-YdaB" }, @@ -305,7 +320,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 61, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -318,8 +333,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "[[-10.968934 -13.400272 -7.7212744 -7.4016604 -7.5777307 8.872316\n", - " -6.1305714 -8.879843 -3.4321747 -2.1780372]]\n", + "[[ -8.505112 -19.33341 -5.5189652 -6.8927536 -14.0578785 11.494652\n", + " -11.22116 -9.992905 -2.6286726 -18.563715 ]]\n", "[5]\n", "Numpy-MLP Prediction: Sandal\n" ] @@ -340,6 +355,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "IgYd_scF1Vjw" @@ -433,6 +449,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "8OSO2aYs-hFD" @@ -445,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "metadata": { "id": "6wADqZ_TpGl7" }, @@ -454,8 +471,8 @@ "@tvm.script.ir_module\n", "class MyModule: \n", " @T.prim_func\n", - " def relu0(X: T.Buffer[(1, 128), \"float32\"], \n", - " Y: T.Buffer[(1, 128), \"float32\"]):\n", + " def relu0(X: T.Buffer((1, 128), \"float32\"), \n", + " Y: T.Buffer((1, 128), \"float32\")):\n", " # function attr dict\n", " T.func_attr({\"global_symbol\": \"relu0\", \"tir.noalias\": True})\n", " for i, j in T.grid(1, 128):\n", @@ -464,10 +481,10 @@ " Y[vi, vj] = T.max(X[vi, vj], T.float32(0))\n", "\n", " @T.prim_func\n", - " def linear0(X: T.Buffer[(1, 784), \"float32\"], \n", - " W: T.Buffer[(128, 784), \"float32\"], \n", - " B: T.Buffer[(128,), \"float32\"], \n", - " Z: T.Buffer[(1, 128), \"float32\"]):\n", + " def linear0(X: T.Buffer((1, 784), \"float32\"), \n", + " W: T.Buffer((128, 784), \"float32\"), \n", + " B: T.Buffer((128,), \"float32\"), \n", + " Z: T.Buffer((1, 128), \"float32\")):\n", " T.func_attr({\"global_symbol\": \"linear0\", \"tir.noalias\": True})\n", " Y = T.alloc_buffer((1, 128), \"float32\")\n", " for i, j, k in T.grid(1, 128, 784):\n", @@ -483,10 +500,10 @@ " Z[vi, vj] = Y[vi, vj] + B[vj]\n", "\n", " @T.prim_func\n", - " def linear1(X: T.Buffer[(1, 128), \"float32\"], \n", - " W: T.Buffer[(10, 128), \"float32\"], \n", - " B: T.Buffer[(10,), \"float32\"], \n", - " Z: T.Buffer[(1, 10), \"float32\"]):\n", + " def linear1(X: T.Buffer((1, 128), \"float32\"), \n", + " W: T.Buffer((10, 128), \"float32\"), \n", + " B: T.Buffer((10,), \"float32\"), \n", + " Z: T.Buffer((1, 10), \"float32\")):\n", " T.func_attr({\"global_symbol\": \"linear1\", \"tir.noalias\": True})\n", " Y = T.alloc_buffer((1, 10), \"float32\")\n", " for i, j, k in T.grid(1, 10, 128):\n", @@ -508,14 +525,15 @@ " w1: R.Tensor((10, 128), \"float32\"), \n", " b1: R.Tensor((10,), \"float32\")):\n", " with R.dataflow():\n", - " lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype=\"float32\")\n", - " lv1 = R.call_tir(relu0, (lv0,), (1, 128), dtype=\"float32\")\n", - " out = R.call_tir(linear1, (lv1, w1, b1), (1, 10), dtype=\"float32\")\n", + " lv0 = R.call_dps_packed(\"linear0\", (x, w0, b0), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", + " lv1 = R.call_dps_packed(\"relu0\", (lv0,), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", + " out = R.call_dps_packed(\"linear1\", (lv1, w1, b1), out_sinfo=R.Tensor((1, 10), dtype=\"float32\"))\n", " R.output(out)\n", " return out" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "d5-LtDpRoia2" @@ -524,10 +542,11 @@ "The above code contains kinds of functions: the primitive tensor functions (`T.prim_func`) that we saw in the last lecture and a new `R.function` (relax function). \n", "Relax function is a new type of abstraction representing high-level neural network executions. \n", "\n", - "Again it is helpful to see the TVMScript code and low-level numpy code side-by-side and check the corresponding elements, and we are going to walk through each of them in detail. Since we already learned about primitive tensor functions, we are going to focus on the high-level execution part." + "Again it is helpful to see the TVMScript code and low-level numpy code side-by-side and check the corresponding elements, and we are going to walk through each of them in detail. Since we already learned about primitive tensor functions, we are going to focus on the high-level execution part. Note that the `call_tir` API below has been changed to `call_dps_packed`." ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Cq1FNpoNojNx" @@ -537,6 +556,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "hhfr5u-2msNV" @@ -546,6 +566,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "CQjGG5AxmdwJ" @@ -555,6 +576,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "nYweMW1krbih" @@ -569,28 +591,30 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "3WoknpyJYcph" }, "source": [ - "### call_tir construct\n", + "### call_dps_packed construct\n", "\n", - "One thing that you may have noticed is that each step of operations in the computational graph contains an `R.call_tir` operation. This is the operation that brings in the tensor primitive functions\n", + "One thing that you may have noticed is that each step of operations in the computational graph contains an `R.call_dps_packed` operation. This is the operation that brings in the tensor primitive functions\n", "\n", "\n", "```python\n", - "lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype=\"float32\")\n", + "lv0 = R.call_dps_packed(\"linear0\", (x, w0, b0), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", "```" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "LQ-48agqtixI" }, "source": [ - "To explain what does `R.call_tir` mean, let us review an equivalent low-level numpy implementation of the operation, as follows:" + "To explain what does `R.call_dps_packed` mean, let us review an equivalent low-level numpy implementation of the operation, as follows:" ] }, { @@ -601,23 +625,24 @@ }, "outputs": [], "source": [ - "def lnumpy_call_tir(prim_func, inputs, shape, dtype):\n", + "def lnumpy_call_dps_packed(prim_func, inputs, shape, dtype):\n", " res = np.empty(shape, dtype=dtype)\n", " prim_func(*inputs, res)\n", " return res" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Juc_EGsIt7js" }, "source": [ - "Specifically, call_tir takes in a primitive function (`prim_func`) a list of inputs. Then what it does is allocate an output tensor `res`, then pass the inputs and the output to the `prim_func`. After executing `prim_func` the result is populated in `res`, then we can return the result.\n", + "Specifically, call_dps_packed takes in a primitive function (`prim_func`) a list of inputs. Then what it does is allocate an output tensor `res`, then pass the inputs and the output to the `prim_func`. After executing `prim_func` the result is populated in `res`, then we can return the result.\n", "\n", - "Note that `lnumpy_call_tir` is only a reference implementation to show the meaning of `R.call_tir`. In practice, there can be different low-level ways to optimize the execution. For example, we might choose to allocate all the output memories ahead of time and then run the execution, which we will cover in future lectures. \n", + "Note that `lnumpy_call_dps_packed` is only a reference implementation to show the meaning of `R.call_dps_packed`. In practice, there can be different low-level ways to optimize the execution. For example, we might choose to allocate all the output memories ahead of time and then run the execution, which we will cover in future lectures. \n", "\n", - "A natural question that one could ask is why do we need `call_tir` construct? This is because our primitive tensor functions take the following calling convention.\n", + "A natural question that one could ask is why do we need `call_dps_packed` construct? This is because our primitive tensor functions take the following calling convention.\n", "\n", "```python\n", "def low_level_prim_func(in0, in1, ..., out):\n", @@ -627,6 +652,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "msKx_bChuLfU" @@ -649,6 +675,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "UEjQ0O-20U7b" @@ -658,6 +685,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "-hpjGSxf0j7D" @@ -672,25 +700,26 @@ "\n", "Of course, we can still generalize the graph definition by introducing the input edge and output edge, and that can complicate the possible transformations associated with the abstraction. \n", "\n", - "So coming back to `call_tir`, the key insight here is that we want to hide possible allocation or explicit writing to the functions. In a more formal term, we want the function to be **pure** or **side-effect free**.\n", + "So coming back to `call_dps_packed`, the key insight here is that we want to hide possible allocation or explicit writing to the functions. In a more formal term, we want the function to be **pure** or **side-effect free**.\n", "\n", "A function is **pure** or **side-effect free** if: it only reads from its inputs and returns the result via its output, it will not change other parts of the program (such as incrementing a global counter).\n", "\n", - "**call_tir** is a way for us to hide these details of calling into low-level primitive functions and expose them into a computational graph." + "**call_dps_packed** is a way for us to hide these details of calling into low-level primitive functions and expose them into a computational graph." ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "qLVcODPTZj_v" }, "source": [ - "We can also see `call_tir` in action in the low-level numpy as well. Now we have defined the `lnumpy_call_tir`, we can rewrite the low-level numpy execution code as:" + "We can also see `call_dps_packed` in action in the low-level numpy as well. Now we have defined the `lnumpy_call_dps_packed`, we can rewrite the low-level numpy execution code as:" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -698,23 +727,15 @@ "id": "nRj7MsLWYeO5", "outputId": "e4378e32-2706-491a-8f42-799199fac042" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Low-level Numpy with CallTIR Prediction: Sandal\n" - ] - } - ], + "outputs": [], "source": [ - "def lnumpy_mlp_with_call_tir(data, w0, b0, w1, b1):\n", - " lv0 = lnumpy_call_tir(lnumpy_linear0, (data, w0, b0), (1, 128), dtype=\"float32\")\n", - " lv1 = lnumpy_call_tir(lnumpy_relu0, (lv0, ), (1, 128), dtype=\"float32\")\n", - " out = lnumpy_call_tir(lnumpy_linear1, (lv1, w1, b1), (1, 10), dtype=\"float32\")\n", + "def lnumpy_mlp_with_call_dps_packed(data, w0, b0, w1, b1):\n", + " lv0 = lnumpy_call_dps_packed(lnumpy_linear0, (data, w0, b0), (1, 128), dtype=\"float32\")\n", + " lv1 = lnumpy_call_dps_packed(lnumpy_relu0, (lv0, ), (1, 128), dtype=\"float32\")\n", + " out = lnumpy_call_dps_packed(lnumpy_linear1, (lv1, w1, b1), (1, 10), dtype=\"float32\")\n", " return out\n", "\n", - "result = lnumpy_mlp_with_call_tir(\n", + "result = lnumpy_mlp_with_call_dps_packed(\n", " img.reshape(1, 784), \n", " mlp_params[\"w0\"], \n", " mlp_params[\"b0\"], \n", @@ -722,19 +743,21 @@ " mlp_params[\"b1\"])\n", "\n", "pred_kind = np.argmax(result, axis=1)\n", - "print(\"Low-level Numpy with CallTIR Prediction:\", class_names[pred_kind[0]])" + "print(\"Low-level Numpy with Call DPS Packed Prediction:\", class_names[pred_kind[0]])" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Fx0vj13N3Fro" }, "source": [ - "In practice, the lowest-level implementation will have explicit memory allocations, so `call_tir` mainly serves as a purpose for us to continue to do some high-level transformations before we generate the actual implementation." + "In practice, the lowest-level implementation will have explicit memory allocations, so `call_dps_packed` mainly serves as a purpose for us to continue to do some high-level transformations before we generate the actual implementation." ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "5QIpOq6O3dfw" @@ -746,9 +769,9 @@ "\n", "```python\n", "with R.dataflow():\n", - " lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype=\"float32\")\n", - " lv1 = R.call_tir(relu0, (lv0,), (1, 128), dtype=\"float32\")\n", - " out = R.call_tir(linear1, (lv1, w1, b1), (1, 10), dtype=\"float32\")\n", + " lv0 = R.call_dps_packed(\"linear0\", (x, w0, b0), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", + " lv1 = R.call_dps_packed(\"relu0\", (lv0,), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", + " out = R.call_dps_packed(\"linear1\", (lv1, w1, b1), out_sinfo=R.Tensor((1, 10), dtype=\"float32\"))\n", " R.output(out)\n", "```\n", "\n", @@ -765,14 +788,14 @@ " b1: Tensor((10,), \"float32\")):\n", "\n", " with R.dataflow():\n", - " lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype=\"float32\")\n", - " gv0 = R.call_tir(relu0, (lv0,), (1, 128), dtype=\"float32\")\n", + " lv0 = R.call_dps_packed(\"linear0\", (x, w0, b0), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", + " gv0 = R.call_dps_packed(\"relu0\", (lv0,), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", " R.output(gv0)\n", "\n", " gv1 = R.alloc_tensor((1, 128), dtype=\"float32\")\n", "\n", " with R.dataflow():\n", - " out = R.call_tir(linear1, (gv0, gv1, b0), (1, 128), dtype=\"float32\")\n", + " out = R.call_dps_packed(\"linear1\", (gv0, gv1, b0), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", " R.output(out)\n", " return out\n", "```\n", @@ -781,6 +804,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Y9aYKOoiC47j" @@ -791,13 +815,14 @@ "\n", "So far, we have gone through one example instance of relax program and covered most of the elements, including:\n", "- computational graph view\n", - "- `call_tir` construct\n", + "- `call_dps_packed` construct\n", "- Dataflow block.\n", "\n", "These elements should get us started in the end to end model execution and compilation. we will also cover new concepts as we encounter them in later chapters." ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "SXa8L7_OhGTX" @@ -811,7 +836,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -824,28 +849,23 @@ { "data": { "text/html": [ - "
@tvm.script.ir_module\n",
-       "class Module:\n",
+       "
# from tvm.script import ir as I\n",
+       "# from tvm.script import tir as T\n",
+       "# from tvm.script import relax as R\n",
+       "\n",
+       "\n",
+       "@I.ir_module\n",
+       "class Module:\n",
        "    @T.prim_func\n",
-       "    def relu0(X: T.Buffer[(1, 128), "float32"], Y: T.Buffer[(1, 128), "float32"]) -> None:\n",
-       "        # function attr dict\n",
-       "        T.func_attr({"global_symbol": "relu0", "tir.noalias": True})\n",
-       "        # body\n",
-       "        # with T.block("root")\n",
-       "        for i, j in T.grid(1, 128):\n",
-       "            with T.block("Y"):\n",
-       "                vi, vj = T.axis.remap("SS", [i, j])\n",
-       "                T.reads(X[vi, vj])\n",
-       "                T.writes(Y[vi, vj])\n",
-       "                Y[vi, vj] = T.max(X[vi, vj], T.float32(0))\n",
-       "    \n",
-       "    @T.prim_func\n",
-       "    def linear0(X: T.Buffer[(1, 784), "float32"], W: T.Buffer[(128, 784), "float32"], B: T.Buffer[128, "float32"], Z: T.Buffer[(1, 128), "float32"]) -> None:\n",
-       "        # function attr dict\n",
-       "        T.func_attr({"global_symbol": "linear0", "tir.noalias": True})\n",
-       "        # body\n",
-       "        # with T.block("root")\n",
-       "        Y = T.alloc_buffer([1, 128], dtype="float32")\n",
+       "    def linear0(\n",
+       "        X: T.Buffer((1, 784), "float32"),\n",
+       "        W: T.Buffer((128, 784), "float32"),\n",
+       "        B: T.Buffer((128,), "float32"),\n",
+       "        Z: T.Buffer((1, 128), "float32"),\n",
+       "    ):\n",
+       "        T.func_attr({"global_symbol": "linear0", "tir.noalias": T.bool(True)})\n",
+       "        # with T.block("root"):\n",
+       "        Y = T.alloc_buffer((1, 128))\n",
        "        for i, j, k in T.grid(1, 128, 784):\n",
        "            with T.block("Y"):\n",
        "                vi, vj, vk = T.axis.remap("SSR", [i, j, k])\n",
@@ -860,24 +880,17 @@
        "                T.reads(Y[vi, vj], B[vj])\n",
        "                T.writes(Z[vi, vj])\n",
        "                Z[vi, vj] = Y[vi, vj] + B[vj]\n",
-       "    \n",
-       "    @R.function\n",
-       "    def main(x: Tensor((1, 784), "float32"), w0: Tensor((128, 784), "float32"), b0: Tensor((128,), "float32"), w1: Tensor((10, 128), "float32"), b1: Tensor((10,), "float32")) -> Tensor(None, "float32", ndim = 2):\n",
-       "        # block 0\n",
-       "        with R.dataflow():\n",
-       "            lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32")\n",
-       "            lv1 = R.call_tir(relu0, (lv0,), (1, 128), dtype="float32")\n",
-       "            out = R.call_tir(linear1, (lv1, w1, b1), (1, 10), dtype="float32")\n",
-       "            R.output(out)\n",
-       "        return out\n",
-       "    \n",
+       "\n",
        "    @T.prim_func\n",
-       "    def linear1(X: T.Buffer[(1, 128), "float32"], W: T.Buffer[(10, 128), "float32"], B: T.Buffer[10, "float32"], Z: T.Buffer[(1, 10), "float32"]) -> None:\n",
-       "        # function attr dict\n",
-       "        T.func_attr({"global_symbol": "linear1", "tir.noalias": True})\n",
-       "        # body\n",
-       "        # with T.block("root")\n",
-       "        Y = T.alloc_buffer([1, 10], dtype="float32")\n",
+       "    def linear1(\n",
+       "        X: T.Buffer((1, 128), "float32"),\n",
+       "        W: T.Buffer((10, 128), "float32"),\n",
+       "        B: T.Buffer((10,), "float32"),\n",
+       "        Z: T.Buffer((1, 10), "float32"),\n",
+       "    ):\n",
+       "        T.func_attr({"global_symbol": "linear1", "tir.noalias": T.bool(True)})\n",
+       "        # with T.block("root"):\n",
+       "        Y = T.alloc_buffer((1, 10))\n",
        "        for i, j, k in T.grid(1, 10, 128):\n",
        "            with T.block("Y"):\n",
        "                vi, vj, vk = T.axis.remap("SSR", [i, j, k])\n",
@@ -892,7 +905,38 @@
        "                T.reads(Y[vi, vj], B[vj])\n",
        "                T.writes(Z[vi, vj])\n",
        "                Z[vi, vj] = Y[vi, vj] + B[vj]\n",
-       "    \n",
+       "\n",
+       "    @T.prim_func\n",
+       "    def relu0(X: T.Buffer((1, 128), "float32"), Y: T.Buffer((1, 128), "float32")):\n",
+       "        T.func_attr({"global_symbol": "relu0", "tir.noalias": T.bool(True)})\n",
+       "        # with T.block("root"):\n",
+       "        for i, j in T.grid(1, 128):\n",
+       "            with T.block("Y"):\n",
+       "                vi, vj = T.axis.remap("SS", [i, j])\n",
+       "                T.reads(X[vi, vj])\n",
+       "                T.writes(Y[vi, vj])\n",
+       "                Y[vi, vj] = T.max(X[vi, vj], T.float32(0))\n",
+       "\n",
+       "    @R.function\n",
+       "    def main(\n",
+       "        x: R.Tensor((1, 784), dtype="float32"),\n",
+       "        w0: R.Tensor((128, 784), dtype="float32"),\n",
+       "        b0: R.Tensor((128,), dtype="float32"),\n",
+       "        w1: R.Tensor((10, 128), dtype="float32"),\n",
+       "        b1: R.Tensor((10,), dtype="float32"),\n",
+       "    ) -> R.Tensor((1, 10), dtype="float32"):\n",
+       "        with R.dataflow():\n",
+       "            lv0 = R.call_dps_packed(\n",
+       "                "linear0", (x, w0, b0), out_sinfo=R.Tensor((1, 128), dtype="float32")\n",
+       "            )\n",
+       "            lv1 = R.call_dps_packed(\n",
+       "                "relu0", (lv0,), out_sinfo=R.Tensor((1, 128), dtype="float32")\n",
+       "            )\n",
+       "            out = R.call_dps_packed(\n",
+       "                "linear1", (lv1, w1, b1), out_sinfo=R.Tensor((1, 10), dtype="float32")\n",
+       "            )\n",
+       "            R.output(out)\n",
+       "        return out\n",
        "
\n" ], "text/plain": [ @@ -908,6 +952,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Ai-pjfobEpoi" @@ -944,6 +989,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "uSyXxd3fE7rg" @@ -964,6 +1010,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "lpIMjrVdFR0d" @@ -974,7 +1021,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 62, "metadata": { "id": "9mOk7BkxFRC9" }, @@ -985,6 +1032,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "ksx1-hl1FrtA" @@ -1028,6 +1076,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "0Cv6FEY1F2fx" @@ -1061,6 +1110,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "g2pBn9vlTz3I" @@ -1091,29 +1141,31 @@ " b1: R.Tensor((10,), \"float32\")):\n", " # block 0\n", " with R.dataflow():\n", - " lv0 = R.call_tir(\"env.linear\", (x, w0, b0), (1, 128), dtype=\"float32\")\n", - " lv1 = R.call_tir(\"env.relu\", (lv0,), (1, 128), dtype=\"float32\")\n", - " out = R.call_tir(\"env.linear\", (lv1, w1, b1), (1, 10), dtype=\"float32\")\n", + " lv0 = R.call_dps_packed(\"env.linear\", (x, w0, b0), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", + " lv1 = R.call_dps_packed(\"env.relu\", (lv0,), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", + " out = R.call_dps_packed(\"env.linear\", (lv1, w1, b1), out_sinfo=R.Tensor((1, 10), dtype=\"float32\"))\n", " R.output(out)\n", " return out" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "6u42cRGpGk7Y" }, "source": [ - "Note that we now directly pass in strings in `call_tir`\n", + "Note that we now directly pass in strings in `call_dps_packed`\n", "\n", "```python\n", - "R.call_tir(\"env.linear\", (x, w0, b0), (1, 128), dtype=\"float32\")\n", + "R.call_dps_packed(\"env.linear\", (x, w0, b0), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", "```\n", "\n", "These strings are names of runtime functions that we expect to exist during model execution. " ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "6uBJoYjHHrlM" @@ -1153,6 +1205,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "z-UeSjTeH9QN" @@ -1166,6 +1219,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "LKkqyqATJboW" @@ -1210,6 +1264,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "pi3g5dNdWHFr" @@ -1222,7 +1277,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 64, "metadata": { "id": "poavVsmOWPDW" }, @@ -1231,10 +1286,10 @@ "@tvm.script.ir_module\n", "class MyModuleMixture: \n", " @T.prim_func\n", - " def linear0(X: T.Buffer[(1, 784), \"float32\"], \n", - " W: T.Buffer[(128, 784), \"float32\"], \n", - " B: T.Buffer[(128,), \"float32\"], \n", - " Z: T.Buffer[(1, 128), \"float32\"]):\n", + " def linear0(X: T.Buffer((1, 784), \"float32\"), \n", + " W: T.Buffer((128, 784), \"float32\"), \n", + " B: T.Buffer((128,), \"float32\"), \n", + " Z: T.Buffer((1, 128), \"float32\")):\n", " T.func_attr({\"global_symbol\": \"linear0\", \"tir.noalias\": True})\n", " Y = T.alloc_buffer((1, 128), \"float32\")\n", " for i, j, k in T.grid(1, 128, 784):\n", @@ -1256,14 +1311,15 @@ " w1: R.Tensor((10, 128), \"float32\"), \n", " b1: R.Tensor((10,), \"float32\")):\n", " with R.dataflow():\n", - " lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype=\"float32\")\n", - " lv1 = R.call_tir(\"env.relu\", (lv0,), (1, 128), dtype=\"float32\")\n", - " out = R.call_tir(\"env.linear\", (lv1, w1, b1), (1, 10), dtype=\"float32\")\n", + " lv0 = R.call_dps_packed(\"linear0\", (x, w0, b0), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", + " lv1 = R.call_dps_packed(\"env.relu\", (lv0,), out_sinfo=R.Tensor((1, 128), dtype=\"float32\"))\n", + " out = R.call_dps_packed(\"env.linear\", (lv1, w1, b1), out_sinfo=R.Tensor((1, 10), dtype=\"float32\"))\n", " R.output(out)\n", " return out" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Sy5QigNPKIp0" @@ -1306,6 +1362,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "wDDFMW1-Ksi3" @@ -1318,7 +1375,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 63, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -1331,25 +1388,23 @@ { "data": { "text/html": [ - "
@tvm.script.ir_module\n",
-       "class Module:\n",
-       "    @R.function\n",
-       "    def main(x: Tensor((1, 784), "float32")) -> Tensor(None, "float32", ndim = 2):\n",
-       "        # block 0\n",
-       "        with R.dataflow():\n",
-       "            lv0 = R.call_tir(linear0, (x, meta[relay.Constant][0], meta[relay.Constant][1]), (1, 128), dtype="float32")\n",
-       "            lv1 = R.call_tir("env.relu", (lv0,), (1, 128), dtype="float32")\n",
-       "            out = R.call_tir("env.linear", (lv1, meta[relay.Constant][2], meta[relay.Constant][3]), (1, 10), dtype="float32")\n",
-       "            R.output(out)\n",
-       "        return out\n",
-       "    \n",
+       "
# from tvm.script import ir as I\n",
+       "# from tvm.script import tir as T\n",
+       "# from tvm.script import relax as R\n",
+       "\n",
+       "\n",
+       "@I.ir_module\n",
+       "class Module:\n",
        "    @T.prim_func\n",
-       "    def linear0(X: T.Buffer[(1, 784), "float32"], W: T.Buffer[(128, 784), "float32"], B: T.Buffer[128, "float32"], Z: T.Buffer[(1, 128), "float32"]) -> None:\n",
-       "        # function attr dict\n",
-       "        T.func_attr({"global_symbol": "linear0", "tir.noalias": True})\n",
-       "        # body\n",
-       "        # with T.block("root")\n",
-       "        Y = T.alloc_buffer([1, 128], dtype="float32")\n",
+       "    def linear0(\n",
+       "        X: T.Buffer((1, 784), "float32"),\n",
+       "        W: T.Buffer((128, 784), "float32"),\n",
+       "        B: T.Buffer((128,), "float32"),\n",
+       "        Z: T.Buffer((1, 128), "float32"),\n",
+       "    ):\n",
+       "        T.func_attr({"global_symbol": "linear0", "tir.noalias": T.bool(True)})\n",
+       "        # with T.block("root"):\n",
+       "        Y = T.alloc_buffer((1, 128))\n",
        "        for i, j, k in T.grid(1, 128, 784):\n",
        "            with T.block("Y"):\n",
        "                vi, vj, vk = T.axis.remap("SSR", [i, j, k])\n",
@@ -1364,7 +1419,38 @@
        "                T.reads(Y[vi, vj], B[vj])\n",
        "                T.writes(Z[vi, vj])\n",
        "                Z[vi, vj] = Y[vi, vj] + B[vj]\n",
-       "    \n",
+       "\n",
+       "    @R.function\n",
+       "    def main(\n",
+       "        x: R.Tensor((1, 784), dtype="float32")\n",
+       "    ) -> R.Tensor((1, 10), dtype="float32"):\n",
+       "        with R.dataflow():\n",
+       "            lv0 = R.call_dps_packed(\n",
+       "                "linear0",\n",
+       "                (\n",
+       "                    x,\n",
+       "                    metadata["relax.expr.Constant"][0],\n",
+       "                    metadata["relax.expr.Constant"][1],\n",
+       "                ),\n",
+       "                out_sinfo=R.Tensor((1, 128), dtype="float32"),\n",
+       "            )\n",
+       "            lv1 = R.call_dps_packed(\n",
+       "                "env.relu", (lv0,), out_sinfo=R.Tensor((1, 128), dtype="float32")\n",
+       "            )\n",
+       "            out = R.call_dps_packed(\n",
+       "                "env.linear",\n",
+       "                (\n",
+       "                    lv1,\n",
+       "                    metadata["relax.expr.Constant"][2],\n",
+       "                    metadata["relax.expr.Constant"][3],\n",
+       "                ),\n",
+       "                out_sinfo=R.Tensor((1, 10), dtype="float32"),\n",
+       "            )\n",
+       "            R.output(out)\n",
+       "        return out\n",
+       "\n",
+       "\n",
+       "# Metadata omitted. Use show_meta=True in script() method to show it.\n",
        "
\n" ], "text/plain": [ @@ -1381,6 +1467,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "j8knyt63L8c0" @@ -1419,6 +1506,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "Y7UCWyFusX5X" @@ -1436,6 +1524,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "y2KrBILMsNGf" @@ -1445,6 +1534,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "VkB5oHttOY0U" @@ -1458,6 +1548,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "pZGWq5EjJ0BA" @@ -1466,7 +1557,7 @@ "## Summary\n", "- Computational graph abstraction helps to stitch primitive tensor functions together for end-to-end execution.\n", "- Key elements of relax abstraction include\n", - " - call_tir construct that embeds destination passing style primitive function into the computational graph\n", + " - call_dps_packed construct that embeds destination passing style primitive function into the computational graph\n", " - dataflow block\n", "- Computational graph allows call into both environment library functions and TensorIR functions." ] @@ -1494,7 +1585,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.1" + "version": "3.9.16" }, "vscode": { "interpreter": { From 9298571bf8f7e6dfb677eadb3c99edd1a51a2d03 Mon Sep 17 00:00:00 2001 From: Sudeep Agarwal Date: Thu, 11 May 2023 09:25:58 -0400 Subject: [PATCH 2/3] Update relax build usage --- 4_Build_End_to_End_Model.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/4_Build_End_to_End_Model.ipynb b/4_Build_End_to_End_Model.ipynb index aefd3ac..b822078 100644 --- a/4_Build_End_to_End_Model.ipynb +++ b/4_Build_End_to_End_Model.ipynb @@ -958,7 +958,7 @@ "id": "Ai-pjfobEpoi" }, "source": [ - "We call `relax.vm.build` to build this function. Relax is still under development, so some of the APIs may change. Our main goal, though, is to get familiar with the overall MLC flow (Construct, transform, build) for end-to-end models.\n" + "We call `relax.build` to build this function. Relax is still under development, so some of the APIs may change. Our main goal, though, is to get familiar with the overall MLC flow (Construct, transform, build) for end-to-end models.\n" ] }, { @@ -984,7 +984,7 @@ } ], "source": [ - "ex = relax.vm.build(MyModule, target=\"llvm\")\n", + "ex = relax.build(MyModule, target=\"llvm\")\n", "type(ex)" ] }, From 1c91ccf6afc51ef3033f971e16208ecbdbbc5b21 Mon Sep 17 00:00:00 2001 From: Sudeep Agarwal Date: Thu, 11 May 2023 10:14:07 -0400 Subject: [PATCH 3/3] Additional updates to relax build usage --- 4_Build_End_to_End_Model.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/4_Build_End_to_End_Model.ipynb b/4_Build_End_to_End_Model.ipynb index b822078..ad2a0c0 100644 --- a/4_Build_End_to_End_Model.ipynb +++ b/4_Build_End_to_End_Model.ipynb @@ -1232,7 +1232,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1250,7 +1250,7 @@ } ], "source": [ - "ex = relax.vm.build(MyModuleWithExternCall, target=\"llvm\")\n", + "ex = relax.build(MyModuleWithExternCall, target=\"llvm\")\n", "vm = relax.VirtualMachine(ex, tvm.cpu())\n", "\n", "nd_res = vm[\"main\"](data_nd, \n", @@ -1348,7 +1348,7 @@ } ], "source": [ - "ex = relax.vm.build(MyModuleMixture, target=\"llvm\")\n", + "ex = relax.build(MyModuleMixture, target=\"llvm\")\n", "vm = relax.VirtualMachine(ex, tvm.cpu())\n", "\n", "nd_res = vm[\"main\"](data_nd, \n", @@ -1496,7 +1496,7 @@ } ], "source": [ - "ex = relax.vm.build(MyModuleWithParams, target=\"llvm\")\n", + "ex = relax.build(MyModuleWithParams, target=\"llvm\")\n", "vm = relax.VirtualMachine(ex, tvm.cpu())\n", "\n", "nd_res = vm[\"main\"](data_nd)\n",