Skip to content

Commit

Permalink
felt conversion nb
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Dec 5, 2024
1 parent 5c059e3 commit 742bc7a
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,8 @@ jobs:
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
Expand Down
112 changes: 112 additions & 0 deletions examples/notebooks/felt_conversion_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"import subprocess\n",
"from pathlib import Path\n",
"\n",
"\n",
"class Passthrough(torch.nn.Module):\n",
" def __init__(self, input_size=10):\n",
" super().__init__()\n",
"\n",
" def forward(self, x):\n",
" return x\n",
"\n",
"def generate_random_data(size=10, min_val=1, max_val=10):\n",
" return [min_val + (max_val - min_val) * torch.rand(1).item() for _ in range(size)]\n",
"\n",
"def save_json(data, filename):\n",
" with open(filename, 'w') as f:\n",
" json.dump(data, f)\n",
"\n",
"async def run_ezkl_pipeline():\n",
" gip_run_args = ezkl.PyRunArgs()\n",
" gip_run_args.input_visibility = \"public\"\n",
" gip_run_args.output_visibility = \"public\" # no parameters used\n",
" gip_run_args.param_visibility = \"fixed\"\n",
" gip_run_args.input_scale = 19\n",
" gip_run_args.param_scale = 19\n",
" gip_run_args.logrows = 8\n",
" run_args = ezkl.gen_settings(py_run_args=gip_run_args)\n",
" ezkl.compile_circuit()\n",
" await ezkl.gen_witness()\n",
" ezkl.setup()\n",
" ezkl.prove(proof_path=\"proof.json\")\n",
" ezkl.verify()\n",
"\n",
"def verify_proof_matches_input():\n",
" settings = json.load(open(\"settings.json\"))\n",
" inputs = json.load(open(\"input.json\"))\n",
" proof = json.load(open(\"proof.json\"))\n",
"\n",
" input_scale = settings[\"model_input_scales\"][0]\n",
" model_shapes = settings[\"model_instance_shapes\"]\n",
"\n",
" flat_inputs = [x for arr in inputs[\"input_data\"] for x in arr]\n",
" scaled_inputs = [ezkl.float_to_felt(x, input_scale) for x in flat_inputs]\n",
" proof_instances = proof[\"instances\"][0]\n",
"\n",
" def get_group_index(i):\n",
" pos = 0\n",
" for idx, (batch, length) in enumerate(model_shapes):\n",
" next_pos = pos + (batch * length)\n",
" if i < next_pos:\n",
" return idx\n",
" pos = next_pos\n",
" raise IndexError(\"Index out of bounds\")\n",
"\n",
" for i, (scaled, instance) in enumerate(zip(scaled_inputs, proof_instances)):\n",
" group_idx = get_group_index(i)\n",
" _, length = model_shapes[group_idx]\n",
"\n",
" descaled_instance = ezkl.felt_to_float(instance, input_scale)\n",
" descaled_input = ezkl.felt_to_float(scaled, input_scale)\n",
" pretty_value = proof[\"pretty_public_inputs\"][\"rescaled_inputs\"][group_idx][i % length]\n",
"\n",
" assert scaled == instance, f\"Input mismatch at index {i}: {scaled} != {instance} ({descaled_instance} != {descaled_input} OG {flat_inputs[i]} PRETTY {pretty_value})\"\n",
"\n",
"model = Passthrough()\n",
"torch.onnx.export(model, torch.randn(1, 10), \"network.onnx\")\n",
"\n",
"input_data = {\"input_data\": [generate_random_data()]}\n",
"save_json(input_data, \"input.json\")\n",
"save_json({\"input_data\": [generate_random_data()]}, \"calibration.json\")\n",
"\n",
"await run_ezkl_pipeline()\n",
"verify_proof_matches_input()\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
11 changes: 11 additions & 0 deletions tests/py_integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ mod py_tests {
anvil_child.kill().unwrap();
}
});

#[test]
fn felt_conversion_test_notebook() {
crate::py_tests::init_binary();
let test_dir: TempDir = TempDir::new("felt_conversion_test").unwrap();
let path = test_dir.path().to_str().unwrap();
crate::py_tests::mv_test_(path, "felt_conversion_test.ipynb");
run_notebook(path, "felt_conversion_test.ipynb");
test_dir.close().unwrap();
}

#[test]
fn voice_notebook_() {
crate::py_tests::init_binary();
Expand Down

0 comments on commit 742bc7a

Please sign in to comment.