Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
PIL Backend with PIL IR (#165)
Browse files Browse the repository at this point in the history
Ready for review. @leolara

# High-Level Approach
Convert Chiquito AST to PIL IR and finally to PIL code.

# Future TODOs
- Add support for trace inputs to PIL. Currently, PIL supports user
input via the command line or Rust APIs for the command line. However,
it only supports input for one variable (column). Because Chiquito
circuit might provide inputs for multiple variables via trace (e.g. the
MiMC7 circuit), we can't use use PIL's Rust API as a generic solution
for providing inputs. The only solution is to input statements in the
format of `ISFIRST * ([[input_variable]] - [[input_value]]) = 0`. An
issue with this approach, however, is that our `trace` is too general
and doesn't associate input values to signals, so we might need to
create an alternative `trace` function or mode that does this. All of
the above is due to the fact that PIL only supports automatic witness
inference and cannot feed in external witness.
- Add Python PIL API support for super circuits. Currently we support
single circuit in Python. Should be very do-able, just didn't get the
time to code it up.
  • Loading branch information
qwang98 authored Mar 4, 2024
1 parent 8817d86 commit 15f441b
Show file tree
Hide file tree
Showing 18 changed files with 1,018 additions and 13 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ num-bigint = { version = "0.4", features = ["rand"] }
uuid = { version = "1.4.0", features = ["v1", "rng"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
regex = "1"

[dev-dependencies]
rand_chacha = "0.3"
Expand Down
2 changes: 2 additions & 0 deletions examples/fibonacci.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,5 @@ def trace(self, n):
) # 2^k specifies the number of PLONKish table rows in Halo2
another_fibo_witness = fibo.gen_witness(4)
fibo.halo2_mock_prover(another_fibo_witness, k=7)

fibo.to_pil(fibo_witness, "FiboCircuit")
29 changes: 23 additions & 6 deletions examples/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use chiquito::{
},
plonkish::ir::{assignments::AssignmentGenerator, Circuit}, // compiled circuit type
poly::ToField,
sbpir::SBPIR,
};
use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr};

Expand All @@ -26,7 +27,10 @@ use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr};
// 1. type that implements a field trait
// 2. empty trace arguments, i.e. (), because there are no external inputs to the Chiquito circuit
// 3. two witness generation arguments both of u64 type, i.e. (u64, u64)
fn fibo_circuit<F: Field + From<u64> + Hash>() -> (Circuit<F>, Option<AssignmentGenerator<F, ()>>) {

type FiboReturn<F> = (Circuit<F>, Option<AssignmentGenerator<F, ()>>, SBPIR<F, ()>);

fn fibo_circuit<F: Field + From<u64> + Hash>() -> FiboReturn<F> {
// PLONKish table for the Fibonacci circuit:
// | a | b | c |
// | 1 | 1 | 2 |
Expand Down Expand Up @@ -73,7 +77,7 @@ fn fibo_circuit<F: Field + From<u64> + Hash>() -> (Circuit<F>, Option<Assignment
// logics for assigning witness values wg function is defined here but no
// witness value is assigned yet
ctx.wg(move |ctx, (a_value, b_value): (u32, u32)| {
println!("fib line wg: {} {} {}", a_value, b_value, a_value + b_value);
// println!("fib line wg: {} {} {}", a_value, b_value, a_value + b_value);
// assign arbitrary input values from witness generation function to witnesses
ctx.assign(a, a_value.field());
ctx.assign(b, b_value.field());
Expand Down Expand Up @@ -105,18 +109,20 @@ fn fibo_circuit<F: Field + From<u64> + Hash>() -> (Circuit<F>, Option<Assignment
})
});

compile(
let compiled = compile(
config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}),
&fibo,
)
);

(compiled.0, compiled.1, fibo)
}

// After compiling Chiquito AST to an IR, it is further parsed by a Chiquito Halo2 backend and
// integrated into a Halo2 circuit, which is done by the boilerplate code below.

