Skip to content

Commit

Permalink
feat: add build_merkle_tree_from_leaves api
Browse files Browse the repository at this point in the history
  • Loading branch information
enricobottazzi committed Nov 1, 2023
1 parent fe19b76 commit e8f68d6
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 64 deletions.
5 changes: 3 additions & 2 deletions zk_prover/src/merkle_sum_tree/aggregation_mst.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::merkle_sum_tree::utils::{
build_merkle_tree_from_roots, create_proof, create_top_tree_proof, verify_proof,
build_merkle_tree_from_leaves, create_proof, create_top_tree_proof, verify_proof,
};
use crate::merkle_sum_tree::{MerkleProof, MerkleSumTree, Node};

Expand All @@ -22,6 +22,7 @@ pub struct AggregationMerkleSumTree<const N_ASSETS: usize, const N_BYTES: usize>

impl<const N_ASSETS: usize, const N_BYTES: usize> AggregationMerkleSumTree<N_ASSETS, N_BYTES> {
/// Builds a AggregationMerkleSumTree from a set of mini MerkleSumTrees
/// The leaves of the AggregationMerkleSumTree are the roots of the mini MerkleSumTrees
pub fn new(
mini_trees: Vec<MerkleSumTree<N_ASSETS, N_BYTES>>,
) -> Result<Self, Box<dyn std::error::Error>>
Expand Down Expand Up @@ -54,7 +55,7 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> AggregationMerkleSumTree<N_ASS
let depth = (roots.len() as f64).log2().ceil() as usize;

let mut nodes = vec![];
let root = build_merkle_tree_from_roots(&roots, depth, &mut nodes)?;
let root = build_merkle_tree_from_leaves(&roots, depth, &mut nodes)?;

Ok(AggregationMerkleSumTree {
root,
Expand Down
9 changes: 7 additions & 2 deletions zk_prover/src/merkle_sum_tree/mst.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::merkle_sum_tree::utils::{
build_merkle_tree_from_entries, create_proof, index_of, parse_csv_to_entries, verify_proof,
compute_leaves, create_proof, index_of, parse_csv_to_entries, verify_proof,
};
use crate::merkle_sum_tree::{Entry, MerkleProof, Node};
use num_bigint::BigUint;

use super::utils::build_merkle_tree_from_leaves;

/// Merkle Sum Tree Data Structure.
///
/// A Merkle Sum Tree is a binary Merkle Tree with the following properties:
Expand Down Expand Up @@ -68,7 +70,10 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE
let depth = (entries.len() as f64).log2().ceil() as usize;

let mut nodes = vec![];
let root = build_merkle_tree_from_entries(&entries, depth, &mut nodes)?;

let leaves = compute_leaves(&entries);

let root = build_merkle_tree_from_leaves(&leaves, depth, &mut nodes)?;

Ok(MerkleSumTree {
root,
Expand Down
69 changes: 10 additions & 59 deletions zk_prover/src/merkle_sum_tree/utils/build_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ use crate::merkle_sum_tree::{Entry, Node};
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use rayon::prelude::*;

pub fn build_merkle_tree_from_entries<const N_ASSETS: usize>(
entries: &[Entry<N_ASSETS>],
pub fn build_merkle_tree_from_leaves<const N_ASSETS: usize>(
leaves: &[Node<N_ASSETS>],
depth: usize,
nodes: &mut Vec<Vec<Node<N_ASSETS>>>,
) -> Result<Node<N_ASSETS>, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
let n = entries.len();
let n = leaves.len();

let mut tree: Vec<Vec<Node<N_ASSETS>>> = Vec::with_capacity(depth + 1);

Expand All @@ -36,7 +36,9 @@ where
]);
}

build_leaves_level(entries, &mut tree);
for (index, leaf) in leaves.iter().enumerate() {
tree[0][index] = leaf.clone();
}

for level in 1..=depth {
build_middle_level(level, &mut tree)
Expand All @@ -47,20 +49,16 @@ where
Ok(root)
}

fn build_leaves_level<const N_ASSETS: usize>(
entries: &[Entry<N_ASSETS>],
tree: &mut [Vec<Node<N_ASSETS>>],
) where
pub fn compute_leaves<const N_ASSETS: usize>(entries: &[Entry<N_ASSETS>]) -> Vec<Node<N_ASSETS>>
where
[usize; N_ASSETS + 1]: Sized,
{
let results = entries
let leaves = entries
.par_iter()
.map(|entry| entry.compute_leaf())
.collect::<Vec<_>>();

for (index, node) in results.iter().enumerate() {
tree[0][index] = node.clone();
}
leaves
}

fn build_middle_level<const N_ASSETS: usize>(level: usize, tree: &mut [Vec<Node<N_ASSETS>>])
Expand All @@ -77,50 +75,3 @@ where
tree[level][index] = new_node;
}
}

pub fn build_merkle_tree_from_roots<const N_ASSETS: usize>(
roots: &[Node<N_ASSETS>],
depth: usize,
nodes: &mut Vec<Vec<Node<N_ASSETS>>>,
) -> Result<Node<N_ASSETS>, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
{
let n = roots.len();

let mut tree: Vec<Vec<Node<N_ASSETS>>> = Vec::with_capacity(depth + 1);

tree.push(vec![
Node {
hash: Fp::from(0),
balances: [Fp::from(0); N_ASSETS]
};
n
]);

for _ in 1..=depth {
let previous_level = tree.last().unwrap();
let nodes_in_level = (previous_level.len() + 1) / 2;

tree.push(vec![
Node {
hash: Fp::from(0),
balances: [Fp::from(0); N_ASSETS]
};
nodes_in_level
]);
}

for (index, node) in roots.iter().enumerate() {
tree[0][index] = node.clone();
}

for level in 1..=depth {
build_middle_level(level, &mut tree)
}

let root = tree[depth][0].clone();
*nodes = tree;
Ok(root)
}
2 changes: 1 addition & 1 deletion zk_prover/src/merkle_sum_tree/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod index_of;
mod operation_helpers;
mod proof_verification;

pub use build_tree::{build_merkle_tree_from_entries, build_merkle_tree_from_roots};
pub use build_tree::{build_merkle_tree_from_leaves, compute_leaves};
pub use create_proof::{create_proof, create_top_tree_proof};
pub use csv_parser::parse_csv_to_entries;
pub use generate_leaf_hash::generate_leaf_hash;
Expand Down

0 comments on commit e8f68d6

Please sign in to comment.