From c07f98dceb3025af8fe10a6c15e02cff554be7c7 Mon Sep 17 00:00:00 2001 From: Alex Kuzmin <6849426+alxkzmn@users.noreply.github.com> Date: Thu, 22 Aug 2024 19:42:14 +0800 Subject: [PATCH] Implement compilation to new SBPIR with multiple machines (#288) `compiler_legacy.rs` contains unmodified legacy code. --- examples/blake2f.rs | 16 +- examples/factorial.rs | 7 +- examples/fibo_with_padding.rs | 7 +- examples/fibonacci.rs | 7 +- examples/keccak.rs | 17 +- examples/mimc7.rs | 8 +- examples/poseidon.rs | 10 +- src/compiler/compiler.rs | 608 +++++++++++++++------------ src/compiler/compiler_legacy.rs | 721 ++++++++++++++++++++++++++++++++ src/compiler/mod.rs | 7 +- src/compiler/setup_inter.rs | 47 ++- src/frontend/dsl/mod.rs | 36 +- src/frontend/dsl/sc.rs | 14 +- src/interpreter/mod.rs | 2 +- src/poly/mielim.rs | 4 +- src/poly/mod.rs | 4 +- src/sbpir/mod.rs | 96 ++++- src/sbpir/query.rs | 17 +- src/sbpir/sbpir_machine.rs | 84 ++-- 19 files changed, 1327 insertions(+), 385 deletions(-) create mode 100644 src/compiler/compiler_legacy.rs diff --git a/examples/blake2f.rs b/examples/blake2f.rs index 558461ce..187c3c45 100644 --- a/examples/blake2f.rs +++ b/examples/blake2f.rs @@ -4,7 +4,7 @@ use chiquito::{ lb::LookupTable, super_circuit, trace::DSLTraceGenerator, - CircuitContext, StepTypeSetupContext, StepTypeWGHandler, + CircuitContextLegacy, StepTypeSetupContext, StepTypeWGHandler, }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, @@ -127,7 +127,10 @@ pub fn split_to_4bits_values(vec_values: &[u64]) -> Vec(ctx: &mut CircuitContext, _: usize) -> LookupTable { +fn blake2f_iv_table( + ctx: &mut CircuitContextLegacy, + _: usize, +) -> LookupTable { let lookup_iv_row: Queriable = ctx.fixed("iv row"); let lookup_iv_value: Queriable = ctx.fixed("iv value"); @@ -144,7 +147,10 @@ fn blake2f_iv_table(ctx: &mut CircuitContext, _: usize) } // For range checking -fn blake2f_4bits_table(ctx: &mut CircuitContext, _: usize) -> LookupTable { +fn blake2f_4bits_table( + ctx: &mut CircuitContextLegacy, + _: usize, +) -> LookupTable { let lookup_4bits_row: Queriable = ctx.fixed("4bits row"); let lookup_4bits_value: Queriable = ctx.fixed("4bits value"); @@ -160,7 +166,7 @@ fn blake2f_4bits_table(ctx: &mut CircuitContext, _: usi } fn blake2f_xor_4bits_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, _: usize, ) -> LookupTable { let lookup_xor_row: Queriable = ctx.fixed("xor row"); @@ -526,7 +532,7 @@ fn g_setup( } fn blake2f_circuit( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, params: CircuitParams, ) { let v_vec: Vec> = (0..V_LEN) diff --git a/examples/factorial.rs b/examples/factorial.rs index 9bef03aa..37196e11 100644 --- a/examples/factorial.rs +++ b/examples/factorial.rs @@ -2,8 +2,9 @@ use std::hash::Hash; use chiquito::{ field::Field, - frontend::dsl::{circuit, trace::DSLTraceGenerator}, /* main function for constructing an AST - * circuit */ + frontend::dsl::{circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing + * an AST + * circuit */ plonkish::{ backend::halo2_legacy::{chiquito2Halo2, ChiquitoHalo2Circuit}, compiler::{ @@ -42,7 +43,7 @@ fn generate + Hash>() -> PlonkishCompilationResult("factorial", |ctx| { + let factorial_circuit = circuit_legacy::("factorial", |ctx| { let i = ctx.shared("i"); let x = ctx.forward("x"); diff --git a/examples/fibo_with_padding.rs b/examples/fibo_with_padding.rs index d93c5bbd..2c6df8b3 100644 --- a/examples/fibo_with_padding.rs +++ b/examples/fibo_with_padding.rs @@ -2,8 +2,9 @@ use std::hash::Hash; use chiquito::{ field::Field, - frontend::dsl::{circuit, trace::DSLTraceGenerator}, /* main function for constructing an AST - * circuit */ + frontend::dsl::{circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing + * an AST + * circuit */ plonkish::{ backend::halo2_legacy::{chiquito2Halo2, ChiquitoHalo2Circuit}, compiler::{ @@ -37,7 +38,7 @@ fn fibo_circuit + Hash>( sbpir::ExposeOffset::*, // for exposing witnesses }; - let fibo = circuit::("fibonacci", |ctx| { + let fibo = circuit_legacy::("fibonacci", |ctx| { // Example table for 7 rounds: // | step_type | a | b | c | n | // --------------------------------------- diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 7dc16302..e83c3fd4 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -2,8 +2,9 @@ use std::hash::Hash; use chiquito::{ field::Field, - frontend::dsl::{circuit, trace::DSLTraceGenerator}, /* main function for constructing an AST - * circuit */ + frontend::dsl::{circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing + * an AST + * circuit */ plonkish::{ backend::{ halo2_legacy::{chiquito2Halo2, ChiquitoHalo2Circuit}, @@ -50,7 +51,7 @@ fn fibo_circuit + Hash>() -> FiboReturn { use chiquito::frontend::dsl::cb::*; // functions for constraint building - let fibo = circuit::("fibonacci", |ctx| { + let fibo = circuit_legacy::("fibonacci", |ctx| { // the following objects (forward signals, steptypes) are defined on the circuit-level // forward signals can have constraints across different steps diff --git a/examples/keccak.rs b/examples/keccak.rs index 5a22f07e..10d66a37 100644 --- a/examples/keccak.rs +++ b/examples/keccak.rs @@ -1,6 +1,7 @@ use chiquito::{ frontend::dsl::{ - lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContext, StepTypeWGHandler, + lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContextLegacy, + StepTypeWGHandler, }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, @@ -231,7 +232,7 @@ fn eval_keccak_f_to_bit_vec4>(value1: F, value2: } fn keccak_xor_table_batch2( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, lens: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -254,7 +255,7 @@ fn keccak_xor_table_batch2( } fn keccak_xor_table_batch3( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, lens: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -280,7 +281,7 @@ fn keccak_xor_table_batch3( } fn keccak_xor_table_batch4( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, lens: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -306,7 +307,7 @@ fn keccak_xor_table_batch4( } fn keccak_chi_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, lens: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -332,7 +333,7 @@ fn keccak_chi_table( } fn keccak_pack_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, _: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -362,7 +363,7 @@ fn keccak_pack_table( } fn keccak_round_constants_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, lens: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -722,7 +723,7 @@ fn eval_keccak_f_one_round + Eq + Hash>( } fn keccak_circuit + Eq + Hash>( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, param: CircuitParams, ) { use chiquito::frontend::dsl::cb::*; diff --git a/examples/mimc7.rs b/examples/mimc7.rs index 468d0868..2dab9976 100644 --- a/examples/mimc7.rs +++ b/examples/mimc7.rs @@ -6,7 +6,9 @@ use halo2_proofs::{ }; use chiquito::{ - frontend::dsl::{lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContext}, + frontend::dsl::{ + lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContextLegacy, + }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, compiler::{ @@ -23,7 +25,7 @@ use mimc7_constants::ROUND_CONSTANTS; pub const ROUNDS: usize = 91; fn mimc7_constants( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, _: (), ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -49,7 +51,7 @@ fn mimc7_constants( } fn mimc7_circuit( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, constants: LookupTable, ) { use chiquito::frontend::dsl::cb::*; diff --git a/examples/poseidon.rs b/examples/poseidon.rs index 56b5f3e3..43e7174d 100644 --- a/examples/poseidon.rs +++ b/examples/poseidon.rs @@ -1,5 +1,7 @@ use chiquito::{ - frontend::dsl::{lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContext}, + frontend::dsl::{ + lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContextLegacy, + }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, compiler::{ @@ -49,7 +51,7 @@ struct CircuitParams { } fn poseidon_constants_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, param_t: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -75,7 +77,7 @@ fn poseidon_constants_table( } fn poseidon_matrix_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, param_t: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -97,7 +99,7 @@ fn poseidon_matrix_table( } fn poseidon_circuit( - ctx: &mut CircuitContext>>, + ctx: &mut CircuitContextLegacy>>, param: CircuitParams, ) { use chiquito::frontend::dsl::cb::*; diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index 44bb6d50..f4d80d20 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -4,10 +4,6 @@ use num_bigint::BigInt; use crate::{ field::Field, - frontend::dsl::{ - cb::{Constraint, Typing}, - circuit, CircuitContext, StepTypeContext, - }, interpreter::InterpreterTraceGenerator, parser::{ ast::{ @@ -19,10 +15,9 @@ use crate::{ }, lang::TLDeclsParser, }, - plonkish::{self, compiler::PlonkishCompilationResult}, - poly::{self, mielim::mi_elimination, reduce::reduce_degree, Expr}, - sbpir::{query::Queriable, InternalSignal, SBPIRLegacy, SBPIR}, - wit_gen::{NullTraceGenerator, SymbolSignalMapping, TraceGenerator}, + poly::Expr, + sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, Constraint, StepType, SBPIR}, + wit_gen::{NullTraceGenerator, SymbolSignalMapping}, }; use super::{ @@ -31,39 +26,20 @@ use super::{ Config, Message, Messages, }; +#[derive(Debug)] pub struct CompilerResult { pub messages: Vec, pub circuit: SBPIR, } -/// Contains the result of a single machine compilation (legacy). -#[derive(Debug)] -pub struct CompilerResultLegacy { - pub messages: Vec, - pub circuit: SBPIRLegacy, -} - -impl CompilerResultLegacy { - /// Compiles to the Plonkish IR, that then can be compiled to plonkish backends. - pub fn plonkish< - CM: plonkish::compiler::cell_manager::CellManager, - SSB: plonkish::compiler::step_selector::StepSelectorBuilder, - >( - &self, - config: plonkish::compiler::CompilerConfig, - ) -> PlonkishCompilationResult { - plonkish::compiler::compile(config, &self.circuit) - } -} - /// This compiler compiles from chiquito source code to the SBPIR. #[derive(Default)] pub(super) struct Compiler { pub(super) config: Config, - messages: Vec, + pub(super) messages: Vec, - mapping: SymbolSignalMapping, + pub(super) mapping: SymbolSignalMapping, _p: PhantomData, } @@ -80,41 +56,6 @@ impl Compiler { } } - /// Compile the source code containing a single machine (legacy). - pub(super) fn compile_legacy( - mut self, - source: &str, - debug_sym_ref_factory: &DebugSymRefFactory, - ) -> Result, Vec> { - let ast = self - .parse(source, debug_sym_ref_factory) - .map_err(|_| self.messages.clone())?; - assert!(ast.len() == 1, "Use `compile` to compile multiple machines"); - let ast = self.add_virtual(ast); - let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?; - let setup = Self::interpret(&ast, &symbols); - let setup = Self::map_consts(setup); - let circuit = self.build(&setup, &symbols); - let circuit = Self::mi_elim(circuit); - let circuit = if let Some(degree) = self.config.max_degree { - Self::reduce(circuit, degree) - } else { - circuit - }; - - let circuit = circuit.with_trace(InterpreterTraceGenerator::new( - ast, - symbols, - self.mapping, - self.config.max_steps, - )); - - Ok(CompilerResultLegacy { - messages: self.messages, - circuit, - }) - } - /// Compile the source code. pub(super) fn compile( mut self, @@ -126,32 +67,30 @@ impl Compiler { .map_err(|_| self.messages.clone())?; let ast = self.add_virtual(ast); let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?; - let setup = Self::interpret(&ast, &symbols); - let setup = Self::map_consts(setup); - - let machine_id = setup.iter().next().unwrap().0; + let machine_setups = Self::interpret(&ast, &symbols); + let machine_setups = machine_setups + .iter() + .map(|(k, v)| (k.clone(), v.map_consts())) + .collect(); - let circuit = self.build(&setup, &symbols); - let circuit = Self::mi_elim(circuit); + let circuit = self.build(&machine_setups, &symbols); + let circuit = circuit.eliminate_mul_inv(); let circuit = if let Some(degree) = self.config.max_degree { - Self::reduce(circuit, degree) + circuit.reduce(degree) } else { circuit }; - let circuit = circuit.with_trace(InterpreterTraceGenerator::new( + let circuit = circuit.with_trace(&InterpreterTraceGenerator::new( ast, symbols, self.mapping, self.config.max_steps, )); - // TODO perform real compilation for multiple machines - let sbpir = SBPIR::from_legacy(circuit, machine_id.as_str()); - Ok(CompilerResult { messages: self.messages, - circuit: sbpir, + circuit, }) } @@ -173,6 +112,7 @@ impl Compiler { } } + /// Adds "virtual" states to the AST (necessary to handle padding) fn add_virtual( &mut self, mut ast: Vec>, @@ -280,6 +220,8 @@ impl Compiler { } } + /// Semantic analysis of the AST + /// Returns the symbol table if successful fn semantic(&mut self, ast: &[TLDecl]) -> Result { let result = super::semantic::analyser::analyse(ast); let has_errors = result.messages.has_errors(); @@ -297,140 +239,71 @@ impl Compiler { interpret(ast, symbols) } - fn map_consts(setup: Setup) -> Setup { - setup - .iter() - .map(|(machine_id, machine)| { - let poly_constraints: HashMap>> = machine - .iter_states_poly_constraints() - .map(|(step_id, step)| { - let new_step: Vec> = - step.iter().map(|pi| Self::map_pi_consts(pi)).collect(); - - (step_id.clone(), new_step) - }) - .collect(); - - let new_machine: MachineSetup = - machine.replace_poly_constraints(poly_constraints); - (machine_id.clone(), new_machine) - }) - .collect() - } - - fn map_pi_consts(expr: &Expr) -> Expr { - use Expr::*; - match expr { - Const(v, _) => Const(F::from_big_int(v), ()), - Sum(ses, _) => Sum(ses.iter().map(|se| Self::map_pi_consts(se)).collect(), ()), - Mul(ses, _) => Mul(ses.iter().map(|se| Self::map_pi_consts(se)).collect(), ()), - Neg(se, _) => Neg(Box::new(Self::map_pi_consts(se)), ()), - Pow(se, exp, _) => Pow(Box::new(Self::map_pi_consts(se)), *exp, ()), - Query(q, _) => Query(q.clone(), ()), - Halo2Expr(_, _) => todo!(), - MI(se, _) => MI(Box::new(Self::map_pi_consts(se)), ()), - } - } + fn build(&mut self, setup: &Setup, symbols: &SymTable) -> SBPIR { + let mut sbpir = SBPIR::default(); - fn build( - &mut self, - setup: &Setup, - symbols: &SymTable, - ) -> SBPIRLegacy { - circuit::("circuit", |ctx| { - for (machine_id, machine) in setup { - self.add_forwards(ctx, symbols, machine_id); - self.add_step_type_handlers(ctx, symbols, machine_id); - - ctx.pragma_num_steps(self.config.max_steps); - ctx.pragma_first_step(self.mapping.get_step_type_handler(machine_id, "initial")); - ctx.pragma_last_step(self.mapping.get_step_type_handler(machine_id, "__padding")); - - for state_id in machine.states() { - ctx.step_type_def( - self.mapping.get_step_type_handler(machine_id, state_id), - |ctx| { - self.add_internals(ctx, symbols, machine_id, state_id); - - ctx.setup(|ctx| { - let poly_constraints = - self.translate_queries(symbols, setup, machine_id, state_id); - poly_constraints.iter().for_each(|poly| { - let constraint = Constraint { - annotation: format!("{:?}", poly), - expr: poly.clone(), - typing: Typing::AntiBooly, - }; - ctx.constr(constraint); - }); - }); - - ctx.wg(|_, _: ()| {}) - }, - ); - } - } + for (machine_name, machine_setup) in setup { + let mut sbpir_machine = SBPIRMachine::default(); + self.add_forward_signals(&mut sbpir_machine, symbols, machine_name); + self.add_step_type_handlers(&mut sbpir_machine, symbols, machine_name); - ctx.trace(|_, _| {}); - }) - .without_trace() - } - - fn mi_elim( - mut circuit: SBPIRLegacy, - ) -> SBPIRLegacy { - for (_, step_type) in circuit.step_types.iter_mut() { - let mut signal_factory = SignalFactory::default(); - - step_type.decomp_constraints(|expr| mi_elimination(expr.clone(), &mut signal_factory)); - } + sbpir_machine.num_steps = self.config.max_steps; + sbpir_machine.first_step = Some( + self.mapping + .get_step_type_handler(machine_name, "initial") + .uuid(), + ); + sbpir_machine.last_step = Some( + self.mapping + .get_step_type_handler(machine_name, "__padding") + .uuid(), + ); - circuit - } + for state_id in machine_setup.states() { + let step_type = + self.create_step_type(symbols, machine_name, machine_setup, state_id); - fn reduce( - mut circuit: SBPIRLegacy, - degree: usize, - ) -> SBPIRLegacy { - for (_, step_type) in circuit.step_types.iter_mut() { - let mut signal_factory = SignalFactory::default(); + sbpir_machine.add_step_type_def(step_type); + } - step_type.decomp_constraints(|expr| { - reduce_degree(expr.clone(), degree, &mut signal_factory) - }); + sbpir.machines.insert(machine_name.clone(), sbpir_machine); } - circuit + sbpir.without_trace() } #[allow(dead_code)] - fn cse(mut _circuit: SBPIRLegacy) -> SBPIRLegacy { + fn cse(mut _circuit: SBPIR) -> SBPIR { todo!() } - fn translate_queries( + /// Translate the queries to constraints + fn queries_into_constraints( &mut self, symbols: &SymTable, - setup: &Setup, - machine_id: &str, + setup: &MachineSetup, + machine_name: &str, state_id: &str, - ) -> Vec, ()>> { - let exprs = setup - .get(machine_id) - .unwrap() - .get_poly_constraints(state_id) - .unwrap(); + ) -> Vec> { + let exprs = setup.get_poly_constraints(state_id).unwrap(); exprs .iter() - .map(|expr| self.translate_queries_expr(symbols, machine_id, state_id, expr)) + .map(|expr| { + let translate_queries_expr = + self.translate_queries_expr(symbols, machine_name, state_id, expr); + Constraint { + annotation: format!("{:?}", translate_queries_expr), + expr: translate_queries_expr.clone(), + } + }) .collect() } fn translate_queries_expr( &mut self, symbols: &SymTable, - machine_id: &str, + machine_name: &str, state_id: &str, expr: &Expr, ) -> Expr, ()> { @@ -439,38 +312,41 @@ impl Compiler { Const(v, _) => Const(*v, ()), Sum(ses, _) => Sum( ses.iter() - .map(|se| self.translate_queries_expr(symbols, machine_id, state_id, se)) + .map(|se| self.translate_queries_expr(symbols, machine_name, state_id, se)) .collect(), (), ), Mul(ses, _) => Mul( ses.iter() - .map(|se| self.translate_queries_expr(symbols, machine_id, state_id, se)) + .map(|se| self.translate_queries_expr(symbols, machine_name, state_id, se)) .collect(), (), ), Neg(se, _) => Neg( - Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + Box::new(self.translate_queries_expr(symbols, machine_name, state_id, se.as_ref())), (), ), Pow(se, exp, _) => Pow( - Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + Box::new(self.translate_queries_expr(symbols, machine_name, state_id, se.as_ref())), *exp, (), ), MI(se, _) => MI( - Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + Box::new(self.translate_queries_expr(symbols, machine_name, state_id, se.as_ref())), (), ), Halo2Expr(se, _) => Halo2Expr(se.clone(), ()), - Query(id, _) => Query(self.translate_query(symbols, machine_id, state_id, id), ()), + Query(id, _) => Query( + self.translate_query(symbols, machine_name, state_id, id), + (), + ), } } fn translate_query( &mut self, symbols: &SymTable, - machine_id: &str, + machine_name: &str, state_id: &str, id: &Identifier, ) -> Queriable { @@ -480,7 +356,7 @@ impl Compiler { .find_symbol( &[ "/".to_string(), - machine_id.to_string(), + machine_name.to_string(), state_id.to_string(), ], id.name(), @@ -489,17 +365,17 @@ impl Compiler { match symbol.symbol.category { InputSignal | OutputSignal | InoutSignal => { - self.translate_forward_queriable(machine_id, id) + self.translate_forward_queriable(machine_name, id) } Signal => match symbol.scope_cat { - ScopeCategory::Machine => self.translate_forward_queriable(machine_id, id), + ScopeCategory::Machine => self.translate_forward_queriable(machine_name, id), ScopeCategory::State => { if id.rotation() != 0 { unreachable!("semantic analyser should prevent this"); } let signal = self .mapping - .get_internal(&format!("//{}/{}", machine_id, state_id), &id.name()); + .get_internal(&format!("//{}/{}", machine_name, state_id), &id.name()); Queriable::Internal(signal) } @@ -507,16 +383,16 @@ impl Compiler { ScopeCategory::Global => unreachable!("no global signals"), }, - State => { - Queriable::StepTypeNext(self.mapping.get_step_type_handler(machine_id, &id.name())) - } + State => Queriable::StepTypeNext( + self.mapping.get_step_type_handler(machine_name, &id.name()), + ), _ => unreachable!("semantic analysis should prevent this"), } } - fn translate_forward_queriable(&mut self, machine_id: &str, id: &Identifier) -> Queriable { - let forward = self.mapping.get_forward(machine_id, &id.name()); + fn translate_forward_queriable(&mut self, machine_name: &str, id: &Identifier) -> Queriable { + let forward = self.mapping.get_forward(machine_name, &id.name()); let rot = if id.rotation() == 0 { false } else if id.rotation() == 1 { @@ -531,13 +407,13 @@ impl Compiler { fn get_all_internals( &mut self, symbols: &SymTable, - machine_id: &str, + machine_name: &str, state_id: &str, ) -> Vec { let symbols = symbols .get_scope(&[ "/".to_string(), - machine_id.to_string(), + machine_name.to_string(), state_id.to_string(), ]) .expect("scope not found") @@ -551,41 +427,56 @@ impl Compiler { .collect() } - fn add_internals( + fn create_step_type( &mut self, - ctx: &mut StepTypeContext, symbols: &SymTable, - machine_id: &str, + machine_name: &str, + machine_setup: &MachineSetup, + state_id: &str, + ) -> StepType { + let handler = self.mapping.get_step_type_handler(machine_name, state_id); + + let mut step_type: StepType = + StepType::new(handler.uuid(), handler.annotation.to_string()); + + self.add_internal_signals(symbols, machine_name, &mut step_type, state_id); + + let poly_constraints = + self.queries_into_constraints(symbols, machine_setup, machine_name, state_id); + + step_type.constraints = poly_constraints.clone(); + + step_type + } + + fn add_internal_signals( + &mut self, + symbols: &SymTable, + machine_name: &str, + step_type: &mut StepType, state_id: &str, ) { - let internal_ids = self.get_all_internals(symbols, machine_id, state_id); - let scope_name = format!("//{}/{}", machine_id, state_id); + let internal_ids = self.get_all_internals(symbols, machine_name, state_id); + let scope_name = format!("//{}/{}", machine_name, state_id); for internal_id in internal_ids { let name = format!("{}:{}", &scope_name, internal_id); + let signal = step_type.add_signal(name.as_str()); - let queriable = ctx.internal(name.as_str()); - if let Queriable::Internal(signal) = queriable { - self.mapping - .symbol_uuid - .insert((scope_name.clone(), internal_id), signal.uuid()); - self.mapping.internal_signals.insert(signal.uuid(), signal); - } else { - unreachable!("ctx.internal returns not internal signal"); - } + self.mapping + .symbol_uuid + .insert((scope_name.clone(), internal_id), signal.uuid()); + self.mapping.internal_signals.insert(signal.uuid(), signal); } } - fn add_step_type_handlers>( + fn add_step_type_handlers( &mut self, - ctx: &mut CircuitContext, + machine: &mut SBPIRMachine, symbols: &SymTable, - machine_id: &str, + machine_name: &str, ) { - let symbols = symbols - .get_scope(&["/".to_string(), machine_id.to_string()]) - .expect("scope not found") - .get_symbols(); + let symbols = get_symbols(symbols, machine_name); let state_ids: Vec<_> = symbols .iter() @@ -595,10 +486,11 @@ impl Compiler { .collect(); for state_id in state_ids { - let scope_name = format!("//{}", machine_id); + let scope_name = format!("//{}", machine_name); let name = format!("{}:{}", scope_name, state_id); - let handler = ctx.step_type(&name); + let handler = machine.add_step_type(name); + self.mapping .step_type_handler .insert(handler.uuid(), handler); @@ -608,16 +500,13 @@ impl Compiler { } } - fn add_forwards>( + fn add_forward_signals( &mut self, - ctx: &mut CircuitContext, + machine: &mut SBPIRMachine, symbols: &SymTable, - machine_id: &str, + machine_name: &str, ) { - let symbols = symbols - .get_scope(&["/".to_string(), machine_id.to_string()]) - .expect("scope not found") - .get_symbols(); + let symbols = get_symbols(symbols, machine_name); let forward_ids: Vec<_> = symbols .iter() @@ -627,17 +516,16 @@ impl Compiler { .collect(); for forward_id in forward_ids { - let scope_name = format!("//{}", machine_id); + let scope_name = format!("//{}", machine_name); let name = format!("{}:{}", scope_name, forward_id); - - let queriable = ctx.forward(name.as_str()); + let queriable = Queriable::::Forward(machine.add_forward(name.as_str(), 0), false); if let Queriable::Forward(signal, _) = queriable { self.mapping .symbol_uuid .insert((scope_name, forward_id), signal.uuid()); self.mapping.forward_signals.insert(signal.uuid(), signal); } else { - unreachable!("ctx.internal returns not internal signal"); + unreachable!("Forward queriable should return a forward signal"); } } } @@ -655,39 +543,77 @@ impl Compiler { } } -// Basic signal factory. -#[derive(Default)] -struct SignalFactory { - count: u64, - _p: PhantomData, -} - -impl poly::SignalFactory> for SignalFactory { - fn create>(&mut self, annotation: S) -> Queriable { - self.count += 1; - Queriable::Internal(InternalSignal::new(format!( - "{}-{}", - annotation.into(), - self.count - ))) - } +fn get_symbols<'a>( + symbols: &'a SymTable, + machine_name: &'a str, +) -> &'a HashMap { + let symbols = symbols + .get_scope(&["/".to_string(), machine_name.to_string()]) + .expect("scope not found") + .get_symbols(); + symbols } #[cfg(test)] mod test { + use std::collections::HashMap; + use halo2_proofs::halo2curves::bn256::Fr; + use itertools::Itertools; use crate::{ - compiler::{compile_file_legacy, compile_legacy}, + compiler::{compile, compile_file, compile_legacy}, parser::ast::debug_sym_factory::DebugSymRefFactory, + wit_gen::TraceGenerator, }; use super::Config; + // TODO rewrite the test after machines are able to call other machines #[test] - fn test_compiler_fibo() { + fn test_compiler_fibo_multiple_machines() { + // Source code containing two machines let circuit = " - machine fibo(signal n) (signal b: field) { + machine fibo1 (signal n) (signal b: field) { + // n and be are created automatically as shared + // signals + signal a: field, i; + + // there is always a state called initial + // input signals get bound to the signal + // in the initial state (first instance) + state initial { + signal c; + + i, a, b, c <== 1, 1, 1, 2; + + -> middle { + i', a', b', n' <== i + 1, b, c, n; + } + } + + state middle { + signal c; + + c <== a + b; + + if i + 1 == n { + -> final { + i', b', n' <== i + 1, c, n; + } + } else { + -> middle { + i', a', b', n' <== i + 1, b, c, n; + } + } + } + + // There is always a state called final. + // Output signals get automatically bound to the signals + // with the same name in the final step (last instance). + // This state can be implicit if there are no constraints in it. + } + machine fibo2 (signal n) (signal b: field) { // n and be are created automatically as shared // signals signal a: field, i; @@ -701,7 +627,7 @@ mod test { i, a, b, c <== 1, 1, 1, 2; -> middle { - a', b', n' <== b, c, n; + i', a', b', n' <== i + 1, b, c, n; } } @@ -729,29 +655,187 @@ mod test { "; let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let result = compile_legacy::( + let result = compile::( circuit, Config::default().max_degree(2), &debug_sym_ref_factory, ); match result { - Ok(result) => println!("{:#?}", result), + Ok(result) => { + assert_eq!(result.circuit.machines.len(), 2); + println!("{:#?}", result) + } Err(messages) => println!("{:#?}", messages), } } + #[test] + fn test_is_new_compiler_identical_to_legacy() { + let circuit = " + machine fibo(signal n) (signal b: field) { + // n and be are created automatically as shared + // signals + signal a: field, i; + + // there is always a state called initial + // input signals get bound to the signal + // in the initial state (first instance) + state initial { + signal c; + + i, a, b, c <== 1, 1, 1, 2; + + -> middle { + i', a', b', n' <== i + 1, b, c, n; + } + } + + state middle { + signal c; + + c <== a + b; + + if i + 1 == n { + -> final { + i', b', n' <== i + 1, c, n; + } + } else { + -> middle { + i', a', b', n' <== i + 1, b, c, n; + } + } + } + + // There is always a state called final. + // Output signals get automatically bound to the signals + // with the same name in the final step (last instance). + // This state can be implicit if there are no constraints in it. + } + "; + + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); + let result = compile::( + circuit, + Config::default().max_degree(2), + &debug_sym_ref_factory, + ) + .unwrap(); + + let result_legacy = compile_legacy::( + circuit, + Config::default().max_degree(2), + &debug_sym_ref_factory, + ) + .unwrap(); + + let result = result.circuit.machines.get("fibo").unwrap(); + let result_legacy = result_legacy.circuit; + let exposed = &result.exposed; + let exposed_legacy = result_legacy.exposed; + + for exposed in exposed.iter().zip(exposed_legacy.iter()) { + assert_eq!(exposed.0 .0, exposed.1 .0); + assert_eq!(exposed.0 .1, exposed.1 .1); + } + assert_eq!(result.annotations.len(), result_legacy.annotations.len()); + for val in result_legacy.annotations.values() { + assert!(result.annotations.values().contains(val)); + } + + assert_eq!( + result.forward_signals.len(), + result_legacy.forward_signals.len() + ); + for val in result_legacy.forward_signals.iter() { + assert!(result + .forward_signals + .iter() + .any(|x| x.annotation() == val.annotation() && x.phase() == val.phase())); + } + + assert_eq!(result.shared_signals, result_legacy.shared_signals); + assert_eq!(result.fixed_signals, result_legacy.fixed_signals); + assert_eq!(result.halo2_advice, result_legacy.halo2_advice); + assert_eq!(result.halo2_fixed, result_legacy.halo2_fixed); + assert_eq!(result.step_types.len(), result_legacy.step_types.len()); + for step in result_legacy.step_types.values() { + let name = step.name(); + let step_new = result + .step_types + .iter() + .find(|x| x.1.name() == name) + .unwrap() + .1; + assert_eq!(step_new.signals.len(), step.signals.len()); + for signal in step.signals.iter() { + assert!(step_new + .signals + .iter() + .any(|x| x.annotation() == signal.annotation())); + } + assert_eq!(step_new.constraints.len(), step.constraints.len()); + for constraint in step.constraints.iter() { + assert!(step_new + .constraints + .iter() + .any(|x| x.annotation == constraint.annotation)); + } + assert_eq!(step_new.lookups.is_empty(), step.lookups.is_empty()); + assert_eq!( + step_new.auto_signals.is_empty(), + step.auto_signals.is_empty() + ); + assert_eq!( + step_new.transition_constraints.is_empty(), + step.transition_constraints.is_empty() + ); + assert_eq!(step_new.annotations.len(), step.annotations.len()); + } + + assert_eq!( + result.first_step.is_some(), + result_legacy.first_step.is_some() + ); + assert_eq!( + result.last_step.is_some(), + result_legacy.last_step.is_some() + ); + assert_eq!(result.num_steps, result_legacy.num_steps); + assert_eq!(result.q_enable, result_legacy.q_enable); + + let tg_new = result.trace_generator.as_ref().unwrap(); + let tg_legacy = result_legacy.trace_generator.unwrap(); + + // Check if the witness values of the new compiler are the same as the legacy compiler + let res = tg_new.generate(HashMap::from([("n".to_string(), Fr::from(12))])); + let res_legacy = tg_legacy.generate(HashMap::from([("n".to_string(), Fr::from(12))])); + assert_eq!(res.step_instances.len(), res_legacy.step_instances.len()); + for (step, step_legacy) in res.step_instances.iter().zip(res_legacy.step_instances) { + assert_eq!(step.assignments.len(), step_legacy.assignments.len()); + for assignment in step.assignments.iter() { + let assignment_legacy = step_legacy + .assignments + .iter() + .find(|x| x.0.annotation() == assignment.0.annotation()) + .unwrap(); + assert_eq!(assignment.0.annotation(), assignment_legacy.0.annotation()); + assert!(assignment.1.eq(assignment_legacy.1)); + } + } + } + #[test] fn test_compiler_fibo_file() { let path = "test/circuit.chiquito"; - let result = compile_file_legacy::(path, Config::default().max_degree(2)); + let result = compile_file::(path, Config::default().max_degree(2)); assert!(result.is_ok()); } #[test] fn test_compiler_fibo_file_err() { let path = "test/circuit_error.chiquito"; - let result = compile_file_legacy::(path, Config::default().max_degree(2)); + let result = compile_file::(path, Config::default().max_degree(2)); assert!(result.is_err()); diff --git a/src/compiler/compiler_legacy.rs b/src/compiler/compiler_legacy.rs new file mode 100644 index 00000000..c7bed275 --- /dev/null +++ b/src/compiler/compiler_legacy.rs @@ -0,0 +1,721 @@ +use std::{collections::HashMap, hash::Hash, marker::PhantomData}; + +use num_bigint::BigInt; + +use crate::{ + field::Field, + frontend::dsl::{ + cb::{Constraint, Typing}, + circuit_legacy, CircuitContextLegacy, StepTypeContext, + }, + interpreter::InterpreterTraceGenerator, + parser::{ + ast::{ + debug_sym_factory::DebugSymRefFactory, + expression::Expression, + statement::{Statement, TypedIdDecl}, + tl::TLDecl, + DebugSymRef, Identifiable, Identifier, + }, + lang::TLDeclsParser, + }, + plonkish::{self, compiler::PlonkishCompilationResult}, + poly::{self, mielim::mi_elimination, reduce::reduce_degree, Expr}, + sbpir::{query::Queriable, InternalSignal, SBPIRLegacy}, + wit_gen::{NullTraceGenerator, SymbolSignalMapping, TraceGenerator}, +}; + +use super::{ + semantic::{SymTable, SymbolCategory}, + setup_inter::{interpret, MachineSetup, Setup}, + Config, Message, Messages, +}; + +/// Contains the result of a single machine compilation (legacy). +#[derive(Debug)] +pub struct CompilerResultLegacy { + pub messages: Vec, + pub circuit: SBPIRLegacy, +} + +impl CompilerResultLegacy { + /// Compiles to the Plonkish IR, that then can be compiled to plonkish backends. + pub fn plonkish< + CM: plonkish::compiler::cell_manager::CellManager, + SSB: plonkish::compiler::step_selector::StepSelectorBuilder, + >( + &self, + config: plonkish::compiler::CompilerConfig, + ) -> PlonkishCompilationResult { + plonkish::compiler::compile(config, &self.circuit) + } +} + +/// This compiler compiles from chiquito source code to the SBPIR. +#[derive(Default)] +pub(super) struct CompilerLegacy { + pub(super) config: Config, + + messages: Vec, + + mapping: SymbolSignalMapping, + + _p: PhantomData, +} + +impl CompilerLegacy { + /// Creates a configured compiler. + pub fn new(mut config: Config) -> Self { + if config.max_steps == 0 { + config.max_steps = 1000; // TODO: organise this better + } + CompilerLegacy { + config, + ..CompilerLegacy::default() + } + } + + /// Compile the source code containing a single machine. + pub(super) fn compile( + mut self, + source: &str, + debug_sym_ref_factory: &DebugSymRefFactory, + ) -> Result, Vec> { + let ast = self + .parse(source, debug_sym_ref_factory) + .map_err(|_| self.messages.clone())?; + assert!(ast.len() == 1, "Use `compile` to compile multiple machines"); + let ast = self.add_virtual(ast); + let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?; + let setup = Self::interpret(&ast, &symbols); + let setup = Self::map_consts(setup); + let circuit = self.build(&setup, &symbols); + let circuit = Self::mi_elim(circuit); + let circuit = if let Some(degree) = self.config.max_degree { + Self::reduce(circuit, degree) + } else { + circuit + }; + + let circuit = circuit.with_trace(InterpreterTraceGenerator::new( + ast, + symbols, + self.mapping, + self.config.max_steps, + )); + + Ok(CompilerResultLegacy { + messages: self.messages, + circuit, + }) + } + + fn parse( + &mut self, + source: &str, + debug_sym_ref_factory: &DebugSymRefFactory, + ) -> Result>, ()> { + let result = TLDeclsParser::new().parse(debug_sym_ref_factory, source); + + match result { + Ok(ast) => Ok(ast), + Err(error) => { + self.messages.push(Message::ParseErr { + msg: error.to_string(), + }); + Err(()) + } + } + } + + /// Adds "virtual" states to the AST (necessary to handle padding) + fn add_virtual( + &mut self, + mut ast: Vec>, + ) -> Vec> { + for tldc in ast.iter_mut() { + match tldc { + TLDecl::MachineDecl { + dsym, + id: _, + input_params: _, + output_params, + block, + } => self.add_virtual_to_machine(dsym, output_params, block), + } + } + + ast + } + + fn add_virtual_to_machine( + &mut self, + dsym: &DebugSymRef, + output_params: &Vec>, + block: &mut Statement, + ) { + let dsym = DebugSymRef::into_virtual(dsym); + let output_params = Self::get_decls(output_params); + + if let Statement::Block(_, stmts) = block { + let mut has_final = false; + + for stmt in stmts.iter() { + if let Statement::StateDecl(_, id, _) = stmt + && id.name() == "final" + { + has_final = true + } + } + if !has_final { + stmts.push(Statement::StateDecl( + dsym.clone(), + Identifier::new("final", dsym.clone()), + Box::new(Statement::Block(dsym.clone(), vec![])), + )); + } + + let final_state = Self::find_state_mut("final", stmts).unwrap(); + + let mut padding_transitions = output_params + .iter() + .map(|output_signal| { + Statement::SignalAssignmentAssert( + dsym.clone(), + vec![output_signal.id.next()], + vec![Expression::Query::( + dsym.clone(), + output_signal.id.clone(), + )], + ) + }) + .collect::>(); + + padding_transitions.push(Statement::Transition( + dsym.clone(), + Identifier::new("__padding", dsym.clone()), + Box::new(Statement::Block(dsym.clone(), vec![])), + )); + + Self::add_virtual_to_state(final_state, padding_transitions.clone()); + + stmts.push(Statement::StateDecl( + dsym.clone(), + Identifier::new("__padding", dsym.clone()), + Box::new(Statement::Block(dsym.clone(), padding_transitions)), + )); + } // Semantic analyser must show an error in the else case + } + + fn find_state_mut>( + state_id: S, + stmts: &mut [Statement], + ) -> Option<&mut Statement> { + let state_id = state_id.into(); + let mut final_state: Option<&mut Statement> = None; + + for stmt in stmts.iter_mut() { + if let Statement::StateDecl(_, id, _) = stmt + && id.name() == state_id + { + final_state = Some(stmt) + } + } + + final_state + } + + fn add_virtual_to_state( + state: &mut Statement, + add_statements: Vec>, + ) { + if let Statement::StateDecl(_, _, final_state_stmts) = state { + if let Statement::Block(_, stmts) = final_state_stmts.as_mut() { + stmts.extend(add_statements) + } + } + } + + /// Semantic analysis of the AST + /// Returns the symbol table if successful + fn semantic(&mut self, ast: &[TLDecl]) -> Result { + let result = super::semantic::analyser::analyse(ast); + let has_errors = result.messages.has_errors(); + + self.messages.extend(result.messages); + + if has_errors { + Err(()) + } else { + Ok(result.symbols) + } + } + + fn interpret(ast: &[TLDecl], symbols: &SymTable) -> Setup { + interpret(ast, symbols) + } + + fn map_consts(setup: Setup) -> Setup { + setup + .iter() + .map(|(machine_id, machine)| { + let poly_constraints: HashMap>> = machine + .poly_constraints_iter() + .map(|(step_id, step)| { + let new_step: Vec> = + step.iter().map(|pi| Self::map_pi_consts(pi)).collect(); + + (step_id.clone(), new_step) + }) + .collect(); + + let new_machine: MachineSetup = + machine.replace_poly_constraints(poly_constraints); + (machine_id.clone(), new_machine) + }) + .collect() + } + + fn map_pi_consts(expr: &Expr) -> Expr { + use Expr::*; + match expr { + Const(v, _) => Const(F::from_big_int(v), ()), + Sum(ses, _) => Sum(ses.iter().map(|se| Self::map_pi_consts(se)).collect(), ()), + Mul(ses, _) => Mul(ses.iter().map(|se| Self::map_pi_consts(se)).collect(), ()), + Neg(se, _) => Neg(Box::new(Self::map_pi_consts(se)), ()), + Pow(se, exp, _) => Pow(Box::new(Self::map_pi_consts(se)), *exp, ()), + Query(q, _) => Query(q.clone(), ()), + Halo2Expr(_, _) => todo!(), + MI(se, _) => MI(Box::new(Self::map_pi_consts(se)), ()), + } + } + + fn build( + &mut self, + setup: &Setup, + symbols: &SymTable, + ) -> SBPIRLegacy { + circuit_legacy::("circuit", |ctx| { + for (machine_id, machine) in setup { + self.add_forwards(ctx, symbols, machine_id); + self.add_step_type_handlers(ctx, symbols, machine_id); + + ctx.pragma_num_steps(self.config.max_steps); + ctx.pragma_first_step(self.mapping.get_step_type_handler(machine_id, "initial")); + ctx.pragma_last_step(self.mapping.get_step_type_handler(machine_id, "__padding")); + + for state_id in machine.states() { + ctx.step_type_def( + self.mapping.get_step_type_handler(machine_id, state_id), + |ctx| { + self.add_internals(ctx, symbols, machine_id, state_id); + + ctx.setup(|ctx| { + let poly_constraints = + self.translate_queries(symbols, setup, machine_id, state_id); + poly_constraints.iter().for_each(|poly| { + let constraint = Constraint { + annotation: format!("{:?}", poly), + expr: poly.clone(), + typing: Typing::AntiBooly, + }; + ctx.constr(constraint); + }); + }); + + ctx.wg(|_, _: ()| {}) + }, + ); + } + } + + ctx.trace(|_, _| {}); + }) + .without_trace() + } + + fn mi_elim( + mut circuit: SBPIRLegacy, + ) -> SBPIRLegacy { + for (_, step_type) in circuit.step_types.iter_mut() { + let mut signal_factory = SignalFactory::default(); + + step_type.decomp_constraints(|expr| mi_elimination(expr.clone(), &mut signal_factory)); + } + + circuit + } + + fn reduce( + mut circuit: SBPIRLegacy, + degree: usize, + ) -> SBPIRLegacy { + for (_, step_type) in circuit.step_types.iter_mut() { + let mut signal_factory = SignalFactory::default(); + + step_type.decomp_constraints(|expr| { + reduce_degree(expr.clone(), degree, &mut signal_factory) + }); + } + + circuit + } + + #[allow(dead_code)] + fn cse(mut _circuit: SBPIRLegacy) -> SBPIRLegacy { + todo!() + } + + fn translate_queries( + &mut self, + symbols: &SymTable, + setup: &Setup, + machine_id: &str, + state_id: &str, + ) -> Vec, ()>> { + let exprs = setup + .get(machine_id) + .unwrap() + .get_poly_constraints(state_id) + .unwrap(); + + exprs + .iter() + .map(|expr| self.translate_queries_expr(symbols, machine_id, state_id, expr)) + .collect() + } + + fn translate_queries_expr( + &mut self, + symbols: &SymTable, + machine_id: &str, + state_id: &str, + expr: &Expr, + ) -> Expr, ()> { + use Expr::*; + match expr { + Const(v, _) => Const(*v, ()), + Sum(ses, _) => Sum( + ses.iter() + .map(|se| self.translate_queries_expr(symbols, machine_id, state_id, se)) + .collect(), + (), + ), + Mul(ses, _) => Mul( + ses.iter() + .map(|se| self.translate_queries_expr(symbols, machine_id, state_id, se)) + .collect(), + (), + ), + Neg(se, _) => Neg( + Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + (), + ), + Pow(se, exp, _) => Pow( + Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + *exp, + (), + ), + MI(se, _) => MI( + Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + (), + ), + Halo2Expr(se, _) => Halo2Expr(se.clone(), ()), + Query(id, _) => Query(self.translate_query(symbols, machine_id, state_id, id), ()), + } + } + + fn translate_query( + &mut self, + symbols: &SymTable, + machine_id: &str, + state_id: &str, + id: &Identifier, + ) -> Queriable { + use super::semantic::{ScopeCategory, SymbolCategory::*}; + + let symbol = symbols + .find_symbol( + &[ + "/".to_string(), + machine_id.to_string(), + state_id.to_string(), + ], + id.name(), + ) + .unwrap_or_else(|| panic!("semantic analyser fail: undeclared id {}", id.name())); + + match symbol.symbol.category { + InputSignal | OutputSignal | InoutSignal => { + self.translate_forward_queriable(machine_id, id) + } + Signal => match symbol.scope_cat { + ScopeCategory::Machine => self.translate_forward_queriable(machine_id, id), + ScopeCategory::State => { + if id.rotation() != 0 { + unreachable!("semantic analyser should prevent this"); + } + let signal = self + .mapping + .get_internal(&format!("//{}/{}", machine_id, state_id), &id.name()); + + Queriable::Internal(signal) + } + + ScopeCategory::Global => unreachable!("no global signals"), + }, + + State => { + Queriable::StepTypeNext(self.mapping.get_step_type_handler(machine_id, &id.name())) + } + + _ => unreachable!("semantic analysis should prevent this"), + } + } + + fn translate_forward_queriable(&mut self, machine_id: &str, id: &Identifier) -> Queriable { + let forward = self.mapping.get_forward(machine_id, &id.name()); + let rot = if id.rotation() == 0 { + false + } else if id.rotation() == 1 { + true + } else { + unreachable!("semantic analyser should prevent this") + }; + + Queriable::Forward(forward, rot) + } + + fn get_all_internals( + &mut self, + symbols: &SymTable, + machine_id: &str, + state_id: &str, + ) -> Vec { + let symbols = symbols + .get_scope(&[ + "/".to_string(), + machine_id.to_string(), + state_id.to_string(), + ]) + .expect("scope not found") + .get_symbols(); + + symbols + .iter() + .filter(|(_, entry)| entry.category == SymbolCategory::Signal) + .map(|(id, _)| id) + .cloned() + .collect() + } + + fn add_internals( + &mut self, + ctx: &mut StepTypeContext, + symbols: &SymTable, + machine_id: &str, + state_id: &str, + ) { + let internal_ids = self.get_all_internals(symbols, machine_id, state_id); + let scope_name = format!("//{}/{}", machine_id, state_id); + + for internal_id in internal_ids { + let name = format!("{}:{}", &scope_name, internal_id); + + let queriable = ctx.internal(name.as_str()); + if let Queriable::Internal(signal) = queriable { + self.mapping + .symbol_uuid + .insert((scope_name.clone(), internal_id), signal.uuid()); + self.mapping.internal_signals.insert(signal.uuid(), signal); + } else { + unreachable!("ctx.internal returns not internal signal"); + } + } + } + + fn add_step_type_handlers>( + &mut self, + ctx: &mut CircuitContextLegacy, + symbols: &SymTable, + machine_id: &str, + ) { + let symbols = symbols + .get_scope(&["/".to_string(), machine_id.to_string()]) + .expect("scope not found") + .get_symbols(); + + let state_ids: Vec<_> = symbols + .iter() + .filter(|(_, entry)| entry.category == SymbolCategory::State) + .map(|(id, _)| id) + .cloned() + .collect(); + + for state_id in state_ids { + let scope_name = format!("//{}", machine_id); + let name = format!("{}:{}", scope_name, state_id); + + let handler = ctx.step_type(&name); + self.mapping + .step_type_handler + .insert(handler.uuid(), handler); + self.mapping + .symbol_uuid + .insert((scope_name, state_id), handler.uuid()); + } + } + + fn add_forwards>( + &mut self, + ctx: &mut CircuitContextLegacy, + symbols: &SymTable, + machine_id: &str, + ) { + let symbols = symbols + .get_scope(&["/".to_string(), machine_id.to_string()]) + .expect("scope not found") + .get_symbols(); + + let forward_ids: Vec<_> = symbols + .iter() + .filter(|(_, entry)| entry.is_signal()) + .map(|(id, _)| id) + .cloned() + .collect(); + + for forward_id in forward_ids { + let scope_name = format!("//{}", machine_id); + let name = format!("{}:{}", scope_name, forward_id); + + let queriable = ctx.forward(name.as_str()); + if let Queriable::Forward(signal, _) = queriable { + self.mapping + .symbol_uuid + .insert((scope_name, forward_id), signal.uuid()); + self.mapping.forward_signals.insert(signal.uuid(), signal); + } else { + unreachable!("ctx.internal returns not internal signal"); + } + } + } + + fn get_decls(stmts: &Vec>) -> Vec> { + let mut result: Vec> = vec![]; + + for stmt in stmts { + if let Statement::SignalDecl(_, ids) = stmt { + result.extend(ids.clone()) + } + } + + result + } +} + +// Basic signal factory. +#[derive(Default)] +struct SignalFactory { + count: u64, + _p: PhantomData, +} + +impl poly::SignalFactory> for SignalFactory { + fn create>(&mut self, annotation: S) -> Queriable { + self.count += 1; + Queriable::Internal(InternalSignal::new(format!( + "{}-{}", + annotation.into(), + self.count + ))) + } +} + +#[cfg(test)] +mod test { + use halo2_proofs::halo2curves::bn256::Fr; + + use crate::{ + compiler::{compile_file_legacy, compile_legacy}, + parser::ast::debug_sym_factory::DebugSymRefFactory, + }; + + use super::Config; + + #[test] + fn test_compiler_fibo() { + let circuit = " + machine fibo(signal n) (signal b: field) { + // n and be are created automatically as shared + // signals + signal a: field, i; + + // there is always a state called initial + // input signals get bound to the signal + // in the initial state (first instance) + state initial { + signal c; + + i, a, b, c <== 1, 1, 1, 2; + + -> middle { + a', b', n' <== b, c, n; + } + } + + state middle { + signal c; + + c <== a + b; + + if i + 1 == n { + -> final { + i', b', n' <== i + 1, c, n; + } + } else { + -> middle { + i', a', b', n' <== i + 1, b, c, n; + } + } + } + + // There is always a state called final. + // Output signals get automatically bound to the signals + // with the same name in the final step (last instance). + // This state can be implicit if there are no constraints in it. + } + "; + + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); + let result = compile_legacy::( + circuit, + Config::default().max_degree(2), + &debug_sym_ref_factory, + ); + + match result { + Ok(result) => println!("{:#?}", result), + Err(messages) => println!("{:#?}", messages), + } + } + + #[test] + fn test_compiler_fibo_file() { + let path = "test/circuit.chiquito"; + let result = compile_file_legacy::(path, Config::default().max_degree(2)); + assert!(result.is_ok()); + } + + #[test] + fn test_compiler_fibo_file_err() { + let path = "test/circuit_error.chiquito"; + let result = compile_file_legacy::(path, Config::default().max_degree(2)); + + assert!(result.is_err()); + + assert_eq!( + format!("{:?}", result.unwrap_err()), + r#"[SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:24:39 }, SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:28:46 }]"# + ) + } +} diff --git a/src/compiler/mod.rs b/src/compiler/mod.rs index e590396c..9a6bfacd 100644 --- a/src/compiler/mod.rs +++ b/src/compiler/mod.rs @@ -4,9 +4,9 @@ use std::{ io::{self, Read}, }; -use compiler::CompilerResult; +use compiler::{Compiler, CompilerResult}; +use compiler_legacy::{CompilerLegacy, CompilerResultLegacy}; -use self::compiler::{Compiler, CompilerResultLegacy}; use crate::{ field::Field, parser::ast::{debug_sym_factory::DebugSymRefFactory, DebugSymRef}, @@ -15,6 +15,7 @@ use crate::{ pub mod abepi; #[allow(clippy::module_inception)] pub mod compiler; +pub mod compiler_legacy; pub mod semantic; mod setup_inter; @@ -74,7 +75,7 @@ pub fn compile_legacy( config: Config, debug_sym_ref_factory: &DebugSymRefFactory, ) -> Result, Vec> { - Compiler::new(config).compile_legacy(source, debug_sym_ref_factory) + CompilerLegacy::new(config).compile(source, debug_sym_ref_factory) } /// Compiles chiquito source code file into a SBPIR for a single machine, also returns messages diff --git a/src/compiler/setup_inter.rs b/src/compiler/setup_inter.rs index 53af8721..9b2b3e24 100644 --- a/src/compiler/setup_inter.rs +++ b/src/compiler/setup_inter.rs @@ -4,6 +4,7 @@ use itertools::Itertools; use num_bigint::BigInt; use crate::{ + field::Field, parser::ast::{ statement::{Statement, TypedIdDecl}, tl::TLDecl, @@ -46,6 +47,50 @@ impl Default for MachineSetup { } } } +impl MachineSetup { + pub(crate) fn map_consts(&self) -> MachineSetup { + let poly_constraints: HashMap>> = self + .poly_constraints_iter() + .map(|(step_id, step)| { + let new_step: Vec> = step + .iter() + .map(|pi| Self::convert_const_to_field(pi)) + .collect(); + + (step_id.clone(), new_step) + }) + .collect(); + + let new_machine: MachineSetup = self.replace_poly_constraints(poly_constraints); + new_machine + } + + fn convert_const_to_field( + expr: &Expr, + ) -> Expr { + use Expr::*; + match expr { + Const(v, _) => Const(F::from_big_int(v), ()), + Sum(ses, _) => Sum( + ses.iter() + .map(|se| Self::convert_const_to_field(se)) + .collect(), + (), + ), + Mul(ses, _) => Mul( + ses.iter() + .map(|se| Self::convert_const_to_field(se)) + .collect(), + (), + ), + Neg(se, _) => Neg(Box::new(Self::convert_const_to_field(se)), ()), + Pow(se, exp, _) => Pow(Box::new(Self::convert_const_to_field(se)), *exp, ()), + Query(q, _) => Query(q.clone(), ()), + Halo2Expr(_, _) => todo!(), + MI(se, _) => MI(Box::new(Self::convert_const_to_field(se)), ()), + } + } +} impl MachineSetup { fn new( @@ -88,7 +133,7 @@ impl MachineSetup { .extend(poly_constraints); } - pub(super) fn iter_states_poly_constraints( + pub(super) fn poly_constraints_iter( &self, ) -> std::collections::hash_map::Iter>> { self.poly_constraints.iter() diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index 6710b114..b6ee4f96 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -5,18 +5,18 @@ use crate::{ wit_gen::{FixedGenContext, StepInstance, TraceGenerator}, }; -use halo2_proofs::plonk::{Advice, Column as Halo2Column, Fixed}; -use trace::{DSLTraceGenerator, TraceContext}; - use core::{fmt::Debug, hash::Hash}; use std::marker::PhantomData; use self::{ cb::{eq, Constraint, Typing}, - lb::{LookupBuilder, LookupTable, LookupTableRegistry, LookupTableStore}, + lb::{LookupBuilder, LookupTableRegistry}, }; +use halo2_proofs::plonk::{Advice, Column as Halo2Column, Fixed}; +use lb::{LookupTable, LookupTableStore}; pub use sc::*; +use trace::{DSLTraceGenerator, TraceContext}; pub mod cb; pub mod lb; @@ -31,12 +31,12 @@ pub mod trace; /// ### Type parameters /// `F` is the field of the circuit. /// `TG` is the trace generator. -pub struct CircuitContext = DSLTraceGenerator> { +pub struct CircuitContextLegacy = DSLTraceGenerator> { circuit: SBPIRLegacy, tables: LookupTableRegistry, } -impl> CircuitContext { +impl> CircuitContextLegacy { /// Adds a forward signal to the circuit with a name string and zero rotation and returns a /// `Queriable` instance representing the added forward signal. pub fn forward(&mut self, name: &str) -> Queriable { @@ -159,7 +159,7 @@ impl> CircuitContext { } } -impl CircuitContext> { +impl CircuitContextLegacy> { /// Sets the trace function that builds the witness. The trace function is responsible for /// adding step instances defined in `step_type_def`. The function is entirely left for /// the user to implement and is Turing complete. Users typically use external parameters @@ -173,7 +173,7 @@ impl CircuitContext> CircuitContext { +impl> CircuitContextLegacy { /// Executes the fixed generation function provided by the user and sets the fixed assignments /// for the circuit. The fixed generation function is responsible for assigning fixed values to /// fixed columns. It is entirely left for the user to implement and is Turing complete. Users @@ -193,7 +193,6 @@ impl> CircuitContext { self.circuit.set_fixed_assignments(assignments); } } - pub enum StepTypeDefInput { Handler(StepTypeHandler), String(&'static str), @@ -355,7 +354,7 @@ pub struct StepTypeHandler { } impl StepTypeHandler { - fn new(annotation: String) -> Self { + pub(crate) fn new(annotation: String) -> Self { Self { id: uuid(), annotation: Box::leak(annotation.into_boxed_str()), @@ -421,15 +420,16 @@ impl, Args) + 'static> StepTypeWGHandler( +/// (LEGACY) +pub fn circuit_legacy( _name: &str, mut def: D, ) -> SBPIRLegacy> where - D: FnMut(&mut CircuitContext>), + D: FnMut(&mut CircuitContextLegacy>), { // TODO annotate circuit - let mut context = CircuitContext { + let mut context = CircuitContextLegacy { circuit: SBPIRLegacy::default(), tables: LookupTableRegistry::default(), }; @@ -442,17 +442,21 @@ where #[cfg(test)] mod tests { use halo2_proofs::halo2curves::bn256::Fr; + use trace::DSLTraceGenerator; - use crate::{sbpir::ForwardSignal, wit_gen::NullTraceGenerator}; + use crate::{ + sbpir::{ExposeOffset, ForwardSignal, SBPIRLegacy}, + wit_gen::{NullTraceGenerator, TraceGenerator}, + }; use super::*; - fn setup_circuit_context() -> CircuitContext + fn setup_circuit_context() -> CircuitContextLegacy where F: Default, TG: TraceGenerator, { - CircuitContext { + CircuitContextLegacy { circuit: SBPIRLegacy::default(), tables: Default::default(), } diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index 860010c7..31275d45 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -17,7 +17,7 @@ use crate::{ wit_gen::{NullTraceGenerator, TraceGenerator}, }; -use super::{lb::LookupTableRegistry, trace::DSLTraceGenerator, CircuitContext}; +use super::{lb::LookupTableRegistry, trace::DSLTraceGenerator, CircuitContextLegacy}; pub struct SuperCircuitContext { super_circuit: SuperCircuit, @@ -61,9 +61,9 @@ impl SuperCircuitContext { Exports, ) where - D: Fn(&mut CircuitContext>, Imports) -> Exports, + D: Fn(&mut CircuitContextLegacy>, Imports) -> Exports, { - let mut sub_circuit_context = CircuitContext { + let mut sub_circuit_context = CircuitContextLegacy { circuit: SBPIRLegacy::default(), tables: self.tables.clone(), }; @@ -148,6 +148,7 @@ mod tests { use halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; use crate::{ + frontend::dsl::circuit_legacy, plonkish::compiler::{ cell_manager::SingleRowCellManager, config, step_selector::SimpleStepSelectorBuilder, }, @@ -180,7 +181,7 @@ mod tests { let mut ctx = SuperCircuitContext::::default(); fn simple_circuit( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, _: (), ) { use crate::frontend::dsl::cb::*; @@ -239,7 +240,7 @@ mod tests { let mut ctx = SuperCircuitContext::::default(); fn simple_circuit( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, _: (), ) { use crate::frontend::dsl::cb::*; @@ -296,10 +297,9 @@ mod tests { #[test] fn test_super_circuit_sub_circuit_with_ast() { - use crate::frontend::dsl::circuit; let mut ctx = SuperCircuitContext::::default(); - let simple_circuit_with_ast = circuit("simple circuit", |ctx| { + let simple_circuit_with_ast = circuit_legacy("simple circuit", |ctx| { use crate::frontend::dsl::cb::*; let x = ctx.forward("x"); diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index e6177239..7cd05010 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -255,7 +255,7 @@ pub fn run( } /// A trace generator that interprets chiquito source -#[derive(Default, Clone)] +#[derive(Debug, Default, Clone)] pub struct InterpreterTraceGenerator { program: Vec>, symbols: SymTable, diff --git a/src/poly/mielim.rs b/src/poly/mielim.rs index 96e7cfeb..a9e1e888 100644 --- a/src/poly/mielim.rs +++ b/src/poly/mielim.rs @@ -3,8 +3,8 @@ use std::{fmt::Debug, hash::Hash}; use super::{ConstrDecomp, Expr, SignalFactory}; use crate::field::Field; -/// This function eliminates MI operators from the PI expression, by creating new signals that are -/// constraint to the MI sub-expressions. +/// This function eliminates MI operators from the PI expression by creating new signals that are +/// constrained to the MI sub-expressions. pub fn mi_elimination>( constr: Expr, signal_factory: &mut SF, diff --git a/src/poly/mod.rs b/src/poly/mod.rs index 2e5942f9..24fedf35 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -31,8 +31,8 @@ pub enum Expr { Pow(Box>, u32, M), Query(V, M), Halo2Expr(Expression, M), - - MI(Box>, M), // Multiplicative inverse, but MI(0) = 0 + /// Multiplicative inverse, but MI(0) = 0 + MI(Box>, M), } impl Expr { diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index 14f46c34..0335345f 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -1,7 +1,7 @@ pub mod query; pub mod sbpir_machine; -use std::{collections::HashMap, fmt::Debug, hash::Hash, rc::Rc}; +use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData, rc::Rc}; use crate::{ field::Field, @@ -9,7 +9,7 @@ use crate::{ trace::{DSLTraceGenerator, TraceContext}, StepTypeHandler, }, - poly::{ConstrDecomp, Expr}, + poly::{self, mielim::mi_elimination, reduce::reduce_degree, ConstrDecomp, Expr}, util::{uuid, UUID}, wit_gen::{FixedAssignment, FixedGenContext, NullTraceGenerator, TraceGenerator}, }; @@ -269,23 +269,95 @@ impl> SBPIRLegacy { } } -pub struct SBPIR = DSLTraceGenerator> { - pub machines: HashMap>, +#[derive(Debug)] +pub struct SBPIR = DSLTraceGenerator> { + pub machines: HashMap>, pub identifiers: HashMap, } -impl> SBPIR { - pub(crate) fn from_legacy(circuit: SBPIRLegacy, machine_id: &str) -> SBPIR { - let mut machines = HashMap::new(); - let circuit_id = circuit.id; - machines.insert(circuit_id, SBPIRMachine::from_legacy(circuit)); - let mut identifiers = HashMap::new(); - identifiers.insert(machine_id.to_string(), circuit_id); +impl> SBPIR { + pub(crate) fn default() -> SBPIR { + let machines = HashMap::new(); + let identifiers = HashMap::new(); SBPIR { machines, identifiers, } } + + pub(crate) fn with_trace + Clone>( + &self, + // TODO does it have to be the same trace across all the machines? + trace: &TG2, + ) -> SBPIR { + let mut machines_with_trace = HashMap::new(); + for (name, machine) in self.machines.iter() { + let machine_with_trace = machine.with_trace(trace.clone()); + machines_with_trace.insert(name.clone(), machine_with_trace); + } + SBPIR { + machines: machines_with_trace, + identifiers: self.identifiers.clone(), + } + } + + pub(crate) fn without_trace(&self) -> SBPIR { + let mut machines_without_trace = HashMap::new(); + for (name, machine) in self.machines.iter() { + let machine_without_trace = machine.without_trace(); + machines_without_trace.insert(name.clone(), machine_without_trace); + } + SBPIR { + machines: machines_without_trace, + identifiers: self.identifiers.clone(), + } + } + + /// Eliminate multiplicative inverses + pub(crate) fn eliminate_mul_inv(mut self) -> SBPIR { + for machine in self.machines.values_mut() { + for (_, step_type) in machine.step_types.iter_mut() { + let mut signal_factory = SignalFactory::default(); + + step_type + .decomp_constraints(|expr| mi_elimination(expr.clone(), &mut signal_factory)); + } + } + + self + } + + pub(crate) fn reduce(mut self, degree: usize) -> SBPIR { + for machine in self.machines.values_mut() { + for (_, step_type) in machine.step_types.iter_mut() { + let mut signal_factory = SignalFactory::default(); + + step_type.decomp_constraints(|expr| { + reduce_degree(expr.clone(), degree, &mut signal_factory) + }); + } + } + + self + } +} + +// Basic signal factory. +#[derive(Default)] +struct SignalFactory { + count: u64, + _p: PhantomData, +} + +impl poly::SignalFactory> for SignalFactory { + fn create>(&mut self, annotation: S) -> Queriable { + self.count += 1; + Queriable::Internal(InternalSignal::new(format!( + "{}-{}", + annotation.into(), + self.count + ))) + } } pub type FixedGen = dyn Fn(&mut FixedGenContext) + 'static; @@ -649,7 +721,7 @@ impl FixedSignal { } } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum ExposeOffset { First, Last, diff --git a/src/sbpir/query.rs b/src/sbpir/query.rs index 0f0b727d..ab6d7efb 100644 --- a/src/sbpir/query.rs +++ b/src/sbpir/query.rs @@ -21,11 +21,26 @@ use super::PIR; #[derive(Clone, Copy, PartialEq, Eq, Hash)] pub enum Queriable { Internal(InternalSignal), + /// Forward signal + /// - `ForwardSignal` is the signal to be queried + /// - `bool` is the rotation state of the signal (true if rotated) Forward(ForwardSignal, bool), + /// Shared signal + /// - `SharedSignal` is the signal to be queried + /// - `i32` is the rotation value Shared(SharedSignal, i32), + /// Fixed signal + /// - `FixedSignal` is the signal to be queried + /// - `i32` is the rotation value Fixed(FixedSignal, i32), StepTypeNext(StepTypeHandler), + /// Imported Halo2 advice query + /// - `ImportedHalo2Advice` is the signal to be queried + /// - `i32` is the rotation value Halo2AdviceQuery(ImportedHalo2Advice, i32), + /// Imported Halo2 fixed query + /// - `ImportedHalo2Fixed` is the signal to be queried + /// - `i32` is the rotation value Halo2FixedQuery(ImportedHalo2Fixed, i32), #[allow(non_camel_case_types)] _unaccessible(PhantomData), @@ -38,7 +53,7 @@ impl Debug for Queriable { } impl Queriable { - /// Call `next` function on a `Querible` forward signal to build constraints for forward + /// Call `next` function on a `Queriable` forward signal to build constraints for forward /// signal with rotation. Cannot be called on an internal signal and must be used within a /// `transition` constraint. Returns a new `Queriable` forward signal with rotation. pub fn next(&self) -> Queriable { diff --git a/src/sbpir/sbpir_machine.rs b/src/sbpir/sbpir_machine.rs index b5489b4d..2b21104b 100644 --- a/src/sbpir/sbpir_machine.rs +++ b/src/sbpir/sbpir_machine.rs @@ -16,12 +16,13 @@ use super::{ ImportedHalo2Fixed, SharedSignal, StepType, StepTypeUUID, }; -/// Circuit (Step-Based Polynomial Identity Representation) +/// Step-Based Polynomial Identity Representation (SBPIR) of a single machine. #[derive(Clone)] -pub struct SBPIRMachine = DSLTraceGenerator> { +pub struct SBPIRMachine = DSLTraceGenerator> { pub step_types: HashMap>, pub forward_signals: Vec, + // TODO currently not used pub shared_signals: Vec, pub fixed_signals: Vec, pub halo2_advice: Vec, @@ -31,6 +32,7 @@ pub struct SBPIRMachine = DSLTraceGenerator> { pub annotations: HashMap, pub trace_generator: Option, + // TODO currently not used pub fixed_assignments: Option>, pub first_step: Option, @@ -41,7 +43,7 @@ pub struct SBPIRMachine = DSLTraceGenerator> { pub id: UUID, } -impl> Debug for SBPIRMachine { +impl> Debug for SBPIRMachine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Circuit") .field("step_types", &self.step_types) @@ -61,7 +63,7 @@ impl> Debug for SBPIRMachine { } } -impl> Default for SBPIRMachine { +impl> Default for SBPIRMachine { fn default() -> Self { Self { step_types: Default::default(), @@ -88,7 +90,7 @@ impl> Default for SBPIRMachine { } } -impl> SBPIRMachine { +impl> SBPIRMachine { pub fn add_forward>(&mut self, name: N, phase: usize) -> ForwardSignal { let name = name.into(); let signal = ForwardSignal::new_with_phase(phase, name.clone()); @@ -109,6 +111,7 @@ impl> SBPIRMachine { signal } + // TODO currently not used pub fn add_fixed>(&mut self, name: N) -> FixedSignal { let name = name.into(); let signal = FixedSignal::new(name.clone()); @@ -119,6 +122,7 @@ impl> SBPIRMachine { signal } + // TODO currently not used pub fn expose(&mut self, signal: Queriable, offset: ExposeOffset) { match signal { Queriable::Forward(..) | Queriable::Shared(..) => { @@ -167,8 +171,11 @@ impl> SBPIRMachine { advice } - pub fn add_step_type>(&mut self, handler: StepTypeHandler, name: N) { - self.annotations.insert(handler.uuid(), name.into()); + pub fn add_step_type>(&mut self, name: N) -> StepTypeHandler { + let annotation = name.into(); + let handler = StepTypeHandler::new(annotation.clone()); + self.annotations.insert(handler.uuid(), annotation); + handler } pub fn add_step_type_def(&mut self, step: StepType) -> StepTypeUUID { @@ -187,18 +194,18 @@ impl> SBPIRMachine { } } - pub fn without_trace(self) -> SBPIRMachine { + pub fn without_trace(&self) -> SBPIRMachine { SBPIRMachine { - step_types: self.step_types, - forward_signals: self.forward_signals, - shared_signals: self.shared_signals, - fixed_signals: self.fixed_signals, - halo2_advice: self.halo2_advice, - halo2_fixed: self.halo2_fixed, - exposed: self.exposed, - annotations: self.annotations, + step_types: self.step_types.clone(), + forward_signals: self.forward_signals.clone(), + shared_signals: self.shared_signals.clone(), + fixed_signals: self.fixed_signals.clone(), + halo2_advice: self.halo2_advice.clone(), + halo2_fixed: self.halo2_fixed.clone(), + exposed: self.exposed.clone(), + annotations: self.annotations.clone(), trace_generator: None, // Remove the trace. - fixed_assignments: self.fixed_assignments, + fixed_assignments: self.fixed_assignments.clone(), first_step: self.first_step, last_step: self.last_step, num_steps: self.num_steps, @@ -207,19 +214,18 @@ impl> SBPIRMachine { } } - #[allow(dead_code)] // TODO: Copy of the legacy SBPIR code. Remove if not used in the new compilation - pub(crate) fn with_trace>(self, trace: TG2) -> SBPIRMachine { + pub(crate) fn with_trace>(&self, clone: TG2) -> SBPIRMachine { SBPIRMachine { - trace_generator: Some(trace), // Change trace - step_types: self.step_types, - forward_signals: self.forward_signals, - shared_signals: self.shared_signals, - fixed_signals: self.fixed_signals, - halo2_advice: self.halo2_advice, - halo2_fixed: self.halo2_fixed, - exposed: self.exposed, - annotations: self.annotations, - fixed_assignments: self.fixed_assignments, + trace_generator: Some(clone), // Set trace + step_types: self.step_types.clone(), + forward_signals: self.forward_signals.clone(), + shared_signals: self.shared_signals.clone(), + fixed_signals: self.fixed_signals.clone(), + halo2_advice: self.halo2_advice.clone(), + halo2_fixed: self.halo2_fixed.clone(), + exposed: self.exposed.clone(), + annotations: self.annotations.clone(), + fixed_assignments: self.fixed_assignments.clone(), first_step: self.first_step, last_step: self.last_step, num_steps: self.num_steps, @@ -227,26 +233,6 @@ impl> SBPIRMachine { id: self.id, } } - - pub(crate) fn from_legacy(circuit: super::SBPIRLegacy) -> SBPIRMachine { - SBPIRMachine { - step_types: circuit.step_types, - forward_signals: circuit.forward_signals, - shared_signals: circuit.shared_signals, - fixed_signals: circuit.fixed_signals, - halo2_advice: circuit.halo2_advice, - halo2_fixed: circuit.halo2_fixed, - exposed: circuit.exposed, - annotations: circuit.annotations, - trace_generator: circuit.trace_generator, - fixed_assignments: circuit.fixed_assignments, - first_step: circuit.first_step, - last_step: circuit.last_step, - num_steps: circuit.num_steps, - q_enable: circuit.q_enable, - id: circuit.id, - } - } } impl SBPIRMachine> {