Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
rutefig committed Jul 31, 2024
1 parent b215605 commit b8328ca
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 21 deletions.
37 changes: 21 additions & 16 deletions src/compiler/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@

use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};

use std::{collections::{HashMap, HashSet}, hash::Hash, marker::PhantomData};
use std::{
collections::{HashMap, HashSet},
hash::Hash,
marker::PhantomData,
};

use crate::{
field::Field,
poly::{self, cse::{create_common_ses_signal, replace_expr}, Expr, HashResult, VarAssignments},
poly::{
self,
cse::{create_common_ses_signal, replace_expr},
Expr, HashResult, VarAssignments,
},
sbpir::{query::Queriable, InternalSignal, SBPIR},
wit_gen::NullTraceGenerator,
};
Expand Down Expand Up @@ -52,19 +60,20 @@ pub(super) fn cse<F: Field + Hash>(
}

// Find the optimal subexpression to replace
println!("Step type before CSE: {:#?}", step_type_with_hash);
if let Some(common_expr) = find_optimal_subexpression(&exprs, &replaced_hashes) {
println!("Common expression found: {:?}", common_expr);
// Add the hash of the replaced expression to the set
replaced_hashes.insert(common_expr.meta().hash);
// Add the hash of the replaced expression to the set
replaced_hashes.insert(common_expr.meta().hash);
// Create a new signal for the common subexpression
let (common_se, decomp) = create_common_ses_signal(&common_expr, &mut signal_factory);
let (common_se, decomp) =
create_common_ses_signal(&common_expr, &mut signal_factory);

decomp.auto_signals.iter().for_each(|(q, expr)| {
if let Queriable::Internal(signal) = q {
step_type_with_hash.add_internal(signal.clone());
}
step_type_with_hash.auto_signals.insert(q.clone(), expr.clone());
step_type_with_hash
.auto_signals
.insert(q.clone(), expr.clone());
step_type_with_hash.add_constr(format!("{:?}", q), expr.clone());
});

Expand Down Expand Up @@ -114,22 +123,18 @@ fn find_optimal_subexpression<F: Field + Hash>(
// Find the best common subexpression to replace
let common_ses = count_map
.into_iter()
.filter(|&(hash, info)| info.count > 1 && info.degree > 1 && !replaced_hashes.contains(&hash))
.filter(|&(hash, info)| {
info.count > 1 && info.degree > 1 && !replaced_hashes.contains(&hash)
})
.collect::<HashMap<_, _>>();

println!("Common subexpressions: {:#?}", common_ses);

let best_subexpr = common_ses
.iter()
.max_by_key(|&(_, info)| (info.degree, info.count))
.map(|(&hash, info)| (hash, info.count, info.degree));

println!("Best subexpression: {:#?}", best_subexpr);

if let Some((hash, _count, _degree)) = best_subexpr {
let best_subexpr = hash_to_expr.get(&hash).cloned();
println!("Best subexpression found: {:#?}", best_subexpr);
best_subexpr
hash_to_expr.get(&hash).cloned()
} else {
None
}
Expand Down
16 changes: 12 additions & 4 deletions src/poly/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@ pub fn replace_expr<F: Field + Hash, V: Clone + Eq + Hash + Debug, SF: SignalFac
expr: &Expr<F, V, HashResult>,
common_se: &Expr<F, V, HashResult>,
signal_factory: &mut SF,
decomp: ConstrDecomp<F, V, HashResult>
decomp: ConstrDecomp<F, V, HashResult>,
) -> (Expr<F, V, HashResult>, ConstrDecomp<F, V, HashResult>) {
let mut decomp = decomp;
let new_expr = replace_subexpr(expr, common_se, signal_factory, &mut decomp);

(new_expr, ConstrDecomp::default())
}

pub fn create_common_ses_signal<F: Field, V: Clone + PartialEq + Eq + Hash + Debug, SF: SignalFactory<V>>(
pub fn create_common_ses_signal<
F: Field,
V: Clone + PartialEq + Eq + Hash + Debug,
SF: SignalFactory<V>,
>(
common_se: &Expr<F, V, HashResult>,
signal_factory: &mut SF,
) -> (Expr<F, V, HashResult>, ConstrDecomp<F, V, HashResult>) {
Expand Down Expand Up @@ -58,7 +62,10 @@ mod tests {
use halo2_proofs::halo2curves::bn256::Fr;

use crate::{
poly::{cse::{create_common_ses_signal, replace_expr}, SignalFactory, ToExpr, VarAssignments},
poly::{
cse::{create_common_ses_signal, replace_expr},
SignalFactory, ToExpr, VarAssignments,
},
sbpir::{query::Queriable, InternalSignal},
};

Expand Down Expand Up @@ -91,7 +98,8 @@ mod tests {
let assignments: VarAssignments<Fr, Queriable<Fr>> =
vars.iter().cloned().map(|q| (q, Fr::from(2))).collect();

let (common_se, decomp) = create_common_ses_signal(&common_expr.hash(&assignments), &mut signal_factory);
let (common_se, decomp) =
create_common_ses_signal(&common_expr.hash(&assignments), &mut signal_factory);

let (new_expr, decomp) = replace_expr(
&expr.hash(&assignments),
Expand Down
5 changes: 4 additions & 1 deletion src/sbpir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,10 @@ pub struct TransitionConstraint<F, M> {
}

impl<F: Field + Hash, M: Clone> TransitionConstraint<F, M> {
pub fn with_meta<N: Clone, ApplyMetaFn>(&self, apply_meta: ApplyMetaFn) -> TransitionConstraint<F, N>
pub fn with_meta<N: Clone, ApplyMetaFn>(
&self,
apply_meta: ApplyMetaFn,
) -> TransitionConstraint<F, N>
where
ApplyMetaFn: Fn(&Expr<F, Queriable<F>, ()>) -> N + Clone,
{
Expand Down

0 comments on commit b8328ca

Please sign in to comment.