Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Refactor compiler to return a new structure for multiple machines #284

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions examples/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use chiquito::{
compile, // input for constructing the compiler
config,
step_selector::SimpleStepSelectorBuilder,
PlonkishCompilationResult,
},
ir::{assignments::AssignmentGenerator, Circuit},
}, /* compiles to
* Chiquito Halo2
* backend,
Expand All @@ -24,7 +24,7 @@ use chiquito::{
* Halo2
* circuit */
poly::ToField,
sbpir::SBPIR,
sbpir::SBPIRLegacy,
};
use halo2_proofs::dev::MockProver;

Expand All @@ -35,9 +35,8 @@ use halo2_proofs::dev::MockProver;
// 3. two witness generation arguments both of u64 type, i.e. (u64, u64)

type FiboReturn<F> = (
Circuit<F>,
Option<AssignmentGenerator<F>>,
SBPIR<F, DSLTraceGenerator<F>>,
PlonkishCompilationResult<F, DSLTraceGenerator<F>>,
SBPIRLegacy<F, DSLTraceGenerator<F>>,
);

fn fibo_circuit<F: Field + From<u64> + Hash>() -> FiboReturn<F> {
Expand Down Expand Up @@ -124,17 +123,20 @@ fn fibo_circuit<F: Field + From<u64> + Hash>() -> FiboReturn<F> {
&fibo,
);

(compiled.circuit, compiled.assignment_generator, fibo)
(compiled, fibo)
}

// After compiling Chiquito AST to an IR, it is further parsed by a Chiquito Halo2 backend and
// integrated into a Halo2 circuit, which is done by the boilerplate code below.

// standard main function for a Halo2 circuit
fn main() {
let (chiquito, wit_gen, _) = fibo_circuit::<Fr>();
let compiled = chiquito2Halo2(chiquito);
let circuit = ChiquitoHalo2Circuit::new(compiled, wit_gen.map(|g| g.generate(())));
let (chiquito, _) = fibo_circuit::<Fr>();
let compiled = chiquito2Halo2(chiquito.circuit);
let circuit = ChiquitoHalo2Circuit::new(
compiled,
chiquito.assignment_generator.map(|g| g.generate(())),
);

let prover = MockProver::<Fr>::run(7, &circuit, circuit.instance()).unwrap();

Expand All @@ -156,11 +158,11 @@ fn main() {
pcs::{multilinear, univariate},
};
// get Chiquito ir
let (circuit, assignment_generator, _) = fibo_circuit::<Fr>();
let (plonkish, _) = fibo_circuit::<Fr>();
// get assignments
let assignments = assignment_generator.unwrap().generate(());
let assignments = plonkish.assignment_generator.unwrap().generate(());
// get hyperplonk circuit
let mut hyperplonk_circuit = ChiquitoHyperPlonkCircuit::new(4, circuit);
let mut hyperplonk_circuit = ChiquitoHyperPlonkCircuit::new(4, plonkish.circuit);
hyperplonk_circuit.set_assignment(assignments);

type GeminiKzg = multilinear::Gemini<univariate::UnivariateKzg<Bn256>>;
Expand All @@ -170,10 +172,15 @@ fn main() {
// pil boilerplate
use chiquito::pil::backend::powdr_pil::chiquito2Pil;

let (_, wit_gen, circuit) = fibo_circuit::<Fr>();
let (plonkish, circuit) = fibo_circuit::<Fr>();
let pil = chiquito2Pil(
circuit,
Some(wit_gen.unwrap().generate_trace_witness(())),
Some(
plonkish
.assignment_generator
.unwrap()
.generate_trace_witness(()),
),
String::from("FiboCircuit"),
);
print!("{}", pil);
Expand Down
81 changes: 66 additions & 15 deletions src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
},
plonkish::{self, compiler::PlonkishCompilationResult},
poly::{self, mielim::mi_elimination, reduce::reduce_degree, Expr},
sbpir::{query::Queriable, InternalSignal, SBPIR},
sbpir::{query::Queriable, InternalSignal, SBPIRLegacy, SBPIR},
wit_gen::{NullTraceGenerator, SymbolSignalMapping, TraceGenerator},
};

Expand All @@ -31,15 +31,19 @@ use super::{
Config, Message, Messages,
};

/// Contains the result of a compilation.
#[derive(Debug)]
pub struct CompilerResult<F: Field + Hash> {
pub messages: Vec<Message>,
// pub wit_gen: WitnessGenerator,
pub circuit: SBPIR<F, InterpreterTraceGenerator>,
}

