From 15f441b8edd78b0b37a0ad925fd5b47d33e5b975 Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Sun, 3 Mar 2024 22:05:09 -0500 Subject: [PATCH] PIL Backend with PIL IR (#165) Ready for review. @leolara # High-Level Approach Convert Chiquito AST to PIL IR and finally to PIL code. # Future TODOs - Add support for trace inputs to PIL. Currently, PIL supports user input via the command line or Rust APIs for the command line. However, it only supports input for one variable (column). Because Chiquito circuit might provide inputs for multiple variables via trace (e.g. the MiMC7 circuit), we can't use use PIL's Rust API as a generic solution for providing inputs. The only solution is to input statements in the format of `ISFIRST * ([[input_variable]] - [[input_value]]) = 0`. An issue with this approach, however, is that our `trace` is too general and doesn't associate input values to signals, so we might need to create an alternative `trace` function or mode that does this. All of the above is due to the fact that PIL only supports automatic witness inference and cannot feed in external witness. - Add Python PIL API support for super circuits. Currently we support single circuit in Python. Should be very do-able, just didn't get the time to code it up. --- Cargo.toml | 1 + examples/fibonacci.py | 2 + examples/fibonacci.rs | 29 +- examples/mimc7.rs | 25 +- src/frontend/dsl/mod.rs | 4 + src/frontend/dsl/sc.rs | 11 +- src/frontend/python/chiquito/dsl.py | 9 + src/frontend/python/mod.rs | 22 ++ src/lib.rs | 1 + src/pil/backend/mod.rs | 1 + src/pil/backend/powdr_pil.rs | 226 +++++++++++ src/pil/compiler/mod.rs | 578 ++++++++++++++++++++++++++++ src/pil/ir/mod.rs | 1 + src/pil/ir/powdr_pil.rs | 18 + src/pil/mod.rs | 3 + src/plonkish/ir/assignments.rs | 6 +- src/plonkish/ir/sc.rs | 52 ++- src/sbpir/mod.rs | 42 ++ 18 files changed, 1018 insertions(+), 13 deletions(-) create mode 100644 src/pil/backend/mod.rs create mode 100644 src/pil/backend/powdr_pil.rs create mode 100644 src/pil/compiler/mod.rs create mode 100644 src/pil/ir/mod.rs create mode 100644 src/pil/ir/powdr_pil.rs create mode 100644 src/pil/mod.rs diff --git a/Cargo.toml b/Cargo.toml index c002d68b..58d1fd0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ num-bigint = { version = "0.4", features = ["rand"] } uuid = { version = "1.4.0", features = ["v1", "rng"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +regex = "1" [dev-dependencies] rand_chacha = "0.3" diff --git a/examples/fibonacci.py b/examples/fibonacci.py index 4ffa7b31..0c7aedfa 100644 --- a/examples/fibonacci.py +++ b/examples/fibonacci.py @@ -84,3 +84,5 @@ def trace(self, n): ) # 2^k specifies the number of PLONKish table rows in Halo2 another_fibo_witness = fibo.gen_witness(4) fibo.halo2_mock_prover(another_fibo_witness, k=7) + +fibo.to_pil(fibo_witness, "FiboCircuit") diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 4021cdda..dbf7dbba 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -18,6 +18,7 @@ use chiquito::{ }, plonkish::ir::{assignments::AssignmentGenerator, Circuit}, // compiled circuit type poly::ToField, + sbpir::SBPIR, }; use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; @@ -26,7 +27,10 @@ use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; // 1. type that implements a field trait // 2. empty trace arguments, i.e. (), because there are no external inputs to the Chiquito circuit // 3. two witness generation arguments both of u64 type, i.e. (u64, u64) -fn fibo_circuit + Hash>() -> (Circuit, Option>) { + +type FiboReturn = (Circuit, Option>, SBPIR); + +fn fibo_circuit + Hash>() -> FiboReturn { // PLONKish table for the Fibonacci circuit: // | a | b | c | // | 1 | 1 | 2 | @@ -73,7 +77,7 @@ fn fibo_circuit + Hash>() -> (Circuit, Option + Hash>() -> (Circuit, Option + Hash>() -> (Circuit, Option(); + let (chiquito, wit_gen, _) = fibo_circuit::(); let compiled = chiquito2Halo2(chiquito); let circuit = ChiquitoHalo2Circuit::new(compiled, wit_gen.map(|g| g.generate(()))); @@ -137,7 +143,7 @@ fn main() { use polyexen::plaf::{backends::halo2::PlafH2Circuit, WitnessDisplayCSV}; // get Chiquito ir - let (circuit, wit_gen) = fibo_circuit::(); + let (circuit, wit_gen, _) = fibo_circuit::(); // get Plaf let (plaf, plaf_wit_gen) = chiquito2Plaf(circuit, 8, false); let wit = plaf_wit_gen.generate(wit_gen.map(|v| v.generate(()))); @@ -162,4 +168,15 @@ fn main() { println!("{}", failure); } } + + // pil boilerplate + use chiquito::pil::backend::powdr_pil::chiquito2Pil; + + let (_, wit_gen, circuit) = fibo_circuit::(); + let pil = chiquito2Pil( + circuit, + Some(wit_gen.unwrap().generate_trace_witness(())), + String::from("FiboCircuit"), + ); + print!("{}", pil); } diff --git a/examples/mimc7.rs b/examples/mimc7.rs index fa1b1d2e..1eafeb0e 100644 --- a/examples/mimc7.rs +++ b/examples/mimc7.rs @@ -175,7 +175,7 @@ fn mimc7_circuit( row_value += F::from(1); x_value += k_value + c_value; x_value = x_value.pow_vartime([7_u64]); - // Step 90: output the hash result as x + k in witness generation + // Step 91: output the hash result as x + k in witness generation // output is not displayed as a public column, which will be implemented in the future ctx.add(&mimc7_last_step, (x_value, k_value, c_value, row_value)); // c_value is not // used here but @@ -219,6 +219,29 @@ fn main() { println!("{}", failure); } } + + // pil boilerplate + use chiquito::pil::backend::powdr_pil::chiquitoSuperCircuit2Pil; + + let x_in_value = Fr::from_str_vartime("1").expect("expected a number"); + let k_value = Fr::from_str_vartime("2").expect("expected a number"); + + let super_circuit = mimc7_super_circuit::(); + + // `super_trace_witnesses` is a mapping from IR id to TraceWitness. However, not all ASTs have a + // corresponding TraceWitness. + let super_trace_witnesses = super_circuit + .get_mapping() + .generate_super_trace_witnesses((x_in_value, k_value)); + + let pil = chiquitoSuperCircuit2Pil::( + super_circuit.get_super_asts(), + super_trace_witnesses, + super_circuit.get_ast_id_to_ir_id_mapping(), + vec![String::from("Mimc7Constant"), String::from("Mimc7Circuit")], + ); + + print!("{}", pil); } mod mimc7_constants { diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index c90db77e..15def6ff 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -358,6 +358,10 @@ impl StepTypeHandler { pub fn next(&self) -> Queriable { Queriable::StepTypeNext(*self) } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } impl, Args) + 'static> From<&StepTypeWGHandler> diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index 9b6daec9..f9ef60d5 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -34,6 +34,12 @@ impl Default for SuperCircuitContext { } } +impl SuperCircuitContext { + fn add_sub_circuit_ast(&mut self, ast: SBPIR) { + self.super_circuit.add_sub_circuit_ast(ast); + } +} + impl SuperCircuitContext { pub fn sub_circuit( &mut self, @@ -48,12 +54,13 @@ impl SuperCircuitContext { circuit: SBPIR::default(), tables: self.tables.clone(), }; - println!("super circuit table registry 2: {:?}", self.tables); let exports = sub_circuit_def(&mut sub_circuit_context, imports); - println!("super circuit table registry 3: {:?}", self.tables); let sub_circuit = sub_circuit_context.circuit; + // ast is used for PIL backend + self.add_sub_circuit_ast(sub_circuit.clone_without_trace()); + let (unit, assignment) = compile_phase1(config, &sub_circuit); let assignment = assignment.unwrap_or_else(|| AssignmentGenerator::empty(unit.uuid)); diff --git a/src/frontend/python/chiquito/dsl.py b/src/frontend/python/chiquito/dsl.py index aa355972..ee41004f 100644 --- a/src/frontend/python/chiquito/dsl.py +++ b/src/frontend/python/chiquito/dsl.py @@ -225,6 +225,15 @@ def halo2_mock_prover(self: Circuit, witness: TraceWitness, k: int = 16): witness_json: str = witness.get_witness_json() rust_chiquito.halo2_mock_prover(witness_json, self.rust_id, k) + def to_pil( + self: Circuit, witness: TraceWitness, circuit_name: str = "Circuit" + ) -> str: + if self.rust_id == 0: + ast_json: str = self.get_ast_json() + self.rust_id: int = rust_chiquito.ast_to_halo2(ast_json) + witness_json: str = witness.get_witness_json() + rust_chiquito.to_pil(witness_json, self.rust_id, circuit_name) + def __str__(self: Circuit) -> str: return self.ast.__str__() diff --git a/src/frontend/python/mod.rs b/src/frontend/python/mod.rs index e6da88aa..69b3abc3 100644 --- a/src/frontend/python/mod.rs +++ b/src/frontend/python/mod.rs @@ -5,6 +5,7 @@ use pyo3::{ use crate::{ frontend::dsl::{StepTypeHandler, SuperCircuitContext}, + pil::backend::powdr_pil::chiquito2Pil, plonkish::{ backend::halo2::{ chiquito2Halo2, chiquitoSuperCircuit2Halo2, ChiquitoHalo2, ChiquitoHalo2Circuit, @@ -81,6 +82,14 @@ pub fn chiquito_ast_map_store(ast_json: &str) -> UUID { uuid } +pub fn chiquito_ast_to_pil(witness_json: &str, rust_id: UUID, circuit_name: &str) -> String { + let trace_witness: TraceWitness = + serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed."); + let (ast, _, _) = rust_id_to_halo2(rust_id); + + chiquito2Pil(ast, Some(trace_witness), circuit_name.to_string()) +} + fn add_assignment_generator_to_rust_id( assignment_generator: AssignmentGenerator, rust_id: UUID, @@ -1845,6 +1854,18 @@ fn ast_to_halo2(json: &PyString) -> u128 { uuid } +#[pyfunction] +fn to_pil(witness_json: &PyString, rust_id: &PyLong, circuit_name: &PyString) -> String { + let pil = chiquito_ast_to_pil( + witness_json.to_str().expect("PyString convertion failed."), + rust_id.extract().expect("PyLong convertion failed."), + circuit_name.to_str().expect("PyString convertion failed."), + ); + + println!("{}", pil); + pil +} + #[pyfunction] fn ast_map_store(json: &PyString) -> u128 { let uuid = chiquito_ast_map_store(json.to_str().expect("PyString conversion failed.")); @@ -1903,6 +1924,7 @@ fn rust_chiquito(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(convert_and_print_ast, m)?)?; m.add_function(wrap_pyfunction!(convert_and_print_trace_witness, m)?)?; m.add_function(wrap_pyfunction!(ast_to_halo2, m)?)?; + m.add_function(wrap_pyfunction!(to_pil, m)?)?; m.add_function(wrap_pyfunction!(ast_map_store, m)?)?; m.add_function(wrap_pyfunction!(halo2_mock_prover, m)?)?; m.add_function(wrap_pyfunction!(super_circuit_halo2_mock_prover, m)?)?; diff --git a/src/lib.rs b/src/lib.rs index dd823581..853e71ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod field; pub mod frontend; +pub mod pil; pub mod plonkish; pub mod poly; pub mod sbpir; diff --git a/src/pil/backend/mod.rs b/src/pil/backend/mod.rs new file mode 100644 index 00000000..d57b058b --- /dev/null +++ b/src/pil/backend/mod.rs @@ -0,0 +1 @@ +pub mod powdr_pil; diff --git a/src/pil/backend/powdr_pil.rs b/src/pil/backend/powdr_pil.rs new file mode 100644 index 00000000..4b86c591 --- /dev/null +++ b/src/pil/backend/powdr_pil.rs @@ -0,0 +1,226 @@ +use crate::{ + field::Field, + pil::{ + compiler::{compile, compile_super_circuits, PILColumn, PILExpr, PILQuery}, + ir::powdr_pil::PILCircuit, + }, + sbpir::SBPIR, + util::UUID, + wit_gen::TraceWitness, +}; +use std::{ + collections::HashMap, + fmt::{Debug, Write}, +}; +extern crate regex; + +#[allow(non_snake_case)] +/// User generate PIL code using this function. User needs to supply AST, TraceWitness, and a name +/// string for the circuit. +pub fn chiquito2Pil( + ast: SBPIR, + witness: Option>, + circuit_name: String, +) -> String { + // generate PIL IR. + let pil_ir = compile::(&ast, witness, circuit_name, &None); + + // generate Powdr PIL code. + pil_ir_to_powdr_pil::(pil_ir) +} + +// Convert PIL IR to Powdr PIL code. +pub fn pil_ir_to_powdr_pil(pil_ir: PILCircuit) -> String { + let mut pil = String::new(); // The string to return. + + writeln!( + pil, + "// ===== START OF CIRCUIT: {} =====", + pil_ir.circuit_name + ) + .unwrap(); + + // Namespace is equivalent to a circuit in PIL. + writeln!( + pil, + "constant %NUM_STEPS_{} = {};", + pil_ir.circuit_name.to_uppercase(), + pil_ir.num_steps + ) + .unwrap(); + writeln!( + pil, + "namespace {}(%NUM_STEPS_{});", + pil_ir.circuit_name, + pil_ir.circuit_name.to_uppercase() + ) + .unwrap(); + + // Declare witness columns in PIL. + generate_pil_witness_columns(&mut pil, &pil_ir); + + // Declare fixed columns and their assignments in PIL. + generate_pil_fixed_columns(&mut pil, &pil_ir); + + // generate constraints + for expr in pil_ir.constraints { + // recursively convert expressions to PIL strings + let expr_string = convert_to_pil_expr_string(expr.clone()); + // each constraint is in the format of `constraint = 0` + writeln!(pil, "{} = 0;", expr_string).unwrap(); + } + + // generate lookups + for lookup in pil_ir.lookups { + let (selector, src_dest_tuples) = lookup; + let lookup_selector = selector.annotation(); + let mut lookup_source: Vec = Vec::new(); + let mut lookup_destination: Vec = Vec::new(); + for (src, dest) in src_dest_tuples { + lookup_source.push(src.annotation()); + lookup_destination.push(dest.annotation()); + } + // PIL lookups have the format of `selector { src1, src2, ... srcn } in {dest1, dest2, ..., + // destn}`. + writeln!( + pil, + "{} {{{}}} in {{{}}} ", + lookup_selector, + lookup_source.join(", "), + lookup_destination.join(", ") + ) + .unwrap(); + } + + writeln!( + pil, + "// ===== END OF CIRCUIT: {} =====", + pil_ir.circuit_name + ) + .unwrap(); + writeln!(pil).unwrap(); // Separator row for the circuit. + + pil +} + +#[allow(non_snake_case)] +/// User generate PIL code for super circuit using this function. +/// User needs to supply a Vec for `circuit_names`, the order of which should be the same as +/// the order of calling `sub_circuit()` function. +pub fn chiquitoSuperCircuit2Pil( + super_asts: Vec>, + super_trace_witnesses: HashMap>, + ast_id_to_ir_id_mapping: HashMap, + circuit_names: Vec, +) -> String { + let mut pil = String::new(); // The string to return. + + // Generate PIL IRs for each sub circuit in the super circuit. + let pil_irs = compile_super_circuits( + super_asts, + super_trace_witnesses, + ast_id_to_ir_id_mapping, + circuit_names, + ); + + // Generate Powdr PIL code for each sub circuit. + for pil_ir in pil_irs { + let pil_circuit = pil_ir_to_powdr_pil(pil_ir); + writeln!(pil, "{}", pil_circuit).unwrap(); + } + + pil +} + +fn generate_pil_witness_columns(pil: &mut String, pil_ir: &PILCircuit) { + if !pil_ir.col_witness.is_empty() { + writeln!(pil, "// === Witness Columns ===").unwrap(); + let mut col_witness = String::from("col witness "); + + let mut col_witness_vars = pil_ir + .col_witness + .iter() + .map(|col| match col { + PILColumn::Advice(_, annotation) => annotation.clone(), + _ => panic!("Witness column should be an advice column."), + }) + .collect::>(); + + // Get unique witness column annotations + col_witness_vars.sort(); + col_witness_vars.dedup(); + col_witness = col_witness + col_witness_vars.join(", ").as_str() + ";"; + writeln!(pil, "{}", col_witness).unwrap(); + } +} + +fn generate_pil_fixed_columns(pil: &mut String, pil_ir: &PILCircuit) { + if !pil_ir.col_fixed.is_empty() { + writeln!( + pil, + "// === Fixed Columns for Signals and Step Type Selectors ===" + ) + .unwrap(); + for (col, assignments) in pil_ir.col_fixed.iter() { + let fixed_name = match col { + PILColumn::Fixed(_, annotation) => annotation.clone(), + _ => panic!("Fixed column should be an advice or fixed column."), + }; + let mut assignments_string = String::new(); + let assignments_vec = assignments + .iter() + .map(|assignment| format!("{:?}", assignment)) + .collect::>(); + write!( + assignments_string, + "{}", + assignments_vec.join(", ").as_str() + ) + .unwrap(); + writeln!(pil, "col fixed {} = [{}];", fixed_name, assignments_string).unwrap(); + } + } +} + +// Convert PIL expression to Powdr PIL string recursively. +fn convert_to_pil_expr_string(expr: PILExpr) -> String { + match expr { + PILExpr::Const(constant) => format!("{:?}", constant), + PILExpr::Sum(sum) => { + let mut expr_string = String::new(); + for (index, expr) in sum.iter().enumerate() { + expr_string += convert_to_pil_expr_string(expr.clone()).as_str(); + if index != sum.len() - 1 { + expr_string += " + "; + } + } + format!("({})", expr_string) + } + PILExpr::Mul(mul) => { + let mut expr_string = String::new(); + for (index, expr) in mul.iter().enumerate() { + expr_string += convert_to_pil_expr_string(expr.clone()).as_str(); + if index != mul.len() - 1 { + expr_string += " * "; + } + } + format!("{}", expr_string) + } + PILExpr::Neg(neg) => format!("(-{})", convert_to_pil_expr_string(*neg)), + PILExpr::Pow(pow, power) => { + format!("({})^{}", convert_to_pil_expr_string(*pow), power) + } + PILExpr::Query(queriable) => convert_to_pil_queriable_string(queriable), + } +} + +// Convert PIL query to Powdr PIL string recursively. +fn convert_to_pil_queriable_string(query: PILQuery) -> String { + let (col, rot) = query; + let annotation = col.annotation(); + if rot { + format!("{}'", annotation) + } else { + annotation + } +} diff --git a/src/pil/compiler/mod.rs b/src/pil/compiler/mod.rs new file mode 100644 index 00000000..0928f3ae --- /dev/null +++ b/src/pil/compiler/mod.rs @@ -0,0 +1,578 @@ +use crate::{ + field::Field, + pil::ir::powdr_pil::{PILCircuit, PILLookup}, + poly::Expr, + sbpir::{query::Queriable, SBPIR}, + util::{uuid, UUID}, + wit_gen::TraceWitness, +}; +use std::{collections::HashMap, fmt::Debug, hash::Hash}; +extern crate regex; + +pub fn compile( + ast: &SBPIR, + witness: Option>, + circuit_name: String, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILCircuit { + let col_witness = collect_witness_columns(ast); + + // HashMap of fixed column to fixed assignments + let mut col_fixed = HashMap::new(); + + if let Some(fixed_assignments) = &ast.fixed_assignments { + fixed_assignments + .iter() + .for_each(|(queriable, assignments)| { + let uuid = queriable.uuid(); + col_fixed.insert( + PILColumn::Fixed( + uuid, + clean_annotation(ast.annotations.get(&uuid).unwrap().clone()), + ), + assignments.clone(), + ); + }); + } + + // Get last step instance UUID, so that we can disable transition of that instance + let mut last_step_instance = 0; + + // Insert into col_fixed the map from step type fixed column to vector of {0,1} where 1 means + // the step type is instantiated whereas 0 not. Each vector should have the same length as the + // number of steps. + if !ast.step_types.is_empty() && witness.is_some() { + let step_instances = witness.as_ref().unwrap().step_instances.iter(); + + // Get last step instance, so that we can disable transition of that instance + last_step_instance = step_instances.clone().last().unwrap().step_type_uuid; + + for step_type in ast.step_types.values() { + let step_type_instantiation: Vec = step_instances + .clone() + .map(|step_instance| { + if step_instance.step_type_uuid == step_type.uuid() { + F::ONE + } else { + F::ZERO + } + }) + .collect(); + assert_eq!(step_type_instantiation.len(), ast.num_steps); + let uuid = step_type.uuid(); + col_fixed.insert( + PILColumn::Fixed( + uuid, + clean_annotation(ast.annotations.get(&uuid).unwrap().clone()), + ), + step_type_instantiation, + ); + } + } + + // Create new UUID for ISFIRST and ISLAST. These are fixed columns unique to PIL. + let is_first_uuid = uuid(); + let is_last_uuid = uuid(); + + // ISFIRST and ISLAST are only relevant when there's non zero number of step instances. + let num_step_instances = witness + .as_ref() + .map(|w| w.step_instances.len()) + .unwrap_or(0); + if num_step_instances != 0 { + // 1 for first row and 0 for all other rows; number of rows equals to number of steps + let is_first_assignments = vec![F::ONE] + .into_iter() + .chain(std::iter::repeat(F::ZERO)) + .take(ast.num_steps) + .collect(); + col_fixed.insert( + PILColumn::Fixed(is_first_uuid, String::from("ISFIRST")), + is_first_assignments, + ); + + // 0 for all rows except the last row, which is 1; number of rows equals to number of steps + let is_last_assignments = std::iter::repeat(F::ZERO) + .take(ast.num_steps - 1) + .chain(std::iter::once(F::ONE)) + .collect(); + col_fixed.insert( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + is_last_assignments, + ); + } + + // Compile step type elements, i.e. constraints, transitions, and lookups. + let (mut constraints, lookups) = compile_steps( + ast, + last_step_instance, + is_last_uuid, + super_circuit_annotations_map, + ); + + // Insert pragma_first_step and pragma_last_step as constraints + if let Some(first_step) = ast.first_step { + // is_first * (1 - first_step) = 0 + constraints.push(PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed(is_first_uuid, String::from("ISFIRST")), + false, + )), + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed( + first_step, + clean_annotation(ast.annotations.get(&first_step).unwrap().clone()), + ), + false, + )))), + ]), + ])); + } + + if let Some(last_step) = ast.last_step { + // is_last * (1 - last_step) = 0 + constraints.push(PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + false, + )), + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed( + last_step, + clean_annotation(ast.annotations.get(&last_step).unwrap().clone()), + ), + false, + )))), + ]), + ])); + } + + PILCircuit { + circuit_name, + num_steps: ast.num_steps, + col_witness, + col_fixed, + constraints, + lookups, + } +} + +pub fn compile_super_circuits( + super_asts: Vec>, + super_trace_witnesses: HashMap>, + ast_id_to_ir_id_mapping: HashMap, + circuit_names: Vec, +) -> Vec> { + assert!(super_asts.len() == circuit_names.len()); + + // Get annotations map for the super circuit, which is a HashMap of object UUID to object + // annotation. + let mut super_circuit_annotations_map: HashMap = HashMap::new(); + + // Loop over each AST. + for (ast, circuit_name) in super_asts.iter().zip(circuit_names.iter()) { + // Create `annotations_map` for each AST, to be added to `super_circuit_annotations_map`. + let mut annotations_map: HashMap = HashMap::new(); + + // First, get AST level annotations. + annotations_map.extend(ast.annotations.clone()); + + // Second, get step level annotations. + for step_type in ast.step_types.values() { + annotations_map.extend(step_type.annotations.clone()); + } + + // Convert annotation to circuit_name.annotation, because this is the general format of + // referring to variables in PIL if there are more than one circuit. + super_circuit_annotations_map.extend(annotations_map.into_iter().map( + |(uuid, annotation)| { + ( + uuid, + format!("{}.{}", circuit_name.clone(), clean_annotation(annotation)), + ) + }, + )); + + // Finally, get annotations for the circuit names. + super_circuit_annotations_map.insert(ast.id, circuit_name.clone()); + } + + // For each AST, find its corresponding TraceWitness. Note that some AST might not have a + // corresponding TraceWitness, so witness is an Option. + let mut pil_irs = Vec::new(); + for (ast, circuit_name) in super_asts.iter().zip(circuit_names.iter()) { + let witness = super_trace_witnesses.get(ast_id_to_ir_id_mapping.get(&ast.id).unwrap()); + + // Create PIL IR + let pil_ir = compile( + ast, + witness.cloned(), + circuit_name.clone(), + &Some(&super_circuit_annotations_map), + ); + + pil_irs.push(pil_ir); + } + + pil_irs +} + +fn collect_witness_columns(ast: &SBPIR) -> Vec { + let mut col_witness = Vec::new(); + + // Collect internal signals to witness columns. + col_witness.extend( + ast.step_types + .values() + .flat_map(|step_type| { + step_type + .signals + .iter() + .map(|signal| { + PILColumn::Advice(signal.uuid(), clean_annotation(signal.annotation())) + }) + .collect::>() + }) + .collect::>(), + ); + + // Collect forward signals to witness columns. + col_witness.extend( + ast.forward_signals + .iter() + .map(|forward_signal| { + PILColumn::Advice( + forward_signal.uuid(), + clean_annotation(forward_signal.annotation()), + ) + }) + .collect::>(), + ); + + // Collect shared signals to witness columns. + col_witness.extend( + ast.shared_signals + .iter() + .map(|shared_signal| { + PILColumn::Advice( + shared_signal.uuid(), + clean_annotation(shared_signal.annotation()), + ) + }) + .collect::>(), + ); + + col_witness +} + +fn compile_steps( + ast: &SBPIR, + last_step_instance: UUID, + is_last_uuid: UUID, + super_circuit_annotations_map: &Option<&HashMap>, +) -> (Vec>, Vec) { + // transitions and constraints all become constraints in PIL + let mut constraints = Vec::new(); + let mut lookups = Vec::new(); + + if !ast.step_types.is_empty() { + ast.step_types.values().for_each(|step_type| { + // Create constraint statements. + constraints.extend( + step_type + .constraints + .iter() + .map(|constraint| { + PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed( + step_type.uuid(), + clean_annotation(step_type.name()), + ), + false, + )), + chiquito_expr_to_pil_expr( + constraint.expr.clone(), + super_circuit_annotations_map, + ), + ]) + }) + .collect::>>(), + ); + + // There's no distinction between constraint and transition in PIL + // However, we do need to identify constraints with rotation in the last row + // and disable them + constraints.extend( + step_type + .transition_constraints + .iter() + .map(|transition| { + let res = PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed( + step_type.uuid(), + clean_annotation(step_type.name()), + ), + false, + )), + chiquito_expr_to_pil_expr( + transition.expr.clone(), + super_circuit_annotations_map, + ), + ]); + if step_type.uuid() == last_step_instance { + PILExpr::Mul(vec![ + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + false, + )))), + ]), + res, + ]) + } else { + res + } + }) + .collect::>>(), + ); + + lookups.extend( + step_type + .lookups + .iter() + .map(|lookup| { + ( + PILColumn::Fixed(step_type.uuid(), clean_annotation(step_type.name())), + lookup + .exprs + .iter() + .map(|(lhs, rhs)| { + ( + chiquito_lookup_column_to_pil_column( + lhs.expr.clone(), + super_circuit_annotations_map, + ), + chiquito_lookup_column_to_pil_column( + rhs.clone(), + super_circuit_annotations_map, + ), + ) + }) + .collect::>(), + ) + }) + .collect::>(), + ); + }); + } + + (constraints, lookups) +} + +// Convert lookup columns (src and dest) in Chiquito to PIL column. Note that Chiquito lookup +// columns have to be Expr::Query type. +fn chiquito_lookup_column_to_pil_column( + src: Expr>, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILColumn { + match src { + Expr::Query(queriable) => { + chiquito_queriable_to_pil_query(queriable, super_circuit_annotations_map).0 + } + _ => panic!("Lookup source is not queriable."), + } +} + +// PIL expression and constraint +#[derive(Clone)] +pub enum PILExpr { + Const(F), + Sum(Vec>), + Mul(Vec>), + Neg(Box>), + Pow(Box>, u32), + Query(PILQuery), +} + +fn chiquito_expr_to_pil_expr( + expr: Expr>, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILExpr { + match expr { + Expr::Const(constant) => PILExpr::Const(constant), + Expr::Sum(sum) => { + let mut pil_sum = Vec::new(); + for expr in sum { + pil_sum.push(chiquito_expr_to_pil_expr( + expr, + super_circuit_annotations_map, + )); + } + PILExpr::Sum(pil_sum) + } + Expr::Mul(mul) => { + let mut pil_mul = Vec::new(); + for expr in mul { + pil_mul.push(chiquito_expr_to_pil_expr( + expr, + super_circuit_annotations_map, + )); + } + PILExpr::Mul(pil_mul) + } + Expr::Neg(neg) => PILExpr::Neg(Box::new(chiquito_expr_to_pil_expr( + *neg, + super_circuit_annotations_map, + ))), + Expr::Pow(pow, power) => PILExpr::Pow( + Box::new(chiquito_expr_to_pil_expr( + *pow, + super_circuit_annotations_map, + )), + power, + ), + Expr::Query(queriable) => PILExpr::Query(chiquito_queriable_to_pil_query( + queriable, + super_circuit_annotations_map, + )), + Expr::Halo2Expr(_) => { + panic!("Halo2 native expression not supported by PIL backend.") + } + Expr::MI(_) => { + panic!("MI not supported by PIL backend.") + } + } +} + +pub type PILQuery = (PILColumn, bool); // column, rotation + +#[derive(Clone, PartialEq, Eq, Hash)] +pub enum PILColumn { + Advice(UUID, String), // UUID, annotation + Fixed(UUID, String), +} + +impl PILColumn { + pub fn uuid(&self) -> UUID { + match self { + PILColumn::Advice(uuid, _) => *uuid, + PILColumn::Fixed(uuid, _) => *uuid, + } + } + + pub fn annotation(&self) -> String { + match self { + PILColumn::Advice(_, annotation) => annotation.clone(), + PILColumn::Fixed(_, annotation) => annotation.clone(), + } + } +} + +pub fn clean_annotation(annotation: String) -> String { + annotation.replace(' ', "_") +} + +// Convert queriable to PIL column recursively. Major differences are: 1. PIL doesn't distinguish +// internal, forward, or shared columns as they are all advice; 2. PIL only supports the next +// rotation, so there's no previous or arbitrary rotation. +fn chiquito_queriable_to_pil_query( + query: Queriable, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILQuery { + match query { + Queriable::Internal(s) => { + if super_circuit_annotations_map.is_none() { + ( + PILColumn::Advice(s.uuid(), clean_annotation(s.annotation())), + false, + ) + } else { + let annotation = super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap(); + ( + PILColumn::Advice(s.uuid(), clean_annotation(annotation.clone())), + false, + ) + } + } + Queriable::Forward(s, rot) => { + if super_circuit_annotations_map.is_none() { + ( + PILColumn::Advice(s.uuid(), clean_annotation(s.annotation())), + rot, + ) + } else { + let annotation = super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap(); + ( + PILColumn::Advice(s.uuid(), clean_annotation(annotation.clone())), + rot, + ) + } + } + Queriable::Shared(s, rot) => { + let annotation = if super_circuit_annotations_map.is_none() { + clean_annotation(s.annotation()) + } else { + super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap() + .clone() + }; + if rot == 0 { + (PILColumn::Advice(s.uuid(), annotation), false) + } else if rot == 1 { + (PILColumn::Advice(s.uuid(), annotation), true) + } else { + panic!( + "PIL backend does not support shared signal with rotation other than 0 or 1." + ) + } + } + Queriable::Fixed(s, rot) => { + let annotation = if super_circuit_annotations_map.is_none() { + clean_annotation(s.annotation()) + } else { + super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap() + .clone() + }; + if rot == 0 { + (PILColumn::Fixed(s.uuid(), annotation), false) + } else if rot == 1 { + (PILColumn::Fixed(s.uuid(), annotation), true) + } else { + panic!("PIL backend does not support fixed signal with rotation other than 0 or 1.") + } + } + Queriable::StepTypeNext(s) => ( + PILColumn::Fixed(s.uuid(), clean_annotation(s.annotation())), + true, + ), + Queriable::Halo2AdviceQuery(_, _) => { + panic!("Halo2 native advice query not supported by PIL backend.") + } + Queriable::Halo2FixedQuery(_, _) => { + panic!("Halo2 native fixed query not supported by PIL backend.") + } + Queriable::_unaccessible(_) => todo!(), + } +} diff --git a/src/pil/ir/mod.rs b/src/pil/ir/mod.rs new file mode 100644 index 00000000..d57b058b --- /dev/null +++ b/src/pil/ir/mod.rs @@ -0,0 +1 @@ +pub mod powdr_pil; diff --git a/src/pil/ir/powdr_pil.rs b/src/pil/ir/powdr_pil.rs new file mode 100644 index 00000000..576d147b --- /dev/null +++ b/src/pil/ir/powdr_pil.rs @@ -0,0 +1,18 @@ +use crate::pil::compiler::{PILColumn, PILExpr, PILQuery}; +use std::collections::HashMap; +extern crate regex; + +// PIL circuit IR +pub struct PILCircuit { + pub circuit_name: String, + pub num_steps: usize, + pub col_witness: Vec, + pub col_fixed: HashMap>, // column -> assignments + pub constraints: Vec>, + pub lookups: Vec, +} + +// lookup in PIL is the format of selector {src1, src2, ..., srcn} -> {dst1, dst2, ..., dstn} +// PILLookup is a tuple of (selector, Vec) tuples, where selector is converted from +// Chiquito step type to fixed column +pub type PILLookup = (PILColumn, Vec<(PILColumn, PILColumn)>); diff --git a/src/pil/mod.rs b/src/pil/mod.rs new file mode 100644 index 00000000..aab5b126 --- /dev/null +++ b/src/pil/mod.rs @@ -0,0 +1,3 @@ +pub mod backend; +pub mod compiler; +pub mod ir; diff --git a/src/plonkish/ir/assignments.rs b/src/plonkish/ir/assignments.rs index 9f605d7a..60a91a48 100644 --- a/src/plonkish/ir/assignments.rs +++ b/src/plonkish/ir/assignments.rs @@ -132,8 +132,12 @@ impl AssignmentGenerator { } } + pub fn generate_trace_witness(&self, args: TraceArgs) -> TraceWitness { + self.trace_gen.generate(args) + } + pub fn generate(&self, args: TraceArgs) -> Assignments { - let witness = self.trace_gen.generate(args); + let witness = self.generate_trace_witness(args); self.generate_with_witness(witness) } diff --git a/src/plonkish/ir/sc.rs b/src/plonkish/ir/sc.rs index 4398bc31..05cf4663 100644 --- a/src/plonkish/ir/sc.rs +++ b/src/plonkish/ir/sc.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, hash::Hash, rc::Rc}; -use crate::{field::Field, util::UUID, wit_gen::TraceWitness}; +use crate::{field::Field, sbpir::SBPIR, util::UUID, wit_gen::TraceWitness}; use super::{ assignments::{AssignmentGenerator, Assignments}, @@ -10,6 +10,7 @@ use super::{ pub struct SuperCircuit { sub_circuits: Vec>, mapping: MappingGenerator, + sub_circuit_asts: Vec>, } impl Default for SuperCircuit { @@ -17,6 +18,7 @@ impl Default for SuperCircuit { Self { sub_circuits: Default::default(), mapping: Default::default(), + sub_circuit_asts: Default::default(), } } } @@ -29,6 +31,30 @@ impl SuperCircuit { pub fn get_mapping(&self) -> MappingGenerator { self.mapping.clone() } + + // Needed for the PIL backend. + pub fn add_sub_circuit_ast(&mut self, sub_circuit_ast: SBPIR) { + self.sub_circuit_asts.push(sub_circuit_ast); + } + + // Mapping from AST id to IR id is needed for the PIL backend to match TraceWitness, which has + // IR id, to AST. + pub fn get_ast_id_to_ir_id_mapping(&self) -> HashMap { + let mut ast_id_to_ir_id_mapping: HashMap = HashMap::new(); + self.sub_circuits.iter().for_each(|circuit| { + let ir_id = circuit.id; + let ast_id = circuit.ast_id; + ast_id_to_ir_id_mapping.insert(ast_id, ir_id); + }); + ast_id_to_ir_id_mapping + } +} + +// Needed for the PIL backend. +impl SuperCircuit { + pub fn get_super_asts(&self) -> Vec> { + self.sub_circuit_asts.clone() + } } impl SuperCircuit { @@ -47,22 +73,29 @@ impl SuperCircuit { } pub type SuperAssignments = HashMap>; +pub type SuperTraceWitness = HashMap>; pub struct MappingContext { assignments: SuperAssignments, + trace_witnesses: SuperTraceWitness, } -impl Default for MappingContext { +impl Default for MappingContext { fn default() -> Self { Self { assignments: Default::default(), + trace_witnesses: Default::default(), } } } impl MappingContext { pub fn map(&mut self, gen: &AssignmentGenerator, args: TraceArgs) { - self.assignments.insert(gen.uuid(), gen.generate(args)); + let trace_witness = gen.generate_trace_witness(args); + self.trace_witnesses + .insert(gen.uuid(), trace_witness.clone()); + self.assignments + .insert(gen.uuid(), gen.generate_with_witness(trace_witness)); } pub fn map_with_witness( @@ -77,6 +110,10 @@ impl MappingContext { pub fn get_super_assignments(self) -> SuperAssignments { self.assignments } + + pub fn get_trace_witnesses(self) -> SuperTraceWitness { + self.trace_witnesses + } } pub type Mapping = dyn Fn(&mut MappingContext, MappingArgs) + 'static; @@ -113,4 +150,13 @@ impl MappingGenerator { ctx.get_super_assignments() } + + // Needed for the PIL backend. + pub fn generate_super_trace_witnesses(&self, args: MappingArgs) -> SuperTraceWitness { + let mut ctx = MappingContext::default(); + + (self.mapping)(&mut ctx, args); + + ctx.get_trace_witnesses() + } } diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index b5a77029..217f2ba8 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -204,6 +204,28 @@ impl SBPIR { } } +impl SBPIR { + pub fn clone_without_trace(&self) -> SBPIR { + SBPIR { + step_types: self.step_types.clone(), + forward_signals: self.forward_signals.clone(), + shared_signals: self.shared_signals.clone(), + fixed_signals: self.fixed_signals.clone(), + halo2_advice: self.halo2_advice.clone(), + halo2_fixed: self.halo2_fixed.clone(), + exposed: self.exposed.clone(), + annotations: self.annotations.clone(), + trace: None, // Remove the trace. + fixed_assignments: self.fixed_assignments.clone(), + first_step: self.first_step, + last_step: self.last_step, + num_steps: self.num_steps, + q_enable: self.q_enable, + id: self.id, + } + } +} + pub type FixedGen = dyn Fn(&mut FixedGenContext) + 'static; pub type StepTypeUUID = UUID; @@ -253,6 +275,10 @@ impl StepType { self.id } + pub fn name(&self) -> String { + self.name.clone() + } + pub fn add_signal>(&mut self, name: N) -> InternalSignal { let name = name.into(); let signal = InternalSignal::new(name.clone()); @@ -417,6 +443,10 @@ impl ForwardSignal { pub fn phase(&self) -> usize { self.phase } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -450,6 +480,10 @@ impl SharedSignal { pub fn phase(&self) -> usize { self.phase } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -476,6 +510,10 @@ impl FixedSignal { pub fn uuid(&self) -> UUID { self.id } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug)] @@ -508,6 +546,10 @@ impl InternalSignal { pub fn uuid(&self) -> UUID { self.id } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]