Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Refactor compiler to return a new structure for multiple machines (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
alxkzmn authored Aug 6, 2024
1 parent 7086456 commit ec0cd93
Show file tree
Hide file tree
Showing 16 changed files with 538 additions and 88 deletions.
35 changes: 21 additions & 14 deletions examples/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use chiquito::{
compile, // input for constructing the compiler
config,
step_selector::SimpleStepSelectorBuilder,
PlonkishCompilationResult,
},
ir::{assignments::AssignmentGenerator, Circuit},
}, /* compiles to
* Chiquito Halo2
* backend,
Expand All @@ -24,7 +24,7 @@ use chiquito::{
* Halo2
* circuit */
poly::ToField,
sbpir::SBPIR,
sbpir::SBPIRLegacy,
};
use halo2_proofs::dev::MockProver;

Expand All @@ -35,9 +35,8 @@ use halo2_proofs::dev::MockProver;
// 3. two witness generation arguments both of u64 type, i.e. (u64, u64)

type FiboReturn<F> = (
Circuit<F>,
Option<AssignmentGenerator<F>>,
SBPIR<F, DSLTraceGenerator<F>>,
PlonkishCompilationResult<F, DSLTraceGenerator<F>>,
SBPIRLegacy<F, DSLTraceGenerator<F>>,
);

fn fibo_circuit<F: Field + From<u64> + Hash>() -> FiboReturn<F> {
Expand Down Expand Up @@ -124,17 +123,20 @@ fn fibo_circuit<F: Field + From<u64> + Hash>() -> FiboReturn<F> {
&fibo,
);

(compiled.circuit, compiled.assignment_generator, fibo)
(compiled, fibo)
}

// After compiling Chiquito AST to an IR, it is further parsed by a Chiquito Halo2 backend and
// integrated into a Halo2 circuit, which is done by the boilerplate code below.

// standard main function for a Halo2 circuit
fn main() {
let (chiquito, wit_gen, _) = fibo_circuit::<Fr>();
let compiled = chiquito2Halo2(chiquito);
let circuit = ChiquitoHalo2Circuit::new(compiled, wit_gen.map(|g| g.generate(())));
let (chiquito, _) = fibo_circuit::<Fr>();
let compiled = chiquito2Halo2(chiquito.circuit);
let circuit = ChiquitoHalo2Circuit::new(
compiled,
chiquito.assignment_generator.map(|g| g.generate(())),
);

let prover = MockProver::<Fr>::run(7, &circuit, circuit.instance()).unwrap();

Expand All @@ -156,11 +158,11 @@ fn main() {
pcs::{multilinear, univariate},
};
// get Chiquito ir
let (circuit, assignment_generator, _) = fibo_circuit::<Fr>();
let (plonkish, _) = fibo_circuit::<Fr>();
// get assignments
let assignments = assignment_generator.unwrap().generate(());
let assignments = plonkish.assignment_generator.unwrap().generate(());
// get hyperplonk circuit
let mut hyperplonk_circuit = ChiquitoHyperPlonkCircuit::new(4, circuit);
let mut hyperplonk_circuit = ChiquitoHyperPlonkCircuit::new(4, plonkish.circuit);
hyperplonk_circuit.set_assignment(assignments);

type GeminiKzg = multilinear::Gemini<univariate::UnivariateKzg<Bn256>>;
Expand All @@ -170,10 +172,15 @@ fn main() {
// pil boilerplate
use chiquito::pil::backend::powdr_pil::chiquito2Pil;

let (_, wit_gen, circuit) = fibo_circuit::<Fr>();
let (plonkish, circuit) = fibo_circuit::<Fr>();
let pil = chiquito2Pil(
circuit,
Some(wit_gen.unwrap().generate_trace_witness(())),
Some(
plonkish
.assignment_generator
.unwrap()
.generate_trace_witness(()),
),
String::from("FiboCircuit"),
);
print!("{}", pil);
Expand Down
81 changes: 66 additions & 15 deletions src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
},
plonkish::{self, compiler::PlonkishCompilationResult},
poly::{self, mielim::mi_elimination, reduce::reduce_degree, Expr},
sbpir::{query::Queriable, InternalSignal, SBPIR},
sbpir::{query::Queriable, InternalSignal, SBPIRLegacy, SBPIR},
wit_gen::{NullTraceGenerator, SymbolSignalMapping, TraceGenerator},
};

Expand All @@ -31,15 +31,19 @@ use super::{
Config, Message, Messages,
};

/// Contains the result of a compilation.
#[derive(Debug)]
pub struct CompilerResult<F: Field + Hash> {
pub messages: Vec<Message>,
// pub wit_gen: WitnessGenerator,
pub circuit: SBPIR<F, InterpreterTraceGenerator>,
}

