diff --git a/Cargo.toml b/Cargo.toml index c002d68b..58d1fd0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ num-bigint = { version = "0.4", features = ["rand"] } uuid = { version = "1.4.0", features = ["v1", "rng"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +regex = "1" [dev-dependencies] rand_chacha = "0.3" diff --git a/examples/fibonacci.py b/examples/fibonacci.py index 4ffa7b31..0c7aedfa 100644 --- a/examples/fibonacci.py +++ b/examples/fibonacci.py @@ -84,3 +84,5 @@ def trace(self, n): ) # 2^k specifies the number of PLONKish table rows in Halo2 another_fibo_witness = fibo.gen_witness(4) fibo.halo2_mock_prover(another_fibo_witness, k=7) + +fibo.to_pil(fibo_witness, "FiboCircuit") diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 4021cdda..dbf7dbba 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -18,6 +18,7 @@ use chiquito::{ }, plonkish::ir::{assignments::AssignmentGenerator, Circuit}, // compiled circuit type poly::ToField, + sbpir::SBPIR, }; use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; @@ -26,7 +27,10 @@ use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; // 1. type that implements a field trait // 2. empty trace arguments, i.e. (), because there are no external inputs to the Chiquito circuit // 3. two witness generation arguments both of u64 type, i.e. (u64, u64) -fn fibo_circuit + Hash>() -> (Circuit, Option>) { + +type FiboReturn = (Circuit, Option>, SBPIR); + +fn fibo_circuit + Hash>() -> FiboReturn { // PLONKish table for the Fibonacci circuit: // | a | b | c | // | 1 | 1 | 2 | @@ -73,7 +77,7 @@ fn fibo_circuit + Hash>() -> (Circuit, Option + Hash>() -> (Circuit, Option + Hash>() -> (Circuit, Option(); + let (chiquito, wit_gen, _) = fibo_circuit::(); let compiled = chiquito2Halo2(chiquito); let circuit = ChiquitoHalo2Circuit::new(compiled, wit_gen.map(|g| g.generate(()))); @@ -137,7 +143,7 @@ fn main() { use polyexen::plaf::{backends::halo2::PlafH2Circuit, WitnessDisplayCSV}; // get Chiquito ir - let (circuit, wit_gen) = fibo_circuit::(); + let (circuit, wit_gen, _) = fibo_circuit::(); // get Plaf let (plaf, plaf_wit_gen) = chiquito2Plaf(circuit, 8, false); let wit = plaf_wit_gen.generate(wit_gen.map(|v| v.generate(()))); @@ -162,4 +168,15 @@ fn main() { println!("{}", failure); } } + + // pil boilerplate + use chiquito::pil::backend::powdr_pil::chiquito2Pil; + + let (_, wit_gen, circuit) = fibo_circuit::(); + let pil = chiquito2Pil( + circuit, + Some(wit_gen.unwrap().generate_trace_witness(())), + String::from("FiboCircuit"), + ); + print!("{}", pil); } diff --git a/examples/mimc7.rs b/examples/mimc7.rs index fa1b1d2e..1eafeb0e 100644 --- a/examples/mimc7.rs +++ b/examples/mimc7.rs @@ -175,7 +175,7 @@ fn mimc7_circuit( row_value += F::from(1); x_value += k_value + c_value; x_value = x_value.pow_vartime([7_u64]); - // Step 90: output the hash result as x + k in witness generation + // Step 91: output the hash result as x + k in witness generation // output is not displayed as a public column, which will be implemented in the future ctx.add(&mimc7_last_step, (x_value, k_value, c_value, row_value)); // c_value is not // used here but @@ -219,6 +219,29 @@ fn main() { println!("{}", failure); } } + + // pil boilerplate + use chiquito::pil::backend::powdr_pil::chiquitoSuperCircuit2Pil; + + let x_in_value = Fr::from_str_vartime("1").expect("expected a number"); + let k_value = Fr::from_str_vartime("2").expect("expected a number"); + + let super_circuit = mimc7_super_circuit::(); + + // `super_trace_witnesses` is a mapping from IR id to TraceWitness. However, not all ASTs have a + // corresponding TraceWitness. + let super_trace_witnesses = super_circuit + .get_mapping() + .generate_super_trace_witnesses((x_in_value, k_value)); + + let pil = chiquitoSuperCircuit2Pil::( + super_circuit.get_super_asts(), + super_trace_witnesses, + super_circuit.get_ast_id_to_ir_id_mapping(), + vec![String::from("Mimc7Constant"), String::from("Mimc7Circuit")], + ); + + print!("{}", pil); } mod mimc7_constants { diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index c90db77e..15def6ff 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -358,6 +358,10 @@ impl StepTypeHandler { pub fn next(&self) -> Queriable { Queriable::StepTypeNext(*self) } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } impl, Args) + 'static> From<&StepTypeWGHandler> diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index 9b6daec9..f9ef60d5 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -34,6 +34,12 @@ impl Default for SuperCircuitContext { } } +impl SuperCircuitContext { + fn add_sub_circuit_ast(&mut self, ast: SBPIR) { + self.super_circuit.add_sub_circuit_ast(ast); + } +} + impl SuperCircuitContext { pub fn sub_circuit( &mut self, @@ -48,12 +54,13 @@ impl SuperCircuitContext { circuit: SBPIR::default(), tables: self.tables.clone(), }; - println!("super circuit table registry 2: {:?}", self.tables); let exports = sub_circuit_def(&mut sub_circuit_context, imports); - println!("super circuit table registry 3: {:?}", self.tables); let sub_circuit = sub_circuit_context.circuit; + // ast is used for PIL backend + self.add_sub_circuit_ast(sub_circuit.clone_without_trace()); + let (unit, assignment) = compile_phase1(config, &sub_circuit); let assignment = assignment.unwrap_or_else(|| AssignmentGenerator::empty(unit.uuid)); diff --git a/src/frontend/python/chiquito/dsl.py b/src/frontend/python/chiquito/dsl.py index aa355972..ee41004f 100644 --- a/src/frontend/python/chiquito/dsl.py +++ b/src/frontend/python/chiquito/dsl.py @@ -225,6 +225,15 @@ def halo2_mock_prover(self: Circuit, witness: TraceWitness, k: int = 16): witness_json: str = witness.get_witness_json() rust_chiquito.halo2_mock_prover(witness_json, self.rust_id, k) + def to_pil( + self: Circuit, witness: TraceWitness, circuit_name: str = "Circuit" + ) -> str: + if self.rust_id == 0: + ast_json: str = self.get_ast_json() + self.rust_id: int = rust_chiquito.ast_to_halo2(ast_json) + witness_json: str = witness.get_witness_json() + rust_chiquito.to_pil(witness_json, self.rust_id, circuit_name) + def __str__(self: Circuit) -> str: return self.ast.__str__() diff --git a/src/frontend/python/mod.rs b/src/frontend/python/mod.rs index e6da88aa..69b3abc3 100644 --- a/src/frontend/python/mod.rs +++ b/src/frontend/python/mod.rs @@ -5,6 +5,7 @@ use pyo3::{ use crate::{ frontend::dsl::{StepTypeHandler, SuperCircuitContext}, + pil::backend::powdr_pil::chiquito2Pil, plonkish::{ backend::halo2::{ chiquito2Halo2, chiquitoSuperCircuit2Halo2, ChiquitoHalo2, ChiquitoHalo2Circuit, @@ -81,6 +82,14 @@ pub fn chiquito_ast_map_store(ast_json: &str) -> UUID { uuid } +pub fn chiquito_ast_to_pil(witness_json: &str, rust_id: UUID, circuit_name: &str) -> String { + let trace_witness: TraceWitness = + serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed."); + let (ast, _, _) = rust_id_to_halo2(rust_id); + + chiquito2Pil(ast, Some(trace_witness), circuit_name.to_string()) +} + fn add_assignment_generator_to_rust_id( assignment_generator: AssignmentGenerator, rust_id: UUID, @@ -1845,6 +1854,18 @@ fn ast_to_halo2(json: &PyString) -> u128 { uuid } +#[pyfunction] +fn to_pil(witness_json: &PyString, rust_id: &PyLong, circuit_name: &PyString) -> String { + let pil = chiquito_ast_to_pil( + witness_json.to_str().expect("PyString convertion failed."), + rust_id.extract().expect("PyLong convertion failed."), + circuit_name.to_str().expect("PyString convertion failed."), + ); + + println!("{}", pil); + pil +} + #[pyfunction] fn ast_map_store(json: &PyString) -> u128 { let uuid = chiquito_ast_map_store(json.to_str().expect("PyString conversion failed.")); @@ -1903,6 +1924,7 @@ fn rust_chiquito(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(convert_and_print_ast, m)?)?; m.add_function(wrap_pyfunction!(convert_and_print_trace_witness, m)?)?; m.add_function(wrap_pyfunction!(ast_to_halo2, m)?)?; + m.add_function(wrap_pyfunction!(to_pil, m)?)?; m.add_function(wrap_pyfunction!(ast_map_store, m)?)?; m.add_function(wrap_pyfunction!(halo2_mock_prover, m)?)?; m.add_function(wrap_pyfunction!(super_circuit_halo2_mock_prover, m)?)?; diff --git a/src/lib.rs b/src/lib.rs index dd823581..853e71ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod field; pub mod frontend; +pub mod pil; pub mod plonkish; pub mod poly; pub mod sbpir; diff --git a/src/pil/backend/mod.rs b/src/pil/backend/mod.rs new file mode 100644 index 00000000..d57b058b --- /dev/null +++ b/src/pil/backend/mod.rs @@ -0,0 +1 @@ +pub mod powdr_pil; diff --git a/src/pil/backend/powdr_pil.rs b/src/pil/backend/powdr_pil.rs new file mode 100644 index 00000000..4b86c591 --- /dev/null +++ b/src/pil/backend/powdr_pil.rs @@ -0,0 +1,226 @@ +use crate::{ + field::Field, + pil::{ + compiler::{compile, compile_super_circuits, PILColumn, PILExpr, PILQuery}, + ir::powdr_pil::PILCircuit, + }, + sbpir::SBPIR, + util::UUID, + wit_gen::TraceWitness, +}; +use std::{ + collections::HashMap, + fmt::{Debug, Write}, +}; +extern crate regex; + +#[allow(non_snake_case)] +/// User generate PIL code using this function. User needs to supply AST, TraceWitness, and a name +/// string for the circuit. +pub fn chiquito2Pil( + ast: SBPIR, + witness: Option>, + circuit_name: String, +) -> String { + // generate PIL IR. + let pil_ir = compile::(&ast, witness, circuit_name, &None); + + // generate Powdr PIL code. + pil_ir_to_powdr_pil::(pil_ir) +} + +// Convert PIL IR to Powdr PIL code. +pub fn pil_ir_to_powdr_pil(pil_ir: PILCircuit) -> String { + let mut pil = String::new(); // The string to return. + + writeln!( + pil, + "// ===== START OF CIRCUIT: {} =====", + pil_ir.circuit_name + ) + .unwrap(); + + // Namespace is equivalent to a circuit in PIL. + writeln!( + pil, + "constant %NUM_STEPS_{} = {};", + pil_ir.circuit_name.to_uppercase(), + pil_ir.num_steps + ) + .unwrap(); + writeln!( + pil, + "namespace {}(%NUM_STEPS_{});", + pil_ir.circuit_name, + pil_ir.circuit_name.to_uppercase() + ) + .unwrap(); + + // Declare witness columns in PIL. + generate_pil_witness_columns(&mut pil, &pil_ir); + + // Declare fixed columns and their assignments in PIL. + generate_pil_fixed_columns(&mut pil, &pil_ir); + + // generate constraints + for expr in pil_ir.constraints { + // recursively convert expressions to PIL strings + let expr_string = convert_to_pil_expr_string(expr.clone()); + // each constraint is in the format of `constraint = 0` + writeln!(pil, "{} = 0;", expr_string).unwrap(); + } + + // generate lookups + for lookup in pil_ir.lookups { + let (selector, src_dest_tuples) = lookup; + let lookup_selector = selector.annotation(); + let mut lookup_source: Vec = Vec::new(); + let mut lookup_destination: Vec = Vec::new(); + for (src, dest) in src_dest_tuples { + lookup_source.push(src.annotation()); + lookup_destination.push(dest.annotation()); + } + // PIL lookups have the format of `selector { src1, src2, ... srcn } in {dest1, dest2, ..., + // destn}`. + writeln!( + pil, + "{} {{{}}} in {{{}}} ", + lookup_selector, + lookup_source.join(", "), + lookup_destination.join(", ") + ) + .unwrap(); + } + + writeln!( + pil, + "// ===== END OF CIRCUIT: {} =====", + pil_ir.circuit_name + ) + .unwrap(); + writeln!(pil).unwrap(); // Separator row for the circuit. + + pil +} + +#[allow(non_snake_case)] +/// User generate PIL code for super circuit using this function. +/// User needs to supply a Vec for `circuit_names`, the order of which should be the same as +/// the order of calling `sub_circuit()` function. +pub fn chiquitoSuperCircuit2Pil( + super_asts: Vec>, + super_trace_witnesses: HashMap>, + ast_id_to_ir_id_mapping: HashMap, + circuit_names: Vec, +) -> String { + let mut pil = String::new(); // The string to return. + + // Generate PIL IRs for each sub circuit in the super circuit. + let pil_irs = compile_super_circuits( + super_asts, + super_trace_witnesses, + ast_id_to_ir_id_mapping, + circuit_names, + ); + + // Generate Powdr PIL code for each sub circuit. + for pil_ir in pil_irs { + let pil_circuit = pil_ir_to_powdr_pil(pil_ir); + writeln!(pil, "{}", pil_circuit).unwrap(); + } + + pil +} + +fn generate_pil_witness_columns(pil: &mut String, pil_ir: &PILCircuit) { + if !pil_ir.col_witness.is_empty() { + writeln!(pil, "// === Witness Columns ===").unwrap(); + let mut col_witness = String::from("col witness "); + + let mut col_witness_vars = pil_ir + .col_witness + .iter() + .map(|col| match col { + PILColumn::Advice(_, annotation) => annotation.clone(), + _ => panic!("Witness column should be an advice column."), + }) + .collect::>(); + + // Get unique witness column annotations + col_witness_vars.sort(); + col_witness_vars.dedup(); + col_witness = col_witness + col_witness_vars.join(", ").as_str() + ";"; + writeln!(pil, "{}", col_witness).unwrap(); + } +} + +fn generate_pil_fixed_columns(pil: &mut String, pil_ir: &PILCircuit) { + if !pil_ir.col_fixed.is_empty() { + writeln!( + pil, + "// === Fixed Columns for Signals and Step Type Selectors ===" + ) + .unwrap(); + for (col, assignments) in pil_ir.col_fixed.iter() { + let fixed_name = match col { + PILColumn::Fixed(_, annotation) => annotation.clone(), + _ => panic!("Fixed column should be an advice or fixed column."), + }; + let mut assignments_string = String::new(); + let assignments_vec = assignments + .iter() + .map(|assignment| format!("{:?}", assignment)) + .collect::>(); + write!( + assignments_string, + "{}", + assignments_vec.join(", ").as_str() + ) + .unwrap(); + writeln!(pil, "col fixed {} = [{}];", fixed_name, assignments_string).unwrap(); + } + } +} + +// Convert PIL expression to Powdr PIL string recursively. +fn convert_to_pil_expr_string(expr: PILExpr) -> String { + match expr { + PILExpr::Const(constant) => format!("{:?}", constant), + PILExpr::Sum(sum) => { + let mut expr_string = String::new(); + for (index, expr) in sum.iter().enumerate() { + expr_string += convert_to_pil_expr_string(expr.clone()).as_str(); + if index != sum.len() - 1 { + expr_string += " + "; + } + } + format!("({})", expr_string) + } + PILExpr::Mul(mul) => { + let mut expr_string = String::new(); + for (index, expr) in mul.iter().enumerate() { + expr_string += convert_to_pil_expr_string(expr.clone()).as_str(); + if index != mul.len() - 1 { + expr_string += " * "; + } + } + format!("{}", expr_string) + } + PILExpr::Neg(neg) => format!("(-{})", convert_to_pil_expr_string(*neg)), + PILExpr::Pow(pow, power) => { + format!("({})^{}", convert_to_pil_expr_string(*pow), power) + } + PILExpr::Query(queriable) => convert_to_pil_queriable_string(queriable), + } +} + +// Convert PIL query to Powdr PIL string recursively. +fn convert_to_pil_queriable_string(query: PILQuery) -> String { + let (col, rot) = query; + let annotation = col.annotation(); + if rot { + format!("{}'", annotation) + } else { + annotation + } +} diff --git a/src/pil/compiler/mod.rs b/src/pil/compiler/mod.rs new file mode 100644 index 00000000..0928f3ae --- /dev/null +++ b/src/pil/compiler/mod.rs @@ -0,0 +1,578 @@ +use crate::{ + field::Field, + pil::ir::powdr_pil::{PILCircuit, PILLookup}, + poly::Expr, + sbpir::{query::Queriable, SBPIR}, + util::{uuid, UUID}, + wit_gen::TraceWitness, +}; +use std::{collections::HashMap, fmt::Debug, hash::Hash}; +extern crate regex; + +pub fn compile( + ast: &SBPIR, + witness: Option>, + circuit_name: String, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILCircuit { + let col_witness = collect_witness_columns(ast); + + // HashMap of fixed column to fixed assignments + let mut col_fixed = HashMap::new(); + + if let Some(fixed_assignments) = &ast.fixed_assignments { + fixed_assignments + .iter() + .for_each(|(queriable, assignments)| { + let uuid = queriable.uuid(); + col_fixed.insert( + PILColumn::Fixed( + uuid, + clean_annotation(ast.annotations.get(&uuid).unwrap().clone()), + ), + assignments.clone(), + ); + }); + } + + // Get last step instance UUID, so that we can disable transition of that instance + let mut last_step_instance = 0; + + // Insert into col_fixed the map from step type fixed column to vector of {0,1} where 1 means + // the step type is instantiated whereas 0 not. Each vector should have the same length as the + // number of steps. + if !ast.step_types.is_empty() && witness.is_some() { + let step_instances = witness.as_ref().unwrap().step_instances.iter(); + + // Get last step instance, so that we can disable transition of that instance + last_step_instance = step_instances.clone().last().unwrap().step_type_uuid; + + for step_type in ast.step_types.values() { + let step_type_instantiation: Vec = step_instances + .clone() + .map(|step_instance| { + if step_instance.step_type_uuid == step_type.uuid() { + F::ONE + } else { + F::ZERO + } + }) + .collect(); + assert_eq!(step_type_instantiation.len(), ast.num_steps); + let uuid = step_type.uuid(); + col_fixed.insert( + PILColumn::Fixed( + uuid, + clean_annotation(ast.annotations.get(&uuid).unwrap().clone()), + ), + step_type_instantiation, + ); + } + } + + // Create new UUID for ISFIRST and ISLAST. These are fixed columns unique to PIL. + let is_first_uuid = uuid(); + let is_last_uuid = uuid(); + + // ISFIRST and ISLAST are only relevant when there's non zero number of step instances. + let num_step_instances = witness + .as_ref() + .map(|w| w.step_instances.len()) + .unwrap_or(0); + if num_step_instances != 0 { + // 1 for first row and 0 for all other rows; number of rows equals to number of steps + let is_first_assignments = vec![F::ONE] + .into_iter() + .chain(std::iter::repeat(F::ZERO)) + .take(ast.num_steps) + .collect(); + col_fixed.insert( + PILColumn::Fixed(is_first_uuid, String::from("ISFIRST")), + is_first_assignments, + ); + + // 0 for all rows except the last row, which is 1; number of rows equals to number of steps + let is_last_assignments = std::iter::repeat(F::ZERO) + .take(ast.num_steps - 1) + .chain(std::iter::once(F::ONE)) + .collect(); + col_fixed.insert( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + is_last_assignments, + ); + } + + // Compile step type elements, i.e. constraints, transitions, and lookups. + let (mut constraints, lookups) = compile_steps( + ast, + last_step_instance, + is_last_uuid, + super_circuit_annotations_map, + ); + + // Insert pragma_first_step and pragma_last_step as constraints + if let Some(first_step) = ast.first_step { + // is_first * (1 - first_step) = 0 + constraints.push(PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed(is_first_uuid, String::from("ISFIRST")), + false, + )), + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed( + first_step, + clean_annotation(ast.annotations.get(&first_step).unwrap().clone()), + ), + false, + )))), + ]), + ])); + } + + if let Some(last_step) = ast.last_step { + // is_last * (1 - last_step) = 0 + constraints.push(PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + false, + )), + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed( + last_step, + clean_annotation(ast.annotations.get(&last_step).unwrap().clone()), + ), + false, + )))), + ]), + ])); + } + + PILCircuit { + circuit_name, + num_steps: ast.num_steps, + col_witness, + col_fixed, + constraints, + lookups, + } +} + +pub fn compile_super_circuits( + super_asts: Vec>, + super_trace_witnesses: HashMap>, + ast_id_to_ir_id_mapping: HashMap, + circuit_names: Vec, +) -> Vec> { + assert!(super_asts.len() == circuit_names.len()); + + // Get annotations map for the super circuit, which is a HashMap of object UUID to object + // annotation. + let mut super_circuit_annotations_map: HashMap = HashMap::new(); + + // Loop over each AST. + for (ast, circuit_name) in super_asts.iter().zip(circuit_names.iter()) { + // Create `annotations_map` for each AST, to be added to `super_circuit_annotations_map`. + let mut annotations_map: HashMap = HashMap::new(); + + // First, get AST level annotations. + annotations_map.extend(ast.annotations.clone()); + + // Second, get step level annotations. + for step_type in ast.step_types.values() { + annotations_map.extend(step_type.annotations.clone()); + } + + // Convert annotation to circuit_name.annotation, because this is the general format of + // referring to variables in PIL if there are more than one circuit. + super_circuit_annotations_map.extend(annotations_map.into_iter().map( + |(uuid, annotation)| { + ( + uuid, + format!("{}.{}", circuit_name.clone(), clean_annotation(annotation)), + ) + }, + )); + + // Finally, get annotations for the circuit names. + super_circuit_annotations_map.insert(ast.id, circuit_name.clone()); + } + + // For each AST, find its corresponding TraceWitness. Note that some AST might not have a + // corresponding TraceWitness, so witness is an Option. + let mut pil_irs = Vec::new(); + for (ast, circuit_name) in super_asts.iter().zip(circuit_names.iter()) { + let witness = super_trace_witnesses.get(ast_id_to_ir_id_mapping.get(&ast.id).unwrap()); + + // Create PIL IR + let pil_ir = compile( + ast, + witness.cloned(), + circuit_name.clone(), + &Some(&super_circuit_annotations_map), + ); + + pil_irs.push(pil_ir); + } + + pil_irs +} + +fn collect_witness_columns(ast: &SBPIR) -> Vec { + let mut col_witness = Vec::new(); + + // Collect internal signals to witness columns. + col_witness.extend( + ast.step_types + .values() + .flat_map(|step_type| { + step_type + .signals + .iter() + .map(|signal| { + PILColumn::Advice(signal.uuid(), clean_annotation(signal.annotation())) + }) + .collect::>() + }) + .collect::>(), + ); + + // Collect forward signals to witness columns. + col_witness.extend( + ast.forward_signals + .iter() + .map(|forward_signal| { + PILColumn::Advice( + forward_signal.uuid(), + clean_annotation(forward_signal.annotation()), + ) + }) + .collect::>(), + ); + + // Collect shared signals to witness columns. + col_witness.extend( + ast.shared_signals + .iter() + .map(|shared_signal| { + PILColumn::Advice( + shared_signal.uuid(), + clean_annotation(shared_signal.annotation()), + ) + }) + .collect::>(), + ); + + col_witness +} + +fn compile_steps( + ast: &SBPIR, + last_step_instance: UUID, + is_last_uuid: UUID, + super_circuit_annotations_map: &Option<&HashMap>, +) -> (Vec>, Vec) { + // transitions and constraints all become constraints in PIL + let mut constraints = Vec::new(); + let mut lookups = Vec::new(); + + if !ast.step_types.is_empty() { + ast.step_types.values().for_each(|step_type| { + // Create constraint statements. + constraints.extend( + step_type + .constraints + .iter() + .map(|constraint| { + PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed( + step_type.uuid(), + clean_annotation(step_type.name()), + ), + false, + )), + chiquito_expr_to_pil_expr( + constraint.expr.clone(), + super_circuit_annotations_map, + ), + ]) + }) + .collect::>>(), + ); + + // There's no distinction between constraint and transition in PIL + // However, we do need to identify constraints with rotation in the last row + // and disable them + constraints.extend( + step_type + .transition_constraints + .iter() + .map(|transition| { + let res = PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed( + step_type.uuid(), + clean_annotation(step_type.name()), + ), + false, + )), + chiquito_expr_to_pil_expr( + transition.expr.clone(), + super_circuit_annotations_map, + ), + ]); + if step_type.uuid() == last_step_instance { + PILExpr::Mul(vec![ + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + false, + )))), + ]), + res, + ]) + } else { + res + } + }) + .collect::>>(), + ); + + lookups.extend( + step_type + .lookups + .iter() + .map(|lookup| { + ( + PILColumn::Fixed(step_type.uuid(), clean_annotation(step_type.name())), + lookup + .exprs + .iter() + .map(|(lhs, rhs)| { + ( + chiquito_lookup_column_to_pil_column( + lhs.expr.clone(), + super_circuit_annotations_map, + ), + chiquito_lookup_column_to_pil_column( + rhs.clone(), + super_circuit_annotations_map, + ), + ) + }) + .collect::>(), + ) + }) + .collect::>(), + ); + }); + } + + (constraints, lookups) +} + +// Convert lookup columns (src and dest) in Chiquito to PIL column. Note that Chiquito lookup +// columns have to be Expr::Query type. +fn chiquito_lookup_column_to_pil_column( + src: Expr>, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILColumn { + match src { + Expr::Query(queriable) => { + chiquito_queriable_to_pil_query(queriable, super_circuit_annotations_map).0 + } + _ => panic!("Lookup source is not queriable."), + } +} + +// PIL expression and constraint +#[derive(Clone)] +pub enum PILExpr { + Const(F), + Sum(Vec>), + Mul(Vec>), + Neg(Box>), + Pow(Box>, u32), + Query(PILQuery), +} + +fn chiquito_expr_to_pil_expr( + expr: Expr>, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILExpr { + match expr { + Expr::Const(constant) => PILExpr::Const(constant), + Expr::Sum(sum) => { + let mut pil_sum = Vec::new(); + for expr in sum { + pil_sum.push(chiquito_expr_to_pil_expr( + expr, + super_circuit_annotations_map, + )); + } + PILExpr::Sum(pil_sum) + } + Expr::Mul(mul) => { + let mut pil_mul = Vec::new(); + for expr in mul { + pil_mul.push(chiquito_expr_to_pil_expr( + expr, + super_circuit_annotations_map, + )); + } + PILExpr::Mul(pil_mul) + } + Expr::Neg(neg) => PILExpr::Neg(Box::new(chiquito_expr_to_pil_expr( + *neg, + super_circuit_annotations_map, + ))), + Expr::Pow(pow, power) => PILExpr::Pow( + Box::new(chiquito_expr_to_pil_expr( + *pow, + super_circuit_annotations_map, + )), + power, + ), + Expr::Query(queriable) => PILExpr::Query(chiquito_queriable_to_pil_query( + queriable, + super_circuit_annotations_map, + )), + Expr::Halo2Expr(_) => { + panic!("Halo2 native expression not supported by PIL backend.") + } + Expr::MI(_) => { + panic!("MI not supported by PIL backend.") + } + } +} + +pub type PILQuery = (PILColumn, bool); // column, rotation + +#[derive(Clone, PartialEq, Eq, Hash)] +pub enum PILColumn { + Advice(UUID, String), // UUID, annotation + Fixed(UUID, String), +} + +impl PILColumn { + pub fn uuid(&self) -> UUID { + match self { + PILColumn::Advice(uuid, _) => *uuid, + PILColumn::Fixed(uuid, _) => *uuid, + } + } + + pub fn annotation(&self) -> String { + match self { + PILColumn::Advice(_, annotation) => annotation.clone(), + PILColumn::Fixed(_, annotation) => annotation.clone(), + } + } +} + +pub fn clean_annotation(annotation: String) -> String { + annotation.replace(' ', "_") +} + +// Convert queriable to PIL column recursively. Major differences are: 1. PIL doesn't distinguish +// internal, forward, or shared columns as they are all advice; 2. PIL only supports the next +// rotation, so there's no previous or arbitrary rotation. +fn chiquito_queriable_to_pil_query( + query: Queriable, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILQuery { + match query { + Queriable::Internal(s) => { + if super_circuit_annotations_map.is_none() { + ( + PILColumn::Advice(s.uuid(), clean_annotation(s.annotation())), + false, + ) + } else { + let annotation = super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap(); + ( + PILColumn::Advice(s.uuid(), clean_annotation(annotation.clone())), + false, + ) + } + } + Queriable::Forward(s, rot) => { + if super_circuit_annotations_map.is_none() { + ( + PILColumn::Advice(s.uuid(), clean_annotation(s.annotation())), + rot, + ) + } else { + let annotation = super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap(); + ( + PILColumn::Advice(s.uuid(), clean_annotation(annotation.clone())), + rot, + ) + } + } + Queriable::Shared(s, rot) => { + let annotation = if super_circuit_annotations_map.is_none() { + clean_annotation(s.annotation()) + } else { + super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap() + .clone() + }; + if rot == 0 { + (PILColumn::Advice(s.uuid(), annotation), false) + } else if rot == 1 { + (PILColumn::Advice(s.uuid(), annotation), true) + } else { + panic!( + "PIL backend does not support shared signal with rotation other than 0 or 1." + ) + } + } + Queriable::Fixed(s, rot) => { + let annotation = if super_circuit_annotations_map.is_none() { + clean_annotation(s.annotation()) + } else { + super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap() + .clone() + }; + if rot == 0 { + (PILColumn::Fixed(s.uuid(), annotation), false) + } else if rot == 1 { + (PILColumn::Fixed(s.uuid(), annotation), true) + } else { + panic!("PIL backend does not support fixed signal with rotation other than 0 or 1.") + } + } + Queriable::StepTypeNext(s) => ( + PILColumn::Fixed(s.uuid(), clean_annotation(s.annotation())), + true, + ), + Queriable::Halo2AdviceQuery(_, _) => { + panic!("Halo2 native advice query not supported by PIL backend.") + } + Queriable::Halo2FixedQuery(_, _) => { + panic!("Halo2 native fixed query not supported by PIL backend.") + } + Queriable::_unaccessible(_) => todo!(), + } +} diff --git a/src/pil/ir/mod.rs b/src/pil/ir/mod.rs new file mode 100644 index 00000000..d57b058b --- /dev/null +++ b/src/pil/ir/mod.rs @@ -0,0 +1 @@ +pub mod powdr_pil; diff --git a/src/pil/ir/powdr_pil.rs b/src/pil/ir/powdr_pil.rs new file mode 100644 index 00000000..576d147b --- /dev/null +++ b/src/pil/ir/powdr_pil.rs @@ -0,0 +1,18 @@ +use crate::pil::compiler::{PILColumn, PILExpr, PILQuery}; +use std::collections::HashMap; +extern crate regex; + +// PIL circuit IR +pub struct PILCircuit { + pub circuit_name: String, + pub num_steps: usize, + pub col_witness: Vec, + pub col_fixed: HashMap>, // column -> assignments + pub constraints: Vec>, + pub lookups: Vec, +} + +// lookup in PIL is the format of selector {src1, src2, ..., srcn} -> {dst1, dst2, ..., dstn} +// PILLookup is a tuple of (selector, Vec) tuples, where selector is converted from +// Chiquito step type to fixed column +pub type PILLookup = (PILColumn, Vec<(PILColumn, PILColumn)>); diff --git a/src/pil/mod.rs b/src/pil/mod.rs new file mode 100644 index 00000000..aab5b126 --- /dev/null +++ b/src/pil/mod.rs @@ -0,0 +1,3 @@ +pub mod backend; +pub mod compiler; +pub mod ir; diff --git a/src/plonkish/ir/assignments.rs b/src/plonkish/ir/assignments.rs index 9f605d7a..60a91a48 100644 --- a/src/plonkish/ir/assignments.rs +++ b/src/plonkish/ir/assignments.rs @@ -132,8 +132,12 @@ impl AssignmentGenerator { } } + pub fn generate_trace_witness(&self, args: TraceArgs) -> TraceWitness { + self.trace_gen.generate(args) + } + pub fn generate(&self, args: TraceArgs) -> Assignments { - let witness = self.trace_gen.generate(args); + let witness = self.generate_trace_witness(args); self.generate_with_witness(witness) } diff --git a/src/plonkish/ir/sc.rs b/src/plonkish/ir/sc.rs index 4398bc31..05cf4663 100644 --- a/src/plonkish/ir/sc.rs +++ b/src/plonkish/ir/sc.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, hash::Hash, rc::Rc}; -use crate::{field::Field, util::UUID, wit_gen::TraceWitness}; +use crate::{field::Field, sbpir::SBPIR, util::UUID, wit_gen::TraceWitness}; use super::{ assignments::{AssignmentGenerator, Assignments}, @@ -10,6 +10,7 @@ use super::{ pub struct SuperCircuit { sub_circuits: Vec>, mapping: MappingGenerator, + sub_circuit_asts: Vec>, } impl Default for SuperCircuit { @@ -17,6 +18,7 @@ impl Default for SuperCircuit { Self { sub_circuits: Default::default(), mapping: Default::default(), + sub_circuit_asts: Default::default(), } } } @@ -29,6 +31,30 @@ impl SuperCircuit { pub fn get_mapping(&self) -> MappingGenerator { self.mapping.clone() } + + // Needed for the PIL backend. + pub fn add_sub_circuit_ast(&mut self, sub_circuit_ast: SBPIR) { + self.sub_circuit_asts.push(sub_circuit_ast); + } + + // Mapping from AST id to IR id is needed for the PIL backend to match TraceWitness, which has + // IR id, to AST. + pub fn get_ast_id_to_ir_id_mapping(&self) -> HashMap { + let mut ast_id_to_ir_id_mapping: HashMap = HashMap::new(); + self.sub_circuits.iter().for_each(|circuit| { + let ir_id = circuit.id; + let ast_id = circuit.ast_id; + ast_id_to_ir_id_mapping.insert(ast_id, ir_id); + }); + ast_id_to_ir_id_mapping + } +} + +// Needed for the PIL backend. +impl SuperCircuit { + pub fn get_super_asts(&self) -> Vec> { + self.sub_circuit_asts.clone() + } } impl SuperCircuit { @@ -47,22 +73,29 @@ impl SuperCircuit { } pub type SuperAssignments = HashMap>; +pub type SuperTraceWitness = HashMap>; pub struct MappingContext { assignments: SuperAssignments, + trace_witnesses: SuperTraceWitness, } -impl Default for MappingContext { +impl Default for MappingContext { fn default() -> Self { Self { assignments: Default::default(), + trace_witnesses: Default::default(), } } } impl MappingContext { pub fn map(&mut self, gen: &AssignmentGenerator, args: TraceArgs) { - self.assignments.insert(gen.uuid(), gen.generate(args)); + let trace_witness = gen.generate_trace_witness(args); + self.trace_witnesses + .insert(gen.uuid(), trace_witness.clone()); + self.assignments + .insert(gen.uuid(), gen.generate_with_witness(trace_witness)); } pub fn map_with_witness( @@ -77,6 +110,10 @@ impl MappingContext { pub fn get_super_assignments(self) -> SuperAssignments { self.assignments } + + pub fn get_trace_witnesses(self) -> SuperTraceWitness { + self.trace_witnesses + } } pub type Mapping = dyn Fn(&mut MappingContext, MappingArgs) + 'static; @@ -113,4 +150,13 @@ impl MappingGenerator { ctx.get_super_assignments() } + + // Needed for the PIL backend. + pub fn generate_super_trace_witnesses(&self, args: MappingArgs) -> SuperTraceWitness { + let mut ctx = MappingContext::default(); + + (self.mapping)(&mut ctx, args); + + ctx.get_trace_witnesses() + } } diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index b5a77029..217f2ba8 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -204,6 +204,28 @@ impl SBPIR { } } +impl SBPIR { + pub fn clone_without_trace(&self) -> SBPIR { + SBPIR { + step_types: self.step_types.clone(), + forward_signals: self.forward_signals.clone(), + shared_signals: self.shared_signals.clone(), + fixed_signals: self.fixed_signals.clone(), + halo2_advice: self.halo2_advice.clone(), + halo2_fixed: self.halo2_fixed.clone(), + exposed: self.exposed.clone(), + annotations: self.annotations.clone(), + trace: None, // Remove the trace. + fixed_assignments: self.fixed_assignments.clone(), + first_step: self.first_step, + last_step: self.last_step, + num_steps: self.num_steps, + q_enable: self.q_enable, + id: self.id, + } + } +} + pub type FixedGen = dyn Fn(&mut FixedGenContext) + 'static; pub type StepTypeUUID = UUID; @@ -253,6 +275,10 @@ impl StepType { self.id } + pub fn name(&self) -> String { + self.name.clone() + } + pub fn add_signal>(&mut self, name: N) -> InternalSignal { let name = name.into(); let signal = InternalSignal::new(name.clone()); @@ -417,6 +443,10 @@ impl ForwardSignal { pub fn phase(&self) -> usize { self.phase } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -450,6 +480,10 @@ impl SharedSignal { pub fn phase(&self) -> usize { self.phase } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -476,6 +510,10 @@ impl FixedSignal { pub fn uuid(&self) -> UUID { self.id } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug)] @@ -508,6 +546,10 @@ impl InternalSignal { pub fn uuid(&self) -> UUID { self.id } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]