Skip to content

Commit

Permalink
feat: added method implementation to Tree trait and moved outside o…
Browse files Browse the repository at this point in the history
…f utils
  • Loading branch information
enricobottazzi committed Nov 6, 2023
1 parent b3990eb commit 35478ec
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 127 deletions.
2 changes: 1 addition & 1 deletion zk_prover/examples/gen_inclusion_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn main() {
let user_entry = merkle_sum_tree.get_entry(user_index);

// Generate the circuit with the actual inputs
let mut circuit =
let circuit =
MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof, user_entry.clone());

let instances = circuit.instances();
Expand Down
2 changes: 1 addition & 1 deletion zk_prover/examples/nova_incremental_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ fn main() {
run_test(circuit_filepath.clone(), witness_gen_filepath);
}

use num_traits::{FromPrimitive, Num};
use num_traits::Num;
use poseidon_rs::{Fr, Poseidon};

// Note that we cannot reuse the MerkleSumTree implementation from zk_prover because it is not compatible with circom's Poseidon Hasher
Expand Down
4 changes: 2 additions & 2 deletions zk_prover/src/circuits/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ mod test {
let user_entry = merkle_sum_tree.get_entry(user_index);

// Only now we can instantiate the circuit with the actual inputs
let mut circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
Expand Down Expand Up @@ -741,7 +741,7 @@ mod test {
let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

let mut circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
Expand Down
19 changes: 2 additions & 17 deletions zk_prover/src/merkle_sum_tree/mst.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::merkle_sum_tree::utils::{
build_merkle_tree_from_leaves, compute_leaves, create_proof, index_of, parse_csv_to_entries,
build_merkle_tree_from_leaves, compute_leaves, parse_csv_to_entries,
};
use crate::merkle_sum_tree::{Entry, MerkleProof, Node, Tree};
use crate::merkle_sum_tree::{Entry, Node, Tree};
use num_bigint::BigUint;

/// Merkle Sum Tree Data Structure.
Expand Down Expand Up @@ -43,14 +43,6 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> Tree<N_ASSETS, N_BYTES>
fn nodes(&self) -> &[Vec<Node<N_ASSETS>>] {
&self.nodes
}

/// Generates a MerkleProof for the user with the given index. No mini tree index is required for a MerkleSumTree.
fn generate_proof(
&self,
user_index: usize,
) -> Result<MerkleProof<N_ASSETS, N_BYTES>, &'static str> {
create_proof(user_index, self.depth, &self.nodes, &self.root)
}
}

impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTES> {
Expand Down Expand Up @@ -170,13 +162,6 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE
Ok((&penultimate_level[0], &penultimate_level[1]))
}

/// Returns the index of the user with the given username and balances in the tree
pub fn index_of(&self, username: &str, balances: [BigUint; N_ASSETS]) -> Option<usize>
where
[usize; N_ASSETS + 1]: Sized,
{
index_of(username, balances, &self.nodes)
}

/// Returns the index of the leaf with the matching username
pub fn index_of_username(&self, username: &str) -> Result<usize, Box<dyn std::error::Error>>
Expand Down
87 changes: 80 additions & 7 deletions zk_prover/src/merkle_sum_tree/tree.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::merkle_sum_tree::utils::verify_proof;
use crate::merkle_sum_tree::{MerkleProof, Node};
use crate::merkle_sum_tree::{Entry, MerkleProof, Node};
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use num_bigint::BigUint;

/// A trait representing the basic operations for a Merkle-Sum-like Tree.
pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
Expand All @@ -16,17 +17,89 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
fn nodes(&self) -> &[Vec<Node<N_ASSETS>>];

/// Generates a MerkleProof for the user with the given index.
fn generate_proof(
&self,
user_index: usize,
) -> Result<MerkleProof<N_ASSETS, N_BYTES>, &'static str>;
fn generate_proof(&self, index: usize) -> Result<MerkleProof<N_ASSETS, N_BYTES>, &'static str> {
let nodes = self.nodes();
let depth = *self.depth();
let root = self.root();

if index >= nodes[0].len() {
return Err("The leaf does not exist in this tree");
}

let mut sibling_hashes = vec![Fp::from(0); depth];
let mut sibling_sums = vec![[Fp::from(0); N_ASSETS]; depth];
let mut path_indices = vec![Fp::from(0); depth];
let mut current_index = index;

let leaf = &nodes[0][index];

for level in 0..depth {
let position = current_index % 2;
let level_start_index = current_index - position;
let level_end_index = level_start_index + 2;

path_indices[level] = Fp::from(position as u64);

for i in level_start_index..level_end_index {
if i != current_index {
sibling_hashes[level] = nodes[level][i].hash;
sibling_sums[level] = nodes[level][i].balances;
}
}
current_index /= 2;
}

Ok(MerkleProof {
leaf: leaf.clone(),
root_hash: root.hash,
sibling_hashes,
sibling_sums,
path_indices,
})
}

/// Verifies a MerkleProof.
fn verify_proof(&self, proof: &MerkleProof<N_ASSETS, N_BYTES>) -> bool
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
verify_proof(proof)
let mut node = proof.leaf.clone();

let mut balances = proof.leaf.balances;

for i in 0..proof.sibling_hashes.len() {
let sibling_node = Node {
hash: proof.sibling_hashes[i],
balances: proof.sibling_sums[i],
};

if proof.path_indices[i] == 0.into() {
node = Node::middle(&node, &sibling_node);
} else {
node = Node::middle(&sibling_node, &node);
}

for (balance, sibling_balance) in balances.iter_mut().zip(sibling_node.balances.iter())
{
*balance += sibling_balance;
}
}

proof.root_hash == node.hash && balances == node.balances
}

/// Returns the index of the user with the given username and balances in the tree
fn index_of(&self, username: &str, balances: [BigUint; N_ASSETS]) -> Option<usize>
where
[usize; N_ASSETS + 1]: Sized,
{
let entry: Entry<N_ASSETS> = Entry::new(username.to_string(), balances).unwrap();
let leaf = entry.compute_leaf();
let leaf_hash = leaf.hash;

self.nodes()[0]
.iter()
.position(|node| node.hash == leaf_hash)
}
}
44 changes: 0 additions & 44 deletions zk_prover/src/merkle_sum_tree/utils/create_proof.rs

This file was deleted.

17 changes: 0 additions & 17 deletions zk_prover/src/merkle_sum_tree/utils/index_of.rs

This file was deleted.

6 changes: 0 additions & 6 deletions zk_prover/src/merkle_sum_tree/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
mod build_tree;
mod create_proof;
mod csv_parser;
mod generate_leaf_hash;
mod hash;
mod index_of;
mod operation_helpers;
mod proof_verification;

pub use build_tree::{build_merkle_tree_from_leaves, compute_leaves};
pub use create_proof::create_proof;
pub use csv_parser::parse_csv_to_entries;
pub use generate_leaf_hash::generate_leaf_hash;
pub use hash::{poseidon_entry, poseidon_node};
pub use index_of::index_of;
pub use operation_helpers::*;
pub use proof_verification::verify_proof;
32 changes: 0 additions & 32 deletions zk_prover/src/merkle_sum_tree/utils/proof_verification.rs

This file was deleted.

0 comments on commit 35478ec

Please sign in to comment.