diff --git a/Cargo.lock b/Cargo.lock index 5b40204fb..8d0cde95e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2950,11 +2950,11 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ - "spin 0.5.2", + "spin", ] [[package]] @@ -3532,9 +3532,9 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.12.1" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", "parking_lot_core", @@ -4426,7 +4426,7 @@ dependencies = [ "cfg-if", "getrandom", "libc", - "spin 0.9.8", + "spin", "untrusted", "windows-sys 0.52.0", ] @@ -4980,12 +4980,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "spin" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" - [[package]] name = "spin" version = "0.9.8" diff --git a/examples/notebooks/ezkl_demo_batch.ipynb b/examples/notebooks/ezkl_demo_batch.ipynb new file mode 100644 index 000000000..1eef1d531 --- /dev/null +++ b/examples/notebooks/ezkl_demo_batch.ipynb @@ -0,0 +1,771 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "n8QlFzjPRIGN" + }, + "source": [ + "# EZKL DEMO (BATCHED)\n", + "\n", + "This is mostly similar to the original EZKL demo but includes an example of how batching is handled.\n", + "\n", + "**Learning Objectives**\n", + "1. Learn some basic AI/ML techniques by training a toy model in pytorch to perform classification\n", + "2. Convert the toy model into zk circuit with ezkl to do provable inference\n", + "3. Create a solidity verifier and deploy it on Remix (you can deploy it however you like but we will use Remix as it's quite easy to setup)\n", + "4. Learn how to use batch inputs and outputs\n", + "\n", + "**Important Note**: You might want to avoid calling \"Run All\". There's some file locking issue with Colab which can cause weird bugs. To mitigate this issue you should run cell by cell on Colab." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dx81GOIySIpa" + }, + "source": [ + "# Step 1: Training a toy model\n", + "\n", + "For this demo we will use a toy data set called the Iris dataset to demonstrate how training can be performed. The Iris dataset is a collection of Iris flowers and is one of the earliest dataset used to validate classification methodologies.\n", + "\n", + "[More info in the dataset](https://archive.ics.uci.edu/dataset/53/iris)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JhHE2WMvS9NP" + }, + "source": [ + "First, we will need to import all the various dependencies required to train the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gvQ5HL1bTDWF" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from sklearn.datasets import load_iris\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import accuracy_score, precision_score, recall_score\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.autograd import Variable\n", + "import tqdm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Op9SHfZHUkaR" + }, + "source": [ + "Inspect the dataset. Note that for the Iris dataset we have 3 targets.\n", + "\n", + "0 = Iris-setosa\n", + "\n", + "1 = Iris-versicolor\n", + "\n", + "2 = Iris-virginica" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 424 + }, + "id": "C4XXA1hoU30c", + "outputId": "b92f7a06-ace9-4bcc-bd6a-d8e9a5550bd2" + }, + "outputs": [], + "source": [ + "iris = load_iris()\n", + "dataset = pd.DataFrame(\n", + " data= np.c_[iris['data'], iris['target']],\n", + " columns= iris['feature_names'] + ['target'])\n", + "dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I8RargmGTWN2" + }, + "source": [ + "Next, we can begin defining the neural net model. For this dataset we will use a small fully connected neural net.\n", + "\n", + "
\n", + "\n", + "**Note:**\n", + "For the 1st layer we use 4x20, because there are 4 features we want as inputs. After which we add a ReLU.\n", + "\n", + "For the 2nd layer we use 20x20, then add a ReLU.\n", + "\n", + "And for the last layer we use 20x3, because there are 3 classes we want to classify, then add a ReLU.\n", + "\n", + "The last ReLU function gives us an array of 3 elements where the position of the largest value gives us the target that we want to classify.\n", + "\n", + "For example, if we get [0, 0.001, 0.002] as the output of the last ReLU. As, 0.002 is the largest value, the inferred value is 2.\n", + "\n", + "\n", + "![image.png]()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dIdQ9U3yTKtP" + }, + "outputs": [], + "source": [ + "class Model(nn.Module):\n", + " # define nn\n", + " def __init__(self):\n", + " super(Model, self).__init__()\n", + " self.fc1 = nn.Linear(4, 20)\n", + " self.fc2 = nn.Linear(20, 20)\n", + " self.fc3 = nn.Linear(20, 3)\n", + " self.relu = nn.ReLU()\n", + "\n", + " def forward(self, x):\n", + " x = self.fc1(x)\n", + " x = self.relu(x)\n", + " x = self.fc2(x)\n", + " x = self.relu(x)\n", + " x = self.fc3(x)\n", + " x = self.relu(x)\n", + "\n", + " return x\n", + "\n", + "# Initialize Model\n", + "model = Model()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SfC03XLNXDPZ" + }, + "source": [ + "We will now need to split the dataset into a training set and testing set for ML. This is done fairly easily with the `train_test_split` helper function from sklearn." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "agmbEdmfUO1-", + "outputId": "e7de179f-4150-4e03-d04e-2f0ce4f9cea5" + }, + "outputs": [], + "source": [ + "train_X, test_X, train_y, test_y = train_test_split(\n", + " dataset[dataset.columns[0:4]].values, # use columns 0-4 as X\n", + " dataset.target, # use target as y\n", + " test_size=0.2 # use 20% of data for testing\n", + ")\n", + "\n", + "# Uncomment for sanity checks\n", + "# print(\"train_X: \", train_X)\n", + "# print(\"test_X: \", test_X)\n", + "print(\"train_y: \", train_y)\n", + "print(\"test_y: \", test_y)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_FrQXhAGZGS3" + }, + "source": [ + "We can now define the parameters for training, we will use the [Cross Entropy Loss](https://machinelearningmastery.com/cross-entropy-for-machine-learning/) and [Stochastic Gradient Descent Optimizer](https://en.wikipedia.org/wiki/Stochastic_gradient_descent)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9PjADXnuXbdk", + "outputId": "31588262-5e9e-4db4-82dc-ea3215a638c9" + }, + "outputs": [], + "source": [ + "# our loss function\n", + "loss_fn = nn.CrossEntropyLoss()\n", + "\n", + "# our optimizer\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n", + "\n", + "\n", + "# use 800 EPOCHS\n", + "EPOCHS = 800\n", + "\n", + "# Convert training data to pytorch variables\n", + "train_X = Variable(torch.Tensor(train_X).float())\n", + "test_X = Variable(torch.Tensor(test_X).float())\n", + "train_y = Variable(torch.Tensor(train_y.values).long())\n", + "test_y = Variable(torch.Tensor(test_y.values).long())\n", + "\n", + "\n", + "loss_list = np.zeros((EPOCHS,))\n", + "accuracy_list = np.zeros((EPOCHS,))\n", + "\n", + "\n", + "# we use tqdm for nice loading bars\n", + "for epoch in tqdm.trange(EPOCHS):\n", + "\n", + " # To train, we get a prediction from the current network\n", + " predicted_y = model(train_X)\n", + "\n", + " # Compute the loss to see how bad or good we are doing\n", + " loss = loss_fn(predicted_y, train_y)\n", + "\n", + " # Append the loss to keep track of our performance\n", + " loss_list[epoch] = loss.item()\n", + "\n", + " # Afterwards, we will need to zero the gradients to reset\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Calculate the accuracy, call torch.no_grad() to prevent updating gradients\n", + " # while calculating accuracy\n", + " with torch.no_grad():\n", + " y_pred = model(test_X)\n", + " correct = (torch.argmax(y_pred, dim=1) == test_y).type(torch.FloatTensor)\n", + " accuracy_list[epoch] = correct.mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 432 + }, + "id": "2fHJAgvwboCe", + "outputId": "e9b49c61-b3d9-4a61-e439-dcb239e1342f" + }, + "outputs": [], + "source": [ + "# Plot the Accuracy and Loss\n", + "\n", + "# import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.style.use('ggplot')\n", + "\n", + "\n", + "fig, (ax1, ax2) = plt.subplots(2, figsize=(12, 6), sharex=True)\n", + "\n", + "ax1.plot(accuracy_list)\n", + "ax1.set_ylabel(\"Accuracy\")\n", + "ax2.plot(loss_list)\n", + "ax2.set_ylabel(\"Loss\")\n", + "ax2.set_xlabel(\"epochs\");" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "djB-UtvgYbF2" + }, + "source": [ + "## Congratulations! You've just trained a neural network\n", + "\n", + "**Exercise:** The model provided is very simplistic, what are other ways the model can be improved upon?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JgtwrbMZcgla" + }, + "source": [ + "# Step 2: ZK the Neural Network\n", + "\n", + "Now that we have the Neural Network trained, we can use ezkl to easily ZK our model.\n", + "\n", + "To proceed we will now need to install `ezkl`\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "C_YiqknhdDwN" + }, + "outputs": [], + "source": [ + "# check if notebook is in colab\n", + "try:\n", + " import google.colab\n", + " import subprocess\n", + " import sys\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n", + "\n", + "# rely on local installation of ezkl if the notebook is not in colab\n", + "except:\n", + " pass\n", + "\n", + "import os\n", + "import json\n", + "import ezkl" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-b_z_d2FdVTB" + }, + "source": [ + "Next, we will need to export the neural network to a `.onnx` file. ezkl reads this `.onnx` file and converts it into a circuit which then allows you to generate proofs as well as verify proofs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YeKWP0tFeCpq" + }, + "outputs": [], + "source": [ + "# Specify all the files we need\n", + "\n", + "model_path = os.path.join('network.onnx')\n", + "data_path = os.path.join('input.json')\n", + "cal_data_path = os.path.join('calibration.json')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cQeNw_qndQ8g", + "outputId": "c293814a-9d98-4ea6-ff96-f32742bca9a5" + }, + "outputs": [], + "source": [ + "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", + "\n", + "# create a random input\n", + "# note that given that we want to use a batch of 2 we can provide a model input as follows\n", + "x = test_X[:2].reshape(2, 4)\n", + "\n", + "# Flips the neural net into inference mode\n", + "model.eval()\n", + "\n", + "# Export the model\n", + "torch.onnx.export(model, # model being run\n", + " x, # model input (or a tuple for multiple inputs)\n", + " model_path, # where to save the model (can be a file or file-like object)\n", + " export_params=True, # store the trained parameter weights inside the model file\n", + " opset_version=10, # the ONNX version to export the model to\n", + " do_constant_folding=True, # whether to execute constant folding for optimization\n", + " input_names = ['input'], # the model's input names\n", + " output_names = ['output'], # the model's output names\n", + " dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n", + " 'output' : {0 : 'batch_size'}})\n", + "\n", + "data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "print(data_array)\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + " # Serialize data into file:\n", + "json.dump(data, open(data_path, 'w'))\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9P4x79hIeiLO" + }, + "source": [ + "After which we can proceed to generate the settings file for `ezkl` and run calibrate settings to find the optimal settings for `ezkl`.\n", + "\n", + "Instantiate a PyRunArgs object and set the `batch_size` accordingly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cY25BIyreIX8" + }, + "outputs": [], + "source": [ + "!RUST_LOG=trace\n", + "# TODO: Dictionary outputs\n", + "py_run_args = ezkl.PyRunArgs()\n", + "py_run_args.variables = [(\"batch_size\", 2)]\n", + "res = ezkl.gen_settings(model=\"network.onnx\", output=\"settings.json\", py_run_args=py_run_args)\n", + "assert res == True\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pe0Zoe-ri_l4", + "outputId": "ed0fe6fe-67ca-4da8-ba9e-d6a8ee451ded" + }, + "outputs": [], + "source": [ + "# use the test set to calibrate the circuit\n", + "cal_data = dict(input_data = test_X.flatten().tolist())\n", + "\n", + "# Serialize calibration data into file:\n", + "json.dump(data, open(cal_data_path, 'w'))\n", + "\n", + "# Optimize for resources, we cap logrows at 12 to reduce setup and proving time, at the expense of accuracy\n", + "# You may want to increase the max logrows if accuracy is a concern\n", + "res = await ezkl.calibrate_settings(target = \"resources\", max_logrows = 12, scales = [2])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MFmPMBQ1jYao" + }, + "source": [ + "Next, we will compile the model. The compilation step allow us to generate proofs faster." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "De5XtpGUerkZ" + }, + "outputs": [], + "source": [ + "res = ezkl.compile_circuit()\n", + "assert res == True" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UbkuSVKljmhA" + }, + "source": [ + "Before we can setup the circuit params, we need a SRS (Structured Reference String). The SRS is used to generate the proofs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "amaTcWG6f2GI" + }, + "outputs": [], + "source": [ + "res = await ezkl.get_srs()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y92p3GhVj1Jd" + }, + "source": [ + "Now run setup, this will generate a proving key (pk) and verification key (vk). The proving key is used for proving while the verification key is used for verificaton." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fdsteit9jzfK", + "outputId": "5292ced0-d79b-46fb-df13-fd16355dac90" + }, + "outputs": [], + "source": [ + "res = ezkl.setup()\n", + "\n", + "\n", + "assert res == True" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QYlqpP3jkExm" + }, + "source": [ + "Now, we can generate a proof and verify the proof as a sanity check. We will use the \"evm\" transcript. This will allow us to provide proofs to the EVM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yoz5Vks5kaHI" + }, + "outputs": [], + "source": [ + "# Generate the Witness for the proof\n", + "\n", + "# now generate the witness file\n", + "witness_path = os.path.join('witness.json')\n", + "\n", + "res = await ezkl.gen_witness()\n", + "assert os.path.isfile(witness_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tmPmQrL1pF9p" + }, + "source": [ + "Note: Instead of having 3 instance variables which corresponds to the outputs of our neural net, we have 6 instance variables since we have 2 inputs in a batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eKkFBZX1kBdE", + "outputId": "59db6c3f-22d5-4258-c70c-6b353a5173b4" + }, + "outputs": [], + "source": [ + "# Generate the proof\n", + "\n", + "proof_path = os.path.join('proof.json')\n", + "\n", + "proof = ezkl.prove(proof_type=\"single\", proof_path=proof_path)\n", + "\n", + "print(proof)\n", + "assert os.path.isfile(proof_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "DuuH-qcOkQf1", + "outputId": "a5656e5f-1f81-4236-d9b9-8cf80e020a05" + }, + "outputs": [], + "source": [ + "# verify our proof\n", + "\n", + "res = ezkl.verify()\n", + "\n", + "assert res == True\n", + "print(\"verified\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TOSRigalkwH-" + }, + "source": [ + "## Congratulations! You have just turned your Neural Network into a Halo2 Circuit!\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "flrg3NOGwsJh" + }, + "source": [ + "\n", + "# Part 3: Deploying the Verifier\n", + "Now that we have the circuit setup, we can proceed to deploy the verifier onchain.\n", + "\n", + "We will need to setup `solc=0.8.20` for this." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CVqMeMYqktvl", + "outputId": "5287ca64-dd25-47b1-8d09-e74d9233f9a6" + }, + "outputs": [], + "source": [ + "# check if notebook is in colab\n", + "try:\n", + " import google.colab\n", + " import subprocess\n", + " import sys\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"solc-select\"])\n", + " !solc-select install 0.8.20\n", + " !solc-select use 0.8.20\n", + " !solc --version\n", + "\n", + "# rely on local installation if the notebook is not in colab\n", + "except:\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HRHvkMjVlfWU" + }, + "source": [ + "With solc in our environment we can now create the evm verifier." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gYlw20VZkva7", + "outputId": "4b0a9e46-6c60-45cb-f4d5-5fdfa17335f0" + }, + "outputs": [], + "source": [ + "sol_code_path = os.path.join('Verifier.sol')\n", + "abi_path = os.path.join('Verifier.abi')\n", + "\n", + "res = await ezkl.create_evm_verifier(\n", + " sol_code_path=sol_code_path,\n", + " abi_path=abi_path,\n", + " )\n", + "\n", + "assert res == True\n", + "assert os.path.isfile(sol_code_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jQSAVMvxrBQD", + "outputId": "e2cf297c-85ed-432c-ad11-7b5350220939" + }, + "outputs": [], + "source": [ + "onchain_input_array = []\n", + "\n", + "# using a loop\n", + "# avoiding printing last comma\n", + "formatted_output = \"[\"\n", + "for i, value in enumerate(proof[\"instances\"]):\n", + " for j, field_element in enumerate(value):\n", + " onchain_input_array.append(ezkl.felt_to_big_endian(field_element))\n", + " formatted_output += '\"' + str(onchain_input_array[-1]) + '\"'\n", + " if j != len(value) - 1:\n", + " formatted_output += \", \"\n", + " if i != len(proof[\"instances\"]) - 1:\n", + " formatted_output += \", \"\n", + "formatted_output += \"]\"\n", + "\n", + "# This will be the values you use onchain\n", + "# copy them over to remix and see if they verify\n", + "# What happens when you change a value?\n", + "print(\"pubInputs: \", formatted_output)\n", + "print(\"proof: \", proof[\"proof\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zrzPxPvZmX9b" + }, + "source": [ + "We will exit colab for the next steps. At the left of colab you can see a folder icon. Click on that.\n", + "\n", + "\n", + "You should see a `Verifier.sol`. Right-click and save it locally.\n", + "\n", + "Now go to [https://remix.ethereum.org](https://remix.ethereum.org).\n", + "\n", + "Create a new file within remix and copy the verifier code over.\n", + "\n", + "Finally, compile the code and deploy. For the demo you can deploy to the test environment within remix.\n", + "\n", + "If everything works, you would have deployed your verifer onchain! Copy the values in the cell above to the respective fields to test if the verifier is working.\n", + "\n", + "**Note that right now this setup accepts random values!**\n", + "\n", + "This might not be great for some applications. For that we will want to use a data attested verifier instead. [See this tutorial.](https://github.com/zkonduit/ezkl/blob/main/examples/notebooks/data_attest.ipynb)\n", + "\n", + "## Congratulations for making it this far!\n", + "\n", + "If you have followed the whole tutorial, you would have deployed a neural network inference model onchain! That's no mean feat!" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/onnx/boolean/gen.py b/examples/onnx/boolean/gen.py index 0a434ccf4..c6564fd87 100644 --- a/examples/onnx/boolean/gen.py +++ b/examples/onnx/boolean/gen.py @@ -9,7 +9,9 @@ def __init__(self): super(MyModel, self).__init__() def forward(self, w, x, y, z): - return [((x & y)) == (x & (y | (z ^ w)))] + a = (x & y) + b = (y & (z ^ w)) + return [a & b] circuit = MyModel() diff --git a/examples/onnx/boolean/input.json b/examples/onnx/boolean/input.json index f4c5b9b85..18d044f3c 100644 --- a/examples/onnx/boolean/input.json +++ b/examples/onnx/boolean/input.json @@ -1 +1 @@ -{"input_data": [[false, true, false], [true, false, false], [true, false, false], [false, false, false]]} \ No newline at end of file +{"input_data": [[false, true, true], [false, true, true], [true, false, false], [false, true, true]]} \ No newline at end of file diff --git a/examples/onnx/boolean/network.onnx b/examples/onnx/boolean/network.onnx index b3b9a95e8..c16e16602 100644 --- a/examples/onnx/boolean/network.onnx +++ b/examples/onnx/boolean/network.onnx @@ -1,21 +1,17 @@ -pytorch1.12.1:« -+ +pytorch2.2.2:„ +* input1 -input2 onnx::Equal_4And_0"And -' +input2 /And_output_0/And"And +) input3 -input -onnx::Or_5Xor_1"Xor -+ +input /Xor_output_0/Xor"Xor +5 input2 - -onnx::Or_5 onnx::And_6Or_2"Or -0 -input1 - onnx::And_6 onnx::Equal_7And_3"And -6 - onnx::Equal_4 - onnx::Equal_7outputEqual_4"Equal torch_jitZ! + /Xor_output_0/And_1_output_0/And_1"And +5 + /And_output_0 +/And_1_output_0output/And_2"And +main_graphZ! input    diff --git a/src/graph/node.rs b/src/graph/node.rs index 539f10d79..a46654752 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -125,6 +125,7 @@ impl RebaseScale { if (op_out_scale > (global_scale * scale_rebase_multiplier as i32)) && !inner.is_constant() && !inner.is_input() + && !inner.is_identity() { let multiplier = scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32); @@ -326,6 +327,19 @@ impl SupportedOp { SupportedOp::RebaseScale(op) => op, } } + + /// check if is the identity operation + /// # Returns + /// * `true` if the operation is the identity operation + /// * `false` otherwise + pub fn is_identity(&self) -> bool { + match self { + SupportedOp::Linear(op) => matches!(op, PolyOp::Identity { .. }), + SupportedOp::Rescaled(op) => op.inner.is_identity(), + SupportedOp::RebaseScale(op) => op.inner.is_identity(), + _ => false, + } + } } impl From>> for SupportedOp { diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 75b8610c5..8bb2eb193 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -41,7 +41,7 @@ use tract_onnx::tract_hir::{ ops::konst::Const, ops::nn::DataFormat, tract_core::ops::cast::Cast, - tract_core::ops::cnn::{conv::KernelFormat, MaxPool, PaddingSpec, SumPool}, + tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool}, }; /// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation. @@ -94,17 +94,18 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale { /// extract padding from a onnx node. pub fn extract_padding( pool_spec: &PoolSpec, - num_dims: usize, + image_size: &[usize], ) -> Result, GraphError> { - let padding = match &pool_spec.padding { - PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => { - b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect() - } - PaddingSpec::Valid => vec![(0, 0); num_dims], - _ => { - return Err(GraphError::MissingParams("padding".to_string())); - } - }; + let num_relevant_dims = pool_spec.kernel_shape.len(); + + // get the last num_relevant_dims of the image size + let image_size = &image_size[image_size.len() - num_relevant_dims..]; + + let dims = pool_spec.computed_padding(image_size); + let mut padding = Vec::new(); + for dim in dims { + padding.push((dim.pad_before, dim.pad_after)); + } Ok(padding) } @@ -1016,8 +1017,13 @@ pub fn new_op_from_onnx( if raw_values.log2().fract() == 0.0 { inputs[const_idx].decrement_use(); deleted_indices.push(const_idx); + // get the non constant index + let non_const_idx = if const_idx == 0 { 1 } else { 0 }; + op = SupportedOp::Linear(PolyOp::Identity { - out_scale: Some(input_scales[0] + raw_values.log2() as i32), + out_scale: Some( + input_scales[non_const_idx] + raw_values.log2() as i32, + ), }); } } @@ -1108,7 +1114,7 @@ pub fn new_op_from_onnx( } let stride = extract_strides(pool_spec)?; - let padding = extract_padding(pool_spec, input_dims[0].len())?; + let padding = extract_padding(pool_spec, &input_dims[0])?; let kernel_shape = &pool_spec.kernel_shape; SupportedOp::Hybrid(HybridOp::MaxPool { @@ -1178,7 +1184,7 @@ pub fn new_op_from_onnx( let pool_spec = &conv_node.pool_spec; let stride = extract_strides(pool_spec)?; - let padding = extract_padding(pool_spec, input_dims[0].len())?; + let padding = extract_padding(pool_spec, &input_dims[0])?; // if bias exists then rescale it to the input + kernel scale if input_scales.len() == 3 { @@ -1236,7 +1242,7 @@ pub fn new_op_from_onnx( let pool_spec = &deconv_node.pool_spec; let stride = extract_strides(pool_spec)?; - let padding = extract_padding(pool_spec, input_dims[0].len())?; + let padding = extract_padding(pool_spec, &input_dims[0])?; // if bias exists then rescale it to the input + kernel scale if input_scales.len() == 3 { let bias_scale = input_scales[2]; @@ -1349,7 +1355,7 @@ pub fn new_op_from_onnx( } let stride = extract_strides(pool_spec)?; - let padding = extract_padding(pool_spec, input_dims[0].len())?; + let padding = extract_padding(pool_spec, &input_dims[0])?; SupportedOp::Hybrid(HybridOp::SumPool { padding, @@ -1358,11 +1364,6 @@ pub fn new_op_from_onnx( normalized: sumpool_node.normalize, }) } - // "GlobalAvgPool" => SupportedOp::Linear(PolyOp::SumPool { - // padding: [(0, 0); 2], - // stride: (1, 1), - // kernel_shape: (inputs[0].out_dims()[0][1], inputs[0].out_dims()[0][2]), - // }), "Pad" => { let pad_node: &Pad = match node.op().downcast_ref::() { Some(b) => b, diff --git a/tests/py_integration_tests.rs b/tests/py_integration_tests.rs index 70ea5433a..555e63600 100644 --- a/tests/py_integration_tests.rs +++ b/tests/py_integration_tests.rs @@ -123,7 +123,8 @@ mod py_tests { } } - const TESTS: [&str; 33] = [ + const TESTS: [&str; 34] = [ + "ezkl_demo_batch.ipynb", "proof_splitting.ipynb", // 0 "variance.ipynb", "mnist_gan.ipynb",