Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
cse to work with the new SBPIR and fixed testto not assume a certain …
Browse files Browse the repository at this point in the history
…order for the expressions optimised
  • Loading branch information
rutefig committed Aug 17, 2024
1 parent 1bbdfbd commit 9e4e59f
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 78 deletions.
159 changes: 122 additions & 37 deletions src/compiler/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ use crate::{
cse::{create_common_ses_signal, replace_expr},
Expr, HashResult, VarAssignments,
},
sbpir::{
query::Queriable, sbpir_machine::SBPIRMachine, ForwardSignal, InternalSignal, StepType,
},
sbpir::{query::Queriable, ForwardSignal, InternalSignal, StepType, SBPIR},
wit_gen::NullTraceGenerator,
};

#[derive(Clone, Debug)]
pub struct CseConfig {
pub(super) struct CseConfig {
max_iterations: usize,
}

Expand All @@ -39,11 +37,11 @@ pub fn config(max_iterations: Option<usize>) -> CseConfig {
}
}

pub trait Scoring<F: Field + Hash> {
pub(super) trait Scoring<F: Field + Hash> {
fn score(&self, expr: &Expr<F, Queriable<F>, HashResult>, info: &SubexprInfo) -> usize;
}

pub struct Scorer {
pub(super) struct Scorer {
min_degree: usize,
min_occurrences: usize,
}
Expand Down Expand Up @@ -77,12 +75,14 @@ impl<F: Field + Hash> Scoring<F> for Scorer {
/// with high probability.
#[allow(dead_code)]
pub(super) fn cse<F: Field + Hash, S: Scoring<F>>(
mut circuit: SBPIRMachine<F, NullTraceGenerator>,
mut circuit: SBPIR<F, NullTraceGenerator>,
config: CseConfig,
scorer: &S,
) -> SBPIRMachine<F, NullTraceGenerator> {
for (_, step_type) in circuit.step_types.iter_mut() {
cse_for_step(step_type, &circuit.forward_signals, &config, scorer)
) -> SBPIR<F, NullTraceGenerator> {
for (_, machine) in circuit.machines.iter_mut() {
for (_, step_type) in machine.step_types.iter_mut() {
cse_for_step(step_type, &machine.forward_signals, &config, scorer)
}
}
circuit
}
Expand Down Expand Up @@ -268,7 +268,7 @@ impl<F> poly::SignalFactory<Queriable<F>> for SignalFactory<F> {

#[cfg(test)]
mod test {
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};

use halo2_proofs::halo2curves::bn256::Fr;
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};
Expand All @@ -277,7 +277,7 @@ mod test {
compiler::cse::{cse, CseConfig, Scorer},
field::Field,
poly::{Expr, VarAssignments},
sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, InternalSignal, StepType},
sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, InternalSignal, StepType, SBPIR},
util::uuid,
wit_gen::NullTraceGenerator,
};
Expand Down Expand Up @@ -358,41 +358,126 @@ mod test {
step.add_constr("expr4".into(), expr4.clone());
step.add_constr("expr5".into(), expr5);

let mut circuit: SBPIRMachine<Fr, NullTraceGenerator> = SBPIRMachine::default();
let step_uuid = circuit.add_step_type_def(step);
let mut machine: SBPIRMachine<Fr, NullTraceGenerator> = SBPIRMachine::default();
let step_uuid = machine.add_step_type_def(step);
let mut machines = HashMap::new();
machines.insert(uuid(), machine);
let circuit = SBPIR {
machines,
identifiers: HashMap::new(),
};

let scorer = Scorer::default();
let circuit = cse(circuit, CseConfig::default(), &scorer);

let common_ses_found_and_replaced = circuit
.step_types
.get(&step_uuid)
.unwrap()
.auto_signals
.values();
let machine = circuit.machines.iter().next().unwrap().1;
let step = machine.step_types.get(&step_uuid).unwrap();

// Check if CSE was applied
assert!(
step.auto_signals.len() > 0,
"No common subexpressions were found"
);

// Helper function to check if an expression contains a CSE signal
fn contains_cse_signal(expr: &Expr<Fr, Queriable<Fr>, ()>) -> bool {
match expr {
Expr::Query(Queriable::Internal(signal), _) => {
signal.annotation().starts_with("cse-")
}
Expr::Sum(exprs, _) | Expr::Mul(exprs, _) => exprs.iter().any(contains_cse_signal),
Expr::Neg(sub_expr, _) => contains_cse_signal(sub_expr),
_ => false,
}
}

assert!(circuit
.step_types
.get(&step_uuid)
.unwrap()
// Check if at least one constraint contains a CSE signal
let has_cse_constraint = step
.constraints
.iter()
.any(|expr| format!("{:?}", expr.expr) == "((e * f * d) + (-cse-1))"));
.any(|constraint| contains_cse_signal(&constraint.expr));
assert!(has_cse_constraint, "No constraints with CSE signals found");

assert!(circuit
.step_types
.get(&step_uuid)
.unwrap()
// Check for specific optimizations without relying on exact CSE signal names
let has_optimized_efg = step
.constraints
.iter()
.any(|expr| format!("{:?}", expr.expr) == "((a * b) + (-cse-2))"));

assert!(common_ses_found_and_replaced
.clone()
.any(|expr| format!("{:?}", &expr) == "(a * b)"));
assert!(common_ses_found_and_replaced
.clone()
.any(|expr| format!("{:?}", &expr) == "(e * f * d)"));
.any(|constraint| match &constraint.expr {
Expr::Sum(terms, _) => {
terms.iter().any(|term| match term {
Expr::Mul(factors, _) => {
factors.len() == 3
&& factors
.iter()
.all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _)))
}
_ => false,
}) && terms.iter().any(contains_cse_signal)
}
_ => false,
});
assert!(
has_optimized_efg,
"Expected optimization for (e * f * d) not found"
);

