diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index ebc1ba5df..bc96a7b81 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -293,12 +293,7 @@ jobs: prove-and-verify-aggr-evm-tests: runs-on: large-self-hosted - needs: - [ - build, - library-tests, - python-tests, - ] + needs: [build, library-tests, python-tests] steps: - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 @@ -460,6 +455,8 @@ jobs: # # now dump the contents of the file into a file called kaggle.json # echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json # chmod 600 /home/ubuntu/.kaggle/kaggle.json + - name: SVM + run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_18_expects - name: LightGBM run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_17_expects - name: XGBoost diff --git a/Cargo.lock b/Cargo.lock index dfcf0345f..4d7768fac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5056,7 +5056,7 @@ dependencies = [ [[package]] name = "tract-core" version = "0.20.19-pre" -source = "git+https://github.com/sonos/tract/?rev=561614519e6cb49eea4d88dcee3b880f127813cb#561614519e6cb49eea4d88dcee3b880f127813cb" +source = "git+https://github.com/sonos/tract/?rev=2ea76c09678f092d00713ebbe6fdb046c0a9ad0f#2ea76c09678f092d00713ebbe6fdb046c0a9ad0f" dependencies = [ "anyhow", "bit-set", @@ -5079,7 +5079,7 @@ dependencies = [ [[package]] name = "tract-data" version = "0.20.19-pre" -source = "git+https://github.com/sonos/tract/?rev=561614519e6cb49eea4d88dcee3b880f127813cb#561614519e6cb49eea4d88dcee3b880f127813cb" +source = "git+https://github.com/sonos/tract/?rev=2ea76c09678f092d00713ebbe6fdb046c0a9ad0f#2ea76c09678f092d00713ebbe6fdb046c0a9ad0f" dependencies = [ "anyhow", "half 2.2.1", @@ -5098,7 +5098,7 @@ dependencies = [ [[package]] name = "tract-hir" version = "0.20.19-pre" -source = "git+https://github.com/sonos/tract/?rev=561614519e6cb49eea4d88dcee3b880f127813cb#561614519e6cb49eea4d88dcee3b880f127813cb" +source = "git+https://github.com/sonos/tract/?rev=2ea76c09678f092d00713ebbe6fdb046c0a9ad0f#2ea76c09678f092d00713ebbe6fdb046c0a9ad0f" dependencies = [ "derive-new", "log", @@ -5108,7 +5108,7 @@ dependencies = [ [[package]] name = "tract-linalg" version = "0.20.19-pre" -source = "git+https://github.com/sonos/tract/?rev=561614519e6cb49eea4d88dcee3b880f127813cb#561614519e6cb49eea4d88dcee3b880f127813cb" +source = "git+https://github.com/sonos/tract/?rev=2ea76c09678f092d00713ebbe6fdb046c0a9ad0f#2ea76c09678f092d00713ebbe6fdb046c0a9ad0f" dependencies = [ "cc", "derive-new", @@ -5132,7 +5132,7 @@ dependencies = [ [[package]] name = "tract-nnef" version = "0.20.19-pre" -source = "git+https://github.com/sonos/tract/?rev=561614519e6cb49eea4d88dcee3b880f127813cb#561614519e6cb49eea4d88dcee3b880f127813cb" +source = "git+https://github.com/sonos/tract/?rev=2ea76c09678f092d00713ebbe6fdb046c0a9ad0f#2ea76c09678f092d00713ebbe6fdb046c0a9ad0f" dependencies = [ "byteorder", "flate2", @@ -5146,7 +5146,7 @@ dependencies = [ [[package]] name = "tract-onnx" version = "0.20.19-pre" -source = "git+https://github.com/sonos/tract/?rev=561614519e6cb49eea4d88dcee3b880f127813cb#561614519e6cb49eea4d88dcee3b880f127813cb" +source = "git+https://github.com/sonos/tract/?rev=2ea76c09678f092d00713ebbe6fdb046c0a9ad0f#2ea76c09678f092d00713ebbe6fdb046c0a9ad0f" dependencies = [ "bytes", "derive-new", @@ -5163,7 +5163,7 @@ dependencies = [ [[package]] name = "tract-onnx-opl" version = "0.20.19-pre" -source = "git+https://github.com/sonos/tract/?rev=561614519e6cb49eea4d88dcee3b880f127813cb#561614519e6cb49eea4d88dcee3b880f127813cb" +source = "git+https://github.com/sonos/tract/?rev=2ea76c09678f092d00713ebbe6fdb046c0a9ad0f#2ea76c09678f092d00713ebbe6fdb046c0a9ad0f" dependencies = [ "getrandom", "log", diff --git a/Cargo.toml b/Cargo.toml index 238b38d61..1dbf19115 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,7 +51,7 @@ tokio = { version = "1.26.0", default_features = false, features = ["macros", " pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true } pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true } pyo3-log = { version = "0.8.1", default_features = false, optional = true } -tract-onnx = { git = "https://github.com/sonos/tract/", rev= "561614519e6cb49eea4d88dcee3b880f127813cb", default_features = false, optional = true } +tract-onnx = { git = "https://github.com/sonos/tract/", rev= "2ea76c09678f092d00713ebbe6fdb046c0a9ad0f", default_features = false, optional = true } tabled = { version = "0.12.0", optional = true } diff --git a/examples/notebooks/decision_tree.ipynb b/examples/notebooks/decision_tree.ipynb index a602d3d91..d3c4ca8cc 100644 --- a/examples/notebooks/decision_tree.ipynb +++ b/examples/notebooks/decision_tree.ipynb @@ -28,7 +28,7 @@ "![image-2.png](attachment:image-2.png)\n", "\n", "\n", - "This notebook showcases how to do that using the `sk2torch` python package ! " + "This notebook showcases how to do that using the `hummingbird-ml` python package ! " ] }, { @@ -46,7 +46,7 @@ " import sys\n", " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n", " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n", - " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"sk2torch\"])\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n", "\n", "# rely on local installation of ezkl if the notebook is not in colab\n", "except:\n", @@ -61,7 +61,7 @@ "from sklearn.datasets import load_iris\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.tree import DecisionTreeClassifier as De\n", - "import sk2torch\n", + "from hummingbird.ml import convert\n", "import torch\n", "import ezkl\n", "import os\n", @@ -75,7 +75,7 @@ "clr = De()\n", "clr.fit(X_train, y_train)\n", "\n", - "circuit = sk2torch.wrap(clr)\n", + "circuit = convert(clr, \"torch\", X_test[:1]).model\n", "\n", "\n", "\n" @@ -282,7 +282,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.9.15" } }, "nbformat": 4, diff --git a/examples/notebooks/svm.ipynb b/examples/notebooks/svm.ipynb new file mode 100644 index 000000000..a5a2ce6a4 --- /dev/null +++ b/examples/notebooks/svm.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67", + "metadata": {}, + "source": [ + "## Support Vector Machines\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95613ee9", + "metadata": {}, + "outputs": [], + "source": [ + "# check if notebook is in colab\n", + "try:\n", + " # install ezkl\n", + " import google.colab\n", + " import subprocess\n", + " import sys\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"sk2torch\"])\n", + "\n", + "# rely on local installation of ezkl if the notebook is not in colab\n", + "except:\n", + " pass\n", + "\n", + "\n", + "# here we create and (potentially train a model)\n", + "\n", + "# make sure you have the dependencies required here already installed\n", + "import json\n", + "import numpy as np\n", + "from sklearn.svm import SVC\n", + "import sk2torch\n", + "import torch\n", + "import ezkl\n", + "import os\n", + "\n", + "\n", + "# Create a dataset of two Gaussians. There will be some overlap\n", + "# between the two classes, which adds some uncertainty to the model.\n", + "xs = np.concatenate(\n", + " [\n", + " np.random.random(size=(256, 2)) + [1, 0],\n", + " np.random.random(size=(256, 2)) + [-1, 0],\n", + " ],\n", + " axis=0,\n", + ")\n", + "ys = np.array([False] * 256 + [True] * 256)\n", + "\n", + "# Train an SVM on the data and wrap it in PyTorch.\n", + "sk_model = SVC(probability=True)\n", + "sk_model.fit(xs, ys)\n", + "model = sk2torch.wrap(sk_model)\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b37637c4", + "metadata": {}, + "outputs": [], + "source": [ + "model_path = os.path.join('network.onnx')\n", + "compiled_model_path = os.path.join('network.compiled')\n", + "pk_path = os.path.join('test.pk')\n", + "vk_path = os.path.join('test.vk')\n", + "settings_path = os.path.join('settings.json')\n", + "srs_path = os.path.join('kzg.srs')\n", + "witness_path = os.path.join('witness.json')\n", + "data_path = os.path.join('input.json')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f0ca328", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "# Create a coordinate grid to compute a vector field on.\n", + "spaced = np.linspace(-2, 2, num=25)\n", + "grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n", + "\n", + "\n", + "# Compute the gradients of the SVM output.\n", + "outputs = model.predict_proba(grid_xs)[:, 1]\n", + "(input_grads,) = torch.autograd.grad(outputs.sum(), (grid_xs,))\n", + "\n", + "\n", + "# Create a quiver plot of the vector field.\n", + "plt.quiver(\n", + " grid_xs[:, 0].detach().numpy(),\n", + " grid_xs[:, 1].detach().numpy(),\n", + " input_grads[:, 0].detach().numpy(),\n", + " input_grads[:, 1].detach().numpy(),\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82db373a", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "# export to onnx format\n", + "# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n", + "\n", + "# Input to the model\n", + "shape = xs.shape[1:]\n", + "x = grid_xs[0:1]\n", + "torch_out = model.predict(x)\n", + "# Export the model\n", + "torch.onnx.export(model, # model being run\n", + " # model input (or a tuple for multiple inputs)\n", + " x,\n", + " # where to save the model (can be a file or file-like object)\n", + " \"network.onnx\",\n", + " export_params=True, # store the trained parameter weights inside the model file\n", + " opset_version=10, # the ONNX version to export the model to\n", + " do_constant_folding=True, # whether to execute constant folding for optimization\n", + " input_names=['input'], # the model's input names\n", + " output_names=['output'], # the model's output names\n", + " dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n", + " 'output': {0: 'batch_size'}})\n", + "\n", + "d = ((x).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_shapes=[shape],\n", + " input_data=[d],\n", + " output_data=[o.reshape([-1]).tolist() for o in torch_out])\n", + "\n", + "# Serialize data into file:\n", + "json.dump(data, open(\"input.json\", 'w'))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5e374a2", + "metadata": {}, + "outputs": [], + "source": [ + "!RUST_LOG=trace\n", + "# TODO: Dictionary outputs\n", + "res = ezkl.gen_settings(model_path, settings_path)\n", + "assert res == True\n", + "\n", + "res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", + "assert res == True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3aa4f090", + "metadata": {}, + "outputs": [], + "source": [ + "res = ezkl.compile_model(model_path, compiled_model_path, settings_path)\n", + "assert res == True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b74dcee", + "metadata": {}, + "outputs": [], + "source": [ + "# srs path\n", + "res = ezkl.get_srs(srs_path, settings_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18c8b7c7", + "metadata": {}, + "outputs": [], + "source": [ + "# now generate the witness file \n", + "\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)\n", + "assert os.path.isfile(witness_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1c561a8", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# HERE WE SETUP THE CIRCUIT PARAMS\n", + "# WE GOT KEYS\n", + "# WE GOT CIRCUIT PARAMETERS\n", + "# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n", + "\n", + "\n", + "\n", + "res = ezkl.setup(\n", + " compiled_model_path,\n", + " vk_path,\n", + " pk_path,\n", + " srs_path,\n", + " settings_path,\n", + " )\n", + "\n", + "assert res == True\n", + "assert os.path.isfile(vk_path)\n", + "assert os.path.isfile(pk_path)\n", + "assert os.path.isfile(settings_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c384cbc8", + "metadata": {}, + "outputs": [], + "source": [ + "# GENERATE A PROOF\n", + "\n", + "\n", + "proof_path = os.path.join('test.pf')\n", + "\n", + "res = ezkl.prove(\n", + " witness_path,\n", + " compiled_model_path,\n", + " pk_path,\n", + " proof_path,\n", + " srs_path,\n", + " \"evm\",\n", + " \"single\",\n", + " settings_path,\n", + " )\n", + "\n", + "print(res)\n", + "assert os.path.isfile(proof_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76f00d41", + "metadata": {}, + "outputs": [], + "source": [ + "# VERIFY IT\n", + "\n", + "res = ezkl.verify(\n", + " proof_path,\n", + " settings_path,\n", + " vk_path,\n", + " srs_path,\n", + " )\n", + "\n", + "assert res == True\n", + "print(\"verified\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/onnx/hummingbird_decision_tree/gen.py b/examples/onnx/hummingbird_decision_tree/gen.py new file mode 100644 index 000000000..68b07387c --- /dev/null +++ b/examples/onnx/hummingbird_decision_tree/gen.py @@ -0,0 +1,50 @@ +# Train a model. +import json +import onnxruntime as rt +from skl2onnx import to_onnx +import numpy as np +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.tree import DecisionTreeClassifier as De +from hummingbird.ml import convert +import torch + +iris = load_iris() +X, y = iris.data, iris.target +X = X.astype(np.float32) +X_train, X_test, y_train, y_test = train_test_split(X, y) +clr = De() +clr.fit(X_train, y_train) + +torch_model = convert(clr, "pytorch").model + + +# Convert into ONNX format. +# export to onnx format + +# Input to the model +shape = X_train.shape[1:] +x = torch.rand(1, *shape, requires_grad=True) +torch_out = torch_model(x) +# Export the model +torch.onnx.export(torch_model, # model being run + # model input (or a tuple for multiple inputs) + x, + # where to save the model (can be a file or file-like object) + "network.onnx", + export_params=True, # store the trained parameter weights inside the model file + opset_version=10, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=['input'], # the model's input names + output_names=['output'], # the model's output names + dynamic_axes={'input': {0: 'batch_size'}, # variable length axes + 'output': {0: 'batch_size'}}) + +d = ((x).detach().numpy()).reshape([-1]).tolist() + +data = dict(input_shapes=[shape], + input_data=[d], + output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out]) + +# Serialize data into file: +json.dump(data, open("input.json", 'w')) diff --git a/examples/onnx/hummingbird_decision_tree/input.json b/examples/onnx/hummingbird_decision_tree/input.json new file mode 100644 index 000000000..884786fd5 --- /dev/null +++ b/examples/onnx/hummingbird_decision_tree/input.json @@ -0,0 +1 @@ +{"input_shapes": [[4]], "input_data": [[0.9813985824584961, 0.793540358543396, 0.548916757106781, 0.6483156681060791]], "output_data": [[0], [1.0, 0.0, 0.0]]} \ No newline at end of file diff --git a/examples/onnx/hummingbird_decision_tree/network.onnx b/examples/onnx/hummingbird_decision_tree/network.onnx new file mode 100644 index 000000000..1a1873e6a Binary files /dev/null and b/examples/onnx/hummingbird_decision_tree/network.onnx differ diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index 427cf9744..cc3727b02 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -49,6 +49,10 @@ pub enum HybridOp { dim: usize, k: usize, }, + OneHot { + dim: usize, + num_classes: usize, + }, GatherElements { dim: usize, constant_idx: Option>, @@ -129,6 +133,10 @@ impl Op for HybridOp { (res.clone(), inter_equals) } } + HybridOp::OneHot { dim, num_classes } => { + let res = tensor::ops::one_hot(&x, *num_classes, *dim)?; + (res.clone(), vec![]) + } HybridOp::TopK { dim, k } => { let res = tensor::ops::topk_axes(&x, *k, *dim)?; @@ -216,25 +224,36 @@ impl Op for HybridOp { } fn as_string(&self) -> String { - let name = match self { - HybridOp::Abs => "ABS", - HybridOp::ReduceMax { .. } => "REDUCEMAX", - HybridOp::ReduceArgMax { .. } => "REDUCEARGMAX", - HybridOp::MaxPool2d { .. } => "MAXPOOL2D", - HybridOp::ReduceMin { .. } => "REDUCEMIN", - HybridOp::ReduceArgMin { .. } => "REDUCEARGMIN", - HybridOp::Softmax { .. } => "SOFTMAX", - HybridOp::RangeCheck(..) => "RANGECHECK", - HybridOp::Greater { .. } => "GREATER", - HybridOp::GreaterEqual { .. } => "GREATEREQUAL", - HybridOp::Less { .. } => "LESS", - HybridOp::LessEqual { .. } => "LESSEQUAL", - HybridOp::Equals => "EQUALS", - HybridOp::Gather { .. } => "GATHER", - HybridOp::TopK { .. } => "TOPK", - HybridOp::GatherElements { .. } => "GATHERELEMENTS", - }; - name.into() + match self { + HybridOp::Abs => "ABS".into(), + HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes), + HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim), + HybridOp::MaxPool2d { + padding, + stride, + pool_dims, + } => format!( + "MAXPOOL2D (padding={:?}, stride={:?}, pool_dims={:?})", + padding, stride, pool_dims + ), + HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes), + HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim), + HybridOp::Softmax { scale, axes } => { + format!("SOFTMAX (scale={}, axes={:?})", scale, axes) + } + HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p), + HybridOp::Greater => "GREATER".into(), + HybridOp::GreaterEqual => "GREATEREQUAL".into(), + HybridOp::Less => "LESS".into(), + HybridOp::LessEqual => "LESSEQUAL".into(), + HybridOp::Equals => "EQUALS".into(), + HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim), + HybridOp::TopK { k, dim } => format!("TOPK (k={}, dim={})", k, dim), + HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim), + HybridOp::OneHot { dim, num_classes } => { + format!("ONEHOT (dim={}, num_classes={})", dim, num_classes) + } + } } fn layout( @@ -303,6 +322,9 @@ impl Op for HybridOp { HybridOp::TopK { dim, k } => { layouts::topk_axes(config, region, values[..].try_into()?, *k, *dim)? } + HybridOp::OneHot { dim, num_classes } => { + layouts::one_hot_axis(config, region, values[..].try_into()?, *num_classes, *dim)? + } })) } @@ -320,6 +342,7 @@ impl Op for HybridOp { | HybridOp::Less { .. } | HybridOp::LessEqual { .. } | HybridOp::ReduceArgMax { .. } + | HybridOp::OneHot { .. } | HybridOp::ReduceArgMin { .. } => 0, HybridOp::Softmax { .. } => 2 * in_scales[0], _ => in_scales[0], @@ -359,6 +382,7 @@ impl Op for HybridOp { | HybridOp::Less { .. } | HybridOp::Equals | HybridOp::Gather { .. } + | HybridOp::OneHot { .. } | HybridOp::TopK { .. } | HybridOp::GatherElements { .. } => { vec![LookupOp::GreaterThan { diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 3e7a671b0..84c715019 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -620,6 +620,128 @@ fn select( Ok(assigned_output) } +fn one_hot( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + num_classes: usize, +) -> Result, Box> { + // assert values is flat + assert_eq!(values[0].dims().len(), 1); + // assert its a single elelemnt + assert_eq!(values[0].len(), 1); + let input = values[0].clone(); + let is_assigned = !input.any_unknowns(); + + let output: ValTensor = if is_assigned { + let int_evals = input.get_int_evals()?; + let res = tensor::ops::one_hot(&int_evals, num_classes, 1)?; + res.iter() + .map(|x| Value::known(i128_to_felt(*x))) + .collect::>() + } else { + Tensor::new( + Some(&vec![Value::::unknown(); num_classes]), + &[num_classes], + )? + } + .into(); + + let assigned_input = region.assign(&config.inputs[0], &input)?; + + // now assert all elems are 0 or 1 + let assigned_output = region.assign(&config.inputs[1], &output)?; + for i in 0..assigned_output.len() { + let (x, y) = config.output.cartesian_coord(region.offset() + i); + let selector = config.selectors.get(&(BaseOp::IsBoolean, x)); + region.enable(selector, y)?; + } + region.increment(std::cmp::max(assigned_output.len(), assigned_input.len())); + + let sum = sum(config, region, &[assigned_output.clone()])?; + // assert sum is 1 + let mut unit = Tensor::from(vec![F::from(1)].into_iter()); + unit.set_visibility(crate::graph::Visibility::Public); + let unit = region.assign(&config.inputs[1], &unit.into())?; + region.assign(&config.output, &sum)?; + + let (x, y) = config.output.cartesian_coord(region.offset()); + let selector = config.selectors.get(&(BaseOp::Identity, x)); + region.enable(selector, y)?; + + region.increment(1); + + let gathered = gather( + config, + region, + &[assigned_output.clone(), assigned_input.clone()], + 0, + )?; + + region.assign(&config.inputs[1], &unit)?; + region.assign(&config.output, &gathered)?; + + let (x, y) = config.output.cartesian_coord(region.offset()); + let selector = config.selectors.get(&(BaseOp::Identity, x)); + region.enable(selector, y)?; + + region.increment(assigned_input.len()); + + Ok(assigned_output) +} + +/// One hot accumulated layout +pub fn one_hot_axis( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + num_classes: usize, + dim: usize, +) -> Result, Box> { + let input = values[0].clone(); + let input_inner = input.get_inner_tensor()?; + + let mut output_dims = values[0].dims().to_vec(); + output_dims.insert(dim, num_classes); + + let op_tensors = input_inner.enum_map(|_: usize, inp| { + let tensor = Tensor::new(Some(&[inp.clone()]), &[1]).unwrap(); + let res = one_hot(config, region, &[tensor.into()], num_classes).map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })?; + + Ok::<_, halo2_proofs::plonk::Error>(res) + })?; + + // Allocate memory for the output tensor + let cartesian_coord = output_dims + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let mut output = Tensor::>::new(None, &output_dims)?; + + output = output.enum_map(|i, _| { + let coord = cartesian_coord[i].clone(); + let mut op_idx = coord.clone(); + let coord_at_dims = vec![coord[dim]]; + op_idx.remove(dim); + + let op_tensor = op_tensors.get(&op_idx).get_inner_tensor().map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })?; + + let one_hot_val = op_tensor.get(&coord_at_dims).clone(); + + Ok::<_, halo2_proofs::plonk::Error>(one_hot_val) + })?; + + Ok(output.into()) +} + /// Gather accumulated layout pub fn gather( config: &BaseConfig, @@ -641,6 +763,12 @@ pub fn gather( // Calculate the output tensor size let input_dims = input.dims(); let mut output_size = input_dims.to_vec(); + if index.dims().is_empty() { + output_size.remove(dim); + input.reshape(&output_size)?; + return Ok(input); + } + output_size[dim] = index.dims()[0]; // Allocate memory for the output tensor @@ -671,7 +799,7 @@ pub fn gather( let output = output?; let mut output: ValTensor = Tensor::new(Some(&output), &[output.len()])?.into(); - // Reshape the output tensor + output.reshape(&output_size)?; Ok(output) @@ -887,10 +1015,10 @@ fn axes_wise_op( prod_dims.push(*c..*c + 1); } } - res.set( - coord, - op(config, region, &[a.get_slice(&prod_dims)?])?.get_inner_tensor()?[0].clone(), - ); + let values = a.get_slice(&prod_dims)?; + let op = op(config, region, &[values])?; + + res.set(coord, op.get_inner_tensor()?[0].clone()); } Ok(res.into()) @@ -1863,7 +1991,9 @@ pub fn concat( ) -> Result, Box> { let collected_inner: Result>, _> = values.iter().map(|e| e.get_inner_tensor()).collect(); - Ok(tensor::ops::concat(&collected_inner?, *axis)?.into()) + let collected_inner = collected_inner?; + + Ok(tensor::ops::concat(&collected_inner, *axis)?.into()) } /// Identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon. diff --git a/src/circuit/ops/poly.rs b/src/circuit/ops/poly.rs index 37db2c99d..580d3ad63 100644 --- a/src/circuit/ops/poly.rs +++ b/src/circuit/ops/poly.rs @@ -214,12 +214,7 @@ impl Deserialize< tensor::ops::prod_axes(&inputs[0], axes) } PolyOp::GlobalSumPool => unreachable!(), - PolyOp::Concat { axis } => { - if inputs.len() < 2 { - return Err(TensorError::DimMismatch("concat inputs".to_string())); - } - tensor::ops::concat(&inputs, *axis) - } + PolyOp::Concat { axis } => tensor::ops::concat(&inputs, *axis), PolyOp::Slice { axis, start, end } => { if 1 != inputs.len() { return Err(TensorError::DimMismatch("slice inputs".to_string())); @@ -332,12 +327,7 @@ impl Deserialize< layouts::pack(config, region, values[..].try_into()?, *base, *scale)? } PolyOp::GlobalSumPool => unreachable!(), - PolyOp::Concat { axis } => { - if values.len() < 2 { - return Err(Box::new(TensorError::DimError)); - } - layouts::concat(values[..].try_into()?, axis)? - } + PolyOp::Concat { axis } => layouts::concat(values[..].try_into()?, axis)?, PolyOp::Slice { axis, start, end } => { layouts::slice(config, region, values[..].try_into()?, axis, start, end)? } diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index fa47e1c9a..cca6d7ff9 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp}; #[cfg(not(target_arch = "wasm32"))] use tract_onnx::tract_core::ops::{ - array::{Gather, GatherElements, Slice, Topk}, + array::{Gather, GatherElements, OneHot, Slice, Topk}, change_axes::AxisOp, cnn::DeconvUnary, einsum::EinSum, @@ -97,14 +97,13 @@ fn extract_tensor_value( input: Arc, ) -> Result, Box> { let dt = input.datum_type(); - let mut dims = input.shape().to_vec(); - if dims.is_empty() { - dims.push(1) - } else if dims.iter().product::() == 1 { - dims = vec![1]; - }; + let dims = input.shape().to_vec(); let mut const_value: Tensor; + if dims.is_empty() { + const_value = Tensor::::new(None, &dims)?; + return Ok(const_value); + } match dt { DatumType::F32 => { @@ -245,6 +244,16 @@ pub fn new_op_from_onnx( SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::TopK { dim: axis, k }) } + "Onehot" => { + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axis = op.axis; + let num_classes = op.dim; + + SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::OneHot { + dim: axis, + num_classes, + }) + } "GatherElements" => { if inputs.len() != 2 { return Err(Box::new(GraphError::InvalidDims( @@ -333,7 +342,12 @@ pub fn new_op_from_onnx( quantize_tensor(raw_value.clone(), constant_scale, param_visibility)?; let mut c = crate::circuit::ops::Constant::new(quantized_value, raw_value); - c.num_uses += node.outputs.len(); + + c.num_uses += node + .outputs + .iter() + .map(|outlet| outlet.successors.len()) + .sum::(); // Create a constant op SupportedOp::Constant(c) } diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 9eee0de5b..ffe40e777 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -448,7 +448,11 @@ impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync> impl Tensor { /// Sets (copies) the tensor values to the provided ones. pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result { - let total_dims: usize = dims.iter().product(); + let total_dims: usize = if !dims.is_empty() { + dims.iter().product() + } else { + 0 + }; match values { Some(v) => { if total_dims != v.len() { @@ -492,7 +496,11 @@ impl Tensor { /// Returns the number of elements in the tensor. pub fn len(&self) -> usize { - self.dims().iter().product::() + if !self.dims().is_empty() && (self.dims() != &[0]) { + self.dims().iter().product::() + } else { + 0 + } } /// Checks if the number of elements in tensor is 0. pub fn is_empty(&self) -> bool { @@ -710,8 +718,19 @@ impl Tensor { /// assert_eq!(a.dims(), &[9, 3]); /// ``` pub fn reshape(&mut self, new_dims: &[usize]) { - assert!(self.len() == new_dims.iter().product::()); - self.dims = Vec::from(new_dims); + // in onnx parlance this corresponds to converting a tensor to a single element + if new_dims.is_empty() { + assert!(self.len() == 1 || self.len() == 0); + self.flatten(); + } else { + let product = if new_dims != &[0] { + new_dims.iter().product::() + } else { + 0 + }; + assert!(self.len() == product); + self.dims = Vec::from(new_dims); + } } /// Move axis of the tensor @@ -901,7 +920,9 @@ impl Tensor { /// assert_eq!(a.dims(), &[27]); /// ``` pub fn flatten(&mut self) { - self.dims = Vec::from([self.dims.iter().product::()]); + if !self.dims().is_empty() && (self.dims() != &[0]) { + self.dims = Vec::from([self.dims.iter().product::()]); + } } /// Maps a function to tensors diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 31033c6c8..8c6ed502a 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -1237,6 +1237,14 @@ pub fn gather( // Calculate the output tensor size let mut output_size = input.dims().to_vec(); + // Reshape the output tensor + if index.len() == 0 { + output_size.remove(dim); + let mut input = input.clone(); + input.reshape(&output_size); + return Ok(input); + } + output_size[dim] = index.dims()[0]; // Allocate memory for the output tensor @@ -1259,7 +1267,6 @@ pub fn gather( Ok(input.get(&new_coord)) })?; - // Reshape the output tensor output.reshape(&output_size); Ok(output) @@ -1945,6 +1952,52 @@ pub fn intercalate_values( Ok(output) } +/// One hot encodes a tensor along a given axis. +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::one_hot; +/// let tensor = Tensor::::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap(); +/// let result = one_hot(&tensor, 5, 2).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 0, +/// 0, 0, 1, 0, 0, +/// 0, 0, 0, 1, 0, +/// 0, 0, 0, 0, 1]), &[2, 2, 5]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn one_hot( + tensor: &Tensor, + num_classes: usize, + axis: usize, +) -> Result, TensorError> { + let mut output_dims = tensor.dims().to_vec(); + output_dims.insert(axis, num_classes); + + let mut output: Tensor = Tensor::new(None, &output_dims)?; + + let cartesian_coord = output + .dims() + .iter() + .map(|d| (0..*d)) + .multi_cartesian_product() + .collect::>(); + + output.iter_mut().enumerate().for_each(|(i, o)| { + let coord = &cartesian_coord[i]; + let coord_axis = coord[axis]; + + let mut coord_without_axis = coord.clone(); + coord_without_axis.remove(axis); + + if coord_axis == tensor.get(&coord_without_axis) as usize { + *o = 1; + } else { + *o = 0; + } + }); + + Ok(output) +} + /// Performs a 2D deconvolution on the given input tensor. /// # Examples /// ``` @@ -2522,6 +2575,10 @@ where /// Returns a TensorError if the tensors in `inputs` have incompatible dimensions for concatenation along the specified `axis`. pub fn concat(inputs: &[Tensor], axis: usize) -> Result, TensorError> { + if inputs.len() == 1 { + return Ok(inputs[0].clone()); + } + // Calculate the output tensor size let mut output_size = inputs[0].dims().to_vec(); output_size[axis] = inputs.iter().map(|x| x.dims()[axis]).sum(); diff --git a/src/tensor/val.rs b/src/tensor/val.rs index a5f9a4a8a..ad580bf8e 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -207,9 +207,13 @@ impl From>> f impl ValTensor { /// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`. - pub fn new_instance(cs: &mut ConstraintSystem, dims: Vec, scale: u32) -> Self { + pub fn new_instance(cs: &mut ConstraintSystem, mut dims: Vec, scale: u32) -> Self { let col = cs.instance_column(); cs.enable_equality(col); + // force there to be at least one dimension + if dims.is_empty() || dims == vec![0] { + dims = vec![1]; + } ValTensor::Instance { inner: col, dims, @@ -589,7 +593,11 @@ impl ValTensor { pub fn len(&self) -> usize { match self { ValTensor::Value { dims, .. } | ValTensor::Instance { dims, .. } => { - dims.iter().product::() + if !dims.is_empty() && (dims != &[0]) { + dims.iter().product::() + } else { + 0 + } } } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index d4609d0b1..708480f66 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -169,7 +169,7 @@ mod native_tests { "1l_prelu", ]; - const TESTS: [&str; 51] = [ + const TESTS: [&str; 52] = [ "1l_mlp", "1l_slice", "1l_concat", @@ -224,6 +224,7 @@ mod native_tests { "1l_topk", "xgboost", "lightgbm", + "hummingbird_decision_tree", ]; const TESTS_AGGR: [&str; 21] = [ @@ -379,7 +380,7 @@ mod native_tests { } }); - seq!(N in 0..=50 { + seq!(N in 0..=51 { #(#[test_case(TESTS[N])])* fn model_serialization_(test: &str) { diff --git a/tests/py_integration_tests.rs b/tests/py_integration_tests.rs index fe6a73498..52f7c928c 100644 --- a/tests/py_integration_tests.rs +++ b/tests/py_integration_tests.rs @@ -110,7 +110,7 @@ mod py_tests { } } - const TESTS: [&str; 18] = [ + const TESTS: [&str; 19] = [ "mnist_gan.ipynb", // "mnist_vae.ipynb", "keras_simple_demo.ipynb", @@ -130,6 +130,7 @@ mod py_tests { "gradient_boosted_trees.ipynb", "xgboost.ipynb", "lightgbm.ipynb", + "svm.ipynb", ]; macro_rules! test_func { @@ -142,7 +143,7 @@ mod py_tests { use super::*; - seq!(N in 0..=17 { + seq!(N in 0..=18 { #(#[test_case(TESTS[N])])* fn run_notebook_(test: &str) { crate::py_tests::init_binary(); diff --git a/tests/python/binding_tests.py b/tests/python/binding_tests.py index 9e5489fe5..7a1d5771b 100644 --- a/tests/python/binding_tests.py +++ b/tests/python/binding_tests.py @@ -69,6 +69,7 @@ def test_field_serialization(): roundtrip_input = ezkl.vecu64_to_float(felt, scale) assert input == roundtrip_input + def test_buffer_to_felts(): """ Test buffer_to_felt @@ -81,7 +82,8 @@ def test_buffer_to_felts(): buffer = bytearray("a sample string!"+"high", 'utf-8') felts = ezkl.buffer_to_felts(buffer) ref_felt_2 = "0x0000000000000000000000000000000000000000000000000000000068676968" - assert felts == [ref_felt_1,ref_felt_2] + assert felts == [ref_felt_1, ref_felt_2] + def test_table_1l_average(): """ @@ -96,17 +98,17 @@ def test_table_1l_average(): expected_table = ( " \n" - "┌─────┬─────────┬───────────┬──────────┬──────────────┬──────────────────┐\n" - "│ idx │ opkind │ out_scale │ inputs │ out_dims │ required_lookups │\n" - "├─────┼─────────┼───────────┼──────────┼──────────────┼──────────────────┤\n" - "│ 0 │ Input │ 7 │ │ [1, 3, 2, 2] │ [] │\n" - "├─────┼─────────┼───────────┼──────────┼──────────────┼──────────────────┤\n" - "│ 1 │ PAD │ 7 │ [(0, 0)] │ [1, 3, 4, 4] │ [] │\n" - "├─────┼─────────┼───────────┼──────────┼──────────────┼──────────────────┤\n" - "│ 2 │ SUMPOOL │ 7 │ [(1, 0)] │ [1, 3, 3, 3] │ [] │\n" - "├─────┼─────────┼───────────┼──────────┼──────────────┼──────────────────┤\n" - "│ 3 │ RESHAPE │ 7 │ [(2, 0)] │ [3, 3, 3] │ [] │\n" - "└─────┴─────────┴───────────┴──────────┴──────────────┴──────────────────┘" + "┌─────┬────────────────┬───────────┬──────────┬──────────────┬──────────────────┐\n" + "│ idx │ opkind │ out_scale │ inputs │ out_dims │ required_lookups │\n" + "├─────┼────────────────┼───────────┼──────────┼──────────────┼──────────────────┤\n" + "│ 0 │ Input │ 7 │ │ [1, 3, 2, 2] │ [] │\n" + "├─────┼────────────────┼───────────┼──────────┼──────────────┼──────────────────┤\n" + "│ 1 │ PAD │ 7 │ [(0, 0)] │ [1, 3, 4, 4] │ [] │\n" + "├─────┼────────────────┼───────────┼──────────┼──────────────┼──────────────────┤\n" + "│ 2 │ SUMPOOL │ 7 │ [(1, 0)] │ [1, 3, 3, 3] │ [] │\n" + "├─────┼────────────────┼───────────┼──────────┼──────────────┼──────────────────┤\n" + "│ 4 │ GATHER (dim=0) │ 7 │ [(2, 0)] │ [3, 3, 3] │ [\"GREATER_THAN\"] │\n" + "└─────┴────────────────┴───────────┴──────────┴──────────────┴──────────────────┘" ) assert ezkl.table(path) == expected_table