diff --git a/src/compiler/cse.rs b/src/compiler/cse.rs index 3cca6824..f7f0d9a9 100644 --- a/src/compiler/cse.rs +++ b/src/compiler/cse.rs @@ -20,6 +20,14 @@ use crate::{ wit_gen::NullTraceGenerator, }; + +/// Common Subexpression Elimination (CSE) optimization. +/// This optimization replaces common subexpressions with new internal signals for the step type. +/// This is done by each time finding the optimal subexpression to replace and creating a new signal +/// for it and replacing it in all constraints. +/// The process is repeated until no more common subexpressions are found. +/// Equivalent expressions are found by hashing the expressions with random assignments to the queriables. Using +/// the Schwartz-Zippel lemma, we can determine if two expressions are equivalent with high probability. pub(super) fn cse( mut circuit: SBPIR, ) -> SBPIR { @@ -67,6 +75,7 @@ pub(super) fn cse( let (common_se, decomp) = create_common_ses_signal(&common_expr, &mut signal_factory); + // Add the new signal to the step type and a constraint for it decomp.auto_signals.iter().for_each(|(q, expr)| { if let Queriable::Internal(signal) = q { step_type_with_hash.add_internal(signal.clone()); @@ -196,16 +205,6 @@ impl poly::SignalFactory> for SignalFactory { } } -// cse -// 0. collects a set of expressions Expr -// 1. turns all Expr into Expr -// 2. traverse all the expressions and find common subexpressions counting the number of times they -// appear -// 3. Sort the common subexpressions by the degree and number of times they appear -// 4. Replace the best common subexpression with a signal -// 5. Repeat until no common subexpressions are found -// - #[cfg(test)] mod test { use std::collections::HashSet; diff --git a/src/poly/cse.rs b/src/poly/cse.rs index 3daf4489..41f9908c 100644 --- a/src/poly/cse.rs +++ b/src/poly/cse.rs @@ -3,6 +3,7 @@ use crate::field::Field; use super::{ConstrDecomp, Expr, HashResult, SignalFactory}; use std::{fmt::Debug, hash::Hash}; +/// This function replaces a common subexpression in an expression with a new signal. pub fn replace_expr>( expr: &Expr, common_se: &Expr, @@ -15,6 +16,7 @@ pub fn replace_expr>( expr: &Expr, common_se: &Expr, @@ -44,13 +47,11 @@ fn replace_subexpr