let has_optimized_ab = step
.constraints
.iter()
.any(|constraint| match &constraint.expr {
Expr::Sum(terms, _) => {
terms.iter().any(|term| match term {
Expr::Mul(factors, _) => {
factors.len() == 2
&& factors
.iter()
.all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _)))
}
_ => false,
}) && terms.iter().any(contains_cse_signal)
}
_ => false,
});
assert!(
has_optimized_ab,
"Expected optimization for (a * b) not found"
);

// Check if the common subexpressions were actually created
let cse_signals: Vec<_> = step
.auto_signals
.values()
.filter(|expr| matches!(expr, Expr::Mul(_, _)))
.collect();

assert!(
cse_signals.len() >= 2,
"Expected at least two multiplication CSEs"
);

let has_ab_cse = cse_signals.iter().any(|expr| {
if let Expr::Mul(factors, _) = expr {
factors.len() == 2
&& factors
.iter()
.all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _)))
} else {
false
}
});
assert!(has_ab_cse, "CSE for (a * b) not found in auto_signals");

let has_efg_cse = cse_signals.iter().any(|expr| {
if let Expr::Mul(factors, _) = expr {
factors.len() == 3
&& factors
.iter()
.all(|f| matches!(f, Expr::Query(Queriable::Internal(_), _)))
} else {
false
}
});
assert!(has_efg_cse, "CSE for (e * f * d) not found in auto_signals");
}

#[derive(Clone)]
Expand Down
38 changes: 3 additions & 35 deletions src/sbpir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,38 +231,6 @@ impl<F, TG: TraceGenerator<F>, M> SBPIRLegacy<F, TG, M> {
}
}

impl<F: Field + Hash, TG: TraceGenerator<F> + Clone, M: Clone> SBPIRLegacy<F, TG, M> {
pub fn transform_meta<N: Clone, ApplyMetaFn>(
&self,
apply_meta: ApplyMetaFn,
) -> SBPIRLegacy<F, TG, N>
where
ApplyMetaFn: Fn(&Expr<F, Queriable<F>, M>) -> N + Clone,
{
SBPIRLegacy {
step_types: self
.step_types
.iter()
.map(|(k, v)| (*k, v.transform_meta(apply_meta.clone())))
.collect(),
forward_signals: self.forward_signals.clone(),
shared_signals: self.shared_signals.clone(),
fixed_signals: self.fixed_signals.clone(),
halo2_advice: self.halo2_advice.clone(),
halo2_fixed: self.halo2_fixed.clone(),
exposed: self.exposed.clone(),
annotations: self.annotations.clone(),
trace_generator: self.trace_generator.clone(),
fixed_assignments: self.fixed_assignments.clone(),
first_step: self.first_step,
last_step: self.last_step,
num_steps: self.num_steps,
q_enable: self.q_enable,
id: self.id,
}
}
}

