Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
added trait for CSE Scorer
Browse files Browse the repository at this point in the history
  • Loading branch information
rutefig committed Aug 9, 2024
1 parent 6b5e4aa commit 823bf4b
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 32 deletions.
84 changes: 57 additions & 27 deletions src/compiler/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use crate::{
cse::{create_common_ses_signal, replace_expr},
Expr, HashResult, VarAssignments,
},
sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, ForwardSignal, InternalSignal, StepType},
sbpir::{
query::Queriable, sbpir_machine::SBPIRMachine, ForwardSignal, InternalSignal, StepType,
},
wit_gen::NullTraceGenerator,
};

Expand All @@ -37,14 +39,22 @@ impl Default for CseConfig {
}

#[allow(dead_code)]
pub fn config(min_degree: usize, min_occurrences: usize, max_iterations: Option<usize>) -> CseConfig {
pub fn config(
min_degree: usize,
min_occurrences: usize,
max_iterations: Option<usize>,
) -> CseConfig {
CseConfig {
min_degree,
min_occurrences,
max_iterations: max_iterations.unwrap_or(100),
}
}

pub trait Scorer<F: Field + Hash> {
fn score(&self, expr: &Expr<F, Queriable<F>, HashResult>, info: &SubexprInfo) -> usize;
}

/// Common Subexpression Elimination (CSE) optimization.
/// This optimization replaces common subexpressions with new internal signals for the step type.
/// This is done by each time finding the optimal subexpression to replace and creating a new signal
Expand All @@ -54,20 +64,22 @@ pub fn config(min_degree: usize, min_occurrences: usize, max_iterations: Option<
/// queriables. Using the Schwartz-Zippel lemma, we can determine if two expressions are equivalent
/// with high probability.
#[allow(dead_code)]
pub(super) fn cse<F: Field + Hash>(
pub(super) fn cse<F: Field + Hash, S: Scorer<F>>(
mut circuit: SBPIRMachine<F, NullTraceGenerator>,
config: CseConfig,
scorer: &S,
) -> SBPIRMachine<F, NullTraceGenerator> {
for (_, step_type) in circuit.step_types.iter_mut() {
cse_for_step(step_type, &circuit.forward_signals, &config)
cse_for_step(step_type, &circuit.forward_signals, &config, scorer)
}
circuit
}

fn cse_for_step<F: Field + Hash>(
fn cse_for_step<F: Field + Hash, S: Scorer<F>>(
step_type: &mut StepType<F, ()>,
forward_signals: &[ForwardSignal],
config: &CseConfig,
scorer: &S,
) {
let mut signal_factory = SignalFactory::default();
let mut replaced_hashes = HashSet::new();
Expand Down Expand Up @@ -105,7 +117,9 @@ fn cse_for_step<F: Field + Hash>(
}

// Find the optimal subexpression to replace
if let Some(common_expr) = find_optimal_subexpression(&exprs, &replaced_hashes, config.clone()) {
if let Some(common_expr) =
find_optimal_subexpression(&exprs, &replaced_hashes, config.clone(), scorer)
{
// Add the hash of the replaced expression to the set
replaced_hashes.insert(common_expr.meta().hash);
// Create a new signal for the common subexpression
Expand Down Expand Up @@ -139,7 +153,7 @@ fn cse_for_step<F: Field + Hash>(
}

#[derive(Debug, Clone, Copy)]
struct SubexprInfo {
pub(super) struct SubexprInfo {
count: usize,
degree: usize,
}
Expand All @@ -154,18 +168,14 @@ impl SubexprInfo {
self.count += 1;
self.degree = self.degree.max(degree);
}

fn get_score(&self) -> usize {
// TODO: Improve the scoring function and adjust the weights
2 * self.count + 3 * self.degree
}
}

/// Find the optimal subexpression to replace in a list of expressions.
fn find_optimal_subexpression<F: Field + Hash>(
fn find_optimal_subexpression<F: Field + Hash, S: Scorer<F>>(
exprs: &[Expr<F, Queriable<F>, HashResult>],
replaced_hashes: &HashSet<u64>,
config: CseConfig
config: CseConfig,
scorer: &S,
) -> Option<Expr<F, Queriable<F>, HashResult>> {
let mut count_map = HashMap::<u64, SubexprInfo>::new();
let mut hash_to_expr = HashMap::<u64, Expr<F, Queriable<F>, HashResult>>::new();
Expand All @@ -179,21 +189,25 @@ fn find_optimal_subexpression<F: Field + Hash>(
let common_ses = count_map
.into_iter()
.filter(|&(hash, info)| {
info.count >= config.min_occurrences && info.degree >= config.min_degree && !replaced_hashes.contains(&hash)
info.count >= config.min_occurrences
&& info.degree >= config.min_degree
&& !replaced_hashes.contains(&hash)
})
.collect::<HashMap<_, _>>();

// Find the best common subexpression to replace based on the score
let best_subexpr = common_ses
.iter()
.max_by_key(|&(_, info)| info.get_score())
.map(|(&hash, info)| (hash, info.count, info.degree));
.map(|(&hash, info)| {
let expr = hash_to_expr.get(&hash).unwrap();
let score = scorer.score(expr, info);
(hash, score)
})
.filter(|&(_, score)| score > 0)
.max_by_key(|&(_, score)| score)
.map(|(hash, _)| hash);

if let Some((hash, _count, _degree)) = best_subexpr {
hash_to_expr.get(&hash).cloned()
} else {
None
}
best_subexpr.and_then(|hash| hash_to_expr.get(&hash).cloned())
}

/// Count the subexpressions in an expression and store them in a map.
Expand Down Expand Up @@ -255,21 +269,29 @@ impl<F> poly::SignalFactory<Queriable<F>> for SignalFactory<F> {

#[cfg(test)]
mod test {
use std::collections::HashSet;
use std::{collections::HashSet, hash::Hash};

use halo2_proofs::halo2curves::bn256::Fr;
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};

use crate::{
compiler::cse::{cse, CseConfig},
field::Field,
poly::{Expr, VarAssignments},
poly::{Expr, HashResult, VarAssignments},
sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, InternalSignal, StepType},
util::uuid,
wit_gen::NullTraceGenerator,
};

use super::find_optimal_subexpression;
use super::{find_optimal_subexpression, Scorer, SubexprInfo};

pub struct TestScorer;

impl<F: Field + Hash> Scorer<F> for TestScorer {
fn score(&self, _expr: &Expr<F, Queriable<F>, HashResult>, info: &SubexprInfo) -> usize {
2 * info.count + 3 * info.degree
}
}

#[test]
fn test_find_optimal_subexpression() {
Expand Down Expand Up @@ -300,7 +322,14 @@ mod test {
hashed_exprs.push(hashed_expr);
}

let best_expr = find_optimal_subexpression(&hashed_exprs, &HashSet::new(), CseConfig::default());
let scorer = TestScorer;

let best_expr = find_optimal_subexpression(
&hashed_exprs,
&HashSet::new(),
CseConfig::default(),
&scorer,
);

assert_eq!(format!("{:?}", best_expr.unwrap()), "(e * f * d)");
}
Expand Down Expand Up @@ -346,7 +375,8 @@ mod test {
let mut circuit: SBPIRMachine<Fr, NullTraceGenerator> = SBPIRMachine::default();
let step_uuid = circuit.add_step_type_def(step);

let circuit = cse(circuit, CseConfig::default());
let scorer = TestScorer;
let circuit = cse(circuit, CseConfig::default(), &scorer);

let common_ses_found_and_replaced = circuit
.step_types
Expand Down
11 changes: 7 additions & 4 deletions src/poly/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,19 @@ impl<F: Field + Hash, V: Debug + Clone + Eq + Hash, M: Clone> Expr<F, V, M> {
.collect(),
new_meta,
),
Expr::Neg(se, _) => Expr::Neg(Box::new(se.transform_meta(apply_meta.clone())), new_meta),
Expr::Pow(se, exp, _) => {
Expr::Pow(Box::new(se.transform_meta(apply_meta.clone())), *exp, new_meta)
Expr::Neg(se, _) => {
Expr::Neg(Box::new(se.transform_meta(apply_meta.clone())), new_meta)
}
Expr::Pow(se, exp, _) => Expr::Pow(
Box::new(se.transform_meta(apply_meta.clone())),
*exp,
new_meta,
),
Expr::Query(v, _) => Expr::Query(v.clone(), new_meta),
Expr::Halo2Expr(e, _) => Expr::Halo2Expr(e.clone(), new_meta),
Expr::MI(se, _) => Expr::MI(Box::new(se.transform_meta(apply_meta.clone())), new_meta),
}
}


pub fn apply_subexpressions<T>(&self, mut f: T) -> Self
where
Expand Down
5 changes: 4 additions & 1 deletion src/sbpir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ impl<F, TG: TraceGenerator<F>, M> SBPIRLegacy<F, TG, M> {
}

impl<F: Field + Hash, TG: TraceGenerator<F> + Clone, M: Clone> SBPIRLegacy<F, TG, M> {
pub fn transform_meta<N: Clone, ApplyMetaFn>(&self, apply_meta: ApplyMetaFn) -> SBPIRLegacy<F, TG, N>
pub fn transform_meta<N: Clone, ApplyMetaFn>(
&self,
apply_meta: ApplyMetaFn,
) -> SBPIRLegacy<F, TG, N>
where
ApplyMetaFn: Fn(&Expr<F, Queriable<F>, M>) -> N + Clone,
{
Expand Down

0 comments on commit 823bf4b

Please sign in to comment.