Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: use SymbolRef as the Lang map key #1195

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions benches/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use lurk::{
instance::{Instance, Kind},
public_params, supernova_public_params,
},
state::{user_sym, State, StateRcCell},
state::{State, StateRcCell},
Symbol,
};

mod common;
Expand Down Expand Up @@ -103,7 +104,7 @@ fn sha256_ivc_prove<M: measurement::Measurement>(
let limit = 10000;

let store = &Store::<Bn>::default();
let cproc_sym = user_sym(&format!("sha256_ivc_{arity}"));
let cproc_sym = Symbol::interned(format!("sha256_ivc_{arity}"), state.clone()).unwrap();

let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
lang.add_coprocessor(cproc_sym, Sha256Coprocessor::new(arity));
Expand Down Expand Up @@ -191,7 +192,7 @@ fn sha256_ivc_prove_compressed<M: measurement::Measurement>(
let limit = 10000;

let store = &Store::<Bn>::default();
let cproc_sym = user_sym(&format!("sha256_ivc_{arity}"));
let cproc_sym = Symbol::interned(format!("sha256_ivc_{arity}"), state.clone()).unwrap();

let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
lang.add_coprocessor(cproc_sym, Sha256Coprocessor::new(arity));
Expand Down Expand Up @@ -281,7 +282,7 @@ fn sha256_nivc_prove<M: measurement::Measurement>(
let limit = 10000;

let store = &Store::<Bn>::default();
let cproc_sym = user_sym(&format!("sha256_ivc_{arity}"));
let cproc_sym = Symbol::interned(format!("sha256_ivc_{arity}"), state.clone()).unwrap();

let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
lang.add_coprocessor(cproc_sym, Sha256Coprocessor::new(arity));
Expand Down
10 changes: 6 additions & 4 deletions examples/circom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use lurk::public_parameters::{
instance::{Instance, Kind},
public_params,
};
use lurk::state::State;
use lurk::Symbol;
use lurk_macros::Coproc;

Expand Down Expand Up @@ -101,14 +102,15 @@ enum Sha256Coproc<F: LurkField> {
/// `cargo run --release --example circom`
fn main() {
let store = &Store::default();
let sym_str = Symbol::new(&[".circom_sha256_2"], false); // two inputs
let state = State::init_lurk_state().rccell();

let name = Symbol::interned(".circom_sha256_2", state.clone()).unwrap(); // two inputs
let circom_sha256: CircomSha256<Bn> = CircomSha256::new(0);
let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
lang.add_coprocessor(sym_str, CircomCoprocessor::new(circom_sha256));
lang.add_coprocessor(name, CircomCoprocessor::new(circom_sha256));
let lang_rc = Arc::new(lang);

let expr = "(.circom_sha256_2)".to_string();
let ptr = store.read_with_default_state(&expr).unwrap();
let ptr = store.read(state, "(.circom_sha256_2)").unwrap();

let nova_prover = NovaProver::<Bn, Sha256Coproc<Bn>>::new(REDUCTION_COUNT, lang_rc.clone());

Expand Down
17 changes: 12 additions & 5 deletions examples/sha256_ivc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@ use lurk::{
instance::{Instance, Kind},
public_params,
},
state::user_sym,
state::{State, StateRcCell},
Symbol,
};

const REDUCTION_COUNT: usize = 10;

fn sha256_ivc<F: LurkField>(store: &Store<F>, n: usize, input: &[usize]) -> Ptr {
fn sha256_ivc<F: LurkField>(
store: &Store<F>,
state: StateRcCell,
n: usize,
input: &[usize],
) -> Ptr {
assert_eq!(n, input.len());
let input = input
.iter()
Expand Down Expand Up @@ -47,7 +53,7 @@ fn sha256_ivc<F: LurkField>(store: &Store<F>, n: usize, input: &[usize]) -> Ptr
"#
);

store.read_with_default_state(&program).unwrap()
store.read(state, &program).unwrap()
}

/// Run the example in this file with
Expand All @@ -64,9 +70,10 @@ fn main() {
let n = args.get(1).unwrap_or(&"1".into()).parse().unwrap();

let store = &Store::default();
let cproc_sym = user_sym(&format!("sha256_ivc_{n}"));
let state = State::init_lurk_state().rccell();
let cproc_sym = Symbol::interned(format!("sha256_ivc_{n}"), state.clone()).unwrap();

let call = sha256_ivc(store, n, &(0..n).collect::<Vec<_>>());
let call = sha256_ivc(store, state, n, &(0..n).collect::<Vec<_>>());

let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
lang.add_coprocessor(cproc_sym, Sha256Coprocessor::new(n));
Expand Down
17 changes: 12 additions & 5 deletions examples/sha256_nivc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@ use lurk::{
instance::{Instance, Kind},
supernova_public_params,
},
state::user_sym,
state::{State, StateRcCell},
Symbol,
};

const REDUCTION_COUNT: usize = 10;

fn sha256_nivc<F: LurkField>(store: &Store<F>, n: usize, input: &[usize]) -> Ptr {
fn sha256_nivc<F: LurkField>(
store: &Store<F>,
state: StateRcCell,
n: usize,
input: &[usize],
) -> Ptr {
assert_eq!(n, input.len());
let input = input
.iter()
Expand Down Expand Up @@ -51,7 +57,7 @@ fn sha256_nivc<F: LurkField>(store: &Store<F>, n: usize, input: &[usize]) -> Ptr
"#
);

store.read_with_default_state(&program).unwrap()
store.read(state, &program).unwrap()
}

/// Run the example in this file with
Expand All @@ -68,9 +74,10 @@ fn main() {
let n = args.get(1).unwrap_or(&"1".into()).parse().unwrap();

let store = &Store::default();
let cproc_sym = user_sym(&format!("sha256_nivc_{n}"));
let state = State::init_lurk_state().rccell();
let cproc_sym = Symbol::interned(format!("sha256_nivc_{n}"), state.clone()).unwrap();

let call = sha256_nivc(store, n, &(0..n).collect::<Vec<_>>());
let call = sha256_nivc(store, state, n, &(0..n).collect::<Vec<_>>());

let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
lang.add_coprocessor(cproc_sym, Sha256Coprocessor::new(n));
Expand Down
7 changes: 3 additions & 4 deletions src/coprocessor/trie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,9 @@ pub fn install<F: LurkField>(state: &StateRcCell, lang: &mut Lang<F, TrieCoproc<
let lookup_sym = package.intern("lookup");
let insert_sym = package.intern("insert");
state_mut.add_package(package);
// TODO: should `Lang` hold `Arc<Symbol>` instead?
lang.add_coprocessor((*new_sym).clone(), NewCoprocessor::default());
lang.add_coprocessor((*lookup_sym).clone(), LookupCoprocessor::default());
lang.add_coprocessor((*insert_sym).clone(), InsertCoprocessor::default());
Comment on lines -321 to -324
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is less cumbersome now

lang.add_coprocessor(new_sym, NewCoprocessor::default());
lang.add_coprocessor(lookup_sym, LookupCoprocessor::default());
lang.add_coprocessor(insert_sym, InsertCoprocessor::default());
}

pub type ChildMap<F, const ARITY: usize> = InversePoseidonCache<F>;
Expand Down
23 changes: 11 additions & 12 deletions src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
coprocessor::{CoCircuit, Coprocessor},
field::LurkField,
lem::{pointers::Ptr, store::Store},
package::SymbolRef,
symbol::Symbol,
};

Expand Down Expand Up @@ -69,12 +70,10 @@ pub enum Coproc<F: LurkField> {
/// - `F`: A field type that implements the [`crate::field::LurkField`] trait.
/// - `C`: A type that implements the [`crate::coprocessor::Coprocessor`] trait. This allows late-binding of the
/// exact set of coprocessors to be allowed in the `Lang` struct.
///
// TODO: Define a trait for the Hash and parameterize on that also.
#[derive(Debug, Default, Clone, Deserialize, Serialize)]
pub struct Lang<F, C = Coproc<F>> {
/// An IndexMap that stores coprocessors with their associated `Sym` keys.
coprocessors: IndexMap<Symbol, C>,
coprocessors: IndexMap<SymbolRef, C>,
_p: PhantomData<F>,
}

Expand Down Expand Up @@ -111,8 +110,8 @@ impl<F: LurkField, C> Lang<F, C> {
}

#[inline]
pub fn add_coprocessor<T: Into<C>, S: Into<Symbol>>(&mut self, name: S, cproc: T) {
self.coprocessors.insert(name.into(), cproc.into());
pub fn add_coprocessor<T: Into<C>>(&mut self, name: SymbolRef, cproc: T) {
self.coprocessors.insert(name, cproc.into());
}

pub fn add_binding<B: Into<Binding<F, C>>>(&mut self, binding: B) {
Expand All @@ -121,7 +120,7 @@ impl<F: LurkField, C> Lang<F, C> {
}

#[inline]
pub fn coprocessors(&self) -> &IndexMap<Symbol, C> {
pub fn coprocessors(&self) -> &IndexMap<SymbolRef, C> {
&self.coprocessors
}

Expand Down Expand Up @@ -155,21 +154,21 @@ impl<F: LurkField, C> Lang<F, C> {
/// modular construction of `Lang`s using `Coprocessor`s.
#[derive(Debug)]
pub struct Binding<F, C> {
name: Symbol,
name: SymbolRef,
coproc: C,
_p: PhantomData<F>,
}

impl<F: LurkField, C: Coprocessor<F>, S: Into<Symbol>> From<(S, C)> for Binding<F, C> {
fn from(pair: (S, C)) -> Self {
impl<F: LurkField, C: Coprocessor<F>> From<(SymbolRef, C)> for Binding<F, C> {
fn from(pair: (SymbolRef, C)) -> Self {
Self::new(pair.0, pair.1)
}
}

impl<F: LurkField, C> Binding<F, C> {
pub fn new<T: Into<C>, S: Into<Symbol>>(name: S, coproc: T) -> Self {
pub fn new<T: Into<C>>(name: SymbolRef, coproc: T) -> Self {
Self {
name: name.into(),
name,
coproc: coproc.into(),
_p: Default::default(),
}
Expand All @@ -191,7 +190,7 @@ pub(crate) mod test {
#[test]
fn dummy_lang() {
let _lang = Lang::<Fr>::new_with_bindings(vec![(
sym!("coproc", "dummy"),
sym!("coproc", "dummy").into(),
DummyCoprocessor::new().into(),
)]);
}
Expand Down
4 changes: 2 additions & 2 deletions src/lem/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ pub fn make_eval_step_from_config<F: LurkField, C: Coprocessor<F>>(
&ec.lang
.coprocessors()
.iter()
.map(|(s, c)| (s, c.arity()))
.map(|(s, c)| (s as &Symbol, c.arity()))
.collect::<Vec<_>>(),
ec.is_ivc(),
)
Expand Down Expand Up @@ -489,7 +489,7 @@ pub fn make_cprocs_funcs_from_lang<F: LurkField, C: Coprocessor<F>>(
) -> Vec<Func> {
lang.coprocessors()
.iter()
.map(|(name, c)| run_cproc(name.clone(), c.arity()))
.map(|(name, c)| run_cproc((*name.clone()).clone(), c.arity()))
.collect()
}

Expand Down
Loading