diff --git a/src/compiler/cse.rs b/src/compiler/cse.rs index 4cab009c..9cb45962 100644 --- a/src/compiler/cse.rs +++ b/src/compiler/cse.rs @@ -13,7 +13,9 @@ use crate::{ cse::{create_common_ses_signal, replace_expr}, Expr, HashResult, VarAssignments, }, - sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, ForwardSignal, InternalSignal, StepType}, + sbpir::{ + query::Queriable, sbpir_machine::SBPIRMachine, ForwardSignal, InternalSignal, StepType, + }, wit_gen::NullTraceGenerator, }; @@ -37,7 +39,11 @@ impl Default for CseConfig { } #[allow(dead_code)] -pub fn config(min_degree: usize, min_occurrences: usize, max_iterations: Option) -> CseConfig { +pub fn config( + min_degree: usize, + min_occurrences: usize, + max_iterations: Option, +) -> CseConfig { CseConfig { min_degree, min_occurrences, @@ -45,6 +51,10 @@ pub fn config(min_degree: usize, min_occurrences: usize, max_iterations: Option< } } +pub trait Scorer { + fn score(&self, expr: &Expr, HashResult>, info: &SubexprInfo) -> usize; +} + /// Common Subexpression Elimination (CSE) optimization. /// This optimization replaces common subexpressions with new internal signals for the step type. /// This is done by each time finding the optimal subexpression to replace and creating a new signal @@ -54,20 +64,22 @@ pub fn config(min_degree: usize, min_occurrences: usize, max_iterations: Option< /// queriables. Using the Schwartz-Zippel lemma, we can determine if two expressions are equivalent /// with high probability. #[allow(dead_code)] -pub(super) fn cse( +pub(super) fn cse>( mut circuit: SBPIRMachine, config: CseConfig, + scorer: &S, ) -> SBPIRMachine { for (_, step_type) in circuit.step_types.iter_mut() { - cse_for_step(step_type, &circuit.forward_signals, &config) + cse_for_step(step_type, &circuit.forward_signals, &config, scorer) } circuit } -fn cse_for_step( +fn cse_for_step>( step_type: &mut StepType, forward_signals: &[ForwardSignal], config: &CseConfig, + scorer: &S, ) { let mut signal_factory = SignalFactory::default(); let mut replaced_hashes = HashSet::new(); @@ -105,7 +117,9 @@ fn cse_for_step( } // Find the optimal subexpression to replace - if let Some(common_expr) = find_optimal_subexpression(&exprs, &replaced_hashes, config.clone()) { + if let Some(common_expr) = + find_optimal_subexpression(&exprs, &replaced_hashes, config.clone(), scorer) + { // Add the hash of the replaced expression to the set replaced_hashes.insert(common_expr.meta().hash); // Create a new signal for the common subexpression @@ -139,7 +153,7 @@ fn cse_for_step( } #[derive(Debug, Clone, Copy)] -struct SubexprInfo { +pub(super) struct SubexprInfo { count: usize, degree: usize, } @@ -154,18 +168,14 @@ impl SubexprInfo { self.count += 1; self.degree = self.degree.max(degree); } - - fn get_score(&self) -> usize { - // TODO: Improve the scoring function and adjust the weights - 2 * self.count + 3 * self.degree - } } /// Find the optimal subexpression to replace in a list of expressions. -fn find_optimal_subexpression( +fn find_optimal_subexpression>( exprs: &[Expr, HashResult>], replaced_hashes: &HashSet, - config: CseConfig + config: CseConfig, + scorer: &S, ) -> Option, HashResult>> { let mut count_map = HashMap::::new(); let mut hash_to_expr = HashMap::, HashResult>>::new(); @@ -179,21 +189,25 @@ fn find_optimal_subexpression( let common_ses = count_map .into_iter() .filter(|&(hash, info)| { - info.count >= config.min_occurrences && info.degree >= config.min_degree && !replaced_hashes.contains(&hash) + info.count >= config.min_occurrences + && info.degree >= config.min_degree + && !replaced_hashes.contains(&hash) }) .collect::>(); // Find the best common subexpression to replace based on the score let best_subexpr = common_ses .iter() - .max_by_key(|&(_, info)| info.get_score()) - .map(|(&hash, info)| (hash, info.count, info.degree)); + .map(|(&hash, info)| { + let expr = hash_to_expr.get(&hash).unwrap(); + let score = scorer.score(expr, info); + (hash, score) + }) + .filter(|&(_, score)| score > 0) + .max_by_key(|&(_, score)| score) + .map(|(hash, _)| hash); - if let Some((hash, _count, _degree)) = best_subexpr { - hash_to_expr.get(&hash).cloned() - } else { - None - } + best_subexpr.and_then(|hash| hash_to_expr.get(&hash).cloned()) } /// Count the subexpressions in an expression and store them in a map. @@ -255,7 +269,7 @@ impl poly::SignalFactory> for SignalFactory { #[cfg(test)] mod test { - use std::collections::HashSet; + use std::{collections::HashSet, hash::Hash}; use halo2_proofs::halo2curves::bn256::Fr; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; @@ -263,13 +277,21 @@ mod test { use crate::{ compiler::cse::{cse, CseConfig}, field::Field, - poly::{Expr, VarAssignments}, + poly::{Expr, HashResult, VarAssignments}, sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, InternalSignal, StepType}, util::uuid, wit_gen::NullTraceGenerator, }; - use super::find_optimal_subexpression; + use super::{find_optimal_subexpression, Scorer, SubexprInfo}; + + pub struct TestScorer; + + impl Scorer for TestScorer { + fn score(&self, _expr: &Expr, HashResult>, info: &SubexprInfo) -> usize { + 2 * info.count + 3 * info.degree + } + } #[test] fn test_find_optimal_subexpression() { @@ -300,7 +322,14 @@ mod test { hashed_exprs.push(hashed_expr); } - let best_expr = find_optimal_subexpression(&hashed_exprs, &HashSet::new(), CseConfig::default()); + let scorer = TestScorer; + + let best_expr = find_optimal_subexpression( + &hashed_exprs, + &HashSet::new(), + CseConfig::default(), + &scorer, + ); assert_eq!(format!("{:?}", best_expr.unwrap()), "(e * f * d)"); } @@ -346,7 +375,8 @@ mod test { let mut circuit: SBPIRMachine = SBPIRMachine::default(); let step_uuid = circuit.add_step_type_def(step); - let circuit = cse(circuit, CseConfig::default()); + let scorer = TestScorer; + let circuit = cse(circuit, CseConfig::default(), &scorer); let common_ses_found_and_replaced = circuit .step_types diff --git a/src/poly/mod.rs b/src/poly/mod.rs index 5acbfc36..97ad9c3c 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -111,16 +111,19 @@ impl Expr { .collect(), new_meta, ), - Expr::Neg(se, _) => Expr::Neg(Box::new(se.transform_meta(apply_meta.clone())), new_meta), - Expr::Pow(se, exp, _) => { - Expr::Pow(Box::new(se.transform_meta(apply_meta.clone())), *exp, new_meta) + Expr::Neg(se, _) => { + Expr::Neg(Box::new(se.transform_meta(apply_meta.clone())), new_meta) } + Expr::Pow(se, exp, _) => Expr::Pow( + Box::new(se.transform_meta(apply_meta.clone())), + *exp, + new_meta, + ), Expr::Query(v, _) => Expr::Query(v.clone(), new_meta), Expr::Halo2Expr(e, _) => Expr::Halo2Expr(e.clone(), new_meta), Expr::MI(se, _) => Expr::MI(Box::new(se.transform_meta(apply_meta.clone())), new_meta), } } - pub fn apply_subexpressions(&self, mut f: T) -> Self where diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index 441e7f88..254d18c9 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -232,7 +232,10 @@ impl, M> SBPIRLegacy { } impl + Clone, M: Clone> SBPIRLegacy { - pub fn transform_meta(&self, apply_meta: ApplyMetaFn) -> SBPIRLegacy + pub fn transform_meta( + &self, + apply_meta: ApplyMetaFn, + ) -> SBPIRLegacy where ApplyMetaFn: Fn(&Expr, M>) -> N + Clone, {