// standard main function for a Halo2 circuit
fn main() {
let (chiquito, wit_gen) = fibo_circuit::<Fr>();
let (chiquito, wit_gen, _) = fibo_circuit::<Fr>();
let compiled = chiquito2Halo2(chiquito);
let circuit = ChiquitoHalo2Circuit::new(compiled, wit_gen.map(|g| g.generate(())));

Expand All @@ -137,7 +143,7 @@ fn main() {
use polyexen::plaf::{backends::halo2::PlafH2Circuit, WitnessDisplayCSV};

// get Chiquito ir
let (circuit, wit_gen) = fibo_circuit::<Fr>();
let (circuit, wit_gen, _) = fibo_circuit::<Fr>();
// get Plaf
let (plaf, plaf_wit_gen) = chiquito2Plaf(circuit, 8, false);
let wit = plaf_wit_gen.generate(wit_gen.map(|v| v.generate(())));
Expand All @@ -162,4 +168,15 @@ fn main() {
println!("{}", failure);
}
}

// pil boilerplate
use chiquito::pil::backend::powdr_pil::chiquito2Pil;

let (_, wit_gen, circuit) = fibo_circuit::<Fr>();
let pil = chiquito2Pil(
circuit,
Some(wit_gen.unwrap().generate_trace_witness(())),
String::from("FiboCircuit"),
);
print!("{}", pil);
}
25 changes: 24 additions & 1 deletion examples/mimc7.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ fn mimc7_circuit<F: PrimeField + Eq + Hash>(
row_value += F::from(1);
x_value += k_value + c_value;
x_value = x_value.pow_vartime([7_u64]);
// Step 90: output the hash result as x + k in witness generation
// Step 91: output the hash result as x + k in witness generation
// output is not displayed as a public column, which will be implemented in the future
ctx.add(&mimc7_last_step, (x_value, k_value, c_value, row_value)); // c_value is not
// used here but
Expand Down Expand Up @@ -219,6 +219,29 @@ fn main() {
println!("{}", failure);
}
}

// pil boilerplate
use chiquito::pil::backend::powdr_pil::chiquitoSuperCircuit2Pil;

let x_in_value = Fr::from_str_vartime("1").expect("expected a number");
let k_value = Fr::from_str_vartime("2").expect("expected a number");

let super_circuit = mimc7_super_circuit::<Fr>();

// `super_trace_witnesses` is a mapping from IR id to TraceWitness. However, not all ASTs have a
// corresponding TraceWitness.
let super_trace_witnesses = super_circuit
.get_mapping()
.generate_super_trace_witnesses((x_in_value, k_value));

let pil = chiquitoSuperCircuit2Pil::<Fr, (), ()>(
super_circuit.get_super_asts(),
super_trace_witnesses,
super_circuit.get_ast_id_to_ir_id_mapping(),
vec![String::from("Mimc7Constant"), String::from("Mimc7Circuit")],
);

print!("{}", pil);
}

mod mimc7_constants {
Expand Down
4 changes: 4 additions & 0 deletions src/frontend/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ impl StepTypeHandler {
pub fn next<F>(&self) -> Queriable<F> {
Queriable::StepTypeNext(*self)
}

pub fn annotation(&self) -> String {
self.annotation.to_string()
}
}

impl<F, Args, D: Fn(&mut StepInstance<F>, Args) + 'static> From<&StepTypeWGHandler<F, Args, D>>
Expand Down
11 changes: 9 additions & 2 deletions src/frontend/dsl/sc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ impl<F, MappingArgs> Default for SuperCircuitContext<F, MappingArgs> {
}
}

impl<F: Clone, MappingArgs> SuperCircuitContext<F, MappingArgs> {
fn add_sub_circuit_ast(&mut self, ast: SBPIR<F, ()>) {
self.super_circuit.add_sub_circuit_ast(ast);
}
}

