From 87a03b0e769093bb8ac88c62f65c2a8f3b0d8c37 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 5 Oct 2023 17:56:54 +0100 Subject: [PATCH 1/6] refactor: replace small mult with `div lookup` --- src/graph/utilities.rs | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 882d75fd7..e50e5717b 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -709,7 +709,30 @@ pub fn new_op_from_onnx( } "Add" => SupportedOp::Linear(PolyOp::Add), "Sub" => SupportedOp::Linear(PolyOp::Sub), - "Mul" => SupportedOp::Linear(PolyOp::Mult), + "Mul" => { + let mut op = SupportedOp::Linear(PolyOp::Mult); + + let const_idx = inputs + .iter() + .enumerate() + .filter(|(_, n)| n.is_constant()) + .map(|(i, _)| i) + .collect::>(); + + assert_eq!(const_idx.len(), 1); + let const_idx = const_idx[0]; + + if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() { + if c.raw_values.len() == 1 && c.raw_values[const_idx] < 1. { + inputs[const_idx].decrement_const(); + deleted_indices.push(const_idx); + op = SupportedOp::Nonlinear(LookupOp::Div { + denom: crate::circuit::utils::F32(c.raw_values[0]), + }) + } + } + op + } "Iff" => SupportedOp::Linear(PolyOp::Iff), "Less" => { if inputs.len() == 2 { From dd81e115c04e84bfc2f962dbe78d06057ce60ac8 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 5 Oct 2023 18:00:15 +0100 Subject: [PATCH 2/6] Update utilities.rs --- src/graph/utilities.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index e50e5717b..94e7ccc84 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -727,7 +727,8 @@ pub fn new_op_from_onnx( inputs[const_idx].decrement_const(); deleted_indices.push(const_idx); op = SupportedOp::Nonlinear(LookupOp::Div { - denom: crate::circuit::utils::F32(c.raw_values[0]), + // we invert the constant for division + denom: crate::circuit::utils::F32(1. / c.raw_values[0]), }) } } From 62f103d5fbd0713b50c1597f6e4ac1316a9fb1d1 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 5 Oct 2023 18:06:03 +0100 Subject: [PATCH 3/6] Update utilities.rs --- src/graph/utilities.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 94e7ccc84..cc00db2dd 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -723,7 +723,7 @@ pub fn new_op_from_onnx( let const_idx = const_idx[0]; if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() { - if c.raw_values.len() == 1 && c.raw_values[const_idx] < 1. { + if c.raw_values.len() == 1 && c.raw_values[0] < 1. { inputs[const_idx].decrement_const(); deleted_indices.push(const_idx); op = SupportedOp::Nonlinear(LookupOp::Div { From 0b951758dbd59f4e2805e50b0edb5e29b5a63f76 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 5 Oct 2023 18:09:37 +0100 Subject: [PATCH 4/6] Update utilities.rs --- src/graph/utilities.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index cc00db2dd..898e8e86f 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -719,7 +719,7 @@ pub fn new_op_from_onnx( .map(|(i, _)| i) .collect::>(); - assert_eq!(const_idx.len(), 1); + assert!(const_idx.len() <= 1); let const_idx = const_idx[0]; if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() { From 9af922d5d4cec1c107e755d9f8108087f60029c2 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 5 Oct 2023 18:13:28 +0100 Subject: [PATCH 5/6] Update utilities.rs --- src/graph/utilities.rs | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 898e8e86f..5376b1215 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -720,16 +720,18 @@ pub fn new_op_from_onnx( .collect::>(); assert!(const_idx.len() <= 1); - let const_idx = const_idx[0]; - if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() { - if c.raw_values.len() == 1 && c.raw_values[0] < 1. { - inputs[const_idx].decrement_const(); - deleted_indices.push(const_idx); - op = SupportedOp::Nonlinear(LookupOp::Div { - // we invert the constant for division - denom: crate::circuit::utils::F32(1. / c.raw_values[0]), - }) + if const_idx.len() == 1 { + let const_idx = const_idx[0]; + if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() { + if c.raw_values.len() == 1 && c.raw_values[0] < 1. { + inputs[const_idx].decrement_const(); + deleted_indices.push(const_idx); + op = SupportedOp::Nonlinear(LookupOp::Div { + // we invert the constant for division + denom: crate::circuit::utils::F32(1. / c.raw_values[0]), + }) + } } } op From d67fec23839a000c0f29d2b0ada5223b7456c16b Mon Sep 17 00:00:00 2001 From: Alexander Camuto <45801863+alexander-camuto@users.noreply.github.com> Date: Fri, 6 Oct 2023 00:18:20 +0100 Subject: [PATCH 6/6] Update variance.ipynb --- examples/notebooks/variance.ipynb | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/examples/notebooks/variance.ipynb b/examples/notebooks/variance.ipynb index b5535e722..8ae51e526 100644 --- a/examples/notebooks/variance.ipynb +++ b/examples/notebooks/variance.ipynb @@ -198,19 +198,6 @@ "**EZKL Workflow**" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rNw0C9QL6W88" - }, - "outputs": [], - "source": [ - "# setup kzg params\n", - "params_path = os.path.join('kzg.params')\n", - "res = ezkl.gen_srs(params_path, 7)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -244,12 +231,14 @@ "onnx_filename = os.path.join('lol.onnx')\n", "compiled_filename = os.path.join('lol.compiled')\n", "settings_filename = os.path.join('settings.json')\n", + "srs_path = os.path.join('kzg.params')\n", "\n", "\n", "\n", "ezkl.gen_settings(onnx_filename, settings_filename)\n", "await ezkl.calibrate_settings(\n", " input_filename, onnx_filename, settings_filename, \"resources\")\n", + "res = ezkl.get_srs(srs_path, settings_filename)\n", "ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)\n", "\n", "# show the settings.json\n", @@ -468,7 +457,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.15" + "version": "3.9.13" } }, "nbformat": 4,