diff --git a/zk_prover/src/merkle_sum_tree/node.rs b/zk_prover/src/merkle_sum_tree/node.rs index 662041ae..19da8f0d 100644 --- a/zk_prover/src/merkle_sum_tree/node.rs +++ b/zk_prover/src/merkle_sum_tree/node.rs @@ -17,24 +17,13 @@ impl Node { where [usize; N_ASSETS + 1]: Sized, { - Node { - hash: Self::poseidon_hash_leaf( - big_uint_to_fp(username), - balances - .iter() - .map(big_uint_to_fp) - .collect::>() - .try_into() - .unwrap(), - ), - //Map the array of balances using big_int_to_fp: - balances: balances - .iter() - .map(big_uint_to_fp) - .collect::>() - .try_into() - .unwrap(), + let mut hash_preimage = [Fp::zero(); N_ASSETS + 1]; + hash_preimage[0] = big_uint_to_fp(username); + for (i, balance) in hash_preimage.iter_mut().enumerate().skip(1) { + *balance = big_uint_to_fp(&balances[i - 1]); } + + Node::leaf_node_from_preimage(&hash_preimage) } /// Builds a "middle" (non-leaf-level) node of the MST /// The middle node hash is equal to `H(LeftChild.balance[0] + RightChild.balance[0], LeftChild.balance[1] + RightChild.balance[1], ..., LeftChild.balance[N_ASSETS - 1] + RightChild.balance[N_ASSETS - 1], LeftChild.hash, RightChild.hash)` @@ -43,15 +32,14 @@ impl Node { where [(); N_ASSETS + 2]: Sized, { - let mut balances_sum = [Fp::zero(); N_ASSETS]; - for (i, balance) in balances_sum.iter_mut().enumerate() { + let mut hash_preimage = [Fp::zero(); N_ASSETS + 2]; + for (i, balance) in hash_preimage.iter_mut().enumerate().take(N_ASSETS) { *balance = child_l.balances[i] + child_r.balances[i]; } + hash_preimage[N_ASSETS] = child_l.hash; + hash_preimage[N_ASSETS + 1] = child_r.hash; - Node { - hash: Self::poseidon_hash_middle(balances_sum, child_l.hash, child_r.hash), - balances: balances_sum, - } + Node::middle_node_from_preimage(&hash_preimage) } pub fn init_empty() -> Node @@ -64,7 +52,7 @@ impl Node { } } - pub fn leaf_node_from_preimage(preimage: [Fp; N_ASSETS + 1]) -> Node + pub fn leaf_node_from_preimage(preimage: &[Fp; N_ASSETS + 1]) -> Node where [usize; N_ASSETS + 1]: Sized, { diff --git a/zk_prover/src/merkle_sum_tree/tests.rs b/zk_prover/src/merkle_sum_tree/tests.rs index f2a67068..c00b672b 100644 --- a/zk_prover/src/merkle_sum_tree/tests.rs +++ b/zk_prover/src/merkle_sum_tree/tests.rs @@ -241,7 +241,7 @@ mod test { // Fetch the hash preimage of the leaf let hash_preimage = merkle_tree.get_leaf_node_hash_preimage(index).unwrap(); - let computed_leaf = Node::::leaf_node_from_preimage(hash_preimage); + let computed_leaf = Node::::leaf_node_from_preimage(&hash_preimage); // The hash of the leaf should match the hash computed from the hash preimage assert_eq!(leaf.hash, computed_leaf.hash); diff --git a/zk_prover/src/merkle_sum_tree/tree.rs b/zk_prover/src/merkle_sum_tree/tree.rs index b443b429..7f7fb3ba 100644 --- a/zk_prover/src/merkle_sum_tree/tree.rs +++ b/zk_prover/src/merkle_sum_tree/tree.rs @@ -150,7 +150,7 @@ pub trait Tree { let mut node = proof.entry.compute_leaf(); let sibling_leaf_node = - Node::::leaf_node_from_preimage(proof.sibling_leaf_node_hash_preimage); + Node::::leaf_node_from_preimage(&proof.sibling_leaf_node_hash_preimage); let mut hash_preimage = [Fp::zero(); N_ASSETS + 2]; for (i, balance) in hash_preimage.iter_mut().enumerate().take(N_ASSETS) {