impl<F: Field + Hash, MappingArgs> SuperCircuitContext<F, MappingArgs> {
pub fn sub_circuit<CM: CellManager, SSB: StepSelectorBuilder, TraceArgs, Imports, Exports, D>(
&mut self,
Expand All @@ -48,12 +54,13 @@ impl<F: Field + Hash, MappingArgs> SuperCircuitContext<F, MappingArgs> {
circuit: SBPIR::default(),
tables: self.tables.clone(),
};
println!("super circuit table registry 2: {:?}", self.tables);
let exports = sub_circuit_def(&mut sub_circuit_context, imports);
println!("super circuit table registry 3: {:?}", self.tables);

let sub_circuit = sub_circuit_context.circuit;

// ast is used for PIL backend
self.add_sub_circuit_ast(sub_circuit.clone_without_trace());

let (unit, assignment) = compile_phase1(config, &sub_circuit);
let assignment = assignment.unwrap_or_else(|| AssignmentGenerator::empty(unit.uuid));

Expand Down
9 changes: 9 additions & 0 deletions src/frontend/python/chiquito/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,15 @@ def halo2_mock_prover(self: Circuit, witness: TraceWitness, k: int = 16):
witness_json: str = witness.get_witness_json()
rust_chiquito.halo2_mock_prover(witness_json, self.rust_id, k)

def to_pil(
self: Circuit, witness: TraceWitness, circuit_name: str = "Circuit"
) -> str:
if self.rust_id == 0:
ast_json: str = self.get_ast_json()
self.rust_id: int = rust_chiquito.ast_to_halo2(ast_json)
witness_json: str = witness.get_witness_json()
rust_chiquito.to_pil(witness_json, self.rust_id, circuit_name)

def __str__(self: Circuit) -> str:
return self.ast.__str__()

Expand Down
22 changes: 22 additions & 0 deletions src/frontend/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use pyo3::{

use crate::{
frontend::dsl::{StepTypeHandler, SuperCircuitContext},
pil::backend::powdr_pil::chiquito2Pil,
plonkish::{
backend::halo2::{
chiquito2Halo2, chiquitoSuperCircuit2Halo2, ChiquitoHalo2, ChiquitoHalo2Circuit,
Expand Down Expand Up @@ -81,6 +82,14 @@ pub fn chiquito_ast_map_store(ast_json: &str) -> UUID {
uuid
}

pub fn chiquito_ast_to_pil(witness_json: &str, rust_id: UUID, circuit_name: &str) -> String {
let trace_witness: TraceWitness<Fr> =
serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed.");
let (ast, _, _) = rust_id_to_halo2(rust_id);

chiquito2Pil(ast, Some(trace_witness), circuit_name.to_string())
}

fn add_assignment_generator_to_rust_id(
assignment_generator: AssignmentGenerator<Fr, ()>,
rust_id: UUID,
Expand Down Expand Up @@ -1845,6 +1854,18 @@ fn ast_to_halo2(json: &PyString) -> u128 {
uuid
}

#[pyfunction]
fn to_pil(witness_json: &PyString, rust_id: &PyLong, circuit_name: &PyString) -> String {
let pil = chiquito_ast_to_pil(
witness_json.to_str().expect("PyString convertion failed."),
rust_id.extract().expect("PyLong convertion failed."),
circuit_name.to_str().expect("PyString convertion failed."),
);

println!("{}", pil);
pil
}

#[pyfunction]
fn ast_map_store(json: &PyString) -> u128 {
let uuid = chiquito_ast_map_store(json.to_str().expect("PyString conversion failed."));
Expand Down Expand Up @@ -1903,6 +1924,7 @@ fn rust_chiquito(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(convert_and_print_ast, m)?)?;
m.add_function(wrap_pyfunction!(convert_and_print_trace_witness, m)?)?;
m.add_function(wrap_pyfunction!(ast_to_halo2, m)?)?;
m.add_function(wrap_pyfunction!(to_pil, m)?)?;
m.add_function(wrap_pyfunction!(ast_map_store, m)?)?;
m.add_function(wrap_pyfunction!(halo2_mock_prover, m)?)?;
m.add_function(wrap_pyfunction!(super_circuit_halo2_mock_prover, m)?)?;
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod field;
pub mod frontend;
pub mod pil;
pub mod plonkish;
pub mod poly;
pub mod sbpir;
Expand Down
1 change: 1 addition & 0 deletions src/pil/backend/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod powdr_pil;
Loading

0 comments on commit 15f441b

Please sign in to comment.