impl<F: Field + Hash> CompilerResult<F> {
/// Contains the result of a single machine compilation (legacy).
#[derive(Debug)]
pub struct CompilerResultLegacy<F: Field + Hash> {
pub messages: Vec<Message>,
pub circuit: SBPIRLegacy<F, InterpreterTraceGenerator>,
}

impl<F: Field + Hash> CompilerResultLegacy<F> {
/// Compiles to the Plonkish IR, that then can be compiled to plonkish backends.
pub fn plonkish<
CM: plonkish::compiler::cell_manager::CellManager,
Expand Down Expand Up @@ -76,6 +80,41 @@ impl<F: Field + Hash> Compiler<F> {
}
}

/// Compile the source code containing a single machine (legacy).
pub(super) fn compile_legacy(
mut self,
source: &str,
debug_sym_ref_factory: &DebugSymRefFactory,
) -> Result<CompilerResultLegacy<F>, Vec<Message>> {
let ast = self
.parse(source, debug_sym_ref_factory)
.map_err(|_| self.messages.clone())?;
assert!(ast.len() == 1, "Use `compile` to compile multiple machines");
let ast = self.add_virtual(ast);
let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?;
let setup = Self::interpret(&ast, &symbols);
let setup = Self::map_consts(setup);
let circuit = self.build(&setup, &symbols);
let circuit = Self::mi_elim(circuit);
let circuit = if let Some(degree) = self.config.max_degree {
Self::reduce(circuit, degree)
} else {
circuit
};

let circuit = circuit.with_trace(InterpreterTraceGenerator::new(
ast,
symbols,
self.mapping,
self.config.max_steps,
));

Ok(CompilerResultLegacy {
messages: self.messages,
circuit,
})
}

/// Compile the source code.
pub(super) fn compile(
mut self,
Expand All @@ -89,6 +128,9 @@ impl<F: Field + Hash> Compiler<F> {
let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?;
let setup = Self::interpret(&ast, &symbols);
let setup = Self::map_consts(setup);

let machine_id = setup.iter().next().unwrap().0;

let circuit = self.build(&setup, &symbols);
let circuit = Self::mi_elim(circuit);
let circuit = if let Some(degree) = self.config.max_degree {
Expand All @@ -104,9 +146,12 @@ impl<F: Field + Hash> Compiler<F> {
self.config.max_steps,
));

// TODO perform real compilation for multiple machines
let sbpir = SBPIR::from_legacy(circuit, machine_id.as_str());

Ok(CompilerResult {
messages: self.messages,
circuit,
circuit: sbpir,
})
}

Expand Down Expand Up @@ -287,7 +332,11 @@ impl<F: Field + Hash> Compiler<F> {
}
}

fn build(&mut self, setup: &Setup<F>, symbols: &SymTable) -> SBPIR<F, NullTraceGenerator> {
fn build(
&mut self,
setup: &Setup<F>,
symbols: &SymTable,
) -> SBPIRLegacy<F, NullTraceGenerator> {
circuit::<F, (), _>("circuit", |ctx| {
for (machine_id, machine) in setup {
self.add_forwards(ctx, symbols, machine_id);
Expand Down Expand Up @@ -327,7 +376,9 @@ impl<F: Field + Hash> Compiler<F> {
.without_trace()
}

fn mi_elim(mut circuit: SBPIR<F, NullTraceGenerator>) -> SBPIR<F, NullTraceGenerator> {
fn mi_elim(
mut circuit: SBPIRLegacy<F, NullTraceGenerator>,
) -> SBPIRLegacy<F, NullTraceGenerator> {
for (_, step_type) in circuit.step_types.iter_mut() {
let mut signal_factory = SignalFactory::default();

Expand All @@ -338,9 +389,9 @@ impl<F: Field + Hash> Compiler<F> {
}

fn reduce(
mut circuit: SBPIR<F, NullTraceGenerator>,
mut circuit: SBPIRLegacy<F, NullTraceGenerator>,
degree: usize,
) -> SBPIR<F, NullTraceGenerator> {
) -> SBPIRLegacy<F, NullTraceGenerator> {
for (_, step_type) in circuit.step_types.iter_mut() {
let mut signal_factory = SignalFactory::default();

Expand All @@ -353,7 +404,7 @@ impl<F: Field + Hash> Compiler<F> {
}

#[allow(dead_code)]
fn cse(mut _circuit: SBPIR<F, NullTraceGenerator>) -> SBPIR<F, NullTraceGenerator> {
fn cse(mut _circuit: SBPIRLegacy<F, NullTraceGenerator>) -> SBPIRLegacy<F, NullTraceGenerator> {
todo!()
}

Expand Down Expand Up @@ -627,7 +678,7 @@ mod test {
use halo2_proofs::halo2curves::bn256::Fr;

use crate::{
compiler::{compile, compile_file},
compiler::{compile_file_legacy, compile_legacy},
parser::ast::debug_sym_factory::DebugSymRefFactory,
};

Expand Down Expand Up @@ -678,7 +729,7 @@ mod test {
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let result = compile::<Fr>(
let result = compile_legacy::<Fr>(
circuit,
Config::default().max_degree(2),
&debug_sym_ref_factory,
Expand All @@ -693,14 +744,14 @@ mod test {
#[test]
fn test_compiler_fibo_file() {
let path = "test/circuit.chiquito";
let result = compile_file::<Fr>(path, Config::default().max_degree(2));
let result = compile_file_legacy::<Fr>(path, Config::default().max_degree(2));
assert!(result.is_ok());
}

#[test]
fn test_compiler_fibo_file_err() {
let path = "test/circuit_error.chiquito";
let result = compile_file::<Fr>(path, Config::default().max_degree(2));
let result = compile_file_legacy::<Fr>(path, Config::default().max_degree(2));

assert!(result.is_err());

Expand Down
36 changes: 35 additions & 1 deletion src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use std::{
io::{self, Read},
};

use self::compiler::{Compiler, CompilerResult};
use compiler::CompilerResult;

use self::compiler::{Compiler, CompilerResultLegacy};
use crate::{
field::Field,
parser::ast::{debug_sym_factory::DebugSymRefFactory, DebugSymRef},
Expand Down Expand Up @@ -65,6 +67,38 @@ impl Messages for Vec<Message> {
}
}

/// Compiles chiquito source code string into a SBPIR for a single machine, also returns messages
/// (legacy).
pub fn compile_legacy<F: Field + Hash>(
source: &str,
config: Config,
debug_sym_ref_factory: &DebugSymRefFactory,
) -> Result<CompilerResultLegacy<F>, Vec<Message>> {
Compiler::new(config).compile_legacy(source, debug_sym_ref_factory)
}

/// Compiles chiquito source code file into a SBPIR for a single machine, also returns messages
/// (legacy).
pub fn compile_file_legacy<F: Field + Hash>(
file_path: &str,
config: Config,
) -> Result<CompilerResultLegacy<F>, Vec<Message>> {
let contents = read_file(file_path);
match contents {
Ok(source) => {
let debug_sym_ref_factory = DebugSymRefFactory::new(file_path, source.as_str());
compile_legacy(source.as_str(), config, &debug_sym_ref_factory)
}
Err(e) => {
let msg = format!("Error reading file: {}", e);
let message = Message::ParseErr { msg };
let messages = vec![message];

Err(messages)
}
}
}

/// Compiles chiquito source code string into a SBPIR, also returns messages.
pub fn compile<F: Field + Hash>(
source: &str,
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/setup_inter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub(super) fn interpret(ast: &[TLDecl<BigInt, Identifier>], _symbols: &SymTable)
interpreter.setup
}

/// Machine setup by machine name
pub(super) type Setup<F> = HashMap<String, MachineSetup<F>>;

pub(super) struct MachineSetup<F> {
Expand Down Expand Up @@ -119,6 +120,7 @@ impl<F: Clone> MachineSetup<F> {
struct SetupInterpreter {
abepi: CompilationUnit<BigInt, Identifier>,

/// Machine setup by machine name
setup: Setup<BigInt>,

current_machine: String,
Expand Down
12 changes: 6 additions & 6 deletions src/frontend/dsl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
field::Field,
sbpir::{query::Queriable, ExposeOffset, StepType, StepTypeUUID, PIR, SBPIR},
sbpir::{query::Queriable, ExposeOffset, StepType, StepTypeUUID, PIR, SBPIRLegacy},
util::{uuid, UUID},
wit_gen::{FixedGenContext, StepInstance, TraceGenerator},
};
Expand Down Expand Up @@ -32,7 +32,7 @@ pub mod trace;
/// `F` is the field of the circuit.
/// `TG` is the trace generator.
pub struct CircuitContext<F, TG: TraceGenerator<F> = DSLTraceGenerator<F>> {
circuit: SBPIR<F, TG>,
circuit: SBPIRLegacy<F, TG>,
tables: LookupTableRegistry<F>,
}

Expand Down Expand Up @@ -424,13 +424,13 @@ impl<F, Args, D: Fn(&mut StepInstance<F>, Args) + 'static> StepTypeWGHandler<F,
pub fn circuit<F: Field, TraceArgs: Clone, D>(
_name: &str,
mut def: D,
) -> SBPIR<F, DSLTraceGenerator<F, TraceArgs>>
) -> SBPIRLegacy<F, DSLTraceGenerator<F, TraceArgs>>
where
D: FnMut(&mut CircuitContext<F, DSLTraceGenerator<F, TraceArgs>>),
{
// TODO annotate circuit
let mut context = CircuitContext {
circuit: SBPIR::default(),
circuit: SBPIRLegacy::default(),
tables: LookupTableRegistry::default(),
};

Expand All @@ -453,14 +453,14 @@ mod tests {
TG: TraceGenerator<F>,
{
CircuitContext {
circuit: SBPIR::default(),
circuit: SBPIRLegacy::default(),
tables: Default::default(),
}
}

#[test]
fn test_circuit_default_initialization() {
let circuit: SBPIR<i32, NullTraceGenerator> = SBPIR::default();
let circuit: SBPIRLegacy<i32, NullTraceGenerator> = SBPIRLegacy::default();

// Assert default values
assert!(circuit.step_types.is_empty());
Expand Down
Loading

0 comments on commit ec0cd93

Please sign in to comment.