Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fixed test_cse and unnecessary constraint with the auto_signal
Browse files Browse the repository at this point in the history
  • Loading branch information
rutefig committed Aug 19, 2024
1 parent 0b42b48 commit 06627bc
Showing 1 changed file with 51 additions and 90 deletions.
141 changes: 51 additions & 90 deletions src/compiler/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ fn cse_for_step<F: Field + Hash, S: Scoring<F>>(
// Step 6: Create a new signal for the common subexpression
let (common_se, decomp) = create_common_ses_signal(&common_expr, &mut signal_factory);

println!("Decomp: {:#?}", decomp);

// Step 7: Update the step type with the new common subexpression
update_step_type_with_common_subexpression(
&mut step_type_with_hash,
Expand Down Expand Up @@ -176,7 +178,6 @@ fn update_step_type_with_common_subexpression<F: Field + Hash>(
step_type.add_internal(*signal);
}
step_type.auto_signals.insert(*q, expr.clone());
step_type.add_constr(format!("{:?}", q), expr.clone());
}
for expr in &decomp.constrs {
step_type.add_constr(format!("{:?}", expr), expr.clone());
Expand Down Expand Up @@ -400,105 +401,65 @@ mod test {
"No common subexpressions were found"
);

// Helper function to check if an expression contains a CSE signal
fn contains_cse_signal(expr: &Expr<Fr, Queriable<Fr>, ()>) -> bool {
match expr {
Expr::Query(Queriable::Internal(signal), _) => {
signal.annotation().starts_with("cse-")
}
Expr::Sum(exprs, _) | Expr::Mul(exprs, _) => exprs.iter().any(contains_cse_signal),
Expr::Neg(sub_expr, _) => contains_cse_signal(sub_expr),
_ => false,
}
}

// Check if at least one constraint contains a CSE signal
let has_cse_constraint = step
.constraints
let cse_signals: Vec<_> = step
.auto_signals
.iter()
.any(|constraint| contains_cse_signal(&constraint.expr));
assert!(has_cse_constraint, "No constraints with CSE signals found");
.filter(|(_, expr)| matches!(expr, Expr::Mul(_, _)))
.collect();

// Check for specific optimizations without relying on exact CSE signal names
let has_optimized_efg = step
.constraints
// Check for (a * b) in auto_signals by cse
let (ab_signal, ab_expr) = cse_signals
.iter()
.any(|constraint| match &constraint.expr {
Expr::Sum(terms, _) => {
terms.iter().any(|term| match term {
Expr::Mul(factors, _) => {
factors.len() == 3
&& factors
.iter()
.all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _)))
}
_ => false,
}) && terms.iter().any(contains_cse_signal)
.find(|(_, expr)| {
if let Expr::Mul(factors, _) = expr {
factors.len() == 2
&& factors
.iter()
.all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _)))
} else {
false
}
_ => false,
});
assert!(
has_optimized_efg,
"Expected optimization for (e * f * d) not found"
);
})
.unwrap();

let has_optimized_ab = step
.constraints
// Check for (e * f * d) in auto_signals by cse
let (efd_signal, efd_expr) = cse_signals
.iter()
.any(|constraint| match &constraint.expr {
Expr::Sum(terms, _) => {
terms.iter().any(|term| match term {
Expr::Mul(factors, _) => {
factors.len() == 2
&& factors
.iter()
.all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _)))
}
_ => false,
}) && terms.iter().any(contains_cse_signal)
.find(|(_, expr)| {
if let Expr::Mul(factors, _) = expr {
factors.len() == 3
&& factors
.iter()
.all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _)))
} else {
false
}
_ => false,
});
assert!(
has_optimized_ab,
"Expected optimization for (a * b) not found"
);

// Check if the common subexpressions were actually created
let cse_signals: Vec<_> = step
.auto_signals
.values()
.filter(|expr| matches!(expr, Expr::Mul(_, _)))
.collect();
})
.unwrap();

assert!(
cse_signals.len() >= 2,
"Expected at least two multiplication CSEs"
);

let has_ab_cse = cse_signals.iter().any(|expr| {
if let Expr::Mul(factors, _) = expr {
factors.len() == 2
&& factors
.iter()
.all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _)))
} else {
false
}
// Assert that step has efd_expr - efd_signal in constraints
let efd_expr_in_constraints = step.constraints.iter().any(|constraint| {
format!("{:?}", constraint.expr) == format!("({:?} + (-{:?}))", efd_expr, efd_signal)
});
assert!(has_ab_cse, "CSE for (a * b) not found in auto_signals");

let has_efg_cse = cse_signals.iter().any(|expr| {
if let Expr::Mul(factors, _) = expr {
factors.len() == 3
&& factors
.iter()
.all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _)))
} else {
false
}
assert!(efd_expr_in_constraints);

// Assert that step has ab_signal - ab_expr in constraints
let ab_expr_in_constraints = step.constraints.iter().any(|constraint| {
format!("{:?}", constraint.expr) == format!("({:?} + (-{:?}))", ab_expr, ab_signal)
});
assert!(has_efg_cse, "CSE for (e * f * d) not found in auto_signals");
assert!(ab_expr_in_constraints);

// Assert that (a * b) only appears once in the constraints
let ab_expr_count = step.constraints.iter().filter(|constraint| {
format!("{:?}", constraint.expr) == format!("({:?} + (-{:?}))", ab_expr, ab_signal)
}).count();
assert_eq!(ab_expr_count, 1);

// Assert that (e * f * d) only appears once in the constraints
let efd_expr_count = step.constraints.iter().filter(|constraint| {
format!("{:?}", constraint.expr) == format!("({:?} + (-{:?}))", efd_expr, efd_signal)
}).count();
assert_eq!(efd_expr_count, 1);
}

#[derive(Clone)]
Expand Down

0 comments on commit 06627bc

Please sign in to comment.