diff --git a/src/compiler/cse.rs b/src/compiler/cse.rs index 882ba791..a7f5f51e 100644 --- a/src/compiler/cse.rs +++ b/src/compiler/cse.rs @@ -113,6 +113,8 @@ fn cse_for_step>( // Step 6: Create a new signal for the common subexpression let (common_se, decomp) = create_common_ses_signal(&common_expr, &mut signal_factory); + println!("Decomp: {:#?}", decomp); + // Step 7: Update the step type with the new common subexpression update_step_type_with_common_subexpression( &mut step_type_with_hash, @@ -176,7 +178,6 @@ fn update_step_type_with_common_subexpression( step_type.add_internal(*signal); } step_type.auto_signals.insert(*q, expr.clone()); - step_type.add_constr(format!("{:?}", q), expr.clone()); } for expr in &decomp.constrs { step_type.add_constr(format!("{:?}", expr), expr.clone()); @@ -400,105 +401,65 @@ mod test { "No common subexpressions were found" ); - // Helper function to check if an expression contains a CSE signal - fn contains_cse_signal(expr: &Expr, ()>) -> bool { - match expr { - Expr::Query(Queriable::Internal(signal), _) => { - signal.annotation().starts_with("cse-") - } - Expr::Sum(exprs, _) | Expr::Mul(exprs, _) => exprs.iter().any(contains_cse_signal), - Expr::Neg(sub_expr, _) => contains_cse_signal(sub_expr), - _ => false, - } - } - - // Check if at least one constraint contains a CSE signal - let has_cse_constraint = step - .constraints + let cse_signals: Vec<_> = step + .auto_signals .iter() - .any(|constraint| contains_cse_signal(&constraint.expr)); - assert!(has_cse_constraint, "No constraints with CSE signals found"); + .filter(|(_, expr)| matches!(expr, Expr::Mul(_, _))) + .collect(); - // Check for specific optimizations without relying on exact CSE signal names - let has_optimized_efg = step - .constraints + // Check for (a * b) in auto_signals by cse + let (ab_signal, ab_expr) = cse_signals .iter() - .any(|constraint| match &constraint.expr { - Expr::Sum(terms, _) => { - terms.iter().any(|term| match term { - Expr::Mul(factors, _) => { - factors.len() == 3 - && factors - .iter() - .all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _))) - } - _ => false, - }) && terms.iter().any(contains_cse_signal) + .find(|(_, expr)| { + if let Expr::Mul(factors, _) = expr { + factors.len() == 2 + && factors + .iter() + .all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _))) + } else { + false } - _ => false, - }); - assert!( - has_optimized_efg, - "Expected optimization for (e * f * d) not found" - ); + }) + .unwrap(); - let has_optimized_ab = step - .constraints + // Check for (e * f * d) in auto_signals by cse + let (efd_signal, efd_expr) = cse_signals .iter() - .any(|constraint| match &constraint.expr { - Expr::Sum(terms, _) => { - terms.iter().any(|term| match term { - Expr::Mul(factors, _) => { - factors.len() == 2 - && factors - .iter() - .all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _))) - } - _ => false, - }) && terms.iter().any(contains_cse_signal) + .find(|(_, expr)| { + if let Expr::Mul(factors, _) = expr { + factors.len() == 3 + && factors + .iter() + .all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _))) + } else { + false } - _ => false, - }); - assert!( - has_optimized_ab, - "Expected optimization for (a * b) not found" - ); - - // Check if the common subexpressions were actually created - let cse_signals: Vec<_> = step - .auto_signals - .values() - .filter(|expr| matches!(expr, Expr::Mul(_, _))) - .collect(); + }) + .unwrap(); - assert!( - cse_signals.len() >= 2, - "Expected at least two multiplication CSEs" - ); - - let has_ab_cse = cse_signals.iter().any(|expr| { - if let Expr::Mul(factors, _) = expr { - factors.len() == 2 - && factors - .iter() - .all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _))) - } else { - false - } + // Assert that step has efd_expr - efd_signal in constraints + let efd_expr_in_constraints = step.constraints.iter().any(|constraint| { + format!("{:?}", constraint.expr) == format!("({:?} + (-{:?}))", efd_expr, efd_signal) }); - assert!(has_ab_cse, "CSE for (a * b) not found in auto_signals"); - - let has_efg_cse = cse_signals.iter().any(|expr| { - if let Expr::Mul(factors, _) = expr { - factors.len() == 3 - && factors - .iter() - .all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _))) - } else { - false - } + assert!(efd_expr_in_constraints); + + // Assert that step has ab_signal - ab_expr in constraints + let ab_expr_in_constraints = step.constraints.iter().any(|constraint| { + format!("{:?}", constraint.expr) == format!("({:?} + (-{:?}))", ab_expr, ab_signal) }); - assert!(has_efg_cse, "CSE for (e * f * d) not found in auto_signals"); + assert!(ab_expr_in_constraints); + + // Assert that (a * b) only appears once in the constraints + let ab_expr_count = step.constraints.iter().filter(|constraint| { + format!("{:?}", constraint.expr) == format!("({:?} + (-{:?}))", ab_expr, ab_signal) + }).count(); + assert_eq!(ab_expr_count, 1); + + // Assert that (e * f * d) only appears once in the constraints + let efd_expr_count = step.constraints.iter().filter(|constraint| { + format!("{:?}", constraint.expr) == format!("({:?} + (-{:?}))", efd_expr, efd_signal) + }).count(); + assert_eq!(efd_expr_count, 1); } #[derive(Clone)]