Skip to content

Commit

Permalink
Fix coprocessors returning lists in SuperNova (#1186)
Browse files Browse the repository at this point in the history
* Fix coprocessors returning lists in SuperNova

In NIVC, the Lurk circuit calls out to the coprocessor circuit, and then
on the next fold step calls back into the main Lurk circuit with the
result of the coprocessor. If the coprocessor returns a single value,
there is no issue, but if the coprocessor returns a list, then the lurk
circuit attempts to evaluate that expression instead of simply returning
it. To fix it, we wrap the coprocessor return expr with a thunk.

* test correctness of a simple coprocessor that returns a list in the NIVC context

---------

Co-authored-by: Arthur Paulino <[email protected]>
  • Loading branch information
wwared and arthurpaulino authored Mar 1, 2024
1 parent 5de1e79 commit eaf996c
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 15 deletions.
15 changes: 5 additions & 10 deletions src/cli/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ use crate::{
supernova::SuperNovaProver,
RecursiveSNARKTrait,
},
public_parameters::{
instance::{Instance, Kind},
public_params, supernova_public_params,
},
public_parameters::{instance::Instance, public_params, supernova_public_params},
state::State,
tag::{ContTag, ExprTag},
Symbol,
Expand Down Expand Up @@ -338,12 +335,11 @@ where
info!("Proof not cached");
let (proof, public_inputs, public_outputs) = match self.backend {
Backend::Nova => {
let prover = NovaProver::<_, C>::new(self.rc, self.lang.clone());
info!("Loading Nova public parameters");
let instance =
Instance::new(self.rc, self.lang.clone(), true, Kind::NovaPublicParams);
let instance = Instance::new_nova(&prover, true);
let pp = public_params(&instance)?;

let prover = NovaProver::<_, C>::new(self.rc, self.lang.clone());
info!("Proving with NovaProver");
let (proof, public_inputs, public_outputs, num_steps) =
prover.prove_from_frames(&pp, frames, &self.store)?;
Expand All @@ -354,12 +350,11 @@ where
(LurkProofWrapper::Nova(proof), public_inputs, public_outputs)
}
Backend::SuperNova => {
let prover = SuperNovaProver::<_, C>::new(self.rc, self.lang.clone());
info!("Loading SuperNova public parameters");
let instance =
Instance::new(self.rc, self.lang.clone(), true, Kind::SuperNovaAuxParams);
let instance = Instance::new_supernova(&prover, true);
let pp = supernova_public_params(&instance)?;

let prover = SuperNovaProver::<_, C>::new(self.rc, self.lang.clone());
info!("Proving with SuperNovaProver");
let (proof, public_inputs, public_outputs, _num_steps) =
prover.prove_from_frames(&pp, frames, &self.store)?;
Expand Down
55 changes: 55 additions & 0 deletions src/coprocessor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ pub(crate) mod test {
use bellpepper_core::num::AllocatedNum;
use serde::{Deserialize, Serialize};

use self::gadgets::construct_cons;

use super::*;
use crate::circuit::gadgets::constraints::{alloc_equal, mul};
use crate::lem::{pointers::RawPtr, tag::Tag as LEMTag};
Expand Down Expand Up @@ -348,4 +350,57 @@ pub(crate) mod test {
Self::intern_hello_world(s)
}
}

/// A coprocessor that simply returns the pair (nil . nil)
#[derive(Clone, Debug, Serialize, Deserialize)]
pub(crate) struct NilNil<F> {
_p: PhantomData<F>,
}

impl<F: LurkField> NilNil<F> {
pub(crate) fn new() -> Self {
Self {
_p: Default::default(),
}
}
}

impl<F: LurkField> CoCircuit<F> for NilNil<F> {
fn arity(&self) -> usize {
0
}

fn synthesize_simple<CS: ConstraintSystem<F>>(
&self,
cs: &mut CS,
g: &GlobalAllocator<F>,
s: &Store<F>,
_not_dummy: &Boolean,
_args: &[AllocatedPtr<F>],
) -> Result<AllocatedPtr<F>, SynthesisError> {
let nil = g.alloc_ptr(cs, &s.intern_nil(), s);
construct_cons(cs, g, s, &nil, &nil)
}

fn alloc_globals<CS: ConstraintSystem<F>>(
&self,
cs: &mut CS,
g: &GlobalAllocator<F>,
s: &Store<F>,
) {
g.alloc_ptr(cs, &s.intern_nil(), s);
g.alloc_tag(cs, &ExprTag::Cons);
}
}

impl<F: LurkField> Coprocessor<F> for NilNil<F> {
fn has_circuit(&self) -> bool {
true
}

fn evaluate_simple(&self, s: &Store<F>, _args: &[Ptr]) -> Ptr {
let nil = s.intern_nil();
s.cons(nil, nil)
}
}
}
6 changes: 5 additions & 1 deletion src/lem/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ fn car_cdr() -> Func {
/// // there must be no remaining arguments
/// if is_nil {
/// Op::Cproc([expr, env, cont], x, [x0, x1, ..., x{n-1}, env, cont]);
/// let expr: Expr::Thunk = cons2(expr, cont);
/// return (expr, env, cont);
/// }
/// return (evaluated_args_cp, env, err);
Expand Down Expand Up @@ -390,7 +391,10 @@ fn run_cproc(cproc_sym: Symbol, arity: usize) -> Func {
cproc_inp.push(env.clone());
cproc_inp.push(cont.clone());
let mut block = Block {
ops: vec![Op::Cproc(cproc_out, cproc_sym.clone(), cproc_inp.clone())],
ops: vec![
Op::Cproc(cproc_out, cproc_sym.clone(), cproc_inp.clone()),
op!(let expr: Expr::Thunk = cons2(expr, cont)),
],
ctrl: Ctrl::Return(func_out),
};
for (i, cproc_arg) in cproc_inp[0..arity].iter().enumerate() {
Expand Down
3 changes: 2 additions & 1 deletion src/proof/nova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a> NovaProver<'a, F, C> {
}

#[inline]
fn lang(&self) -> &Arc<Lang<F, C>> {
/// Returns the `Lang` wrapped with `Arc` for cheap cloning
pub fn lang(&self) -> &Arc<Lang<F, C>> {
&self.lang
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/proof/supernova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a> SuperNovaProver<'a, F, C
}

#[inline]
fn lang(&self) -> &Arc<Lang<F, C>> {
/// Returns the `Lang` wrapped with `Arc` for cheap cloning
pub fn lang(&self) -> &Arc<Lang<F, C>> {
&self.lang
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/proof/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
mod nova_tests;
mod supernova_tests;

use bellpepper::util_cs::{metric_cs::MetricCS, witness_cs::WitnessCS, Comparable};
use bellpepper_core::{test_cs::TestConstraintSystem, Circuit, ConstraintSystem, Delta};
use expect_test::Expect;
Expand Down
62 changes: 62 additions & 0 deletions src/proof/tests/supernova_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use halo2curves::bn256::Fr;
use std::sync::Arc;

use crate::{
eval::lang::Lang,
lem::{
eval::{evaluate, make_cprocs_funcs_from_lang, make_eval_step_from_config, EvalConfig},
store::Store,
},
proof::{supernova::SuperNovaProver, RecursiveSNARKTrait},
public_parameters::{instance::Instance, supernova_public_params},
state::user_sym,
};

#[test]
fn test_nil_nil_lang() {
use crate::coprocessor::test::NilNil;
let mut lang = Lang::<Fr, NilNil<Fr>>::new();
lang.add_coprocessor(user_sym("nil-nil"), NilNil::new());

let eval_config = EvalConfig::new_nivc(&lang);
let lurk_step = make_eval_step_from_config(&eval_config);
let cprocs = make_cprocs_funcs_from_lang(&lang);

let store = Store::default();
let expr = store.read_with_default_state("(nil-nil)").unwrap();
let frames = evaluate(Some((&lurk_step, &cprocs, &lang)), expr, &store, 50).unwrap();

// iteration 1: main circuit sets up a call to the coprocessor
// iteration 2: coprocessor does its job
// iteration 3: main circuit sets termination to terminal
assert_eq!(frames.len(), 3);

let first_frame = frames.first().unwrap();
let last_frame = frames.last().unwrap();
let output = &last_frame.output;

// the result is the (nil . nil) pair
let nil = store.intern_nil();
assert!(store.ptr_eq(&output[0], &store.cons(nil, nil)));

// computation must end with the terminal continuation
assert!(store.ptr_eq(&output[2], &store.cont_terminal()));

let supernova_prover = SuperNovaProver::new(5, Arc::new(lang));
let instance = Instance::new_supernova(&supernova_prover, true);
let pp = supernova_public_params(&instance).unwrap();

let (proof, ..) = supernova_prover
.prove_from_frames(&pp, &frames, &store)
.unwrap();

let input_scalar = store.to_scalar_vector(&first_frame.input);
let output_scalar = store.to_scalar_vector(output);

// uncompressed proof verifies
assert!(proof.verify(&pp, &input_scalar, &output_scalar).unwrap());

// compressed proof verifies
let proof = proof.compress(&pp).unwrap();
assert!(proof.verify(&pp, &input_scalar, &output_scalar).unwrap());
}
29 changes: 27 additions & 2 deletions src/public_parameters/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ use crate::{
coprocessor::Coprocessor,
eval::lang::Lang,
proof::{
nova::{self, CurveCycleEquipped},
supernova::{self},
nova::{self, CurveCycleEquipped, NovaProver},
supernova::{self, SuperNovaProver},
Prover,
},
};

Expand Down Expand Up @@ -134,6 +135,30 @@ impl<F: CurveCycleEquipped, C: Coprocessor<F>> Instance<F, C> {
}
}

/// Returns an `Instance` for Nova public parameters with the prover's
/// reduction count and lang
#[inline]
pub fn new_nova(prover: &NovaProver<'_, F, C>, abomonated: bool) -> Self {
Self::new(
prover.reduction_count(),
prover.lang().clone(),
abomonated,
Kind::NovaPublicParams,
)
}

/// Returns an `Instance` for SuperNova public parameters with the prover's
/// reduction count and lang
#[inline]
pub fn new_supernova(prover: &SuperNovaProver<'_, F, C>, abomonated: bool) -> Self {
Self::new(
prover.reduction_count(),
prover.lang().clone(),
abomonated,
Kind::SuperNovaAuxParams,
)
}

/// If this [Instance] is of [Kind::SuperNovaAuxParams], then generate the `num_circuits + 1`
/// circuit param instances that are determined by the internal [Lang].
pub fn circuit_param_instances(&self) -> Vec<Self> {
Expand Down

1 comment on commit eaf996c

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Fibonacci GPU benchmark.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
32 vCPUs
125 GB RAM
Workflow run: https://github.com/lurk-lab/lurk-rs/actions/runs/8114883874

Benchmark Results

LEM Fibonacci Prove - rc = 100

ref=5de1e798ca3e8d21642556e135649cc0365bdba4 ref=eaf996c69d0c47c6f316b9359ff144ddcaf17f47
num-100 1.45 s (✅ 1.00x) 1.46 s (✅ 1.00x slower)
num-200 2.77 s (✅ 1.00x) 2.77 s (✅ 1.00x slower)

LEM Fibonacci Prove - rc = 600

ref=5de1e798ca3e8d21642556e135649cc0365bdba4 ref=eaf996c69d0c47c6f316b9359ff144ddcaf17f47
num-100 1.85 s (✅ 1.00x) 1.83 s (✅ 1.01x faster)
num-200 3.04 s (✅ 1.00x) 3.02 s (✅ 1.01x faster)

Made with criterion-table

Please sign in to comment.