Skip to content

Commit

Permalink
Merge branch 'main' into ac/numerical-accuracy-report
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Dec 21, 2023
2 parents 3ce98c2 + 5c52909 commit 9b0df1e
Show file tree
Hide file tree
Showing 37 changed files with 799 additions and 217 deletions.
64 changes: 62 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ tabled = { version = "0.12.0", optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
colored = { version = "2.0.0", default_features = false, optional = true}
env_logger = { version = "0.10.0", default_features = false, optional = true}

chrono = "0.4.31"
sha256 = "1.4.0"

[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.8", features = ["js"] }
Expand Down
23 changes: 21 additions & 2 deletions examples/notebooks/decision_tree.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
Expand Down Expand Up @@ -154,6 +154,25 @@
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(torch.rand(20, *shape)).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -284,4 +303,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
24 changes: 17 additions & 7 deletions examples/notebooks/ezkl_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')\n",
"cal_data_path = os.path.join('cal_data.json')"
Expand Down Expand Up @@ -402,11 +402,7 @@
"json.dump(data, open(data_path, 'w'))\n",
"\n",
"\n",
"# use the test set to calibrate the circuit\n",
"cal_data = dict(input_data = test_X.flatten().tolist())\n",
"\n",
"# Serialize calibration data into file:\n",
"json.dump(data, open(cal_data_path, 'w'))\n"
"\n"
]
},
{
Expand All @@ -430,6 +426,20 @@
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# use the test set to calibrate the circuit\n",
"cal_data = dict(input_data = test_X.flatten().tolist())\n",
"\n",
"# Serialize calibration data into file:\n",
"json.dump(data, open(cal_data_path, 'w'))\n",
"\n",
"res = ezkl.calibrate_settings(cal_data_path, model_path, settings_path, \"resources\") # Optimize for resources"
]
Expand Down Expand Up @@ -751,4 +761,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
9 changes: 9 additions & 0 deletions examples/notebooks/gcn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,15 @@
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
Expand Down
29 changes: 24 additions & 5 deletions examples/notebooks/generalized_inverse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
Expand All @@ -111,7 +111,9 @@
"outputs": [],
"source": [
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
"A = 0.1*torch.rand(1,*[10, 10], requires_grad=True)\n",
"shape = [10, 10]\n",
"\n",
"A = 0.1*torch.rand(1,*shape, requires_grad=True)\n",
"B = A.inverse()\n",
"\n",
"# Flips the neural net into inference mode\n",
Expand Down Expand Up @@ -174,10 +176,27 @@
"\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=gip_run_args)\n",
"\n",
"assert res == True\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (0.1*torch.rand(20,*shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
"assert res == True\n"
]
},
{
Expand Down Expand Up @@ -321,4 +340,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
Loading

0 comments on commit 9b0df1e

Please sign in to comment.