Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
moved min_degree and min_occurrences from CseConfig to Scorer
Browse files Browse the repository at this point in the history
  • Loading branch information
rutefig committed Aug 13, 2024
1 parent 823bf4b commit 0720a88
Showing 1 changed file with 37 additions and 45 deletions.
82 changes: 37 additions & 45 deletions src/compiler/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,52 @@ use crate::{

#[derive(Clone, Debug)]
pub struct CseConfig {
min_degree: usize,
min_occurrences: usize,

max_iterations: usize,
}

impl Default for CseConfig {
fn default() -> Self {
Self {
min_degree: 2,
min_occurrences: 2,

max_iterations: 100,
}
}
}

#[allow(dead_code)]
pub fn config(
min_degree: usize,
min_occurrences: usize,
max_iterations: Option<usize>,
) -> CseConfig {
pub fn config(max_iterations: Option<usize>) -> CseConfig {
CseConfig {
min_degree,
min_occurrences,
max_iterations: max_iterations.unwrap_or(100),
}
}

pub trait Scorer<F: Field + Hash> {
pub trait Scoring<F: Field + Hash> {
fn score(&self, expr: &Expr<F, Queriable<F>, HashResult>, info: &SubexprInfo) -> usize;
}

pub struct Scorer {
min_degree: usize,
min_occurrences: usize,
}

impl Default for Scorer {
fn default() -> Self {
Self {
min_degree: 2,
min_occurrences: 2,
}
}
}

impl<F: Field + Hash> Scoring<F> for Scorer {
fn score(&self, _expr: &Expr<F, Queriable<F>, HashResult>, info: &SubexprInfo) -> usize {
if info.degree < self.min_degree || info.count < self.min_occurrences {
0
} else {
info.count * info.degree
}
}
}

/// Common Subexpression Elimination (CSE) optimization.
/// This optimization replaces common subexpressions with new internal signals for the step type.
/// This is done by each time finding the optimal subexpression to replace and creating a new signal
Expand All @@ -64,7 +76,7 @@ pub trait Scorer<F: Field + Hash> {
/// queriables. Using the Schwartz-Zippel lemma, we can determine if two expressions are equivalent
/// with high probability.
#[allow(dead_code)]
pub(super) fn cse<F: Field + Hash, S: Scorer<F>>(
pub(super) fn cse<F: Field + Hash, S: Scoring<F>>(
mut circuit: SBPIRMachine<F, NullTraceGenerator>,
config: CseConfig,
scorer: &S,
Expand All @@ -75,7 +87,7 @@ pub(super) fn cse<F: Field + Hash, S: Scorer<F>>(
circuit
}

fn cse_for_step<F: Field + Hash, S: Scorer<F>>(
fn cse_for_step<F: Field + Hash, S: Scoring<F>>(
step_type: &mut StepType<F, ()>,
forward_signals: &[ForwardSignal],
config: &CseConfig,
Expand Down Expand Up @@ -118,7 +130,7 @@ fn cse_for_step<F: Field + Hash, S: Scorer<F>>(

// Find the optimal subexpression to replace
if let Some(common_expr) =
find_optimal_subexpression(&exprs, &replaced_hashes, config.clone(), scorer)
find_optimal_subexpression(&exprs, &replaced_hashes, scorer)
{
// Add the hash of the replaced expression to the set
replaced_hashes.insert(common_expr.meta().hash);
Expand Down Expand Up @@ -171,10 +183,9 @@ impl SubexprInfo {
}

/// Find the optimal subexpression to replace in a list of expressions.
fn find_optimal_subexpression<F: Field + Hash, S: Scorer<F>>(
fn find_optimal_subexpression<F: Field + Hash, S: Scoring<F>>(
exprs: &[Expr<F, Queriable<F>, HashResult>],
replaced_hashes: &HashSet<u64>,
config: CseConfig,
scorer: &S,
) -> Option<Expr<F, Queriable<F>, HashResult>> {
let mut count_map = HashMap::<u64, SubexprInfo>::new();
Expand All @@ -185,18 +196,8 @@ fn find_optimal_subexpression<F: Field + Hash, S: Scorer<F>>(
count_subexpressions(expr, &mut count_map, &mut hash_to_expr, replaced_hashes);
}

// Find the best common subexpression to replace
let common_ses = count_map
.into_iter()
.filter(|&(hash, info)| {
info.count >= config.min_occurrences
&& info.degree >= config.min_degree
&& !replaced_hashes.contains(&hash)
})
.collect::<HashMap<_, _>>();

// Find the best common subexpression to replace based on the score
let best_subexpr = common_ses
let best_subexpr = count_map
.iter()
.map(|(&hash, info)| {
let expr = hash_to_expr.get(&hash).unwrap();
Expand Down Expand Up @@ -269,29 +270,21 @@ impl<F> poly::SignalFactory<Queriable<F>> for SignalFactory<F> {

#[cfg(test)]
mod test {
use std::{collections::HashSet, hash::Hash};
use std::collections::HashSet;

use halo2_proofs::halo2curves::bn256::Fr;
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};

use crate::{
compiler::cse::{cse, CseConfig},
compiler::cse::{cse, CseConfig, Scorer},
field::Field,
poly::{Expr, HashResult, VarAssignments},
poly::{Expr, VarAssignments},
sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, InternalSignal, StepType},
util::uuid,
wit_gen::NullTraceGenerator,
};

use super::{find_optimal_subexpression, Scorer, SubexprInfo};

pub struct TestScorer;

impl<F: Field + Hash> Scorer<F> for TestScorer {
fn score(&self, _expr: &Expr<F, Queriable<F>, HashResult>, info: &SubexprInfo) -> usize {
2 * info.count + 3 * info.degree
}
}
use super::find_optimal_subexpression;

#[test]
fn test_find_optimal_subexpression() {
Expand Down Expand Up @@ -322,12 +315,11 @@ mod test {
hashed_exprs.push(hashed_expr);
}

let scorer = TestScorer;
let scorer = Scorer::default();

let best_expr = find_optimal_subexpression(
&hashed_exprs,
&HashSet::new(),
CseConfig::default(),
&scorer,
);

Expand Down Expand Up @@ -375,7 +367,7 @@ mod test {
let mut circuit: SBPIRMachine<Fr, NullTraceGenerator> = SBPIRMachine::default();
let step_uuid = circuit.add_step_type_def(step);

let scorer = TestScorer;
let scorer = Scorer::default();
let circuit = cse(circuit, CseConfig::default(), &scorer);

let common_ses_found_and_replaced = circuit
Expand Down

0 comments on commit 0720a88

Please sign in to comment.