Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
removed unnecessary HashSet to store replaced expression hashes
Browse files Browse the repository at this point in the history
  • Loading branch information
rutefig committed Aug 17, 2024
1 parent 6f5a220 commit 0b42b48
Showing 1 changed file with 14 additions and 27 deletions.
41 changes: 14 additions & 27 deletions src/compiler/cse.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};

use std::{
collections::{HashMap, HashSet},
hash::Hash,
marker::PhantomData,
};
use std::{collections::HashMap, hash::Hash, marker::PhantomData};

use crate::{
field::Field,
Expand Down Expand Up @@ -94,7 +90,6 @@ fn cse_for_step<F: Field + Hash, S: Scoring<F>>(
scorer: &S,
) {
let mut signal_factory = SignalFactory::default();
let mut replaced_hashes = HashSet::new();

for _ in 0..config.max_iterations {
// Step 1: Collect all queriables (forward and internal signals)
Expand All @@ -114,7 +109,7 @@ fn cse_for_step<F: Field + Hash, S: Scoring<F>>(
.collect();

// Step 5: Find the optimal subexpression to replace
if let Some(common_expr) = find_optimal_subexpression(&exprs, &replaced_hashes, scorer) {
if let Some(common_expr) = find_optimal_subexpression(&exprs, scorer) {
// Step 6: Create a new signal for the common subexpression
let (common_se, decomp) = create_common_ses_signal(&common_expr, &mut signal_factory);

Expand All @@ -124,9 +119,6 @@ fn cse_for_step<F: Field + Hash, S: Scoring<F>>(
decomp,
&common_se,
);

// Step 8: Mark this subexpression as replaced
replaced_hashes.insert(common_expr.meta().hash);
} else {
// No more common subexpressions found, exit the loop
break;
Expand Down Expand Up @@ -217,15 +209,14 @@ impl SubexprInfo {
/// Find the optimal subexpression to replace in a list of expressions.
fn find_optimal_subexpression<F: Field + Hash, S: Scoring<F>>(
exprs: &[Expr<F, Queriable<F>, HashResult>],
replaced_hashes: &HashSet<u64>,
scorer: &S,
) -> Option<Expr<F, Queriable<F>, HashResult>> {
let mut count_map = HashMap::<u64, SubexprInfo>::new();
let mut hash_to_expr = HashMap::<u64, Expr<F, Queriable<F>, HashResult>>::new();

// Extract all subexpressions and count them
for expr in exprs.iter() {
count_subexpressions(expr, &mut count_map, &mut hash_to_expr, replaced_hashes);
count_subexpressions(expr, &mut count_map, &mut hash_to_expr);
}

// Find the best common subexpression to replace based on the score
Expand All @@ -248,35 +239,31 @@ fn count_subexpressions<F: Field + Hash>(
expr: &Expr<F, Queriable<F>, HashResult>,
count_map: &mut HashMap<u64, SubexprInfo>,
hash_to_expr: &mut HashMap<u64, Expr<F, Queriable<F>, HashResult>>,
replaced_hashes: &HashSet<u64>,
) {
let degree = expr.degree();
let hash_result = expr.meta().hash;

// Only count and store if not already replaced
if !replaced_hashes.contains(&hash_result) {
// Store the expression with its hash
hash_to_expr.insert(hash_result, expr.clone());
// Store the expression with its hash
hash_to_expr.insert(hash_result, expr.clone());

count_map
.entry(hash_result)
.and_modify(|info| info.update(degree))
.or_insert(SubexprInfo::new(1, degree));
}
count_map
.entry(hash_result)
.and_modify(|info| info.update(degree))
.or_insert(SubexprInfo::new(1, degree));

// Recurse into subexpressions
match expr {
Expr::Const(_, _) | Expr::Query(_, _) => {}
Expr::Sum(exprs, _) | Expr::Mul(exprs, _) => {
for subexpr in exprs {
count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes);
count_subexpressions(subexpr, count_map, hash_to_expr);
}
}
Expr::Neg(subexpr, _) | Expr::MI(subexpr, _) => {
count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes);
count_subexpressions(subexpr, count_map, hash_to_expr);
}
Expr::Pow(subexpr, _, _) => {
count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes);
count_subexpressions(subexpr, count_map, hash_to_expr);
}
_ => {}
}
Expand All @@ -302,7 +289,7 @@ impl<F> poly::SignalFactory<Queriable<F>> for SignalFactory<F> {

#[cfg(test)]
mod test {
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;

use halo2_proofs::halo2curves::bn256::Fr;
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};
Expand Down Expand Up @@ -349,7 +336,7 @@ mod test {

let scorer = Scorer::default();

let best_expr = find_optimal_subexpression(&hashed_exprs, &HashSet::new(), &scorer);
let best_expr = find_optimal_subexpression(&hashed_exprs, &scorer);

assert_eq!(format!("{:?}", best_expr.unwrap()), "(e * f * d)");
}
Expand Down

0 comments on commit 0b42b48

Please sign in to comment.