-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3c1b9d1
commit ac1aaa2
Showing
8 changed files
with
326 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67", | ||
"metadata": {}, | ||
"source": [ | ||
"## Support Vector Machines\n", | ||
"\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "95613ee9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# check if notebook is in colab\n", | ||
"try:\n", | ||
" # install ezkl\n", | ||
" import google.colab\n", | ||
" import subprocess\n", | ||
" import sys\n", | ||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n", | ||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n", | ||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"sk2torch\"])\n", | ||
"\n", | ||
"# rely on local installation of ezkl if the notebook is not in colab\n", | ||
"except:\n", | ||
" pass\n", | ||
"\n", | ||
"\n", | ||
"# here we create and (potentially train a model)\n", | ||
"\n", | ||
"# make sure you have the dependencies required here already installed\n", | ||
"import json\n", | ||
"import numpy as np\n", | ||
"from sklearn.svm import SVC\n", | ||
"import sk2torch\n", | ||
"import torch\n", | ||
"import ezkl\n", | ||
"import os\n", | ||
"\n", | ||
"\n", | ||
"# Create a dataset of two Gaussians. There will be some overlap\n", | ||
"# between the two classes, which adds some uncertainty to the model.\n", | ||
"xs = np.concatenate(\n", | ||
" [\n", | ||
" np.random.random(size=(256, 2)) + [1, 0],\n", | ||
" np.random.random(size=(256, 2)) + [-1, 0],\n", | ||
" ],\n", | ||
" axis=0,\n", | ||
")\n", | ||
"ys = np.array([False] * 256 + [True] * 256)\n", | ||
"\n", | ||
"# Train an SVM on the data and wrap it in PyTorch.\n", | ||
"sk_model = SVC(probability=True)\n", | ||
"sk_model.fit(xs, ys)\n", | ||
"model = sk2torch.wrap(sk_model)\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b37637c4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_path = os.path.join('network.onnx')\n", | ||
"compiled_model_path = os.path.join('network.compiled')\n", | ||
"pk_path = os.path.join('test.pk')\n", | ||
"vk_path = os.path.join('test.vk')\n", | ||
"settings_path = os.path.join('settings.json')\n", | ||
"srs_path = os.path.join('kzg.srs')\n", | ||
"witness_path = os.path.join('witness.json')\n", | ||
"data_path = os.path.join('input.json')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7f0ca328", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"# Create a coordinate grid to compute a vector field on.\n", | ||
"spaced = np.linspace(-2, 2, num=25)\n", | ||
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n", | ||
"\n", | ||
"\n", | ||
"# Compute the gradients of the SVM output.\n", | ||
"outputs = model.predict_proba(grid_xs)[:, 1]\n", | ||
"(input_grads,) = torch.autograd.grad(outputs.sum(), (grid_xs,))\n", | ||
"\n", | ||
"\n", | ||
"# Create a quiver plot of the vector field.\n", | ||
"plt.quiver(\n", | ||
" grid_xs[:, 0].detach().numpy(),\n", | ||
" grid_xs[:, 1].detach().numpy(),\n", | ||
" input_grads[:, 0].detach().numpy(),\n", | ||
" input_grads[:, 1].detach().numpy(),\n", | ||
")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "82db373a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"\n", | ||
"# export to onnx format\n", | ||
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n", | ||
"\n", | ||
"# Input to the model\n", | ||
"shape = xs.shape[1:]\n", | ||
"x = grid_xs[0:1]\n", | ||
"torch_out = model.predict(x)\n", | ||
"# Export the model\n", | ||
"torch.onnx.export(model, # model being run\n", | ||
" # model input (or a tuple for multiple inputs)\n", | ||
" x,\n", | ||
" # where to save the model (can be a file or file-like object)\n", | ||
" \"network.onnx\",\n", | ||
" export_params=True, # store the trained parameter weights inside the model file\n", | ||
" opset_version=10, # the ONNX version to export the model to\n", | ||
" do_constant_folding=True, # whether to execute constant folding for optimization\n", | ||
" input_names=['input'], # the model's input names\n", | ||
" output_names=['output'], # the model's output names\n", | ||
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n", | ||
" 'output': {0: 'batch_size'}})\n", | ||
"\n", | ||
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n", | ||
"\n", | ||
"data = dict(input_shapes=[shape],\n", | ||
" input_data=[d],\n", | ||
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n", | ||
"\n", | ||
"# Serialize data into file:\n", | ||
"json.dump(data, open(\"input.json\", 'w'))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d5e374a2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!RUST_LOG=trace\n", | ||
"# TODO: Dictionary outputs\n", | ||
"res = ezkl.gen_settings(model_path, settings_path)\n", | ||
"assert res == True\n", | ||
"\n", | ||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", | ||
"assert res == True" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3aa4f090", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", | ||
"assert res == True" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "8b74dcee", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# srs path\n", | ||
"res = ezkl.get_srs(srs_path, settings_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "18c8b7c7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# now generate the witness file \n", | ||
"\n", | ||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", | ||
"assert os.path.isfile(witness_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b1c561a8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"# HERE WE SETUP THE CIRCUIT PARAMS\n", | ||
"# WE GOT KEYS\n", | ||
"# WE GOT CIRCUIT PARAMETERS\n", | ||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"res = ezkl.setup(\n", | ||
" compiled_model_path,\n", | ||
" vk_path,\n", | ||
" pk_path,\n", | ||
" srs_path,\n", | ||
" settings_path,\n", | ||
" )\n", | ||
"\n", | ||
"assert res == True\n", | ||
"assert os.path.isfile(vk_path)\n", | ||
"assert os.path.isfile(pk_path)\n", | ||
"assert os.path.isfile(settings_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c384cbc8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# GENERATE A PROOF\n", | ||
"\n", | ||
"\n", | ||
"proof_path = os.path.join('test.pf')\n", | ||
"\n", | ||
"res = ezkl.prove(\n", | ||
" witness_path,\n", | ||
" compiled_model_path,\n", | ||
" pk_path,\n", | ||
" proof_path,\n", | ||
" srs_path,\n", | ||
" \"evm\",\n", | ||
" \"single\",\n", | ||
" settings_path,\n", | ||
" )\n", | ||
"\n", | ||
"print(res)\n", | ||
"assert os.path.isfile(proof_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "76f00d41", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# VERIFY IT\n", | ||
"\n", | ||
"res = ezkl.verify(\n", | ||
" proof_path,\n", | ||
" settings_path,\n", | ||
" vk_path,\n", | ||
" srs_path,\n", | ||
" )\n", | ||
"\n", | ||
"assert res == True\n", | ||
"print(\"verified\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.