From 050998442390aabfb2ffc8d71086bc646f25f4db Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Fri, 24 Nov 2023 13:50:33 +0800 Subject: [PATCH] Add circuit audit comments --- zk_prover/src/chips/merkle_sum_tree.rs | 18 +++++++++++++++++- zk_prover/src/chips/range/range_check.rs | 5 +++-- zk_prover/src/circuits/merkle_sum_tree.rs | 7 +++---- zk_prover/src/circuits/solvency.rs | 6 +++--- zk_prover/src/circuits/utils.rs | 4 ++-- .../src/merkle_sum_tree/utils/csv_parser.rs | 4 ++-- 6 files changed, 30 insertions(+), 14 deletions(-) diff --git a/zk_prover/src/chips/merkle_sum_tree.rs b/zk_prover/src/chips/merkle_sum_tree.rs index b5cd463c..3686e97e 100644 --- a/zk_prover/src/chips/merkle_sum_tree.rs +++ b/zk_prover/src/chips/merkle_sum_tree.rs @@ -20,7 +20,7 @@ pub struct MerkleSumTreeConfig { /// Chip that performs various constraints related to a Merkle Sum Tree data structure such as: /// /// * `s * swap_bit * (1 - swap_bit) = 0` (if `bool_and_swap_selector` is toggled). It basically enforces that swap_bit is either a 0 or 1. -/// * `s * (swap_bit * 2 * (elelment_r_cur - elelment_l_cur) - (elelment_l_next - elelment_l_cur) - (elelment_r_cur - elelment_r_next)) = 0`. Enforces that if the swap_bit is equal to 1, the values will be swapped on the next row (if `bool_and_swap_selector` is toggled). +/// * `s * (swap_bit * 2 * (element_r_cur - element_l_cur) - (element_l_next - element_l_cur) - (element_r_cur - element_r_next)) = 0`. Enforces that if the swap_bit is equal to 1, the values will be swapped on the next row (if `bool_and_swap_selector` is toggled). /// If the swap_bit is equal to 0, the values will remain the same on the next row (if `bool_and_swap_selector` is toggled). /// * `s * (left_balance + right_balance - computed_sum)`. It constraints the computed sum to be equal to the sum of the left and right balances (if `sum_selector` is toggled). #[derive(Debug, Clone)] @@ -59,6 +59,22 @@ impl MerkleSumTreeChip { let element_l_next = meta.query_advice(col_a, Rotation::next()); let element_r_next = meta.query_advice(col_b, Rotation::next()); + // Audit: The constraint is aimed at checking the correct swap of the values. If swap_bit is 1, + // element_l_cur == element_r_next and element_r_cur == element_l_next. + // If swap_bit is 0, element_l_cur == element_l_next and element_r_cur == element_r_next. + // However, if we combine the two equations, there's a potential for unintended solutions + // that satisfy the composite equation without satisfying the intended individual constraints. + // For the case where swap_bit = 0: + // element_l_cur − element_l_next + element_r_next − element_r_cur = 0 + // Here, it's theoretically possible to cheat by picking element_l_cur, element_l_next, + // element_r_cur, and element_r_next such that they don't individually satisfy + // element_l_cur − element_l_next = 0 and element_l_cur − element_l_next = 0 + // but still satisfy the combined equation. + // For example, if element_l_cur = 3 and element_l_next = 1, and element_r_cur = 2 and element_r_next = 4, + // the composite equation would still be satisfied because: + // 3 − 1 + 4 − 2 = 0 + // even though neither element_l_cur = element_l_next nor element_r_cur = element_r_next. + // Splitting the constraint into two separate constraints would prevent this and also make the constraints more human-readable. let swap_constraint = s * ((swap_bit * Expression::Constant(Fp::from(2)) diff --git a/zk_prover/src/chips/range/range_check.rs b/zk_prover/src/chips/range/range_check.rs index 2d8bb7b7..58573e3b 100644 --- a/zk_prover/src/chips/range/range_check.rs +++ b/zk_prover/src/chips/range/range_check.rs @@ -28,7 +28,7 @@ pub struct RangeCheckConfig { lookup_enable_selector: Selector, } -/// Helper chip that verfiies that the value witnessed in a given cell lies within a given range defined by N_BYTES. +/// Helper chip that verifies that the value witnessed in a given cell lies within a given range defined by N_BYTES. /// For example, Let's say we want to constraint 0x1f2f3f4f to be within the range N_BYTES=4. /// /// `z(0) = 0x1f2f3f4f` @@ -47,7 +47,7 @@ pub struct RangeCheckConfig { /// /// The column z contains the witnessed value to be checked at offset 0 /// At offset i, the column z contains the value z(i+1) = (z(i) - k(i)) / 2^8 (shift right by 8 bits) where k(i) is the i-th decomposition big-endian of `value` -/// The contraints that are enforced are: +/// The constraints that are enforced are: /// - z(i) - 2^8⋅z(i+1) ∈ lookup_u8 (enabled by lookup_enable_selector at offset [0, N_BYTES - 1]) /// - z(N_BYTES) == 0 #[derive(Debug, Clone)] @@ -155,6 +155,7 @@ impl RangeCheckChip { } /// Loads the lookup table with values from `0` to `2^8 - 1` + // Audit: Is there a way to share the range check lookup table between all chip instantiations? pub fn load(&self, layouter: &mut impl Layouter) -> Result<(), Error> { let range = 1 << (8); diff --git a/zk_prover/src/circuits/merkle_sum_tree.rs b/zk_prover/src/circuits/merkle_sum_tree.rs index 3ed075c0..77ce5e4c 100644 --- a/zk_prover/src/circuits/merkle_sum_tree.rs +++ b/zk_prover/src/circuits/merkle_sum_tree.rs @@ -222,8 +222,7 @@ where config.poseidon_middle_config.clone(), ); - let range_check_chip = - RangeCheckChip::::construct(config.range_check_config.clone()); + let range_check_chip = RangeCheckChip::::construct(config.range_check_config); // Assign the entry username let username = self.assign_value_to_witness( @@ -302,7 +301,7 @@ where let mut right_balances = vec![]; // Within each level, assign the balances to the circuit per asset - for asset in 0..N_ASSETS { + for (asset, current_balance) in current_balances.iter().enumerate().take(N_ASSETS) { let (left_balance, right_balance, next_balance) = merkle_sum_tree_chip .assign_nodes_balance_per_asset( layouter.namespace(|| { @@ -311,7 +310,7 @@ where namespace_prefix, asset ) }), - ¤t_balances[asset], + current_balance, self.path_element_balances[level][asset], swap_bit_level.clone(), )?; diff --git a/zk_prover/src/circuits/solvency.rs b/zk_prover/src/circuits/solvency.rs index 3d5cde6b..8dc45886 100644 --- a/zk_prover/src/circuits/solvency.rs +++ b/zk_prover/src/circuits/solvency.rs @@ -300,11 +300,11 @@ where let mut right_balances = vec![]; // assign penultimate nodes balances per each asset according to the swap bit - for asset in 0..N_ASSETS { + for (asset, left_node_balance) in left_node_balances.iter().enumerate().take(N_ASSETS) { let (left_balance, right_balance, next_balance) = merkle_sum_tree_chip .assign_nodes_balance_per_asset( layouter.namespace(|| format!("asset {}: assign nodes balances", asset)), - &left_node_balances[asset], + left_node_balance, self.right_node_balances[asset], swap_bit.clone(), )?; @@ -335,7 +335,7 @@ where hash_input, )?; - // expose the root hash, as public input + // expose the root hash as public input self.expose_public( layouter.namespace(|| "public root hash"), &root_hash, diff --git a/zk_prover/src/circuits/utils.rs b/zk_prover/src/circuits/utils.rs index 96bb2bcb..ee1075e3 100644 --- a/zk_prover/src/circuits/utils.rs +++ b/zk_prover/src/circuits/utils.rs @@ -225,7 +225,7 @@ fn fix_verifier_sol(yul_code_path: PathBuf) -> Result()?; let transcript_addr = format!("{:#x}", addr_as_num); transcript_addrs.push(addr_as_num); line = line.replace( @@ -238,7 +238,7 @@ fn fix_verifier_sol(yul_code_path: PathBuf) -> Result()?; let transcript_addr = format!("{:#x}", addr_as_num); transcript_addrs.push(addr_as_num); line = line.replace( diff --git a/zk_prover/src/merkle_sum_tree/utils/csv_parser.rs b/zk_prover/src/merkle_sum_tree/utils/csv_parser.rs index 75bb39a6..e2fcfc61 100644 --- a/zk_prover/src/merkle_sum_tree/utils/csv_parser.rs +++ b/zk_prover/src/merkle_sum_tree/utils/csv_parser.rs @@ -21,7 +21,7 @@ pub fn parse_csv_to_entries, const N_ASSETS: usize, const N_BYTES .delimiter(b';') // The fields are separated by a semicolon .from_reader(file); - let mut balances_acc: Vec = vec![BigUint::from(0 as usize); N_ASSETS]; + let mut balances_acc: Vec = vec![BigUint::from(0_usize); N_ASSETS]; for result in rdr.deserialize() { let record: CsvEntry = result?; @@ -47,7 +47,7 @@ pub fn parse_csv_to_entries, const N_ASSETS: usize, const N_BYTES // Iterate through the balance accumulator and throw error if any balance is not in range 0, 2 ^ (8 * N_BYTES): for balance in balances_acc { - if balance >= BigUint::from(2 as usize).pow(8 * N_BYTES as u32) { + if balance >= BigUint::from(2_usize).pow(8 * N_BYTES as u32) { return Err( "Accumulated balance is not in the expected range, proof generation will fail!" .into(),