diff --git a/src/compiler/cse.rs b/src/compiler/cse.rs index 8c1fc5e1..e5742884 100644 --- a/src/compiler/cse.rs +++ b/src/compiler/cse.rs @@ -13,14 +13,12 @@ use crate::{ cse::{create_common_ses_signal, replace_expr}, Expr, HashResult, VarAssignments, }, - sbpir::{ - query::Queriable, sbpir_machine::SBPIRMachine, ForwardSignal, InternalSignal, StepType, - }, + sbpir::{query::Queriable, ForwardSignal, InternalSignal, StepType, SBPIR}, wit_gen::NullTraceGenerator, }; #[derive(Clone, Debug)] -pub struct CseConfig { +pub(super) struct CseConfig { max_iterations: usize, } @@ -39,11 +37,11 @@ pub fn config(max_iterations: Option) -> CseConfig { } } -pub trait Scoring { +pub(super) trait Scoring { fn score(&self, expr: &Expr, HashResult>, info: &SubexprInfo) -> usize; } -pub struct Scorer { +pub(super) struct Scorer { min_degree: usize, min_occurrences: usize, } @@ -77,12 +75,14 @@ impl Scoring for Scorer { /// with high probability. #[allow(dead_code)] pub(super) fn cse>( - mut circuit: SBPIRMachine, + mut circuit: SBPIR, config: CseConfig, scorer: &S, -) -> SBPIRMachine { - for (_, step_type) in circuit.step_types.iter_mut() { - cse_for_step(step_type, &circuit.forward_signals, &config, scorer) +) -> SBPIR { + for (_, machine) in circuit.machines.iter_mut() { + for (_, step_type) in machine.step_types.iter_mut() { + cse_for_step(step_type, &machine.forward_signals, &config, scorer) + } } circuit } @@ -268,7 +268,7 @@ impl poly::SignalFactory> for SignalFactory { #[cfg(test)] mod test { - use std::collections::HashSet; + use std::collections::{HashMap, HashSet}; use halo2_proofs::halo2curves::bn256::Fr; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; @@ -277,7 +277,7 @@ mod test { compiler::cse::{cse, CseConfig, Scorer}, field::Field, poly::{Expr, VarAssignments}, - sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, InternalSignal, StepType}, + sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, InternalSignal, StepType, SBPIR}, util::uuid, wit_gen::NullTraceGenerator, }; @@ -358,41 +358,126 @@ mod test { step.add_constr("expr4".into(), expr4.clone()); step.add_constr("expr5".into(), expr5); - let mut circuit: SBPIRMachine = SBPIRMachine::default(); - let step_uuid = circuit.add_step_type_def(step); + let mut machine: SBPIRMachine = SBPIRMachine::default(); + let step_uuid = machine.add_step_type_def(step); + let mut machines = HashMap::new(); + machines.insert(uuid(), machine); + let circuit = SBPIR { + machines, + identifiers: HashMap::new(), + }; let scorer = Scorer::default(); let circuit = cse(circuit, CseConfig::default(), &scorer); - let common_ses_found_and_replaced = circuit - .step_types - .get(&step_uuid) - .unwrap() - .auto_signals - .values(); + let machine = circuit.machines.iter().next().unwrap().1; + let step = machine.step_types.get(&step_uuid).unwrap(); + + // Check if CSE was applied + assert!( + step.auto_signals.len() > 0, + "No common subexpressions were found" + ); + + // Helper function to check if an expression contains a CSE signal + fn contains_cse_signal(expr: &Expr, ()>) -> bool { + match expr { + Expr::Query(Queriable::Internal(signal), _) => { + signal.annotation().starts_with("cse-") + } + Expr::Sum(exprs, _) | Expr::Mul(exprs, _) => exprs.iter().any(contains_cse_signal), + Expr::Neg(sub_expr, _) => contains_cse_signal(sub_expr), + _ => false, + } + } - assert!(circuit - .step_types - .get(&step_uuid) - .unwrap() + // Check if at least one constraint contains a CSE signal + let has_cse_constraint = step .constraints .iter() - .any(|expr| format!("{:?}", expr.expr) == "((e * f * d) + (-cse-1))")); + .any(|constraint| contains_cse_signal(&constraint.expr)); + assert!(has_cse_constraint, "No constraints with CSE signals found"); - assert!(circuit - .step_types - .get(&step_uuid) - .unwrap() + // Check for specific optimizations without relying on exact CSE signal names + let has_optimized_efg = step .constraints .iter() - .any(|expr| format!("{:?}", expr.expr) == "((a * b) + (-cse-2))")); - - assert!(common_ses_found_and_replaced - .clone() - .any(|expr| format!("{:?}", &expr) == "(a * b)")); - assert!(common_ses_found_and_replaced - .clone() - .any(|expr| format!("{:?}", &expr) == "(e * f * d)")); + .any(|constraint| match &constraint.expr { + Expr::Sum(terms, _) => { + terms.iter().any(|term| match term { + Expr::Mul(factors, _) => { + factors.len() == 3 + && factors + .iter() + .all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _))) + } + _ => false, + }) && terms.iter().any(contains_cse_signal) + } + _ => false, + }); + assert!( + has_optimized_efg, + "Expected optimization for (e * f * d) not found" + ); + + let has_optimized_ab = step + .constraints + .iter() + .any(|constraint| match &constraint.expr { + Expr::Sum(terms, _) => { + terms.iter().any(|term| match term { + Expr::Mul(factors, _) => { + factors.len() == 2 + && factors + .iter() + .all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _))) + } + _ => false, + }) && terms.iter().any(contains_cse_signal) + } + _ => false, + }); + assert!( + has_optimized_ab, + "Expected optimization for (a * b) not found" + ); + + // Check if the common subexpressions were actually created + let cse_signals: Vec<_> = step + .auto_signals + .values() + .filter(|expr| matches!(expr, Expr::Mul(_, _))) + .collect(); + + assert!( + cse_signals.len() >= 2, + "Expected at least two multiplication CSEs" + ); + + let has_ab_cse = cse_signals.iter().any(|expr| { + if let Expr::Mul(factors, _) = expr { + factors.len() == 2 + && factors + .iter() + .all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _))) + } else { + false + } + }); + assert!(has_ab_cse, "CSE for (a * b) not found in auto_signals"); + + let has_efg_cse = cse_signals.iter().any(|expr| { + if let Expr::Mul(factors, _) = expr { + factors.len() == 3 + && factors + .iter() + .all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _))) + } else { + false + } + }); + assert!(has_efg_cse, "CSE for (e * f * d) not found in auto_signals"); } #[derive(Clone)] diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index 254d18c9..ac7055f2 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -231,38 +231,6 @@ impl, M> SBPIRLegacy { } } -impl + Clone, M: Clone> SBPIRLegacy { - pub fn transform_meta( - &self, - apply_meta: ApplyMetaFn, - ) -> SBPIRLegacy - where - ApplyMetaFn: Fn(&Expr, M>) -> N + Clone, - { - SBPIRLegacy { - step_types: self - .step_types - .iter() - .map(|(k, v)| (*k, v.transform_meta(apply_meta.clone()))) - .collect(), - forward_signals: self.forward_signals.clone(), - shared_signals: self.shared_signals.clone(), - fixed_signals: self.fixed_signals.clone(), - halo2_advice: self.halo2_advice.clone(), - halo2_fixed: self.halo2_fixed.clone(), - exposed: self.exposed.clone(), - annotations: self.annotations.clone(), - trace_generator: self.trace_generator.clone(), - fixed_assignments: self.fixed_assignments.clone(), - first_step: self.first_step, - last_step: self.last_step, - num_steps: self.num_steps, - q_enable: self.q_enable, - id: self.id, - } - } -} - impl SBPIRLegacy> { pub fn set_trace(&mut self, def: D) where @@ -301,8 +269,8 @@ impl> SBPIRLegacy { } } -pub struct SBPIR = DSLTraceGenerator> { - pub machines: HashMap>, +pub struct SBPIR = DSLTraceGenerator, M = ()> { + pub machines: HashMap>, pub identifiers: HashMap, } @@ -310,7 +278,7 @@ impl> SBPIR { pub(crate) fn from_legacy(circuit: SBPIRLegacy, machine_id: &str) -> SBPIR { let mut machines = HashMap::new(); let circuit_id = circuit.id; - machines.insert(circuit_id, SBPIRMachine::from_legacy(circuit)); + machines.insert(circuit_id, SBPIRMachine::::from_legacy(circuit)); let mut identifiers = HashMap::new(); identifiers.insert(machine_id.to_string(), circuit_id); SBPIR { diff --git a/src/sbpir/sbpir_machine.rs b/src/sbpir/sbpir_machine.rs index 87cb681b..57e518de 100644 --- a/src/sbpir/sbpir_machine.rs +++ b/src/sbpir/sbpir_machine.rs @@ -41,7 +41,7 @@ pub struct SBPIRMachine = DSLTraceGenerator, M = ()> pub id: UUID, } -impl> Debug for SBPIRMachine { +impl, M: Debug> Debug for SBPIRMachine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Circuit") .field("step_types", &self.step_types) @@ -61,7 +61,7 @@ impl> Debug for SBPIRMachine { } } -impl> Default for SBPIRMachine { +impl, M> Default for SBPIRMachine { fn default() -> Self { Self { step_types: Default::default(), @@ -88,7 +88,7 @@ impl> Default for SBPIRMachine { } } -impl> SBPIRMachine { +impl, M> SBPIRMachine { pub fn add_forward>(&mut self, name: N, phase: usize) -> ForwardSignal { let name = name.into(); let signal = ForwardSignal::new_with_phase(phase, name.clone()); @@ -171,7 +171,7 @@ impl> SBPIRMachine { self.annotations.insert(handler.uuid(), name.into()); } - pub fn add_step_type_def(&mut self, step: StepType) -> StepTypeUUID { + pub fn add_step_type_def(&mut self, step: StepType) -> StepTypeUUID { let uuid = step.uuid(); self.step_types.insert(uuid, step); @@ -187,7 +187,7 @@ impl> SBPIRMachine { } } - pub fn without_trace(self) -> SBPIRMachine { + pub fn without_trace(self) -> SBPIRMachine { SBPIRMachine { step_types: self.step_types, forward_signals: self.forward_signals, @@ -208,7 +208,7 @@ impl> SBPIRMachine { } #[allow(dead_code)] // TODO: Copy of the legacy SBPIR code. Remove if not used in the new compilation - pub(crate) fn with_trace>(self, trace: TG2) -> SBPIRMachine { + pub(crate) fn with_trace>(self, trace: TG2) -> SBPIRMachine { SBPIRMachine { trace_generator: Some(trace), // Change trace step_types: self.step_types,