Skip to content

Commit

Permalink
feat: add root_balances as public input of MstInclusionCircuit
Browse files Browse the repository at this point in the history
  • Loading branch information
enricobottazzi committed Nov 10, 2023
1 parent 2d9d65c commit 444bcc7
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 20 deletions.
29 changes: 22 additions & 7 deletions zk_prover/src/circuits/merkle_sum_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::chips::poseidon::hash::{PoseidonChip, PoseidonConfig};
use crate::chips::poseidon::poseidon_spec::PoseidonSpec;
use crate::chips::range::range_check::{RangeCheckChip, RangeCheckConfig};
use crate::circuits::traits::CircuitBase;
use crate::merkle_sum_tree::{big_uint_to_fp, Entry, MerkleProof};
use crate::merkle_sum_tree::{big_uint_to_fp, Entry, MerkleProof, Node};
use halo2_proofs::circuit::{AssignedCell, Layouter, SimpleFloorPlanner};
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use halo2_proofs::plonk::{
Expand Down Expand Up @@ -31,7 +31,7 @@ pub struct MstInclusionCircuit<const LEVELS: usize, const N_ASSETS: usize, const
pub path_element_hashes: Vec<Fp>,
pub path_element_balances: Vec<[Fp; N_ASSETS]>,
pub path_indices: Vec<Fp>,
pub root_hash: Fp,
pub root: Node<N_ASSETS>,
}

impl<const LEVELS: usize, const N_ASSETS: usize, const N_BYTES: usize> CircuitExt<Fp>
Expand All @@ -40,13 +40,15 @@ where
[usize; 2 * (1 + N_ASSETS)]: Sized,
[usize; N_ASSETS + 1]: Sized,
{
/// Returns the number of public inputs of the circuit. It is 2, namely the laef hash to be verified inclusion of and the root hash of the merkle sum tree.
/// Returns the number of public inputs of the circuit. It is {2 + N_ASSETS}, namely the leaf hash to be verified inclusion of, the root hash of the merkle sum tree and the root balances of the merkle sum tree.
fn num_instance(&self) -> Vec<usize> {
vec![2]
vec![{ 2 + N_ASSETS }]
}
/// Returns the values of the public inputs of the circuit. Namely the leaf hash to be verified inclusion of and the root hash of the merkle sum tree.
fn instances(&self) -> Vec<Vec<Fp>> {
vec![vec![self.entry.compute_leaf().hash, self.root_hash]]
let mut instance = vec![self.entry.compute_leaf().hash, self.root.hash];
instance.extend_from_slice(&self.root.balances);
vec![instance]
}
}

Expand All @@ -57,14 +59,16 @@ impl<const LEVELS: usize, const N_ASSETS: usize, const N_BYTES: usize> CircuitBa

impl<const LEVELS: usize, const N_ASSETS: usize, const N_BYTES: usize>
MstInclusionCircuit<LEVELS, N_ASSETS, N_BYTES>
where
[usize; N_ASSETS + 1]: Sized,
{
pub fn init_empty() -> Self {
Self {
entry: Entry::init_empty(),
path_element_hashes: vec![Fp::zero(); LEVELS],
path_element_balances: vec![[Fp::zero(); N_ASSETS]; LEVELS],
path_indices: vec![Fp::zero(); LEVELS],
root_hash: Fp::zero(),
root: Node::init_empty(),
}
}

Expand All @@ -85,7 +89,7 @@ impl<const LEVELS: usize, const N_ASSETS: usize, const N_BYTES: usize>
path_element_hashes: merkle_proof.sibling_hashes,
path_element_balances: merkle_proof.sibling_sums,
path_indices: merkle_proof.path_indices,
root_hash: merkle_proof.root_hash,
root: merkle_proof.root,
}
}
}
Expand Down Expand Up @@ -191,6 +195,7 @@ where
impl<const LEVELS: usize, const N_ASSETS: usize, const N_BYTES: usize> Circuit<Fp>
for MstInclusionCircuit<LEVELS, N_ASSETS, N_BYTES>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
type Config = MstInclusionConfig<N_ASSETS, N_BYTES>;
Expand Down Expand Up @@ -374,6 +379,16 @@ where
config.instance,
)?;

