From 38b0a5e0ff7d524e1b1c58a1ef6d889d4507bb92 Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Sun, 10 Sep 2023 15:53:59 -0400 Subject: [PATCH] hooked up frontend and backend --- examples/mimc7.py | 10 ++++-- src/frontend/python/chiquito/dsl.py | 5 ++- src/frontend/python/mod.rs | 52 ++++++++++++++++++----------- 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/examples/mimc7.py b/examples/mimc7.py index c6db64eb..634fd43d 100644 --- a/examples/mimc7.py +++ b/examples/mimc7.py @@ -51,13 +51,13 @@ def trace(self, args): self.add(self.mimc7_first_step, (x_value, k_value, c_value, row_value)) - for i in range(1, ROUND_KEYS): + for i in range(1, ROUNDS): row_value += F(1) x_value += k_value + c_value x_value = F(x_value**7) c_value = F(ROUND_KEYS[i]) - self.add(self.mimc_step, (x_value, k_value, c_value, row_value)) + self.add(self.mimc7_step, (x_value, k_value, c_value, row_value)) row_value += F(1) x_value += k_value + c_value @@ -176,4 +176,8 @@ def mapping(self, args): self.map(self.mimc7_circuit, (x_in_value, k_value)) -Mimc7SuperCircuit() +mimc7 = Mimc7SuperCircuit() +mimc7_witnesses = mimc7.gen_witness((1, 2)) +# for key, value in mimc7_witnesses.items(): +# print(f"{key}: {str(value)}") +mimc7.halo2_mock_prover(mimc7_witnesses) \ No newline at end of file diff --git a/src/frontend/python/chiquito/dsl.py b/src/frontend/python/chiquito/dsl.py index a33e429b..e83e371e 100644 --- a/src/frontend/python/chiquito/dsl.py +++ b/src/frontend/python/chiquito/dsl.py @@ -94,9 +94,8 @@ def halo2_mock_prover(self: SuperCircuit, witnesses: Dict[int, TraceWitness]): raise ValueError( f"SuperCircuit.halo2_mock_prover(): TraceWitness with rust_id {rust_id} not found in sub_circuits." ) - rust_chiquito.add_witness_to_ast(witness_json, rust_id) - for sub_circuit_id in self.ast.sub_circuits: - pass + rust_chiquito.add_witness_to_rust_id(witness_json, rust_id) + rust_chiquito.super_circuit_halo2_mock_prover(list(self.ast.sub_circuits.keys())) # def halo2_mock_prover(self: Circuit, witness: TraceWitness): # if self.rust_id == 0: diff --git a/src/frontend/python/mod.rs b/src/frontend/python/mod.rs index 995ff507..dd813e04 100644 --- a/src/frontend/python/mod.rs +++ b/src/frontend/python/mod.rs @@ -1,6 +1,6 @@ use pyo3::{ prelude::*, - types::{PyLong, PyString}, + types::{PyLong, PyString, PyList}, }; use crate::{ @@ -34,6 +34,7 @@ thread_local! { pub static CIRCUIT_MAP: CircuitMap = RefCell::new(HashMap::new()); } +/// Parses JSON into `ast::Circuit` and compile. Generates a Rust UUID. Inserts tuple of (`ast::Circuit`, `ChiquitoHalo2`, `AssignmentGenerator`, _) to `CIRCUIT_MAP` with the Rust UUID as the key. Return the Rust UUID to Python. The last field of the tuple, `TraceWitness`, is left as None, for `chiquito_add_witness_to_rust_id` to insert. pub fn chiquito_ast_to_halo2(ast_json: &str) -> UUID { let circuit: Circuit = serde_json::from_str(ast_json).expect("Json deserialization to Circuit failed."); @@ -54,7 +55,8 @@ pub fn chiquito_ast_to_halo2(ast_json: &str) -> UUID { uuid } -pub fn chiquito_add_witness_to_ast(witness_json: &str, rust_id: UUID) { +/// Parses JSON into `TraceWitness` and insert it into `CIRCUIT_MAP` with `rust_id` as the key. +pub fn chiquito_add_witness_to_rust_id(witness_json: &str, rust_id: UUID) { let witness: TraceWitness = serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed."); @@ -67,7 +69,7 @@ pub fn chiquito_add_witness_to_ast(witness_json: &str, rust_id: UUID) { println!("Added TraceWitness to rust_id: {:?}", rust_id); } -fn add_assignment_generator_to_ast(assignment_generator: AssignmentGenerator, rust_id: UUID) { +fn add_assignment_generator_to_rust_id(assignment_generator: AssignmentGenerator, rust_id: UUID) { CIRCUIT_MAP.with(|circuit_map| { let mut circuit_map = circuit_map.borrow_mut(); let circuit_map_store = circuit_map.get_mut(&rust_id).unwrap(); @@ -77,24 +79,25 @@ fn add_assignment_generator_to_ast(assignment_generator: AssignmentGenerator) { - let mut ctx = SuperCircuitContext::::default(); + let mut super_circuit_ctx = SuperCircuitContext::::default(); // super_circuit def let config = config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}); for rust_id in rust_ids.clone() { - let circuit_map_store = uuid_to_halo2(rust_id); + let circuit_map_store = rust_id_to_halo2(rust_id); let (circuit, chiquito_halo2, assignment_generator, witness) = circuit_map_store; - let assignment = ctx.sub_circuit_with_ast(config.clone(), circuit); - add_assignment_generator_to_ast(assignment, rust_id); + let assignment = super_circuit_ctx.sub_circuit_with_ast(config.clone(), circuit); + add_assignment_generator_to_rust_id(assignment, rust_id); } - let super_circuit = ctx.compile(); + let super_circuit = super_circuit_ctx.compile(); let compiled = chiquitoSuperCircuit2Halo2(&super_circuit); let mut mapping_ctx = MappingContext::default(); for rust_id in rust_ids { - let circuit_map_store = uuid_to_halo2(rust_id); + let circuit_map_store = rust_id_to_halo2(rust_id); let (circuit, chiquito_halo2, assignment_generator, witness) = circuit_map_store; if witness.is_some() { mapping_ctx.map_with_witness(&assignment_generator.unwrap(), witness.unwrap()); @@ -121,17 +124,19 @@ pub fn chiquito_super_circuit_halo2_mock_prover(rust_ids: Vec) { } } -fn uuid_to_halo2(uuid: UUID) -> CircuitMapStore { +/// Returns the (`ast::Circuit`, `ChiquitoHalo2`, `AssignmentGenerator`, `TraceWitness`) tuple corresponding to `rust_id`. +fn rust_id_to_halo2(uuid: UUID) -> CircuitMapStore { CIRCUIT_MAP.with(|circuit_map| { let circuit_map = circuit_map.borrow(); circuit_map.get(&uuid).unwrap().clone() }) } +/// Runs `MockProver` for a single circuit given JSON of `TraceWitness` and `rust_id` of the circuit. pub fn chiquito_halo2_mock_prover(witness_json: &str, rust_id: UUID) { let trace_witness: TraceWitness = serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed."); - let (_, compiled, assignment_generator, _) = uuid_to_halo2(rust_id); + let (_, compiled, assignment_generator, _) = rust_id_to_halo2(rust_id); let circuit: ChiquitoHalo2Circuit<_> = ChiquitoHalo2Circuit::new( compiled, assignment_generator.map(|g| g.generate_with_witness(trace_witness)), @@ -1816,24 +1821,32 @@ fn ast_to_halo2(json: &PyString) -> u128 { } #[pyfunction] -fn halo2_mock_prover(witness_json: &PyString, ast_uuid: &PyLong) { +fn halo2_mock_prover(witness_json: &PyString, rust_id: &PyLong) { chiquito_halo2_mock_prover( witness_json.to_str().expect("PyString convertion failed."), - ast_uuid.extract().expect("PyLong convertion failed."), + rust_id.extract().expect("PyLong convertion failed."), ); } #[pyfunction] -fn add_witness_to_ast(witness_json: &PyString, ast_uuid: &PyLong) { - chiquito_add_witness_to_ast( +fn add_witness_to_rust_id(witness_json: &PyString, rust_id: &PyLong) { + chiquito_add_witness_to_rust_id( witness_json.to_str().expect("PyString convertion failed."), - ast_uuid.extract().expect("PyLong convertion failed."), + rust_id.extract().expect("PyLong convertion failed."), ); } #[pyfunction] -fn super_circuit_halo2_mock_prover() { - +fn super_circuit_halo2_mock_prover(rust_ids: &PyList) { + let uuids = rust_ids.iter().map(|rust_id| { + rust_id + .downcast::() + .expect("PyAny downcast failed.") + .extract() + .expect("PyLong convertion failed.") + }).collect::>(); + + chiquito_super_circuit_halo2_mock_prover(uuids) } #[pymodule] @@ -1842,6 +1855,7 @@ fn rust_chiquito(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(convert_and_print_trace_witness, m)?)?; m.add_function(wrap_pyfunction!(ast_to_halo2, m)?)?; m.add_function(wrap_pyfunction!(halo2_mock_prover, m)?)?; - m.add_function(wrap_pyfunction!(add_witness_to_ast, m)?)?; + m.add_function(wrap_pyfunction!(add_witness_to_rust_id, m)?)?; + m.add_function(wrap_pyfunction!(super_circuit_halo2_mock_prover, m)?)?; Ok(()) }