Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanly separate Circom1 and Circom2 traits #60

Merged
merged 5 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 41 additions & 46 deletions src/witness/circom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ pub struct Wasm(Instance);
pub trait CircomBase {
fn init(&self, sanity_check: bool) -> Result<()>;
fn func(&self, name: &str) -> &Function;
fn get_ptr_witness_buffer(&self) -> Result<u32>;
fn get_ptr_witness(&self, w: u32) -> Result<u32>;
fn get_n_vars(&self) -> Result<u32>;
fn get_u32(&self, name: &str) -> Result<u32>;
// Only exists natively in Circom2, hardcoded for Circom
fn get_version(&self) -> Result<u32>;
}

pub trait Circom1 {
fn get_ptr_witness(&self, w: u32) -> Result<u32>;
fn get_fr_len(&self) -> Result<u32>;
fn get_signal_offset32(
&self,
p_sig_offset: u32,
Expand All @@ -18,13 +24,6 @@ pub trait CircomBase {
hash_lsb: u32,
) -> Result<()>;
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()>;
fn get_u32(&self, name: &str) -> Result<u32>;
// Only exists natively in Circom2, hardcoded for Circom
fn get_version(&self) -> Result<u32>;
}

pub trait Circom {
fn get_fr_len(&self) -> Result<u32>;
fn get_ptr_raw_prime(&self) -> Result<u32>;
}

Expand All @@ -38,14 +37,46 @@ pub trait Circom2 {
fn get_witness_size(&self) -> Result<u32>;
}

impl Circom for Wasm {
impl Circom1 for Wasm {
fn get_fr_len(&self) -> Result<u32> {
self.get_u32("getFrLen")
}

fn get_ptr_raw_prime(&self) -> Result<u32> {
self.get_u32("getPRawPrime")
}

fn get_ptr_witness(&self, w: u32) -> Result<u32> {
let func = self.func("getPWitness");
let res = func.call(&[w.into()])?;

Ok(res[0].unwrap_i32() as u32)
}

fn get_signal_offset32(
&self,
p_sig_offset: u32,
component: u32,
hash_msb: u32,
hash_lsb: u32,
) -> Result<()> {
let func = self.func("getSignalOffset32");
func.call(&[
p_sig_offset.into(),
component.into(),
hash_msb.into(),
hash_lsb.into(),
])?;

Ok(())
}

fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> {
let func = self.func("setSignal");
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;

Ok(())
}
}

#[cfg(feature = "circom-2")]
Expand Down Expand Up @@ -96,46 +127,10 @@ impl CircomBase for Wasm {
Ok(())
}

fn get_ptr_witness_buffer(&self) -> Result<u32> {
self.get_u32("getWitnessBuffer")
}

fn get_ptr_witness(&self, w: u32) -> Result<u32> {
let func = self.func("getPWitness");
let res = func.call(&[w.into()])?;

Ok(res[0].unwrap_i32() as u32)
}

fn get_n_vars(&self) -> Result<u32> {
self.get_u32("getNVars")
}

fn get_signal_offset32(
&self,
p_sig_offset: u32,
component: u32,
hash_msb: u32,
hash_lsb: u32,
) -> Result<()> {
let func = self.func("getSignalOffset32");
func.call(&[
p_sig_offset.into(),
component.into(),
hash_msb.into(),
hash_lsb.into(),
])?;

Ok(())
}

fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> {
let func = self.func("setSignal");
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;

Ok(())
}

// Default to version 1 if it isn't explicitly defined
fn get_version(&self) -> Result<u32> {
match self.0.exports.get_function("getVersion") {
Expand Down
2 changes: 1 addition & 1 deletion src/witness/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub(super) use circom::{CircomBase, Wasm};
#[cfg(feature = "circom-2")]
pub(super) use circom::Circom2;

pub(super) use circom::Circom;
pub(super) use circom::Circom1;

use fnv::FnvHasher;
use std::hash::Hasher;
Expand Down
58 changes: 25 additions & 33 deletions src/witness/witness_calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@ use super::{fnv, CircomBase, SafeMemory, Wasm};
use color_eyre::Result;
use num_bigint::BigInt;
use num_traits::Zero;
use std::cell::Cell;
use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store};

#[cfg(feature = "circom-2")]
use num::ToPrimitive;

use super::Circom1;
#[cfg(feature = "circom-2")]
use super::Circom2;

use super::Circom;

#[derive(Clone, Debug)]
pub struct WitnessCalculator {
pub instance: Wasm,
pub memory: SafeMemory,
pub memory: Option<SafeMemory>,
pub n64: u32,
pub circom_version: u32,
pub prime: BigInt,
}

// Error type to signal end of execution.
Expand Down Expand Up @@ -92,9 +91,8 @@ impl WitnessCalculator {

// Circom 2 feature flag with version 2
#[cfg(feature = "circom-2")]
fn new_circom2(instance: Wasm, memory: Memory, version: u32) -> Result<WitnessCalculator> {
fn new_circom2(instance: Wasm, version: u32) -> Result<WitnessCalculator> {
let n32 = instance.get_field_num_len32()?;
let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
instance.get_raw_prime()?;
let mut arr = vec![0; n32 as usize];
for i in 0..n32 {
Expand All @@ -104,13 +102,13 @@ impl WitnessCalculator {
let prime = from_array32(arr);

let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
safe_memory.prime = prime;

Ok(WitnessCalculator {
instance,
memory: safe_memory,
memory: None,
n64,
circom_version: version,
prime,
})
}

Expand All @@ -122,13 +120,14 @@ impl WitnessCalculator {
let prime = safe_memory.read_big(ptr as usize, n32 as usize)?;

let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
safe_memory.prime = prime;
safe_memory.prime = prime.clone();

Ok(WitnessCalculator {
instance,
memory: safe_memory,
memory: Some(safe_memory),
n64,
circom_version: version,
prime,
})
}

Expand All @@ -142,7 +141,7 @@ impl WitnessCalculator {
cfg_if::cfg_if! {
if #[cfg(feature = "circom-2")] {
match version {
2 => new_circom2(instance, memory, version),
2 => new_circom2(instance, version),
1 => new_circom1(instance, memory, version),
_ => panic!("Unknown Circom version")
}
Expand Down Expand Up @@ -180,9 +179,9 @@ impl WitnessCalculator {
) -> Result<Vec<BigInt>> {
self.instance.init(sanity_check)?;

let old_mem_free_pos = self.memory.free_pos();
let p_sig_offset = self.memory.alloc_u32();
let p_fr = self.memory.alloc_fr();
let old_mem_free_pos = self.memory.as_ref().unwrap().free_pos();
let p_sig_offset = self.memory.as_mut().unwrap().alloc_u32();
let p_fr = self.memory.as_mut().unwrap().alloc_fr();

// allocate the inputs
for (name, values) in inputs.into_iter() {
Expand All @@ -191,10 +190,17 @@ impl WitnessCalculator {
self.instance
.get_signal_offset32(p_sig_offset, 0, msb, lsb)?;

let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize;
let sig_offset = self
.memory
.as_ref()
.unwrap()
.read_u32(p_sig_offset as usize) as usize;

for (i, value) in values.into_iter().enumerate() {
self.memory.write_fr(p_fr as usize, &value)?;
self.memory
.as_mut()
.unwrap()
.write_fr(p_fr as usize, &value)?;
self.instance
.set_signal(0, 0, (sig_offset + i) as u32, p_fr)?;
}
Expand All @@ -205,11 +211,11 @@ impl WitnessCalculator {
let n_vars = self.instance.get_n_vars()?;
for i in 0..n_vars {
let ptr = self.instance.get_ptr_witness(i)? as usize;
let el = self.memory.read_fr(ptr)?;
let el = self.memory.as_ref().unwrap().read_fr(ptr)?;
w.push(el);
}

self.memory.set_free_pos(old_mem_free_pos);
self.memory.as_mut().unwrap().set_free_pos(old_mem_free_pos);

Ok(w)
}
Expand Down Expand Up @@ -283,20 +289,6 @@ impl WitnessCalculator {

Ok(witness)
}

pub fn get_witness_buffer(&self) -> Result<Vec<u8>> {
let ptr = self.instance.get_ptr_witness_buffer()? as usize;

let view = self.memory.memory.view::<u8>();

let len = self.instance.get_n_vars()? * self.n64 * 8;
let arr = view[ptr..ptr + len as usize]
.iter()
.map(Cell::get)
.collect::<Vec<_>>();

Ok(arr)
}
}

// callback hooks for debugging
Expand Down Expand Up @@ -463,7 +455,7 @@ mod tests {
fn run_test(case: TestCase) {
let mut wtns = WitnessCalculator::new(case.circuit_path).unwrap();
assert_eq!(
wtns.memory.prime.to_str_radix(16),
wtns.prime.to_str_radix(16),
"30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase()
);
assert_eq!({ wtns.instance.get_n_vars().unwrap() }, case.n_vars);
Expand Down
Loading