impl<F: Field, TraceArgs: Clone> SBPIRLegacy<F, DSLTraceGenerator<F, TraceArgs>> {
pub fn set_trace<D>(&mut self, def: D)
where
Expand Down Expand Up @@ -301,16 +269,16 @@ impl<F: Clone + Field, TG: TraceGenerator<F>> SBPIRLegacy<F, TG> {
}
}

pub struct SBPIR<F, TG: TraceGenerator<F> = DSLTraceGenerator<F>> {
pub machines: HashMap<UUID, SBPIRMachine<F, TG>>,
pub struct SBPIR<F, TG: TraceGenerator<F> = DSLTraceGenerator<F>, M = ()> {
pub machines: HashMap<UUID, SBPIRMachine<F, TG, M>>,
pub identifiers: HashMap<String, UUID>,
}

impl<F, TG: TraceGenerator<F>> SBPIR<F, TG> {
pub(crate) fn from_legacy(circuit: SBPIRLegacy<F, TG>, machine_id: &str) -> SBPIR<F, TG> {
let mut machines = HashMap::new();
let circuit_id = circuit.id;
machines.insert(circuit_id, SBPIRMachine::from_legacy(circuit));
machines.insert(circuit_id, SBPIRMachine::<F, TG, ()>::from_legacy(circuit));
let mut identifiers = HashMap::new();
identifiers.insert(machine_id.to_string(), circuit_id);
SBPIR {
Expand Down
12 changes: 6 additions & 6 deletions src/sbpir/sbpir_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct SBPIRMachine<F, TG: TraceGenerator<F> = DSLTraceGenerator<F>, M = ()>
pub id: UUID,
}

impl<F: Debug, TG: TraceGenerator<F>> Debug for SBPIRMachine<F, TG> {
impl<F: Debug, TG: TraceGenerator<F>, M: Debug> Debug for SBPIRMachine<F, TG, M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Circuit")
.field("step_types", &self.step_types)
Expand All @@ -61,7 +61,7 @@ impl<F: Debug, TG: TraceGenerator<F>> Debug for SBPIRMachine<F, TG> {
}
}

impl<F, TG: TraceGenerator<F>> Default for SBPIRMachine<F, TG> {
impl<F, TG: TraceGenerator<F>, M> Default for SBPIRMachine<F, TG, M> {
fn default() -> Self {
Self {
step_types: Default::default(),
Expand All @@ -88,7 +88,7 @@ impl<F, TG: TraceGenerator<F>> Default for SBPIRMachine<F, TG> {
}
}

impl<F, TG: TraceGenerator<F>> SBPIRMachine<F, TG> {
impl<F, TG: TraceGenerator<F>, M> SBPIRMachine<F, TG, M> {
pub fn add_forward<N: Into<String>>(&mut self, name: N, phase: usize) -> ForwardSignal {
let name = name.into();
let signal = ForwardSignal::new_with_phase(phase, name.clone());
Expand Down Expand Up @@ -171,7 +171,7 @@ impl<F, TG: TraceGenerator<F>> SBPIRMachine<F, TG> {
self.annotations.insert(handler.uuid(), name.into());
}

pub fn add_step_type_def(&mut self, step: StepType<F, ()>) -> StepTypeUUID {
pub fn add_step_type_def(&mut self, step: StepType<F, M>) -> StepTypeUUID {
let uuid = step.uuid();
self.step_types.insert(uuid, step);

Expand All @@ -187,7 +187,7 @@ impl<F, TG: TraceGenerator<F>> SBPIRMachine<F, TG> {
}
}

pub fn without_trace(self) -> SBPIRMachine<F, NullTraceGenerator> {
pub fn without_trace(self) -> SBPIRMachine<F, NullTraceGenerator, M> {
SBPIRMachine {
step_types: self.step_types,
forward_signals: self.forward_signals,
Expand All @@ -208,7 +208,7 @@ impl<F, TG: TraceGenerator<F>> SBPIRMachine<F, TG> {
}

#[allow(dead_code)] // TODO: Copy of the legacy SBPIR code. Remove if not used in the new compilation
pub(crate) fn with_trace<TG2: TraceGenerator<F>>(self, trace: TG2) -> SBPIRMachine<F, TG2> {
pub(crate) fn with_trace<TG2: TraceGenerator<F>>(self, trace: TG2) -> SBPIRMachine<F, TG2, M> {
SBPIRMachine {
trace_generator: Some(trace), // Change trace
step_types: self.step_types,
Expand Down

0 comments on commit 9e4e59f

Please sign in to comment.