diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e866fb90f..61313e970 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -257,6 +257,8 @@ jobs: run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1 - name: KZG prove and verify tests (EVM + on chain inputs & outputs) run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1 + - name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes) + run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1 - name: KZG prove and verify tests (EVM) run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify --test-threads 1 - name: KZG prove and verify tests (EVM + hashed inputs) @@ -445,7 +447,7 @@ jobs: run: cargo nextest run neg_tests::neg_examples_ python-tests: - runs-on: 256gb + runs-on: self-hosted needs: [build, library-tests, docs] steps: - uses: actions/checkout@v4 @@ -549,8 +551,10 @@ jobs: # # now dump the contents of the file into a file called kaggle.json # echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json # chmod 600 /home/ubuntu/.kaggle/kaggle.json + - name: Hashed DA tutorial + run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_24_expects - name: Little transformer tutorial - run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_8_expects --no-capture + run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_8_expects - name: Stacked Regression tutorial run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_23_expects - name: Linear Regression tutorial @@ -593,7 +597,7 @@ jobs: run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_2_expects - name: Hashed tutorial run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_3_expects - - name: Data attestation tutorial + - name: DA tutorial run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_5_expects - name: Variance tutorial run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_6_expects diff --git a/contracts/AttestData.sol b/contracts/AttestData.sol index 4d780aa28..e542075e5 100644 --- a/contracts/AttestData.sol +++ b/contracts/AttestData.sol @@ -181,10 +181,6 @@ contract DataAttestation is LoadInstances { if (mulmod(uint256(x), scale, decimals) * 2 >= decimals) { output += 1; } - // In the interest of keeping feature parity with the quantization done on the EZKL cli, - // we set the fixed point value type to be int128. Any value greater than that will throw an error - // as it does on the EZKL cli. - require(output <= uint128(type(int128).max), "Significant bit truncation"); quantized_data = neg ? -int256(output): int256(output); } /** diff --git a/contracts/QuantizeData.sol b/contracts/QuantizeData.sol index 6c8a4e020..ce832b42a 100644 --- a/contracts/QuantizeData.sol +++ b/contracts/QuantizeData.sol @@ -3,20 +3,26 @@ pragma solidity ^0.8.17; contract QuantizeData { - - /** * @notice EZKL P value * @dev In order to prevent the verifier from accepting two version of the same instance, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a * @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P. */ - uint256 constant ORDER = uint256(0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001); + uint256 constant ORDER = + uint256( + 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 + ); + /** * @notice Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or denominator == 0 * @dev Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) * with further edits by Uniswap Labs also under MIT license. */ - function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) { + function mulDiv( + uint256 x, + uint256 y, + uint256 denominator + ) internal pure returns (uint256 result) { unchecked { // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2^256 and mod 2^256 - 1, then use // use the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256 @@ -96,29 +102,34 @@ contract QuantizeData { return result; } } - function quantize_data(bytes[] memory data, uint256[] memory decimals, uint256[] memory scales) external pure returns (int128[] memory quantized_data) { - quantized_data = new int128[](data.length); - for(uint i; i < data.length; i++){ + + function quantize_data( + bytes[] memory data, + uint256[] memory decimals, + uint256[] memory scales + ) external pure returns (int256[] memory quantized_data) { + quantized_data = new int256[](data.length); + for (uint i; i < data.length; i++) { int x = abi.decode(data[i], (int256)); bool neg = x < 0; if (neg) x = -x; - uint denom = 10**decimals[i]; - uint output = mulDiv(uint256(x), scales[i], denom); - if (mulmod(uint256(x), scales[i], denom)*2 >= denom) { + uint denom = 10 ** decimals[i]; + uint scale = 1 << scales[i]; + uint output = mulDiv(uint256(x), scale, denom); + if (mulmod(uint256(x), scale, denom) * 2 >= denom) { output += 1; } - // In the interest of keeping feature parity with the quantization done on the EZKL cli, - // we set the fixed point value type to be int128. Any value greater than that will throw an error - // as it does on the EZKL cli. - require(output <= uint128(type(int128).max), "Significant bit truncation"); - quantized_data[i] = neg ? int128(-int256(output)): int128(int256(output)); + + quantized_data[i] = neg ? -int256(output) : int256(output); } } - function to_field_element(int128[] memory quantized_data) public pure returns(uint256[] memory output) { + function to_field_element( + int128[] memory quantized_data + ) public pure returns (uint256[] memory output) { output = new uint256[](quantized_data.length); - for(uint i; i < quantized_data.length; i++){ + for (uint i; i < quantized_data.length; i++) { output[i] = uint256(quantized_data[i] + int(ORDER)) % ORDER; } } -} \ No newline at end of file +} diff --git a/contracts/TestReads.sol b/contracts/TestReads.sol index 294a408f6..2a263f028 100644 --- a/contracts/TestReads.sol +++ b/contracts/TestReads.sol @@ -2,10 +2,10 @@ pragma solidity ^0.8.17; contract TestReads { - int[] public arr; + constructor(int256[] memory _numbers) { - for(uint256 i = 0; i < _numbers.length; i++) { + for (uint256 i = 0; i < _numbers.length; i++) { arr.push(_numbers[i]); } } diff --git a/examples/notebooks/data_attest_hashed.ipynb b/examples/notebooks/data_attest_hashed.ipynb new file mode 100644 index 000000000..8e693c471 --- /dev/null +++ b/examples/notebooks/data_attest_hashed.ipynb @@ -0,0 +1,659 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# data-attest-ezkl hashed\n", + "\n", + "Here's an example leveraging EZKL whereby the hashes of the outputs to the model are read and attested to from an on-chain source.\n", + "\n", + "In this setup:\n", + "- the hashes of outputs are publicly known to the prover and verifier\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First we import the necessary dependencies and set up logging to be as informative as possible. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# check if notebook is in colab\n", + "try:\n", + " # install ezkl\n", + " import google.colab\n", + " import subprocess\n", + " import sys\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n", + "\n", + "# rely on local installation of ezkl if the notebook is not in colab\n", + "except:\n", + " pass\n", + "\n", + "\n", + "from torch import nn\n", + "import ezkl\n", + "import os\n", + "import json\n", + "import logging\n", + "\n", + "# uncomment for more descriptive logging \n", + "# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n", + "# logging.basicConfig(format=FORMAT)\n", + "# logging.getLogger().setLevel(logging.DEBUG)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "# Defines the model\n", + "\n", + "class MyModel(nn.Module):\n", + " def __init__(self):\n", + " super(MyModel, self).__init__()\n", + " self.layer = nn.AvgPool2d(2, 1, (1, 1))\n", + "\n", + " def forward(self, x):\n", + " return self.layer(x)[0]\n", + "\n", + "\n", + "circuit = MyModel()\n", + "\n", + "# this is where you'd train your model\n", + "\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n", + "Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n", + "\n", + "You can replace the random `x` with real data if you so wish. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n", + "\n", + "# Flips the neural net into inference mode\n", + "circuit.eval()\n", + "\n", + " # Export the model\n", + "torch.onnx.export(circuit, # model being run\n", + " x, # model input (or a tuple for multiple inputs)\n", + " \"network.onnx\", # where to save the model (can be a file or file-like object)\n", + " export_params=True, # store the trained parameter weights inside the model file\n", + " opset_version=10, # the ONNX version to export the model to\n", + " do_constant_folding=True, # whether to execute constant folding for optimization\n", + " input_names = ['input'], # the model's input names\n", + " output_names = ['output'], # the model's output names\n", + " dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n", + " 'output' : {0 : 'batch_size'}})\n", + "\n", + "data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + " # Serialize data into file:\n", + "json.dump(data, open(\"input.json\", 'w' ))\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "import time\n", + "import threading\n", + "\n", + "# make sure anvil is running locally\n", + "# $ anvil -p 3030\n", + "\n", + "RPC_URL = \"http://localhost:3030\"\n", + "\n", + "# Save process globally\n", + "anvil_process = None\n", + "\n", + "def start_anvil():\n", + " global anvil_process\n", + " if anvil_process is None:\n", + " anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n", + " if anvil_process.returncode is not None:\n", + " raise Exception(\"failed to start anvil process\")\n", + " time.sleep(3)\n", + "\n", + "def stop_anvil():\n", + " global anvil_process\n", + " if anvil_process is not None:\n", + " anvil_process.terminate()\n", + " anvil_process = None\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n", + "- `input_visibility` defines the visibility of the model inputs\n", + "- `param_visibility` defines the visibility of the model weights and constants and parameters \n", + "- `output_visibility` defines the visibility of the model outputs\n", + "\n", + "Here we create the following setup:\n", + "- `input_visibility`: \"private\"\n", + "- `param_visibility`: \"private\"\n", + "- `output_visibility`: hashed\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ezkl\n", + "\n", + "model_path = os.path.join('network.onnx')\n", + "compiled_model_path = os.path.join('network.compiled')\n", + "pk_path = os.path.join('test.pk')\n", + "vk_path = os.path.join('test.vk')\n", + "settings_path = os.path.join('settings.json')\n", + "srs_path = os.path.join('kzg.srs')\n", + "data_path = os.path.join('input.json')\n", + "\n", + "run_args = ezkl.PyRunArgs()\n", + "run_args.input_visibility = \"private\"\n", + "run_args.param_visibility = \"private\"\n", + "run_args.output_visibility = \"hashed\"\n", + "run_args.variables = [(\"batch_size\", 1)]\n", + "\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n", + "\n", + "You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!RUST_LOG=trace\n", + "# TODO: Dictionary outputs\n", + "res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n", + "assert res == True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# generate a bunch of dummy calibration data\n", + "cal_data = {\n", + " \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n", + "}\n", + "\n", + "cal_path = os.path.join('val_data.json')\n", + "# save as json file\n", + "with open(cal_path, \"w\") as f:\n", + " json.dump(cal_data, f)\n", + "\n", + "res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", + "assert res == True" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n", + "\n", + "These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res = ezkl.get_srs(srs_path, settings_path)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!export RUST_BACKTRACE=1\n", + "\n", + "witness_path = \"witness.json\"\n", + "\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(ezkl.vecu64_to_felt(res['processed_outputs']['poseidon_hash'][0]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now post the hashes of the outputs to the chain. This is the data that will be read from and attested to." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from web3 import Web3, HTTPProvider, utils\n", + "from solcx import compile_standard\n", + "from decimal import Decimal\n", + "import json\n", + "import os\n", + "import torch\n", + "\n", + "\n", + "# setup web3 instance\n", + "w3 = Web3(HTTPProvider(RPC_URL))\n", + "\n", + "def test_on_chain_data(res):\n", + " # Step 0: Convert the tensor to a flat list\n", + " data = [int(ezkl.vecu64_to_felt(res['processed_outputs']['poseidon_hash'][0]), 0)]\n", + "\n", + " # Step 1: Prepare the data\n", + " # Step 2: Prepare and compile the contract.\n", + " # We are using a test contract here but in production you would\n", + " # use whatever contract you are fetching data from.\n", + " contract_source_code = '''\n", + " // SPDX-License-Identifier: UNLICENSED\n", + " pragma solidity ^0.8.17;\n", + "\n", + " contract TestReads {\n", + "\n", + " uint[] public arr;\n", + " constructor(uint256[] memory _numbers) {\n", + " for(uint256 i = 0; i < _numbers.length; i++) {\n", + " arr.push(_numbers[i]);\n", + " }\n", + " }\n", + " }\n", + " '''\n", + "\n", + " compiled_sol = compile_standard({\n", + " \"language\": \"Solidity\",\n", + " \"sources\": {\"testreads.sol\": {\"content\": contract_source_code}},\n", + " \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n", + " })\n", + "\n", + " # Get bytecode\n", + " bytecode = compiled_sol['contracts']['testreads.sol']['TestReads']['evm']['bytecode']['object']\n", + "\n", + " # Get ABI\n", + " # In production if you are reading from really large contracts you can just use\n", + " # a stripped down version of the ABI of the contract you are calling, containing only the view functions you will fetch data from.\n", + " abi = json.loads(compiled_sol['contracts']['testreads.sol']['TestReads']['metadata'])['output']['abi']\n", + "\n", + " # Step 3: Deploy the contract\n", + " TestReads = w3.eth.contract(abi=abi, bytecode=bytecode)\n", + " tx_hash = TestReads.constructor(data).transact()\n", + " tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n", + " # If you are deploying to production you can skip the 3 lines of code above and just instantiate the contract like this,\n", + " # passing the address and abi of the contract you are fetching data from.\n", + " contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n", + "\n", + " # Step 4: Interact with the contract\n", + " calldata = []\n", + " for i, _ in enumerate(data):\n", + " call = contract.functions.arr(i).build_transaction()\n", + " calldata.append((call['data'][2:], 0))\n", + "\n", + " # Prepare the calls_to_account object\n", + " # If you were calling view functions across multiple contracts,\n", + " # you would have multiple entries in the calls_to_account array,\n", + " # one for each contract.\n", + " calls_to_account = [{\n", + " 'call_data': calldata,\n", + " 'address': contract.address[2:], # remove the '0x' prefix\n", + " }]\n", + "\n", + " print(f'calls_to_account: {calls_to_account}')\n", + "\n", + " return calls_to_account\n", + "\n", + "# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n", + "start_anvil()\n", + "\n", + "# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n", + "calls_to_account = test_on_chain_data(res)\n", + "\n", + "data = dict(input_data = [data_array], output_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n", + "\n", + "# Serialize on-chain data into file:\n", + "json.dump(data, open(\"input.json\", 'w'))\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# HERE WE SETUP THE CIRCUIT PARAMS\n", + "# WE GOT KEYS\n", + "# WE GOT CIRCUIT PARAMETERS\n", + "# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n", + "res = ezkl.setup(\n", + " compiled_model_path,\n", + " vk_path,\n", + " pk_path,\n", + " srs_path,\n", + " )\n", + "\n", + "assert res == True\n", + "assert os.path.isfile(vk_path)\n", + "assert os.path.isfile(pk_path)\n", + "assert os.path.isfile(settings_path)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we generate a full proof. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# GENERATE A PROOF\n", + "\n", + "proof_path = os.path.join('test.pf')\n", + "\n", + "res = ezkl.prove(\n", + " witness_path,\n", + " compiled_model_path,\n", + " pk_path,\n", + " proof_path,\n", + " srs_path,\n", + " \"single\",\n", + " )\n", + "\n", + "print(res)\n", + "assert os.path.isfile(proof_path)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And verify it as a sanity check. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# VERIFY IT\n", + "\n", + "res = ezkl.verify(\n", + " proof_path,\n", + " settings_path,\n", + " vk_path,\n", + " srs_path,\n", + " )\n", + "\n", + "assert res == True\n", + "print(\"verified\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now create and then deploy a vanilla evm verifier." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "abi_path = 'test.abi'\n", + "sol_code_path = 'test.sol'\n", + "\n", + "res = ezkl.create_evm_verifier(\n", + " vk_path,\n", + " srs_path,\n", + " settings_path,\n", + " sol_code_path,\n", + " abi_path,\n", + " )\n", + "assert res == True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "addr_path_verifier = \"addr_verifier.txt\"\n", + "\n", + "res = ezkl.deploy_evm(\n", + " addr_path_verifier,\n", + " sol_code_path,\n", + " 'http://127.0.0.1:3030'\n", + ")\n", + "\n", + "assert res == True" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "abi_path = 'test.abi'\n", + "sol_code_path = 'test.sol'\n", + "input_path = 'input.json'\n", + "\n", + "res = ezkl.create_evm_data_attestation(\n", + " vk_path,\n", + " srs_path,\n", + " settings_path,\n", + " sol_code_path,\n", + " abi_path,\n", + " input_path,\n", + " )" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n", + "So should only be used for testing purposes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "addr_path_da = \"addr_da.txt\"\n", + "\n", + "res = ezkl.deploy_da_evm(\n", + " addr_path_da,\n", + " input_path,\n", + " settings_path,\n", + " sol_code_path,\n", + " RPC_URL,\n", + " )\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# read the verifier address\n", + "addr_verifier = None\n", + "with open(addr_path_verifier, 'r') as f:\n", + " addr = f.read()\n", + "#read the data attestation address\n", + "addr_da = None\n", + "with open(addr_path_da, 'r') as f:\n", + " addr_da = f.read()\n", + "\n", + "res = ezkl.verify_evm(\n", + " proof_path,\n", + " addr,\n", + " RPC_URL,\n", + " addr_da,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ezkl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/eth.rs b/src/eth.rs index fe0556656..1f3736d8d 100644 --- a/src/eth.rs +++ b/src/eth.rs @@ -1,4 +1,5 @@ -use crate::graph::input::{CallsToAccount, GraphData}; +use crate::graph::input::{CallsToAccount, FileSourceInner, GraphData}; +use crate::graph::modules::{ELGAMAL_INSTANCES, POSEIDON_INSTANCES}; use crate::graph::DataSource; #[cfg(not(target_arch = "wasm32"))] use crate::graph::GraphSettings; @@ -26,9 +27,9 @@ use ethers::{ prelude::{LocalWallet, Wallet}, utils::{Anvil, AnvilInstance}, }; +use halo2_solidity_verifier::encode_calldata; use halo2curves::bn256::{Fr, G1Affine}; use halo2curves::group::ff::PrimeField; -use halo2_solidity_verifier::encode_calldata; use log::{debug, info, warn}; use std::error::Error; use std::path::PathBuf; @@ -140,25 +141,64 @@ pub async fn deploy_da_verifier_via_solidity( // The data that will be stored in the test contracts that will eventually be read from. let mut calls_to_accounts = vec![]; - let instance_shapes = settings.model_instance_shapes; + let mut instance_shapes = vec![]; + let mut model_instance_offset = 0; + + if settings.run_args.input_visibility.is_hashed() { + instance_shapes.push(POSEIDON_INSTANCES) + } else if settings.run_args.input_visibility.is_encrypted() { + instance_shapes.push(ELGAMAL_INSTANCES) + } else if settings.run_args.input_visibility.is_public() { + for idx in 0..settings.model_input_scales.len() { + let shape = &settings.model_instance_shapes[idx]; + instance_shapes.push(shape.iter().product::()); + model_instance_offset += 1; + } + } + + if settings.run_args.param_visibility.is_hashed() + || settings.run_args.param_visibility.is_encrypted() + { + todo!() + } + + if settings.run_args.output_visibility.is_hashed() { + instance_shapes.push(POSEIDON_INSTANCES) + } else if settings.run_args.output_visibility.is_encrypted() { + instance_shapes.push(ELGAMAL_INSTANCES) + } else if settings.run_args.output_visibility.is_public() { + for idx in model_instance_offset..model_instance_offset + settings.model_output_scales.len() + { + let shape = &settings.model_instance_shapes[idx]; + instance_shapes.push(shape.iter().product::()); + } + } + + println!("instance_shapes: {:#?}", instance_shapes); let mut instance_idx = 0; let mut contract_instance_offset = 0; if let DataSource::OnChain(source) = input.input_data { - let input_scales = settings.model_input_scales; + if settings.run_args.input_visibility.is_hashed_public() { + // set scales 1.0 + scales.extend(vec![0; instance_shapes[instance_idx]]); + instance_idx += 1; + } else if settings.run_args.input_visibility.is_encrypted() { + // set scales 1.0 + scales.extend(vec![0; instance_shapes[instance_idx]]); + instance_idx += 1; + } else { + let input_scales = settings.model_input_scales; + // give each input a scale + for scale in input_scales { + scales.extend(vec![scale; instance_shapes[instance_idx]]); + instance_idx += 1; + } + } for call in source.calls { calls_to_accounts.push(call); } - - // give each input a scale - for scale in input_scales { - scales.extend(vec![ - scale; - instance_shapes[instance_idx].iter().product::() - ]); - instance_idx += 1; - } } else if let DataSource::File(source) = input.input_data { if settings.run_args.input_visibility.is_public() { instance_idx += source.len(); @@ -169,19 +209,23 @@ pub async fn deploy_da_verifier_via_solidity( } if let Some(DataSource::OnChain(source)) = input.output_data { - let output_scales = settings.model_output_scales; + if settings.run_args.output_visibility.is_hashed_public() { + // set scales 1.0 + scales.extend(vec![0; instance_shapes[instance_idx]]); + } else if settings.run_args.output_visibility.is_encrypted() { + // set scales 1.0 + scales.extend(vec![0; instance_shapes[instance_idx]]); + } else { + let input_scales = settings.model_output_scales; + // give each output a scale + for scale in input_scales { + scales.extend(vec![scale; instance_shapes[instance_idx]]); + instance_idx += 1; + } + } for call in source.calls { calls_to_accounts.push(call); } - - // give each input a scale - for scale in output_scales { - scales.extend(vec![ - scale; - instance_shapes[instance_idx].iter().product::() - ]); - instance_idx += 1; - } } let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() { @@ -310,11 +354,7 @@ pub async fn verify_proof_via_solidity( ) -> Result> { let flattened_instances = proof.instances.into_iter().flatten(); - let encoded = encode_calldata( - None, - &proof.proof, - &flattened_instances.collect::>(), - ); + let encoded = encode_calldata(None, &proof.proof, &flattened_instances.collect::>()); info!("encoded: {:#?}", hex::encode(&encoded)); let (anvil, client) = setup_eth_backend(rpc_url, None).await?; @@ -374,7 +414,7 @@ fn count_decimal_places(num: f32) -> usize { /// pub async fn setup_test_contract( client: Arc, - data: &[Vec], + data: &[Vec], ) -> Result<(ContractInstance, M>, Vec), Box> { // save the abi to a tmp file let mut sol_path = std::env::temp_dir(); @@ -391,10 +431,18 @@ pub async fn setup_test_contract( let mut decimals = vec![]; let mut scaled_by_decimals_data = vec![]; for input in &data[0] { - let decimal_places = count_decimal_places(*input) as u8; - let scaled_by_decimals = input * f32::powf(10., decimal_places.into()); - scaled_by_decimals_data.push(scaled_by_decimals as i128); - decimals.push(decimal_places); + if input.is_float() { + let input = input.to_float() as f32; + let decimal_places = count_decimal_places(input) as u8; + let scaled_by_decimals = input * f32::powf(10., decimal_places.into()); + scaled_by_decimals_data.push(I256::from(scaled_by_decimals as i128)); + decimals.push(decimal_places); + } else if input.is_field() { + let input = input.to_field(0); + let hex_str_fr = format!("{:?}", input); + scaled_by_decimals_data.push(I256::from_raw(U256::from_str_radix(&hex_str_fr, 16)?)); + decimals.push(0); + } } let contract = factory.deploy(scaled_by_decimals_data)?.send().await?; @@ -421,11 +469,8 @@ pub async fn verify_proof_with_data_attestation( public_inputs.push(u); } - let encoded_verifier = encode_calldata( - None, - &proof.proof, - &flattened_instances.collect::>(), - ); + let encoded_verifier = + encode_calldata(None, &proof.proof, &flattened_instances.collect::>()); info!("encoded: {:#?}", hex::encode(&encoded_verifier)); @@ -504,7 +549,7 @@ pub fn get_provider(rpc_url: &str) -> Result, Box> { /// the number of decimals of the floating point value on chain. pub async fn test_on_chain_data( client: Arc, - data: &[Vec], + data: &[Vec], ) -> Result, Box> { let (contract, decimals) = setup_test_contract(client.clone(), data).await?; @@ -563,7 +608,7 @@ pub async fn read_on_chain_inputs( #[cfg(not(target_arch = "wasm32"))] pub async fn evm_quantize( client: Arc, - scales: Vec, + scales: Vec, data: &(Vec, Vec), ) -> Result, Box> { // save the sol to a tmp file @@ -680,13 +725,13 @@ pub fn fix_da_sol( let mut accounts_len = 0; let mut contract = ATTESTDATA_SOL.to_string(); let load_instances = LOADINSTANCES_SOL.to_string(); - // replace the import statment with the load_instances contract, not including the + // replace the import statment with the load_instances contract, not including the // `SPDX-License-Identifier: MIT pragma solidity ^0.8.20;` at the top of the file contract = contract.replace( "import './LoadInstances.sol';", &load_instances[load_instances.find("contract").unwrap()..], ); - + // fill in the quantization params and total calls // as constants to the contract to save on gas if let Some(input_data) = input_data { diff --git a/src/execute.rs b/src/execute.rs index d60bbe2d7..5b67f96c3 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -904,8 +904,8 @@ pub(crate) fn create_evm_data_attestation( let data = GraphData::from_path(input)?; let output_data = if let Some(DataSource::OnChain(source)) = data.output_data { - if !visibility.output.is_public() { - todo!("we currently don't support private output data on chain") + if visibility.output.is_private() { + todo!("private output data on chain is not supported on chain") } let mut on_chain_output_data = vec![]; for call in source.calls { @@ -917,7 +917,7 @@ pub(crate) fn create_evm_data_attestation( }; let input_data = if let DataSource::OnChain(source) = data.input_data { - if !visibility.input.is_public() { + if visibility.input.is_private() { todo!("we currently don't support private input data on chain") } let mut on_chain_input_data = vec![]; diff --git a/src/graph/input.rs b/src/graph/input.rs index cc38b7394..ffc5c1fd8 100644 --- a/src/graph/input.rs +++ b/src/graph/input.rs @@ -37,6 +37,21 @@ pub enum FileSourceInner { Field(Fp), } +impl FileSourceInner { + /// + pub fn is_float(&self) -> bool { + matches!(self, FileSourceInner::Float(_)) + } + /// + pub fn is_bool(&self) -> bool { + matches!(self, FileSourceInner::Bool(_)) + } + /// + pub fn is_field(&self) -> bool { + matches!(self, FileSourceInner::Field(_)) + } +} + impl Serialize for FileSourceInner { fn serialize(&self, serializer: S) -> Result where @@ -259,12 +274,10 @@ impl OnChainSource { pub async fn test_from_file_data( data: &FileSource, scales: Vec, - shapes: Vec>, + mut shapes: Vec>, rpc: Option<&str>, ) -> Result<(Vec>, Self), Box> { use crate::eth::{evm_quantize, read_on_chain_inputs, test_on_chain_data}; - use crate::graph::scale_to_multiplier; - use itertools::Itertools; use log::debug; // Set up local anvil instance for reading on-chain data @@ -272,15 +285,16 @@ impl OnChainSource { let address = client.address(); - let scales: Vec = scales.into_iter().map(scale_to_multiplier).collect(); - - // unquantize data - let float_data = data - .iter() - .map(|t| t.iter().map(|e| (e.to_float() as f32)).collect_vec()) - .collect::>>(); + let mut scales: Vec = scales; + // set scales to 1 where data is a field element + for (idx, i) in data.iter().enumerate() { + if i.iter().all(|e| e.is_field()) { + scales[idx] = 0; + shapes[idx] = vec![i.len()]; + } + } - let calls_to_accounts = test_on_chain_data(client.clone(), &float_data).await?; + let calls_to_accounts = test_on_chain_data(client.clone(), &data).await?; debug!("Calls to accounts: {:?}", calls_to_accounts); let inputs = read_on_chain_inputs(client.clone(), address, &calls_to_accounts).await?; debug!("Inputs: {:?}", inputs); @@ -348,6 +362,7 @@ pub enum DataSource { #[cfg(not(target_arch = "wasm32"))] DB(PostgresSource), } + impl Default for DataSource { fn default() -> Self { DataSource::File(vec![vec![]]) diff --git a/src/graph/mod.rs b/src/graph/mod.rs index a9d1dca2e..34161f2f3 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -705,7 +705,7 @@ impl GraphCircuit { // quantize the supplied data using the provided scale + QuantizeData.sol let quantized_evm_inputs = evm_quantize( client, - scales.into_iter().map(scale_to_multiplier).collect(), + scales, &inputs, ) .await?; @@ -1004,7 +1004,7 @@ impl GraphCircuit { TestDataSource::OnChain ) { // if not public then fail - if !self.settings().run_args.input_visibility.is_public() { + if self.settings().run_args.input_visibility.is_private() { return Err("Cannot use on-chain data source as private data".into()); } @@ -1016,11 +1016,11 @@ impl GraphCircuit { ), }; // Get the flatten length of input_data - let length = input_data.iter().map(|x| x.len()).sum(); - let scales = vec![self.settings().run_args.input_scale; length]; + // if the input source is a field then set scale to 0 + let datam: (Vec>, OnChainSource) = OnChainSource::test_from_file_data( input_data, - scales, + self.model().graph.get_input_scales(), self.model().graph.input_shapes(), test_on_chain_data.rpc.as_deref(), ) @@ -1032,7 +1032,7 @@ impl GraphCircuit { TestDataSource::OnChain ) { // if not public then fail - if !self.settings().run_args.output_visibility.is_public() { + if self.settings().run_args.output_visibility.is_private() { return Err("Cannot use on-chain data source as private data".into()); } diff --git a/src/graph/modules.rs b/src/graph/modules.rs index e84f97bbc..af3b2451d 100644 --- a/src/graph/modules.rs +++ b/src/graph/modules.rs @@ -15,6 +15,11 @@ use super::{VarVisibility, Visibility}; /// poseidon len to hash in tree pub const POSEIDON_LEN_GRAPH: usize = 32; +/// ElGamal number of instances +pub const ELGAMAL_INSTANCES: usize = 4; +/// Poseidon number of instancess +pub const POSEIDON_INSTANCES: usize = 1; + /// Poseidon module type pub type ModulePoseidon = PoseidonChip; diff --git a/src/graph/vars.rs b/src/graph/vars.rs index c839bd465..5104561e8 100644 --- a/src/graph/vars.rs +++ b/src/graph/vars.rs @@ -145,6 +145,11 @@ impl Visibility { pub fn is_fixed(&self) -> bool { matches!(&self, Visibility::Fixed) } + #[allow(missing_docs)] + pub fn is_private(&self) -> bool { + matches!(&self, Visibility::Private) || self.is_hashed_private() + } + #[allow(missing_docs)] pub fn is_public(&self) -> bool { matches!(&self, Visibility::Public) diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index af416ac37..c8433b1c2 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -4,8 +4,8 @@ mod native_tests { use core::panic; // use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD; - use ezkl::graph::input::{FileSource, GraphData}; - use ezkl::graph::{DataSource, GraphSettings, Visibility}; + use ezkl::graph::input::{FileSource, FileSourceInner, GraphData}; + use ezkl::graph::{DataSource, GraphSettings, GraphWitness, Visibility}; use lazy_static::lazy_static; use rand::Rng; use std::env::var; @@ -870,7 +870,7 @@ mod native_tests { let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(test_dir.path().to_str().unwrap(), test); let _anvil_child = crate::native_tests::start_anvil(true); - kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "file"); + kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "file", "public", "private"); test_dir.close().unwrap(); } @@ -880,7 +880,7 @@ mod native_tests { let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(test_dir.path().to_str().unwrap(), test); let _anvil_child = crate::native_tests::start_anvil(true); - kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "on-chain"); + kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "on-chain", "private", "public"); test_dir.close().unwrap(); } @@ -890,7 +890,17 @@ mod native_tests { let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(test_dir.path().to_str().unwrap(), test); let _anvil_child = crate::native_tests::start_anvil(true); - kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "on-chain"); + kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "on-chain", "public", "public"); + test_dir.close().unwrap(); + } + + #(#[test_case(TESTS_ON_CHAIN_INPUT[N])])* + fn kzg_evm_on_chain_input_output_hashed_prove_and_verify_(test: &str) { + crate::native_tests::init_binary(); + let test_dir = TempDir::new(test).unwrap(); + let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(test_dir.path().to_str().unwrap(), test); + let _anvil_child = crate::native_tests::start_anvil(true); + kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "on-chain", "hashed", "hashed"); test_dir.close().unwrap(); } }); @@ -2411,10 +2421,9 @@ mod native_tests { example_name: String, input_source: &str, output_source: &str, + input_visbility: &str, + output_visbility: &str, ) { - // set up the circuit - let input_visbility = "public"; - let output_visbility = "public"; let model_path = format!("{}/{}/network.onnx", test_dir, example_name); let settings_path = format!("{}/{}/settings.json", test_dir, example_name); @@ -2464,30 +2473,57 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "setup-test-evm-data", + "gen-witness", "-D", data_path.as_str(), "-M", &model_path, - "--test-data", - test_on_chain_data_path.as_str(), - rpc_arg.as_str(), - test_input_source.as_str(), - test_output_source.as_str(), + "-O", + &witness_path, ]) .status() .expect("failed to execute process"); assert!(status.success()); + // load witness + let witness: GraphWitness = GraphWitness::from_path(witness_path.clone().into()).unwrap(); + let mut input: GraphData = GraphData::from_path(data_path.clone().into()).unwrap(); + + if input_visbility == "hashed" { + let hashes = witness.processed_inputs.unwrap().poseidon_hash.unwrap(); + input.input_data = DataSource::File( + hashes + .iter() + .map(|h| vec![FileSourceInner::Field(*h)]) + .collect(), + ); + } + if output_visbility == "hashed" { + let hashes = witness.processed_outputs.unwrap().poseidon_hash.unwrap(); + input.output_data = Some(DataSource::File( + hashes + .iter() + .map(|h| vec![FileSourceInner::Field(*h)]) + .collect(), + )); + } + + println!("input is {:?}", input); + + input.save(data_path.clone().into()).unwrap(); + let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "gen-witness", + "setup-test-evm-data", "-D", - test_on_chain_data_path.as_str(), + data_path.as_str(), "-M", &model_path, - "-O", - &witness_path, + "--test-data", + test_on_chain_data_path.as_str(), + rpc_arg.as_str(), + test_input_source.as_str(), + test_output_source.as_str(), ]) .status() .expect("failed to execute process"); @@ -2653,7 +2689,7 @@ mod native_tests { let deployed_addr_arg = format!("--addr={}", addr_da); - let mut args = vec![ + let args = vec![ "test-update-account-calls", deployed_addr_arg.as_str(), "-D", @@ -2667,7 +2703,14 @@ mod native_tests { assert!(status.success()); // As sanity check, add example that should fail. - args[2] = PF_FAILURE; + let args = vec![ + "verify-evm", + "--proof-path", + PF_FAILURE, + deployed_addr_verifier_arg.as_str(), + deployed_addr_da_arg.as_str(), + rpc_arg.as_str(), + ]; let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args(args) .status() diff --git a/tests/py_integration_tests.rs b/tests/py_integration_tests.rs index 8bead26cd..cae8e91a9 100644 --- a/tests/py_integration_tests.rs +++ b/tests/py_integration_tests.rs @@ -115,7 +115,7 @@ mod py_tests { } } - const TESTS: [&str; 24] = [ + const TESTS: [&str; 25] = [ "mnist_gan.ipynb", // "mnist_vae.ipynb", "keras_simple_demo.ipynb", @@ -141,6 +141,7 @@ mod py_tests { "gcn.ipynb", "linear_regression.ipynb", "stacked_regression.ipynb", + "data_attest_hashed.ipynb", ]; macro_rules! test_func { @@ -153,7 +154,7 @@ mod py_tests { use super::*; - seq!(N in 0..=23 { + seq!(N in 0..=24 { #(#[test_case(TESTS[N])])* fn run_notebook_(test: &str) {