Skip to content

Commit

Permalink
feat: tree building rules
Browse files Browse the repository at this point in the history
  • Loading branch information
enricobottazzi committed Nov 21, 2023
1 parent e6cdd58 commit e0e8ce2
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 28 deletions.
12 changes: 6 additions & 6 deletions zk_prover/src/merkle_sum_tree/mst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use num_bigint::BigUint;
///
/// A Merkle Sum Tree is a binary Merkle Tree with the following properties:
/// * Each Entry of a Merkle Sum Tree is a pair of a username and #N_ASSETS balances.
/// * Each Leaf Node contains a hash and #N_ASSETS balances. The hash is equal to `H(username, balance[0], balance[1], ... balance[N_ASSETS])`.
/// * Each Middle Node contains a hash and #N_ASSETS balances. The hash is equal to `H(LeftChild.hash, LeftChild.balance[0], LeftChild.balance[1], LeftChild.balance[N_ASSETS], RightChild.hash, RightChild.balance[0], RightChild.balance[1], RightChild.balance[N_ASSETS])`. The balances are equal to the sum of the balances of the child nodes per each asset.
/// * Each Leaf Node contains a hash and #N_ASSETS balances. The hash is equal to `H(username, balance[0], balance[1], ... balance[N_ASSETS])`. The balances are equal to the balances associated to the entry
/// * Each Middle Node contains a hash and #N_ASSETS balances. The hash is equal to `H(LeftChild.balance[0] + RightChild.balance[0], LeftChild.balance[1] + RightChild.balance[1], ..., LeftChild.balance[N_ASSETS] + RightChild.balance[N_ASSETS], LeftChild.hash, RightChild.hash)`. The balances are equal to the sum of the balances of the child nodes per each asset.
/// * The Root Node represents the committed state of the Tree and contains the sum of all the entries' balances per each asset.
///
/// # Type Parameters
Expand Down Expand Up @@ -58,7 +58,7 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE
pub fn new(path: &str) -> Result<Self, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
let entries = parse_csv_to_entries::<&str, N_ASSETS, N_BYTES>(path)?;
Self::from_entries(entries, false)
Expand All @@ -72,7 +72,7 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE
pub fn new_sorted(path: &str) -> Result<Self, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
let mut entries = parse_csv_to_entries::<&str, N_ASSETS, N_BYTES>(path)?;

Expand All @@ -87,7 +87,7 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE
) -> Result<MerkleSumTree<N_ASSETS, N_BYTES>, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
let depth = (entries.len() as f64).log2().ceil() as usize;

Expand Down Expand Up @@ -123,7 +123,7 @@ impl<const N_ASSETS: usize, const N_BYTES: usize> MerkleSumTree<N_ASSETS, N_BYTE
) -> Result<Node<N_ASSETS>, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
let index = self.index_of_username(username)?;

Expand Down
9 changes: 2 additions & 7 deletions zk_prover/src/merkle_sum_tree/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,15 @@ impl<const N_ASSETS: usize> Node<N_ASSETS> {
/// Builds a "middle" (non-leaf-level) node of the MST
pub fn middle(child_l: &Node<N_ASSETS>, child_r: &Node<N_ASSETS>) -> Node<N_ASSETS>
where
[usize; 2 * (1 + N_ASSETS)]: Sized,
[(); N_ASSETS + 2]: Sized,
{
let mut balances_sum = [Fp::zero(); N_ASSETS];
for (i, balance) in balances_sum.iter_mut().enumerate() {
*balance = child_l.balances[i] + child_r.balances[i];
}

Node {
hash: poseidon_node(
child_l.hash,
child_l.balances,
child_r.hash,
child_r.balances,
),
hash: poseidon_node(balances_sum, child_l.hash, child_r.hash),
balances: balances_sum,
}
}
Expand Down
3 changes: 1 addition & 2 deletions zk_prover/src/merkle_sum_tree/tree.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::merkle_sum_tree::{Entry, MerkleProof, Node};
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use num_bigint::BigUint;

/// A trait representing the basic operations for a Merkle-Sum-like Tree.
pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
Expand Down Expand Up @@ -64,7 +63,7 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
fn verify_proof(&self, proof: &MerkleProof<N_ASSETS, N_BYTES>) -> bool
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
let mut node = proof.leaf.clone();

Expand Down
4 changes: 2 additions & 2 deletions zk_prover/src/merkle_sum_tree/utils/build_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub fn build_merkle_tree_from_leaves<const N_ASSETS: usize>(
) -> Result<Node<N_ASSETS>, Box<dyn std::error::Error>>
where
[usize; N_ASSETS + 1]: Sized,
[usize; 2 * (1 + N_ASSETS)]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
let n = leaves.len();

Expand Down Expand Up @@ -65,7 +65,7 @@ where

fn build_middle_level<const N_ASSETS: usize>(level: usize, tree: &mut [Vec<Node<N_ASSETS>>])
where
[usize; 2 * (1 + N_ASSETS)]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
let results: Vec<Node<N_ASSETS>> = (0..tree[level - 1].len())
.into_par_iter()
Expand Down
20 changes: 9 additions & 11 deletions zk_prover/src/merkle_sum_tree/utils/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@ use halo2_gadgets::poseidon::primitives::{self as poseidon, ConstantLength};
use halo2_proofs::halo2curves::bn256::Fr as Fp;

pub fn poseidon_node<const N_ASSETS: usize>(
l1: Fp,
l2: [Fp; N_ASSETS],
r1: Fp,
r2: [Fp; N_ASSETS],
balances_sum: [Fp; N_ASSETS],
hash_child_left: Fp,
hash_child_right: Fp,
) -> Fp
where
[usize; 2 * (1 + N_ASSETS)]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
let mut hash_inputs: [Fp; 2 * (1 + N_ASSETS)] = [Fp::zero(); 2 * (1 + N_ASSETS)];
let mut hash_inputs: [Fp; N_ASSETS + 2] = [Fp::zero(); N_ASSETS + 2];

hash_inputs[0] = l1;
hash_inputs[1..N_ASSETS + 1].copy_from_slice(&l2);
hash_inputs[N_ASSETS + 1] = r1;
hash_inputs[N_ASSETS + 2..2 * N_ASSETS + 2].copy_from_slice(&r2);
hash_inputs[0..N_ASSETS].copy_from_slice(&balances_sum);
hash_inputs[N_ASSETS] = hash_child_left;
hash_inputs[N_ASSETS + 1] = hash_child_right;

poseidon::Hash::<Fp, PoseidonSpec, ConstantLength<{ 2 * (1 + N_ASSETS) }>, 2, 1>::init()
poseidon::Hash::<Fp, PoseidonSpec, ConstantLength<{ N_ASSETS + 2 }>, 2, 1>::init()
.hash(hash_inputs)
}

Expand Down

0 comments on commit e0e8ce2

Please sign in to comment.