// expose the last current balances, namely the root balances, as public input
for (i, balance) in current_balances.iter().enumerate() {
self.expose_public(
layouter.namespace(|| format!("public root balance {}", i)),
balance,
2 + i,
config.instance,
)?;
}

// perform range check on the balances of the root to make sure these lie in the range defined by N_BYTES
for balance in current_balances.iter() {
range_check_chip.assign(layouter.namespace(|| "range check root balance"), balance)?;
Expand Down
53 changes: 52 additions & 1 deletion zk_prover/src/circuits/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ mod test {
let valid_prover = MockProver::run(K, &circuit, circuit.instances()).unwrap();

assert_eq!(circuit.instances()[0].len(), circuit.num_instance()[0]);
assert_eq!(circuit.instances()[0].len(), 2 + N_ASSETS);

valid_prover.assert_satisfied();
}
Expand Down Expand Up @@ -78,6 +79,21 @@ mod test {

// verify the proof to be true
assert!(full_verifier(&params, &vk, proof, circuit.instances()));

// the user should perform the check on the public inputs
// public input #0 is the leaf hash
let expected_leaf_hash = user_entry.compute_leaf().hash;
assert_eq!(circuit.instances()[0][0], expected_leaf_hash);

// public input #1 is the root hash
let expected_root_hash = merkle_sum_tree.root().hash;
assert_eq!(circuit.instances()[0][1], expected_root_hash);

// public inputs [2, 2+N_ASSETS - 1] are the root balances
let expected_root_balances = merkle_sum_tree.root().balances;
for i in 0..N_ASSETS {
assert_eq!(circuit.instances()[0][2 + i], expected_root_balances[i]);
}
}

// Passing an invalid root hash in the instance column should fail the permutation check between the computed root hash and the instance column root hash
Expand Down Expand Up @@ -157,6 +173,7 @@ mod test {
// Passing an invalid entry balance as input for the witness generation should fail:
// - the permutation check between the leaf hash and the instance column leaf hash
// - the permutation check between the computed root hash and the instance column root hash
// - the permutations checks between the computed root balances and the instance column root balances
#[test]
fn test_invalid_entry_balance_as_witness() {
let merkle_sum_tree =
Expand Down Expand Up @@ -202,6 +219,20 @@ mod test {
offset: 36
}
},
VerifyFailure::Permutation {
column: (Any::advice(), 0).into(),
location: FailureLocation::InRegion {
region: (95, "assign value to perform range check").into(),
offset: 0
}
},
VerifyFailure::Permutation {
column: (Any::advice(), 0).into(),
location: FailureLocation::InRegion {
region: (96, "assign value to perform range check").into(),
offset: 0
}
},
VerifyFailure::Permutation {
column: (Any::Instance, 0).into(),
location: FailureLocation::OutsideRegion { row: 0 }
Expand All @@ -210,6 +241,14 @@ mod test {
column: (Any::Instance, 0).into(),
location: FailureLocation::OutsideRegion { row: 1 }
},
VerifyFailure::Permutation {
column: (Any::Instance, 0).into(),
location: FailureLocation::OutsideRegion { row: 2 }
},
VerifyFailure::Permutation {
column: (Any::Instance, 0).into(),
location: FailureLocation::OutsideRegion { row: 3 }
},
])
);
}
Expand Down Expand Up @@ -424,7 +463,8 @@ mod test {
);
}

