From c84fbb8b6dea1f3d59570b4921baa75f0f1c9345 Mon Sep 17 00:00:00 2001 From: "Eduard S." Date: Mon, 16 Sep 2024 09:28:14 +0200 Subject: [PATCH] feat: add middleware check_witness (#356) --- halo2_debug/src/check_witness.rs | 178 +++++++++++++++++++++++++++++++ halo2_debug/src/lib.rs | 7 +- p3_frontend/Cargo.toml | 1 + p3_frontend/src/lib.rs | 63 ++--------- p3_frontend/tests/common/mod.rs | 13 ++- 5 files changed, 199 insertions(+), 63 deletions(-) create mode 100644 halo2_debug/src/check_witness.rs diff --git a/halo2_debug/src/check_witness.rs b/halo2_debug/src/check_witness.rs new file mode 100644 index 0000000000..e4299e6bf4 --- /dev/null +++ b/halo2_debug/src/check_witness.rs @@ -0,0 +1,178 @@ +use crate::display::FDisp; +use halo2_middleware::circuit::{Any, CompiledCircuit, ExpressionMid, VarMid}; +use halo2_middleware::ff::PrimeField; +use rand_chacha::ChaCha20Rng; +use rand_core::SeedableRng; +use std::collections::HashSet; + +fn rotate(n: usize, offset: usize, rotation: i32) -> usize { + let offset = offset as i32 + rotation; + if offset < 0 { + (offset + n as i32) as usize + } else if offset >= n as i32 { + (offset - n as i32) as usize + } else { + offset as usize + } +} + +struct Assignments<'a, F: PrimeField> { + public: &'a [Vec], + witness: &'a [Vec], + fixed: &'a [Vec], + blinders: &'a [Vec], + blinded: &'a [bool], + usable_rows: usize, + n: usize, +} + +impl<'a, F: PrimeField> Assignments<'a, F> { + // Query a particular Column at an offset + fn query(&self, column_type: Any, column_index: usize, offset: usize) -> F { + match column_type { + Any::Instance => self.public[column_index][offset], + Any::Advice => { + if offset >= self.usable_rows && self.blinded[column_index] { + self.blinders[column_index][offset - self.usable_rows] + } else { + self.witness[column_index][offset] + } + } + Any::Fixed => self.fixed[column_index][offset], + } + } + + // Evaluate an expression using the assingment data + fn eval(&self, expr: &ExpressionMid, offset: usize) -> F { + expr.evaluate( + &|s| s, + &|v| match v { + VarMid::Query(q) => { + let offset = rotate(self.n, offset, q.rotation.0); + self.query(q.column_type, q.column_index, offset) + } + VarMid::Challenge(_c) => unimplemented!(), + }, + &|ne| -ne, + &|a, b| a + b, + &|a, b| a * b, + ) + } + + // Evaluate multiple expressions and return the result as concatenated bytes from the field + // element representation. + fn eval_to_buf(&self, f_len: usize, exprs: &[ExpressionMid], offset: usize) -> Vec { + let mut eval_buf = Vec::with_capacity(exprs.len() * f_len); + for eval in exprs.iter().map(|e| self.eval(e, offset)) { + eval_buf.extend_from_slice(eval.to_repr().as_ref()) + } + eval_buf + } +} + +/// Check that the wintess passes all the constraints defined by the circuit. Panics if any +/// constraint is not satisfied. +pub fn check_witness( + circuit: &CompiledCircuit, + k: u32, + blinding_rows: usize, + witness: &[Vec], + public: &[Vec], +) { + let n = 2usize.pow(k); + let usable_rows = n - blinding_rows; + let cs = &circuit.cs; + + // Calculate blinding values + let mut rng = ChaCha20Rng::seed_from_u64(0xdeadbeef); + let mut blinders = vec![vec![F::ZERO; blinding_rows]; cs.num_advice_columns]; + for column_blinders in blinders.iter_mut() { + for v in column_blinders.iter_mut() { + *v = F::random(&mut rng); + } + } + + let mut blinded = vec![true; cs.num_advice_columns]; + for advice_column_index in &cs.unblinded_advice_columns { + blinded[*advice_column_index] = false; + } + + let assignments = Assignments { + public, + witness, + fixed: &circuit.preprocessing.fixed, + blinders: &blinders, + blinded: &blinded, + usable_rows, + n, + }; + + // Verify all gates + for (i, gate) in cs.gates.iter().enumerate() { + for offset in 0..n { + let res = assignments.eval(&gate.poly, offset); + if !res.is_zero_vartime() { + panic!( + "Unsatisfied gate {} \"{}\" at offset {}", + i, gate.name, offset + ); + } + } + } + + // Verify all copy constraints + for (lhs, rhs) in &circuit.preprocessing.permutation.copies { + let value_lhs = assignments.query(lhs.column.column_type, lhs.column.index, lhs.row); + let value_rhs = assignments.query(rhs.column.column_type, rhs.column.index, rhs.row); + if value_lhs != value_rhs { + panic!( + "Unsatisfied copy constraint ({:?},{:?}): {} != {}", + lhs, + rhs, + FDisp(&value_lhs), + FDisp(&value_rhs) + ) + } + } + + // Verify all lookups + let f_len = F::Repr::default().as_ref().len(); + for (i, lookup) in cs.lookups.iter().enumerate() { + let mut virtual_table = HashSet::new(); + for offset in 0..usable_rows { + let table_eval_buf = assignments.eval_to_buf(f_len, &lookup.table_expressions, offset); + virtual_table.insert(table_eval_buf); + } + for offset in 0..usable_rows { + let input_eval_buf = assignments.eval_to_buf(f_len, &lookup.input_expressions, offset); + if !virtual_table.contains(&input_eval_buf) { + panic!( + "Unsatisfied lookup {} \"{}\" at offset {}", + i, lookup.name, offset + ); + } + } + } + + // Verify all shuffles + for (i, shuffle) in cs.shuffles.iter().enumerate() { + let mut virtual_shuffle = Vec::with_capacity(usable_rows); + for offset in 0..usable_rows { + let shuffle_eval_buf = + assignments.eval_to_buf(f_len, &shuffle.shuffle_expressions, offset); + virtual_shuffle.push(shuffle_eval_buf); + } + let mut virtual_input = Vec::with_capacity(usable_rows); + for offset in 0..usable_rows { + let input_eval_buf = assignments.eval_to_buf(f_len, &shuffle.input_expressions, offset); + virtual_input.push(input_eval_buf); + } + + virtual_shuffle.sort_unstable(); + virtual_input.sort_unstable(); + + if virtual_input != virtual_shuffle { + panic!("Unsatisfied shuffle {} \"{}\"", i, shuffle.name); + } + } +} diff --git a/halo2_debug/src/lib.rs b/halo2_debug/src/lib.rs index 911e90e6df..d9343e50d4 100644 --- a/halo2_debug/src/lib.rs +++ b/halo2_debug/src/lib.rs @@ -1,3 +1,8 @@ +mod check_witness; +pub mod display; + +pub use check_witness::check_witness; + use rand_chacha::ChaCha20Rng; use rand_core::SeedableRng; use tiny_keccak::Hasher; @@ -34,5 +39,3 @@ pub fn test_result Vec + Send>(test: F, _expected: &str) -> V result } - -pub mod display; diff --git a/p3_frontend/Cargo.toml b/p3_frontend/Cargo.toml index 34dceddde5..bc91918ae7 100644 --- a/p3_frontend/Cargo.toml +++ b/p3_frontend/Cargo.toml @@ -36,3 +36,4 @@ p3-keccak-air = { git = "https://github.com/Plonky3/Plonky3", rev = "7b5b8a6" } p3-keccak = { git = "https://github.com/Plonky3/Plonky3", rev = "7b5b8a6" } p3-util = { git = "https://github.com/Plonky3/Plonky3", rev = "7b5b8a6" } rand = "0.8.5" +halo2_debug = { path = "../halo2_debug" } diff --git a/p3_frontend/src/lib.rs b/p3_frontend/src/lib.rs index 5ba49e25c7..947b9bb832 100644 --- a/p3_frontend/src/lib.rs +++ b/p3_frontend/src/lib.rs @@ -5,8 +5,8 @@ extern crate alloc; use halo2_middleware::circuit::{ - Any, Cell, ColumnMid, CompiledCircuit, ConstraintSystemMid, ExpressionMid, GateMid, - Preprocessing, QueryMid, VarMid, + Any, Cell, ColumnMid, ConstraintSystemMid, ExpressionMid, GateMid, Preprocessing, QueryMid, + VarMid, }; use halo2_middleware::ff::{Field, PrimeField}; use halo2_middleware::permutation; @@ -184,7 +184,7 @@ fn extract_copy_public( pub fn get_public_inputs( preprocessing_info: &PreprocessingInfo, size: usize, - witness: &[Option>], + witness: &[Vec], ) -> Vec> { if preprocessing_info.num_public_values == 0 { return Vec::new(); @@ -196,7 +196,7 @@ pub fn get_public_inputs( Location::LastRow => size - 1, Location::Transition => unreachable!(), }; - public_inputs[*public_index] = witness[cell.0].as_ref().unwrap()[offset] + public_inputs[*public_index] = witness[cell.0][offset] } vec![public_inputs] } @@ -293,7 +293,7 @@ where (cs, preprocessing_info) } -pub fn trace_to_wit(k: u32, trace: RowMajorMatrix>) -> Vec>> { +pub fn trace_to_wit(k: u32, trace: RowMajorMatrix>) -> Vec> { let n = 2usize.pow(k); let num_columns = trace.width; let mut witness = vec![vec![F::ZERO; n]; num_columns]; @@ -302,56 +302,5 @@ pub fn trace_to_wit(k: u32, trace: RowMajorMatrix>) -> Vec( - circuit: &CompiledCircuit, - k: u32, - witness: &[Option>], - public: &[Vec], -) { - let n = 2usize.pow(k); - let cs = &circuit.cs; - let preprocessing = &circuit.preprocessing; - // TODO: Simulate blinding rows - // Verify all gates - for (i, gate) in cs.gates.iter().enumerate() { - for offset in 0..n { - let res = gate.poly.evaluate( - &|s| s, - &|v| match v { - VarMid::Query(q) => { - let offset = offset as i32 + q.rotation.0; - // TODO: Try to do mod n with a rust function - let offset = if offset < 0 { - (offset + n as i32) as usize - } else if offset >= n as i32 { - (offset - n as i32) as usize - } else { - offset as usize - }; - match q.column_type { - Any::Instance => public[q.column_index][offset], - Any::Advice => witness[q.column_index].as_ref().unwrap()[offset], - Any::Fixed => preprocessing.fixed[q.column_index][offset], - } - } - VarMid::Challenge(_c) => unimplemented!(), - }, - &|ne| -ne, - &|a, b| a + b, - &|a, b| a * b, - ); - if !res.is_zero_vartime() { - println!( - "Unsatisfied gate {} \"{}\" at offset {}", - i, gate.name, offset - ); - panic!("KO"); - } - } - } - println!("Check witness: OK"); + witness } diff --git a/p3_frontend/tests/common/mod.rs b/p3_frontend/tests/common/mod.rs index d762b5f2d5..5ea7e8682e 100644 --- a/p3_frontend/tests/common/mod.rs +++ b/p3_frontend/tests/common/mod.rs @@ -11,14 +11,15 @@ use halo2_backend::{ Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, }, }; +use halo2_debug::check_witness; use halo2_debug::test_rng; use halo2_middleware::circuit::CompiledCircuit; use halo2_middleware::zal::impls::H2cEngine; use halo2curves::bn256::{Bn256, Fr, G1Affine}; use p3_air::Air; use p3_frontend::{ - check_witness, compile_circuit_cs, compile_preprocessing, get_public_inputs, trace_to_wit, - CompileParams, FWrap, SymbolicAirBuilder, + compile_circuit_cs, compile_preprocessing, get_public_inputs, trace_to_wit, CompileParams, + FWrap, SymbolicAirBuilder, }; use p3_matrix::dense::RowMajorMatrix; use std::time::Instant; @@ -50,8 +51,12 @@ where let witness = trace_to_wit(k, trace); let pis = get_public_inputs(&preprocessing_info, size, &witness); - check_witness(&compiled_circuit, k, &witness, &pis); - (compiled_circuit, witness, pis) + check_witness(&compiled_circuit, k, 5, &witness, &pis); + ( + compiled_circuit, + witness.into_iter().map(Some).collect(), + pis, + ) } pub(crate) fn setup_prove_verify(