Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge branch 'chiquito-2024' into language-server
Browse files Browse the repository at this point in the history
  • Loading branch information
alxkzmn committed Jun 7, 2024
2 parents 061330f + b9cc99c commit a3cecdd
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 37 deletions.
2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ lalrpop-util = { version = "0.20.0", features = ["lexer", "unicode"] }
lazy_static = "1.4.0"
itertools = "0.12.1"
codespan-reporting = "0.11.1"

[dev-dependencies]
rand_chacha = "0.3"

[build-dependencies]
Expand Down
7 changes: 6 additions & 1 deletion src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ impl<F: Field + Hash> Compiler<F> {
circuit
}

#[allow(dead_code)]
fn cse(mut _circuit: SBPIR<F, ()>) -> SBPIR<F, ()> {
todo!()
}

fn translate_queries(
&mut self,
symbols: &SymTable,
Expand Down Expand Up @@ -522,7 +527,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let result = compile::<Fr>(
circuit,
Config::default().max_degree(2),
Expand Down
18 changes: 9 additions & 9 deletions src/compiler/semantic/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let decls = lang::TLDeclsParser::new()
.parse(&debug_sym_ref_factory, circuit)
.unwrap();
Expand Down Expand Up @@ -454,7 +454,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let decls = lang::TLDeclsParser::new()
.parse(&debug_sym_ref_factory, circuit)
.unwrap();
Expand Down Expand Up @@ -621,7 +621,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let decls = lang::TLDeclsParser::new()
.parse(&debug_sym_ref_factory, circuit)
.unwrap();
Expand Down Expand Up @@ -736,7 +736,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let decls = lang::TLDeclsParser::new()
.parse(&debug_sym_ref_factory, circuit)
.unwrap();
Expand Down Expand Up @@ -796,7 +796,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let decls = lang::TLDeclsParser::new()
.parse(&debug_sym_ref_factory, circuit)
.unwrap();
Expand Down Expand Up @@ -861,7 +861,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let decls = lang::TLDeclsParser::new()
.parse(&debug_sym_ref_factory, circuit)
.unwrap();
Expand Down Expand Up @@ -927,7 +927,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let decls = lang::TLDeclsParser::new()
.parse(&debug_sym_ref_factory, circuit)
.unwrap();
Expand Down Expand Up @@ -984,7 +984,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let decls = lang::TLDeclsParser::new()
.parse(&debug_sym_ref_factory, circuit)
.unwrap();
Expand Down Expand Up @@ -1050,7 +1050,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let decls = lang::TLDeclsParser::new()
.parse(&debug_sym_ref_factory, circuit)
.unwrap();
Expand Down
4 changes: 4 additions & 0 deletions src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::{
};

use num_bigint::BigInt;
use rand_chacha::rand_core::RngCore;

pub trait Field:
Sized
Expand Down Expand Up @@ -46,6 +47,9 @@ pub trait Field:
/// FF invert that returns None if the element is zero.
fn mi(&self) -> Self;

/// Returns an element chosen uniformly at random using a user-provided RNG.
fn random(rng: impl RngCore) -> Self;

/// Exponentiates `self` by `exp`, where `exp` is a little-endian order integer
/// exponent.
fn pow<S: AsRef<[u64]>>(&self, exp: S) -> Self;
Expand Down
2 changes: 1 addition & 1 deletion src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ mod test {
}
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", &circuit);
let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let decls = lang::TLDeclsParser::new()
.parse(&debug_sym_ref_factory, circuit)
.unwrap();
Expand Down
4 changes: 4 additions & 0 deletions src/plonkish/backend/halo2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ impl<T: PrimeField + From<u64>> ChiquitoField for T {
fn from_big_int(value: &num_bigint::BigInt) -> Self {
PrimeField::from_str_vartime(value.to_string().as_str()).expect("field value")
}

fn random(rng: impl rand_chacha::rand_core::RngCore) -> Self {
Self::random(rng)
}
}

#[allow(non_snake_case)]
Expand Down
134 changes: 134 additions & 0 deletions src/poly/cse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};

use crate::field::Field;

use std::{
collections::HashMap,
fmt::Debug,
hash::Hash,
rc::{Rc, Weak},
};

use super::{Expr, HashResult, VarAssignments};

/// Common Subexpression Elimination - takes a collection of expr
pub fn cse<F: Field + Hash, V: Debug + Clone + Eq + Hash>(
exprs: Vec<Expr<F, V, ()>>,
queriables: &Vec<V>,
) -> Vec<Rc<Expr<F, V, HashResult>>> {
// generate random point in the field set
let assignments = generate_random_assignment(queriables);

// hash table with the hash of the expression as the key
// and the weak reference to the expression as the value
let mut seen_hashes: HashMap<u64, Weak<Expr<F, V, HashResult>>> = HashMap::new();

let mut result = Vec::new();

for expr in exprs {
let hashed_expr = Rc::new(expr.hash(&assignments));
let hash = hashed_expr.meta().hash;

// if the hash is already in the hash table,
// push the existing expression to the result
if let Some(existing_weak) = seen_hashes.get(&hash) {
if let Some(existing_expr) = existing_weak.upgrade() {
result.push(existing_expr);
continue;
}
}

// otherwise, insert the hash and the weak reference to the expression
seen_hashes.insert(hash, Rc::downgrade(&hashed_expr));
result.push(hashed_expr);
}

result
}

fn generate_random_assignment<F: Field + Hash, V: Debug + Clone + Eq + Hash>(
queriables: &Vec<V>,
) -> VarAssignments<F, V> {
let mut rng = ChaCha20Rng::seed_from_u64(0);

let mut assignments = HashMap::new();

for queriable in queriables {
let value = F::random(&mut rng);
assignments.insert(queriable.clone(), value);
}

assignments
}

#[cfg(test)]
mod test {
use super::*;
use std::collections::HashSet;

use halo2_proofs::halo2curves::bn256::Fr;

use crate::{
poly::{Expr::*, ToExpr},
sbpir::{query::Queriable, ForwardSignal, InternalSignal},
};

#[test]
fn test_generate_random_assignment() {
let internal = InternalSignal::new("internal");
let forward = ForwardSignal::new("forward");

let a: Queriable<Fr> = Queriable::Internal(internal);
let b: Queriable<Fr> = Queriable::Forward(forward, false);
let c: Queriable<Fr> = Queriable::Forward(forward, true);

let vars = vec![a, b, c];

let keys: HashSet<Queriable<Fr>> = vars.iter().cloned().collect();

let assignments: VarAssignments<Fr, Queriable<Fr>> = generate_random_assignment(&vars);

println!("Assignments: {:#?}", assignments);

for key in &keys {
assert!(assignments.contains_key(key));
}
}

#[test]
fn test_cse() {
let forward = ForwardSignal::new("forward");

let a: Queriable<Fr> = Queriable::Internal(InternalSignal::new("a"));
let b: Queriable<Fr> = Queriable::Internal(InternalSignal::new("b"));
let c: Queriable<Fr> = Queriable::Forward(forward, false);
let d: Queriable<Fr> = Queriable::Forward(forward, true);
let e: Queriable<Fr> = Queriable::Internal(InternalSignal::new("e"));
let f: Queriable<Fr> = Queriable::Internal(InternalSignal::new("f"));
let g: Queriable<Fr> = Queriable::Internal(InternalSignal::new("g"));
let vars = vec![a, b, c, d, e, f, g];

// Equivalent expressions
let expr1 = Pow(Box::new(e.expr()), 6, ()) * a * b + c * d - 1.expr();
let expr2 = d * c - 1.expr() + a * b * Pow(Box::new(e.expr()), 6, ());

// Equivalent expressions
let expr3 = f * b - c * d - 1.expr();
let expr4 = -(1.expr()) - c * d + b * f;

// Equivalent expressions
let expr5 = -(-f * g) * (-(-(-a)));
let expr6 = -(f * g * a);

let exprs = vec![expr1, expr2, expr3, expr4, expr5, expr6];

let result = cse(exprs, &vars);

assert!(Rc::ptr_eq(&result[0], &result[1]));
assert!(Rc::ptr_eq(&result[2], &result[3]));
assert!(Rc::ptr_eq(&result[4], &result[5]));
assert!(!Rc::ptr_eq(&result[0], &result[2]));
assert!(!Rc::ptr_eq(&result[0], &result[4]));
assert!(!Rc::ptr_eq(&result[2], &result[4]));
}
}
Loading

0 comments on commit a3cecdd

Please sign in to comment.