// Adding a balance at the verge of overflowing should fail the range check for any following computed sum and, because we are adding a fake balance, the root hash check should fail too
// Adding a balance at the verge of overflowing should fail the range check for any following computed sum and, because we are adding a fake balance.
// Furthermore, the public input check on the root hash and on root_balances[0] should fail too
#[test]
fn test_balance_not_in_range() {
let merkle_sum_tree =
Expand Down Expand Up @@ -495,6 +535,13 @@ mod test {
offset: 36
}
},
VerifyFailure::Permutation {
column: (Any::advice(), 0).into(),
location: FailureLocation::InRegion {
region: (95, "assign value to perform range check").into(),
offset: 0
}
},
VerifyFailure::Permutation {
column: (Any::advice(), 0).into(),
location: FailureLocation::InRegion {
Expand All @@ -506,6 +553,10 @@ mod test {
column: (Any::Instance, 0).into(),
location: FailureLocation::OutsideRegion { row: 1 }
},
VerifyFailure::Permutation {
column: (Any::Instance, 0).into(),
location: FailureLocation::OutsideRegion { row: 2 }
},
])
);
}
Expand Down
2 changes: 1 addition & 1 deletion zk_prover/src/merkle_sum_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use halo2_proofs::halo2curves::bn256::Fr as Fp;
#[derive(Clone, Debug)]
pub struct MerkleProof<const N_ASSETS: usize, const N_BYTES: usize> {
pub leaf: Node<N_ASSETS>,
pub root_hash: Fp,
pub root: Node<N_ASSETS>,
pub sibling_hashes: Vec<Fp>,
pub sibling_sums: Vec<[Fp; N_ASSETS]>,
pub path_indices: Vec<Fp>,
Expand Down
10 changes: 10 additions & 0 deletions zk_prover/src/merkle_sum_tree/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ impl<const N_ASSETS: usize> Node<N_ASSETS> {
}
}

pub fn init_empty() -> Node<N_ASSETS>
where
[usize; N_ASSETS + 1]: Sized,
{
Node {
hash: Fp::zero(),
balances: [Fp::zero(); N_ASSETS],
}
}

/// Builds a leaf-level node of the MST
pub fn leaf(username: &BigUint, balances: &[BigUint; N_ASSETS]) -> Node<N_ASSETS>
where
Expand Down
4 changes: 2 additions & 2 deletions zk_prover/src/merkle_sum_tree/tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(test)]
mod test {

use crate::merkle_sum_tree::utils::{big_uint_to_fp, poseidon_node};
use crate::merkle_sum_tree::utils::big_uint_to_fp;
use crate::merkle_sum_tree::{Entry, MerkleSumTree, Tree};
use num_bigint::{BigUint, ToBigUint};

Expand Down Expand Up @@ -81,7 +81,7 @@ mod test {

// shouldn't verify a proof with a wrong root hash
let mut proof_invalid_2 = proof.clone();
proof_invalid_2.root_hash = 0.into();
proof_invalid_2.root.hash = 0.into();
assert!(!merkle_tree.verify_proof(&proof_invalid_2));

// shouldn't verify a proof with a wrong computed balance
Expand Down
11 changes: 2 additions & 9 deletions zk_prover/src/merkle_sum_tree/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {

Ok(MerkleProof {
leaf: leaf.clone(),
root_hash: root.hash,
root: root.clone(),
sibling_hashes,
sibling_sums,
path_indices,
Expand All @@ -68,8 +68,6 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
{
let mut node = proof.leaf.clone();

let mut balances = proof.leaf.balances;

for i in 0..proof.sibling_hashes.len() {
let sibling_node = Node {
hash: proof.sibling_hashes[i],
Expand All @@ -81,14 +79,9 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
} else {
node = Node::middle(&sibling_node, &node);
}

for (balance, sibling_balance) in balances.iter_mut().zip(sibling_node.balances.iter())
{
*balance += sibling_balance;
}
}

proof.root_hash == node.hash && balances == node.balances
proof.root.hash == node.hash && proof.root.balances == node.balances
}

/// Returns the index of the user with the given username and balances in the tree
Expand Down

0 comments on commit 444bcc7

Please sign in to comment.