diff --git a/Cargo.toml b/Cargo.toml index 3c42eab2..102a6fe1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,8 +33,6 @@ lalrpop-util = { version = "0.20.0", features = ["lexer", "unicode"] } lazy_static = "1.4.0" itertools = "0.12.1" codespan-reporting = "0.11.1" - -[dev-dependencies] rand_chacha = "0.3" [build-dependencies] diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index dd5c08af..0cb3959f 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -195,6 +195,11 @@ impl Compiler { circuit } + #[allow(dead_code)] + fn cse(mut _circuit: SBPIR) -> SBPIR { + todo!() + } + fn translate_queries( &mut self, symbols: &SymTable, @@ -522,7 +527,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let result = compile::( circuit, Config::default().max_degree(2), diff --git a/src/compiler/semantic/rules.rs b/src/compiler/semantic/rules.rs index fc8d6889..76263b1b 100644 --- a/src/compiler/semantic/rules.rs +++ b/src/compiler/semantic/rules.rs @@ -397,7 +397,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); @@ -454,7 +454,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); @@ -621,7 +621,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); @@ -736,7 +736,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); @@ -796,7 +796,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); @@ -861,7 +861,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); @@ -927,7 +927,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); @@ -984,7 +984,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); @@ -1050,7 +1050,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); diff --git a/src/field.rs b/src/field.rs index 679cdd40..354d7b87 100644 --- a/src/field.rs +++ b/src/field.rs @@ -5,6 +5,7 @@ use core::{ }; use num_bigint::BigInt; +use rand_chacha::rand_core::RngCore; pub trait Field: Sized @@ -46,6 +47,9 @@ pub trait Field: /// FF invert that returns None if the element is zero. fn mi(&self) -> Self; + /// Returns an element chosen uniformly at random using a user-provided RNG. + fn random(rng: impl RngCore) -> Self; + /// Exponentiates `self` by `exp`, where `exp` is a little-endian order integer /// exponent. fn pow>(&self, exp: S) -> Self; diff --git a/src/parser/mod.rs b/src/parser/mod.rs index e0332db4..c4031960 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -156,7 +156,7 @@ mod test { } "; - let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit); + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); diff --git a/src/plonkish/backend/halo2.rs b/src/plonkish/backend/halo2.rs index 89a6a163..96f21e8a 100644 --- a/src/plonkish/backend/halo2.rs +++ b/src/plonkish/backend/halo2.rs @@ -39,6 +39,10 @@ impl> ChiquitoField for T { fn from_big_int(value: &num_bigint::BigInt) -> Self { PrimeField::from_str_vartime(value.to_string().as_str()).expect("field value") } + + fn random(rng: impl rand_chacha::rand_core::RngCore) -> Self { + Self::random(rng) + } } #[allow(non_snake_case)] diff --git a/src/poly/cse.rs b/src/poly/cse.rs new file mode 100644 index 00000000..387ac5ef --- /dev/null +++ b/src/poly/cse.rs @@ -0,0 +1,134 @@ +use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + +use crate::field::Field; + +use std::{ + collections::HashMap, + fmt::Debug, + hash::Hash, + rc::{Rc, Weak}, +}; + +use super::{Expr, HashResult, VarAssignments}; + +/// Common Subexpression Elimination - takes a collection of expr +pub fn cse( + exprs: Vec>, + queriables: &Vec, +) -> Vec>> { + // generate random point in the field set + let assignments = generate_random_assignment(queriables); + + // hash table with the hash of the expression as the key + // and the weak reference to the expression as the value + let mut seen_hashes: HashMap>> = HashMap::new(); + + let mut result = Vec::new(); + + for expr in exprs { + let hashed_expr = Rc::new(expr.hash(&assignments)); + let hash = hashed_expr.meta().hash; + + // if the hash is already in the hash table, + // push the existing expression to the result + if let Some(existing_weak) = seen_hashes.get(&hash) { + if let Some(existing_expr) = existing_weak.upgrade() { + result.push(existing_expr); + continue; + } + } + + // otherwise, insert the hash and the weak reference to the expression + seen_hashes.insert(hash, Rc::downgrade(&hashed_expr)); + result.push(hashed_expr); + } + + result +} + +fn generate_random_assignment( + queriables: &Vec, +) -> VarAssignments { + let mut rng = ChaCha20Rng::seed_from_u64(0); + + let mut assignments = HashMap::new(); + + for queriable in queriables { + let value = F::random(&mut rng); + assignments.insert(queriable.clone(), value); + } + + assignments +} + +#[cfg(test)] +mod test { + use super::*; + use std::collections::HashSet; + + use halo2_proofs::halo2curves::bn256::Fr; + + use crate::{ + poly::{Expr::*, ToExpr}, + sbpir::{query::Queriable, ForwardSignal, InternalSignal}, + }; + + #[test] + fn test_generate_random_assignment() { + let internal = InternalSignal::new("internal"); + let forward = ForwardSignal::new("forward"); + + let a: Queriable = Queriable::Internal(internal); + let b: Queriable = Queriable::Forward(forward, false); + let c: Queriable = Queriable::Forward(forward, true); + + let vars = vec![a, b, c]; + + let keys: HashSet> = vars.iter().cloned().collect(); + + let assignments: VarAssignments> = generate_random_assignment(&vars); + + println!("Assignments: {:#?}", assignments); + + for key in &keys { + assert!(assignments.contains_key(key)); + } + } + + #[test] + fn test_cse() { + let forward = ForwardSignal::new("forward"); + + let a: Queriable = Queriable::Internal(InternalSignal::new("a")); + let b: Queriable = Queriable::Internal(InternalSignal::new("b")); + let c: Queriable = Queriable::Forward(forward, false); + let d: Queriable = Queriable::Forward(forward, true); + let e: Queriable = Queriable::Internal(InternalSignal::new("e")); + let f: Queriable = Queriable::Internal(InternalSignal::new("f")); + let g: Queriable = Queriable::Internal(InternalSignal::new("g")); + let vars = vec![a, b, c, d, e, f, g]; + + // Equivalent expressions + let expr1 = Pow(Box::new(e.expr()), 6, ()) * a * b + c * d - 1.expr(); + let expr2 = d * c - 1.expr() + a * b * Pow(Box::new(e.expr()), 6, ()); + + // Equivalent expressions + let expr3 = f * b - c * d - 1.expr(); + let expr4 = -(1.expr()) - c * d + b * f; + + // Equivalent expressions + let expr5 = -(-f * g) * (-(-(-a))); + let expr6 = -(f * g * a); + + let exprs = vec![expr1, expr2, expr3, expr4, expr5, expr6]; + + let result = cse(exprs, &vars); + + assert!(Rc::ptr_eq(&result[0], &result[1])); + assert!(Rc::ptr_eq(&result[2], &result[3])); + assert!(Rc::ptr_eq(&result[4], &result[5])); + assert!(!Rc::ptr_eq(&result[0], &result[2])); + assert!(!Rc::ptr_eq(&result[0], &result[4])); + assert!(!Rc::ptr_eq(&result[2], &result[4])); + } +} diff --git a/src/poly/mod.rs b/src/poly/mod.rs index 221ba71c..2e5942f9 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -1,7 +1,7 @@ use std::{ collections::HashMap, fmt::Debug, - hash::Hash, + hash::{DefaultHasher, Hash, Hasher}, ops::{Add, Mul, Neg, Sub}, }; @@ -9,6 +9,7 @@ use halo2_proofs::plonk::Expression; use crate::field::Field; +pub mod cse; pub mod mielim; pub mod reduce; pub mod simplify; @@ -47,6 +48,58 @@ impl Expr { Expr::MI(_, _) => panic!("not implemented"), } } + + pub fn meta(&self) -> &M { + match self { + Expr::Const(_, m) => m, + Expr::Sum(_, m) => m, + Expr::Mul(_, m) => m, + Expr::Neg(_, m) => m, + Expr::Pow(_, _, m) => m, + Expr::Query(_, m) => m, + Expr::Halo2Expr(_, m) => m, + Expr::MI(_, m) => m, + } + } +} + +#[derive(Debug, Clone)] +pub struct HashResult { + pub hash: u64, + pub degree: usize, +} + +impl Expr { + /// Uses Schwartz-Zippel Lemma to hash + pub fn hash(&self, assignments: &VarAssignments) -> Expr { + let mut hasher = DefaultHasher::new(); + + if let Some(result) = self.eval(assignments) { + result.hash(&mut hasher); + } + + let hash_result = HashResult { + hash: hasher.finish(), + degree: self.degree(), + }; + + match self { + Expr::Const(v, _) => Expr::Const(*v, hash_result), + Expr::Query(v, _) => Expr::Query(v.clone(), hash_result), + Expr::Sum(ses, _) => { + let new_ses = ses.iter().map(|se| se.hash(assignments)).collect(); + Expr::Sum(new_ses, hash_result) + } + Expr::Mul(ses, _) => { + let new_ses = ses.iter().map(|se| se.hash(assignments)).collect(); + Expr::Mul(new_ses, hash_result) + } + Expr::Neg(se, _) => Expr::Neg(Box::new(se.hash(assignments)), hash_result), + Expr::Pow(se, exp, _) => Expr::Pow(Box::new(se.hash(assignments)), *exp, hash_result), + Expr::Halo2Expr(_, _) => panic!("not implemented"), + Expr::MI(se, _) => Expr::MI(Box::new(se.hash(assignments)), hash_result), + } + } } impl Debug for Expr { @@ -91,7 +144,7 @@ impl Debug for Expr { pub type VarAssignments = HashMap; -impl Expr { +impl Expr { pub fn eval(&self, assignments: &VarAssignments) -> Option { match self { Expr::Const(v, _) => Some(*v), @@ -112,6 +165,21 @@ impl Expr Expr { + /// Returns all the keys of the queries + pub fn get_queries(&self) -> Vec { + match self { + Expr::Const(_, _) => Vec::new(), + Expr::Sum(ses, _) | Expr::Mul(ses, _) => { + ses.iter().flat_map(|se| se.get_queries()).collect() + } + Expr::Neg(se, _) | Expr::Pow(se, _, _) | Expr::MI(se, _) => se.get_queries(), + Expr::Query(q, _) => vec![q.clone()], + Expr::Halo2Expr(_, _) => Vec::new(), + } + } +} + impl ToExpr for Expr { fn expr(&self) -> Expr { self.clone() @@ -297,9 +365,13 @@ impl ConstrDecomp { #[cfg(test)] mod test { - use halo2_proofs::halo2curves::bn256::Fr; + use halo2_proofs::{arithmetic::Field, halo2curves::bn256::Fr}; + use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - use crate::{field::Field, poly::VarAssignments}; + use crate::{ + poly::{ToExpr, VarAssignments}, + sbpir::{query::Queriable, InternalSignal}, + }; use super::Expr; @@ -417,4 +489,101 @@ mod test { format!("{:?}", expr) ); } + + #[test] + fn test_hash() { + use super::Expr::*; + + let mut rng = ChaCha20Rng::seed_from_u64(0); + + let a: Queriable = Queriable::Internal(InternalSignal::new("a")); + let b: Queriable = Queriable::Internal(InternalSignal::new("b")); + let c: Queriable = Queriable::Internal(InternalSignal::new("c")); + let d: Queriable = Queriable::Internal(InternalSignal::new("d")); + let e: Queriable = Queriable::Internal(InternalSignal::new("e")); + let f: Queriable = Queriable::Internal(InternalSignal::new("f")); + let g: Queriable = Queriable::Internal(InternalSignal::new("g")); + let vars = vec![a, b, c, d, e, f, g]; + + let mut assignments = VarAssignments::new(); + for v in &vars { + assignments.insert(*v, Fr::random(&mut rng)); + } + + // Equivalent expressions + let expr1 = Pow(Box::new(e.expr()), 6, ()) * a * b + c * d - 1.expr(); + let expr2 = d * c - 1.expr() + a * b * Pow(Box::new(e.expr()), 6, ()); + + // Equivalent expressions + let expr3 = f * b - c * d - 1.expr(); + let expr4 = -(1.expr()) - c * d + b * f; + + // Equivalent expressions + let expr5 = -(-f * g) * (-(-(-a))); + let expr6 = -(f * g * a); + + let expressions = [expr1, expr2, expr3, expr4, expr5, expr6]; + let mut hashed_expressions = Vec::new(); + + for expr in expressions { + let hashed_expr = expr.hash(&assignments); + hashed_expressions.push(hashed_expr); + } + + assert_eq!( + hashed_expressions[0].meta().hash, + hashed_expressions[1].meta().hash + ); + + assert_eq!( + hashed_expressions[2].meta().hash, + hashed_expressions[3].meta().hash + ); + + assert_eq!( + hashed_expressions[4].meta().hash, + hashed_expressions[5].meta().hash + ); + + assert_ne!( + hashed_expressions[0].meta().hash, + hashed_expressions[2].meta().hash + ); + + assert_ne!( + hashed_expressions[0].meta().hash, + hashed_expressions[4].meta().hash + ); + + assert_ne!( + hashed_expressions[2].meta().hash, + hashed_expressions[5].meta().hash + ); + } + + #[test] + fn test_get_queries() { + use super::Expr::*; + + let a: Queriable = Queriable::Internal(InternalSignal::new("a")); + let b: Queriable = Queriable::Internal(InternalSignal::new("b")); + let c: Queriable = Queriable::Internal(InternalSignal::new("c")); + let d: Queriable = Queriable::Internal(InternalSignal::new("d")); + let e: Queriable = Queriable::Internal(InternalSignal::new("e")); + let f: Queriable = Queriable::Internal(InternalSignal::new("f")); + let g: Queriable = Queriable::Internal(InternalSignal::new("g")); + + let expr = Pow(Box::new(e.expr()), 6, ()) * a * b + c * d - 1.expr(); + + let queries = expr.get_queries(); + + assert_eq!(queries.len(), 5); + assert!(queries.contains(&a)); + assert!(queries.contains(&b)); + assert!(queries.contains(&c)); + assert!(queries.contains(&d)); + assert!(queries.contains(&e)); + assert!(!queries.contains(&f)); + assert!(!queries.contains(&g)); + } } diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index e4acc926..fd10b234 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -480,23 +480,23 @@ pub struct ForwardSignal { } impl ForwardSignal { - pub fn new(annotation: String) -> ForwardSignal { - Self::new_with_id(uuid(), 0, annotation) + pub fn new>(annotation: S) -> ForwardSignal { + Self::new_with_id(uuid(), 0, annotation.into()) } - pub fn new_with_phase(phase: usize, annotation: String) -> ForwardSignal { + pub fn new_with_phase>(phase: usize, annotation: S) -> ForwardSignal { ForwardSignal { id: uuid(), phase, - annotation: Box::leak(annotation.into_boxed_str()), + annotation: Box::leak(annotation.into().into_boxed_str()), } } - pub fn new_with_id(id: UUID, phase: usize, annotation: String) -> Self { + pub fn new_with_id>(id: UUID, phase: usize, annotation: S) -> Self { Self { id, phase, - annotation: Box::leak(annotation.into_boxed_str()), + annotation: Box::leak(annotation.into().into_boxed_str()), } } @@ -521,19 +521,19 @@ pub struct SharedSignal { } impl SharedSignal { - pub fn new_with_phase(phase: usize, annotation: String) -> SharedSignal { + pub fn new_with_phase>(phase: usize, annotation: S) -> SharedSignal { SharedSignal { id: uuid(), phase, - annotation: Box::leak(annotation.into_boxed_str()), + annotation: Box::leak(annotation.into().into_boxed_str()), } } - pub fn new_with_id(id: UUID, phase: usize, annotation: String) -> Self { + pub fn new_with_id>(id: UUID, phase: usize, annotation: S) -> Self { Self { id, phase, - annotation: Box::leak(annotation.into_boxed_str()), + annotation: Box::leak(annotation.into().into_boxed_str()), } } @@ -557,17 +557,17 @@ pub struct FixedSignal { } impl FixedSignal { - pub fn new(annotation: String) -> FixedSignal { + pub fn new>(annotation: S) -> FixedSignal { FixedSignal { id: uuid(), - annotation: Box::leak(annotation.into_boxed_str()), + annotation: Box::leak(annotation.into().into_boxed_str()), } } - pub fn new_with_id(id: UUID, annotation: String) -> Self { + pub fn new_with_id>(id: UUID, annotation: S) -> Self { Self { id, - annotation: Box::leak(annotation.into_boxed_str()), + annotation: Box::leak(annotation.into().into_boxed_str()), } } @@ -600,10 +600,10 @@ impl InternalSignal { } } - pub fn new_with_id(id: UUID, annotation: String) -> Self { + pub fn new_with_id>(id: UUID, annotation: S) -> Self { Self { id, - annotation: Box::leak(annotation.into_boxed_str()), + annotation: Box::leak(annotation.into().into_boxed_str()), } } diff --git a/src/wit_gen.rs b/src/wit_gen.rs index b3da3674..e970a810 100644 --- a/src/wit_gen.rs +++ b/src/wit_gen.rs @@ -324,15 +324,15 @@ mod tests { StepInstance { step_type_uuid: 9, assignments: HashMap::from([ - (Queriable::Fixed(FixedSignal::new("a".into()), 0), 1), - (Queriable::Fixed(FixedSignal::new("b".into()), 0), 2) + (Queriable::Fixed(FixedSignal::new("a"), 0), 1), + (Queriable::Fixed(FixedSignal::new("b"), 0), 2) ]), }, StepInstance { step_type_uuid: 10, assignments: HashMap::from([ - (Queriable::Fixed(FixedSignal::new("a".into()), 0), 1), - (Queriable::Fixed(FixedSignal::new("b".into()), 0), 2) + (Queriable::Fixed(FixedSignal::new("a"), 0), 1), + (Queriable::Fixed(FixedSignal::new("b"), 0), 2) ]), } ]