impl<F: Field + Hash> CompilerResult<F> {
/// Contains the result of a single machine compilation (legacy).
#[derive(Debug)]
pub struct CompilerResultLegacy<F: Field + Hash> {
pub messages: Vec<Message>,
pub circuit: SBPIRLegacy<F, InterpreterTraceGenerator>,
}

impl<F: Field + Hash> CompilerResultLegacy<F> {
/// Compiles to the Plonkish IR, that then can be compiled to plonkish backends.
pub fn plonkish<
CM: plonkish::compiler::cell_manager::CellManager,
Expand Down Expand Up @@ -76,6 +80,41 @@ impl<F: Field + Hash> Compiler<F> {
}
}

/// Compile the source code containing a single machine (legacy).
pub(super) fn compile_legacy(
mut self,
source: &str,
debug_sym_ref_factory: &DebugSymRefFactory,
) -> Result<CompilerResultLegacy<F>, Vec<Message>> {
let ast = self
.parse(source, debug_sym_ref_factory)
.map_err(|_| self.messages.clone())?;
assert!(ast.len() == 1, "Use `compile` to compile multiple machines");
let ast = self.add_virtual(ast);
let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?;
let setup = Self::interpret(&ast, &symbols);
let setup = Self::map_consts(setup);
let circuit = self.build(&setup, &symbols);
let circuit = Self::mi_elim(circuit);
let circuit = if let Some(degree) = self.config.max_degree {
Self::reduce(circuit, degree)
} else {
circuit
};

let circuit = circuit.with_trace(InterpreterTraceGenerator::new(
ast,
symbols,
self.mapping,
self.config.max_steps,
));

Ok(CompilerResultLegacy {
messages: self.messages,
circuit,
})
}

/// Compile the source code.
pub(super) fn compile(
mut self,
Expand All @@ -89,6 +128,9 @@ impl<F: Field + Hash> Compiler<F> {
let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?;
let setup = Self::interpret(&ast, &symbols);
let setup = Self::map_consts(setup);

let machine_id = setup.iter().next().unwrap().0;

let circuit = self.build(&setup, &symbols);
let circuit = Self::mi_elim(circuit);
let circuit = if let Some(degree) = self.config.max_degree {
Expand All @@ -104,9 +146,12 @@ impl<F: Field + Hash> Compiler<F> {
self.config.max_steps,
));

// TODO perform real compilation for multiple machines
let sbpir = SBPIR::from_legacy(circuit, machine_id.as_str());

Ok(CompilerResult {
messages: self.messages,
circuit,
circuit: sbpir,
})
}

Expand Down Expand Up @@ -287,7 +332,11 @@ impl<F: Field + Hash> Compiler<F> {
}
}

fn build(&mut self, setup: &Setup<F>, symbols: &SymTable) -> SBPIR<F, NullTraceGenerator> {
fn build(
&mut self,
setup: &Setup<F>,
symbols: &SymTable,
) -> SBPIRLegacy<F, NullTraceGenerator> {
circuit::<F, (), _>("circuit", |ctx| {
for (machine_id, machine) in setup {
self.add_forwards(ctx, symbols, machine_id);
Expand Down Expand Up @@ -327,7 +376,9 @@ impl<F: Field + Hash> Compiler<F> {
.without_trace()
}

fn mi_elim(mut circuit: SBPIR<F, NullTraceGenerator>) -> SBPIR<F, NullTraceGenerator> {
fn mi_elim(
mut circuit: SBPIRLegacy<F, NullTraceGenerator>,
) -> SBPIRLegacy<F, NullTraceGenerator> {
for (_, step_type) in circuit.step_types.iter_mut() {
let mut signal_factory = SignalFactory::default();

Expand All @@ -338,9 +389,9 @@ impl<F: Field + Hash> Compiler<F> {
}

fn reduce(
mut circuit: SBPIR<F, NullTraceGenerator>,
mut circuit: SBPIRLegacy<F, NullTraceGenerator>,
degree: usize,
) -> SBPIR<F, NullTraceGenerator> {
) -> SBPIRLegacy<F, NullTraceGenerator> {
for (_, step_type) in circuit.step_types.iter_mut() {
let mut signal_factory = SignalFactory::default();

Expand All @@ -353,7 +404,7 @@ impl<F: Field + Hash> Compiler<F> {
}

#[allow(dead_code)]
fn cse(mut _circuit: SBPIR<F, NullTraceGenerator>) -> SBPIR<F, NullTraceGenerator> {
fn cse(mut _circuit: SBPIRLegacy<F, NullTraceGenerator>) -> SBPIRLegacy<F, NullTraceGenerator> {
todo!()
}

Expand Down Expand Up @@ -627,7 +678,7 @@ mod test {
use halo2_proofs::halo2curves::bn256::Fr;

use crate::{
compiler::{compile, compile_file},
compiler::{compile_file_legacy, compile_legacy},
parser::ast::debug_sym_factory::DebugSymRefFactory,
};

Expand Down Expand Up @@ -678,7 +729,7 @@ mod test {
";

let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit);
let result = compile::<Fr>(
let result = compile_legacy::<Fr>(
circuit,
Config::default().max_degree(2),
&debug_sym_ref_factory,
Expand All @@ -693,14 +744,14 @@ mod test {
#[test]
fn test_compiler_fibo_file() {
let path = "test/circuit.chiquito";
let result = compile_file::<Fr>(path, Config::default().max_degree(2));
let result = compile_file_legacy::<Fr>(path, Config::default().max_degree(2));
assert!(result.is_ok());
}

#[test]
fn test_compiler_fibo_file_err() {
let path = "test/circuit_error.chiquito";
let result = compile_file::<Fr>(path, Config::default().max_degree(2));
let result = compile_file_legacy::<Fr>(path, Config::default().max_degree(2));

assert!(result.is_err());

Expand Down
36 changes: 35 additions & 1 deletion src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use std::{
io::{self, Read},
};

use self::compiler::{Compiler, CompilerResult};
use compiler::CompilerResult;

use self::compiler::{Compiler, CompilerResultLegacy};
use crate::{
field::Field,
parser::ast::{debug_sym_factory::DebugSymRefFactory, DebugSymRef},
Expand Down Expand Up @@ -65,6 +67,38 @@ impl Messages for Vec<Message> {
}
}

/// Compiles chiquito source code string into a SBPIR for a single machine, also returns messages
/// (legacy).
pub fn compile_legacy<F: Field + Hash>(
source: &str,
config: Config,
debug_sym_ref_factory: &DebugSymRefFactory,
) -> Result<CompilerResultLegacy<F>, Vec<Message>> {
Compiler::new(config).compile_legacy(source, debug_sym_ref_factory)
}

/// Compiles chiquito source code file into a SBPIR for a single machine, also returns messages
/// (legacy).
pub fn compile_file_legacy<F: Field + Hash>(
file_path: &str,
config: Config,
) -> Result<CompilerResultLegacy<F>, Vec<Message>> {
let contents = read_file(file_path);
match contents {
Ok(source) => {
let debug_sym_ref_factory = DebugSymRefFactory::new(file_path, source.as_str());
compile_legacy(source.as_str(), config, &debug_sym_ref_factory)
}
Err(e) => {
let msg = format!("Error reading file: {}", e);
let message = Message::ParseErr { msg };
let messages = vec![message];

Err(messages)
}
}
}

/// Compiles chiquito source code string into a SBPIR, also returns messages.
pub fn compile<F: Field + Hash>(
source: &str,
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/setup_inter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub(super) fn interpret(ast: &[TLDecl<BigInt, Identifier>], _symbols: &SymTable)
interpreter.setup
}

/// Machine setup by machine name
pub(super) type Setup<F> = HashMap<String, MachineSetup<F>>;

pub(super) struct MachineSetup<F> {
Expand Down Expand Up @@ -119,6 +120,7 @@ impl<F: Clone> MachineSetup<F> {
struct SetupInterpreter {
abepi: CompilationUnit<BigInt, Identifier>,

/// Machine setup by machine name
setup: Setup<BigInt>,

current_machine: String,
Expand Down
12 changes: 6 additions & 6 deletions src/frontend/dsl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
field::Field,
sbpir::{query::Queriable, ExposeOffset, StepType, StepTypeUUID, PIR, SBPIR},
sbpir::{query::Queriable, ExposeOffset, StepType, StepTypeUUID, PIR, SBPIRLegacy},
util::{uuid, UUID},
wit_gen::{FixedGenContext, StepInstance, TraceGenerator},
};
Expand Down Expand Up @@ -32,7 +32,7 @@ pub mod trace;
/// `F` is the field of the circuit.
/// `TG` is the trace generator.
pub struct CircuitContext<F, TG: TraceGenerator<F> = DSLTraceGenerator<F>> {
circuit: SBPIR<F, TG>,
circuit: SBPIRLegacy<F, TG>,
tables: LookupTableRegistry<F>,
}

Expand Down Expand Up @@ -424,13 +424,13 @@ impl<F, Args, D: Fn(&mut StepInstance<F>, Args) + 'static> StepTypeWGHandler<F,
pub fn circuit<F: Field, TraceArgs: Clone, D>(
_name: &str,
mut def: D,
) -> SBPIR<F, DSLTraceGenerator<F, TraceArgs>>
) -> SBPIRLegacy<F, DSLTraceGenerator<F, TraceArgs>>
where
D: FnMut(&mut CircuitContext<F, DSLTraceGenerator<F, TraceArgs>>),
{
// TODO annotate circuit
let mut context = CircuitContext {
circuit: SBPIR::default(),
circuit: SBPIRLegacy::default(),
tables: LookupTableRegistry::default(),
};

Expand All @@ -453,14 +453,14 @@ mod tests {
TG: TraceGenerator<F>,
{
CircuitContext {
circuit: SBPIR::default(),
circuit: SBPIRLegacy::default(),
tables: Default::default(),
}
}

#[test]
fn test_circuit_default_initialization() {
let circuit: SBPIR<i32, NullTraceGenerator> = SBPIR::default();
let circuit: SBPIRLegacy<i32, NullTraceGenerator> = SBPIRLegacy::default();

// Assert default values
assert!(circuit.step_types.is_empty());
Expand Down
Loading
Loading