diff --git a/Cargo.lock b/Cargo.lock index b415c6fde..adbad2da9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,6 +77,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "ansi-str" version = "0.8.0" @@ -721,12 +730,16 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.26" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" dependencies = [ "android-tzdata", + "iana-time-zone", + "js-sys", "num-traits", + "wasm-bindgen", + "windows-targets 0.48.0", ] [[package]] @@ -1778,6 +1791,7 @@ version = "0.0.0" dependencies = [ "ark-std 0.3.0", "bincode", + "chrono", "clap 4.3.8", "colored", "colored_json", @@ -1815,6 +1829,7 @@ dependencies = [ "serde", "serde-wasm-bindgen", "serde_json", + "sha256", "shellexpand", "snark-verifier", "tabled", @@ -2458,6 +2473,29 @@ dependencies = [ "tokio-native-tls", ] +[[package]] +name = "iana-time-zone" +version = "0.1.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icicle" version = "0.1.0" @@ -4595,6 +4633,19 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "sha256" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7895c8ae88588ccead14ff438b939b0c569cd619116f14b4d13fdff7b8333386" +dependencies = [ + "async-trait", + "bytes", + "hex", + "sha2 0.10.7", + "tokio", +] + [[package]] name = "sha3" version = "0.9.1" @@ -5750,6 +5801,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af6041b3f84485c21b57acdc0fee4f4f0c93f426053dc05fa5d6fc262537bbff" +dependencies = [ + "windows-targets 0.48.0", +] + [[package]] name = "windows-sys" version = "0.42.0" diff --git a/Cargo.toml b/Cargo.toml index 3f6caa85b..ada036e20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,7 +59,8 @@ tabled = { version = "0.12.0", optional = true } [target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies] colored = { version = "2.0.0", default_features = false, optional = true} env_logger = { version = "0.10.0", default_features = false, optional = true} - +chrono = "0.4.31" +sha256 = "1.4.0" [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2.8", features = ["js"] } diff --git a/examples/notebooks/decision_tree.ipynb b/examples/notebooks/decision_tree.ipynb index e962d2971..dbe52a358 100644 --- a/examples/notebooks/decision_tree.ipynb +++ b/examples/notebooks/decision_tree.ipynb @@ -93,7 +93,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -154,6 +154,25 @@ "assert res == True" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# generate a bunch of dummy calibration data\n", + "cal_data = {\n", + " \"input_data\": [(torch.rand(20, *shape)).flatten().tolist()],\n", + "}\n", + "\n", + "cal_path = os.path.join('val_data.json')\n", + "# save as json file\n", + "with open(cal_path, \"w\") as f:\n", + " json.dump(cal_data, f)\n", + "\n", + "res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -284,4 +303,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/ezkl_demo.ipynb b/examples/notebooks/ezkl_demo.ipynb index d708093e7..79bff5fcf 100644 --- a/examples/notebooks/ezkl_demo.ipynb +++ b/examples/notebooks/ezkl_demo.ipynb @@ -356,7 +356,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')\n", "cal_data_path = os.path.join('cal_data.json')" @@ -402,11 +402,7 @@ "json.dump(data, open(data_path, 'w'))\n", "\n", "\n", - "# use the test set to calibrate the circuit\n", - "cal_data = dict(input_data = test_X.flatten().tolist())\n", - "\n", - "# Serialize calibration data into file:\n", - "json.dump(data, open(cal_data_path, 'w'))\n" + "\n" ] }, { @@ -430,6 +426,20 @@ "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path)\n", "assert res == True\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# use the test set to calibrate the circuit\n", + "cal_data = dict(input_data = test_X.flatten().tolist())\n", + "\n", + "# Serialize calibration data into file:\n", + "json.dump(data, open(cal_data_path, 'w'))\n", "\n", "res = ezkl.calibrate_settings(cal_data_path, model_path, settings_path, \"resources\") # Optimize for resources" ] @@ -751,4 +761,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/examples/notebooks/gcn.ipynb b/examples/notebooks/gcn.ipynb index 1ea4ca2c2..410342194 100644 --- a/examples/notebooks/gcn.ipynb +++ b/examples/notebooks/gcn.ipynb @@ -457,6 +457,15 @@ "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n", "assert res == True\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "\n", "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", "assert res == True" diff --git a/examples/notebooks/generalized_inverse.ipynb b/examples/notebooks/generalized_inverse.ipynb index 5336dd1aa..583b0038f 100644 --- a/examples/notebooks/generalized_inverse.ipynb +++ b/examples/notebooks/generalized_inverse.ipynb @@ -96,7 +96,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -111,7 +111,9 @@ "outputs": [], "source": [ "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", - "A = 0.1*torch.rand(1,*[10, 10], requires_grad=True)\n", + "shape = [10, 10]\n", + "\n", + "A = 0.1*torch.rand(1,*shape, requires_grad=True)\n", "B = A.inverse()\n", "\n", "# Flips the neural net into inference mode\n", @@ -174,10 +176,27 @@ "\n", "res = ezkl.gen_settings(model_path, settings_path, py_run_args=gip_run_args)\n", "\n", - "assert res == True\n", + "assert res == True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = (0.1*torch.rand(20,*shape).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", "\n", "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "assert res == True\n" ] }, { @@ -321,4 +340,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/gradient_boosted_trees.ipynb b/examples/notebooks/gradient_boosted_trees.ipynb index 3adb838b2..a608f9974 100644 --- a/examples/notebooks/gradient_boosted_trees.ipynb +++ b/examples/notebooks/gradient_boosted_trees.ipynb @@ -33,18 +33,10 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "95613ee9", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "num diff: [0]\n" - ] - } - ], + "outputs": [], "source": [ "# check if notebook is in colab\n", "try:\n", @@ -102,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "b37637c4", "metadata": {}, "outputs": [], @@ -112,28 +104,17 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "82db373a", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "================ Diagnostic Run torch.onnx.export version 2.0.1 ================\n", - "verbose: False, log level: Level.ERROR\n", - "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "# !!!!!!!!!!!!!!!!! This cell will flash a warning about onnx runtime compat but it is fine !!!!!!!!!!!!!!!!!!!!!\n", "\n", @@ -168,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "d5e374a2", "metadata": {}, "outputs": [], @@ -179,14 +160,32 @@ "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n", "assert res == True\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = (torch.rand(20,*shape).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", "\n", "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "assert res == True\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "3aa4f090", "metadata": {}, "outputs": [], @@ -197,7 +196,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "8b74dcee", "metadata": {}, "outputs": [], @@ -208,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "18c8b7c7", "metadata": {}, "outputs": [], @@ -249,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "c384cbc8", "metadata": {}, "outputs": [], @@ -314,4 +313,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/hashed_vis.ipynb b/examples/notebooks/hashed_vis.ipynb index dfcb00d8f..439c2ea86 100644 --- a/examples/notebooks/hashed_vis.ipynb +++ b/examples/notebooks/hashed_vis.ipynb @@ -117,7 +117,9 @@ "metadata": {}, "outputs": [], "source": [ - "x = torch.rand(1,*[3, 8, 8], requires_grad=True)\n", + "\n", + "shape = [3, 8, 8]\n", + "x = torch.rand(1,*shape, requires_grad=True)\n", "\n", "# Flips the neural net into inference mode\n", "circuit.eval()\n", diff --git a/examples/notebooks/keras_simple_demo.ipynb b/examples/notebooks/keras_simple_demo.ipynb index ff335a218..e302e2626 100644 --- a/examples/notebooks/keras_simple_demo.ipynb +++ b/examples/notebooks/keras_simple_demo.ipynb @@ -103,8 +103,10 @@ "import tf2onnx\n", "import tensorflow as tf\n", "\n", + "\n", + "shape = [1, 28, 28]\n", "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", - "x = 0.1*np.random.rand(1,*[1, 28, 28])\n", + "x = 0.1*np.random.rand(1,*shape)\n", "\n", "spec = tf.TensorSpec([1, 28, 28, 1], tf.float32, name='input_0')\n", "\n", @@ -130,11 +132,27 @@ "\n", "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = (0.1*np.random.rand(20,*shape)).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", "\n", "\n", "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "assert res == True\n" ] }, { diff --git a/examples/notebooks/kmeans.ipynb b/examples/notebooks/kmeans.ipynb index 65e38252c..a2d96a6be 100644 --- a/examples/notebooks/kmeans.ipynb +++ b/examples/notebooks/kmeans.ipynb @@ -78,7 +78,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -136,10 +136,27 @@ "!RUST_LOG=trace\n", "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = (grid_xs[0:20].detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", "\n", "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "assert res == True\n" ] }, { @@ -272,4 +289,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/lightgbm.ipynb b/examples/notebooks/lightgbm.ipynb index ee764abfc..3bbbf9197 100644 --- a/examples/notebooks/lightgbm.ipynb +++ b/examples/notebooks/lightgbm.ipynb @@ -117,7 +117,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -174,10 +174,27 @@ "\n", "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", "\n", "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "assert res == True\n" ] }, { @@ -330,4 +347,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/linear_regression.ipynb b/examples/notebooks/linear_regression.ipynb index 067f0e21c..b5eeeb25b 100644 --- a/examples/notebooks/linear_regression.ipynb +++ b/examples/notebooks/linear_regression.ipynb @@ -69,7 +69,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -124,10 +124,26 @@ "!RUST_LOG=trace\n", "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", "\n", "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "assert res == True\n" ] }, { @@ -260,4 +276,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/little_transformer.ipynb b/examples/notebooks/little_transformer.ipynb index 9187fb0f8..9bd3deb78 100644 --- a/examples/notebooks/little_transformer.ipynb +++ b/examples/notebooks/little_transformer.ipynb @@ -266,7 +266,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')\n", "\n" @@ -280,10 +280,13 @@ "outputs": [], "source": [ "\n", - "import json \n", + "import json\n", + "\n", + "\n", + "shape = [1, 6]\n", "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", - "x = torch.ones([1, 6], dtype=torch.long)\n", - "x = x.reshape([1, 6])\n", + "x = torch.zeros(shape, dtype=torch.long)\n", + "x = x.reshape(shape)\n", "\n", "print(x)\n", "\n", @@ -335,7 +338,14 @@ "metadata": {}, "outputs": [], "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", "\n", "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", "assert res == True\n" @@ -475,9 +485,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/lstm.ipynb b/examples/notebooks/lstm.ipynb index 4c4cf7e12..272d69a60 100644 --- a/examples/notebooks/lstm.ipynb +++ b/examples/notebooks/lstm.ipynb @@ -74,7 +74,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')\n", "\n" @@ -107,9 +107,11 @@ " dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n", " 'output' : {0 : 'batch_size'}})\n", "\n", + "\n", "SEQ_LEN = 10\n", + "shape = (SEQ_LEN, 3)\n", "# sequence of length 10\n", - "x = torch.randn(SEQ_LEN, 3)\n", + "x = torch.randn(*shape)\n", "\n", "data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n", "\n", @@ -141,6 +143,24 @@ "assert res == True\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = (torch.randn(10, *shape).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -280,4 +300,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/mnist_gan.ipynb b/examples/notebooks/mnist_gan.ipynb index bf326da22..13d089b75 100644 --- a/examples/notebooks/mnist_gan.ipynb +++ b/examples/notebooks/mnist_gan.ipynb @@ -231,10 +231,11 @@ "import tensorflow as tf\n", "import json\n", "\n", + "shape = [1, ZDIM]\n", "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", - "x = 0.1*np.random.rand(1,*[1, ZDIM])\n", + "x = 0.1*np.random.rand(1,*shape)\n", "\n", - "spec = tf.TensorSpec([1, ZDIM], tf.float32, name='input_0')\n", + "spec = tf.TensorSpec(shape, tf.float32, name='input_0')\n", "\n", "\n", "tf2onnx.convert.from_keras(gm, input_signature=[spec], inputs_as_nchw=['input_0'], opset=12, output_path=model_path)\n", @@ -264,11 +265,26 @@ "!RUST_LOG=trace\n", "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[0,6])\n", - "assert res == True\n", - "print(\"verified\")" + "data_array = (0.2 * np.random.rand(20, *shape)).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])" ] }, { @@ -404,4 +420,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/examples/notebooks/mnist_vae.ipynb b/examples/notebooks/mnist_vae.ipynb index 85d3b090c..69011e8fa 100644 --- a/examples/notebooks/mnist_vae.ipynb +++ b/examples/notebooks/mnist_vae.ipynb @@ -169,10 +169,11 @@ "import tensorflow as tf\n", "import json\n", "\n", + "shape = [1, ZDIM]\n", "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", - "x = 0.1*np.random.rand(1,*[1, ZDIM])\n", + "x = 0.1*np.random.rand(1,*shape)\n", "\n", - "spec = tf.TensorSpec([1, ZDIM], tf.float32, name='input_0')\n", + "spec = tf.TensorSpec(shape, tf.float32, name='input_0')\n", "\n", "\n", "tf2onnx.convert.from_keras(dec, input_signature=[spec], inputs_as_nchw=['input_0'], opset=12, output_path=model_path)\n", @@ -195,11 +196,26 @@ "\n", "!RUST_LOG=trace\n", "res = ezkl.gen_settings(model_path, settings_path)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True\n", - "print(\"verified\")" + "data_array = (0.1 * np.random.rand(20, *shape)).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -572,4 +588,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/examples/notebooks/random_forest.ipynb b/examples/notebooks/random_forest.ipynb index 0af8c242e..6c6d992bb 100644 --- a/examples/notebooks/random_forest.ipynb +++ b/examples/notebooks/random_forest.ipynb @@ -117,7 +117,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -174,10 +174,26 @@ "!RUST_LOG=trace\n", "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -310,4 +326,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/simple_demo_aggregated_proofs.ipynb b/examples/notebooks/simple_demo_aggregated_proofs.ipynb index 3cb7664d4..f2dbb1cb9 100644 --- a/examples/notebooks/simple_demo_aggregated_proofs.ipynb +++ b/examples/notebooks/simple_demo_aggregated_proofs.ipynb @@ -114,9 +114,9 @@ "outputs": [], "source": [ "\n", - "\n", + "shape = [1, 28, 28]\n", "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", - "x = 0.1*torch.rand(1,*[1, 28, 28], requires_grad=True)\n", + "x = 0.1*torch.rand(1,*shape, requires_grad=True)\n", "\n", "# Flips the neural net into inference mode\n", "circuit.eval()\n", @@ -152,9 +152,26 @@ "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path)\n", "assert res == True\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -383,7 +400,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.9.15" } }, "nbformat": 4, diff --git a/examples/notebooks/simple_demo_all_public.ipynb b/examples/notebooks/simple_demo_all_public.ipynb index e420b4168..c17239db4 100644 --- a/examples/notebooks/simple_demo_all_public.ipynb +++ b/examples/notebooks/simple_demo_all_public.ipynb @@ -97,7 +97,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -110,9 +110,9 @@ "outputs": [], "source": [ "\n", - "\n", + "shape = [1, 28, 28]\n", "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", - "x = 0.1*torch.rand(1,*[1, 28, 28], requires_grad=True)\n", + "x = 0.1*torch.rand(1,*shape, requires_grad=True)\n", "\n", "# Flips the neural net into inference mode\n", "circuit.eval()\n", @@ -151,10 +151,26 @@ "py_run_args.param_visibility = \"fixed\" # \"fixed\" for params means that the committed to params are used for all proofs\n", "\n", "res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -287,4 +303,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/simple_demo_public_input_output.ipynb b/examples/notebooks/simple_demo_public_input_output.ipynb index 9a6a20b5b..7a1e3a277 100644 --- a/examples/notebooks/simple_demo_public_input_output.ipynb +++ b/examples/notebooks/simple_demo_public_input_output.ipynb @@ -97,7 +97,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -110,9 +110,9 @@ "outputs": [], "source": [ "\n", - "\n", + "shape = [1, 28, 28]\n", "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", - "x = 0.1*torch.rand(1,*[1, 28, 28], requires_grad=True)\n", + "x = 0.1*torch.rand(1,*shape, requires_grad=True)\n", "\n", "# Flips the neural net into inference mode\n", "circuit.eval()\n", @@ -150,11 +150,26 @@ "py_run_args.param_visibility = \"private\" # private by default\n", "\n", "res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", "\n", - "assert res == True\n", + "data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -287,4 +302,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/simple_demo_public_network_output.ipynb b/examples/notebooks/simple_demo_public_network_output.ipynb index bbfe818f7..243d984fe 100644 --- a/examples/notebooks/simple_demo_public_network_output.ipynb +++ b/examples/notebooks/simple_demo_public_network_output.ipynb @@ -97,7 +97,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -110,9 +110,9 @@ "outputs": [], "source": [ "\n", - "\n", + "shape = [1, 28, 28]\n", "# After training, export to onnx (network.onnx) and create a data file (input.json)\n", - "x = 0.1*torch.rand(1,*[1, 28, 28], requires_grad=True)\n", + "x = 0.1*torch.rand(1,*shape, requires_grad=True)\n", "\n", "# Flips the neural net into inference mode\n", "circuit.eval()\n", @@ -151,10 +151,26 @@ "\n", "res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)\n", "\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -287,4 +303,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/sklearn_mlp.ipynb b/examples/notebooks/sklearn_mlp.ipynb index 1637242c9..dd7949253 100644 --- a/examples/notebooks/sklearn_mlp.ipynb +++ b/examples/notebooks/sklearn_mlp.ipynb @@ -75,7 +75,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -130,10 +130,26 @@ "!RUST_LOG=trace\n", "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -266,4 +282,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/stacked_regression.ipynb b/examples/notebooks/stacked_regression.ipynb index e63fe9ec1..68167164a 100644 --- a/examples/notebooks/stacked_regression.ipynb +++ b/examples/notebooks/stacked_regression.ipynb @@ -110,7 +110,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -171,6 +171,25 @@ "assert res == True" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -301,4 +320,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/svm.ipynb b/examples/notebooks/svm.ipynb index b9387d93d..977c87cd3 100644 --- a/examples/notebooks/svm.ipynb +++ b/examples/notebooks/svm.ipynb @@ -79,7 +79,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -161,10 +161,26 @@ "!RUST_LOG=trace\n", "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cal_path = os.path.join(\"calibration.json\")\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "data_array = ((grid_xs[0:20]).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -387,7 +403,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -430,4 +446,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/examples/notebooks/tictactoe_autoencoder.ipynb b/examples/notebooks/tictactoe_autoencoder.ipynb index c413f49bd..9414e7f79 100644 --- a/examples/notebooks/tictactoe_autoencoder.ipynb +++ b/examples/notebooks/tictactoe_autoencoder.ipynb @@ -598,7 +598,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "proof_path = os.path.join('proof.json')" ] @@ -623,7 +623,17 @@ "# For testing we will just stick to resources to reduce computational costs\n", "# Example:\n", "# ezkl.calibrate_settings(data_path, model_path, settings_path, \"accuracy\", scales = [2,9])\n", - "ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")" + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = ((x[0:20]).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -739,9 +749,10 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} diff --git a/examples/notebooks/tictactoe_binary_classification.ipynb b/examples/notebooks/tictactoe_binary_classification.ipynb index c7aa7cdd5..ea9c4a8b4 100644 --- a/examples/notebooks/tictactoe_binary_classification.ipynb +++ b/examples/notebooks/tictactoe_binary_classification.ipynb @@ -490,7 +490,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "proof_path = os.path.join('proof.json')" ] @@ -510,7 +510,17 @@ "metadata": {}, "outputs": [], "source": [ - "ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")" + "cal_path = os.path.join(\"calibration.json\")\n", + "\n", + "data_array = ((x[0:20]).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_data = [data_array])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(cal_path, 'w'))\n", + "\n", + "\n", + "ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -519,7 +529,7 @@ "metadata": {}, "outputs": [], "source": [ - "ezkl.get_srs( settings_path)" + "ezkl.get_srs(settings_path)" ] }, { @@ -631,4 +641,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} diff --git a/examples/notebooks/world_rotation.ipynb b/examples/notebooks/world_rotation.ipynb index 2d9d3e4ed..336049c51 100644 --- a/examples/notebooks/world_rotation.ipynb +++ b/examples/notebooks/world_rotation.ipynb @@ -529,7 +529,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.9.15" } }, "nbformat": 4, diff --git a/examples/notebooks/xgboost.ipynb b/examples/notebooks/xgboost.ipynb index 4baa328d6..b3bd3a000 100644 --- a/examples/notebooks/xgboost.ipynb +++ b/examples/notebooks/xgboost.ipynb @@ -118,7 +118,7 @@ "pk_path = os.path.join('test.pk')\n", "vk_path = os.path.join('test.vk')\n", "settings_path = os.path.join('settings.json')\n", - "", + "\n", "witness_path = os.path.join('witness.json')\n", "data_path = os.path.join('input.json')" ] @@ -166,7 +166,6 @@ { "cell_type": "code", "execution_count": null, - "id": "d5e374a2", "metadata": {}, "outputs": [], "source": [ @@ -175,10 +174,26 @@ "\n", "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n", - "assert res == True\n", + "assert res == True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# generate a bunch of dummy calibration data\n", + "cal_data = {\n", + " \"input_data\": [(torch.rand(20, *shape)).flatten().tolist()],\n", + "}\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", - "assert res == True" + "cal_path = os.path.join('calibration.json')\n", + "# save as json file\n", + "with open(cal_path, \"w\") as f:\n", + " json.dump(cal_data, f)\n", + "\n", + "res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")" ] }, { @@ -331,4 +346,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/src/commands.rs b/src/commands.rs index 19b64c74e..e0efdcdb5 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -261,10 +261,10 @@ pub enum Commands { /// The path to the .json data file #[arg(short = 'D', long, default_value = DEFAULT_DATA)] data: PathBuf, - /// The path to the compiled model file + /// The path to the compiled model file (generated using the compile-circuit command) #[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)] compiled_circuit: PathBuf, - /// Path to the witness (public and private inputs) .json file + /// Path to output the witness .json file #[arg(short = 'O', long, default_value = DEFAULT_WITNESS)] output: PathBuf, /// Path to the verification key file (optional - solely used to generate kzg commits) @@ -280,7 +280,7 @@ pub enum Commands { /// The path to the .onnx model file #[arg(short = 'M', long, default_value = DEFAULT_MODEL)] model: PathBuf, - /// Path to circuit_settings file to output + /// The path to generate the circuit settings .json file to #[arg(short = 'O', long, default_value = DEFAULT_SETTINGS)] settings_path: PathBuf, /// proving arguments @@ -297,7 +297,7 @@ pub enum Commands { /// The path to the .onnx model file #[arg(short = 'M', long, default_value = DEFAULT_MODEL)] model: PathBuf, - /// Path to circuit_settings file to read in AND overwrite. + /// The path to load circuit settings .json file AND overwrite (generated using the gen-settings command). #[arg(short = 'O', long, default_value = DEFAULT_SETTINGS)] settings_path: PathBuf, #[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET)] @@ -314,7 +314,7 @@ pub enum Commands { /// Generates a dummy SRS #[command(name = "gen-srs", arg_required_else_help = true)] GenSrs { - /// The path to output to the desired srs file + /// The path to output the generated SRS #[arg(long)] srs_path: PathBuf, /// number of logrows to use for srs @@ -326,22 +326,22 @@ pub enum Commands { /// Gets an SRS from a circuit settings file. #[command(name = "get-srs")] GetSrs { - /// The path to output to the desired srs file + /// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs #[arg(long)] srs_path: Option, - /// Path to circuit_settings file to read in. Overriden by logrows if specified. + /// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified. #[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)] settings_path: Option, /// Number of logrows to use for srs. Overrides settings_path if specified. #[arg(long, default_value = None)] logrows: Option, - /// Check mode for srs. verifies downloaded srs is valid. set to unsafe for speed. + /// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed. #[arg(long, default_value = DEFAULT_CHECKMODE)] check: CheckMode, }, /// Loads model and input and runs mock prover (for testing) Mock { - /// The path to the .json witness file + /// The path to the .json witness file (generated using the gen-witness command) #[arg(short = 'W', long, default_value = DEFAULT_WITNESS)] witness: PathBuf, /// The path to the .onnx model file @@ -351,7 +351,7 @@ pub enum Commands { /// Mock aggregate proofs MockAggregate { - /// The path to the snarks to aggregate over + /// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag) #[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)] aggregation_snarks: Vec, /// logrows used for aggregation circuit @@ -364,16 +364,16 @@ pub enum Commands { /// setup aggregation circuit :) SetupAggregate { - /// The path to samples of snarks that will be aggregated over + /// The path to samples of snarks that will be aggregated over (generated using the prove command with the --proof-type=for-aggr flag) #[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)] sample_snarks: Vec, - /// The path to save the desired verification key file + /// The path to save the desired verification key file to #[arg(long, default_value = DEFAULT_VK_AGGREGATED)] vk_path: PathBuf, - /// The path to save the desired proving key file + /// The path to save the proving key to #[arg(long, default_value = DEFAULT_PK_AGGREGATED)] pk_path: PathBuf, - /// The path to SRS + /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs #[arg(long)] srs_path: Option, /// logrows used for aggregation circuit @@ -385,16 +385,16 @@ pub enum Commands { }, /// Aggregates proofs :) Aggregate { - /// The path to the snarks to aggregate over + /// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag) #[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)] aggregation_snarks: Vec, - /// The path to load the desired proving key file + /// The path to load the desired proving key file (generated using the setup-aggregate command) #[arg(long, default_value = DEFAULT_PK_AGGREGATED)] pk_path: PathBuf, - /// The path to the desired output file + /// The path to output the proof file to #[arg(long, default_value = DEFAULT_PROOF_AGGREGATED)] proof_path: PathBuf, - /// The path to SRS + /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs #[arg(long)] srs_path: Option, #[arg( @@ -411,7 +411,7 @@ pub enum Commands { /// run sanity checks during calculations (safe or unsafe) #[arg(long, default_value = DEFAULT_CHECKMODE)] check_mode: CheckMode, - /// whether the accumulated are segments of a larger proof + /// whether the accumulated proofs are segments of a larger circuit #[arg(long, default_value = DEFAULT_SPLIT)] split_proofs: bool, }, @@ -420,25 +420,25 @@ pub enum Commands { /// The path to the .onnx model file #[arg(short = 'M', long, default_value = DEFAULT_MODEL)] model: PathBuf, - /// The path to output the processed model + /// The path to the compiled model file (generated using the compile-circuit command) #[arg(long, default_value = DEFAULT_COMPILED_CIRCUIT)] compiled_circuit: PathBuf, - /// The path to load circuit params from + /// The path to load circuit settings .json file from (generated using the gen-settings command) #[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)] settings_path: PathBuf, }, /// Creates pk and vk Setup { - /// The path to the compiled model file + /// The path to the compiled model file (generated using the compile-circuit command) #[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)] compiled_circuit: PathBuf, - /// The srs path + /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs #[arg(long)] srs_path: Option, - /// The path to output the verification key file + /// The path to output the verification key file to #[arg(long, default_value = DEFAULT_VK)] vk_path: PathBuf, - /// The path to output the proving key file + /// The path to output the proving key file to #[arg(long, default_value = DEFAULT_PK)] pk_path: PathBuf, /// The graph witness (optional - used to override fixed values in the circuit) @@ -449,10 +449,10 @@ pub enum Commands { #[cfg(not(target_arch = "wasm32"))] /// Fuzzes the proof pipeline with random inputs, random parameters, and random keys Fuzz { - /// The path to the .json witness file, which should include both the network input (possibly private) and the network output (public input to the proof) + /// The path to the .json witness file (generated using the gen-witness command) #[arg(short = 'W', long, default_value = DEFAULT_WITNESS)] witness: PathBuf, - /// The path to the processed model file + /// The path to the compiled model file (generated using the compile-circuit command) #[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)] compiled_circuit: PathBuf, #[arg( @@ -473,7 +473,7 @@ pub enum Commands { /// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof) #[arg(short = 'D', long)] data: PathBuf, - /// The path to the compiled model file + /// The path to the compiled model file (generated using the compile-circuit command) #[arg(short = 'M', long)] compiled_circuit: PathBuf, /// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information @@ -484,20 +484,20 @@ pub enum Commands { /// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state #[arg(short = 'U', long)] rpc_url: Option, - /// where does the input data come from + /// where the input data come from #[arg(long, default_value = "on-chain")] input_source: TestDataSource, - /// where does the output data come from + /// where the output data come from #[arg(long, default_value = "on-chain")] output_source: TestDataSource, }, #[cfg(not(target_arch = "wasm32"))] #[command(arg_required_else_help = true)] TestUpdateAccountCalls { - /// The path to verifier contract's address + /// The path to the verifier contract's address #[arg(long)] addr: H160, - /// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof) + /// The path to the .json data file. #[arg(short = 'D', long)] data: PathBuf, /// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state @@ -518,19 +518,19 @@ pub enum Commands { #[cfg(not(target_arch = "wasm32"))] /// Loads model, data, and creates proof Prove { - /// The path to the .json witness file, which should include both the network input (possibly private) and the network output (public input to the proof) + /// The path to the .json witness file (generated using the gen-witness command) #[arg(short = 'W', long, default_value = DEFAULT_WITNESS)] witness: PathBuf, - /// The path to the compiled model file + /// The path to the compiled model file (generated using the compile-circuit command) #[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)] compiled_circuit: PathBuf, - /// The path to load the desired proving key file + /// The path to load the desired proving key file (generated using the setup command) #[arg(long, default_value = DEFAULT_PK)] pk_path: PathBuf, - /// The path to the desired output file + /// The path to output the proof file to #[arg(long, default_value = DEFAULT_PROOF)] proof_path: PathBuf, - /// The parameter path + /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs #[arg(long)] srs_path: Option, #[arg( @@ -549,10 +549,10 @@ pub enum Commands { /// Creates an EVM verifier for a single proof #[command(name = "create-evm-verifier")] CreateEVMVerifier { - /// The path to load the desired params file + /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs #[arg(long)] srs_path: Option, - /// The path to load circuit settings from + /// The path to load circuit settings .json file from (generated using the gen-settings command) #[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)] settings_path: PathBuf, /// The path to load the desired verification key file @@ -569,10 +569,10 @@ pub enum Commands { /// Creates an EVM verifier that attests to on-chain inputs for a single proof #[command(name = "create-evm-da")] CreateEVMDataAttestation { - /// The path to load the desired srs file from + /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs #[arg(long)] srs_path: Option, - /// The path to load circuit settings from + /// The path to load circuit settings .json file from (generated using the gen-settings command) #[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)] settings_path: PathBuf, /// The path to load the desired verification key file @@ -585,23 +585,22 @@ pub enum Commands { #[arg(long, default_value = DEFAULT_VERIFIER_DA_ABI)] abi_path: PathBuf, /// The path to the .json data file, which should - /// contain the necessary calldata and accoount addresses - /// needed need to read from all the on-chain + /// contain the necessary calldata and account addresses + /// needed to read from all the on-chain /// view functions that return the data that the network /// ingests as inputs. #[arg(short = 'D', long, default_value = DEFAULT_DATA)] data: PathBuf, - // todo, optionally allow supplying proving key }, #[cfg(not(target_arch = "wasm32"))] /// Creates an EVM verifier for an aggregate proof #[command(name = "create-evm-verifier-aggr")] CreateEVMVerifierAggr { - /// The path to load the desired srs file from + /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs #[arg(long)] srs_path: Option, - /// The path to to load the desired verification key file + /// The path to load the desired verification key file #[arg(long, default_value = DEFAULT_VK_AGGREGATED)] vk_path: PathBuf, /// The path to the Solidity code @@ -619,28 +618,28 @@ pub enum Commands { }, /// Verifies a proof, returning accept or reject Verify { - /// The path to load circuit params from + /// The path to load circuit settings .json file from (generated using the gen-settings command) #[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)] settings_path: PathBuf, - /// The path to the proof file + /// The path to the proof file (generated using the prove command) #[arg(long, default_value = DEFAULT_PROOF)] proof_path: PathBuf, - /// The path to output the desired verification key file (optional) + /// The path to the verification key file (generated using the setup command) #[arg(long, default_value = DEFAULT_VK)] vk_path: PathBuf, - /// The kzg srs path + /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs #[arg(long)] srs_path: Option, }, /// Verifies an aggregate proof, returning accept or reject VerifyAggr { - /// The path to the proof file + /// The path to the proof file (generated using the prove command) #[arg(long, default_value = DEFAULT_PROOF_AGGREGATED)] proof_path: PathBuf, - /// The path to output the desired verification key file (optional) + /// The path to the verification key file (generated using the setup-aggregate command) #[arg(long, default_value = DEFAULT_VK_AGGREGATED)] vk_path: PathBuf, - /// The srs path + /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs #[arg(long)] srs_path: Option, /// logrows used for aggregation circuit @@ -649,7 +648,7 @@ pub enum Commands { }, #[cfg(not(target_arch = "wasm32"))] DeployEvmVerifier { - /// The path to the Solidity code + /// The path to the Solidity code (generated using the create-evm-verifier command) #[arg(long, default_value = DEFAULT_SOL_CODE)] sol_code_path: PathBuf, /// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state @@ -658,7 +657,7 @@ pub enum Commands { #[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)] /// The path to output the contract address addr_path: PathBuf, - /// The optimizer runs to set on the verifier. (Lower values optimize for deployment, while higher values optimize for execution) + /// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost. #[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)] optimizer_runs: usize, /// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil @@ -671,7 +670,7 @@ pub enum Commands { /// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof) #[arg(short = 'D', long, default_value = DEFAULT_DATA)] data: PathBuf, - /// The path to load circuit params from + /// The path to load circuit settings .json file from (generated using the gen-settings command) #[arg(long, default_value = DEFAULT_SETTINGS)] settings_path: PathBuf, /// The path to the Solidity code @@ -694,7 +693,7 @@ pub enum Commands { /// Verifies a proof using a local EVM executor, returning accept or reject #[command(name = "verify-evm")] VerifyEVM { - /// The path to the proof file + /// The path to the proof file (generated using the prove command) #[arg(long, default_value = DEFAULT_PROOF)] proof_path: PathBuf, /// The path to verifier contract's address diff --git a/src/execute.rs b/src/execute.rs index ad59ae3bf..8995a91b7 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -428,6 +428,29 @@ async fn fetch_srs(uri: &str) -> Result, Box> { Ok(std::mem::take(&mut buf)) } +#[cfg(not(target_arch = "wasm32"))] +fn check_srs_hash(logrows: u32, srs_path: Option) -> Result> { + let path = get_srs_path(logrows, srs_path); + let hash = sha256::digest(&std::fs::read(path.clone())?); + info!("SRS hash: {}", hash); + + let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } { + Some(h) => h, + None => return Err(format!("SRS (k={}) hash not found in public set", logrows).into()), + }; + + if hash != predefined_hash.to_string() { + // delete file + warn!("removing SRS file at {}", path.display()); + std::fs::remove_file(path)?; + return Err( + "SRS hash does not match the expected hash. Remote SRS may have been tampered with." + .into(), + ); + } + Ok(hash) +} + #[cfg(not(target_arch = "wasm32"))] pub(crate) async fn get_srs_cmd( srs_path: Option, @@ -436,6 +459,7 @@ pub(crate) async fn get_srs_cmd( check_mode: CheckMode, ) -> Result> { // logrows overrides settings + let k = if let Some(k) = logrows { k } else if let Some(settings_p) = settings_path { @@ -470,12 +494,14 @@ pub(crate) async fn get_srs_cmd( pb.finish_with_message("SRS validated"); } - let mut file = std::fs::File::create(get_srs_path(k, srs_path))?; + let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?; file.write_all(reader.get_ref())?; info!("SRS downloaded"); } else { info!("SRS already exists at that path"); - } + }; + // check the hash + check_srs_hash(k, srs_path.clone())?; Ok(String::new()) } @@ -1081,8 +1107,8 @@ pub(crate) fn print_proof_hex(proof_path: PathBuf) -> Result) -> Result<(), std::fmt:: /// initializes the logger pub fn init_logger() { - let start = Instant::now(); let mut builder = Builder::new(); builder.format(move |buf, record| { writeln!( buf, - "{} [{}s, {}] - {}", + "{} [{}, {}] - {}", prefix_token(&record.level()), - start.elapsed().as_secs(), + // pretty print UTC time + chrono::Utc::now() + .format("%Y-%m-%d %H:%M:%S") + .to_string() + .bright_magenta(), record.metadata().target(), level_text_color(&record.level(), &format!("{}", record.args())) .replace('\n', &format!("\n{} ", " | ".white().bold())) diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index 8f015e50d..88c33cf65 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -234,6 +234,8 @@ where pub instances: Vec>, /// the proof pub proof: Vec, + /// hex encoded proof + pub hex_proof: Option, /// transcript type pub transcript_type: TranscriptType, /// the split proof @@ -261,7 +263,7 @@ where .collect::>(); dict.set_item("instances", field_elems).unwrap(); let hex_proof = hex::encode(&self.proof); - dict.set_item("proof", hex_proof).unwrap(); + dict.set_item("proof", format!("0x{}", hex_proof)).unwrap(); dict.set_item("transcript_type", self.transcript_type) .unwrap(); dict.to_object(py) @@ -281,6 +283,7 @@ where protocol: Option>, instances: Vec>, proof: Vec, + hex_proof: Option, transcript_type: TranscriptType, split: Option, pretty_public_inputs: Option, @@ -289,6 +292,7 @@ where protocol, instances, proof, + hex_proof, transcript_type, split, pretty_public_inputs, @@ -523,8 +527,17 @@ where &mut transcript, )?; let proof = transcript.finalize(); - - let checkable_pf = Snark::new(protocol, instances, proof, transcript_type, split, None); + let hex_proof = format!("0x{}", hex::encode(&proof)); + + let checkable_pf = Snark::new( + protocol, + instances, + proof, + Some(hex_proof), + transcript_type, + split, + None, + ); // sanity check that the generated proof is valid if check_mode == CheckMode::SAFE { @@ -894,6 +907,7 @@ mod tests { instances: vec![vec![Fr::from(1)], vec![Fr::from(2)]], transcript_type: TranscriptType::EVM, protocol: None, + hex_proof: None, split: None, pretty_public_inputs: None, timestamp: None, diff --git a/src/python.rs b/src/python.rs index 059a496ed..61beedcd2 100644 --- a/src/python.rs +++ b/src/python.rs @@ -1015,7 +1015,8 @@ fn print_proof_hex(proof_path: PathBuf) -> Result { let proof = Snark::load::>(&proof_path) .map_err(|_| PyIOError::new_err("Failed to load proof"))?; - Ok(hex::encode(proof.proof)) + let hex_str = hex::encode(proof.proof); + Ok(format!("0x{}", hex_str)) } // Python Module diff --git a/src/srs_sha.rs b/src/srs_sha.rs new file mode 100644 index 000000000..ae7daa44d --- /dev/null +++ b/src/srs_sha.rs @@ -0,0 +1,112 @@ +use lazy_static::lazy_static; +use std::collections::HashMap; + +lazy_static! { + /// SRS SHA256 hashes + pub static ref PUBLIC_SRS_SHA256_HASHES: HashMap = HashMap::from_iter([ + ( + 1, + "cafb2aa72c200ddc4e28aacabb8066e829207e2484b8d17059a566232f8a297b", + ), + ( + 2, + "8194ec51da5d332d2e17283ade34920644774452c2fadf33742e8c739e275d8e", + ), + ( + 3, + "0729e815bce2ac4dfad7819982c6479c3b22c32b71f64dca05e8fdd90e8535ef", + ), + ( + 4, + "2c0785da20217fcafd3b12cc363a95eb2529037cc8a9bddf8fb15025cbc8cdc9", + ), + ( + 5, + "5b950e3b76e7a9923d69f6d6585ce6b5f9458e5ec57a71c9de5005d32d544692", + ), + ( + 6, + "85030b2924111fc60acaf4fb8a7bad89531fbe0271aeab0c21e545f71eee273d", + ), + ( + 7, + "e65f95150519fe01c2bedf8f832f5249822ef84c9c017307419e10374ff9eeb1", + ), + ( + 8, + "446092fd1d6030e5bb2f2a8368267d5ed0fbdb6a766f6c5e4a4841827ad3106f", + ), + ( + 9, + "493d088951882ad81af11e08c791a38a37c0ffff14578cf2c7fb9b7bca654d8b", + ), + ( + 10, + "9705d450e5dfd06adb673705f7bc34418ec86339203198beceb2ae7f1ffefedb", + ), + ( + 11, + "257fa566ed9bc0767d3e63e92b5e966829fa3347d320a32055dc31ee7d33f8a4", + ), + ( + 12, + "28b151069f41abc121baa6d2eaa8f9e4c4d8326ddbefee2bd9c0776b80ac6fad", + ), + ( + 13, + "d5d94bb25bdc024f649213593027d861042ee807cafd94b49b54f1663f8f267d", + ), + ( + 14, + "c09129f064c08ecb07ea3689a2247dcc177de6837e7d2f5f946e30453abbccef", + ), + ( + 15, + "90807800a1c3b248a452e1732c45ee5099f38b737356f5542c0584ec9c3ebb45", + ), + ( + 16, + "2a1a494630e71bc026dd5c0eab4c1b9a5dbc656228c1f0d48f5dbd3909b161d3", + ), + ( + 17, + "41509f380362a8d14401c5ae92073154922fe23e45459ce6f696f58607655db7", + ), + ( + 18, + "d0148475717a2ba269784a178cb0ab617bc77f16c58d4a3cbdfe785b591c7034", + ), + ( + 19, + "d1a1655b4366a766d1578beb257849a92bf91cb1358c1a2c37ab180c5d3a204d", + ), + ( + 20, + "54ef75911da76d7a6b7ea341998aaf66cb06c679c53e0a88a4fe070dd3add963", + ), + ( + 21, + "486e044cf98704e07f41137d2b89698dc03d1fbf34d13b60902fea19a6013b4b", + ), + ( + 22, + "1ee9b4396db3e4e2516ac5016626ab6ba967f091d5d23afbdb7df122a0bb9d0c", + ), + ( + 23, + "748e48b9b6d06f9c82d26bf551d0af43ee2e801e4be56d7ccb20312e267fd1d6", + ), + ( + 24, + "f94fa4afa2f5147680f907d4dd96a8826206c26bd3328cd379feaed614b234de", + ), + ( + 25, + "dec49a69893fbcd66cd06296b2d936a6aceb431c130b2e52675fe4274b504f57", + ), + ( + 26, + "b198a51d48b88181508d8e4ea9dea39db285e4585663b29b7e4ded0c22a94875", + ), + ]); +} diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index cb84e246d..f77a99a4c 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -2262,6 +2262,7 @@ pub fn deconv< /// * `padding` - Tuple of padding values in x and y directions. /// * `stride` - Tuple of stride values in x and y directions. /// * `pool_dims` - Tuple of pooling window size in x and y directions. +/// * `normalize` - Flag to normalize the output by the number of elements in the pooling window. /// # Examples /// ``` /// use ezkl::tensor::Tensor; @@ -2277,6 +2278,11 @@ pub fn deconv< /// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap().0; /// let expected: Tensor = Tensor::::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap(); /// assert_eq!(pooled, expected); +/// +/// // This time with normalization +/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap().0; +/// let expected: Tensor = Tensor::::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(pooled, expected); /// ``` pub fn sumpool( image: &Tensor, diff --git a/src/wasm.rs b/src/wasm.rs index 6534a3cad..b58c1583b 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -388,7 +388,8 @@ pub fn prove( pub fn printProofHex(proof: wasm_bindgen::Clamped>) -> Result { let proof: crate::pfsys::Snark = serde_json::from_slice(&proof[..]) .map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?; - Ok(hex::encode(proof.proof)) + let hex_str = hex::encode(proof.proof); + Ok(format!("0x{}", hex_str)) } // VALIDATION FUNCTIONS