diff --git a/src/compiler/cse.rs b/src/compiler/cse.rs index b292bc90..882ba791 100644 --- a/src/compiler/cse.rs +++ b/src/compiler/cse.rs @@ -1,10 +1,6 @@ use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use std::{ - collections::{HashMap, HashSet}, - hash::Hash, - marker::PhantomData, -}; +use std::{collections::HashMap, hash::Hash, marker::PhantomData}; use crate::{ field::Field, @@ -94,7 +90,6 @@ fn cse_for_step>( scorer: &S, ) { let mut signal_factory = SignalFactory::default(); - let mut replaced_hashes = HashSet::new(); for _ in 0..config.max_iterations { // Step 1: Collect all queriables (forward and internal signals) @@ -114,7 +109,7 @@ fn cse_for_step>( .collect(); // Step 5: Find the optimal subexpression to replace - if let Some(common_expr) = find_optimal_subexpression(&exprs, &replaced_hashes, scorer) { + if let Some(common_expr) = find_optimal_subexpression(&exprs, scorer) { // Step 6: Create a new signal for the common subexpression let (common_se, decomp) = create_common_ses_signal(&common_expr, &mut signal_factory); @@ -124,9 +119,6 @@ fn cse_for_step>( decomp, &common_se, ); - - // Step 8: Mark this subexpression as replaced - replaced_hashes.insert(common_expr.meta().hash); } else { // No more common subexpressions found, exit the loop break; @@ -217,7 +209,6 @@ impl SubexprInfo { /// Find the optimal subexpression to replace in a list of expressions. fn find_optimal_subexpression>( exprs: &[Expr, HashResult>], - replaced_hashes: &HashSet, scorer: &S, ) -> Option, HashResult>> { let mut count_map = HashMap::::new(); @@ -225,7 +216,7 @@ fn find_optimal_subexpression>( // Extract all subexpressions and count them for expr in exprs.iter() { - count_subexpressions(expr, &mut count_map, &mut hash_to_expr, replaced_hashes); + count_subexpressions(expr, &mut count_map, &mut hash_to_expr); } // Find the best common subexpression to replace based on the score @@ -248,35 +239,31 @@ fn count_subexpressions( expr: &Expr, HashResult>, count_map: &mut HashMap, hash_to_expr: &mut HashMap, HashResult>>, - replaced_hashes: &HashSet, ) { let degree = expr.degree(); let hash_result = expr.meta().hash; - // Only count and store if not already replaced - if !replaced_hashes.contains(&hash_result) { - // Store the expression with its hash - hash_to_expr.insert(hash_result, expr.clone()); + // Store the expression with its hash + hash_to_expr.insert(hash_result, expr.clone()); - count_map - .entry(hash_result) - .and_modify(|info| info.update(degree)) - .or_insert(SubexprInfo::new(1, degree)); - } + count_map + .entry(hash_result) + .and_modify(|info| info.update(degree)) + .or_insert(SubexprInfo::new(1, degree)); // Recurse into subexpressions match expr { Expr::Const(_, _) | Expr::Query(_, _) => {} Expr::Sum(exprs, _) | Expr::Mul(exprs, _) => { for subexpr in exprs { - count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes); + count_subexpressions(subexpr, count_map, hash_to_expr); } } Expr::Neg(subexpr, _) | Expr::MI(subexpr, _) => { - count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes); + count_subexpressions(subexpr, count_map, hash_to_expr); } Expr::Pow(subexpr, _, _) => { - count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes); + count_subexpressions(subexpr, count_map, hash_to_expr); } _ => {} } @@ -302,7 +289,7 @@ impl poly::SignalFactory> for SignalFactory { #[cfg(test)] mod test { - use std::collections::{HashMap, HashSet}; + use std::collections::HashMap; use halo2_proofs::halo2curves::bn256::Fr; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; @@ -349,7 +336,7 @@ mod test { let scorer = Scorer::default(); - let best_expr = find_optimal_subexpression(&hashed_exprs, &HashSet::new(), &scorer); + let best_expr = find_optimal_subexpression(&hashed_exprs, &scorer); assert_eq!(format!("{:?}", best_expr.unwrap()), "(e * f * d)"); }