Skip to content

Commit

Permalink
feat: update MerkleProof
Browse files Browse the repository at this point in the history
  • Loading branch information
enricobottazzi committed Nov 21, 2023
1 parent d62f231 commit 63f89d0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 71 deletions.
12 changes: 5 additions & 7 deletions zk_prover/src/merkle_sum_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@ pub mod utils;
use halo2_proofs::halo2curves::bn256::Fr as Fp;

#[derive(Clone, Debug)]
pub struct MerkleProof<const N_ASSETS: usize, const N_BYTES: usize>
pub struct MerkleProof<const N_ASSETS: usize, const N_BYTES: usize>
where
[usize; N_ASSETS + 1]: Sized,
[usize; N_ASSETS + 2]: Sized,
[usize; N_ASSETS + 1]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
pub leaf: Node<N_ASSETS>,
pub root: Node<N_ASSETS>,
pub sibling_hashes: Vec<Fp>,
pub sibling_sums: Vec<[Fp; N_ASSETS]>,
pub sibling_leaf_hash_preimage: [Fp; N_ASSETS + 1],
pub sibling_node_hash_preimages: Vec<[Fp; N_ASSETS + 2]>,
pub sibling_leaf_node_hash_preimage: [Fp; N_ASSETS + 1],
pub sibling_middle_node_hash_preimages: Vec<[Fp; N_ASSETS + 2]>,
pub path_indices: Vec<Fp>,
}

Expand Down
4 changes: 0 additions & 4 deletions zk_prover/src/merkle_sum_tree/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ mod test {
let mut proof_invalid_2 = proof.clone();
proof_invalid_2.root.hash = 0.into();
assert!(!merkle_tree.verify_proof(&proof_invalid_2));

// shouldn't verify a proof with a wrong computed balance
let mut proof_invalid_3 = proof;
proof_invalid_3.sibling_sums[0] = [0.into(), 0.into()];
}

#[test]
Expand Down
93 changes: 33 additions & 60 deletions zk_prover/src/merkle_sum_tree/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,11 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
return Err(Box::from("Index out of bounds"));
}

let mut sibling_hashes = vec![Fp::zero(); depth];
let mut sibling_sums = vec![[Fp::zero(); N_ASSETS]; depth];
let mut sibling_node_hash_preimages = Vec::with_capacity(depth - 1);
let mut sibling_middle_node_hash_preimages = Vec::with_capacity(depth - 1);

let sibling_leaf_index = if index % 2 == 0 {
// Leaf is a left child, sibling is the next node
index + 1
} else {
// Leaf is a right child, sibling is the previous node
index - 1
};
let sibling_leaf_index = if index % 2 == 0 { index + 1 } else { index - 1 };

let sibling_leaf_hash_preimage: [Fp; N_ASSETS + 1] =
let sibling_leaf_node_hash_preimage: [Fp; N_ASSETS + 1] =
self.get_leaf_node_hash_preimage(sibling_leaf_index)?;
let mut path_indices = vec![Fp::zero(); depth];
let mut current_index = index;
Expand All @@ -129,14 +121,11 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
if sibling_index < nodes[level].len() {
let sibling_node = &nodes[level][sibling_index];

sibling_hashes[level] = sibling_node.hash;
sibling_sums[level] = sibling_node.balances;

if level != 0 {
// Fetch hash preimage for sibling middle nodes
let sibling_node_preimage =
self.get_middle_node_hash_preimage(level, sibling_index)?;
sibling_node_hash_preimages.push(sibling_node_preimage);
sibling_middle_node_hash_preimages.push(sibling_node_preimage);
}
}

Expand All @@ -147,10 +136,8 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
Ok(MerkleProof {
leaf: leaf.clone(),
root: root.clone(),
sibling_hashes,
sibling_sums,
sibling_leaf_hash_preimage,
sibling_node_hash_preimages,
sibling_leaf_node_hash_preimage,
sibling_middle_node_hash_preimages,
path_indices,
})
}
Expand All @@ -163,10 +150,18 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
{
let mut node = proof.leaf.clone();

// Perform leaf level verification outside of the loop
let sibling_leaf_node_balances = proof.sibling_leaf_node_hash_preimage[1..]
.try_into()
.unwrap();

let sibling_leaf_node_hash = poseidon_entry::<N_ASSETS>(
proof.sibling_leaf_node_hash_preimage[0],
sibling_leaf_node_balances,
);

let sibling_leaf_node = Node {
hash: proof.sibling_hashes[0],
balances: proof.sibling_sums[0],
hash: sibling_leaf_node_hash,
balances: sibling_leaf_node_balances,
};

if proof.path_indices[0] == 0.into() {
Expand All @@ -175,49 +170,27 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
node = Node::middle(&sibling_leaf_node, &node);
}

// Verify that the balances of the sibling leaf node matches that ones in the sibling leaf hash preimage
for (i, balance) in sibling_leaf_node.balances.iter().enumerate() {
if *balance != proof.sibling_leaf_hash_preimage[i + 1] {
return false;
}
}
for i in 1..proof.path_indices.len() {
let sibling_middle_node_balances = proof.sibling_middle_node_hash_preimages[i - 1]
[0..N_ASSETS]
.try_into()
.unwrap();

// Verify that the hash of the sibling leaf node matches the result of hashing the sibling leaf hash preimage
if sibling_leaf_node.hash
!= poseidon_entry::<N_ASSETS>(
proof.sibling_leaf_hash_preimage[0],
proof.sibling_leaf_hash_preimage[1..].try_into().unwrap(),
)
{
return false;
}
let sibling_middle_node_child_left_hash =
proof.sibling_middle_node_hash_preimages[i - 1][N_ASSETS];

let sibling_middle_node_child_right_hash =
proof.sibling_middle_node_hash_preimages[i - 1][N_ASSETS + 1];

for i in 1..proof.sibling_hashes.len() {
let sibling_node = Node {
hash: proof.sibling_hashes[i],
balances: proof.sibling_sums[i],
hash: poseidon_node::<N_ASSETS>(
sibling_middle_node_balances,
sibling_middle_node_child_left_hash,
sibling_middle_node_child_right_hash,
),
balances: sibling_middle_node_balances,
};

// Verify that the balances of the sibling node matches that ones in the sibling node hash preimage
for (j, balance) in sibling_node.balances.iter().enumerate() {
if *balance != proof.sibling_node_hash_preimages[i - 1][j] {
return false;
}
}

// Verify that the hash of the sibling node matches the result of hashing the sibling node hash preimage
if sibling_node.hash
!= poseidon_node::<N_ASSETS>(
proof.sibling_node_hash_preimages[i - 1][0..N_ASSETS]
.try_into()
.unwrap(),
proof.sibling_node_hash_preimages[i - 1][N_ASSETS],
proof.sibling_node_hash_preimages[i - 1][N_ASSETS + 1],
)
{
return false;
}

if proof.path_indices[i] == 0.into() {
node = Node::middle(&node, &sibling_node);
} else {
Expand Down

0 comments on commit 63f89d0

Please sign in to comment.