diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c78102306..2e7665d55 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -584,3 +584,5 @@ jobs: run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_6_expects - name: Gan tutorial run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_0_expects + - name: GCN tutorial + run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_21_expects diff --git a/examples/notebooks/gcn.ipynb b/examples/notebooks/gcn.ipynb new file mode 100644 index 000000000..b9c1e1fa4 --- /dev/null +++ b/examples/notebooks/gcn.ipynb @@ -0,0 +1,622 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5fe9feb6-2b35-414a-be9d-771eabdbb0dc", + "metadata": { + "id": "5fe9feb6-2b35-414a-be9d-771eabdbb0dc" + }, + "source": [ + "## EZKL GCN Notebook" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "nGcl_1sltpRq", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nGcl_1sltpRq", + "outputId": "642693ac-970f-4ad9-80f5-e58c69f04ee9" + }, + "outputs": [], + "source": [ + "!pip install torch-scatter torch-sparse torch-geometric" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1005303a-cd48-4766-9c43-2116f94ed381", + "metadata": { + "id": "1005303a-cd48-4766-9c43-2116f94ed381" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "\n", + "# check if notebook is in colab\n", + "try:\n", + " # install ezkl\n", + " import google.colab\n", + " import subprocess\n", + " import sys\n", + " for e in [\"ezkl\", \"onnx\", \"torch\", \"torchvision\", \"torch-scatter\", \"torch-sparse\", \"torch-geometric\"]:\n", + " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", e])\n", + "\n", + "# rely on local installation of ezkl if the notebook is not in colab\n", + "except:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89e5732e-a97b-445e-9174-69689e37e72c", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "89e5732e-a97b-445e-9174-69689e37e72c", + "outputId": "24049b0a-439b-4327-a829-4b4045490f0f" + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch_geometric.data import Data\n", + "\n", + "edge_index = torch.tensor([[2, 1, 3],\n", + " [0, 0, 2]], dtype=torch.long)\n", + "x = torch.tensor([[1], [1], [1]], dtype=torch.float)\n", + "\n", + "data = Data(x=x, edge_index=edge_index)\n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73b34e81-63cb-44b0-9f95-f8490e844676", + "metadata": { + "id": "73b34e81-63cb-44b0-9f95-f8490e844676" + }, + "outputs": [], + "source": [ + "import torch\n", + "import math\n", + "from torch_geometric.nn import MessagePassing\n", + "from torch.nn.modules.module import Module\n", + "\n", + "def glorot(tensor):\n", + " if tensor is not None:\n", + " stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))\n", + " tensor.data.uniform_(-stdv, stdv)\n", + "\n", + "\n", + "def zeros(tensor):\n", + " if tensor is not None:\n", + " tensor.data.fill_(0)\n", + "\n", + "class GCNConv(Module):\n", + " def __init__(self, in_channels, out_channels):\n", + " super(GCNConv, self).__init__() # \"Add\" aggregation.\n", + " self.lin = torch.nn.Linear(in_channels, out_channels)\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " glorot(self.lin.weight)\n", + " zeros(self.lin.bias)\n", + "\n", + " def forward(self, x, adj_t, deg):\n", + " x = self.lin(x)\n", + " adj_t = self.normalize_adj(adj_t, deg)\n", + " x = adj_t @ x\n", + "\n", + " return x\n", + "\n", + " def normalize_adj(self, adj_t, deg):\n", + " deg.masked_fill_(deg == 0, 1.)\n", + " deg_inv_sqrt = deg.pow_(-0.5)\n", + " deg_inv_sqrt.masked_fill_(deg_inv_sqrt == 1, 0.)\n", + " adj_t = adj_t * deg_inv_sqrt.view(-1, 1) # N, 1\n", + " adj_t = adj_t * deg_inv_sqrt.view(1, -1) # 1, N\n", + "\n", + " return adj_t" + ] + }, + { + "cell_type": "markdown", + "id": "ae70bc34-def7-40fd-9558-2500c6f29323", + "metadata": { + "id": "ae70bc34-def7-40fd-9558-2500c6f29323" + }, + "source": [ + "## Train Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ca117a1-7473-42a6-be95-dc314eb3e251", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7ca117a1-7473-42a6-be95-dc314eb3e251", + "outputId": "edacee52-8a88-4c02-9a71-fd094e89c7b9" + }, + "outputs": [], + "source": [ + "import os\n", + "import os.path as osp\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch_geometric.datasets import Planetoid\n", + "import torch_geometric.transforms as T\n", + "\n", + "path = osp.join(os.getcwd(), 'data', 'Cora')\n", + "dataset = Planetoid(path, 'Cora')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "807f4d87-6acc-4cbb-80e4-8eb09feb994c", + "metadata": { + "id": "807f4d87-6acc-4cbb-80e4-8eb09feb994c" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "from torch import tensor\n", + "from torch.optim import Adam\n", + "\n", + "# define num feat to use for training here\n", + "num_feat = 10\n", + "\n", + "def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping):\n", + "\n", + " val_losses, accs, durations = [], [], []\n", + " for _ in range(runs):\n", + " data = dataset[0]\n", + " data = data.to(device)\n", + "\n", + " model.to(device).reset_parameters()\n", + " optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n", + "\n", + " if torch.cuda.is_available():\n", + " torch.cuda.synchronize()\n", + "\n", + " t_start = time.perf_counter()\n", + "\n", + " best_val_loss = float('inf')\n", + " test_acc = 0\n", + " val_loss_history = []\n", + "\n", + " for epoch in range(1, epochs + 1):\n", + " train(model, optimizer, data)\n", + " eval_info = evaluate(model, data)\n", + " eval_info['epoch'] = epoch\n", + "\n", + " if eval_info['val_loss'] < best_val_loss:\n", + " best_val_loss = eval_info['val_loss']\n", + " test_acc = eval_info['test_acc']\n", + "\n", + " val_loss_history.append(eval_info['val_loss'])\n", + " if early_stopping > 0 and epoch > epochs // 2:\n", + " tmp = tensor(val_loss_history[-(early_stopping + 1):-1])\n", + " if eval_info['val_loss'] > tmp.mean().item():\n", + " break\n", + "\n", + " if torch.cuda.is_available():\n", + " torch.cuda.synchronize()\n", + "\n", + " t_end = time.perf_counter()\n", + "\n", + " val_losses.append(best_val_loss)\n", + " accs.append(test_acc)\n", + " durations.append(t_end - t_start)\n", + "\n", + " loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations)\n", + "\n", + " print('Val Loss: {:.4f}, Test Accuracy: {:.3f} ± {:.3f}, Duration: {:.3f}'.\n", + " format(loss.mean().item(),\n", + " acc.mean().item(),\n", + " acc.std().item(),\n", + " duration.mean().item()))\n", + "\n", + "\n", + "def train(model, optimizer, data):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + "\n", + " E = data.edge_index.size(1)\n", + " N = data.x.size(0)\n", + " x = data.x[:, :num_feat]\n", + " adj_t = torch.sparse_coo_tensor(data.edge_index, torch.ones(E), size=(N, N)).to_dense().T\n", + " deg = torch.sum(adj_t, dim=1)\n", + " out = model(x, adj_t, deg)\n", + " loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "\n", + "def evaluate(model, data):\n", + " model.eval()\n", + "\n", + " with torch.no_grad():\n", + "\n", + " E = data.edge_index.size(1)\n", + " N = data.x.size(0)\n", + " x = data.x[:, :num_feat]\n", + " adj_t = torch.sparse_coo_tensor(data.edge_index, torch.ones(E), size=(N, N)).to_dense().T\n", + " deg = torch.sum(adj_t, dim=1)\n", + " logits = model(x, adj_t, deg)\n", + "\n", + " outs = {}\n", + " for key in ['train', 'val', 'test']:\n", + " mask = data['{}_mask'.format(key)]\n", + " loss = F.nll_loss(logits[mask], data.y[mask]).item()\n", + " pred = logits[mask].max(1)[1]\n", + " acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n", + "\n", + " outs['{}_loss'.format(key)] = loss\n", + " outs['{}_acc'.format(key)] = acc\n", + "\n", + " return outs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28b3605e-e6fd-45ff-ae4b-607065f4849c", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "28b3605e-e6fd-45ff-ae4b-607065f4849c", + "outputId": "b3ea504c-b57c-46d4-b382-aa54c9a4786f" + }, + "outputs": [], + "source": [ + "runs = 1\n", + "epochs = 200\n", + "lr = 0.01\n", + "weight_decay = 0.0005\n", + "early_stopping = 10\n", + "hidden = 16\n", + "dropout = 0.5\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "\n", + "class Net(torch.nn.Module):\n", + " def __init__(self, dataset, num_feat):\n", + " super(Net, self).__init__()\n", + " # self.conv1 = GCNConv(dataset.num_features, hidden)\n", + " self.conv1 = GCNConv(num_feat, hidden)\n", + " self.conv2 = GCNConv(hidden, dataset.num_classes)\n", + "\n", + "\n", + " def reset_parameters(self):\n", + " self.conv1.reset_parameters()\n", + " self.conv2.reset_parameters()\n", + "\n", + " def forward(self, x, adj_t, deg):\n", + " x = F.relu(self.conv1(x, adj_t, deg))\n", + " x = F.dropout(x, p=dropout, training=self.training)\n", + " x = self.conv2(x, adj_t, deg)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + "model = Net(dataset, num_feat)\n", + "run(dataset, model, runs, epochs, lr, weight_decay, early_stopping)" + ] + }, + { + "cell_type": "markdown", + "id": "4cc3ffed-74c2-48e3-86bc-a5e51f44a09a", + "metadata": { + "id": "4cc3ffed-74c2-48e3-86bc-a5e51f44a09a" + }, + "source": [ + "## EZKL Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92585631-ff39-402e-bd1c-aaebdce682e5", + "metadata": { + "id": "92585631-ff39-402e-bd1c-aaebdce682e5" + }, + "outputs": [], + "source": [ + "import os\n", + "import ezkl\n", + "\n", + "\n", + "model_path = os.path.join('network.onnx')\n", + "compiled_model_path = os.path.join('network.compiled')\n", + "pk_path = os.path.join('test.pk')\n", + "vk_path = os.path.join('test.vk')\n", + "settings_path = os.path.join('settings.json')\n", + "srs_path = os.path.join('kzg.srs')\n", + "witness_path = os.path.join('witness.json')\n", + "data_path = os.path.join('input.json')\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d80d3169-cc70-4aee-bdc2-df9a435b3116", + "metadata": { + "id": "d80d3169-cc70-4aee-bdc2-df9a435b3116" + }, + "outputs": [], + "source": [ + "# Downsample graph\n", + "num_node = 5\n", + "\n", + "# filter edges so that we only bring adjacencies among downsampled node\n", + "filter_row = []\n", + "filter_col = []\n", + "row, col = dataset[0].edge_index\n", + "for idx in range(row.size(0)):\n", + " if row[idx] < num_node and col[idx] < num_node:\n", + " filter_row.append(row[idx])\n", + " filter_col.append(col[idx])\n", + "filter_edge_index = torch.stack([torch.tensor(filter_row), torch.tensor(filter_col)])\n", + "num_edge = len(filter_row)\n", + "\n", + "\n", + "x = dataset[0].x[:num_node, :num_feat]\n", + "edge_index = filter_edge_index\n", + "\n", + "adj_t = torch.sparse_coo_tensor(edge_index, torch.ones(num_edge), size=(num_node, num_node)).to_dense().T\n", + "deg = torch.sum(adj_t, dim=1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46367b2f-951d-403b-9346-e689de0bee3f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "46367b2f-951d-403b-9346-e689de0bee3f", + "outputId": "f063bf1b-e518-4fdb-b8ad-507c521acaa3" + }, + "outputs": [], + "source": [ + "import json\n", + "\n", + "# Flips the neural net into inference mode\n", + "model.eval()\n", + "model.to('cpu')\n", + "\n", + "# No dynamic axis for GNN batch\n", + "torch.onnx.export(model, # model being run\n", + " (x, adj_t, deg), # model input (or a tuple for multiple inputs)\n", + " model_path, # where to save the model (can be a file or file-like object)\n", + " export_params=True, # store the trained parameter weights inside the model file\n", + " opset_version=11, # the ONNX version to export the model to\n", + " do_constant_folding=True, # whether to execute constant folding for optimization\n", + " input_names = ['x', 'edge_index'], # the model's input names\n", + " output_names = ['output']) # the model's output names" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e6da242-540e-48dc-bc20-d08fcd192af4", + "metadata": { + "id": "9e6da242-540e-48dc-bc20-d08fcd192af4" + }, + "outputs": [], + "source": [ + "torch_out = model(x, adj_t, deg)\n", + "x_shape = x.shape\n", + "adj_t_shape=adj_t.shape\n", + "deg_shape=deg.shape\n", + "\n", + "x = ((x).detach().numpy()).reshape([-1]).tolist()\n", + "adj_t = ((adj_t).detach().numpy()).reshape([-1]).tolist()\n", + "deg = ((deg).detach().numpy()).reshape([-1]).tolist()\n", + "\n", + "data = dict(input_shapes=[x_shape, adj_t_shape, deg_shape],\n", + " input_data=[x, adj_t, deg],\n", + " output_data=[((torch_out).detach().numpy()).reshape([-1]).tolist()])\n", + "json.dump(data, open(data_path, 'w'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3393a884-7a14-435e-bb9e-4fa4fcbdc76b", + "metadata": { + "id": "3393a884-7a14-435e-bb9e-4fa4fcbdc76b", + "tags": [] + }, + "outputs": [], + "source": [ + "!RUST_LOG=trace\n", + "import ezkl\n", + "\n", + "run_args = ezkl.PyRunArgs()\n", + "run_args.input_scale = 5\n", + "run_args.param_scale = 5\n", + "# TODO: Dictionary outputs\n", + "res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n", + "assert res == True\n", + "\n", + "res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", + "assert res == True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f86fceb", + "metadata": { + "id": "8f86fceb" + }, + "outputs": [], + "source": [ + "res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n", + "assert res == True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b55c925", + "metadata": { + "id": "3b55c925" + }, + "outputs": [], + "source": [ + "# srs path\n", + "res = ezkl.get_srs(srs_path, settings_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6478bab", + "metadata": { + "id": "d6478bab" + }, + "outputs": [], + "source": [ + "# now generate the witness file\n", + "\n", + "res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n", + "assert os.path.isfile(witness_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b500c1ba", + "metadata": { + "id": "b500c1ba" + }, + "outputs": [], + "source": [ + "# HERE WE SETUP THE CIRCUIT PARAMS\n", + "# WE GOT KEYS\n", + "# WE GOT CIRCUIT PARAMETERS\n", + "# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n", + "\n", + "\n", + "\n", + "res = ezkl.setup(\n", + " compiled_model_path,\n", + " vk_path,\n", + " pk_path,\n", + " srs_path,\n", + " )\n", + "\n", + "assert res == True\n", + "assert os.path.isfile(vk_path)\n", + "assert os.path.isfile(pk_path)\n", + "assert os.path.isfile(settings_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae152a64", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ae152a64", + "outputId": "599cc9b8-ee85-407e-f0da-b2360634d2a8" + }, + "outputs": [], + "source": [ + "# GENERATE A PROOF\n", + "\n", + "\n", + "proof_path = os.path.join('test.pf')\n", + "\n", + "res = ezkl.prove(\n", + " witness_path,\n", + " compiled_model_path,\n", + " pk_path,\n", + " proof_path,\n", + " srs_path,\n", + " \"single\",\n", + " )\n", + "\n", + "print(res)\n", + "assert os.path.isfile(proof_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2548b00", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a2548b00", + "outputId": "e2972113-c079-4cb2-bfc5-6f7ad2842195" + }, + "outputs": [], + "source": [ + "# VERIFY IT\n", + "\n", + "res = ezkl.verify(\n", + " proof_path,\n", + " settings_path,\n", + " vk_path,\n", + " srs_path,\n", + " )\n", + "\n", + "assert res == True\n", + "print(\"verified\")" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3.11.4 ('.env': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "vscode": { + "interpreter": { + "hash": "af2b032f4d5a009ff33cd3ba5ac25dedfd7d71c9736fbe82aa90983ec2fc3628" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/test_snark_serialization_roundtrip.json b/test_snark_serialization_roundtrip.json new file mode 100644 index 000000000..43af0afc6 --- /dev/null +++ b/test_snark_serialization_roundtrip.json @@ -0,0 +1 @@ +{"protocol":null,"instances":[[[12436184717236109307,3962172157175319849,7381016538464732718,1011752739694698287]],[[6425625360762666998,7924344314350639699,14762033076929465436,2023505479389396574]]],"proof":[1,2,3,4,5,6,7,8],"transcript_type":"EVM"} \ No newline at end of file diff --git a/tests/py_integration_tests.rs b/tests/py_integration_tests.rs index 5d36f29c4..eb1f1c6e3 100644 --- a/tests/py_integration_tests.rs +++ b/tests/py_integration_tests.rs @@ -115,7 +115,7 @@ mod py_tests { } } - const TESTS: [&str; 21] = [ + const TESTS: [&str; 22] = [ "mnist_gan.ipynb", // "mnist_vae.ipynb", "keras_simple_demo.ipynb", @@ -138,6 +138,7 @@ mod py_tests { "svm.ipynb", "simple_demo_public_input_output.ipynb", "simple_demo_public_network_output.ipynb", + "gcn.ipynb", ]; macro_rules! test_func { @@ -150,7 +151,7 @@ mod py_tests { use super::*; - seq!(N in 0..=20 { + seq!(N in 0..=21 { #(#[test_case(TESTS[N])])* fn run_notebook_(test: &str) {