From 15f441b8edd78b0b37a0ad925fd5b47d33e5b975 Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Sun, 3 Mar 2024 22:05:09 -0500 Subject: [PATCH 1/8] PIL Backend with PIL IR (#165) Ready for review. @leolara # High-Level Approach Convert Chiquito AST to PIL IR and finally to PIL code. # Future TODOs - Add support for trace inputs to PIL. Currently, PIL supports user input via the command line or Rust APIs for the command line. However, it only supports input for one variable (column). Because Chiquito circuit might provide inputs for multiple variables via trace (e.g. the MiMC7 circuit), we can't use use PIL's Rust API as a generic solution for providing inputs. The only solution is to input statements in the format of `ISFIRST * ([[input_variable]] - [[input_value]]) = 0`. An issue with this approach, however, is that our `trace` is too general and doesn't associate input values to signals, so we might need to create an alternative `trace` function or mode that does this. All of the above is due to the fact that PIL only supports automatic witness inference and cannot feed in external witness. - Add Python PIL API support for super circuits. Currently we support single circuit in Python. Should be very do-able, just didn't get the time to code it up. --- Cargo.toml | 1 + examples/fibonacci.py | 2 + examples/fibonacci.rs | 29 +- examples/mimc7.rs | 25 +- src/frontend/dsl/mod.rs | 4 + src/frontend/dsl/sc.rs | 11 +- src/frontend/python/chiquito/dsl.py | 9 + src/frontend/python/mod.rs | 22 ++ src/lib.rs | 1 + src/pil/backend/mod.rs | 1 + src/pil/backend/powdr_pil.rs | 226 +++++++++++ src/pil/compiler/mod.rs | 578 ++++++++++++++++++++++++++++ src/pil/ir/mod.rs | 1 + src/pil/ir/powdr_pil.rs | 18 + src/pil/mod.rs | 3 + src/plonkish/ir/assignments.rs | 6 +- src/plonkish/ir/sc.rs | 52 ++- src/sbpir/mod.rs | 42 ++ 18 files changed, 1018 insertions(+), 13 deletions(-) create mode 100644 src/pil/backend/mod.rs create mode 100644 src/pil/backend/powdr_pil.rs create mode 100644 src/pil/compiler/mod.rs create mode 100644 src/pil/ir/mod.rs create mode 100644 src/pil/ir/powdr_pil.rs create mode 100644 src/pil/mod.rs diff --git a/Cargo.toml b/Cargo.toml index c002d68b..58d1fd0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ num-bigint = { version = "0.4", features = ["rand"] } uuid = { version = "1.4.0", features = ["v1", "rng"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +regex = "1" [dev-dependencies] rand_chacha = "0.3" diff --git a/examples/fibonacci.py b/examples/fibonacci.py index 4ffa7b31..0c7aedfa 100644 --- a/examples/fibonacci.py +++ b/examples/fibonacci.py @@ -84,3 +84,5 @@ def trace(self, n): ) # 2^k specifies the number of PLONKish table rows in Halo2 another_fibo_witness = fibo.gen_witness(4) fibo.halo2_mock_prover(another_fibo_witness, k=7) + +fibo.to_pil(fibo_witness, "FiboCircuit") diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 4021cdda..dbf7dbba 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -18,6 +18,7 @@ use chiquito::{ }, plonkish::ir::{assignments::AssignmentGenerator, Circuit}, // compiled circuit type poly::ToField, + sbpir::SBPIR, }; use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; @@ -26,7 +27,10 @@ use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; // 1. type that implements a field trait // 2. empty trace arguments, i.e. (), because there are no external inputs to the Chiquito circuit // 3. two witness generation arguments both of u64 type, i.e. (u64, u64) -fn fibo_circuit + Hash>() -> (Circuit, Option>) { + +type FiboReturn = (Circuit, Option>, SBPIR); + +fn fibo_circuit + Hash>() -> FiboReturn { // PLONKish table for the Fibonacci circuit: // | a | b | c | // | 1 | 1 | 2 | @@ -73,7 +77,7 @@ fn fibo_circuit + Hash>() -> (Circuit, Option + Hash>() -> (Circuit, Option + Hash>() -> (Circuit, Option(); + let (chiquito, wit_gen, _) = fibo_circuit::(); let compiled = chiquito2Halo2(chiquito); let circuit = ChiquitoHalo2Circuit::new(compiled, wit_gen.map(|g| g.generate(()))); @@ -137,7 +143,7 @@ fn main() { use polyexen::plaf::{backends::halo2::PlafH2Circuit, WitnessDisplayCSV}; // get Chiquito ir - let (circuit, wit_gen) = fibo_circuit::(); + let (circuit, wit_gen, _) = fibo_circuit::(); // get Plaf let (plaf, plaf_wit_gen) = chiquito2Plaf(circuit, 8, false); let wit = plaf_wit_gen.generate(wit_gen.map(|v| v.generate(()))); @@ -162,4 +168,15 @@ fn main() { println!("{}", failure); } } + + // pil boilerplate + use chiquito::pil::backend::powdr_pil::chiquito2Pil; + + let (_, wit_gen, circuit) = fibo_circuit::(); + let pil = chiquito2Pil( + circuit, + Some(wit_gen.unwrap().generate_trace_witness(())), + String::from("FiboCircuit"), + ); + print!("{}", pil); } diff --git a/examples/mimc7.rs b/examples/mimc7.rs index fa1b1d2e..1eafeb0e 100644 --- a/examples/mimc7.rs +++ b/examples/mimc7.rs @@ -175,7 +175,7 @@ fn mimc7_circuit( row_value += F::from(1); x_value += k_value + c_value; x_value = x_value.pow_vartime([7_u64]); - // Step 90: output the hash result as x + k in witness generation + // Step 91: output the hash result as x + k in witness generation // output is not displayed as a public column, which will be implemented in the future ctx.add(&mimc7_last_step, (x_value, k_value, c_value, row_value)); // c_value is not // used here but @@ -219,6 +219,29 @@ fn main() { println!("{}", failure); } } + + // pil boilerplate + use chiquito::pil::backend::powdr_pil::chiquitoSuperCircuit2Pil; + + let x_in_value = Fr::from_str_vartime("1").expect("expected a number"); + let k_value = Fr::from_str_vartime("2").expect("expected a number"); + + let super_circuit = mimc7_super_circuit::(); + + // `super_trace_witnesses` is a mapping from IR id to TraceWitness. However, not all ASTs have a + // corresponding TraceWitness. + let super_trace_witnesses = super_circuit + .get_mapping() + .generate_super_trace_witnesses((x_in_value, k_value)); + + let pil = chiquitoSuperCircuit2Pil::( + super_circuit.get_super_asts(), + super_trace_witnesses, + super_circuit.get_ast_id_to_ir_id_mapping(), + vec![String::from("Mimc7Constant"), String::from("Mimc7Circuit")], + ); + + print!("{}", pil); } mod mimc7_constants { diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index c90db77e..15def6ff 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -358,6 +358,10 @@ impl StepTypeHandler { pub fn next(&self) -> Queriable { Queriable::StepTypeNext(*self) } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } impl, Args) + 'static> From<&StepTypeWGHandler> diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index 9b6daec9..f9ef60d5 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -34,6 +34,12 @@ impl Default for SuperCircuitContext { } } +impl SuperCircuitContext { + fn add_sub_circuit_ast(&mut self, ast: SBPIR) { + self.super_circuit.add_sub_circuit_ast(ast); + } +} + impl SuperCircuitContext { pub fn sub_circuit( &mut self, @@ -48,12 +54,13 @@ impl SuperCircuitContext { circuit: SBPIR::default(), tables: self.tables.clone(), }; - println!("super circuit table registry 2: {:?}", self.tables); let exports = sub_circuit_def(&mut sub_circuit_context, imports); - println!("super circuit table registry 3: {:?}", self.tables); let sub_circuit = sub_circuit_context.circuit; + // ast is used for PIL backend + self.add_sub_circuit_ast(sub_circuit.clone_without_trace()); + let (unit, assignment) = compile_phase1(config, &sub_circuit); let assignment = assignment.unwrap_or_else(|| AssignmentGenerator::empty(unit.uuid)); diff --git a/src/frontend/python/chiquito/dsl.py b/src/frontend/python/chiquito/dsl.py index aa355972..ee41004f 100644 --- a/src/frontend/python/chiquito/dsl.py +++ b/src/frontend/python/chiquito/dsl.py @@ -225,6 +225,15 @@ def halo2_mock_prover(self: Circuit, witness: TraceWitness, k: int = 16): witness_json: str = witness.get_witness_json() rust_chiquito.halo2_mock_prover(witness_json, self.rust_id, k) + def to_pil( + self: Circuit, witness: TraceWitness, circuit_name: str = "Circuit" + ) -> str: + if self.rust_id == 0: + ast_json: str = self.get_ast_json() + self.rust_id: int = rust_chiquito.ast_to_halo2(ast_json) + witness_json: str = witness.get_witness_json() + rust_chiquito.to_pil(witness_json, self.rust_id, circuit_name) + def __str__(self: Circuit) -> str: return self.ast.__str__() diff --git a/src/frontend/python/mod.rs b/src/frontend/python/mod.rs index e6da88aa..69b3abc3 100644 --- a/src/frontend/python/mod.rs +++ b/src/frontend/python/mod.rs @@ -5,6 +5,7 @@ use pyo3::{ use crate::{ frontend::dsl::{StepTypeHandler, SuperCircuitContext}, + pil::backend::powdr_pil::chiquito2Pil, plonkish::{ backend::halo2::{ chiquito2Halo2, chiquitoSuperCircuit2Halo2, ChiquitoHalo2, ChiquitoHalo2Circuit, @@ -81,6 +82,14 @@ pub fn chiquito_ast_map_store(ast_json: &str) -> UUID { uuid } +pub fn chiquito_ast_to_pil(witness_json: &str, rust_id: UUID, circuit_name: &str) -> String { + let trace_witness: TraceWitness = + serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed."); + let (ast, _, _) = rust_id_to_halo2(rust_id); + + chiquito2Pil(ast, Some(trace_witness), circuit_name.to_string()) +} + fn add_assignment_generator_to_rust_id( assignment_generator: AssignmentGenerator, rust_id: UUID, @@ -1845,6 +1854,18 @@ fn ast_to_halo2(json: &PyString) -> u128 { uuid } +#[pyfunction] +fn to_pil(witness_json: &PyString, rust_id: &PyLong, circuit_name: &PyString) -> String { + let pil = chiquito_ast_to_pil( + witness_json.to_str().expect("PyString convertion failed."), + rust_id.extract().expect("PyLong convertion failed."), + circuit_name.to_str().expect("PyString convertion failed."), + ); + + println!("{}", pil); + pil +} + #[pyfunction] fn ast_map_store(json: &PyString) -> u128 { let uuid = chiquito_ast_map_store(json.to_str().expect("PyString conversion failed.")); @@ -1903,6 +1924,7 @@ fn rust_chiquito(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(convert_and_print_ast, m)?)?; m.add_function(wrap_pyfunction!(convert_and_print_trace_witness, m)?)?; m.add_function(wrap_pyfunction!(ast_to_halo2, m)?)?; + m.add_function(wrap_pyfunction!(to_pil, m)?)?; m.add_function(wrap_pyfunction!(ast_map_store, m)?)?; m.add_function(wrap_pyfunction!(halo2_mock_prover, m)?)?; m.add_function(wrap_pyfunction!(super_circuit_halo2_mock_prover, m)?)?; diff --git a/src/lib.rs b/src/lib.rs index dd823581..853e71ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod field; pub mod frontend; +pub mod pil; pub mod plonkish; pub mod poly; pub mod sbpir; diff --git a/src/pil/backend/mod.rs b/src/pil/backend/mod.rs new file mode 100644 index 00000000..d57b058b --- /dev/null +++ b/src/pil/backend/mod.rs @@ -0,0 +1 @@ +pub mod powdr_pil; diff --git a/src/pil/backend/powdr_pil.rs b/src/pil/backend/powdr_pil.rs new file mode 100644 index 00000000..4b86c591 --- /dev/null +++ b/src/pil/backend/powdr_pil.rs @@ -0,0 +1,226 @@ +use crate::{ + field::Field, + pil::{ + compiler::{compile, compile_super_circuits, PILColumn, PILExpr, PILQuery}, + ir::powdr_pil::PILCircuit, + }, + sbpir::SBPIR, + util::UUID, + wit_gen::TraceWitness, +}; +use std::{ + collections::HashMap, + fmt::{Debug, Write}, +}; +extern crate regex; + +#[allow(non_snake_case)] +/// User generate PIL code using this function. User needs to supply AST, TraceWitness, and a name +/// string for the circuit. +pub fn chiquito2Pil( + ast: SBPIR, + witness: Option>, + circuit_name: String, +) -> String { + // generate PIL IR. + let pil_ir = compile::(&ast, witness, circuit_name, &None); + + // generate Powdr PIL code. + pil_ir_to_powdr_pil::(pil_ir) +} + +// Convert PIL IR to Powdr PIL code. +pub fn pil_ir_to_powdr_pil(pil_ir: PILCircuit) -> String { + let mut pil = String::new(); // The string to return. + + writeln!( + pil, + "// ===== START OF CIRCUIT: {} =====", + pil_ir.circuit_name + ) + .unwrap(); + + // Namespace is equivalent to a circuit in PIL. + writeln!( + pil, + "constant %NUM_STEPS_{} = {};", + pil_ir.circuit_name.to_uppercase(), + pil_ir.num_steps + ) + .unwrap(); + writeln!( + pil, + "namespace {}(%NUM_STEPS_{});", + pil_ir.circuit_name, + pil_ir.circuit_name.to_uppercase() + ) + .unwrap(); + + // Declare witness columns in PIL. + generate_pil_witness_columns(&mut pil, &pil_ir); + + // Declare fixed columns and their assignments in PIL. + generate_pil_fixed_columns(&mut pil, &pil_ir); + + // generate constraints + for expr in pil_ir.constraints { + // recursively convert expressions to PIL strings + let expr_string = convert_to_pil_expr_string(expr.clone()); + // each constraint is in the format of `constraint = 0` + writeln!(pil, "{} = 0;", expr_string).unwrap(); + } + + // generate lookups + for lookup in pil_ir.lookups { + let (selector, src_dest_tuples) = lookup; + let lookup_selector = selector.annotation(); + let mut lookup_source: Vec = Vec::new(); + let mut lookup_destination: Vec = Vec::new(); + for (src, dest) in src_dest_tuples { + lookup_source.push(src.annotation()); + lookup_destination.push(dest.annotation()); + } + // PIL lookups have the format of `selector { src1, src2, ... srcn } in {dest1, dest2, ..., + // destn}`. + writeln!( + pil, + "{} {{{}}} in {{{}}} ", + lookup_selector, + lookup_source.join(", "), + lookup_destination.join(", ") + ) + .unwrap(); + } + + writeln!( + pil, + "// ===== END OF CIRCUIT: {} =====", + pil_ir.circuit_name + ) + .unwrap(); + writeln!(pil).unwrap(); // Separator row for the circuit. + + pil +} + +#[allow(non_snake_case)] +/// User generate PIL code for super circuit using this function. +/// User needs to supply a Vec for `circuit_names`, the order of which should be the same as +/// the order of calling `sub_circuit()` function. +pub fn chiquitoSuperCircuit2Pil( + super_asts: Vec>, + super_trace_witnesses: HashMap>, + ast_id_to_ir_id_mapping: HashMap, + circuit_names: Vec, +) -> String { + let mut pil = String::new(); // The string to return. + + // Generate PIL IRs for each sub circuit in the super circuit. + let pil_irs = compile_super_circuits( + super_asts, + super_trace_witnesses, + ast_id_to_ir_id_mapping, + circuit_names, + ); + + // Generate Powdr PIL code for each sub circuit. + for pil_ir in pil_irs { + let pil_circuit = pil_ir_to_powdr_pil(pil_ir); + writeln!(pil, "{}", pil_circuit).unwrap(); + } + + pil +} + +fn generate_pil_witness_columns(pil: &mut String, pil_ir: &PILCircuit) { + if !pil_ir.col_witness.is_empty() { + writeln!(pil, "// === Witness Columns ===").unwrap(); + let mut col_witness = String::from("col witness "); + + let mut col_witness_vars = pil_ir + .col_witness + .iter() + .map(|col| match col { + PILColumn::Advice(_, annotation) => annotation.clone(), + _ => panic!("Witness column should be an advice column."), + }) + .collect::>(); + + // Get unique witness column annotations + col_witness_vars.sort(); + col_witness_vars.dedup(); + col_witness = col_witness + col_witness_vars.join(", ").as_str() + ";"; + writeln!(pil, "{}", col_witness).unwrap(); + } +} + +fn generate_pil_fixed_columns(pil: &mut String, pil_ir: &PILCircuit) { + if !pil_ir.col_fixed.is_empty() { + writeln!( + pil, + "// === Fixed Columns for Signals and Step Type Selectors ===" + ) + .unwrap(); + for (col, assignments) in pil_ir.col_fixed.iter() { + let fixed_name = match col { + PILColumn::Fixed(_, annotation) => annotation.clone(), + _ => panic!("Fixed column should be an advice or fixed column."), + }; + let mut assignments_string = String::new(); + let assignments_vec = assignments + .iter() + .map(|assignment| format!("{:?}", assignment)) + .collect::>(); + write!( + assignments_string, + "{}", + assignments_vec.join(", ").as_str() + ) + .unwrap(); + writeln!(pil, "col fixed {} = [{}];", fixed_name, assignments_string).unwrap(); + } + } +} + +// Convert PIL expression to Powdr PIL string recursively. +fn convert_to_pil_expr_string(expr: PILExpr) -> String { + match expr { + PILExpr::Const(constant) => format!("{:?}", constant), + PILExpr::Sum(sum) => { + let mut expr_string = String::new(); + for (index, expr) in sum.iter().enumerate() { + expr_string += convert_to_pil_expr_string(expr.clone()).as_str(); + if index != sum.len() - 1 { + expr_string += " + "; + } + } + format!("({})", expr_string) + } + PILExpr::Mul(mul) => { + let mut expr_string = String::new(); + for (index, expr) in mul.iter().enumerate() { + expr_string += convert_to_pil_expr_string(expr.clone()).as_str(); + if index != mul.len() - 1 { + expr_string += " * "; + } + } + format!("{}", expr_string) + } + PILExpr::Neg(neg) => format!("(-{})", convert_to_pil_expr_string(*neg)), + PILExpr::Pow(pow, power) => { + format!("({})^{}", convert_to_pil_expr_string(*pow), power) + } + PILExpr::Query(queriable) => convert_to_pil_queriable_string(queriable), + } +} + +// Convert PIL query to Powdr PIL string recursively. +fn convert_to_pil_queriable_string(query: PILQuery) -> String { + let (col, rot) = query; + let annotation = col.annotation(); + if rot { + format!("{}'", annotation) + } else { + annotation + } +} diff --git a/src/pil/compiler/mod.rs b/src/pil/compiler/mod.rs new file mode 100644 index 00000000..0928f3ae --- /dev/null +++ b/src/pil/compiler/mod.rs @@ -0,0 +1,578 @@ +use crate::{ + field::Field, + pil::ir::powdr_pil::{PILCircuit, PILLookup}, + poly::Expr, + sbpir::{query::Queriable, SBPIR}, + util::{uuid, UUID}, + wit_gen::TraceWitness, +}; +use std::{collections::HashMap, fmt::Debug, hash::Hash}; +extern crate regex; + +pub fn compile( + ast: &SBPIR, + witness: Option>, + circuit_name: String, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILCircuit { + let col_witness = collect_witness_columns(ast); + + // HashMap of fixed column to fixed assignments + let mut col_fixed = HashMap::new(); + + if let Some(fixed_assignments) = &ast.fixed_assignments { + fixed_assignments + .iter() + .for_each(|(queriable, assignments)| { + let uuid = queriable.uuid(); + col_fixed.insert( + PILColumn::Fixed( + uuid, + clean_annotation(ast.annotations.get(&uuid).unwrap().clone()), + ), + assignments.clone(), + ); + }); + } + + // Get last step instance UUID, so that we can disable transition of that instance + let mut last_step_instance = 0; + + // Insert into col_fixed the map from step type fixed column to vector of {0,1} where 1 means + // the step type is instantiated whereas 0 not. Each vector should have the same length as the + // number of steps. + if !ast.step_types.is_empty() && witness.is_some() { + let step_instances = witness.as_ref().unwrap().step_instances.iter(); + + // Get last step instance, so that we can disable transition of that instance + last_step_instance = step_instances.clone().last().unwrap().step_type_uuid; + + for step_type in ast.step_types.values() { + let step_type_instantiation: Vec = step_instances + .clone() + .map(|step_instance| { + if step_instance.step_type_uuid == step_type.uuid() { + F::ONE + } else { + F::ZERO + } + }) + .collect(); + assert_eq!(step_type_instantiation.len(), ast.num_steps); + let uuid = step_type.uuid(); + col_fixed.insert( + PILColumn::Fixed( + uuid, + clean_annotation(ast.annotations.get(&uuid).unwrap().clone()), + ), + step_type_instantiation, + ); + } + } + + // Create new UUID for ISFIRST and ISLAST. These are fixed columns unique to PIL. + let is_first_uuid = uuid(); + let is_last_uuid = uuid(); + + // ISFIRST and ISLAST are only relevant when there's non zero number of step instances. + let num_step_instances = witness + .as_ref() + .map(|w| w.step_instances.len()) + .unwrap_or(0); + if num_step_instances != 0 { + // 1 for first row and 0 for all other rows; number of rows equals to number of steps + let is_first_assignments = vec![F::ONE] + .into_iter() + .chain(std::iter::repeat(F::ZERO)) + .take(ast.num_steps) + .collect(); + col_fixed.insert( + PILColumn::Fixed(is_first_uuid, String::from("ISFIRST")), + is_first_assignments, + ); + + // 0 for all rows except the last row, which is 1; number of rows equals to number of steps + let is_last_assignments = std::iter::repeat(F::ZERO) + .take(ast.num_steps - 1) + .chain(std::iter::once(F::ONE)) + .collect(); + col_fixed.insert( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + is_last_assignments, + ); + } + + // Compile step type elements, i.e. constraints, transitions, and lookups. + let (mut constraints, lookups) = compile_steps( + ast, + last_step_instance, + is_last_uuid, + super_circuit_annotations_map, + ); + + // Insert pragma_first_step and pragma_last_step as constraints + if let Some(first_step) = ast.first_step { + // is_first * (1 - first_step) = 0 + constraints.push(PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed(is_first_uuid, String::from("ISFIRST")), + false, + )), + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed( + first_step, + clean_annotation(ast.annotations.get(&first_step).unwrap().clone()), + ), + false, + )))), + ]), + ])); + } + + if let Some(last_step) = ast.last_step { + // is_last * (1 - last_step) = 0 + constraints.push(PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + false, + )), + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed( + last_step, + clean_annotation(ast.annotations.get(&last_step).unwrap().clone()), + ), + false, + )))), + ]), + ])); + } + + PILCircuit { + circuit_name, + num_steps: ast.num_steps, + col_witness, + col_fixed, + constraints, + lookups, + } +} + +pub fn compile_super_circuits( + super_asts: Vec>, + super_trace_witnesses: HashMap>, + ast_id_to_ir_id_mapping: HashMap, + circuit_names: Vec, +) -> Vec> { + assert!(super_asts.len() == circuit_names.len()); + + // Get annotations map for the super circuit, which is a HashMap of object UUID to object + // annotation. + let mut super_circuit_annotations_map: HashMap = HashMap::new(); + + // Loop over each AST. + for (ast, circuit_name) in super_asts.iter().zip(circuit_names.iter()) { + // Create `annotations_map` for each AST, to be added to `super_circuit_annotations_map`. + let mut annotations_map: HashMap = HashMap::new(); + + // First, get AST level annotations. + annotations_map.extend(ast.annotations.clone()); + + // Second, get step level annotations. + for step_type in ast.step_types.values() { + annotations_map.extend(step_type.annotations.clone()); + } + + // Convert annotation to circuit_name.annotation, because this is the general format of + // referring to variables in PIL if there are more than one circuit. + super_circuit_annotations_map.extend(annotations_map.into_iter().map( + |(uuid, annotation)| { + ( + uuid, + format!("{}.{}", circuit_name.clone(), clean_annotation(annotation)), + ) + }, + )); + + // Finally, get annotations for the circuit names. + super_circuit_annotations_map.insert(ast.id, circuit_name.clone()); + } + + // For each AST, find its corresponding TraceWitness. Note that some AST might not have a + // corresponding TraceWitness, so witness is an Option. + let mut pil_irs = Vec::new(); + for (ast, circuit_name) in super_asts.iter().zip(circuit_names.iter()) { + let witness = super_trace_witnesses.get(ast_id_to_ir_id_mapping.get(&ast.id).unwrap()); + + // Create PIL IR + let pil_ir = compile( + ast, + witness.cloned(), + circuit_name.clone(), + &Some(&super_circuit_annotations_map), + ); + + pil_irs.push(pil_ir); + } + + pil_irs +} + +fn collect_witness_columns(ast: &SBPIR) -> Vec { + let mut col_witness = Vec::new(); + + // Collect internal signals to witness columns. + col_witness.extend( + ast.step_types + .values() + .flat_map(|step_type| { + step_type + .signals + .iter() + .map(|signal| { + PILColumn::Advice(signal.uuid(), clean_annotation(signal.annotation())) + }) + .collect::>() + }) + .collect::>(), + ); + + // Collect forward signals to witness columns. + col_witness.extend( + ast.forward_signals + .iter() + .map(|forward_signal| { + PILColumn::Advice( + forward_signal.uuid(), + clean_annotation(forward_signal.annotation()), + ) + }) + .collect::>(), + ); + + // Collect shared signals to witness columns. + col_witness.extend( + ast.shared_signals + .iter() + .map(|shared_signal| { + PILColumn::Advice( + shared_signal.uuid(), + clean_annotation(shared_signal.annotation()), + ) + }) + .collect::>(), + ); + + col_witness +} + +fn compile_steps( + ast: &SBPIR, + last_step_instance: UUID, + is_last_uuid: UUID, + super_circuit_annotations_map: &Option<&HashMap>, +) -> (Vec>, Vec) { + // transitions and constraints all become constraints in PIL + let mut constraints = Vec::new(); + let mut lookups = Vec::new(); + + if !ast.step_types.is_empty() { + ast.step_types.values().for_each(|step_type| { + // Create constraint statements. + constraints.extend( + step_type + .constraints + .iter() + .map(|constraint| { + PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed( + step_type.uuid(), + clean_annotation(step_type.name()), + ), + false, + )), + chiquito_expr_to_pil_expr( + constraint.expr.clone(), + super_circuit_annotations_map, + ), + ]) + }) + .collect::>>(), + ); + + // There's no distinction between constraint and transition in PIL + // However, we do need to identify constraints with rotation in the last row + // and disable them + constraints.extend( + step_type + .transition_constraints + .iter() + .map(|transition| { + let res = PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed( + step_type.uuid(), + clean_annotation(step_type.name()), + ), + false, + )), + chiquito_expr_to_pil_expr( + transition.expr.clone(), + super_circuit_annotations_map, + ), + ]); + if step_type.uuid() == last_step_instance { + PILExpr::Mul(vec![ + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + false, + )))), + ]), + res, + ]) + } else { + res + } + }) + .collect::>>(), + ); + + lookups.extend( + step_type + .lookups + .iter() + .map(|lookup| { + ( + PILColumn::Fixed(step_type.uuid(), clean_annotation(step_type.name())), + lookup + .exprs + .iter() + .map(|(lhs, rhs)| { + ( + chiquito_lookup_column_to_pil_column( + lhs.expr.clone(), + super_circuit_annotations_map, + ), + chiquito_lookup_column_to_pil_column( + rhs.clone(), + super_circuit_annotations_map, + ), + ) + }) + .collect::>(), + ) + }) + .collect::>(), + ); + }); + } + + (constraints, lookups) +} + +// Convert lookup columns (src and dest) in Chiquito to PIL column. Note that Chiquito lookup +// columns have to be Expr::Query type. +fn chiquito_lookup_column_to_pil_column( + src: Expr>, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILColumn { + match src { + Expr::Query(queriable) => { + chiquito_queriable_to_pil_query(queriable, super_circuit_annotations_map).0 + } + _ => panic!("Lookup source is not queriable."), + } +} + +// PIL expression and constraint +#[derive(Clone)] +pub enum PILExpr { + Const(F), + Sum(Vec>), + Mul(Vec>), + Neg(Box>), + Pow(Box>, u32), + Query(PILQuery), +} + +fn chiquito_expr_to_pil_expr( + expr: Expr>, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILExpr { + match expr { + Expr::Const(constant) => PILExpr::Const(constant), + Expr::Sum(sum) => { + let mut pil_sum = Vec::new(); + for expr in sum { + pil_sum.push(chiquito_expr_to_pil_expr( + expr, + super_circuit_annotations_map, + )); + } + PILExpr::Sum(pil_sum) + } + Expr::Mul(mul) => { + let mut pil_mul = Vec::new(); + for expr in mul { + pil_mul.push(chiquito_expr_to_pil_expr( + expr, + super_circuit_annotations_map, + )); + } + PILExpr::Mul(pil_mul) + } + Expr::Neg(neg) => PILExpr::Neg(Box::new(chiquito_expr_to_pil_expr( + *neg, + super_circuit_annotations_map, + ))), + Expr::Pow(pow, power) => PILExpr::Pow( + Box::new(chiquito_expr_to_pil_expr( + *pow, + super_circuit_annotations_map, + )), + power, + ), + Expr::Query(queriable) => PILExpr::Query(chiquito_queriable_to_pil_query( + queriable, + super_circuit_annotations_map, + )), + Expr::Halo2Expr(_) => { + panic!("Halo2 native expression not supported by PIL backend.") + } + Expr::MI(_) => { + panic!("MI not supported by PIL backend.") + } + } +} + +pub type PILQuery = (PILColumn, bool); // column, rotation + +#[derive(Clone, PartialEq, Eq, Hash)] +pub enum PILColumn { + Advice(UUID, String), // UUID, annotation + Fixed(UUID, String), +} + +impl PILColumn { + pub fn uuid(&self) -> UUID { + match self { + PILColumn::Advice(uuid, _) => *uuid, + PILColumn::Fixed(uuid, _) => *uuid, + } + } + + pub fn annotation(&self) -> String { + match self { + PILColumn::Advice(_, annotation) => annotation.clone(), + PILColumn::Fixed(_, annotation) => annotation.clone(), + } + } +} + +pub fn clean_annotation(annotation: String) -> String { + annotation.replace(' ', "_") +} + +// Convert queriable to PIL column recursively. Major differences are: 1. PIL doesn't distinguish +// internal, forward, or shared columns as they are all advice; 2. PIL only supports the next +// rotation, so there's no previous or arbitrary rotation. +fn chiquito_queriable_to_pil_query( + query: Queriable, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILQuery { + match query { + Queriable::Internal(s) => { + if super_circuit_annotations_map.is_none() { + ( + PILColumn::Advice(s.uuid(), clean_annotation(s.annotation())), + false, + ) + } else { + let annotation = super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap(); + ( + PILColumn::Advice(s.uuid(), clean_annotation(annotation.clone())), + false, + ) + } + } + Queriable::Forward(s, rot) => { + if super_circuit_annotations_map.is_none() { + ( + PILColumn::Advice(s.uuid(), clean_annotation(s.annotation())), + rot, + ) + } else { + let annotation = super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap(); + ( + PILColumn::Advice(s.uuid(), clean_annotation(annotation.clone())), + rot, + ) + } + } + Queriable::Shared(s, rot) => { + let annotation = if super_circuit_annotations_map.is_none() { + clean_annotation(s.annotation()) + } else { + super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap() + .clone() + }; + if rot == 0 { + (PILColumn::Advice(s.uuid(), annotation), false) + } else if rot == 1 { + (PILColumn::Advice(s.uuid(), annotation), true) + } else { + panic!( + "PIL backend does not support shared signal with rotation other than 0 or 1." + ) + } + } + Queriable::Fixed(s, rot) => { + let annotation = if super_circuit_annotations_map.is_none() { + clean_annotation(s.annotation()) + } else { + super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap() + .clone() + }; + if rot == 0 { + (PILColumn::Fixed(s.uuid(), annotation), false) + } else if rot == 1 { + (PILColumn::Fixed(s.uuid(), annotation), true) + } else { + panic!("PIL backend does not support fixed signal with rotation other than 0 or 1.") + } + } + Queriable::StepTypeNext(s) => ( + PILColumn::Fixed(s.uuid(), clean_annotation(s.annotation())), + true, + ), + Queriable::Halo2AdviceQuery(_, _) => { + panic!("Halo2 native advice query not supported by PIL backend.") + } + Queriable::Halo2FixedQuery(_, _) => { + panic!("Halo2 native fixed query not supported by PIL backend.") + } + Queriable::_unaccessible(_) => todo!(), + } +} diff --git a/src/pil/ir/mod.rs b/src/pil/ir/mod.rs new file mode 100644 index 00000000..d57b058b --- /dev/null +++ b/src/pil/ir/mod.rs @@ -0,0 +1 @@ +pub mod powdr_pil; diff --git a/src/pil/ir/powdr_pil.rs b/src/pil/ir/powdr_pil.rs new file mode 100644 index 00000000..576d147b --- /dev/null +++ b/src/pil/ir/powdr_pil.rs @@ -0,0 +1,18 @@ +use crate::pil::compiler::{PILColumn, PILExpr, PILQuery}; +use std::collections::HashMap; +extern crate regex; + +// PIL circuit IR +pub struct PILCircuit { + pub circuit_name: String, + pub num_steps: usize, + pub col_witness: Vec, + pub col_fixed: HashMap>, // column -> assignments + pub constraints: Vec>, + pub lookups: Vec, +} + +// lookup in PIL is the format of selector {src1, src2, ..., srcn} -> {dst1, dst2, ..., dstn} +// PILLookup is a tuple of (selector, Vec) tuples, where selector is converted from +// Chiquito step type to fixed column +pub type PILLookup = (PILColumn, Vec<(PILColumn, PILColumn)>); diff --git a/src/pil/mod.rs b/src/pil/mod.rs new file mode 100644 index 00000000..aab5b126 --- /dev/null +++ b/src/pil/mod.rs @@ -0,0 +1,3 @@ +pub mod backend; +pub mod compiler; +pub mod ir; diff --git a/src/plonkish/ir/assignments.rs b/src/plonkish/ir/assignments.rs index 9f605d7a..60a91a48 100644 --- a/src/plonkish/ir/assignments.rs +++ b/src/plonkish/ir/assignments.rs @@ -132,8 +132,12 @@ impl AssignmentGenerator { } } + pub fn generate_trace_witness(&self, args: TraceArgs) -> TraceWitness { + self.trace_gen.generate(args) + } + pub fn generate(&self, args: TraceArgs) -> Assignments { - let witness = self.trace_gen.generate(args); + let witness = self.generate_trace_witness(args); self.generate_with_witness(witness) } diff --git a/src/plonkish/ir/sc.rs b/src/plonkish/ir/sc.rs index 4398bc31..05cf4663 100644 --- a/src/plonkish/ir/sc.rs +++ b/src/plonkish/ir/sc.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, hash::Hash, rc::Rc}; -use crate::{field::Field, util::UUID, wit_gen::TraceWitness}; +use crate::{field::Field, sbpir::SBPIR, util::UUID, wit_gen::TraceWitness}; use super::{ assignments::{AssignmentGenerator, Assignments}, @@ -10,6 +10,7 @@ use super::{ pub struct SuperCircuit { sub_circuits: Vec>, mapping: MappingGenerator, + sub_circuit_asts: Vec>, } impl Default for SuperCircuit { @@ -17,6 +18,7 @@ impl Default for SuperCircuit { Self { sub_circuits: Default::default(), mapping: Default::default(), + sub_circuit_asts: Default::default(), } } } @@ -29,6 +31,30 @@ impl SuperCircuit { pub fn get_mapping(&self) -> MappingGenerator { self.mapping.clone() } + + // Needed for the PIL backend. + pub fn add_sub_circuit_ast(&mut self, sub_circuit_ast: SBPIR) { + self.sub_circuit_asts.push(sub_circuit_ast); + } + + // Mapping from AST id to IR id is needed for the PIL backend to match TraceWitness, which has + // IR id, to AST. + pub fn get_ast_id_to_ir_id_mapping(&self) -> HashMap { + let mut ast_id_to_ir_id_mapping: HashMap = HashMap::new(); + self.sub_circuits.iter().for_each(|circuit| { + let ir_id = circuit.id; + let ast_id = circuit.ast_id; + ast_id_to_ir_id_mapping.insert(ast_id, ir_id); + }); + ast_id_to_ir_id_mapping + } +} + +// Needed for the PIL backend. +impl SuperCircuit { + pub fn get_super_asts(&self) -> Vec> { + self.sub_circuit_asts.clone() + } } impl SuperCircuit { @@ -47,22 +73,29 @@ impl SuperCircuit { } pub type SuperAssignments = HashMap>; +pub type SuperTraceWitness = HashMap>; pub struct MappingContext { assignments: SuperAssignments, + trace_witnesses: SuperTraceWitness, } -impl Default for MappingContext { +impl Default for MappingContext { fn default() -> Self { Self { assignments: Default::default(), + trace_witnesses: Default::default(), } } } impl MappingContext { pub fn map(&mut self, gen: &AssignmentGenerator, args: TraceArgs) { - self.assignments.insert(gen.uuid(), gen.generate(args)); + let trace_witness = gen.generate_trace_witness(args); + self.trace_witnesses + .insert(gen.uuid(), trace_witness.clone()); + self.assignments + .insert(gen.uuid(), gen.generate_with_witness(trace_witness)); } pub fn map_with_witness( @@ -77,6 +110,10 @@ impl MappingContext { pub fn get_super_assignments(self) -> SuperAssignments { self.assignments } + + pub fn get_trace_witnesses(self) -> SuperTraceWitness { + self.trace_witnesses + } } pub type Mapping = dyn Fn(&mut MappingContext, MappingArgs) + 'static; @@ -113,4 +150,13 @@ impl MappingGenerator { ctx.get_super_assignments() } + + // Needed for the PIL backend. + pub fn generate_super_trace_witnesses(&self, args: MappingArgs) -> SuperTraceWitness { + let mut ctx = MappingContext::default(); + + (self.mapping)(&mut ctx, args); + + ctx.get_trace_witnesses() + } } diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index b5a77029..217f2ba8 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -204,6 +204,28 @@ impl SBPIR { } } +impl SBPIR { + pub fn clone_without_trace(&self) -> SBPIR { + SBPIR { + step_types: self.step_types.clone(), + forward_signals: self.forward_signals.clone(), + shared_signals: self.shared_signals.clone(), + fixed_signals: self.fixed_signals.clone(), + halo2_advice: self.halo2_advice.clone(), + halo2_fixed: self.halo2_fixed.clone(), + exposed: self.exposed.clone(), + annotations: self.annotations.clone(), + trace: None, // Remove the trace. + fixed_assignments: self.fixed_assignments.clone(), + first_step: self.first_step, + last_step: self.last_step, + num_steps: self.num_steps, + q_enable: self.q_enable, + id: self.id, + } + } +} + pub type FixedGen = dyn Fn(&mut FixedGenContext) + 'static; pub type StepTypeUUID = UUID; @@ -253,6 +275,10 @@ impl StepType { self.id } + pub fn name(&self) -> String { + self.name.clone() + } + pub fn add_signal>(&mut self, name: N) -> InternalSignal { let name = name.into(); let signal = InternalSignal::new(name.clone()); @@ -417,6 +443,10 @@ impl ForwardSignal { pub fn phase(&self) -> usize { self.phase } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -450,6 +480,10 @@ impl SharedSignal { pub fn phase(&self) -> usize { self.phase } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -476,6 +510,10 @@ impl FixedSignal { pub fn uuid(&self) -> UUID { self.id } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug)] @@ -508,6 +546,10 @@ impl InternalSignal { pub fn uuid(&self) -> UUID { self.id } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] From ad070771ee641c324c37cde967342605c45686a6 Mon Sep 17 00:00:00 2001 From: even <35983442+10to4@users.noreply.github.com> Date: Tue, 5 Mar 2024 14:42:53 +0800 Subject: [PATCH 2/8] Develop blake2f circuits (#201) Add blake2f circuit example. References * [rfc7693](https://datatracker.ietf.org/doc/html/rfc7693#section-3.2) * [eip-152](https://github.com/ethereum/EIPs/blob/6572e92dccb2a581c0082befb953050f75d0ece5/EIPS/eip-152.md) --------- Co-authored-by: Leo Lara --- examples/blake2f.rs | 1499 +++++++++++++++++++++++++++++++++++++++ src/frontend/dsl/mod.rs | 6 + 2 files changed, 1505 insertions(+) create mode 100644 examples/blake2f.rs diff --git a/examples/blake2f.rs b/examples/blake2f.rs new file mode 100644 index 00000000..f4a2848f --- /dev/null +++ b/examples/blake2f.rs @@ -0,0 +1,1499 @@ +use chiquito::{ + frontend::dsl::{ + cb::{eq, select, table}, + lb::LookupTable, + super_circuit, CircuitContext, StepTypeSetupContext, StepTypeWGHandler, + }, + plonkish::{ + backend::halo2::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, + compiler::{ + cell_manager::{MaxWidthCellManager, SingleRowCellManager}, + config, + step_selector::SimpleStepSelectorBuilder, + }, + ir::sc::SuperCircuit, + }, + poly::ToExpr, + sbpir::query::Queriable, +}; +use halo2_proofs::{ + dev::MockProver, + halo2curves::{bn256::Fr, group::ff::PrimeField}, +}; +use std::{fmt::Write, hash::Hash}; + +pub const IV_LEN: usize = 8; +pub const SIGMA_VECTOR_LENGTH: usize = 16; +pub const SIGMA_VECTOR_NUMBER: usize = 10; +pub const R1: u64 = 32; +pub const R2: u64 = 24; +pub const R3: u64 = 16; +pub const R4: u64 = 63; +pub const MIXING_ROUNDS: u64 = 12; +pub const SPLIT_64BITS: u64 = 16; +pub const BASE_4BITS: u64 = 16; +pub const XOR_4SPLIT_64BITS: u64 = SPLIT_64BITS * SPLIT_64BITS; +pub const V_LEN: usize = 16; +pub const M_LEN: usize = 16; +pub const H_LEN: usize = 8; +pub const G_ROUNDS: u64 = 16; + +pub const IV_VALUES: [u64; IV_LEN] = [ + 0x6A09E667F3BCC908, + 0xBB67AE8584CAA73B, + 0x3C6EF372FE94F82B, + 0xA54FF53A5F1D36F1, + 0x510E527FADE682D1, + 0x9B05688C2B3E6C1F, + 0x1F83D9ABFB41BD6B, + 0x5BE0CD19137E2179, +]; + +pub const SIGMA_VALUES: [[usize; SIGMA_VECTOR_LENGTH]; SIGMA_VECTOR_NUMBER] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], +]; + +pub const XOR_VALUES: [u8; XOR_4SPLIT_64BITS as usize] = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, + 12, 15, 14, 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13, 3, 2, 1, 0, 7, 6, 5, 4, 11, + 10, 9, 8, 15, 14, 13, 12, 4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, 5, 4, 7, 6, 1, + 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, 6, 7, 4, 5, 2, 3, 0, 1, 14, 15, 12, 13, 10, 11, 8, 9, 7, + 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, + 5, 6, 7, 9, 8, 11, 10, 13, 12, 15, 14, 1, 0, 3, 2, 5, 4, 7, 6, 10, 11, 8, 9, 14, 15, 12, 13, 2, + 3, 0, 1, 6, 7, 4, 5, 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4, 12, 13, 14, 15, 8, + 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3, 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 14, + 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, + 2, 1, 0, +]; + +pub fn string_to_u64(inputs: [&str; 4]) -> [u64; 4] { + inputs + .iter() + .map(|&input| { + assert_eq!(16, input.len()); + u64::from_le_bytes( + (0..input.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&input[i..i + 2], 16).unwrap()) + .collect::>() + .try_into() + .unwrap(), + ) + }) + .collect::>() + .try_into() + .unwrap() +} + +pub fn u64_to_string(inputs: &[u64; 4]) -> [String; 4] { + inputs + .iter() + .map(|input| { + let mut s = String::new(); + for byte in input.to_le_bytes() { + write!(&mut s, "{:02x}", byte).expect("Unable to write"); + } + s + }) + .collect::>() + .try_into() + .unwrap() +} + +pub fn split_to_4bits_values(vec_values: &[u64]) -> Vec> { + vec_values + .iter() + .map(|&value| { + let mut value = value; + (0..SPLIT_64BITS) + .map(|_| { + let v = value % BASE_4BITS; + value >>= 4; + F::from(v) + }) + .collect() + }) + .collect() +} + +fn blake2f_iv_table( + ctx: &mut CircuitContext, + _: usize, +) -> LookupTable { + let lookup_iv_row: Queriable = ctx.fixed("iv row"); + let lookup_iv_value: Queriable = ctx.fixed("iv value"); + + let iv_values = IV_VALUES; + ctx.pragma_num_steps(IV_LEN); + ctx.fixed_gen(move |ctx| { + for (i, &value) in iv_values.iter().enumerate() { + ctx.assign(i, lookup_iv_row, F::from(i as u64)); + ctx.assign(i, lookup_iv_value, F::from(value)); + } + }); + + ctx.new_table(table().add(lookup_iv_row).add(lookup_iv_value)) +} + +// For range checking +fn blake2f_4bits_table( + ctx: &mut CircuitContext, + _: usize, +) -> LookupTable { + let lookup_4bits_row: Queriable = ctx.fixed("4bits row"); + let lookup_4bits_value: Queriable = ctx.fixed("4bits value"); + + ctx.pragma_num_steps(SPLIT_64BITS as usize); + ctx.fixed_gen(move |ctx| { + for i in 0..SPLIT_64BITS as usize { + ctx.assign(i, lookup_4bits_row, F::ONE); + ctx.assign(i, lookup_4bits_value, F::from(i as u64)); + } + }); + + ctx.new_table(table().add(lookup_4bits_row).add(lookup_4bits_value)) +} + +fn blake2f_xor_4bits_table( + ctx: &mut CircuitContext, + _: usize, +) -> LookupTable { + let lookup_xor_row: Queriable = ctx.fixed("xor row"); + let lookup_xor_value: Queriable = ctx.fixed("xor value"); + + ctx.pragma_num_steps((SPLIT_64BITS * SPLIT_64BITS) as usize); + let xor_values = XOR_VALUES; + ctx.fixed_gen(move |ctx| { + for (i, &value) in xor_values.iter().enumerate() { + ctx.assign(i, lookup_xor_row, F::from(i as u64)); + ctx.assign(i, lookup_xor_value, F::from(value as u64)); + } + }); + + ctx.new_table(table().add(lookup_xor_row).add(lookup_xor_value)) +} + +#[derive(Clone, Copy)] +struct CircuitParams { + pub iv_table: LookupTable, + pub bits_table: LookupTable, + pub xor_4bits_table: LookupTable, +} + +impl CircuitParams { + fn check_4bit( + self, + ctx: &mut StepTypeSetupContext, + bits: Queriable, + ) { + ctx.add_lookup(self.bits_table.apply(1).apply(bits)); + } + + fn check_3bit( + self, + ctx: &mut StepTypeSetupContext, + bits: Queriable, + ) { + ctx.add_lookup(self.bits_table.apply(1).apply(bits)); + ctx.add_lookup(self.bits_table.apply(1).apply(bits * 2)); + } + + fn check_xor( + self, + ctx: &mut StepTypeSetupContext, + b1: Queriable, + b2: Queriable, + xor: Queriable, + ) { + ctx.add_lookup(self.xor_4bits_table.apply(b1 * BASE_4BITS + b2).apply(xor)); + } + + fn check_not( + self, + ctx: &mut StepTypeSetupContext, + b1: Queriable, + xor: Queriable, + ) { + ctx.add_lookup(self.xor_4bits_table.apply(b1 * BASE_4BITS + 0xF).apply(xor)); + } + + fn check_iv( + self, + ctx: &mut StepTypeSetupContext, + i: usize, + iv: Queriable, + ) { + ctx.add_lookup(self.iv_table.apply(i).apply(iv)); + } +} + +struct PreInput { + round: F, + t0: F, + t1: F, + f: F, + v_vec: Vec, + h_vec: Vec, + m_vec: Vec, + h_split_4bits_vec: Vec>, + m_split_4bits_vec: Vec>, + t_split_4bits_vec: Vec>, + iv_split_4bits_vec: Vec>, + final_split_bits_vec: Vec>, +} + +struct GInput { + round: F, + v_vec: Vec, + h_vec: Vec, + m_vec: Vec, + v_mid1_vec: Vec, + v_mid2_vec: Vec, + v_mid3_vec: Vec, + v_mid4_vec: Vec, + v_mid_va_bit_vec: Vec>, + v_mid_vb_bit_vec: Vec>, + v_mid_vc_bit_vec: Vec>, + v_mid_vd_bit_vec: Vec>, + v_xor_d_bit_vec: Vec>, + v_xor_b_bit_vec: Vec>, + b_bit_vec: Vec, + b_3bits_vec: Vec, +} + +struct FinalInput { + round: F, + v_vec: Vec, + h_vec: Vec, + output_vec: Vec, + v_split_bit_vec: Vec>, + h_split_bit_vec: Vec>, + v_xor_split_bit_vec: Vec>, + final_split_bit_vec: Vec>, +} + +struct InputValues { + pub round: u32, // 32bit + pub h_vec: [u64; H_LEN], // 8 * 64bits + pub m_vec: [u64; M_LEN], // 16 * 64bits + pub t0: u64, // 64bits + pub t1: u64, // 64bits + pub f: bool, // 8bits +} + +struct GStepParams { + m_vec: Vec>, + v_mid_va_bit_vec: Vec>, + v_mid_vb_bit_vec: Vec>, + v_mid_vc_bit_vec: Vec>, + v_mid_vd_bit_vec: Vec>, + v_xor_b_bit_vec: Vec>, + v_xor_d_bit_vec: Vec>, + input_vec: Vec>, + output_vec: Vec>, + b_bit: Queriable, + b_3bits: Queriable, +} + +fn split_value_4bits(mut value: u128, n: u64) -> Vec { + (0..n) + .map(|_| { + let v = value % BASE_4BITS as u128; + value /= BASE_4BITS as u128; + + F::from(v as u64) + }) + .collect() +} + +fn split_xor_value(value1: u64, value2: u64) -> Vec { + let mut value1 = value1; + let mut value2 = value2; + let bit_values: Vec = (0..64) + .map(|_| { + let b1 = value1 % 2; + value1 /= 2; + let b2 = value2 % 2; + value2 /= 2; + b1 ^ b2 + }) + .collect(); + (0..SPLIT_64BITS as usize) + .map(|i| { + F::from( + bit_values[i * 4] + + bit_values[i * 4 + 1] * 2 + + bit_values[i * 4 + 2] * 4 + + bit_values[i * 4 + 3] * 8, + ) + }) + .collect() +} + +fn g_wg( + (v1_vec_values, v2_vec_values): (&mut [u64], &mut [u64]), + (a, b, c, d): (usize, usize, usize, usize), + (x, y): (u64, u64), + (va_bit_vec, vb_bit_vec): (&mut Vec>, &mut Vec>), + (vc_bit_vec, vd_bit_vec): (&mut Vec>, &mut Vec>), + (v_xor_d_bit_vec, v_xor_b_bit_vec): (&mut Vec>, &mut Vec>), + (b_bit_vec, b_3bits_vec): (&mut Vec, &mut Vec), +) { + va_bit_vec.push(split_value_4bits( + v1_vec_values[a] as u128 + v1_vec_values[b] as u128 + x as u128, + SPLIT_64BITS + 1, + )); + v1_vec_values[a] = (v1_vec_values[a] as u128 + v1_vec_values[b] as u128 + x as u128) as u64; + + vd_bit_vec.push(split_value_4bits(v1_vec_values[d] as u128, SPLIT_64BITS)); + v1_vec_values[d] = ((v1_vec_values[d] ^ v1_vec_values[a]) >> R1) + ^ (v1_vec_values[d] ^ v1_vec_values[a]) << (64 - R1); + v_xor_d_bit_vec.push(split_value_4bits(v1_vec_values[d] as u128, SPLIT_64BITS)); + + vc_bit_vec.push(split_value_4bits( + v1_vec_values[c] as u128 + v1_vec_values[d] as u128, + SPLIT_64BITS + 1, + )); + v1_vec_values[c] = (v1_vec_values[c] as u128 + v1_vec_values[d] as u128) as u64; + + vb_bit_vec.push(split_value_4bits(v1_vec_values[b] as u128, SPLIT_64BITS)); + v1_vec_values[b] = ((v1_vec_values[b] ^ v1_vec_values[c]) >> R2) + ^ (v1_vec_values[b] ^ v1_vec_values[c]) << (64 - R2); + v_xor_b_bit_vec.push(split_value_4bits(v1_vec_values[b] as u128, SPLIT_64BITS)); + + va_bit_vec.push(split_value_4bits( + v1_vec_values[a] as u128 + v1_vec_values[b] as u128 + y as u128, + SPLIT_64BITS + 1, + )); + v2_vec_values[a] = (v1_vec_values[a] as u128 + v1_vec_values[b] as u128 + y as u128) as u64; + + vd_bit_vec.push(split_value_4bits(v1_vec_values[d] as u128, SPLIT_64BITS)); + v2_vec_values[d] = ((v1_vec_values[d] ^ v2_vec_values[a]) >> R3) + ^ (v1_vec_values[d] ^ v2_vec_values[a]) << (64 - R3); + v_xor_d_bit_vec.push(split_value_4bits(v2_vec_values[d] as u128, SPLIT_64BITS)); + + vc_bit_vec.push(split_value_4bits( + v1_vec_values[c] as u128 + v2_vec_values[d] as u128, + SPLIT_64BITS + 1, + )); + v2_vec_values[c] = (v1_vec_values[c] as u128 + v2_vec_values[d] as u128) as u64; + + vb_bit_vec.push(split_value_4bits(v1_vec_values[b] as u128, SPLIT_64BITS)); + v2_vec_values[b] = ((v1_vec_values[b] ^ v2_vec_values[c]) >> R4) + ^ (v1_vec_values[b] ^ v2_vec_values[c]) << (64 - R4); + v_xor_b_bit_vec.push(split_value_4bits( + (v1_vec_values[b] ^ v2_vec_values[c]) as u128, + SPLIT_64BITS, + )); + let bits = (v1_vec_values[b] ^ v2_vec_values[c]) / 2u64.pow(60); + b_bit_vec.push(F::from(bits / 8)); + b_3bits_vec.push(F::from(bits % 8)) +} + +fn split_4bit_signals( + ctx: &mut StepTypeSetupContext, + params: &CircuitParams, + input: &[Queriable], + output: &[Vec>], +) { + for (i, split_vec) in output.iter().enumerate() { + let mut sum_value = 0.expr() * 1; + + for &bits in split_vec.iter().rev() { + params.check_4bit(ctx, bits); + sum_value = sum_value * BASE_4BITS + bits; + } + ctx.constr(eq(sum_value, input[i])) + } +} + +// We check G function one time by calling twice g_setup function.c +// Because the G function can be divided into two similar parts. +fn g_setup( + ctx: &mut StepTypeSetupContext<'_, F>, + params: CircuitParams, + q_params: GStepParams, + (a, b, c, d): (usize, usize, usize, usize), + (move1, move2): (u64, u64), + s: usize, + flag: bool, +) { + let mut a_bits_sum_value = 0.expr() * 1; + let mut a_bits_sum_mod_value = 0.expr() * 1; + for (j, &bits) in q_params.v_mid_va_bit_vec.iter().rev().enumerate() { + a_bits_sum_value = a_bits_sum_value * BASE_4BITS + bits; + if j != 0 { + a_bits_sum_mod_value = a_bits_sum_mod_value * BASE_4BITS + bits; + } + params.check_4bit(ctx, bits); + } + // check v_mid_va_bit_vec = 4bit split of v[a] + v[b] + x + ctx.constr(eq( + a_bits_sum_value, + q_params.input_vec[a] + q_params.input_vec[b] + q_params.m_vec[s], + )); + // check v[a] = (v[a] + v[b] + x) mod 2^64 + ctx.constr(eq(a_bits_sum_mod_value, q_params.output_vec[a])); + + // check d_bits_sum_value = 4bit split of v[b] + let mut d_bits_sum_value = 0.expr() * 1; + for &bits in q_params.v_mid_vd_bit_vec.iter().rev() { + d_bits_sum_value = d_bits_sum_value * BASE_4BITS + bits; + params.check_4bit(ctx, bits); + } + ctx.constr(eq(d_bits_sum_value, q_params.input_vec[d])); + + let mut ad_xor_sum_value = 0.expr() * 1; + for &bits in q_params.v_xor_d_bit_vec.iter().rev() { + ad_xor_sum_value = ad_xor_sum_value * BASE_4BITS + bits; + } + // check v_xor_d_bit_vec = 4bit split of v[d] + ctx.constr(eq(ad_xor_sum_value, q_params.output_vec[d])); + // check v_xor_d_bit_vec[i] = (v[d][i] ^ v[a][i]) >>> R1(or R3) + for j in 0..SPLIT_64BITS as usize { + params.check_xor( + ctx, + q_params.v_mid_va_bit_vec[j], + q_params.v_mid_vd_bit_vec[j], + q_params.v_xor_d_bit_vec + [(j + BASE_4BITS as usize - move1 as usize) % BASE_4BITS as usize], + ); + } + + // check v[c] = (v[c] + v[d]) mod 2^64 + let mut c_bits_sum_value = 0.expr() * 1; + let mut c_bits_sum_mod_value = 0.expr() * 1; + for (j, &bits) in q_params.v_mid_vc_bit_vec.iter().rev().enumerate() { + c_bits_sum_value = c_bits_sum_value * BASE_4BITS + bits; + if j != 0 { + c_bits_sum_mod_value = c_bits_sum_mod_value * BASE_4BITS + bits; + } + params.check_4bit(ctx, bits); + } + // check v_mid_vc_bit_vec = 4bit split of (v[c] + v[d]) + ctx.constr(eq( + c_bits_sum_value, + q_params.input_vec[c] + q_params.output_vec[d], + )); + // check v[c] = (v[c] + v[d] ) mod 2^64 + ctx.constr(eq(c_bits_sum_mod_value, q_params.output_vec[c])); + + let mut b_bits_sum_value = 0.expr() * 1; + for &bits in q_params.v_mid_vb_bit_vec.iter().rev() { + b_bits_sum_value = b_bits_sum_value * BASE_4BITS + bits; + params.check_4bit(ctx, bits); + } + + // v_mid_vb_bit_vec = 4bit split of v[b] + ctx.constr(eq(b_bits_sum_value, q_params.input_vec[b])); + let mut bc_xor_sum_value = 0.expr() * 1; + for (j, &bits) in q_params.v_xor_b_bit_vec.iter().rev().enumerate() { + if j == 0 && flag { + // b_bit * 8 + b_3bits = v_xor_b_bit_vec[0] + bc_xor_sum_value = q_params.b_3bits * 1; + ctx.constr(eq(q_params.b_bit * 8 + q_params.b_3bits, bits)); + } else { + bc_xor_sum_value = bc_xor_sum_value * BASE_4BITS + bits; + } + params.check_4bit(ctx, bits); + } + if flag { + bc_xor_sum_value = bc_xor_sum_value * 2 + q_params.b_bit; + + ctx.constr(eq(q_params.b_bit * (q_params.b_bit - 1), 0)); + // To constraint b_3bits_vec[i/2] \in [0..8) + params.check_3bit(ctx, q_params.b_3bits); + } + // check v_xor_b_bit_vec = v[b] + ctx.constr(eq(bc_xor_sum_value, q_params.output_vec[b])); + + // check v_xor_b_bit_vec[i] = (v[b][i] ^ v[c][i]) >>> R2(or R4) + for j in 0..SPLIT_64BITS as usize { + params.check_xor( + ctx, + q_params.v_mid_vb_bit_vec[j], + q_params.v_mid_vc_bit_vec[j], + q_params.v_xor_b_bit_vec + [(j + BASE_4BITS as usize - move2 as usize) % BASE_4BITS as usize], + ); + } +} + +fn blake2f_circuit( + ctx: &mut CircuitContext, + params: CircuitParams, +) { + let v_vec: Vec> = (0..V_LEN) + .map(|i| ctx.forward(format!("v_vec[{}]", i).as_str())) + .collect(); + let h_vec: Vec> = (0..H_LEN) + .map(|i| ctx.forward(format!("h_vec[{}]", i).as_str())) + .collect(); + let m_vec: Vec> = (0..M_LEN) + .map(|i| ctx.forward(format!("m_vec[{}]", i).as_str())) + .collect(); + let round = ctx.forward("round"); + + let blake2f_pre_step = ctx.step_type_def("blake2f_pre_step", |ctx| { + let v_vec = v_vec.clone(); + let wg_v_vec = v_vec.clone(); + + let h_vec = h_vec.clone(); + let wg_h_vec = h_vec.clone(); + + let m_vec = m_vec.clone(); + let wg_m_vec = m_vec.clone(); + + let t0 = ctx.internal("t0"); + let t1 = ctx.internal("t1"); + let f = ctx.internal("f"); + + // h_split_4bits_vec = 4bit split of h_vec + let h_split_4bits_vec: Vec>> = (0..H_LEN) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("h_split_4bits_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_h_split_4bits_vec = h_split_4bits_vec.clone(); + + // m_split_4bits_vec = 4bit split of m_vec + let m_split_4bits_vec: Vec>> = (0..M_LEN) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("m_split_4bits_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_m_split_4bits_vec = m_split_4bits_vec.clone(); + + // t_split_4bits_vec = 4bit split of t0 and t1 + let t_split_4bits_vec: Vec>> = (0..2) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("t_split_4bits_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_t_split_4bits_vec = t_split_4bits_vec.clone(); + + // iv_split_4bits_vec = 4bit split of IV[5], IV[6], IV[7] + let iv_split_4bits_vec: Vec>> = (0..3) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("iv_split_4bits_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_iv_split_4bits_vec = iv_split_4bits_vec.clone(); + + // final_split_bits_vec = 4bit split of IV[5] xor t0, IV[6] xor t1, IV[7] xor + // 0xFFFFFFFFFFFFFFFF, + let final_split_bits_vec: Vec>> = (0..3) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("final_split_bits_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_final_split_bits_vec = final_split_bits_vec.clone(); + + ctx.setup(move |ctx| { + // check inputs: h_vec + split_4bit_signals(ctx, ¶ms, &h_vec, &h_split_4bits_vec); + + // check inputs: m_vec + split_4bit_signals(ctx, ¶ms, &m_vec, &m_split_4bits_vec); + + // check inputs: t0,t1 + split_4bit_signals(ctx, ¶ms, &[t0, t1], &t_split_4bits_vec); + + // check input f + ctx.constr(eq(f * (f - 1), 0)); + + // check v_vec + for i in 0..H_LEN { + ctx.constr(eq(v_vec[i], h_vec[i])); + } + for (i, &iv) in v_vec[V_LEN / 2..V_LEN].iter().enumerate() { + params.check_iv(ctx, i, iv); + } + + // check the split-fields of v[12], v[13], v[14] + split_4bit_signals(ctx, ¶ms, &v_vec[12..15], &iv_split_4bits_vec); + + // check v[12] := v[12] ^ (t mod 2**w) + // check v[13] := v[13] ^ (t >> w) + for (i, (final_plit_bits_value, (iv_split_bits_value, t_split_bits_value))) in + final_split_bits_vec + .iter() + .zip(iv_split_4bits_vec.iter().zip(t_split_4bits_vec.iter())) + .enumerate() + .take(2) + { + let mut final_bits_sum_value = 0.expr() * 1; + for (&value, (&iv, &t)) in final_plit_bits_value.iter().rev().zip( + iv_split_bits_value + .iter() + .rev() + .zip(t_split_bits_value.iter().rev()), + ) { + params.check_xor(ctx, iv, t, value); + final_bits_sum_value = final_bits_sum_value * BASE_4BITS + value; + } + ctx.constr(eq(final_bits_sum_value, v_vec[12 + i].next())) + } + + // check if f, v[14] = v[14] ^ 0xffffffffffffffff else v[14] + let mut final_bits_sum_value = 0.expr() * 1; + for (&bits, &iv) in final_split_bits_vec[2] + .iter() + .rev() + .zip(iv_split_4bits_vec[2].iter().rev()) + { + params.check_not(ctx, iv, bits); + final_bits_sum_value = final_bits_sum_value * BASE_4BITS + bits; + } + + // check v_vec v_vec.next + for &v in v_vec.iter().take(12) { + ctx.transition(eq(v, v.next())); + } + ctx.transition(eq( + select(f, final_bits_sum_value, v_vec[14]), + v_vec[14].next(), + )); + ctx.transition(eq(v_vec[15], v_vec[15].next())); + // check h_vec h_vec.next + for &h in h_vec.iter() { + ctx.transition(eq(h, h.next())); + } + // check m_vec m_vec.next + for &m in m_vec.iter() { + ctx.transition(eq(m, m.next())); + } + + ctx.constr(eq(round, 0)); + ctx.transition(eq(round, round.next())); + }); + + ctx.wg(move |ctx, inputs: PreInput| { + ctx.assign(round, inputs.round); + ctx.assign(t0, inputs.t0); + ctx.assign(t1, inputs.t1); + ctx.assign(f, inputs.f); + for (&q, &v) in wg_v_vec.iter().zip(inputs.v_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_h_vec.iter().zip(inputs.h_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_m_vec.iter().zip(inputs.m_vec.iter()) { + ctx.assign(q, v) + } + for (q_vec, v_vec) in wg_h_split_4bits_vec + .iter() + .zip(inputs.h_split_4bits_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_m_split_4bits_vec + .iter() + .zip(inputs.m_split_4bits_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_t_split_4bits_vec + .iter() + .zip(inputs.t_split_4bits_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_iv_split_4bits_vec + .iter() + .zip(inputs.iv_split_4bits_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_final_split_bits_vec + .iter() + .zip(inputs.final_split_bits_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + }) + }); + + let blake2f_g_setup_vec: Vec> = (0..MIXING_ROUNDS as usize) + .map(|r| { + ctx.step_type_def(format!("blake2f_g_setup_{}", r), |ctx| { + let v_vec = v_vec.clone(); + let wg_v_vec = v_vec.clone(); + let h_vec = h_vec.clone(); + let wg_h_vec = h_vec.clone(); + let m_vec = m_vec.clone(); + let wg_m_vec = m_vec.clone(); + + // v_mid1_vec is the new v_vec after the first round call to the g_setup function + let v_mid1_vec: Vec> = (0..V_LEN) + .map(|i| ctx.internal(format!("v_mid1_vec[{}]", i).as_str())) + .collect(); + let wg_v_mid1_vec = v_mid1_vec.clone(); + + // v_mid2_vec is the new v_vec after the second round call to the g_setup function + let v_mid2_vec: Vec> = (0..V_LEN) + .map(|i| ctx.internal(format!("v_mid2_vec[{}]", i).as_str())) + .collect(); + let wg_v_mid2_vec = v_mid2_vec.clone(); + + // v_mid3_vec is the new v_vec after the third round to the g_setup function + let v_mid3_vec: Vec> = (0..V_LEN) + .map(|i| ctx.internal(format!("v_mid3_vec[{}]", i).as_str())) + .collect(); + let wg_v_mid3_vec = v_mid3_vec.clone(); + + // v_mid4_vec is the new v_vec after the forth round to the g_setup function,as + // well as the final result of v_vec + let v_mid4_vec: Vec> = (0..V_LEN) + .map(|i| ctx.internal(format!("v_mid4_vec[{}]", i).as_str())) + .collect(); + let wg_v_mid4_vec = v_mid4_vec.clone(); + + // v_mid_va_bit_vec = 4bit split of v[a] + v[b] + x(or y) + let v_mid_va_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS + 1) + .map(|j| { + ctx.internal(format!("v_mid_va_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_mid_va_bit_vec = v_mid_va_bit_vec.clone(); + + // v_mid_vd_bit_vec = 4bit split of v[d] + let v_mid_vd_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| { + ctx.internal(format!("v_mid_vd_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_mid_vd_bit_vec = v_mid_vd_bit_vec.clone(); + + // v_mid_vc_bit_vec = 4bit split of v[c] + v[d] + let v_mid_vc_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS + 1) + .map(|j| { + ctx.internal(format!("v_mid_vc_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_mid_vc_bit_vec = v_mid_vc_bit_vec.clone(); + + // v_mid_vb_bit_vec = 4bit split of v[b] + let v_mid_vb_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| { + ctx.internal(format!("v_mid_vb_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_mid_vb_bit_vec = v_mid_vb_bit_vec.clone(); + + // v_xor_d_bit_vec = 4bit split of (v[d] ^ v[a]) >>> R1(or R3) + let v_xor_d_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| { + ctx.internal(format!("v_xor_d_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_xor_d_bit_vec = v_xor_d_bit_vec.clone(); + + // v_xor_b_bit_vec = 4bit split of (v[b] ^ v[c]) >>> R2(or R4) + let v_xor_b_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| { + ctx.internal(format!("v_xor_b_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_xor_b_bit_vec = v_xor_b_bit_vec.clone(); + + // b_bit_vec[i] * 8 + b_3bits_vec[i] = v_xor_b_bit_vec[i * 2 + 1][0] + // the step of v[b] := (v[b] ^ v[c]) >>> R4 needs to split a 4-bit value to a + // one-bit value and a 3-bit value + let b_bit_vec: Vec> = (0..G_ROUNDS / 2) + .map(|i| ctx.internal(format!("b_bit_vec[{}]", i).as_str())) + .collect(); + let wg_b_bit_vec = b_bit_vec.clone(); + let b_3bits_vec: Vec> = (0..G_ROUNDS / 2) + .map(|i| ctx.internal(format!("b_3bits_vec[{}]", i).as_str())) + .collect(); + let wg_b_3bits_vec = b_3bits_vec.clone(); + + ctx.setup(move |ctx| { + let s = SIGMA_VALUES[r % 10]; + + for i in 0..G_ROUNDS as usize { + let mut input_vec = v_vec.clone(); + let mut output_vec = v_mid1_vec.clone(); + if i >= 8 { + if i % 2 == 0 { + input_vec = v_mid2_vec.clone(); + output_vec = v_mid3_vec.clone(); + } else { + input_vec = v_mid3_vec.clone(); + output_vec = v_mid4_vec.clone(); + } + } else if i % 2 == 1 { + input_vec = v_mid1_vec.clone(); + output_vec = v_mid2_vec.clone(); + } + let (mut a, mut b, mut c, mut d) = + (i / 2, 4 + i / 2, 8 + i / 2, 12 + i / 2); + if i / 2 == 4 { + (a, b, c, d) = (0, 5, 10, 15); + } else if i / 2 == 5 { + (a, b, c, d) = (1, 6, 11, 12); + } else if i / 2 == 6 { + (a, b, c, d) = (2, 7, 8, 13); + } else if i / 2 == 7 { + (a, b, c, d) = (3, 4, 9, 14); + } + let mut move1 = R1 / 4; + let mut move2 = R2 / 4; + if i % 2 == 1 { + move1 = R3 / 4; + move2 = (R4 + 1) / 4; + } + let q_params = GStepParams { + input_vec, + output_vec, + m_vec: m_vec.clone(), + v_mid_va_bit_vec: v_mid_va_bit_vec[i].clone(), + v_mid_vb_bit_vec: v_mid_vb_bit_vec[i].clone(), + v_mid_vc_bit_vec: v_mid_vc_bit_vec[i].clone(), + v_mid_vd_bit_vec: v_mid_vd_bit_vec[i].clone(), + v_xor_b_bit_vec: v_xor_b_bit_vec[i].clone(), + v_xor_d_bit_vec: v_xor_d_bit_vec[i].clone(), + b_bit: b_bit_vec[i / 2], + b_3bits: b_3bits_vec[i / 2], + }; + g_setup( + ctx, + params, + q_params, + (a, b, c, d), + (move1, move2), + s[i], + i % 2 == 1, + ); + } + + // check v_vec v_vec.next() + for (&v, &new_v) in v_vec.iter().zip(v_mid4_vec.iter()) { + ctx.transition(eq(new_v, v.next())); + } + // check h_vec h_vec.next() + for &h in h_vec.iter() { + ctx.transition(eq(h, h.next())); + } + // check m_vec m_vec.next() + if r < MIXING_ROUNDS as usize - 1 { + for &m in m_vec.iter() { + ctx.transition(eq(m, m.next())); + } + } + ctx.transition(eq(round + 1, round.next())); + }); + + ctx.wg(move |ctx, inputs: GInput| { + ctx.assign(round, inputs.round); + for (&q, &v) in wg_v_vec.iter().zip(inputs.v_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_h_vec.iter().zip(inputs.h_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_m_vec.iter().zip(inputs.m_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_v_mid1_vec.iter().zip(inputs.v_mid1_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_v_mid2_vec.iter().zip(inputs.v_mid2_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_v_mid3_vec.iter().zip(inputs.v_mid3_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_v_mid4_vec.iter().zip(inputs.v_mid4_vec.iter()) { + ctx.assign(q, v) + } + for (q_vec, v_vec) in wg_v_mid_va_bit_vec + .iter() + .zip(inputs.v_mid_va_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_v_mid_vb_bit_vec + .iter() + .zip(inputs.v_mid_vb_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_v_mid_vc_bit_vec + .iter() + .zip(inputs.v_mid_vc_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_v_mid_vd_bit_vec + .iter() + .zip(inputs.v_mid_vd_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in + wg_v_xor_d_bit_vec.iter().zip(inputs.v_xor_d_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in + wg_v_xor_b_bit_vec.iter().zip(inputs.v_xor_b_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (&q, &v) in wg_b_bit_vec.iter().zip(inputs.b_bit_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_b_3bits_vec.iter().zip(inputs.b_3bits_vec.iter()) { + ctx.assign(q, v) + } + }) + }) + }) + .collect(); + + let blake2f_final_step = ctx.step_type_def("blake2f_final_step", |ctx| { + let v_vec = v_vec.clone(); + let wg_v_vec = v_vec.clone(); + + let h_vec = h_vec.clone(); + let wg_h_vec = h_vec.clone(); + + let output_vec = m_vec.clone(); + let wg_output_vec = output_vec.clone(); + + // v_split_bit_vec = 4bit split of v_vec + let v_split_bit_vec: Vec>> = (0..V_LEN) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("v_split_bit_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_v_split_bit_vec = v_split_bit_vec.clone(); + + // h_split_bit_vec = 4bit split of h_vec + let h_split_bit_vec: Vec>> = (0..H_LEN) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("h_split_bit_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_h_split_bit_vec = h_split_bit_vec.clone(); + + // v_xor_split_bit_vec = 4bit split of v[i] ^ v[i + 8] + let v_xor_split_bit_vec: Vec>> = (0..8) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("v_xor_split_bit_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_v_xor_split_bit_vec = v_xor_split_bit_vec.clone(); + + // final_split_bit_vec = 4bit split of h[i] ^ v[i] ^ v[i + 8] + let final_split_bit_vec: Vec>> = (0..8) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("v_xor_split_bit_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_final_split_bit_vec = final_split_bit_vec.clone(); + + ctx.setup(move |ctx| { + // check split-fields of v_vec + for (&v, v_split) in v_vec.iter().zip(v_split_bit_vec.iter()) { + let mut v_4bits_sum_value = 0.expr() * 1; + for &bits in v_split.iter().rev() { + v_4bits_sum_value = v_4bits_sum_value * BASE_4BITS + bits; + params.check_4bit(ctx, bits); + } + ctx.constr(eq(v_4bits_sum_value, v)); + } + + // check split-fields of h_vec + for (&h, h_split) in h_vec.iter().zip(h_split_bit_vec.iter()) { + let mut h_4bits_sum_value = 0.expr() * 1; + for &bits in h_split.iter().rev() { + h_4bits_sum_value = h_4bits_sum_value * BASE_4BITS + bits; + params.check_4bit(ctx, bits); + } + ctx.constr(eq(h_4bits_sum_value, h)); + } + + // check split-fields of v[i] ^ v[i+8] + for (xor_vec, (v1_vec, v2_vec)) in v_xor_split_bit_vec.iter().zip( + v_split_bit_vec[0..V_LEN / 2] + .iter() + .zip(v_split_bit_vec[V_LEN / 2..V_LEN].iter()), + ) { + for (&xor, (&v1, &v2)) in xor_vec.iter().zip(v1_vec.iter().zip(v2_vec.iter())) { + params.check_xor(ctx, v1, v2, xor); + } + } + + // check split-fields of h[i] ^ v[i] ^ v[i+8] + for (final_vec, (xor_vec, h_vec)) in final_split_bit_vec + .iter() + .zip(v_xor_split_bit_vec.iter().zip(h_split_bit_vec.iter())) + { + for (&value, (&v1, &v2)) in final_vec.iter().zip(xor_vec.iter().zip(h_vec.iter())) { + params.check_xor(ctx, v1, v2, value); + } + } + + // check output = h[i] ^ v[i] ^ v[i+8] + for (final_vec, &output) in final_split_bit_vec.iter().zip(output_vec.iter()) { + let mut final_4bits_sum_value = 0.expr() * 1; + for &value in final_vec.iter().rev() { + final_4bits_sum_value = final_4bits_sum_value * BASE_4BITS + value; + } + ctx.constr(eq(output, final_4bits_sum_value)); + } + ctx.constr(eq(round, MIXING_ROUNDS)); + }); + + ctx.wg(move |ctx, inputs: FinalInput| { + ctx.assign(round, inputs.round); + for (&q, &v) in wg_v_vec.iter().zip(inputs.v_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_h_vec.iter().zip(inputs.h_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_output_vec.iter().zip(inputs.output_vec.iter()) { + ctx.assign(q, v) + } + for (q_vec, v_vec) in wg_v_split_bit_vec.iter().zip(inputs.v_split_bit_vec.iter()) { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_h_split_bit_vec.iter().zip(inputs.h_split_bit_vec.iter()) { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_v_xor_split_bit_vec + .iter() + .zip(inputs.v_xor_split_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_final_split_bit_vec + .iter() + .zip(inputs.final_split_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + }) + }); + + ctx.pragma_first_step(&blake2f_pre_step); + ctx.pragma_last_step(&blake2f_final_step); + ctx.pragma_num_steps(MIXING_ROUNDS as usize + 2); + + ctx.trace(move |ctx, values| { + let h_vec_values = values.h_vec.to_vec(); + let h_split_4bits_vec = split_to_4bits_values::(&h_vec_values); + + let m_vec_values = values.m_vec.to_vec(); + let m_split_4bits_vec = split_to_4bits_values::(&m_vec_values); + + let mut iv_vec_values = IV_VALUES.to_vec(); + let iv_split_4bits_vec: Vec> = split_to_4bits_values::(&iv_vec_values[4..7]); + + let mut v_vec_values = h_vec_values.clone(); + v_vec_values.append(&mut iv_vec_values); + + let t_split_4bits_vec = split_to_4bits_values::(&[values.t0, values.t1]); + + let final_values = vec![ + v_vec_values[12] ^ values.t0, + v_vec_values[13] ^ values.t1, + v_vec_values[14] ^ 0xFFFFFFFFFFFFFFFF, + ]; + let final_split_bits_vec = split_to_4bits_values::(&final_values); + + let pre_inputs = PreInput { + round: F::ZERO, + t0: F::from(values.t0), + t1: F::from(values.t1), + f: F::from(if values.f { 1 } else { 0 }), + h_vec: h_vec_values.iter().map(|&v| F::from(v)).collect(), + m_vec: m_vec_values.iter().map(|&v| F::from(v)).collect(), + v_vec: v_vec_values.iter().map(|&v| F::from(v)).collect(), + h_split_4bits_vec, + m_split_4bits_vec, + t_split_4bits_vec, + iv_split_4bits_vec, + final_split_bits_vec, + }; + ctx.add(&blake2f_pre_step, pre_inputs); + + v_vec_values[12] = final_values[0]; + v_vec_values[13] = final_values[1]; + if values.f { + v_vec_values[14] = final_values[2]; + } + + for r in 0..values.round { + let s = SIGMA_VALUES[(r as usize) % 10]; + + let mut v_mid1_vec_values = v_vec_values.clone(); + let mut v_mid2_vec_values = v_vec_values.clone(); + let mut v_mid_va_bit_vec = Vec::new(); + let mut v_mid_vb_bit_vec = Vec::new(); + let mut v_mid_vc_bit_vec = Vec::new(); + let mut v_mid_vd_bit_vec = Vec::new(); + let mut v_xor_d_bit_vec = Vec::new(); + let mut v_xor_b_bit_vec = Vec::new(); + let mut b_bit_vec = Vec::new(); + let mut b_3bits_vec = Vec::new(); + + g_wg( + (&mut v_mid1_vec_values, &mut v_mid2_vec_values), + (0, 4, 8, 12), + (m_vec_values[s[0]], m_vec_values[s[1]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid1_vec_values, &mut v_mid2_vec_values), + (1, 5, 9, 13), + (m_vec_values[s[2]], m_vec_values[s[3]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid1_vec_values, &mut v_mid2_vec_values), + (2, 6, 10, 14), + (m_vec_values[s[4]], m_vec_values[s[5]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid1_vec_values, &mut v_mid2_vec_values), + (3, 7, 11, 15), + (m_vec_values[s[6]], m_vec_values[s[7]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + + let mut v_mid3_vec_values = v_mid2_vec_values.clone(); + let mut v_mid4_vec_values = v_mid2_vec_values.clone(); + g_wg( + (&mut v_mid3_vec_values, &mut v_mid4_vec_values), + (0, 5, 10, 15), + (m_vec_values[s[8]], m_vec_values[s[9]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid3_vec_values, &mut v_mid4_vec_values), + (1, 6, 11, 12), + (m_vec_values[s[10]], m_vec_values[s[11]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid3_vec_values, &mut v_mid4_vec_values), + (2, 7, 8, 13), + (m_vec_values[s[12]], m_vec_values[s[13]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid3_vec_values, &mut v_mid4_vec_values), + (3, 4, 9, 14), + (m_vec_values[s[14]], m_vec_values[s[15]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + + let ginputs = GInput { + round: F::from(r as u64), + v_vec: v_vec_values.iter().map(|&v| F::from(v)).collect(), + h_vec: h_vec_values.iter().map(|&v| F::from(v)).collect(), + m_vec: m_vec_values.iter().map(|&v| F::from(v)).collect(), + v_mid1_vec: v_mid1_vec_values.iter().map(|&v| F::from(v)).collect(), + v_mid2_vec: v_mid2_vec_values.iter().map(|&v| F::from(v)).collect(), + v_mid3_vec: v_mid3_vec_values.iter().map(|&v| F::from(v)).collect(), + v_mid4_vec: v_mid4_vec_values.iter().map(|&v| F::from(v)).collect(), + v_mid_va_bit_vec, + v_mid_vb_bit_vec, + v_mid_vc_bit_vec, + v_mid_vd_bit_vec, + v_xor_d_bit_vec, + v_xor_b_bit_vec, + b_bit_vec, + b_3bits_vec, + }; + ctx.add(&blake2f_g_setup_vec[r as usize], ginputs); + v_vec_values = v_mid4_vec_values.clone(); + } + + let output_vec_values: Vec = h_vec_values + .iter() + .zip( + v_vec_values[0..8] + .iter() + .zip(v_vec_values[V_LEN / 2..V_LEN].iter()), + ) + .map(|(h, (v1, v2))| h ^ v1 ^ v2) + .collect(); + + let final_inputs = FinalInput { + round: F::from(values.round as u64), + v_vec: v_vec_values.iter().map(|&v| F::from(v)).collect(), + h_vec: h_vec_values.iter().map(|&v| F::from(v)).collect(), + output_vec: output_vec_values.iter().map(|&v| F::from(v)).collect(), + v_split_bit_vec: v_vec_values + .iter() + .map(|&v| split_value_4bits(v as u128, SPLIT_64BITS)) + .collect(), + h_split_bit_vec: h_vec_values + .iter() + .map(|&v| split_value_4bits(v as u128, SPLIT_64BITS)) + .collect(), + v_xor_split_bit_vec: v_vec_values[0..V_LEN / 2] + .iter() + .zip(v_vec_values[V_LEN / 2..V_LEN].iter()) + .map(|(&v1, &v2)| split_xor_value(v1, v2)) + .collect(), + final_split_bit_vec: output_vec_values + .iter() + .map(|&output| split_value_4bits(output as u128, SPLIT_64BITS)) + .collect(), + }; + ctx.add(&blake2f_final_step, final_inputs); + // ba80a53f981c4d0d, 6a2797b69f12f6e9, 4c212f14685ac4b7, 4b12bb6fdbffa2d1 + // 7d87c5392aab792d, c252d5de4533cc95, 18d38aa8dbf1925a,b92386edd4009923 + println!( + "output = {:?} \n {:?}", + u64_to_string(&output_vec_values[0..4].try_into().unwrap()), + u64_to_string(&output_vec_values[4..8].try_into().unwrap()) + ); + }) +} + +fn blake2f_super_circuit() -> SuperCircuit { + super_circuit::("blake2f", |ctx| { + let single_config = config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}); + let (_, iv_table) = ctx.sub_circuit(single_config.clone(), blake2f_iv_table, IV_LEN); + let (_, bits_table) = ctx.sub_circuit( + single_config.clone(), + blake2f_4bits_table, + SPLIT_64BITS as usize, + ); + let (_, xor_4bits_table) = ctx.sub_circuit( + single_config, + blake2f_xor_4bits_table, + (SPLIT_64BITS * SPLIT_64BITS) as usize, + ); + + let maxwidth_config = config( + MaxWidthCellManager::new(250, true), + SimpleStepSelectorBuilder {}, + ); + + let params = CircuitParams { + iv_table, + bits_table, + xor_4bits_table, + }; + let (blake2f, _) = ctx.sub_circuit(maxwidth_config, blake2f_circuit, params); + + ctx.mapping(move |ctx, values| { + ctx.map(&blake2f, values); + }) + }) +} + +fn main() { + let super_circuit = blake2f_super_circuit::(); + let compiled = chiquitoSuperCircuit2Halo2(&super_circuit); + + // h[0] = hex"48c9bdf267e6096a 3ba7ca8485ae67bb 2bf894fe72f36e3c f1361d5f3af54fa5"; + // h[1] = hex"d182e6ad7f520e51 1f6c3e2b8c68059b 6bbd41fbabd9831f 79217e1319cde05b"; + let h0 = string_to_u64([ + "48c9bdf267e6096a", + "3ba7ca8485ae67bb", + "2bf894fe72f36e3c", + "f1361d5f3af54fa5", + ]); + let h1 = string_to_u64([ + "d182e6ad7f520e51", + "1f6c3e2b8c68059b", + "6bbd41fbabd9831f", + "79217e1319cde05b", + ]); + // m[0] = hex"6162630000000000 0000000000000000 0000000000000000 0000000000000000"; + // m[1] = hex"0000000000000000 0000000000000000 0000000000000000 0000000000000000"; + // m[2] = hex"0000000000000000 0000000000000000 0000000000000000 0000000000000000"; + // m[3] = hex"0000000000000000 0000000000000000 0000000000000000 0000000000000000"; + let m0 = string_to_u64([ + "6162630000000000", + "0000000000000000", + "0000000000000000", + "0000000000000000", + ]); + let m1 = string_to_u64([ + "0000000000000000", + "0000000000000000", + "0000000000000000", + "0000000000000000", + ]); + let m2 = string_to_u64([ + "0000000000000000", + "0000000000000000", + "0000000000000000", + "0000000000000000", + ]); + let m3 = string_to_u64([ + "0000000000000000", + "0000000000000000", + "0000000000000000", + "0000000000000000", + ]); + + let values = InputValues { + round: 12, + + h_vec: [ + h0[0], // 0x6a09e667f2bdc948, + h0[1], // 0xbb67ae8584caa73b, + h0[2], // 0x3c6ef372fe94f82b, + h0[3], // 0xa54ff53a5f1d36f1, + h1[0], // 0x510e527fade682d1, + h1[1], // 0x9b05688c2b3e6c1f, + h1[2], // 0x1f83d9abfb41bd6b, + h1[3], // 0x5be0cd19137e2179, + ], // 8 * 64bits + + m_vec: [ + m0[0], // 0x636261, + m0[1], // 0, + m0[2], // 0, + m0[3], // 0, + m1[0], // 0, + m1[1], // 0, + m1[2], // 0, + m1[3], // 0, + m2[0], // 0, + m2[1], // 0, + m2[2], // 0, + m2[3], // 0, + m3[0], // 0, + m3[1], // 0, + m3[2], // 0, + m3[3], // 0, + ], // 16 * 64bits + t0: 3, // 64bits + t1: 0, // 64bits + f: true, // 8bits + }; + + let circuit = + ChiquitoHalo2SuperCircuit::new(compiled, super_circuit.get_mapping().generate(values)); + + let prover = MockProver::run(9, &circuit, Vec::new()).unwrap(); + let result = prover.verify_par(); + + println!("result = {:#?}", result); + + if let Err(failures) = &result { + for failure in failures.iter() { + println!("{}", failure); + } + } +} diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index 15def6ff..303e66b9 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -201,6 +201,12 @@ impl From<&'static str> for StepTypeDefInput { } } +impl From for StepTypeDefInput { + fn from(s: String) -> Self { + StepTypeDefInput::String(Box::leak(s.into_boxed_str())) + } +} + /// A generic structure designed to handle the context of a step type definition. The struct /// contains a `StepType` instance and implements methods to build the step type, add components, /// and manipulate the step type. `F` is a generic type representing the field of the step type. From d139e1b0ccd2759f14b2f88448c441010a23d10d Mon Sep 17 00:00:00 2001 From: Rute Figueiredo Date: Thu, 21 Mar 2024 18:23:40 +0000 Subject: [PATCH 3/8] Feature/improve unit tests coverage (#189) Added more unit tests to the following modules: - poly - ast/query - dsl - dsl/cb - compiler/step_selector - compiler - super_circuit Covers some of these issues: #157 #102 #105 --- src/frontend/dsl/cb.rs | 47 ++++++ src/frontend/dsl/mod.rs | 173 +++++++++++++-------- src/frontend/dsl/sc.rs | 203 +++++++++++++++++++++++++ src/plonkish/compiler/mod.rs | 78 +++++++++- src/plonkish/compiler/step_selector.rs | 93 +++++++++++ src/plonkish/ir/sc.rs | 176 ++++++++++++++++++++- src/poly/mod.rs | 65 ++++++++ src/sbpir/query.rs | 157 +++++++++++++++++++ 8 files changed, 924 insertions(+), 68 deletions(-) diff --git a/src/frontend/dsl/cb.rs b/src/frontend/dsl/cb.rs index 5b2f577f..d41b041b 100644 --- a/src/frontend/dsl/cb.rs +++ b/src/frontend/dsl/cb.rs @@ -716,4 +716,51 @@ mod tests { matches!(v[1], Expr::Const(c) if c == 40u64.field())) && matches!(v[1], Expr::Const(c) if c == 10u64.field()))); } + + #[test] + fn test_constraint_from_queriable() { + // Create a Queriable instance and convert it to a Constraint + let queriable = Queriable::StepTypeNext(StepTypeHandler::new("test_step".to_owned())); + let constraint: Constraint = Constraint::from(queriable); + + assert_eq!(constraint.annotation, "test_step"); + assert!( + matches!(constraint.expr, Expr::Query(Queriable::StepTypeNext(s)) if + matches!(s, StepTypeHandler {id: _id, annotation: "test_step"})) + ); + assert!(matches!(constraint.typing, Typing::Boolean)); + } + + #[test] + fn test_constraint_from_expr() { + // Create an expression and convert it to a Constraint + let expr = >>::expr(&10) * 20u64.expr(); + let constraint: Constraint = Constraint::from(expr); + + // returns "10 * 20" + assert!(matches!(constraint.expr, Expr::Mul(v) if v.len() == 2 && + matches!(v[0], Expr::Const(c) if c == 10u64.field()) && + matches!(v[1], Expr::Const(c) if c == 20u64.field()))); + assert!(matches!(constraint.typing, Typing::Unknown)); + } + + #[test] + fn test_constraint_from_int() { + // Create an integer and convert it to a Constraint + let constraint: Constraint = Constraint::from(10); + + // returns "10" + assert!(matches!(constraint.expr, Expr::Const(c) if c == 10u64.field())); + assert!(matches!(constraint.typing, Typing::Unknown)); + } + + #[test] + fn test_constraint_from_bool() { + // Create a boolean and convert it to a Constraint + let constraint: Constraint = Constraint::from(true); + + assert_eq!(constraint.annotation, "0x1"); + assert!(matches!(constraint.expr, Expr::Const(c) if c == 1u64.field())); + assert!(matches!(constraint.typing, Typing::Unknown)); + } } diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index 303e66b9..ce93ca46 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -154,6 +154,8 @@ impl CircuitContext { self.circuit.last_step = Some(step_type.into().uuid()); } + /// Enforce the number of step instances by adding a constraint to the circuit. Takes a `usize` + /// parameter that represents the total number of steps. pub fn pragma_num_steps(&mut self, num_steps: usize) { self.circuit.num_steps = num_steps; } @@ -231,6 +233,7 @@ impl StepTypeContext { } /// DEPRECATED + // #[deprecated(note = "use step types setup for constraints instead")] pub fn constr>>(&mut self, constraint: C) { println!("DEPRECATED constr: use setup for constraints in step types"); @@ -241,6 +244,7 @@ impl StepTypeContext { } /// DEPRECATED + #[deprecated(note = "use step types setup for constraints instead")] pub fn transition>>(&mut self, constraint: C) { println!("DEPRECATED transition: use setup for constraints in step types"); @@ -430,28 +434,49 @@ pub mod sc; #[cfg(test)] mod tests { + use crate::sbpir::ForwardSignal; + use super::*; + fn setup_circuit_context() -> CircuitContext + where + F: Default, + TraceArgs: Default, + { + CircuitContext { + circuit: SBPIR::default(), + tables: Default::default(), + } + } + #[test] - fn test_disable_q_enable() { + fn test_circuit_default_initialization() { let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; - context.pragma_disable_q_enable(); + // Assert default values + assert!(circuit.step_types.is_empty()); + assert!(circuit.forward_signals.is_empty()); + assert!(circuit.shared_signals.is_empty()); + assert!(circuit.fixed_signals.is_empty()); + assert!(circuit.exposed.is_empty()); + assert!(circuit.annotations.is_empty()); + assert!(circuit.trace.is_none()); + assert!(circuit.first_step.is_none()); + assert!(circuit.last_step.is_none()); + assert!(circuit.num_steps == 0); + assert!(circuit.q_enable); + } + #[test] + fn test_disable_q_enable() { + let mut context = setup_circuit_context::(); + context.pragma_disable_q_enable(); assert!(!context.circuit.q_enable); } #[test] fn test_set_num_steps() { - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); context.pragma_num_steps(3); assert_eq!(context.circuit.num_steps, 3); @@ -460,14 +485,29 @@ mod tests { assert_eq!(context.circuit.num_steps, 0); } + #[test] + fn test_set_first_step() { + let mut context = setup_circuit_context::(); + + let step_type: StepTypeHandler = context.step_type("step_type"); + + context.pragma_first_step(step_type); + assert_eq!(context.circuit.first_step, Some(step_type.uuid())); + } + + #[test] + fn test_set_last_step() { + let mut context = setup_circuit_context::(); + + let step_type: StepTypeHandler = context.step_type("step_type"); + + context.pragma_last_step(step_type); + assert_eq!(context.circuit.last_step, Some(step_type.uuid())); + } + #[test] fn test_forward() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set forward signals let forward_a: Queriable = context.forward("forward_a"); @@ -479,14 +519,21 @@ mod tests { assert_eq!(context.circuit.forward_signals[1].uuid(), forward_b.uuid()); } + #[test] + fn test_adding_duplicate_signal_names() { + let mut context = setup_circuit_context::(); + context.forward("duplicate_name"); + context.forward("duplicate_name"); + // Assert how the system should behave. Does it override the previous signal, throw an + // error, or something else? + // TODO: Should we let the user know that they are adding a duplicate signal name? And let + // the circuit have two signals with the same name? + assert_eq!(context.circuit.forward_signals.len(), 2); + } + #[test] fn test_forward_with_phase() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set forward signals with specified phase context.forward_with_phase("forward_a", 1); @@ -500,12 +547,7 @@ mod tests { #[test] fn test_shared() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set shared signal let shared_a: Queriable = context.shared("shared_a"); @@ -517,12 +559,7 @@ mod tests { #[test] fn test_shared_with_phase() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set shared signal with specified phase context.shared_with_phase("shared_a", 2); @@ -534,12 +571,7 @@ mod tests { #[test] fn test_fixed() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set fixed signal context.fixed("fixed_a"); @@ -550,12 +582,7 @@ mod tests { #[test] fn test_expose() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set forward signal and step to expose let forward_a: Queriable = context.forward("forward_a"); @@ -572,14 +599,21 @@ mod tests { ); } + #[test] + #[ignore] + #[should_panic(expected = "Signal not found")] + fn test_expose_non_existing_signal() { + let mut context = setup_circuit_context::(); + let non_existing_signal = + Queriable::Forward(ForwardSignal::new_with_phase(0, "".to_owned()), false); // Create a signal not added to the circuit + context.expose(non_existing_signal, ExposeOffset::First); + + todo!("remove the ignore after fixing the check for non existing signals") + } + #[test] fn test_step_type() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // create a step type let handler: StepTypeHandler = context.step_type("fibo_first_step"); @@ -593,12 +627,7 @@ mod tests { #[test] fn test_step_type_def() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // create a step type including its definition let simple_step = context.step_type_def("simple_step", |context| { @@ -619,12 +648,7 @@ mod tests { #[test] fn test_step_type_def_pass_handler() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // create a step type handler let handler: StepTypeHandler = context.step_type("simple_step"); @@ -645,4 +669,23 @@ mod tests { context.circuit.step_types[&simple_step.uuid()].uuid() ); } + + #[test] + fn test_trace() { + let mut context = setup_circuit_context::(); + + // set trace function + context.trace(|_, _: i32| {}); + + // assert trace function was set + assert!(context.circuit.trace.is_some()); + } + + #[test] + #[should_panic(expected = "circuit cannot have more than one trace generator")] + fn test_setting_trace_multiple_times() { + let mut context = setup_circuit_context::(); + context.trace(|_, _| {}); + context.trace(|_, _| {}); + } } diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index f9ef60d5..8e3fdcd2 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -18,6 +18,7 @@ use crate::{ use super::{lb::LookupTableRegistry, CircuitContext}; +#[derive(Debug)] pub struct SuperCircuitContext { super_circuit: SuperCircuit, sub_circuit_phase1: Vec>, @@ -120,3 +121,205 @@ where ctx.compile() } + +#[cfg(test)] +mod tests { + use halo2curves::{bn256::Fr, ff::PrimeField}; + + use crate::{ + plonkish::compiler::{ + cell_manager::SingleRowCellManager, config, step_selector::SimpleStepSelectorBuilder, + }, + poly::ToField, + }; + + use super::*; + + #[test] + fn test_super_circuit_context_default() { + let ctx = SuperCircuitContext::::default(); + + assert_eq!( + format!("{:#?}", ctx.super_circuit), + format!("{:#?}", SuperCircuit::::default()) + ); + assert_eq!( + format!("{:#?}", ctx.sub_circuit_phase1), + format!("{:#?}", Vec::>::default()) + ); + assert_eq!(ctx.sub_circuit_phase1.len(), 0); + assert_eq!( + format!("{:#?}", ctx.tables), + format!("{:#?}", LookupTableRegistry::::default()) + ); + } + + #[test] + fn test_super_circuit_context_sub_circuit() { + let mut ctx = SuperCircuitContext::::default(); + + fn simple_circuit(ctx: &mut CircuitContext, _: ()) { + use crate::frontend::dsl::cb::*; + + let x = ctx.forward("x"); + let y = ctx.forward("y"); + + let step_type = ctx.step_type_def("sum should be 10", |ctx| { + ctx.setup(move |ctx| { + ctx.constr(eq(x + y, 10)); + }); + + ctx.wg(move |ctx, (x_value, y_value): (u32, u32)| { + ctx.assign(x, x_value.field()); + ctx.assign(y, y_value.field()); + }) + }); + + ctx.pragma_num_steps(1); + + ctx.trace(move |ctx, ()| { + ctx.add(&step_type, (2, 8)); + }) + } + + // simple circuit to check if the sum of two inputs are 10 + ctx.sub_circuit( + config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}), + simple_circuit, + (), + ); + + // ensure phase 1 was done correctly for the sub circuit + assert_eq!(ctx.sub_circuit_phase1.len(), 1); + assert_eq!(ctx.sub_circuit_phase1[0].columns.len(), 4); + assert_eq!( + ctx.sub_circuit_phase1[0].columns[0].annotation, + "srcm forward x" + ); + assert_eq!( + ctx.sub_circuit_phase1[0].columns[1].annotation, + "srcm forward y" + ); + assert_eq!(ctx.sub_circuit_phase1[0].columns[2].annotation, "q_enable"); + assert_eq!( + ctx.sub_circuit_phase1[0].columns[3].annotation, + "'step selector for sum should be 10'" + ); + assert_eq!(ctx.sub_circuit_phase1[0].forward_signals.len(), 2); + assert_eq!(ctx.sub_circuit_phase1[0].step_types.len(), 1); + assert_eq!(ctx.sub_circuit_phase1[0].compilation_phase, 1); + } + + #[test] + fn test_super_circuit_compile() { + let mut ctx = SuperCircuitContext::::default(); + + fn simple_circuit(ctx: &mut CircuitContext, _: ()) { + use crate::frontend::dsl::cb::*; + + let x = ctx.forward("x"); + let y = ctx.forward("y"); + + let step_type = ctx.step_type_def("sum should be 10", |ctx| { + ctx.setup(move |ctx| { + ctx.constr(eq(x + y, 10)); + }); + + ctx.wg(move |ctx, (x_value, y_value): (u32, u32)| { + ctx.assign(x, x_value.field()); + ctx.assign(y, y_value.field()); + }) + }); + + ctx.pragma_num_steps(1); + + ctx.trace(move |ctx, ()| { + ctx.add(&step_type, (2, 8)); + }) + } + + // simple circuit to check if the sum of two inputs are 10 + ctx.sub_circuit( + config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}), + simple_circuit, + (), + ); + + let super_circuit = ctx.compile(); + + assert_eq!(super_circuit.get_sub_circuits().len(), 1); + assert_eq!(super_circuit.get_sub_circuits()[0].columns.len(), 4); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[0].annotation, + "srcm forward x" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[1].annotation, + "srcm forward y" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[2].annotation, + "q_enable" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[3].annotation, + "'step selector for sum should be 10'" + ); + } + + #[test] + fn test_super_circuit_sub_circuit_with_ast() { + use crate::frontend::dsl::circuit; + let mut ctx = SuperCircuitContext::::default(); + + let simple_circuit_with_ast = circuit("simple circuit", |ctx| { + use crate::frontend::dsl::cb::*; + + let x = ctx.forward("x"); + let y = ctx.forward("y"); + + let step_type = ctx.step_type_def("sum should be 10", |ctx| { + ctx.setup(move |ctx| { + ctx.constr(eq(x + y, 10)); + }); + + ctx.wg(move |ctx, (x_value, y_value): (u32, u32)| { + ctx.assign(x, x_value.field()); + ctx.assign(y, y_value.field()); + }) + }); + + ctx.pragma_num_steps(1); + + ctx.trace(move |ctx, ()| { + ctx.add(&step_type, (2, 8)); + }); + }); + + ctx.sub_circuit_with_ast( + config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}), + simple_circuit_with_ast, + ); + + let super_circuit = ctx.compile(); + + assert_eq!(super_circuit.get_sub_circuits().len(), 1); + assert_eq!(super_circuit.get_sub_circuits()[0].columns.len(), 4); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[0].annotation, + "srcm forward x" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[1].annotation, + "srcm forward y" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[2].annotation, + "q_enable" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[3].annotation, + "'step selector for sum should be 10'" + ); + } +} diff --git a/src/plonkish/compiler/mod.rs b/src/plonkish/compiler/mod.rs index 1e40c710..d1a5c9d1 100644 --- a/src/plonkish/compiler/mod.rs +++ b/src/plonkish/compiler/mod.rs @@ -567,10 +567,69 @@ fn add_halo2_columns(unit: &mut CompilationUnit, ast: &astCircu } #[cfg(test)] -mod tests { - use super::*; +mod test { + use halo2_proofs::plonk::Any; use halo2curves::bn256::Fr; + use super::{cell_manager::SingleRowCellManager, step_selector::SimpleStepSelectorBuilder, *}; + + #[test] + fn test_compiler_config_initialization() { + let cell_manager = SingleRowCellManager::default(); + let step_selector_builder = SimpleStepSelectorBuilder::default(); + + let config = config(cell_manager.clone(), step_selector_builder.clone()); + + assert_eq!( + format!("{:#?}", config.cell_manager), + format!("{:#?}", cell_manager) + ); + assert_eq!( + format!("{:#?}", config.step_selector_builder), + format!("{:#?}", step_selector_builder) + ); + } + + #[test] + fn test_compile() { + let cell_manager = SingleRowCellManager::default(); + let step_selector_builder = SimpleStepSelectorBuilder::default(); + let config = config(cell_manager, step_selector_builder); + + let mock_ast_circuit = astCircuit::::default(); + + let (circuit, assignment_generator) = compile(config, &mock_ast_circuit); + + assert_eq!(circuit.columns.len(), 1); + assert_eq!(circuit.exposed.len(), 0); + assert_eq!(circuit.polys.len(), 0); + assert_eq!(circuit.lookups.len(), 0); + assert_eq!(circuit.fixed_assignments.len(), 1); + assert_eq!(circuit.ast_id, mock_ast_circuit.id); + + assert!(assignment_generator.is_none()); + } + + #[test] + fn test_compile_phase1() { + let cell_manager = SingleRowCellManager::default(); + let step_selector_builder = SimpleStepSelectorBuilder::default(); + let config = config(cell_manager, step_selector_builder); + + let mock_ast_circuit = astCircuit::::default(); + + let (unit, assignment_generator) = compile_phase1(config, &mock_ast_circuit); + + assert_eq!(unit.columns.len(), 1); + assert_eq!(unit.exposed.len(), 0); + assert_eq!(unit.polys.len(), 0); + assert_eq!(unit.lookups.len(), 0); + assert_eq!(unit.fixed_assignments.len(), 0); + assert_eq!(unit.ast_id, mock_ast_circuit.id); + + assert!(assignment_generator.is_none()); + } + #[test] #[should_panic] fn test_compile_phase2_before_phase1() { @@ -578,4 +637,19 @@ mod tests { compile_phase2(&mut unit); } + + #[test] + fn test_add_default_columns() { + let mock_ast_circuit = astCircuit::::default(); + + let mut unit = CompilationUnit::from(&mock_ast_circuit); + add_default_columns(&mut unit); + + assert_eq!(unit.columns.len(), 1); + assert_eq!(unit.exposed.len(), 0); + assert_eq!(unit.polys.len(), 0); + assert_eq!(unit.lookups.len(), 0); + assert_eq!(unit.fixed_assignments.len(), 0); + assert_eq!(unit.ast_id, mock_ast_circuit.id); + } } diff --git a/src/plonkish/compiler/step_selector.rs b/src/plonkish/compiler/step_selector.rs index b4d88a6f..fbe0eb53 100644 --- a/src/plonkish/compiler/step_selector.rs +++ b/src/plonkish/compiler/step_selector.rs @@ -289,6 +289,99 @@ mod tests { } } + #[test] + fn test_default_step_selector() { + let unit = mock_compilation_unit::(); + assert_eq!(unit.selector.columns.len(), 0); + assert_eq!(unit.selector.selector_expr.len(), 0); + assert_eq!(unit.selector.selector_expr_not.len(), 0); + assert_eq!(unit.selector.selector_assignment.len(), 0); + } + + #[test] + fn test_select_step_selector() { + let mut unit = mock_compilation_unit::(); + let step_type = Rc::new(StepType::new(Uuid::nil().as_u128(), "StepType".to_string())); + unit.step_types.insert(step_type.uuid(), step_type.clone()); + + let builder = SimpleStepSelectorBuilder {}; + builder.build(&mut unit); + + let selector = &unit.selector; + let constraint = PolyExpr::Const(Fr::ONE); + + let step_uuid = step_type.uuid(); + let selector_expr = selector + .selector_expr + .get(&step_uuid) + .expect("Step not found") + .clone(); + let expected_expr = PolyExpr::Mul(vec![selector_expr, constraint.clone()]); + + assert_eq!( + format!("{:#?}", selector.select(step_uuid, &constraint)), + format!("{:#?}", expected_expr) + ); + } + + #[test] + fn test_next_step_selector() { + let mut unit = mock_compilation_unit::(); + let step_type = Rc::new(StepType::new(Uuid::nil().as_u128(), "StepType".to_string())); + unit.step_types.insert(step_type.uuid(), step_type.clone()); + + let builder = SimpleStepSelectorBuilder {}; + builder.build(&mut unit); + + let selector = &unit.selector; + let step_uuid = step_type.uuid(); + let step_height = 1; + let expected_expr = selector + .selector_expr + .get(&step_uuid) + .expect("Step not found") + .clone() + .rotate(step_height); + + assert_eq!( + format!("{:#?}", selector.next_expr(step_uuid, step_height as u32)), + format!("{:#?}", expected_expr) + ); + } + + #[test] + fn test_unselect_step_selector() { + let mut unit = mock_compilation_unit::(); + let step_type = Rc::new(StepType::new(Uuid::nil().as_u128(), "StepType".to_string())); + unit.step_types.insert(step_type.uuid(), step_type.clone()); + + let builder = SimpleStepSelectorBuilder {}; + builder.build(&mut unit); + + let selector = &unit.selector; + let step_uuid = step_type.uuid(); + let expected_expr = selector + .selector_expr_not + .get(&step_uuid) + .expect("Step not found") + .clone(); + + assert_eq!( + format!("{:#?}", selector.unselect(step_uuid)), + format!("{:#?}", expected_expr) + ); + } + + #[test] + fn test_simple_step_selector_builder() { + let builder = SimpleStepSelectorBuilder {}; + let mut unit = mock_compilation_unit::(); + + add_step_types_to_unit(&mut unit, 2); + builder.build(&mut unit); + assert_common_tests(&unit, 2); + } + #[test] fn test_log_n_selector_builder_3_step_types() { let builder = LogNSelectorBuilder {}; diff --git a/src/plonkish/ir/sc.rs b/src/plonkish/ir/sc.rs index 05cf4663..9da33c11 100644 --- a/src/plonkish/ir/sc.rs +++ b/src/plonkish/ir/sc.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, hash::Hash, rc::Rc}; +use std::{collections::HashMap, fmt::Debug, hash::Hash, rc::Rc}; use crate::{field::Field, sbpir::SBPIR, util::UUID, wit_gen::TraceWitness}; @@ -7,6 +7,7 @@ use super::{ Circuit, }; +#[derive(Debug)] pub struct SuperCircuit { sub_circuits: Vec>, mapping: MappingGenerator, @@ -75,6 +76,7 @@ impl SuperCircuit { pub type SuperAssignments = HashMap>; pub type SuperTraceWitness = HashMap>; +#[derive(Clone)] pub struct MappingContext { assignments: SuperAssignments, trace_witnesses: SuperTraceWitness, @@ -130,6 +132,12 @@ impl Clone for MappingGenerator { } } +impl std::fmt::Debug for MappingGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MappingGenerator") + } +} + impl Default for MappingGenerator { fn default() -> Self { Self { @@ -160,3 +168,169 @@ impl MappingGenerator { ctx.get_trace_witnesses() } } + +#[cfg(test)] +mod test { + use halo2curves::bn256::Fr; + + use crate::{ + plonkish::{ + compiler::{cell_manager::Placement, step_selector::StepSelector}, + ir::Column, + }, + util::uuid, + wit_gen::{AutoTraceGenerator, TraceGenerator}, + }; + + use super::*; + + #[test] + fn test_default() { + let super_circuit: SuperCircuit = Default::default(); + + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits), + format!("{:#?}", Vec::>::default()) + ); + assert_eq!( + format!("{:#?}", super_circuit.mapping), + format!("{:#?}", MappingGenerator::::default()) + ); + } + + #[test] + fn test_add_sub_circuit() { + let mut super_circuit: SuperCircuit = Default::default(); + + fn simple_circuit() -> Circuit { + let columns = vec![Column::advice('a', 0)]; + let exposed = vec![(Column::advice('a', 0), 2)]; + let polys = vec![]; + let lookups = vec![]; + let fixed_assignments = Default::default(); + + Circuit { + columns, + exposed, + polys, + lookups, + fixed_assignments, + id: uuid(), + ast_id: uuid(), + } + } + + let sub_circuit = simple_circuit(); + + super_circuit.add_sub_circuit(sub_circuit.clone()); + + assert_eq!(super_circuit.sub_circuits.len(), 1); + assert_eq!(super_circuit.sub_circuits[0].id, sub_circuit.id); + assert_eq!(super_circuit.sub_circuits[0].ast_id, sub_circuit.ast_id); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].columns), + format!("{:#?}", sub_circuit.columns) + ); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].exposed), + format!("{:#?}", sub_circuit.exposed) + ); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].polys), + format!("{:#?}", sub_circuit.polys) + ); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].lookups), + format!("{:#?}", sub_circuit.lookups) + ); + } + + #[test] + fn test_get_sub_circuits() { + fn simple_circuit() -> Circuit { + let columns = vec![Column::advice('a', 0)]; + let exposed = vec![(Column::advice('a', 0), 2)]; + let polys = vec![]; + let lookups = vec![]; + let fixed_assignments = Default::default(); + + Circuit { + columns, + exposed, + polys, + lookups, + fixed_assignments, + id: uuid(), + ast_id: uuid(), + } + } + + let super_circuit: SuperCircuit = SuperCircuit { + sub_circuits: vec![simple_circuit()], + mapping: Default::default(), + sub_circuit_asts: Default::default(), + }; + + let sub_circuits = super_circuit.get_sub_circuits(); + + assert_eq!(sub_circuits.len(), 1); + assert_eq!(sub_circuits[0].id, super_circuit.sub_circuits[0].id); + } + + #[test] + fn test_mapping_context_default() { + let ctx = MappingContext::::default(); + + assert_eq!( + format!("{:#?}", ctx.assignments), + format!("{:#?}", SuperAssignments::::default()) + ); + } + + fn simple_assignment_generator() -> AssignmentGenerator { + AssignmentGenerator::new( + vec![Column::advice('a', 0)], + Placement { + forward: HashMap::new(), + shared: HashMap::new(), + fixed: HashMap::new(), + steps: HashMap::new(), + columns: vec![], + base_height: 0, + }, + StepSelector::default(), + TraceGenerator::default(), + AutoTraceGenerator::default(), + 1, + uuid(), + ) + } + + #[test] + fn test_mapping_context_map() { + let mut ctx = MappingContext::::default(); + + assert_eq!(ctx.assignments.len(), 0); + + let gen = simple_assignment_generator(); + + ctx.map(&gen, ()); + + assert_eq!(ctx.assignments.len(), 1); + } + + #[test] + fn test_mapping_context_map_with_witness() { + let mut ctx = MappingContext::::default(); + + let gen = simple_assignment_generator(); + + let witness = TraceWitness:: { + step_instances: vec![], + }; + + ctx.map_with_witness(&gen, witness); + + assert_eq!(ctx.assignments.len(), 1); + } +} diff --git a/src/poly/mod.rs b/src/poly/mod.rs index fbc61bd0..01c12582 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -322,4 +322,69 @@ mod test { assert_eq!(experiment.eval(&assignments), None) } + + #[test] + fn test_degree_expr() { + use super::Expr::*; + + let expr: Expr = + (Query("a") * Query("a")) + (Query("c") * Query("d")) - Const(Fr::ONE); + + assert_eq!(expr.degree(), 2); + + let expr: Expr = + (Query("a") * Query("a")) + (Query("c") * Query("d")) * Query("e"); + + assert_eq!(expr.degree(), 3); + } + + #[test] + fn test_expr_sum() { + use super::Expr::*; + + let lhs: Expr = Query("a") + Query("b"); + + let rhs: Expr = Query("c") + Query("d"); + + assert_eq!( + format!("({:?} + {:?})", lhs, rhs), + format!("{:?}", Sum(vec![lhs, rhs])) + ); + } + + #[test] + fn test_expr_mul() { + use super::Expr::*; + + let lhs: Expr = Query("a") * Query("b"); + + let rhs: Expr = Query("c") * Query("d"); + + assert_eq!( + format!("({:?} * {:?})", lhs, rhs), + format!("{:?}", Mul(vec![lhs, rhs])) + ); + } + + #[test] + fn test_expr_neg() { + use super::Expr::*; + + let expr: Expr = Query("a") + Query("b"); + + assert_eq!( + format!("(-{:?})", expr), + format!("{:?}", Neg(Box::new(expr))) + ); + + let lhs: Expr = Query("a") * Query("b"); + let rhs: Expr = Query("c") + Query("d"); + + let expr: Expr = lhs.clone() - rhs.clone(); + + assert_eq!( + format!("{:?}", Sum(vec![lhs, Neg(Box::new(rhs))])), + format!("{:?}", expr) + ); + } } diff --git a/src/sbpir/query.rs b/src/sbpir/query.rs index 82cf8f45..5701b0d1 100644 --- a/src/sbpir/query.rs +++ b/src/sbpir/query.rs @@ -211,4 +211,161 @@ mod tests { let expr5: Expr> = Expr::Pow(Box::new(Expr::Const(a)), 2); assert_eq!(format!("{:?}", expr5), "(0xa)^2"); } + + #[test] + fn test_next_for_forward_signal() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, false); + let next_queriable = queriable.next(); + + assert_eq!(next_queriable, Queriable::Forward(forward_signal, true)); + } + + #[test] + #[should_panic(expected = "jarrl: cannot rotate next(forward)")] + fn test_next_for_forward_signal_panic() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, true); + let _ = queriable.next(); // This should panic + } + + #[test] + fn test_next_for_shared_signal() { + let shared_signal = SharedSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Shared(shared_signal, 0); + let next_queriable = queriable.next(); + + assert_eq!(next_queriable, Queriable::Shared(shared_signal, 1)); + } + + #[test] + fn test_next_for_fixed_signal() { + let fixed_signal = FixedSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Fixed(fixed_signal, 0); + let next_queriable = queriable.next(); + + assert_eq!(next_queriable, Queriable::Fixed(fixed_signal, 1)); + } + + #[test] + #[should_panic(expected = "can only next a forward, shared, fixed, or halo2 column")] + fn test_next_for_internal_signal_panic() { + let internal_signal = InternalSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Internal(internal_signal); + let _ = queriable.next(); // This should panic + } + + #[test] + fn test_prev_for_shared_signal() { + let shared_signal = SharedSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Shared(shared_signal, 1); + let prev_queriable = queriable.prev(); + + assert_eq!(prev_queriable, Queriable::Shared(shared_signal, 0)); + } + + #[test] + fn test_prev_for_fixed_signal() { + let fixed_signal = FixedSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Fixed(fixed_signal, 1); + let prev_queriable = queriable.prev(); + + assert_eq!(prev_queriable, Queriable::Fixed(fixed_signal, 0)); + } + + #[test] + #[should_panic(expected = "can only prev a shared or fixed column")] + fn test_prev_for_forward_signal_panic() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, true); + let _ = queriable.prev(); // This should panic + } + + #[test] + #[should_panic(expected = "can only prev a shared or fixed column")] + fn test_prev_for_internal_signal_panic() { + let internal_signal = InternalSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Internal(internal_signal); + let _ = queriable.prev(); // This should panic + } + + #[test] + fn test_rot_for_shared_signal() { + let shared_signal = SharedSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Shared(shared_signal, 1); + let rot_queriable = queriable.rot(2); + + assert_eq!(rot_queriable, Queriable::Shared(shared_signal, 3)); + } + + #[test] + fn test_rot_for_fixed_signal() { + let fixed_signal = FixedSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Fixed(fixed_signal, 1); + let rot_queriable = queriable.rot(2); + + assert_eq!(rot_queriable, Queriable::Fixed(fixed_signal, 3)); + } + + #[test] + #[should_panic(expected = "can only rot a shared or fixed column")] + fn test_rot_for_forward_signal_panic() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, true); + let _ = queriable.rot(2); // This should panic + } + + #[test] + #[should_panic(expected = "can only rot a shared or fixed column")] + fn test_rot_for_internal_signal_panic() { + let internal_signal = InternalSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Internal(internal_signal); + let _ = queriable.rot(2); // This should panic + } } From 8365674f81c724ccfc613312bede6f729ceebfad Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Tue, 26 Mar 2024 07:03:41 -0400 Subject: [PATCH 4/8] HyperPlonk backend (#184) Ready for review @leolara. I asked Han about the best approach to testing, for which he suggested exposing functions from his repo, for which I did over my fork of his repo in its main branch: https://github.com/qwang98/plonkish. --------- Co-authored-by: Leo Lara --- .github/workflows/rust.yml | 2 +- Cargo.toml | 2 + examples/fibonacci.rs | 46 ++- rust-toolchain | 2 +- src/pil/backend/powdr_pil.rs | 2 +- src/plonkish/backend/halo2.rs | 23 +- src/plonkish/backend/hyperplonk.rs | 379 +++++++++++++++++++++++++ src/plonkish/backend/mod.rs | 1 + src/plonkish/compiler/step_selector.rs | 2 +- src/plonkish/ir/mod.rs | 23 +- src/poly/mod.rs | 4 +- 11 files changed, 447 insertions(+), 39 deletions(-) create mode 100644 src/plonkish/backend/hyperplonk.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 40df3e2b..a9f59a2e 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -49,7 +49,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly-2023-04-24 + toolchain: nightly-2024-02-14 components: clippy override: true - name: Run Clippy diff --git a/Cargo.toml b/Cargo.toml index 58d1fd0e..5ddc4de0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,8 @@ num-bigint = { version = "0.4", features = ["rand"] } uuid = { version = "1.4.0", features = ["v1", "rng"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +hyperplonk_benchmark = { git = "https://github.com/qwang98/plonkish.git", branch = "main", package = "benchmark" } +plonkish_backend = { git = "https://github.com/qwang98/plonkish.git", branch = "main", package = "plonkish_backend" } regex = "1" [dev-dependencies] diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index dbf7dbba..2fee596f 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -3,24 +3,29 @@ use std::hash::Hash; use chiquito::{ field::Field, frontend::dsl::circuit, // main function for constructing an AST circuit - plonkish::backend::halo2::{chiquito2Halo2, ChiquitoHalo2Circuit}, /* compiles to + plonkish::{ + backend::{ + halo2::{chiquito2Halo2, ChiquitoHalo2Circuit}, + hyperplonk::ChiquitoHyperPlonkCircuit, + }, + compiler::{ + cell_manager::SingleRowCellManager, // input for constructing the compiler + compile, // input for constructing the compiler + config, + step_selector::SimpleStepSelectorBuilder, + }, + ir::{assignments::AssignmentGenerator, Circuit}, + }, /* compiles to * Chiquito Halo2 * backend, * which can be * integrated into * Halo2 * circuit */ - plonkish::compiler::{ - cell_manager::SingleRowCellManager, // input for constructing the compiler - compile, // input for constructing the compiler - config, - step_selector::SimpleStepSelectorBuilder, - }, - plonkish::ir::{assignments::AssignmentGenerator, Circuit}, // compiled circuit type poly::ToField, sbpir::SBPIR, }; -use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; +use halo2_proofs::dev::MockProver; // the main circuit function: returns the compiled IR of a Chiquito circuit // Generic types F, (), (u64, 64) stand for: @@ -86,7 +91,7 @@ fn fibo_circuit + Hash>() -> FiboReturn { }) }); - ctx.pragma_num_steps(11); + ctx.pragma_num_steps(16); // trace function is responsible for adding step instantiations defined in step_type_def // function above trace function is Turing complete and allows arbitrary user @@ -99,7 +104,7 @@ fn fibo_circuit + Hash>() -> FiboReturn { let mut a = 1; let mut b = 2; - for _i in 1..11 { + for _i in 1..16 { ctx.add(&fibo_step, (a, b)); let prev_a = a; @@ -169,6 +174,25 @@ fn main() { } } + // hyperplonk boilerplate + use hyperplonk_benchmark::proof_system::{bench_plonkish_backend, System}; + use plonkish_backend::{ + backend, + halo2_curves::bn256::{Bn256, Fr}, + pcs::{multilinear, univariate}, + }; + // get Chiquito ir + let (circuit, assignment_generator, _) = fibo_circuit::(); + // get assignments + let assignments = assignment_generator.unwrap().generate(()); + // get hyperplonk circuit + let mut hyperplonk_circuit = ChiquitoHyperPlonkCircuit::new(4, circuit); + hyperplonk_circuit.set_assignment(assignments); + + type GeminiKzg = multilinear::Gemini>; + type HyperPlonk = backend::hyperplonk::HyperPlonk; + bench_plonkish_backend::(System::HyperPlonk, 4, &hyperplonk_circuit); + // pil boilerplate use chiquito::pil::backend::powdr_pil::chiquito2Pil; diff --git a/rust-toolchain b/rust-toolchain index b6ce6a50..c5f61037 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2023-04-24 +nightly-2024-02-14 diff --git a/src/pil/backend/powdr_pil.rs b/src/pil/backend/powdr_pil.rs index 4b86c591..0585c45c 100644 --- a/src/pil/backend/powdr_pil.rs +++ b/src/pil/backend/powdr_pil.rs @@ -204,7 +204,7 @@ fn convert_to_pil_expr_string(expr: PILExpr) -> S expr_string += " * "; } } - format!("{}", expr_string) + expr_string.to_string() } PILExpr::Neg(neg) => format!("(-{})", convert_to_pil_expr_string(*neg)), PILExpr::Pow(pow, power) => { diff --git a/src/plonkish/backend/halo2.rs b/src/plonkish/backend/halo2.rs index 7cc0082d..d8dfd6d2 100644 --- a/src/plonkish/backend/halo2.rs +++ b/src/plonkish/backend/halo2.rs @@ -215,25 +215,6 @@ impl + Hash> ChiquitoHalo2 { Ok(()) } - fn instance(&self, witness: &Assignments) -> Vec { - let mut instance_values = Vec::new(); - for (column, rotation) in &self.circuit.exposed { - let values = witness - .get(column) - .unwrap_or_else(|| panic!("exposed column not found: {}", column.annotation)); - - if let Some(value) = values.get(*rotation as usize) { - instance_values.push(*value); - } else { - panic!( - "assignment index out of bounds for column: {}", - column.annotation - ); - } - } - instance_values - } - fn annotate_circuit(&self, region: &mut Region) { for column in self.circuit.columns.iter() { match column.ctype { @@ -379,7 +360,7 @@ impl + Hash> ChiquitoHalo2Circuit { pub fn instance(&self) -> Vec> { if !self.compiled.circuit.exposed.is_empty() { if let Some(witness) = &self.witness { - return vec![self.compiled.instance(witness)]; + return vec![self.compiled.circuit.instance(witness)]; } } Vec::new() @@ -444,7 +425,7 @@ impl + Hash> ChiquitoHalo2SuperCircuit { for sub_circuit in &self.sub_circuits { if !sub_circuit.circuit.exposed.is_empty() { - let instance_values = sub_circuit.instance( + let instance_values = sub_circuit.circuit.instance( self.witness .get(&sub_circuit.ir_id) .expect("No matching witness found for given UUID."), diff --git a/src/plonkish/backend/hyperplonk.rs b/src/plonkish/backend/hyperplonk.rs new file mode 100644 index 00000000..622c1c35 --- /dev/null +++ b/src/plonkish/backend/hyperplonk.rs @@ -0,0 +1,379 @@ +use crate::{ + plonkish::ir::{assignments::Assignments, Circuit, Column, ColumnType, PolyExpr}, + util::UUID, +}; +use halo2_proofs::arithmetic::Field; +use plonkish_backend::{ + backend::{PlonkishCircuit, PlonkishCircuitInfo}, + util::expression::{rotate::Rotation, Expression, Query}, +}; +use std::{collections::HashMap, hash::Hash}; + +// get max phase number + 1 to get number of phases +// for example, if the phases slice is [0, 1, 0, 1, 2, 2], then the output will be 3 +fn num_phases(phases: &[usize]) -> usize { + phases.iter().max().copied().unwrap_or_default() + 1 +} + +// get number of columns for each phase given a vector of phases +// for example, if the phases slice is [0, 1, 0, 1, 2, 2], then the output vector will be +// [2, 2, 2] +fn num_by_phase(phases: &[usize]) -> Vec { + phases.iter().copied().fold( + vec![0usize; num_phases(phases)], + |mut num_by_phase, phase| { + num_by_phase[phase] += 1; + num_by_phase + }, + ) +} + +// This function maps each element in the phases slice to its index within the circuit, given an +// offset For example, if the phases slice is [0, 1, 0, 1, 2, 2], and the offset is 3, then the +// output vector will be [3, 5, 4, 6, 7, 8], i.e. [3+0+0, 3+2+0, 3+0+1, 3+2+1, 3+4+0, 3+4+1], i.e. +// [offset+phase_offset+index] +fn idx_order_by_phase(phases: &[usize], offset: usize) -> Vec { + phases + .iter() + .copied() + .scan(phase_offsets(phases), |state, phase| { + let index = state[phase]; + state[phase] += 1; + Some(offset + index) + }) + .collect() +} + +// get vector of advice column phases +fn advice_phases(circuit: &Circuit) -> Vec { + circuit + .columns + .iter() + .filter(|column| column.ctype == ColumnType::Advice) + .map(|column| column.phase) + .collect::>() +} + +// This function computes the offsets for each phase. +// For example, if the phases slice is [0, 1, 0, 1, 2, 2], then the output vector will be +// [0, 2, 4]. +fn phase_offsets(phases: &[usize]) -> Vec { + num_by_phase(phases) + .into_iter() + .scan(0, |state, num| { + let offset = *state; + *state += num; + Some(offset) + }) + .collect() +} + +pub struct ChiquitoHyperPlonkCircuit { + circuit: ChiquitoHyperPlonk, + assignments: Option>, +} + +pub struct ChiquitoHyperPlonk { + k: usize, + instances: Vec>, /* outter vec has length 1, inner vec has length equal to number of + * exposed signals */ + chiquito_ir: Circuit, + num_witness_polys: Vec, + all_uuids: Vec, // the same order as self.chiquito_ir.columns + fixed_uuids: Vec, // the same order as self.chiquito_ir.columns + advice_uuids: Vec, // the same order as self.chiquito_ir.columns + advice_uuids_by_phase: HashMap>, +} + +impl + Hash> ChiquitoHyperPlonk { + fn new(k: usize, circuit: Circuit) -> Self { + // get all column uuids + let all_uuids = circuit + .columns + .iter() + .map(|column| column.id) + .collect::>(); + + // get fixed column uuids + let fixed_uuids = circuit + .columns + .iter() + .filter(|column| column.ctype == ColumnType::Fixed) + .map(|column| column.id) + .collect::>(); + + // get advice column uuids (including step selectors) + let advice_uuids = circuit + .columns + .iter() + .filter(|column| column.ctype == ColumnType::Advice) + .map(|column| column.id) + .collect::>(); + + // check that length of all uuid vectors equals length of all columns + assert_eq!( + fixed_uuids.len() + advice_uuids.len(), + circuit.columns.len() + ); + + // get phase number for all advice columns + let advice_phases = advice_phases(&circuit); + // get number of witness polynomials for each phase + let num_witness_polys = num_by_phase(&advice_phases); + + // given non_selector_advice_phases and non_selector_advice_uuids, which have equal lengths, + // create hashmap of phase to vector of uuids if phase doesn't exist in map, create + // a new vector and insert it into map if phase exists in map, insert the uuid to + // the vector associated with the phase + assert_eq!(advice_phases.len(), advice_uuids.len()); + let advice_uuids_by_phase = advice_phases.iter().zip(advice_uuids.iter()).fold( + HashMap::new(), + |mut map: HashMap>, (phase, uuid)| { + map.entry(*phase).or_default().push(*uuid); + map + }, + ); + + Self { + k, + instances: Vec::default(), + chiquito_ir: circuit, + num_witness_polys, + all_uuids, + fixed_uuids, + advice_uuids, + advice_uuids_by_phase, + } + } + + fn set_instance(&mut self, instance: Vec>) { + self.instances = instance; + } +} + +impl + Hash> ChiquitoHyperPlonkCircuit { + pub fn new(k: usize, circuit: Circuit) -> Self { + let chiquito_hyper_plonk = ChiquitoHyperPlonk::new(k, circuit); + Self { + circuit: chiquito_hyper_plonk, + assignments: None, + } + } + + pub fn set_assignment(&mut self, assignments: Assignments) { + let instances = vec![self.circuit.chiquito_ir.instance(&assignments)]; + self.circuit.set_instance(instances); + self.assignments = Some(assignments); + } +} + +// given column uuid and the vector of all column uuids, get the index or position of the uuid +// has no offset +fn column_idx(column_uuid: UUID, column_uuids: &[UUID]) -> usize { + column_uuids + .iter() + .position(|&uuid| uuid == column_uuid) + .unwrap() +} + +impl + Hash> PlonkishCircuit for ChiquitoHyperPlonkCircuit { + fn circuit_info_without_preprocess( + &self, + ) -> Result, plonkish_backend::Error> { + // there's only one instance column whose length is equal to the number of exposed signals + // in chiquito circuit `num_instances` is a vector of length 1, because we only have + // one instance column + let num_instances = self.circuit.instances.iter().map(Vec::len).collect(); + + // a vector of zero vectors, each zero vector with 2^k length + // number of preprocess is equal to number of fixed columns + let preprocess_polys = + vec![vec![F::ZERO; 1 << self.circuit.k]; self.circuit.fixed_uuids.len()]; + + let advice_idx = self.circuit.advice_idx(); + let constraints: Vec> = self + .circuit + .chiquito_ir + .polys + .iter() + .map(|poly| { + self.circuit + .convert_expression(poly.expr.clone(), &advice_idx) + }) + .collect(); + + let lookups = self + .circuit + .chiquito_ir + .lookups + .iter() + .map(|lookup| { + lookup + .exprs + .iter() + .map(|(input, table)| { + ( + self.circuit.convert_expression(input.clone(), &advice_idx), + self.circuit.convert_expression(table.clone(), &advice_idx), + ) + }) + .collect() + }) + .collect(); + + let max_degree = constraints + .iter() + .map(|constraint| constraint.degree()) + .max(); + + Ok(PlonkishCircuitInfo { + k: self.circuit.k, + num_instances, + preprocess_polys, + num_witness_polys: self.circuit.num_witness_polys.clone(), + num_challenges: vec![0; self.circuit.num_witness_polys.len()], + constraints, + lookups, + permutations: Default::default(), // Chiquito doesn't have permutations + max_degree, + }) + } + + // preprocess fixed assignments + fn circuit_info( + &self, + ) -> Result, plonkish_backend::Error> { + let mut circuit_info = self.circuit_info_without_preprocess()?; + // make sure all fixed assignments are for fixed column type + self.circuit + .chiquito_ir + .fixed_assignments + .iter() + .for_each(|(column, _)| match column.ctype { + ColumnType::Fixed => (), + _ => panic!("fixed assignments must be for fixed column type"), + }); + // get assignments Vec by looking up from fixed_assignments and reorder assignment + // vectors according to self.fixed_uuids. finally bind all Vec to a Vec>. + // here, get Vec from fixed_assigments: HashMap> by looking up the Column + // with uuid + let fixed_assignments = self + .circuit + .fixed_uuids + .iter() + .map(|uuid| { + self.circuit + .chiquito_ir + .fixed_assignments + .get( + &self.circuit.chiquito_ir.columns + [column_idx(*uuid, &self.circuit.all_uuids)], + ) + .unwrap() + .clone() + }) + .collect::>>(); + + circuit_info.preprocess_polys = fixed_assignments; + + Ok(circuit_info) + } + + fn instances(&self) -> &[Vec] { + &self.circuit.instances + } + + fn synthesize( + &self, + phase: usize, + _challenges: &[F], + ) -> Result>, plonkish_backend::Error> { + let assignments = self.assignments.clone().unwrap(); + + let advice_assignments = self + .circuit + .advice_uuids_by_phase + .get(&phase) + .expect("synthesize: phase not found") + .iter() + .map(|uuid| { + assignments + .get( + &self.circuit.chiquito_ir.columns + [column_idx(*uuid, &self.circuit.all_uuids)], + ) + .unwrap() + .clone() + }) + .collect::>>(); + Ok(advice_assignments) + } +} + +impl ChiquitoHyperPlonk { + fn advice_idx(self: &ChiquitoHyperPlonk) -> Vec { + let advice_offset = self.fixed_uuids.len(); + idx_order_by_phase(&advice_phases(&self.chiquito_ir), advice_offset) + } + + fn convert_query( + self: &ChiquitoHyperPlonk, + column: Column, + rotation: i32, + advice_indx: &[usize], + ) -> Expression { + // if column type is fixed, query column will be determined by column_idx function and + // self.fixed_uuids + // if column type is advice, query column will be + // determined by column_idx function and self.advice_uuids + // advice columns come after fixed columns + if column.ctype == ColumnType::Fixed { + let column_idx = column_idx(column.id, &self.fixed_uuids); + Query::new(column_idx, Rotation(rotation)).into() + } else if column.ctype == ColumnType::Advice { + // advice_idx already takes into account of the offset of fixed columns + let column_idx = advice_indx[column_idx(column.id, &self.advice_uuids)]; + Query::new(column_idx, Rotation(rotation)).into() + } else { + panic!("convert_query: column type not supported") + } + } + + fn convert_expression( + self: &ChiquitoHyperPlonk, + poly: PolyExpr, + advice_idx: &Vec, + ) -> Expression { + match poly { + PolyExpr::Const(constant) => Expression::Constant(constant), + PolyExpr::Query((column, rotation, _)) => { + self.convert_query(column, rotation, advice_idx) + } + PolyExpr::Sum(expressions) => { + let mut iter = expressions.iter(); + let first = self.convert_expression(iter.next().unwrap().clone(), advice_idx); + iter.fold(first, |acc, expression| { + acc + self.convert_expression(expression.clone(), advice_idx) + }) + } + PolyExpr::Mul(expressions) => { + let mut iter = expressions.iter(); + let first = self.convert_expression(iter.next().unwrap().clone(), advice_idx); + iter.fold(first, |acc, expression| { + acc * self.convert_expression(expression.clone(), advice_idx) + }) + } + PolyExpr::Neg(expression) => -self.convert_expression(*expression, advice_idx), /* might need to convert to Expression::Negated */ + PolyExpr::Pow(expression, pow) => { + if pow == 0 { + Expression::Constant(F::ONE) + } else { + let expression = self.convert_expression(*expression, advice_idx); + (1..pow).fold(expression.clone(), |acc, _| acc * expression.clone()) + } + } + PolyExpr::Halo2Expr(_) => panic!("halo2 expressions not supported"), + PolyExpr::MI(_) => panic!("MI expressions not supported"), + } + } +} diff --git a/src/plonkish/backend/mod.rs b/src/plonkish/backend/mod.rs index 84f7c9f2..8d078b29 100644 --- a/src/plonkish/backend/mod.rs +++ b/src/plonkish/backend/mod.rs @@ -1,2 +1,3 @@ pub mod halo2; +pub mod hyperplonk; pub mod plaf; diff --git a/src/plonkish/compiler/step_selector.rs b/src/plonkish/compiler/step_selector.rs index fbe0eb53..9207c976 100644 --- a/src/plonkish/compiler/step_selector.rs +++ b/src/plonkish/compiler/step_selector.rs @@ -205,7 +205,7 @@ impl StepSelectorBuilder for LogNSelectorBuilder { let mut annotation; for index in 0..n_cols { - annotation = format!("'binary selector column {}'", index); + annotation = format!("'step selector for binary column {}'", index); let column = Column::advice(annotation.clone(), 0); selector.columns.push(column.clone()); diff --git a/src/plonkish/ir/mod.rs b/src/plonkish/ir/mod.rs index d012f883..a332a1bc 100644 --- a/src/plonkish/ir/mod.rs +++ b/src/plonkish/ir/mod.rs @@ -36,7 +36,28 @@ impl Debug for Circuit { } } -#[derive(Clone, Debug, Hash)] +impl Circuit { + pub(crate) fn instance(&self, witness: &Assignments) -> Vec { + let mut instance_values = Vec::new(); + for (column, rotation) in &self.exposed { + let values = witness + .get(column) + .unwrap_or_else(|| panic!("exposed column not found: {}", column.annotation)); + + if let Some(value) = values.get(*rotation as usize) { + instance_values.push(value.clone()); + } else { + panic!( + "assignment index out of bounds for column: {}", + column.annotation + ); + } + } + instance_values + } +} + +#[derive(Clone, Debug, Hash, PartialEq)] pub enum ColumnType { Advice, Fixed, diff --git a/src/poly/mod.rs b/src/poly/mod.rs index 01c12582..f9adac97 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -97,10 +97,10 @@ impl Expr { Expr::Const(v) => Some(*v), Expr::Sum(ses) => ses .iter() - .fold(Some(F::ZERO), |acc, se| Some(acc? + se.eval(assignments)?)), + .try_fold(F::ZERO, |acc, se| Some(acc + se.eval(assignments)?)), Expr::Mul(ses) => ses .iter() - .fold(Some(F::ONE), |acc, se| Some(acc? * se.eval(assignments)?)), + .try_fold(F::ONE, |acc, se| Some(acc * se.eval(assignments)?)), Expr::Neg(se) => Some(F::ZERO - se.eval(assignments)?), Expr::Pow(se, exp) => Some(se.eval(assignments)?.pow([*exp as u64])), Expr::Query(q) => assignments.get(q).copied(), From 1788d7e7ba4f06e6ba8a4404bd6c43328f7e5e4f Mon Sep 17 00:00:00 2001 From: even <35983442+10to4@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:34:42 +0800 Subject: [PATCH 5/8] Develop keccak (#89) Co-authored-by: Leo Lara --- examples/keccak.rs | 2365 ++++++++++++++++++++++++ examples/poseidon.rs | 2 - src/plonkish/compiler/step_selector.rs | 2 +- 3 files changed, 2366 insertions(+), 3 deletions(-) create mode 100644 examples/keccak.rs diff --git a/examples/keccak.rs b/examples/keccak.rs new file mode 100644 index 00000000..da3bb7eb --- /dev/null +++ b/examples/keccak.rs @@ -0,0 +1,2365 @@ +use chiquito::{ + frontend::dsl::{lb::LookupTable, super_circuit, CircuitContext, StepTypeWGHandler}, + plonkish::{ + backend::halo2::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, + compiler::{ + cell_manager::{MaxWidthCellManager, SingleRowCellManager}, + config, + step_selector::SimpleStepSelectorBuilder, + }, + ir::sc::SuperCircuit, + }, + poly::ToExpr, + sbpir::query::Queriable, +}; +use std::{hash::Hash, ops::Neg}; + +use halo2_proofs::{ + dev::MockProver, + halo2curves::{bn256::Fr, group::ff::PrimeField}, +}; + +use std::{ + fs::File, + io::{self, Write}, +}; + +const BIT_COUNT: u64 = 3; +const PART_SIZE: u64 = 5; +const NUM_BYTES_PER_WORD: u64 = 8; +const NUM_BITS_PER_BYTE: u64 = 8; +const NUM_WORDS_TO_ABSORB: u64 = 17; +const RATE: u64 = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; +const NUM_BITS_PER_WORD: u64 = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; +const NUM_PER_WORD: u64 = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE / 2; +const RATE_IN_BITS: u64 = RATE * NUM_BITS_PER_BYTE; +const NUM_ROUNDS: u64 = 24; +const BIT_SIZE: usize = 2usize.pow(BIT_COUNT as u32); + +const NUM_PER_WORD_BATCH3: u64 = 22; +const NUM_PER_WORD_BATCH4: u64 = 16; + +const SQUEEZE_VECTOR_NUM: u64 = 4; +const SQUEEZE_SPLIT_NUM: u64 = 16; + +const PART_SIZE_SQURE: u64 = PART_SIZE * PART_SIZE; + +pub const ROUND_CST: [u64; NUM_ROUNDS as usize + 1] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808a, + 0x8000000080008000, + 0x000000000000808b, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008a, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000a, + 0x000000008000808b, + 0x800000000000008b, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800a, + 0x800000008000000a, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, + 0x0000000000000000, +]; + +pub const XOR_VALUE_BATCH2: [u64; 36] = [ + 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x8, 0x9, 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, 0x8, + 0x9, 0x8, 0x9, 0x8, 0x9, 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x8, 0x9, +]; + +pub const XOR_VALUE_BATCH3: [u64; 64] = [ + 0x0, 0x1, 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x0, 0x1, 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x40, 0x41, + 0x40, 0x41, 0x48, 0x49, 0x48, 0x49, 0x40, 0x41, 0x40, 0x41, 0x48, 0x49, 0x48, 0x49, 0x0, 0x1, + 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x0, 0x1, 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x40, 0x41, 0x40, 0x41, + 0x48, 0x49, 0x48, 0x49, 0x40, 0x41, 0x40, 0x41, 0x48, 0x49, 0x48, 0x49, +]; + +pub const XOR_VALUE_BATCH4: [u64; 81] = [ + 0x0, 0x1, 0x0, 0x8, 0x9, 0x8, 0x0, 0x1, 0x0, 0x40, 0x41, 0x40, 0x48, 0x49, 0x48, 0x40, 0x41, + 0x40, 0x0, 0x1, 0x0, 0x8, 0x9, 0x8, 0x0, 0x1, 0x0, 0x200, 0x201, 0x200, 0x208, 0x209, 0x208, + 0x200, 0x201, 0x200, 0x240, 0x241, 0x240, 0x248, 0x249, 0x248, 0x240, 0x241, 0x240, 0x200, + 0x201, 0x200, 0x208, 0x209, 0x208, 0x200, 0x201, 0x200, 0x0, 0x1, 0x0, 0x8, 0x9, 0x8, 0x0, 0x1, + 0x0, 0x40, 0x41, 0x40, 0x48, 0x49, 0x48, 0x40, 0x41, 0x40, 0x0, 0x1, 0x0, 0x8, 0x9, 0x8, 0x0, + 0x1, 0x0, +]; + +pub const CHI_VALUE: [u64; 125] = [ + 0x0, 0x1, 0x1, 0x0, 0x0, 0x8, 0x9, 0x9, 0x8, 0x8, 0x8, 0x9, 0x9, 0x8, 0x8, 0x0, 0x1, 0x1, 0x0, + 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x40, 0x41, 0x41, 0x40, 0x40, 0x48, 0x49, 0x49, 0x48, 0x48, 0x48, + 0x49, 0x49, 0x48, 0x48, 0x40, 0x41, 0x41, 0x40, 0x40, 0x40, 0x41, 0x41, 0x40, 0x40, 0x40, 0x41, + 0x41, 0x40, 0x40, 0x48, 0x49, 0x49, 0x48, 0x48, 0x48, 0x49, 0x49, 0x48, 0x48, 0x40, 0x41, 0x41, + 0x40, 0x40, 0x40, 0x41, 0x41, 0x40, 0x40, 0x0, 0x1, 0x1, 0x0, 0x0, 0x8, 0x9, 0x9, 0x8, 0x8, + 0x8, 0x9, 0x9, 0x8, 0x8, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, + 0x0, 0x8, 0x9, 0x9, 0x8, 0x8, 0x8, 0x9, 0x9, 0x8, 0x8, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x1, 0x1, + 0x0, 0x0, +]; + +/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word +fn pack(bits: &[u8]) -> F { + pack_with_base(bits, BIT_SIZE) +} + +/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word with the +/// specified bit base +fn pack_with_base(bits: &[u8], base: usize) -> F { + // \sum 8^i * bit_i + let base = F::from(base as u64); + bits.iter() + .rev() + .fold(F::ZERO, |acc, &bit| acc * base + F::from(bit as u64)) +} + +fn pack_u64(value: u64) -> F { + pack( + &((0..NUM_BITS_PER_WORD) + .map(|i| ((value >> i) & 1) as u8) + .collect::>()), + ) +} + +/// Calculates a ^ b with a and b field elements +fn field_xor>(a: F, b: F) -> F { + let mut bytes = [0u8; 32]; + for (idx, (a, b)) in a + .to_repr() + .as_ref() + .iter() + .zip(b.to_repr().as_ref().iter()) + .enumerate() + { + bytes[idx] = *a ^ *b; + } + F::from_repr(bytes).unwrap() +} + +fn convert_bytes_to_bits(bytes: Vec) -> Vec { + bytes + .iter() + .map(|&byte| { + let mut byte = byte; + (0..8) + .map(|_| { + let b = byte % 2; + byte /= 2; + b + }) + .collect() + }) + .collect::>>() + .concat() +} + +fn convert_field_to_vec_bits(value: F) -> Vec { + let mut v_vec = Vec::new(); + let mut left = 0; + for (idx, &v1) in value.to_repr().as_ref().iter().enumerate() { + if idx % 3 == 0 { + v_vec.push(v1 % 8); + v_vec.push((v1 / 8) % 8); + left = v1 / 64; + } else if idx % 3 == 1 { + v_vec.push((v1 % 2) * 4 + left); + v_vec.push((v1 / 2) % 8); + v_vec.push((v1 / 16) % 8); + left = v1 / 128; + } else { + v_vec.push((v1 % 4) * 2 + left); + v_vec.push((v1 / 4) % 8); + v_vec.push(v1 / 32); + left = 0; + } + } + v_vec[0..64].to_vec() +} + +fn convert_bits_to_f>(value_vec: &[u8]) -> F { + assert_eq!(value_vec.len(), NUM_BITS_PER_WORD as usize); + let mut sum_value_arr: Vec = (0..24) + .map(|t| { + if t % 3 == 0 { + value_vec[(t / 3) * 8] + + value_vec[(t / 3) * 8 + 1] * 8 + + (value_vec[(t / 3) * 8 + 2] % 4) * 64 + } else if t % 3 == 1 { + value_vec[(t / 3) * 8 + 2] / 4 + + value_vec[(t / 3) * 8 + 3] * 2 + + (value_vec[(t / 3) * 8 + 4]) * 16 + + ((value_vec[(t / 3) * 8 + 5]) % 2) * 128 + } else { + value_vec[(t / 3) * 8 + 5] / 2 + + value_vec[(t / 3) * 8 + 6] * 4 + + (value_vec[(t / 3) * 8 + 7]) * 32 + } + }) + .collect(); + while sum_value_arr.len() < 32 { + sum_value_arr.push(0); + } + F::from_repr(sum_value_arr.try_into().unwrap()).unwrap() +} + +fn eval_keccak_f_to_bit_vec4>(value1: F, value2: F) -> Vec<(F, F)> { + let v1_vec = convert_field_to_vec_bits(value1); + let v2_vec = convert_field_to_vec_bits(value2); + assert_eq!(v1_vec.len(), NUM_BITS_PER_WORD as usize); + assert_eq!(v2_vec.len(), NUM_BITS_PER_WORD as usize); + (0..NUM_PER_WORD_BATCH4 as usize) + .map(|i| { + ( + F::from_u128( + v1_vec[4 * i] as u128 + + v1_vec[4 * i + 1] as u128 * 8 + + v1_vec[4 * i + 2] as u128 * 64 + + v1_vec[4 * i + 3] as u128 * 512, + ), + F::from_u128( + v2_vec[4 * i] as u128 + + v2_vec[4 * i + 1] as u128 * 8 + + v2_vec[4 * i + 2] as u128 * 64 + + v2_vec[4 * i + 3] as u128 * 512, + ), + ) + }) + .collect() +} + +fn keccak_xor_table_batch2( + ctx: &mut CircuitContext, + lens: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_xor_row: Queriable = ctx.fixed("xor row(batch 2)"); + let lookup_xor_c: Queriable = ctx.fixed("xor value(batch 2)"); + + let constants_value = XOR_VALUE_BATCH2; + assert_eq!(lens, constants_value.len()); + ctx.pragma_num_steps(lens); + + ctx.fixed_gen(move |ctx| { + for (i, &value) in constants_value.iter().enumerate().take(lens) { + ctx.assign(i, lookup_xor_row, F::from(((i / 6) * 8 + i % 6) as u64)); + ctx.assign(i, lookup_xor_c, F::from(value)); + } + }); + + ctx.new_table(table().add(lookup_xor_row).add(lookup_xor_c)) +} + +fn keccak_xor_table_batch3( + ctx: &mut CircuitContext, + lens: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_xor_row: Queriable = ctx.fixed("xor row(batch 3)"); + let lookup_xor_c: Queriable = ctx.fixed("xor value(batch 3)"); + + let constants_value = XOR_VALUE_BATCH3; + assert_eq!(lens, constants_value.len()); + ctx.pragma_num_steps(lens); + ctx.fixed_gen(move |ctx| { + for (i, &value) in constants_value.iter().enumerate().take(lens) { + ctx.assign( + i, + lookup_xor_row, + F::from(((i / 16) * 64 + (i % 16) / 4 * 8 + i % 4) as u64), + ); + ctx.assign(i, lookup_xor_c, F::from(value)); + } + }); + + ctx.new_table(table().add(lookup_xor_row).add(lookup_xor_c)) +} + +fn keccak_xor_table_batch4( + ctx: &mut CircuitContext, + lens: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_xor_row: Queriable = ctx.fixed("xor row(batch 4)"); + let lookup_xor_c: Queriable = ctx.fixed("xor value(batch 4)"); + + let constants_value = XOR_VALUE_BATCH4; + assert_eq!(lens, constants_value.len()); + ctx.pragma_num_steps(lens); + ctx.fixed_gen(move |ctx| { + for (i, &value) in constants_value.iter().enumerate().take(lens) { + ctx.assign( + i, + lookup_xor_row, + F::from((i / 27 * 512 + (i % 27) / 9 * 64 + (i % 9) / 3 * 8 + i % 3) as u64), + ); + ctx.assign(i, lookup_xor_c, F::from(value)); + } + }); + + ctx.new_table(table().add(lookup_xor_row).add(lookup_xor_c)) +} + +fn keccak_chi_table( + ctx: &mut CircuitContext, + lens: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_chi_row: Queriable = ctx.fixed("chi row"); + let lookup_chi_c: Queriable = ctx.fixed("chi value"); + + let constants_value = CHI_VALUE; + assert_eq!(lens, constants_value.len()); + ctx.pragma_num_steps(lens); + ctx.fixed_gen(move |ctx| { + for (i, &value) in constants_value.iter().enumerate().take(lens) { + ctx.assign( + i, + lookup_chi_row, + F::from(((i / 25) * 64 + (i % 25) / 5 * 8 + i % 5) as u64), + ); + ctx.assign(i, lookup_chi_c, F::from(value)); + } + }); + + ctx.new_table(table().add(lookup_chi_row).add(lookup_chi_c)) +} + +fn keccak_pack_table( + ctx: &mut CircuitContext, + _: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_pack_row: Queriable = ctx.fixed("pack row"); + let lookup_pack_c: Queriable = ctx.fixed("pack value"); + ctx.pragma_num_steps((SQUEEZE_SPLIT_NUM * SQUEEZE_SPLIT_NUM) as usize); + ctx.fixed_gen(move |ctx| { + for i in 0..SQUEEZE_SPLIT_NUM as usize { + let index = (i / 8) * 512 + (i % 8) / 4 * 64 + (i % 4) / 2 * 8 + i % 2; + for j in 0..SQUEEZE_SPLIT_NUM as usize { + let index_j = (j / 8) * 512 + (j % 8) / 4 * 64 + (j % 4) / 2 * 8 + j % 2; + ctx.assign( + i * SQUEEZE_SPLIT_NUM as usize + j, + lookup_pack_row, + F::from((index * 4096 + index_j) as u64), + ); + ctx.assign( + i * SQUEEZE_SPLIT_NUM as usize + j, + lookup_pack_c, + F::from((i * 16 + j) as u64), + ); + } + } + }); + ctx.new_table(table().add(lookup_pack_row).add(lookup_pack_c)) +} + +fn keccak_round_constants_table( + ctx: &mut CircuitContext, + lens: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_constant_row: Queriable = ctx.fixed("constant row"); + let lookup_constant_c: Queriable = ctx.fixed("constant value"); + + let constants_value = ROUND_CST; + ctx.pragma_num_steps(lens); + ctx.fixed_gen(move |ctx| { + for (i, &value) in constants_value.iter().enumerate().take(lens) { + ctx.assign(i, lookup_constant_row, F::from(i as u64)); + ctx.assign(i, lookup_constant_c, pack_u64::(value)); + } + }); + ctx.new_table(table().add(lookup_constant_row).add(lookup_constant_c)) +} + +struct PreValues { + s_vec: Vec, + absorb_rows: Vec, + round_value: F, + absorb_split_vec: Vec>, + absorb_split_input_vec: Vec>, + split_values: Vec>, + is_padding_vec: Vec>, + input_len: F, + data_rlc_vec: Vec>, + data_rlc: F, + input_acc: F, + padded: F, +} + +#[derive(Clone)] +struct SqueezeValues { + s_new_vec: Vec, + squeeze_split_vec: Vec>, + squeeze_split_output_vec: Vec>, + hash_rlc: F, +} + +#[derive(Clone)] +struct OneRoundValues { + round: F, + next_round: F, + round_cst: F, + input_len: F, + input_acc: F, + + s_vec: Vec, + s_new_vec: Vec, + + theta_split_vec: Vec>, + theta_split_xor_vec: Vec>, + theta_sum_split_vec: Vec>, + theta_sum_split_xor_vec: Vec>, + + rho_bit_0: Vec, + rho_bit_1: Vec, + + chi_split_value_vec: Vec>, + + final_sum_split_vec: Vec, + final_xor_split_vec: Vec, + + svalues: SqueezeValues, + data_rlc: F, + padded: F, +} + +fn eval_keccak_f_one_round + Eq + Hash>( + round: u64, + cst: u64, + s_vec: Vec, + input_len: u64, + data_rlc: F, + input_acc: F, + padded: F, +) -> OneRoundValues { + let mut s_new_vec = Vec::new(); + let mut theta_split_vec = Vec::new(); + let mut theta_split_xor_vec = Vec::new(); + let mut theta_sum_split_xor_value_vec = Vec::new(); + let mut theta_sum_split_xor_move_value_vec = Vec::new(); + let mut theta_sum_split_vec = Vec::new(); + let mut theta_sum_split_xor_vec = Vec::new(); + let mut rho_pi_s_new_vec = vec![F::ZERO; PART_SIZE_SQURE as usize]; + let mut rho_bit_0 = vec![F::ZERO; 15]; + let mut rho_bit_1 = vec![F::ZERO; 15]; + let mut chi_sum_value_vec = Vec::new(); + let mut chi_sum_split_value_vec = Vec::new(); + let mut chi_split_value_vec = Vec::new(); + let mut final_sum_split_vec = Vec::new(); + let mut final_xor_split_vec = Vec::new(); + + let mut t_vec = vec![0; PART_SIZE_SQURE as usize]; + { + let mut i: usize = 1; + let mut j: usize = 0; + for t in 0..PART_SIZE_SQURE as usize { + if t == 0 { + i = 0; + j = 0 + } else if t == 1 { + i = 1; + j = 0; + } else { + let m = j; + j = (2 * i + 3 * j) % PART_SIZE as usize; + i = m; + } + t_vec[i * PART_SIZE as usize + j] = t; + } + } + + for i in 0..PART_SIZE as usize { + let sum = s_vec[i * PART_SIZE as usize] + + s_vec[i * PART_SIZE as usize + 1] + + s_vec[i * PART_SIZE as usize + 2] + + s_vec[i * PART_SIZE as usize + 3] + + s_vec[i * PART_SIZE as usize + 4]; + let sum_bits = convert_field_to_vec_bits(sum); + + let xor: F = field_xor( + field_xor( + field_xor( + field_xor( + s_vec[i * PART_SIZE as usize], + s_vec[i * PART_SIZE as usize + 1], + ), + s_vec[i * PART_SIZE as usize + 2], + ), + s_vec[i * PART_SIZE as usize + 3], + ), + s_vec[i * PART_SIZE as usize + 4], + ); + let xor_bits = convert_field_to_vec_bits(xor); + let mut xor_bits_move = xor_bits.clone(); + xor_bits_move.rotate_right(1); + let xor_rot: F = convert_bits_to_f(&xor_bits_move); + + let mut sum_split = Vec::new(); + let mut sum_split_xor = Vec::new(); + for k in 0..sum_bits.len() / 2 { + if k == sum_bits.len() / 2 - 1 { + sum_split.push(F::from_u128(sum_bits[2 * k] as u128)); + sum_split.push(F::from_u128(sum_bits[2 * k + 1] as u128)); + sum_split_xor.push(F::from_u128(xor_bits[2 * k] as u128)); + sum_split_xor.push(F::from_u128(xor_bits[2 * k + 1] as u128)); + } else { + sum_split.push( + F::from_u128(sum_bits[2 * k] as u128) + + F::from_u128(sum_bits[2 * k + 1] as u128) * F::from_u128(8), + ); + sum_split_xor.push( + F::from_u128(xor_bits[2 * k] as u128) + + F::from_u128(xor_bits[2 * k + 1] as u128) * F::from_u128(8), + ); + } + } + + theta_split_vec.push(sum_split); + theta_split_xor_vec.push(sum_split_xor); + theta_sum_split_xor_value_vec.push(xor); + theta_sum_split_xor_move_value_vec.push(xor_rot); + } + + let mut rho_index = 0; + for i in 0..PART_SIZE as usize { + let xor = theta_sum_split_xor_value_vec[(i + PART_SIZE as usize - 1) % PART_SIZE as usize]; + let xor_rot = theta_sum_split_xor_move_value_vec[(i + 1) % PART_SIZE as usize]; + for j in 0..PART_SIZE as usize { + let v = ((t_vec[i * PART_SIZE as usize + j] + 1) * t_vec[i * PART_SIZE as usize + j] + / 2) + % NUM_BITS_PER_WORD as usize; + let st = s_vec[i * PART_SIZE as usize + j] + xor + xor_rot; + let st_xor = field_xor(field_xor(s_vec[i * PART_SIZE as usize + j], xor), xor_rot); + let mut st_split = Vec::new(); + let mut st_split_xor = Vec::new(); + let mut st_bit_vec = convert_field_to_vec_bits(st); + let mut st_bit_xor_vec = convert_field_to_vec_bits(st_xor); + + // rho + // a[x][y][z] = a[x][y][z-(t+1)(t+2)/2] + if v % 3 == 1 { + rho_bit_0[rho_index] = + F::from(st_bit_vec[1] as u64) * F::from_u128(8) + F::from(st_bit_vec[0] as u64); + rho_bit_1[rho_index] = F::from(st_bit_vec[NUM_BITS_PER_WORD as usize - 1] as u64); + rho_index += 1 + } else if v % 3 == 2 { + rho_bit_0[rho_index] = F::from(st_bit_vec[0] as u64); + rho_bit_1[rho_index] = F::from(st_bit_vec[NUM_BITS_PER_WORD as usize - 1] as u64) + * F::from_u128(8) + + F::from(st_bit_vec[NUM_BITS_PER_WORD as usize - 2] as u64); + rho_index += 1 + } + + st_bit_vec.rotate_right(v); + st_bit_xor_vec.rotate_right(v); + + for i in 0..st_bit_vec.len() / 3 { + st_split.push( + F::from_u128(st_bit_vec[3 * i] as u128) + + F::from_u128(st_bit_vec[3 * i + 1] as u128) * F::from_u128(8) + + F::from_u128(st_bit_vec[3 * i + 2] as u128) * F::from_u128(64), + ); + st_split_xor.push( + F::from_u128(st_bit_xor_vec[3 * i] as u128) + + F::from_u128(st_bit_xor_vec[3 * i + 1] as u128) * F::from_u128(8) + + F::from_u128(st_bit_xor_vec[3 * i + 2] as u128) * F::from_u128(64), + ); + } + st_split.push(F::from_u128( + st_bit_vec[NUM_BITS_PER_WORD as usize - 1] as u128, + )); + st_split_xor.push(F::from_u128( + st_bit_xor_vec[NUM_BITS_PER_WORD as usize - 1] as u128, + )); + + theta_sum_split_vec.push(st_split); + theta_sum_split_xor_vec.push(st_split_xor); + + // pi + // a[y][2x + 3y] = a[x][y] + rho_pi_s_new_vec[j * PART_SIZE as usize + ((2 * i + 3 * j) % PART_SIZE as usize)] = + convert_bits_to_f(&st_bit_xor_vec); + } + } + + // chi + // a[x] = a[x] ^ (~a[x+1] & a[x+2]) + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + let a_vec = convert_field_to_vec_bits(rho_pi_s_new_vec[i * PART_SIZE as usize + j]); + let b_vec = + convert_field_to_vec_bits(rho_pi_s_new_vec[((i + 1) % 5) * PART_SIZE as usize + j]); + let c_vec = + convert_field_to_vec_bits(rho_pi_s_new_vec[((i + 2) % 5) * PART_SIZE as usize + j]); + let sum_vec: Vec = a_vec + .iter() + .zip(b_vec.iter().zip(c_vec.iter())) + .map(|(&a, (&b, &c))| 3 + b - 2 * a - c) + .collect(); + let sum: F = convert_bits_to_f(&sum_vec); + + let split_chi_value: Vec = sum_vec + .iter() + .map(|&v| if v == 1 || v == 2 { 1 } else { 0 }) + .collect(); + let sum_chi = convert_bits_to_f(&split_chi_value); + + let sum_split_vec: Vec = (0..NUM_PER_WORD_BATCH3 as usize) + .map(|i| { + if i == NUM_PER_WORD_BATCH3 as usize - 1 { + F::from_u128(sum_vec[3 * i] as u128) + } else { + F::from_u128( + sum_vec[3 * i] as u128 + + sum_vec[3 * i + 1] as u128 * 8 + + sum_vec[3 * i + 2] as u128 * 64, + ) + } + }) + .collect(); + let chi_split_vec: Vec = (0..NUM_PER_WORD_BATCH3 as usize) + .map(|i| { + if i == NUM_PER_WORD_BATCH3 as usize - 1 { + F::from_u128(split_chi_value[3 * i] as u128) + } else { + F::from_u128( + split_chi_value[3 * i] as u128 + + split_chi_value[3 * i + 1] as u128 * 8 + + split_chi_value[3 * i + 2] as u128 * 64, + ) + } + }) + .collect(); + + chi_sum_value_vec.push(sum); + s_new_vec.push(sum_chi); + chi_sum_split_value_vec.push(sum_split_vec); + chi_split_value_vec.push(chi_split_vec); + } + } + + let s_iota_vec = convert_field_to_vec_bits(s_new_vec[0]); + let cst_vec = convert_field_to_vec_bits(pack_u64::(cst)); + let split_xor_vec: Vec = s_iota_vec + .iter() + .zip(cst_vec.iter()) + .map(|(v1, v2)| v1 ^ v2) + .collect(); + let xor_rows: Vec<(F, F)> = s_iota_vec + .iter() + .zip(cst_vec.iter()) + .map(|(v1, v2)| { + ( + F::from_u128((v1 + v2) as u128), + F::from_u128((v1 ^ v2) as u128), + ) + }) + .collect(); + + for i in 0..NUM_PER_WORD_BATCH4 as usize { + final_sum_split_vec.push( + xor_rows[4 * i].0 + + xor_rows[4 * i + 1].0 * F::from_u128(8) + + xor_rows[4 * i + 2].0 * F::from_u128(64) + + xor_rows[4 * i + 3].0 * F::from_u128(512), + ); + final_xor_split_vec.push( + xor_rows[4 * i].1 + + xor_rows[4 * i + 1].1 * F::from_u128(8) + + xor_rows[4 * i + 2].1 * F::from_u128(64) + + xor_rows[4 * i + 3].1 * F::from_u128(512), + ); + } + + s_new_vec[0] = convert_bits_to_f(&split_xor_vec); + + let svalues = SqueezeValues { + s_new_vec: Vec::new(), + squeeze_split_vec: Vec::new(), + squeeze_split_output_vec: Vec::new(), + hash_rlc: F::ZERO, + }; + + let next_round = if round < NUM_ROUNDS - 1 { round + 1 } else { 0 }; + + OneRoundValues { + round: F::from(round), + round_cst: pack_u64::(cst), + input_len: F::from(input_len), + next_round: F::from(next_round), + + s_vec, + s_new_vec, + + theta_split_vec, + theta_split_xor_vec, + theta_sum_split_vec, + theta_sum_split_xor_vec, + + rho_bit_0, + rho_bit_1, + + chi_split_value_vec, + + final_sum_split_vec, + final_xor_split_vec, + + svalues, + data_rlc, + input_acc, + padded, + } +} + +fn keccak_circuit + Eq + Hash>( + ctx: &mut CircuitContext, + param: CircuitParams, +) { + use chiquito::frontend::dsl::cb::*; + + let s_vec: Vec> = (0..PART_SIZE_SQURE) + .map(|i| ctx.forward(&format!("s[{}][{}]", i / PART_SIZE, i % PART_SIZE))) + .collect(); + + let round = ctx.forward("round"); + let data_rlc = ctx.forward("data_rlc"); + + let input_len = ctx.forward("input_len"); + let input_acc = ctx.forward("input_acc"); + + let padded = ctx.forward("padded"); + + let keccak_first_step = ctx.step_type_def("keccak first step", |ctx| { + let s_vec = s_vec.clone(); + let setup_s_vec = s_vec.clone(); + + let absorb_vec: Vec> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| ctx.internal(&format!("absorb_{}", i))) + .collect(); + let setup_absorb_vec = absorb_vec.clone(); + + let absorb_split_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("absorb_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_absorb_split_vec = absorb_split_vec.clone(); + + let absorb_split_input_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("absorb_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_absorb_split_input_vec = absorb_split_input_vec.clone(); + + let is_padding_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("is_padding_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_is_padding_vec = is_padding_vec.clone(); + + let data_rlc_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("is_padding_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_data_rlc_vec = data_rlc_vec.clone(); + + ctx.setup(move |ctx| { + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + ctx.constr(eq(setup_s_vec[i * PART_SIZE as usize + j], 0)); + if j * PART_SIZE as usize + i < NUM_WORDS_TO_ABSORB as usize { + // xor + // 000 xor 000/001 -> 000 + 000/001 + ctx.transition(eq( + setup_s_vec[i * PART_SIZE as usize + j] + + setup_absorb_vec[j * PART_SIZE as usize + i], + setup_s_vec[i * PART_SIZE as usize + j].next(), + )); + + let mut tmp_absorb_split_sum_vec = setup_absorb_split_vec + [j * PART_SIZE as usize + i][SQUEEZE_SPLIT_NUM as usize / 2 - 1] + * 1; + for k in 1..SQUEEZE_SPLIT_NUM as usize / 2 { + tmp_absorb_split_sum_vec = tmp_absorb_split_sum_vec * 4096 * 4096 + + setup_absorb_split_vec[j * PART_SIZE as usize + i] + [SQUEEZE_SPLIT_NUM as usize / 2 - k - 1]; + } + ctx.constr(eq( + setup_absorb_vec[j * PART_SIZE as usize + i], + tmp_absorb_split_sum_vec, + )); + + for k in 0..SQUEEZE_SPLIT_NUM as usize / 2 { + ctx.add_lookup( + param + .pack_table + .apply(setup_absorb_split_vec[j * PART_SIZE as usize + i][k]) + .apply( + setup_absorb_split_input_vec[j * PART_SIZE as usize + i][k], + ), + ); + ctx.constr(eq( + (setup_is_padding_vec[j * PART_SIZE as usize + i][k] - 1) + * setup_is_padding_vec[j * PART_SIZE as usize + i][k], + 0, + )); + } + } else { + ctx.transition(eq( + setup_s_vec[i * PART_SIZE as usize + j], + setup_s_vec[i * PART_SIZE as usize + j].next(), + )); + } + } + } + ctx.constr(eq(data_rlc, 0)); + ctx.transition(eq( + setup_data_rlc_vec[NUM_WORDS_TO_ABSORB as usize - 1] + [SQUEEZE_SPLIT_NUM as usize / 2 - 1], + data_rlc.next(), + )); + let mut acc_value = 0.expr() * 1; + for i in 0..NUM_WORDS_TO_ABSORB as usize { + if i == 0 { + // data_rlc_vec[0][0] = 0 * 256 + absorb_split_input_vec[0][0]; + ctx.constr(eq( + setup_data_rlc_vec[i][0], + (data_rlc * 256 + setup_absorb_split_input_vec[i][0]) + * (1.expr() - setup_is_padding_vec[i][0]) + + data_rlc * setup_is_padding_vec[i][0], + )); + } else { + // data_rlc_vec[0][0] = 0 * 256 + absorb_split_input_vec[0][0]; + ctx.constr(eq( + setup_data_rlc_vec[i][0], + (setup_data_rlc_vec[i - 1][SQUEEZE_SPLIT_NUM as usize / 2 - 1] * 256 + + setup_absorb_split_input_vec[i][0]) + * (setup_is_padding_vec[i][0] - 1).neg() + + setup_data_rlc_vec[i - 1][SQUEEZE_SPLIT_NUM as usize / 2 - 1] + * setup_is_padding_vec[i][0], + )); + } + + for k in 1..SQUEEZE_SPLIT_NUM as usize / 2 { + ctx.constr(eq( + setup_data_rlc_vec[i][k], + (setup_data_rlc_vec[i][k - 1] * 256 + setup_absorb_split_input_vec[i][k]) + * (setup_is_padding_vec[i][k] - 1).neg() + + setup_data_rlc_vec[i][k - 1] * setup_is_padding_vec[i][k], + )); + } + acc_value = acc_value + (1.expr() - setup_is_padding_vec[i][0]); + if i == 0 { + ctx.constr(eq(setup_is_padding_vec[i][0], 0)); + } else { + ctx.constr(eq( + (setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) + * ((setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) - 1), + 0, + )); + ctx.constr(eq( + (setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) + * (setup_absorb_split_vec[i][0] - 1), + 0, + )); + } + for k in 1..8 { + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * ((setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) - 1), + 0, + )); + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 1), + 0, + )); + // the last one + if k == 7 && i == NUM_WORDS_TO_ABSORB as usize - 1 { + // the padding length is equal than 1 byte + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 2097153), + 0, + )); + // the padding length is bigger than 1 byte + ctx.constr(eq( + setup_is_padding_vec[i][k - 1] + * (setup_absorb_split_vec[i][k] - 2097152), + 0, + )); + } else { + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 1), + 0, + )); + // the first padding byte = 1, other = 0 + ctx.constr(eq( + setup_is_padding_vec[i][k] + * (setup_is_padding_vec[i][k] + - setup_is_padding_vec[i][k - 1] + - setup_absorb_split_vec[i][k]), + 0, + )); + } + + acc_value = acc_value + (1.expr() - setup_is_padding_vec[i][k]); + } + } + ctx.constr(eq( + (input_len - input_acc - acc_value.clone()) + * setup_is_padding_vec[NUM_WORDS_TO_ABSORB as usize - 1][7], + 0, + )); + ctx.transition(eq(input_acc + acc_value, input_acc.next())); + + ctx.constr(eq(round, 0)); + ctx.transition(eq(round, round.next())); + ctx.transition(eq(input_len, input_len.next())); + ctx.constr(eq(padded, 0)); + ctx.transition(eq( + setup_is_padding_vec[NUM_WORDS_TO_ABSORB as usize - 1][7], + padded.next(), + )); + }); + + ctx.wg::, _>(move |ctx, values| { + for (q, v) in absorb_vec.iter().zip(values.absorb_rows.iter()) { + ctx.assign(*q, *v) + } + + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + ctx.assign(s_vec[i * PART_SIZE as usize + j], F::ZERO); + } + } + for (q_vec, v_vec) in absorb_split_vec.iter().zip(values.absorb_split_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in absorb_split_input_vec + .iter() + .zip(values.absorb_split_input_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + + for (q_vec, v_vec) in is_padding_vec.iter().zip(values.is_padding_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + + for (q_vec, v_vec) in data_rlc_vec.iter().zip(values.data_rlc_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + ctx.assign(round, values.round_value); + ctx.assign(input_len, values.input_len); + ctx.assign(data_rlc, values.data_rlc); + ctx.assign(input_acc, values.input_acc); + ctx.assign(padded, values.padded); + }) + }); + + let keccak_pre_step = ctx.step_type_def("keccak pre step", |ctx| { + let s_vec = s_vec.clone(); + let setup_s_vec = s_vec.clone(); + + let absorb_vec: Vec> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| ctx.internal(&format!("absorb_{}", i))) + .collect(); + let setup_absorb_vec = absorb_vec.clone(); + + let absorb_split_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("absorb_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_absorb_split_vec = absorb_split_vec.clone(); + + let absorb_split_input_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("absorb_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_absorb_split_input_vec = absorb_split_input_vec.clone(); + + let sum_split_value_vec: Vec> = (0..PART_SIZE_SQURE) + .map(|i| ctx.internal(&format!("sum_split_value_{}", i))) + .collect(); + let setup_sum_split_value_vec = sum_split_value_vec.clone(); + + let split_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..NUM_PER_WORD_BATCH4) + .map(|j| ctx.internal(&format!("split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_split_vec = split_vec.clone(); + + let split_xor_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..NUM_PER_WORD_BATCH4) + .map(|j| ctx.internal(&format!("split_xor_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_split_xor_vec = split_xor_vec.clone(); + + let is_padding_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("is_padding_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_is_padding_vec = is_padding_vec.clone(); + + let data_rlc_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("data_rlc_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_data_rlc_vec = data_rlc_vec.clone(); + + ctx.setup(move |ctx| { + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + if j * PART_SIZE as usize + i < NUM_WORDS_TO_ABSORB as usize { + // xor + ctx.constr(eq( + setup_s_vec[i * PART_SIZE as usize + j] + + setup_absorb_vec[j * PART_SIZE as usize + i], + setup_sum_split_value_vec[i * PART_SIZE as usize + j], + )); + + let mut tmp_absorb_split_sum_vec = setup_absorb_split_vec + [j * PART_SIZE as usize + i][SQUEEZE_SPLIT_NUM as usize / 2 - 1] + * 1; + for k in 1..SQUEEZE_SPLIT_NUM as usize / 2 { + tmp_absorb_split_sum_vec = tmp_absorb_split_sum_vec * 4096 * 4096 + + setup_absorb_split_vec[j * PART_SIZE as usize + i] + [SQUEEZE_SPLIT_NUM as usize / 2 - k - 1]; + } + ctx.constr(eq( + setup_absorb_vec[j * PART_SIZE as usize + i], + tmp_absorb_split_sum_vec, + )); + for k in 0..SQUEEZE_SPLIT_NUM as usize / 2 { + ctx.add_lookup( + param + .pack_table + .apply(setup_absorb_split_vec[j * PART_SIZE as usize + i][k]) + .apply( + setup_absorb_split_input_vec[j * PART_SIZE as usize + i][k], + ), + ); + } + + for k in 0..NUM_PER_WORD_BATCH4 as usize { + ctx.add_lookup( + param + .xor_table_batch4 + .apply(setup_split_vec[j * PART_SIZE as usize + i][k]) + .apply(setup_split_xor_vec[j * PART_SIZE as usize + i][k]), + ); + } + } else { + ctx.transition(eq( + setup_s_vec[i * PART_SIZE as usize + j], + setup_s_vec[i * PART_SIZE as usize + j].next(), + )); + } + } + } + + ctx.transition(eq( + setup_data_rlc_vec[NUM_WORDS_TO_ABSORB as usize - 1] + [SQUEEZE_SPLIT_NUM as usize / 2 - 1], + data_rlc.next(), + )); + + let mut acc_value = 0.expr() * 1; + for i in 0..NUM_WORDS_TO_ABSORB as usize { + if i == 0 { + // data_rlc_vec[0][0] = 0 * 256 + absorb_split_input_vec[0][0]; + ctx.constr(eq( + setup_data_rlc_vec[i][0], + (data_rlc * 256 + setup_absorb_split_input_vec[i][0]) + * (setup_is_padding_vec[i][0] - 1).neg() + + data_rlc * setup_is_padding_vec[i][0], + )); + } else { + // data_rlc_vec[0][0] = 0 * 256 + absorb_split_input_vec[0][0]; + ctx.constr(eq( + setup_data_rlc_vec[i][0], + (setup_data_rlc_vec[i - 1][SQUEEZE_SPLIT_NUM as usize / 2 - 1] * 256 + + setup_absorb_split_input_vec[i][0]) + * (setup_is_padding_vec[i][0] - 1).neg() + + setup_data_rlc_vec[i - 1][SQUEEZE_SPLIT_NUM as usize / 2 - 1] + * setup_is_padding_vec[i][0], + )); + } + for k in 1..SQUEEZE_SPLIT_NUM as usize / 2 { + ctx.constr(eq( + setup_data_rlc_vec[i][k], + (setup_data_rlc_vec[i][k - 1] * 256 + setup_absorb_split_input_vec[i][k]) + * (setup_is_padding_vec[i][k] - 1).neg() + + setup_data_rlc_vec[i][k - 1] * setup_is_padding_vec[i][k], + )); + } + + acc_value = acc_value + (1.expr() - setup_is_padding_vec[i][0]); + if i == 0 { + ctx.constr(eq( + setup_is_padding_vec[i][0] * (setup_is_padding_vec[i][0] - 1), + 0, + )); + } else { + ctx.constr(eq( + (setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) + * ((setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) - 1), + 0, + )); + ctx.constr(eq( + (setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) + * (setup_absorb_split_vec[i][0] - 1), + 0, + )); + } + for k in 1..8 { + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * ((setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) - 1), + 0, + )); + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 1), + 0, + )); + + if k == 7 && i == NUM_WORDS_TO_ABSORB as usize - 1 { + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 2097153), + 0, + )); + ctx.constr(eq( + setup_is_padding_vec[i][k - 1] + * (setup_absorb_split_vec[i][k] - 2097152), + 0, + )); + } else { + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 1), + 0, + )); + ctx.constr(eq( + setup_is_padding_vec[i][k] + * (setup_is_padding_vec[i][k] + - setup_is_padding_vec[i][k - 1] + - setup_absorb_split_vec[i][k]), + 0, + )); + } + acc_value = acc_value + (1.expr() - setup_is_padding_vec[i][k]); + } + } + + for s in 0..NUM_WORDS_TO_ABSORB as usize { + let mut sum_split_vec = setup_split_vec[s][NUM_PER_WORD_BATCH4 as usize - 1] * 1; + let mut sum_split_xor_vec = + setup_split_xor_vec[s][NUM_PER_WORD_BATCH4 as usize - 1] * 1; + for (&value, &xor_value) in setup_split_vec[s] + .iter() + .rev() + .zip(setup_split_xor_vec[s].iter().rev()) + .skip(1) + { + sum_split_vec = sum_split_vec * 64 * 64 + value; + sum_split_xor_vec = sum_split_xor_vec * 64 * 64 + xor_value; + } + ctx.constr(eq( + sum_split_vec, + setup_sum_split_value_vec + [(s % PART_SIZE as usize) * PART_SIZE as usize + s / PART_SIZE as usize], + )); + ctx.transition(eq( + sum_split_xor_vec, + setup_s_vec + [(s % PART_SIZE as usize) * PART_SIZE as usize + s / PART_SIZE as usize] + .next(), + )); + } + + ctx.constr(eq( + (input_len - input_acc - acc_value.clone()) + * setup_is_padding_vec[NUM_WORDS_TO_ABSORB as usize - 1][7], + 0, + )); + ctx.transition(eq(input_acc + acc_value, input_acc.next())); + + ctx.transition(eq(round, round.next())); + ctx.transition(eq(input_len, input_len.next())); + + ctx.constr(eq(padded, 0)); + ctx.transition(eq( + setup_is_padding_vec[NUM_WORDS_TO_ABSORB as usize - 1][7], + padded.next(), + )); + }); + + ctx.wg::, _>(move |ctx, values| { + ctx.assign(round, F::ZERO); + for (q, v) in absorb_vec.iter().zip(values.absorb_rows.iter()) { + ctx.assign(*q, *v) + } + + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + ctx.assign(s_vec[i * PART_SIZE as usize + j], F::ZERO); + } + } + for (q_vec, v_vec) in absorb_split_vec.iter().zip(values.absorb_split_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in absorb_split_input_vec + .iter() + .zip(values.absorb_split_input_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + + for (q_vec, v_vec) in is_padding_vec.iter().zip(values.is_padding_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + + for (q_vec, v_vec) in data_rlc_vec.iter().zip(values.data_rlc_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + if j * PART_SIZE as usize + i < NUM_WORDS_TO_ABSORB as usize { + ctx.assign( + sum_split_value_vec[i * PART_SIZE as usize + j], + values.s_vec[i * PART_SIZE as usize + j] + + values.absorb_rows[j * PART_SIZE as usize + i], + ); + ctx.assign( + absorb_vec[j * PART_SIZE as usize + i], + values.absorb_rows[j * PART_SIZE as usize + i], + ); + } else { + ctx.assign( + sum_split_value_vec[i * PART_SIZE as usize + j], + values.s_vec[i * PART_SIZE as usize + j], + ); + } + ctx.assign( + s_vec[i * PART_SIZE as usize + j], + values.s_vec[i * PART_SIZE as usize + j], + ); + } + } + + for i in 0..NUM_WORDS_TO_ABSORB as usize { + for j in 0..NUM_PER_WORD_BATCH4 as usize { + ctx.assign(split_vec[i][j], values.split_values[i][j].0); + ctx.assign(split_xor_vec[i][j], values.split_values[i][j].1); + } + } + ctx.assign(input_len, values.input_len); + ctx.assign(data_rlc, values.data_rlc); + ctx.assign(input_acc, values.input_acc); + ctx.assign(padded, values.padded); + }) + }); + + let keccak_one_round_step_vec: Vec, _>> = (0..2) + .map(|last| { + ctx.step_type_def("keccak one round", |ctx| { + let s_vec = s_vec.clone(); + let setup_s_vec = s_vec.clone(); + + let theta_split_vec: Vec>> = (0..PART_SIZE) + .map(|i| { + (0..NUM_PER_WORD + 1) + .map(|j| ctx.internal(&format!("theta_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_theta_split_vec = theta_split_vec.clone(); + + let theta_split_xor_vec: Vec>> = (0..PART_SIZE) + .map(|i| { + (0..NUM_PER_WORD + 1) + .map(|j| ctx.internal(&format!("theta_split_xor_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_theta_split_xor_vec = theta_split_xor_vec.clone(); + + let theta_sum_split_vec: Vec>> = (0..PART_SIZE_SQURE) + .map(|i| { + (0..NUM_PER_WORD_BATCH3) + .map(|j| ctx.internal(&format!("theta_sum_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_theta_sum_split_vec = theta_sum_split_vec.clone(); + + let theta_sum_split_xor_vec: Vec>> = (0..PART_SIZE_SQURE) + .map(|i| { + (0..NUM_PER_WORD_BATCH3) + .map(|j| ctx.internal(&format!("theta_sum_split_xor_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_theta_sum_split_xor_vec = theta_sum_split_xor_vec.clone(); + + let rho_bit_0: Vec> = (0..15) + .map(|i| ctx.internal(&format!("rho_bit0_{}", i))) + .collect(); + let setup_rho_bit_0 = rho_bit_0.clone(); + + let rho_bit_1: Vec> = (0..15) + .map(|i| ctx.internal(&format!("rho_bit1_{}", i))) + .collect(); + let setup_rho_bit_1 = rho_bit_1.clone(); + + let chi_split_value_vec: Vec>> = (0..PART_SIZE_SQURE) + .map(|i| { + (0..NUM_PER_WORD_BATCH3) + .map(|j| ctx.internal(&format!("chi_split_value_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_chi_split_value_vec: Vec>> = chi_split_value_vec.clone(); + + let final_xor_split_vec: Vec> = (0..NUM_PER_WORD_BATCH4) + .map(|i| ctx.internal(&format!("final_xor_split_{}", i))) + .collect(); + let setup_final_xor_split_vec = final_xor_split_vec.clone(); + + let final_sum_split_vec: Vec> = (0..NUM_PER_WORD_BATCH4) + .map(|i| ctx.internal(&format!("final_sum_split_{}", i))) + .collect(); + let setup_final_sum_split_vec = final_sum_split_vec.clone(); + let round_cst: Queriable = ctx.internal("round constant"); + + let mut hash_rlc = data_rlc; + let mut next_round = round; + if last == 0 { + next_round = ctx.internal("next_round"); + } else { + hash_rlc = ctx.internal("hash_rlc"); + } + + let mut squeeze_split_vec: Vec>> = Vec::new(); + if last == 1 { + squeeze_split_vec = (0..SQUEEZE_VECTOR_NUM) + .map(|i| { + (0..SQUEEZE_SPLIT_NUM / 2) + .map(|j| ctx.internal(&format!("squeeze_split_vec_{}_{}", i, j))) + .collect() + }) + .collect(); + } + let setup_squeeze_split_vec = squeeze_split_vec.clone(); + + let mut squeeze_split_output_vec: Vec>> = Vec::new(); + if last == 1 { + squeeze_split_output_vec = (0..SQUEEZE_VECTOR_NUM) + .map(|i| { + (0..SQUEEZE_SPLIT_NUM / 2) + .map(|j| { + ctx.internal(&format!("squeeze_split_output_vec_{}_{}", i, j)) + }) + .collect() + }) + .collect(); + } + let setup_squeeze_split_output_vec = squeeze_split_output_vec.clone(); + + let mut s_new_vec: Vec> = Vec::new(); + if last == 1 { + s_new_vec = (0..PART_SIZE_SQURE) + .map(|i| { + ctx.internal(&format!("s_new[{}][{}]", i / PART_SIZE, i % PART_SIZE)) + }) + .collect(); + } + let setup_s_new_vec = s_new_vec.clone(); + + ctx.setup(move |ctx| { + let mut t_vec = vec![0; PART_SIZE_SQURE as usize]; + { + let mut i: usize = 1; + let mut j: usize = 0; + for t in 0..PART_SIZE_SQURE { + if t == 0 { + i = 0; + j = 0 + } else if t == 1 { + i = 1; + j = 0; + } else { + let m = j; + j = (2 * i + 3 * j) % PART_SIZE as usize; + i = m; + } + t_vec[i * PART_SIZE as usize + j] = t; + } + } + + // Theta + let mut tmp_theta_sum_split_xor_vec = Vec::new(); + let mut tmp_theta_sum_move_split_xor_vec = Vec::new(); + for s in 0..PART_SIZE as usize { + // 1. \sum_y' a[x][y'][z] + // 2. xor(sum) + let mut sum_split_vec = setup_theta_split_vec[s][NUM_PER_WORD as usize] * 8 + + setup_theta_split_vec[s][NUM_PER_WORD as usize - 1]; + let mut sum_split_xor_vec = + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize] * 8 + + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize - 1]; + let mut sum_split_xor_move_value_vec = + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize - 1] * 1; + for k in 1..NUM_PER_WORD as usize { + sum_split_vec = sum_split_vec * 64 + + setup_theta_split_vec[s][NUM_PER_WORD as usize - k - 1]; + sum_split_xor_vec = sum_split_xor_vec * 64 + + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize - k - 1]; + sum_split_xor_move_value_vec = sum_split_xor_move_value_vec * 64 + + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize - k - 1]; + } + sum_split_xor_move_value_vec = sum_split_xor_move_value_vec * 8 + + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize]; + + for k in 0..NUM_PER_WORD as usize { + ctx.add_lookup( + param + .xor_table + .apply(setup_theta_split_vec[s][k]) + .apply(setup_theta_split_xor_vec[s][k]), + ); + } + + ctx.constr(eq( + setup_s_vec[s * PART_SIZE as usize] + + setup_s_vec[s * PART_SIZE as usize + 1] + + setup_s_vec[s * PART_SIZE as usize + 2] + + setup_s_vec[s * PART_SIZE as usize + 3] + + setup_s_vec[s * PART_SIZE as usize + 4], + sum_split_vec, + )); + + tmp_theta_sum_split_xor_vec.push(sum_split_xor_vec); + tmp_theta_sum_move_split_xor_vec.push(sum_split_xor_move_value_vec); + } + + let mut rho_index = 0; + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + // Theta + // 3. a[x][y][z] = a[x][y][z] + xor(\sum_y' a[x-1][y'][z]) + xor(\sum + // a[x+1][y'][z-1]) 4. a'[x][y][z'+(t+1)(t+2)/2] = + // xor(a[x][y][z'+(t+1)(t+2)/2]) rho + // a[x][y][z'] = a[x][y][z'] + let v = ((t_vec[i * PART_SIZE as usize + j] + 1) + * t_vec[i * PART_SIZE as usize + j] + / 2) + % NUM_BITS_PER_WORD; + + for k in 0..NUM_PER_WORD_BATCH3 as usize { + ctx.add_lookup( + param + .xor_table_batch3 + .apply( + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [k], + ) + .apply( + setup_theta_sum_split_xor_vec + [i * PART_SIZE as usize + j][k], + ), + ); + } + + let mut tmp_theta_sum_split; + if v % 3 == 0 { + let st = (v / 3) as usize; + if st != 0 { + tmp_theta_sum_split = setup_theta_sum_split_vec + [i * PART_SIZE as usize + j][st - 1] + * 1; + for k in 1..st { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [st - k - 1]; + } + tmp_theta_sum_split = tmp_theta_sum_split * 8 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - 1]; + for k in 1..NUM_PER_WORD_BATCH3 as usize - st { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - k - 1]; + } + } else { + tmp_theta_sum_split = setup_theta_sum_split_vec + [i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - 1] + * 1; + for k in 1..NUM_PER_WORD_BATCH3 as usize { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - k - 1]; + } + } + } else if v % 3 == 1 { + let st = ((v - 1) / 3) as usize; + tmp_theta_sum_split = setup_rho_bit_1[rho_index] * 1; + for k in 0..st { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [st - k - 1]; + } + for k in 0..NUM_PER_WORD_BATCH3 as usize - st - 1 { + if k == 0 { + tmp_theta_sum_split = tmp_theta_sum_split * 8 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - 1]; + } else { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - k - 1]; + } + } + tmp_theta_sum_split = + tmp_theta_sum_split * 64 + setup_rho_bit_0[rho_index]; + ctx.constr(eq( + setup_rho_bit_0[rho_index] * 8 + setup_rho_bit_1[rho_index], + setup_theta_sum_split_vec[i * PART_SIZE as usize + j][st], + )); + rho_index += 1; + } else { + let st = ((v - 2) / 3) as usize; + tmp_theta_sum_split = setup_rho_bit_1[rho_index] * 1; + for k in 0..st { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [st - k - 1]; + } + for k in 0..NUM_PER_WORD_BATCH3 as usize - st - 1 { + if k == 0 { + tmp_theta_sum_split = tmp_theta_sum_split * 8 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - 1]; + } else { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - k - 1]; + } + } + tmp_theta_sum_split = + tmp_theta_sum_split * 8 + setup_rho_bit_0[rho_index]; + ctx.constr(eq( + setup_rho_bit_0[rho_index] * 64 + setup_rho_bit_1[rho_index], + setup_theta_sum_split_vec[i * PART_SIZE as usize + j][st], + )); + rho_index += 1; + } + + ctx.constr(eq( + tmp_theta_sum_split, + setup_s_vec[i * PART_SIZE as usize + j] + + tmp_theta_sum_split_xor_vec + [(i + PART_SIZE as usize - 1) % PART_SIZE as usize] + .clone() + + tmp_theta_sum_move_split_xor_vec + [(i + 1) % PART_SIZE as usize] + .clone(), + )); + } + } + + let mut tmp_pi_sum_split_xor_vec = setup_theta_sum_split_xor_vec.clone(); + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + tmp_pi_sum_split_xor_vec + [j * PART_SIZE as usize + ((2 * i + 3 * j) % PART_SIZE as usize)] = + setup_theta_sum_split_xor_vec[i * PART_SIZE as usize + j].clone(); + } + } + + // chi + // a[x] = a[x] ^ (~a[x+1] & a[x+2]) + // chi(3 - 2a[x] + a[x+1] - a[x+2]) + ctx.add_lookup(param.constants_table.apply(round).apply(round_cst)); + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + for k in 0..NUM_PER_WORD_BATCH3 as usize { + ctx.add_lookup( + param + .chi_table + .apply( + tmp_pi_sum_split_xor_vec[((i + 1) + % PART_SIZE as usize) + * PART_SIZE as usize + + j][k] + - tmp_pi_sum_split_xor_vec + [i * PART_SIZE as usize + j][k] + - tmp_pi_sum_split_xor_vec + [i * PART_SIZE as usize + j][k] + - tmp_pi_sum_split_xor_vec[((i + 2) + % PART_SIZE as usize) + * PART_SIZE as usize + + j][k] + + 219, + ) + .apply( + setup_chi_split_value_vec[i * PART_SIZE as usize + j] + [k], + ), + ); + } + + let mut tmp_sum_split_chi_vec = setup_chi_split_value_vec + [i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - 1] + * 1; + for k in 1..NUM_PER_WORD_BATCH3 as usize { + tmp_sum_split_chi_vec = tmp_sum_split_chi_vec * 512 + + setup_chi_split_value_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - k - 1]; + } + + if i != 0 || j != 0 { + if last == 1 { + ctx.transition(eq( + tmp_sum_split_chi_vec, + setup_s_new_vec[i * PART_SIZE as usize + j], + )); + } else { + ctx.transition(eq( + tmp_sum_split_chi_vec, + setup_s_vec[i * PART_SIZE as usize + j].next(), + )); + } + } else { + let mut tmp_sum_s_split_vec = + setup_final_sum_split_vec[NUM_PER_WORD_BATCH4 as usize - 1] * 1; + let mut tmp_sum_s_split_xor_vec = + setup_final_xor_split_vec[NUM_PER_WORD_BATCH4 as usize - 1] * 1; + ctx.add_lookup( + param + .xor_table_batch4 + .apply( + setup_final_sum_split_vec + [NUM_PER_WORD_BATCH4 as usize - 1], + ) + .apply( + setup_final_xor_split_vec + [NUM_PER_WORD_BATCH4 as usize - 1], + ), + ); + for (&value, &xor_value) in setup_final_sum_split_vec + .iter() + .zip(setup_final_xor_split_vec.iter()) + .rev() + .skip(1) + { + tmp_sum_s_split_vec = tmp_sum_s_split_vec * 64 * 64 + value; + tmp_sum_s_split_xor_vec = + tmp_sum_s_split_xor_vec * 64 * 64 + xor_value; + ctx.add_lookup( + param.xor_table_batch4.apply(value).apply(xor_value), + ); + } + + ctx.constr(eq( + tmp_sum_s_split_vec, + tmp_sum_split_chi_vec + round_cst, + )); + if last == 1 { + ctx.transition(eq( + tmp_sum_s_split_xor_vec, + setup_s_new_vec[i * PART_SIZE as usize + j], + )); + } else { + ctx.transition(eq( + tmp_sum_s_split_xor_vec, + setup_s_vec[i * PART_SIZE as usize + j].next(), + )); + } + } + } + } + + if last == 1 { + for i in 0..SQUEEZE_VECTOR_NUM as usize { + let mut tmp_squeeze_split_sum = + setup_squeeze_split_vec[i][SQUEEZE_SPLIT_NUM as usize / 2 - 1] * 1; + for j in 1..SQUEEZE_SPLIT_NUM as usize / 2 { + tmp_squeeze_split_sum = tmp_squeeze_split_sum * 4096 * 4096 + + setup_squeeze_split_vec[i] + [SQUEEZE_SPLIT_NUM as usize / 2 - j - 1]; + } + ctx.constr(eq( + tmp_squeeze_split_sum, + setup_s_new_vec[i * PART_SIZE as usize], + )); + for j in 0..SQUEEZE_SPLIT_NUM as usize / 2 { + ctx.add_lookup( + param + .pack_table + .apply(setup_squeeze_split_vec[i][j]) + .apply(setup_squeeze_split_output_vec[i][j]), + ); + } + // hash_rlc + let mut tmp_hash_rlc_value = setup_squeeze_split_output_vec[0][0] * 1; + + for (i, values) in setup_squeeze_split_output_vec.iter().enumerate() { + for (j, &value) in values + .iter() + .enumerate() + .take(SQUEEZE_SPLIT_NUM as usize / 2) + { + if i != 0 || j != 0 { + tmp_hash_rlc_value = tmp_hash_rlc_value * 256 + value; + } + } + } + } + } + + if last == 1 { + ctx.constr(eq(round + 1, NUM_ROUNDS)); + ctx.constr(eq(input_len, input_acc)); + ctx.constr(eq(padded, 1)); + } else { + ctx.constr(eq((round + 1 - next_round) * next_round, 0)); + // xor((round + 1 = next_round), (round + 1 = NUM_ROUNDS)) + // (round + 1 - next_round) / NUM_ROUNDS = 0, round < 23; 1, round = 23 + // (round + 1 - NUM_ROUNDS) / (NUM_ROUNDS - next_round) = 1,round < 23; 0, + // round = 23 (round + 1 - next_round) / NUM_ROUNDS + // + (round + 1 - NUM_ROUNDS) / (NUM_ROUNDS - next_round) + // - 2 * ((round + 1 - next_round) / NUM_ROUNDS) * ((round + 1 - NUM_ROUNDS) + // / (NUM_ROUNDS - next_round)) = 1 + // (round + 1 - next_round) * (NUM_ROUNDS - next_round) + (round + 1 - + // NUM_ROUNDS) * NUM_ROUNDS + 2 * (round + 1 - + // next_round) * (round + 1 - NUM_ROUNDS) = NUM_ROUNDS * (NUM_ROUNDS - + // next_round) + ctx.constr(eq( + (round + 1 - next_round) * (next_round - NUM_ROUNDS) + + (round + 1 - NUM_ROUNDS) * NUM_ROUNDS + - (round + 1 - next_round) * (round + 1 - NUM_ROUNDS) * 2, + (next_round - NUM_ROUNDS) * NUM_ROUNDS, + )); + ctx.transition(eq(next_round, round.next())); + ctx.transition(eq(input_len, input_len.next())); + ctx.transition(eq(data_rlc, data_rlc.next())); + ctx.transition(eq(padded, padded.next())); + } + }); + + ctx.wg::, _>(move |ctx, values| { + ctx.assign(round, values.round); + ctx.assign(round_cst, values.round_cst); + if last == 0 { + ctx.assign(next_round, values.next_round); + } + for (q, v) in s_vec.iter().zip(values.s_vec.iter()) { + ctx.assign(*q, *v) + } + for (q_vec, v_vec) in theta_split_vec.iter().zip(values.theta_split_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in theta_split_xor_vec + .iter() + .zip(values.theta_split_xor_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in theta_sum_split_vec + .iter() + .zip(values.theta_sum_split_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in theta_sum_split_xor_vec + .iter() + .zip(values.theta_sum_split_xor_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q, v) in rho_bit_0.iter().zip(values.rho_bit_0.iter()) { + ctx.assign(*q, *v) + } + for (q, v) in rho_bit_1.iter().zip(values.rho_bit_1.iter()) { + ctx.assign(*q, *v) + } + for (q_vec, v_vec) in chi_split_value_vec + .iter() + .zip(values.chi_split_value_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q, v) in final_sum_split_vec + .iter() + .zip(values.final_sum_split_vec.iter()) + { + ctx.assign(*q, *v) + } + for (q, v) in final_xor_split_vec + .iter() + .zip(values.final_xor_split_vec.iter()) + { + ctx.assign(*q, *v) + } + if last == 1 { + for (q, v) in s_new_vec.iter().zip(values.svalues.s_new_vec.iter()) { + ctx.assign(*q, *v) + } + for (q_vec, v_vec) in squeeze_split_vec + .iter() + .zip(values.svalues.squeeze_split_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in squeeze_split_output_vec + .iter() + .zip(values.svalues.squeeze_split_output_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + ctx.assign(hash_rlc, values.svalues.hash_rlc); + } + ctx.assign(input_len, values.input_len); + ctx.assign(data_rlc, values.data_rlc); + ctx.assign(input_acc, values.input_acc); + ctx.assign(padded, values.padded); + }) + }) + }) + .collect(); + + ctx.pragma_first_step(&keccak_first_step); // keccak_pre_step + ctx.pragma_last_step(&keccak_one_round_step_vec[1]); // keccak_squeeze_step + ctx.pragma_num_steps(param.step_num); + + ctx.trace(move |ctx, params| { + let input_num = params.bytes.len(); + let mut bits = convert_bytes_to_bits(params.bytes); + println!("intput bits(without padding) = {:?}", bits); + // padding + bits.push(1); + while (bits.len() + 1) % RATE_IN_BITS as usize != 0 { + bits.push(0); + } + bits.push(1); + println!("intput bits(with padding) = {:?}", bits); + + let mut s_new = [F::ZERO; PART_SIZE_SQURE as usize]; + + // chunks + let chunks = bits.chunks(RATE_IN_BITS as usize); + let chunks_len = chunks.len(); + let mut data_rlc_value = F::ZERO; + let mut input_acc = F::ZERO; + // absorb + for (k, chunk) in chunks.enumerate() { + let s: Vec = s_new.to_vec(); + let absorbs: Vec = (0..PART_SIZE_SQURE as usize) + .map(|idx| { + let i = idx % PART_SIZE as usize; + let j = idx / PART_SIZE as usize; + let mut absorb = F::ZERO; + if idx < NUM_WORDS_TO_ABSORB as usize { + absorb = pack(&chunk[idx * 64..(idx + 1) * 64]); + s_new[i * PART_SIZE as usize + j] = + field_xor(s[i * PART_SIZE as usize + j], absorb); + } else { + s_new[i * PART_SIZE as usize + j] = s[i * PART_SIZE as usize + j]; + } + absorb + }) + .collect(); + + let absorb_split_vec: Vec> = (0..NUM_WORDS_TO_ABSORB as usize) + .map(|idx| { + let bits = chunk[idx * 64..(idx + 1) * 64].to_vec(); + (0..SQUEEZE_SPLIT_NUM as usize / 2) + .map(|k| { + F::from( + bits[k * 8] as u64 + + bits[k * 8 + 1] as u64 * 8 + + bits[k * 8 + 2] as u64 * 64 + + bits[k * 8 + 3] as u64 * 512 + + bits[k * 8 + 4] as u64 * 4096 + + bits[k * 8 + 5] as u64 * 8 * 4096 + + bits[k * 8 + 6] as u64 * 64 * 4096 + + bits[k * 8 + 7] as u64 * 512 * 4096, + ) + }) + .collect() + }) + .collect(); + + let absorb_split_input_vec: Vec> = (0..NUM_WORDS_TO_ABSORB as usize) + .map(|idx| { + let bits = chunk[idx * 64..(idx + 1) * 64].to_vec(); + (0..SQUEEZE_SPLIT_NUM as usize / 2) + .map(|k| { + F::from( + bits[k * 8] as u64 + + bits[k * 8 + 1] as u64 * 2 + + bits[k * 8 + 2] as u64 * 4 + + bits[k * 8 + 3] as u64 * 8 + + bits[k * 8 + 4] as u64 * 16 + + bits[k * 8 + 5] as u64 * 32 + + bits[k * 8 + 6] as u64 * 64 + + bits[k * 8 + 7] as u64 * 128, + ) + }) + .collect() + }) + .collect(); + + let mut is_padding_vec = vec![vec![F::ONE; 8]; NUM_WORDS_TO_ABSORB as usize]; + is_padding_vec = is_padding_vec + .iter() + .enumerate() + .map(|(i, is_paddings)| { + is_paddings + .iter() + .enumerate() + .take(8) + .map(|(j, &is_padding)| { + if input_num > k * 8 * NUM_WORDS_TO_ABSORB as usize + i * 8 + j { + F::ZERO + } else { + is_padding + } + }) + .collect() + }) + .collect(); + + let mut padded = F::ZERO; + if k == 0 { + let data_rlc = data_rlc_value; + let data_rlc_vec: Vec> = absorb_split_input_vec + .iter() + .zip(is_padding_vec.iter()) + .map(|(v1_vec, v2_vec)| { + v1_vec + .iter() + .zip(v2_vec.iter()) + .map(|(&v1, &v2)| { + if v2 == F::ZERO { + data_rlc_value = data_rlc_value * F::from(256) + v1 + } + data_rlc_value + }) + .collect() + }) + .collect(); + + let values = PreValues { + s_vec: s, + absorb_rows: absorbs[0..NUM_WORDS_TO_ABSORB as usize].to_vec(), + round_value: F::ZERO, + absorb_split_vec, + absorb_split_input_vec, + split_values: Vec::new(), + is_padding_vec: is_padding_vec.clone(), + input_len: F::from(input_num as u64), + data_rlc_vec, + data_rlc, + input_acc, + padded, + }; + ctx.add(&keccak_first_step, values); + } else { + let data_rlc = data_rlc_value; + let split_values = (0..NUM_WORDS_TO_ABSORB as usize) + .map(|t| { + let i = t % PART_SIZE as usize; + let j = t / PART_SIZE as usize; + let v = i * PART_SIZE as usize + j; + eval_keccak_f_to_bit_vec4::( + s[v] + absorbs[(v % PART_SIZE as usize) * PART_SIZE as usize + + (v / PART_SIZE as usize)], + s_new[v], + ) + }) + .collect(); + + let data_rlc_vec: Vec> = absorb_split_input_vec + .iter() + .zip(is_padding_vec.iter()) + .map(|(v1_vec, v2_vec)| { + v1_vec + .iter() + .zip(v2_vec.iter()) + .map(|(&v1, &v2)| { + if v2 == F::ZERO { + data_rlc_value = data_rlc_value * F::from(256) + v1 + } + data_rlc_value + }) + .collect() + }) + .collect(); + let values = PreValues { + s_vec: s, + absorb_rows: absorbs[0..NUM_WORDS_TO_ABSORB as usize].to_vec(), + split_values, + absorb_split_vec, + absorb_split_input_vec, + round_value: F::ZERO, + is_padding_vec: is_padding_vec.clone(), + input_len: F::from(input_num as u64), + data_rlc_vec, + data_rlc, + input_acc, + padded, + }; + ctx.add(&keccak_pre_step, values); + } + padded = is_padding_vec[NUM_WORDS_TO_ABSORB as usize - 1][7]; + + input_acc = is_padding_vec.iter().fold(input_acc, |acc, is_paddings| { + let v = is_paddings + .iter() + .fold(F::ZERO, |acc, is_padding| acc + (F::ONE - is_padding)); + acc + v + }); + + for (round, &cst) in ROUND_CST.iter().enumerate().take(NUM_ROUNDS as usize) { + let mut values = eval_keccak_f_one_round( + round as u64, + cst, + s_new.to_vec(), + input_num as u64, + data_rlc_value, + input_acc, + padded, + ); + s_new = values.s_new_vec.clone().try_into().unwrap(); + + if k != chunks_len - 1 || round != NUM_ROUNDS as usize - 1 { + ctx.add(&keccak_one_round_step_vec[0], values.clone()); + } else { + // squeezing + let mut squeeze_split_vec: Vec> = Vec::new(); + let mut squeeze_split_output_vec: Vec> = Vec::new(); + for i in 0..4 { + let bits = convert_field_to_vec_bits(s_new[(i * PART_SIZE) as usize]); + + squeeze_split_vec.push( + (0..SQUEEZE_SPLIT_NUM as usize / 2) + .map(|k| { + let value = bits[k * 8] as u64 + + bits[k * 8 + 1] as u64 * 8 + + bits[k * 8 + 2] as u64 * 64 + + bits[k * 8 + 3] as u64 * 512 + + bits[k * 8 + 4] as u64 * 4096 + + bits[k * 8 + 5] as u64 * 8 * 4096 + + bits[k * 8 + 6] as u64 * 64 * 4096 + + bits[k * 8 + 7] as u64 * 512 * 4096; + F::from(value) + }) + .collect(), + ); + squeeze_split_output_vec.push( + (0..SQUEEZE_SPLIT_NUM as usize / 2) + .map(|k| { + let value = bits[k * 8] as u64 + + bits[k * 8 + 1] as u64 * 2 + + bits[k * 8 + 2] as u64 * 4 + + bits[k * 8 + 3] as u64 * 8 + + bits[k * 8 + 4] as u64 * 16 + + bits[k * 8 + 5] as u64 * 32 + + bits[k * 8 + 6] as u64 * 64 + + bits[k * 8 + 7] as u64 * 128; + F::from(value) + }) + .collect(), + ); + } + let mut hash_rlc = F::ZERO; + for squeeze_split_output in squeeze_split_output_vec.iter().take(4) { + for output in squeeze_split_output + .iter() + .take(SQUEEZE_SPLIT_NUM as usize / 2) + { + hash_rlc = hash_rlc * F::from(256) + output; + } + } + values.svalues = SqueezeValues { + s_new_vec: s_new.to_vec(), + squeeze_split_vec, + squeeze_split_output_vec, + hash_rlc, + }; + ctx.add(&keccak_one_round_step_vec[1], values); + } + } + } + + let output2: Vec> = (0..4) + .map(|k| { + pack_with_base::( + &convert_field_to_vec_bits(s_new[(k * PART_SIZE) as usize]), + 2, + ) + .to_repr() + .into_iter() + .take(8) + .collect::>() + .to_vec() + }) + .collect(); + println!("output = {:x?}", output2.concat()); + }); +} + +#[derive(Default)] +struct KeccakCircuit { + // pub bits: Vec, + pub bytes: Vec, +} + +struct CircuitParams { + pub constants_table: LookupTable, + pub xor_table: LookupTable, + pub xor_table_batch3: LookupTable, + pub xor_table_batch4: LookupTable, + pub chi_table: LookupTable, + pub pack_table: LookupTable, + pub step_num: usize, +} + +fn keccak_super_circuit + Eq + Hash>( + input_len: usize, +) -> SuperCircuit { + super_circuit::("keccak", |ctx| { + let in_n = (input_len * 8 + 1 + RATE_IN_BITS as usize) / RATE_IN_BITS as usize; + let step_num = in_n * (1 + NUM_ROUNDS as usize); + + let single_config = config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}); + // config(SingleRowCellManager {}, LogNSelectorBuilder {}); + + let (_, constants_table) = ctx.sub_circuit( + single_config.clone(), + keccak_round_constants_table, + NUM_ROUNDS as usize + 1, + ); + let (_, xor_table) = ctx.sub_circuit(single_config.clone(), keccak_xor_table_batch2, 36); + let (_, xor_table_batch3) = + ctx.sub_circuit(single_config.clone(), keccak_xor_table_batch3, 64); + let (_, xor_table_batch4) = + ctx.sub_circuit(single_config.clone(), keccak_xor_table_batch4, 81); + let (_, chi_table) = ctx.sub_circuit(single_config.clone(), keccak_chi_table, 125); + let (_, pack_table) = ctx.sub_circuit(single_config, keccak_pack_table, 0); + + let params = CircuitParams { + constants_table, + xor_table, + xor_table_batch3, + xor_table_batch4, + chi_table, + pack_table, + step_num, + }; + + let maxwidth_config = config( + MaxWidthCellManager::new(198, true), + SimpleStepSelectorBuilder {}, + ); + let (keccak, _) = ctx.sub_circuit(maxwidth_config, keccak_circuit, params); + + ctx.mapping(move |ctx, values| { + ctx.map(&keccak, values); + }) + }) +} + +use chiquito::plonkish::backend::plaf::chiquito2Plaf; +use polyexen::plaf::{Plaf, PlafDisplayBaseTOML, PlafDisplayFixedCSV, Witness, WitnessDisplayCSV}; + +fn write_files(name: &str, plaf: &Plaf, wit: &Witness) -> Result<(), io::Error> { + let mut base_file = File::create(format!("{}.toml", name))?; + let mut fixed_file = File::create(format!("{}_fixed.csv", name))?; + let mut witness_file = File::create(format!("{}_witness.csv", name))?; + + write!(base_file, "{}", PlafDisplayBaseTOML(plaf))?; + write!(fixed_file, "{}", PlafDisplayFixedCSV(plaf))?; + write!(witness_file, "{}", WitnessDisplayCSV(wit))?; + println!("write file success...{}", name); + Ok(()) +} + +fn keccak_plaf(circuit_param: KeccakCircuit, k: u32) { + let super_circuit = keccak_super_circuit::(circuit_param.bytes.len()); + let witness = super_circuit.get_mapping().generate(circuit_param); + + for wit_gen in witness.values() { + let wit_gen = wit_gen.clone(); + + let mut circuit = super_circuit.get_sub_circuits()[0].clone(); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[1].columns); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[2].columns); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[3].columns); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[4].columns); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[5].columns); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[6].columns); + + for (key, value) in super_circuit.get_sub_circuits()[0].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + for (key, value) in super_circuit.get_sub_circuits()[1].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + for (key, value) in super_circuit.get_sub_circuits()[2].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + for (key, value) in super_circuit.get_sub_circuits()[3].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + for (key, value) in super_circuit.get_sub_circuits()[4].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + for (key, value) in super_circuit.get_sub_circuits()[5].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + + let (plaf, plaf_wit_gen) = chiquito2Plaf(circuit, k, false); + + let mut plaf = plaf; + plaf.set_challange_alias(0, "r_keccak".to_string()); + let wit = plaf_wit_gen.generate(Some(wit_gen)); + write_files("keccak_output", &plaf, &wit).unwrap(); + } +} + +fn keccak_run(circuit_param: KeccakCircuit, k: u32) -> bool { + let super_circuit = keccak_super_circuit::(circuit_param.bytes.len()); + + let compiled = chiquitoSuperCircuit2Halo2(&super_circuit); + + let circuit = ChiquitoHalo2SuperCircuit::new( + compiled, + super_circuit.get_mapping().generate(circuit_param), + ); + + let prover = MockProver::::run(k, &circuit, Vec::new()).unwrap(); + let result = prover.verify_par(); + + println!("result = {:#?}", result); + + if let Err(failures) = &result { + for failure in failures.iter() { + println!("{}", failure); + } + false + } else { + true + } +} + +fn main() { + let circuit_param = KeccakCircuit { + bytes: vec![0, 1, 2, 3, 4, 5, 6, 7], + }; + + let res = keccak_run(circuit_param, 9); + + if res { + keccak_plaf( + KeccakCircuit { + bytes: vec![0, 1, 2, 3, 4, 5, 6, 7], + }, + 11, + ); + } +} diff --git a/examples/poseidon.rs b/examples/poseidon.rs index 0b626235..f5bdcc21 100644 --- a/examples/poseidon.rs +++ b/examples/poseidon.rs @@ -484,7 +484,6 @@ fn poseidon_circuit( x_value }) .collect(); - let mut sbox_values: Vec = x_values .iter() .map(|x_value| *x_value * x_value * x_value * x_value * x_value) @@ -513,7 +512,6 @@ fn poseidon_circuit( out_values: outputs.clone(), round: F::ZERO, }; - ctx.add(&poseidon_step_first_round, round_values); inputs = outputs; diff --git a/src/plonkish/compiler/step_selector.rs b/src/plonkish/compiler/step_selector.rs index 9207c976..0e528512 100644 --- a/src/plonkish/compiler/step_selector.rs +++ b/src/plonkish/compiler/step_selector.rs @@ -202,7 +202,7 @@ impl StepSelectorBuilder for LogNSelectorBuilder { let n_step_types = unit.step_types.len() as u64; let n_cols = (n_step_types as f64 + 1.0).log2().ceil() as u64; - + println!("n_step_types = {}, n_cols = {}", n_step_types, n_cols); let mut annotation; for index in 0..n_cols { annotation = format!("'step selector for binary column {}'", index); From 428f6db0c3b54cee0fede065e7f4a48a3aad84f2 Mon Sep 17 00:00:00 2001 From: Leo Lara Date: Tue, 2 Apr 2024 11:01:06 +0700 Subject: [PATCH 6/8] Update dependencies 2024-03-26 (#235) --- Cargo.toml | 19 ++++++++++--------- examples/blake2f.rs | 2 +- examples/factorial.rs | 4 ++-- examples/fibo_with_padding.rs | 4 ++-- examples/fibonacci.rs | 4 ++-- examples/mimc7.rs | 2 +- examples/poseidon.rs | 4 ++-- src/frontend/dsl/sc.rs | 2 +- src/frontend/python/mod.rs | 10 ++++++++-- src/plonkish/compiler/mod.rs | 3 +-- src/plonkish/compiler/step_selector.rs | 2 +- src/plonkish/ir/sc.rs | 2 +- src/poly/mielim.rs | 2 +- src/poly/mod.rs | 2 +- src/poly/reduce.rs | 3 +-- src/poly/simplify.rs | 2 +- 16 files changed, 36 insertions(+), 31 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5ddc4de0..b5622a6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,15 +7,21 @@ authors = ["Leo Lara "] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[patch.crates-io] +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v0.3.0" } + +[patch."https://github.com/scroll-tech/halo2.git"] +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v0.3.0" } + + [dependencies] pyo3 = { version = "0.19.1", features = ["extension-module"] } halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", features = [ "circuit-params", -], tag = "v2023_04_20" } -halo2curves = { git = 'https://github.com/privacy-scaling-explorations/halo2curves', tag = "0.3.2", features = [ "derive_serde", -] } -polyexen = { git = "https://github.com/Dhole/polyexen.git", rev = "4d128ad2ebd0094160ea77e30fb9ce56abb854e0" } +], tag = "v0.3.0" } + +polyexen = { git = "https://github.com/Dhole/polyexen.git", rev = "16a85c5411f804dc49bbf373d24ff9eedadedfbe" } num-bigint = { version = "0.4", features = ["rand"] } uuid = { version = "1.4.0", features = ["v1", "rng"] } serde = { version = "1.0", features = ["derive"] } @@ -26,8 +32,3 @@ regex = "1" [dev-dependencies] rand_chacha = "0.3" - -[patch."https://github.com/privacy-scaling-explorations/halo2.git"] -halo2_proofs = { git = "https://github.com/appliedzkp/halo2.git", rev = "d3746109d7d38be53afc8ddae8fdfaf1f02ad1d7", features = [ - "circuit-params", -] } diff --git a/examples/blake2f.rs b/examples/blake2f.rs index f4a2848f..1ce2db7f 100644 --- a/examples/blake2f.rs +++ b/examples/blake2f.rs @@ -1487,7 +1487,7 @@ fn main() { ChiquitoHalo2SuperCircuit::new(compiled, super_circuit.get_mapping().generate(values)); let prover = MockProver::run(9, &circuit, Vec::new()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("result = {:#?}", result); diff --git a/examples/factorial.rs b/examples/factorial.rs index c927ab49..e47702bf 100644 --- a/examples/factorial.rs +++ b/examples/factorial.rs @@ -136,7 +136,7 @@ fn main() { let prover = MockProver::::run(10, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("result = {:#?}", result); @@ -167,7 +167,7 @@ fn main() { // same as halo2 boilerplate above let prover_plaf = MockProver::::run(8, &plaf_circuit, Vec::new()).unwrap(); - let result_plaf = prover_plaf.verify_par(); + let result_plaf = prover_plaf.verify(); println!("result = {:#?}", result_plaf); diff --git a/examples/fibo_with_padding.rs b/examples/fibo_with_padding.rs index 403f2ae2..6a4fc481 100644 --- a/examples/fibo_with_padding.rs +++ b/examples/fibo_with_padding.rs @@ -206,7 +206,7 @@ fn main() { let prover = MockProver::::run(7, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("{:#?}", result); @@ -237,7 +237,7 @@ fn main() { // same as halo2 boilerplate above let prover_plaf = MockProver::::run(8, &plaf_circuit, plaf_circuit.instance()).unwrap(); - let result_plaf = prover_plaf.verify_par(); + let result_plaf = prover_plaf.verify(); println!("result = {:#?}", result_plaf); diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 2fee596f..b01d00e3 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -133,7 +133,7 @@ fn main() { let prover = MockProver::::run(7, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("{:#?}", result); @@ -164,7 +164,7 @@ fn main() { // same as halo2 boilerplate above let prover_plaf = MockProver::::run(8, &plaf_circuit, plaf_circuit.instance()).unwrap(); - let result_plaf = prover_plaf.verify_par(); + let result_plaf = prover_plaf.verify(); println!("result = {:#?}", result_plaf); diff --git a/examples/mimc7.rs b/examples/mimc7.rs index 1eafeb0e..595e02b1 100644 --- a/examples/mimc7.rs +++ b/examples/mimc7.rs @@ -210,7 +210,7 @@ fn main() { let prover = MockProver::::run(10, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("result = {:#?}", result); diff --git a/examples/poseidon.rs b/examples/poseidon.rs index f5bdcc21..aeb476e3 100644 --- a/examples/poseidon.rs +++ b/examples/poseidon.rs @@ -11,7 +11,7 @@ use chiquito::{ }, sbpir::query::Queriable, }; -// use halo2curves::ff::Field; + use std::hash::Hash; use halo2_proofs::{ @@ -710,7 +710,7 @@ fn main() { let prover = MockProver::::run(12, &circuit, Vec::new()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("result = {:#?}", result); diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index 8e3fdcd2..3997ca11 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -124,7 +124,7 @@ where #[cfg(test)] mod tests { - use halo2curves::{bn256::Fr, ff::PrimeField}; + use halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; use crate::{ plonkish::compiler::{ diff --git a/src/frontend/python/mod.rs b/src/frontend/python/mod.rs index 69b3abc3..1a752878 100644 --- a/src/frontend/python/mod.rs +++ b/src/frontend/python/mod.rs @@ -142,7 +142,7 @@ pub fn chiquito_super_circuit_halo2_mock_prover( let prover = MockProver::::run(k as u32, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("result = {:#?}", result); @@ -175,7 +175,7 @@ pub fn chiquito_halo2_mock_prover(witness_json: &str, rust_id: UUID, k: usize) { let prover = MockProver::::run(k as u32, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("{:#?}", result); @@ -919,6 +919,8 @@ impl<'de> Deserialize<'de> for SBPIR { #[cfg(test)] mod tests { use super::*; + + #[ignore] #[test] fn test_trace_witness() { let json = r#" @@ -1063,6 +1065,7 @@ mod tests { let _: ExposeOffset = serde_json::from_str(json).unwrap(); } + #[ignore] #[test] fn test_circuit() { let json = r#" @@ -1534,6 +1537,7 @@ mod tests { println!("{:?}", circuit); } + #[ignore] #[test] fn test_step_type() { let json = r#" @@ -1669,6 +1673,7 @@ mod tests { println!("{:?}", step_type); } + #[ignore] #[test] fn test_constraint() { let json = r#" @@ -1752,6 +1757,7 @@ mod tests { println!("{:?}", transition_constraint); } + #[ignore] #[test] fn test_expr() { let json = r#" diff --git a/src/plonkish/compiler/mod.rs b/src/plonkish/compiler/mod.rs index d1a5c9d1..b36a2d3a 100644 --- a/src/plonkish/compiler/mod.rs +++ b/src/plonkish/compiler/mod.rs @@ -568,8 +568,7 @@ fn add_halo2_columns(unit: &mut CompilationUnit, ast: &astCircu #[cfg(test)] mod test { - use halo2_proofs::plonk::Any; - use halo2curves::bn256::Fr; + use halo2_proofs::{halo2curves::bn256::Fr, plonk::Any}; use super::{cell_manager::SingleRowCellManager, step_selector::SimpleStepSelectorBuilder, *}; diff --git a/src/plonkish/compiler/step_selector.rs b/src/plonkish/compiler/step_selector.rs index 0e528512..390c1c80 100644 --- a/src/plonkish/compiler/step_selector.rs +++ b/src/plonkish/compiler/step_selector.rs @@ -258,7 +258,7 @@ fn other_step_type(unit: &CompilationUnit, uuid: UUID) -> Option MappingGenerator { #[cfg(test)] mod test { - use halo2curves::bn256::Fr; + use halo2_proofs::halo2curves::bn256::Fr; use crate::{ plonkish::{ diff --git a/src/poly/mielim.rs b/src/poly/mielim.rs index 190cf08c..63b0a554 100644 --- a/src/poly/mielim.rs +++ b/src/poly/mielim.rs @@ -67,7 +67,7 @@ fn mi_elimination_recursive< #[cfg(test)] mod test { - use halo2curves::bn256::Fr; + use halo2_proofs::halo2curves::bn256::Fr; use crate::{ poly::{mielim::mi_elimination, Expr}, diff --git a/src/poly/mod.rs b/src/poly/mod.rs index f9adac97..0f8bb296 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -270,7 +270,7 @@ impl ConstrDecomp { #[cfg(test)] mod test { - use halo2curves::bn256::Fr; + use halo2_proofs::halo2curves::bn256::Fr; use crate::{field::Field, poly::VarAssignments}; diff --git a/src/poly/reduce.rs b/src/poly/reduce.rs index a928a26a..64625788 100644 --- a/src/poly/reduce.rs +++ b/src/poly/reduce.rs @@ -181,8 +181,7 @@ fn reduce_degree_mul( #[cfg(test)] mod test { - use halo2curves::bn256::Fr; + use halo2_proofs::halo2curves::bn256::Fr; use crate::{ poly::{ From c81ab2888dc25d5bb9f9b549e34e6eefe0e40a4b Mon Sep 17 00:00:00 2001 From: Leo Lara Date: Tue, 2 Apr 2024 04:18:14 +0000 Subject: [PATCH 7/8] Fix small issue in example --- examples/keccak.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/keccak.rs b/examples/keccak.rs index da3bb7eb..bc06c93b 100644 --- a/examples/keccak.rs +++ b/examples/keccak.rs @@ -2316,7 +2316,7 @@ fn keccak_plaf(circuit_param: KeccakCircuit, k: u32) { let (plaf, plaf_wit_gen) = chiquito2Plaf(circuit, k, false); let mut plaf = plaf; - plaf.set_challange_alias(0, "r_keccak".to_string()); + plaf.set_challenge_alias(0, "r_keccak".to_string()); let wit = plaf_wit_gen.generate(Some(wit_gen)); write_files("keccak_output", &plaf, &wit).unwrap(); } @@ -2333,7 +2333,7 @@ fn keccak_run(circuit_param: KeccakCircuit, k: u32) -> bool { ); let prover = MockProver::::run(k, &circuit, Vec::new()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("result = {:#?}", result); From 93282e00b925fd3d61ea32c0f48fd8385f385b51 Mon Sep 17 00:00:00 2001 From: Rute Figueiredo Date: Wed, 8 May 2024 22:32:35 +0100 Subject: [PATCH 8/8] Rutefig/237 fix python front end (#240) --- src/frontend/dsl/mod.rs | 3 - src/frontend/python/chiquito/cb.py | 4 +- src/frontend/python/chiquito/chiquito_ast.py | 2 +- src/frontend/python/chiquito/dsl.py | 4 +- src/frontend/python/chiquito/expr.py | 4 +- src/frontend/python/chiquito/query.py | 2 +- src/frontend/python/chiquito/util.py | 12 +- src/frontend/python/chiquito/wit_gen.py | 2 +- src/frontend/python/mod.rs | 292 ++++++++++--------- 9 files changed, 162 insertions(+), 163 deletions(-) diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index ce93ca46..e52313b5 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -600,15 +600,12 @@ mod tests { } #[test] - #[ignore] #[should_panic(expected = "Signal not found")] fn test_expose_non_existing_signal() { let mut context = setup_circuit_context::(); let non_existing_signal = Queriable::Forward(ForwardSignal::new_with_phase(0, "".to_owned()), false); // Create a signal not added to the circuit context.expose(non_existing_signal, ExposeOffset::First); - - todo!("remove the ignore after fixing the check for non existing signals") } #[test] diff --git a/src/frontend/python/chiquito/cb.py b/src/frontend/python/chiquito/cb.py index 6d25d920..5828fe09 100644 --- a/src/frontend/python/chiquito/cb.py +++ b/src/frontend/python/chiquito/cb.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field from enum import Enum, auto -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union from chiquito.util import F, uuid from chiquito.expr import Expr, Const, Neg, to_expr, ToExpr @@ -205,7 +205,7 @@ def table() -> LookupTable: return LookupTable() -ToConstraint = Constraint | Expr | int | F +ToConstraint = Union[Constraint, Expr, int, F] def to_constraint(v: ToConstraint) -> Constraint: diff --git a/src/frontend/python/chiquito/chiquito_ast.py b/src/frontend/python/chiquito/chiquito_ast.py index fc6e6d0b..3b44505a 100644 --- a/src/frontend/python/chiquito/chiquito_ast.py +++ b/src/frontend/python/chiquito/chiquito_ast.py @@ -126,7 +126,7 @@ def __json__(self: ASTCircuit): "last_step": self.last_step, "num_steps": self.num_steps, "q_enable": self.q_enable, - "id": self.id, + "id": self.id.__str__(), } def add_forward(self: ASTCircuit, name: str, phase: int) -> ForwardSignal: diff --git a/src/frontend/python/chiquito/dsl.py b/src/frontend/python/chiquito/dsl.py index ee41004f..51f089c7 100644 --- a/src/frontend/python/chiquito/dsl.py +++ b/src/frontend/python/chiquito/dsl.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import List, Dict +from typing import List, Dict, Union from enum import Enum from typing import Callable, Any @@ -295,4 +295,4 @@ def add_lookup(self: StepType, lookup_builder: LookupBuilder): self.step_type.lookups.append(lookup) -LookupBuilder = LookupTableBuilder | InPlaceLookupBuilder +LookupBuilder = Union[LookupTableBuilder, InPlaceLookupBuilder] diff --git a/src/frontend/python/chiquito/expr.py b/src/frontend/python/chiquito/expr.py index 22e11396..41f4d93f 100644 --- a/src/frontend/python/chiquito/expr.py +++ b/src/frontend/python/chiquito/expr.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import List +from typing import List, Union from dataclasses import dataclass from chiquito.util import F @@ -141,7 +141,7 @@ def __json__(self): return {"Pow": [self.expr.__json__(), self.pow]} -ToExpr = Expr | int | F +ToExpr = Union[Expr, int, F] def to_expr(v: ToExpr) -> Expr: diff --git a/src/frontend/python/chiquito/query.py b/src/frontend/python/chiquito/query.py index 9dafb6c2..0eeb52e7 100644 --- a/src/frontend/python/chiquito/query.py +++ b/src/frontend/python/chiquito/query.py @@ -134,5 +134,5 @@ def __str__(self: ASTStepType) -> str: def __json__(self): return { - "StepTypeNext": {"id": self.step_type.id, "annotation": self.step_type.name} + "StepTypeNext": {"id": f"{self.step_type.id}", "annotation": self.step_type.name} } diff --git a/src/frontend/python/chiquito/util.py b/src/frontend/python/chiquito/util.py index 0533fd6f..a9621acb 100644 --- a/src/frontend/python/chiquito/util.py +++ b/src/frontend/python/chiquito/util.py @@ -14,11 +14,9 @@ def __json__(self: F): # Convert the integer to a byte array montgomery_form = self.n * R % F.field_modulus byte_array = montgomery_form.to_bytes(32, "little") - # Split into four 64-bit integers - ints = [ - int.from_bytes(byte_array[i * 8 : i * 8 + 8], "little") for i in range(4) - ] - return ints + + # return the hex string + return byte_array.hex() class CustomEncoder(json.JSONEncoder): @@ -29,5 +27,5 @@ def default(self, obj): # int field is the u128 version of uuid. -def uuid() -> int: - return uuid1(node=int.from_bytes([10, 10, 10, 10, 10, 10], byteorder="little")).int +def uuid() -> str: + return uuid1(node=int.from_bytes([10, 10, 10, 10, 10, 10], byteorder="little")).int.__str__() diff --git a/src/frontend/python/chiquito/wit_gen.py b/src/frontend/python/chiquito/wit_gen.py index cec59905..2ab1bf39 100644 --- a/src/frontend/python/chiquito/wit_gen.py +++ b/src/frontend/python/chiquito/wit_gen.py @@ -41,7 +41,7 @@ def __str__(self: StepInstance): # For assignments, return "uuid: (Queriable, F)" rather than "Queriable: F", because JSON doesn't accept Dict as key. def __json__(self: StepInstance): return { - "step_type_uuid": self.step_type_uuid, + "step_type_uuid": self.step_type_uuid.__str__(), "assignments": { lhs.uuid(): [lhs, rhs] for (lhs, rhs) in self.assignments.items() }, diff --git a/src/frontend/python/mod.rs b/src/frontend/python/mod.rs index 1a752878..32a25726 100644 --- a/src/frontend/python/mod.rs +++ b/src/frontend/python/mod.rs @@ -2,6 +2,7 @@ use pyo3::{ prelude::*, types::{PyDict, PyList, PyLong, PyString}, }; +use serde_json::{from_str, Value}; use crate::{ frontend::dsl::{StepTypeHandler, SuperCircuitContext}, @@ -47,8 +48,10 @@ thread_local! { /// as the key. Return the Rust UUID to Python. The last field of the tuple, `TraceWitness`, is left /// as None, for `chiquito_add_witness_to_rust_id` to insert. pub fn chiquito_ast_to_halo2(ast_json: &str) -> UUID { + let value: Value = from_str(ast_json).expect("Invalid JSON"); + // Attempt to convert `Value` into `SBPIR` let circuit: SBPIR = - serde_json::from_str(ast_json).expect("Json deserialization to Circuit failed."); + serde_json::from_value(value).expect("Deserialization to Circuit failed."); let config = config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}); let (chiquito, assignment_generator) = compile(config, &circuit); @@ -212,13 +215,18 @@ impl<'de> Visitor<'de> for CircuitVisitor { let mut q_enable = None; let mut id = None; + println!("------ Visiting map -------"); + while let Some(key) = map.next_key::()? { + println!("key = {}", key); match key.as_str() { "step_types" => { + println!("------ Visiting step_types -------"); if step_types.is_some() { return Err(de::Error::duplicate_field("step_types")); } step_types = Some(map.next_value::>>()?); + println!("step_types = {:#?}", step_types); } "forward_signals" => { if forward_signals.is_some() { @@ -261,13 +269,33 @@ impl<'de> Visitor<'de> for CircuitVisitor { if first_step.is_some() { return Err(de::Error::duplicate_field("first_step")); } - first_step = Some(map.next_value::>()?); + let first_step_opt: Option = map.next_value()?; // Deserialize the value as an optional string + first_step = Some(first_step_opt.map_or(Ok(None), |first_step_str| { + StepTypeUUID::from_str_radix(&first_step_str, 10) + .map(Some) + .map_err(|e| { + de::Error::custom(format!( + "Failed to parse first_step '{}': {}", + first_step_str, e + )) + }) + })?); } "last_step" => { if last_step.is_some() { return Err(de::Error::duplicate_field("last_step")); } - last_step = Some(map.next_value::>()?); + let last_step_opt: Option = map.next_value()?; // Deserialize the value as an optional string + last_step = Some(last_step_opt.map_or(Ok(None), |last_step_str| { + StepTypeUUID::from_str_radix(&last_step_str, 10) + .map(Some) + .map_err(|e| { + de::Error::custom(format!( + "Failed to parse last_step '{}': {}", + last_step_str, e + )) + }) + })?); } "num_steps" => { if num_steps.is_some() { @@ -285,7 +313,10 @@ impl<'de> Visitor<'de> for CircuitVisitor { if id.is_some() { return Err(de::Error::duplicate_field("id")); } - id = Some(map.next_value()?); + let id_str: String = map.next_value()?; + id = Some(id_str.parse::().map_err(|e| { + de::Error::custom(format!("Failed to parse id '{}': {}", id_str, e)) + })?); } _ => { return Err(de::Error::unknown_field( @@ -376,7 +407,10 @@ impl<'de> Visitor<'de> for StepTypeVisitor { if id.is_some() { return Err(de::Error::duplicate_field("id")); } - id = Some(map.next_value()?); + let id_str: String = map.next_value()?; + id = Some(id_str.parse::().map_err(|e| { + de::Error::custom(format!("Failed to parse id '{}': {}", id_str, e)) + })?); } "name" => { if name.is_some() { @@ -629,6 +663,7 @@ impl<'de> Visitor<'de> for QueriableVisitor { let key: String = map .next_key()? .ok_or_else(|| de::Error::custom("map is empty"))?; + match key.as_str() { "Internal" => map.next_value().map(Queriable::Internal), "Forward" => map @@ -637,9 +672,11 @@ impl<'de> Visitor<'de> for QueriableVisitor { "Shared" => map .next_value() .map(|(signal, rotation)| Queriable::Shared(signal, rotation)), - "Fixed" => map - .next_value() - .map(|(signal, rotation)| Queriable::Fixed(signal, rotation)), + "Fixed" => { + println!("Processing Fixed"); + map.next_value() + .map(|(signal, rotation)| Queriable::Fixed(signal, rotation)) + } "StepTypeNext" => map.next_value().map(Queriable::StepTypeNext), _ => Err(de::Error::unknown_variant( &key, @@ -703,7 +740,10 @@ macro_rules! impl_visitor_internal_fixed_steptypehandler { if id.is_some() { return Err(de::Error::duplicate_field("id")); } - id = Some(map.next_value()?); + let id_str: String = map.next_value()?; // Get the UUID as a string + id = Some(id_str.parse::().map_err(|e| { + de::Error::custom(format!("Failed to parse id '{}': {}", id_str, e)) + })?); } "annotation" => { if annotation.is_some() { @@ -759,7 +799,10 @@ macro_rules! impl_visitor_forward_shared { if id.is_some() { return Err(de::Error::duplicate_field("id")); } - id = Some(map.next_value()?); + let id_str: String = map.next_value()?; // Get the UUID as a string + id = Some(id_str.parse::().map_err(|e| { + de::Error::custom(format!("Failed to parse id '{}': {}", id_str, e)) + })?); } "phase" => { if phase.is_some() { @@ -848,7 +891,12 @@ impl<'de> Visitor<'de> for StepInstanceVisitor { if step_type_uuid.is_some() { return Err(de::Error::duplicate_field("step_type_uuid")); } - step_type_uuid = Some(map.next_value()?); + let uuid_str: String = map.next_value()?; // Get the UUID as a string + step_type_uuid = Some( + uuid_str + .parse::() // Assuming the string is in decimal format + .map_err(de::Error::custom)?, + ); } "assignments" => { if assignments.is_some() { @@ -920,119 +968,89 @@ impl<'de> Deserialize<'de> for SBPIR { mod tests { use super::*; - #[ignore] #[test] + #[ignore] fn test_trace_witness() { let json = r#" { "step_instances": [ { - "step_type_uuid": 270606747459021742275781620564109167114, + "step_type_uuid": "270606747459021742275781620564109167114", "assignments": { "270606737951642240564318377467548666378": [ { "Forward": [ { - "id": 270606737951642240564318377467548666378, + "id": "270606737951642240564318377467548666378", "phase": 0, "annotation": "a" }, false ] }, - [ - 55, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000055" ], "270606743497613616562965561253747624458": [ { "Forward": [ { - "id": 270606743497613616562965561253747624458, + "id": "270606743497613616562965561253747624458", "phase": 0, "annotation": "b" }, false ] }, - [ - 89, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000089" ], "270606753004993118272949371872716917258": [ { "Internal": { - "id": 270606753004993118272949371872716917258, + "id": "270606753004993118272949371872716917258", "annotation": "c" } }, - [ - 144, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000144" ] } }, { - "step_type_uuid": 270606783111694873693576112554652600842, + "step_type_uuid": "270606783111694873693576112554652600842", "assignments": { "270606737951642240564318377467548666378": [ { "Forward": [ { - "id": 270606737951642240564318377467548666378, + "id": "270606737951642240564318377467548666378", "phase": 0, "annotation": "a" }, false ] }, - [ - 89, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000089" ], "270606743497613616562965561253747624458": [ { "Forward": [ { - "id": 270606743497613616562965561253747624458, + "id": "270606743497613616562965561253747624458", "phase": 0, "annotation": "b" }, false ] }, - [ - 144, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000144" ], "270606786280821374261518951164072823306": [ { "Internal": { - "id": 270606786280821374261518951164072823306, + "id": "270606786280821374261518951164072823306", "annotation": "c" } }, - [ - 233, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000233" ] } } @@ -1065,18 +1083,17 @@ mod tests { let _: ExposeOffset = serde_json::from_str(json).unwrap(); } - #[ignore] #[test] fn test_circuit() { let json = r#" { "step_types": { "258869595755756204079859764249309612554": { - "id": 258869595755756204079859764249309612554, + "id": "258869595755756204079859764249309612554", "name": "fibo_first_step", "signals": [ { - "id": 258869599717164329791616633222308956682, + "id": "258869599717164329791616633222308956682", "annotation": "c" } ], @@ -1088,7 +1105,7 @@ mod tests { { "Forward": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, @@ -1097,12 +1114,7 @@ mod tests { }, { "Neg": { - "Const": [ - 1, - 0, - 0, - 0 - ] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" } } ] @@ -1115,7 +1127,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1124,12 +1136,7 @@ mod tests { }, { "Neg": { - "Const": [ - 1, - 0, - 0, - 0 - ] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" } } ] @@ -1142,7 +1149,7 @@ mod tests { { "Forward": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, @@ -1152,7 +1159,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1162,7 +1169,7 @@ mod tests { { "Neg": { "Internal": { - "id": 258869599717164329791616633222308956682, + "id": "258869599717164329791616633222308956682", "annotation": "c" } } @@ -1179,7 +1186,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1190,7 +1197,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, @@ -1207,7 +1214,7 @@ mod tests { "Sum": [ { "Internal": { - "id": 258869599717164329791616633222308956682, + "id": "258869599717164329791616633222308956682", "annotation": "c" } }, @@ -1215,7 +1222,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1233,7 +1240,7 @@ mod tests { { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1244,7 +1251,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1262,11 +1269,11 @@ mod tests { } }, "258869628239302834927102989021255174666": { - "id": 258869628239302834927102989021255174666, + "id": "258869628239302834927102989021255174666", "name": "fibo_step", "signals": [ { - "id": 258869632200710960639812650790420089354, + "id": "258869632200710960639812650790420089354", "annotation": "c" } ], @@ -1278,7 +1285,7 @@ mod tests { { "Forward": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, @@ -1288,7 +1295,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1298,7 +1305,7 @@ mod tests { { "Neg": { "Internal": { - "id": 258869632200710960639812650790420089354, + "id": "258869632200710960639812650790420089354", "annotation": "c" } } @@ -1315,7 +1322,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1326,7 +1333,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, @@ -1343,7 +1350,7 @@ mod tests { "Sum": [ { "Internal": { - "id": 258869632200710960639812650790420089354, + "id": "258869632200710960639812650790420089354", "annotation": "c" } }, @@ -1351,7 +1358,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1369,7 +1376,7 @@ mod tests { { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1380,7 +1387,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1398,7 +1405,7 @@ mod tests { } }, "258869646461780213207493341245063432714": { - "id": 258869646461780213207493341245063432714, + "id": "258869646461780213207493341245063432714", "name": "padding", "signals": [], "constraints": [], @@ -1410,7 +1417,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1421,7 +1428,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1439,7 +1446,7 @@ mod tests { { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1450,7 +1457,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1468,17 +1475,17 @@ mod tests { }, "forward_signals": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" } @@ -1490,7 +1497,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1505,7 +1512,7 @@ mod tests { { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1526,31 +1533,30 @@ mod tests { "258869646461780213207493341245063432714": "padding" }, "fixed_assignments": null, - "first_step": 258869595755756204079859764249309612554, - "last_step": 258869646461780213207493341245063432714, + "first_step": "258869595755756204079859764249309612554", + "last_step": "258869646461780213207493341245063432714", "num_steps": 10, "q_enable": true, - "id": 258867373405797678961444396351437277706 + "id": "258867373405797678961444396351437277706" } "#; let circuit: SBPIR = serde_json::from_str(json).unwrap(); println!("{:?}", circuit); } - #[ignore] #[test] fn test_step_type() { let json = r#" { - "id":1, + "id":"1", "name":"fibo", "signals":[ { - "id":1, + "id":"1", "annotation":"a" }, { - "id":2, + "id":"2", "annotation":"b" } ], @@ -1560,18 +1566,18 @@ mod tests { "expr":{ "Sum":[ { - "Const":[1, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" }, { "Mul":[ { "Internal":{ - "id":3, + "id":"3", "annotation":"c" } }, { - "Const":[3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" } ] } @@ -1583,14 +1589,14 @@ mod tests { "expr":{ "Sum":[ { - "Const":[1, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" }, { "Mul":[ { "Shared":[ { - "id":4, + "id":"4", "phase":2, "annotation":"d" }, @@ -1598,7 +1604,7 @@ mod tests { ] }, { - "Const":[3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" } ] } @@ -1612,14 +1618,14 @@ mod tests { "expr":{ "Sum":[ { - "Const":[1, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" }, { "Mul":[ { "Forward":[ { - "id":5, + "id":"5", "phase":1, "annotation":"e" }, @@ -1627,7 +1633,7 @@ mod tests { ] }, { - "Const":[3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" } ] } @@ -1639,21 +1645,21 @@ mod tests { "expr":{ "Sum":[ { - "Const":[1, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" }, { "Mul":[ { "Fixed":[ { - "id":6, + "id":"6", "annotation":"e" }, 2 ] }, { - "Const":[3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" } ] } @@ -1673,7 +1679,6 @@ mod tests { println!("{:?}", step_type); } - #[ignore] #[test] fn test_constraint() { let json = r#" @@ -1683,14 +1688,14 @@ mod tests { "Sum": [ { "Internal": { - "id": 27, + "id": "27", "annotation": "a" } }, { "Fixed": [ { - "id": 28, + "id": "28", "annotation": "b" }, 1 @@ -1699,7 +1704,7 @@ mod tests { { "Shared": [ { - "id": 29, + "id": "29", "phase": 1, "annotation": "c" }, @@ -1709,7 +1714,7 @@ mod tests { { "Forward": [ { - "id": 30, + "id": "30", "phase": 2, "annotation": "d" }, @@ -1718,32 +1723,32 @@ mod tests { }, { "StepTypeNext": { - "id": 31, + "id": "31", "annotation": "e" } }, { - "Const": [3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" }, { "Mul": [ { - "Const": [4, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000004" }, { - "Const": [5, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000005" } ] }, { "Neg": { - "Const": [2, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000002" } }, { "Pow": [ { - "Const": [3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" }, 4 ] @@ -1757,7 +1762,6 @@ mod tests { println!("{:?}", transition_constraint); } - #[ignore] #[test] fn test_expr() { let json = r#" @@ -1765,14 +1769,14 @@ mod tests { "Sum": [ { "Internal": { - "id": 27, + "id": "27", "annotation": "a" } }, { "Fixed": [ { - "id": 28, + "id": "28", "annotation": "b" }, 1 @@ -1781,7 +1785,7 @@ mod tests { { "Shared": [ { - "id": 29, + "id": "29", "phase": 1, "annotation": "c" }, @@ -1791,7 +1795,7 @@ mod tests { { "Forward": [ { - "id": 30, + "id": "30", "phase": 2, "annotation": "d" }, @@ -1800,32 +1804,32 @@ mod tests { }, { "StepTypeNext": { - "id": 31, + "id": "31", "annotation": "e" } }, { - "Const": [3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" }, { "Mul": [ { - "Const": [4, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000004" }, { - "Const": [5, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000005" } ] }, { "Neg": { - "Const": [2, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000002" } }, { "Pow": [ { - "Const": [3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" }, 4 ]