diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index eda30e81d..6132dd1a9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -650,6 +650,8 @@ jobs: run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade - name: Build python ezkl run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release + - name: Postgres tutorials + run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture - name: Postgres tutorials run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture - name: Tictactoe tutorials diff --git a/examples/notebooks/felt_conversion_test.ipynb b/examples/notebooks/felt_conversion_test.ipynb new file mode 100644 index 000000000..e3bd7f6bb --- /dev/null +++ b/examples/notebooks/felt_conversion_test.ipynb @@ -0,0 +1,112 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "import torch\n", + "import ezkl\n", + "import json\n", + "import subprocess\n", + "from pathlib import Path\n", + "\n", + "\n", + "class Passthrough(torch.nn.Module):\n", + " def __init__(self, input_size=10):\n", + " super().__init__()\n", + "\n", + " def forward(self, x):\n", + " return x\n", + "\n", + "def generate_random_data(size=10, min_val=1, max_val=10):\n", + " return [min_val + (max_val - min_val) * torch.rand(1).item() for _ in range(size)]\n", + "\n", + "def save_json(data, filename):\n", + " with open(filename, 'w') as f:\n", + " json.dump(data, f)\n", + "\n", + "async def run_ezkl_pipeline():\n", + " gip_run_args = ezkl.PyRunArgs()\n", + " gip_run_args.input_visibility = \"public\"\n", + " gip_run_args.output_visibility = \"public\" # no parameters used\n", + " gip_run_args.param_visibility = \"fixed\"\n", + " gip_run_args.input_scale = 19\n", + " gip_run_args.param_scale = 19\n", + " gip_run_args.logrows = 8\n", + " run_args = ezkl.gen_settings(py_run_args=gip_run_args)\n", + " ezkl.compile_circuit()\n", + " await ezkl.gen_witness()\n", + " ezkl.setup()\n", + " ezkl.prove(proof_path=\"proof.json\")\n", + " ezkl.verify()\n", + "\n", + "def verify_proof_matches_input():\n", + " settings = json.load(open(\"settings.json\"))\n", + " inputs = json.load(open(\"input.json\"))\n", + " proof = json.load(open(\"proof.json\"))\n", + "\n", + " input_scale = settings[\"model_input_scales\"][0]\n", + " model_shapes = settings[\"model_instance_shapes\"]\n", + "\n", + " flat_inputs = [x for arr in inputs[\"input_data\"] for x in arr]\n", + " scaled_inputs = [ezkl.float_to_felt(x, input_scale) for x in flat_inputs]\n", + " proof_instances = proof[\"instances\"][0]\n", + "\n", + " def get_group_index(i):\n", + " pos = 0\n", + " for idx, (batch, length) in enumerate(model_shapes):\n", + " next_pos = pos + (batch * length)\n", + " if i < next_pos:\n", + " return idx\n", + " pos = next_pos\n", + " raise IndexError(\"Index out of bounds\")\n", + "\n", + " for i, (scaled, instance) in enumerate(zip(scaled_inputs, proof_instances)):\n", + " group_idx = get_group_index(i)\n", + " _, length = model_shapes[group_idx]\n", + "\n", + " descaled_instance = ezkl.felt_to_float(instance, input_scale)\n", + " descaled_input = ezkl.felt_to_float(scaled, input_scale)\n", + " pretty_value = proof[\"pretty_public_inputs\"][\"rescaled_inputs\"][group_idx][i % length]\n", + "\n", + " assert scaled == instance, f\"Input mismatch at index {i}: {scaled} != {instance} ({descaled_instance} != {descaled_input} OG {flat_inputs[i]} PRETTY {pretty_value})\"\n", + "\n", + "model = Passthrough()\n", + "torch.onnx.export(model, torch.randn(1, 10), \"network.onnx\")\n", + "\n", + "input_data = {\"input_data\": [generate_random_data()]}\n", + "save_json(input_data, \"input.json\")\n", + "save_json({\"input_data\": [generate_random_data()]}, \"calibration.json\")\n", + "\n", + "await run_ezkl_pipeline()\n", + "verify_proof_matches_input()\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/py_integration_tests.rs b/tests/py_integration_tests.rs index 6a0bba604..abb6c9129 100644 --- a/tests/py_integration_tests.rs +++ b/tests/py_integration_tests.rs @@ -189,6 +189,17 @@ mod py_tests { anvil_child.kill().unwrap(); } }); + + #[test] + fn felt_conversion_test_notebook() { + crate::py_tests::init_binary(); + let test_dir: TempDir = TempDir::new("felt_conversion_test").unwrap(); + let path = test_dir.path().to_str().unwrap(); + crate::py_tests::mv_test_(path, "felt_conversion_test.ipynb"); + run_notebook(path, "felt_conversion_test.ipynb"); + test_dir.close().unwrap(); + } + #[test] fn voice_notebook_() { crate::py_tests::init_binary();