Skip to content

Commit

Permalink
feat: match accum constraints to cuboid (#595)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Nov 7, 2023
1 parent ee4e64f commit ae88ffe
Show file tree
Hide file tree
Showing 16 changed files with 782 additions and 204 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,14 @@ jobs:
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests double inner col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_double_col
- name: KZG prove and verify tests triple inner col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_triple_col
- name: KZG prove and verify tests quadruple inner col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_quadruple_col
- name: KZG prove and verify tests octuple inner col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_octuple_col
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output
- name: KZG prove and verify tests (public outputs + column overflow)
Expand Down
84 changes: 9 additions & 75 deletions examples/notebooks/generalized_inverse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "95613ee9",
"metadata": {
"id": "95613ee9"
Expand Down Expand Up @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "9LgqGF56Qcdz",
"metadata": {
"id": "9LgqGF56Qcdz"
Expand All @@ -69,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "YRQLvvsXVs9s",
"metadata": {
"id": "YRQLvvsXVs9s"
Expand All @@ -84,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "b37637c4",
"metadata": {
"id": "b37637c4"
Expand All @@ -103,7 +103,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "82db373a",
"metadata": {
"id": "82db373a"
Expand Down Expand Up @@ -143,7 +143,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "HOLcdGx4eQ9n",
"metadata": {
"colab": {
Expand All @@ -152,25 +152,14 @@
"id": "HOLcdGx4eQ9n",
"outputId": "cd0a4f10-251e-492e-9f05-d8af0d79c86a"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"circuit.forward(A,B)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "d5e374a2",
"metadata": {
"colab": {
Expand All @@ -180,62 +169,7 @@
"id": "d5e374a2",
"outputId": "11ae5963-02d4-4939-9c98-d126071a9ba0"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n",
"constant node with 1 use\n",
"constant node with 1 use\n",
"no verifying key provided for kzgcommit. processed value will be none\n"
]
}
],
"outputs": [],
"source": [
"\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=gip_run_args)\n",
Expand Down
55 changes: 49 additions & 6 deletions src/circuit/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ use std::{
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum BaseOp {
Dot,
DotInit,
CumProdInit,
CumProd,
Identity,
Add,
Mult,
Sub,
SumInit,
Sum,
Neg,
Range { tol: i32 },
Expand All @@ -24,25 +27,53 @@ pub enum BaseOp {
/// Matches a [BaseOp] to an operation over inputs
impl BaseOp {
/// forward func
pub fn f<
pub fn nonaccum_f<
T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Neg<Output = T>,
>(
&self,
inputs: (T, T, T),
inputs: (T, T),
) -> T {
let (a, b, m) = inputs;
let (a, b) = inputs;
match &self {
BaseOp::Dot => a * b + m,
BaseOp::Add => a + b,
BaseOp::Identity => b,
BaseOp::Sum => b + m,
BaseOp::CumProd => b * m,
BaseOp::Neg => -b,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
BaseOp::Range { .. } => b,
BaseOp::IsZero => b,
BaseOp::IsBoolean => b,
_ => panic!("nonaccum_f called on accumulating operation"),
}
}

/// forward func
pub fn accum_f<
T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Neg<Output = T>,
>(
&self,
prev_output: T,
a: Vec<T>,
b: Vec<T>,
) -> T {
match &self {
BaseOp::DotInit => a
.into_iter()
.zip(b.into_iter())
.fold(T::zero().unwrap(), |acc, (a, b)| acc + a * b),
BaseOp::Dot => {
prev_output
+ a.into_iter()
.zip(b.into_iter())
.fold(T::zero().unwrap(), |acc, (a, b)| acc + a * b)
}
BaseOp::CumProdInit => b.into_iter().fold(T::one().unwrap(), |acc, b| acc * b),
BaseOp::CumProd => {
prev_output * b.into_iter().fold(T::one().unwrap(), |acc, b| acc * b)
}
BaseOp::SumInit => b.into_iter().fold(T::zero().unwrap(), |acc, b| acc + b),
BaseOp::Sum => prev_output + b.into_iter().fold(T::zero().unwrap(), |acc, b| acc + b),
_ => panic!("accum_f called on non-accumulating operation"),
}
}

Expand All @@ -51,12 +82,15 @@ impl BaseOp {
match self {
BaseOp::Identity => "IDENTITY",
BaseOp::Dot => "DOT",
BaseOp::DotInit => "DOTINIT",
BaseOp::CumProdInit => "CUMPRODINIT",
BaseOp::CumProd => "CUMPROD",
BaseOp::Add => "ADD",
BaseOp::Neg => "NEG",
BaseOp::Sub => "SUB",
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
BaseOp::SumInit => "SUMINIT",
BaseOp::Range { .. } => "RANGE",
BaseOp::IsZero => "ISZERO",
BaseOp::IsBoolean => "ISBOOLEAN",
Expand All @@ -68,12 +102,15 @@ impl BaseOp {
match self {
BaseOp::Identity => (0, 1),
BaseOp::Neg => (0, 1),
BaseOp::DotInit => (0, 1),
BaseOp::Dot => (-1, 2),
BaseOp::CumProd => (-1, 2),
BaseOp::CumProdInit => (0, 1),
BaseOp::Add => (0, 1),
BaseOp::Sub => (0, 1),
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
BaseOp::SumInit => (0, 1),
BaseOp::Range { .. } => (0, 1),
BaseOp::IsZero => (0, 1),
BaseOp::IsBoolean => (0, 1),
Expand All @@ -85,12 +122,15 @@ impl BaseOp {
match self {
BaseOp::Identity => 1,
BaseOp::Neg => 1,
BaseOp::DotInit => 2,
BaseOp::Dot => 2,
BaseOp::CumProdInit => 1,
BaseOp::CumProd => 1,
BaseOp::Add => 2,
BaseOp::Sub => 2,
BaseOp::Mult => 2,
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::Range { .. } => 1,
BaseOp::IsZero => 1,
BaseOp::IsBoolean => 1,
Expand All @@ -102,13 +142,16 @@ impl BaseOp {
match self {
BaseOp::Identity => 0,
BaseOp::Neg => 0,
BaseOp::DotInit => 0,
BaseOp::Dot => 1,
BaseOp::Add => 0,
BaseOp::Sub => 0,
BaseOp::Mult => 0,
BaseOp::Range { .. } => 0,
BaseOp::Sum => 1,
BaseOp::SumInit => 0,
BaseOp::CumProd => 1,
BaseOp::CumProdInit => 0,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
Expand Down
Loading

0 comments on commit ae88ffe

Please sign in to comment.