From 0720a88d4339eb5d9e640e98e8009cf0323f3f85 Mon Sep 17 00:00:00 2001 From: Rute Figueiredo Date: Tue, 13 Aug 2024 12:20:53 +0100 Subject: [PATCH] moved min_degree and min_occurrences from CseConfig to Scorer --- src/compiler/cse.rs | 82 ++++++++++++++++++++------------------------- 1 file changed, 37 insertions(+), 45 deletions(-) diff --git a/src/compiler/cse.rs b/src/compiler/cse.rs index 9cb45962..f8a2e75d 100644 --- a/src/compiler/cse.rs +++ b/src/compiler/cse.rs @@ -21,40 +21,52 @@ use crate::{ #[derive(Clone, Debug)] pub struct CseConfig { - min_degree: usize, - min_occurrences: usize, - max_iterations: usize, } impl Default for CseConfig { fn default() -> Self { Self { - min_degree: 2, - min_occurrences: 2, - max_iterations: 100, } } } #[allow(dead_code)] -pub fn config( - min_degree: usize, - min_occurrences: usize, - max_iterations: Option, -) -> CseConfig { +pub fn config(max_iterations: Option) -> CseConfig { CseConfig { - min_degree, - min_occurrences, max_iterations: max_iterations.unwrap_or(100), } } -pub trait Scorer { +pub trait Scoring { fn score(&self, expr: &Expr, HashResult>, info: &SubexprInfo) -> usize; } +pub struct Scorer { + min_degree: usize, + min_occurrences: usize, +} + +impl Default for Scorer { + fn default() -> Self { + Self { + min_degree: 2, + min_occurrences: 2, + } + } +} + +impl Scoring for Scorer { + fn score(&self, _expr: &Expr, HashResult>, info: &SubexprInfo) -> usize { + if info.degree < self.min_degree || info.count < self.min_occurrences { + 0 + } else { + info.count * info.degree + } + } +} + /// Common Subexpression Elimination (CSE) optimization. /// This optimization replaces common subexpressions with new internal signals for the step type. /// This is done by each time finding the optimal subexpression to replace and creating a new signal @@ -64,7 +76,7 @@ pub trait Scorer { /// queriables. Using the Schwartz-Zippel lemma, we can determine if two expressions are equivalent /// with high probability. #[allow(dead_code)] -pub(super) fn cse>( +pub(super) fn cse>( mut circuit: SBPIRMachine, config: CseConfig, scorer: &S, @@ -75,7 +87,7 @@ pub(super) fn cse>( circuit } -fn cse_for_step>( +fn cse_for_step>( step_type: &mut StepType, forward_signals: &[ForwardSignal], config: &CseConfig, @@ -118,7 +130,7 @@ fn cse_for_step>( // Find the optimal subexpression to replace if let Some(common_expr) = - find_optimal_subexpression(&exprs, &replaced_hashes, config.clone(), scorer) + find_optimal_subexpression(&exprs, &replaced_hashes, scorer) { // Add the hash of the replaced expression to the set replaced_hashes.insert(common_expr.meta().hash); @@ -171,10 +183,9 @@ impl SubexprInfo { } /// Find the optimal subexpression to replace in a list of expressions. -fn find_optimal_subexpression>( +fn find_optimal_subexpression>( exprs: &[Expr, HashResult>], replaced_hashes: &HashSet, - config: CseConfig, scorer: &S, ) -> Option, HashResult>> { let mut count_map = HashMap::::new(); @@ -185,18 +196,8 @@ fn find_optimal_subexpression>( count_subexpressions(expr, &mut count_map, &mut hash_to_expr, replaced_hashes); } - // Find the best common subexpression to replace - let common_ses = count_map - .into_iter() - .filter(|&(hash, info)| { - info.count >= config.min_occurrences - && info.degree >= config.min_degree - && !replaced_hashes.contains(&hash) - }) - .collect::>(); - // Find the best common subexpression to replace based on the score - let best_subexpr = common_ses + let best_subexpr = count_map .iter() .map(|(&hash, info)| { let expr = hash_to_expr.get(&hash).unwrap(); @@ -269,29 +270,21 @@ impl poly::SignalFactory> for SignalFactory { #[cfg(test)] mod test { - use std::{collections::HashSet, hash::Hash}; + use std::collections::HashSet; use halo2_proofs::halo2curves::bn256::Fr; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; use crate::{ - compiler::cse::{cse, CseConfig}, + compiler::cse::{cse, CseConfig, Scorer}, field::Field, - poly::{Expr, HashResult, VarAssignments}, + poly::{Expr, VarAssignments}, sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, InternalSignal, StepType}, util::uuid, wit_gen::NullTraceGenerator, }; - use super::{find_optimal_subexpression, Scorer, SubexprInfo}; - - pub struct TestScorer; - - impl Scorer for TestScorer { - fn score(&self, _expr: &Expr, HashResult>, info: &SubexprInfo) -> usize { - 2 * info.count + 3 * info.degree - } - } + use super::find_optimal_subexpression; #[test] fn test_find_optimal_subexpression() { @@ -322,12 +315,11 @@ mod test { hashed_exprs.push(hashed_expr); } - let scorer = TestScorer; + let scorer = Scorer::default(); let best_expr = find_optimal_subexpression( &hashed_exprs, &HashSet::new(), - CseConfig::default(), &scorer, ); @@ -375,7 +367,7 @@ mod test { let mut circuit: SBPIRMachine = SBPIRMachine::default(); let step_uuid = circuit.add_step_type_def(step); - let scorer = TestScorer; + let scorer = Scorer::default(); let circuit = cse(circuit, CseConfig::default(), &scorer); let common_ses_found_and_replaced = circuit