Skip to content

Commit

Permalink
feat: add middleware check_witness (#356)
Browse files Browse the repository at this point in the history
  • Loading branch information
ed255 authored Sep 16, 2024
1 parent ee611ee commit c84fbb8
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 63 deletions.
178 changes: 178 additions & 0 deletions halo2_debug/src/check_witness.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use crate::display::FDisp;
use halo2_middleware::circuit::{Any, CompiledCircuit, ExpressionMid, VarMid};
use halo2_middleware::ff::PrimeField;
use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
use std::collections::HashSet;

fn rotate(n: usize, offset: usize, rotation: i32) -> usize {
let offset = offset as i32 + rotation;
if offset < 0 {
(offset + n as i32) as usize
} else if offset >= n as i32 {
(offset - n as i32) as usize
} else {
offset as usize
}
}

struct Assignments<'a, F: PrimeField> {
public: &'a [Vec<F>],
witness: &'a [Vec<F>],
fixed: &'a [Vec<F>],
blinders: &'a [Vec<F>],
blinded: &'a [bool],
usable_rows: usize,
n: usize,
}

impl<'a, F: PrimeField> Assignments<'a, F> {
// Query a particular Column at an offset
fn query(&self, column_type: Any, column_index: usize, offset: usize) -> F {
match column_type {
Any::Instance => self.public[column_index][offset],
Any::Advice => {
if offset >= self.usable_rows && self.blinded[column_index] {
self.blinders[column_index][offset - self.usable_rows]
} else {
self.witness[column_index][offset]
}
}
Any::Fixed => self.fixed[column_index][offset],
}
}

// Evaluate an expression using the assingment data
fn eval(&self, expr: &ExpressionMid<F>, offset: usize) -> F {
expr.evaluate(
&|s| s,
&|v| match v {
VarMid::Query(q) => {
let offset = rotate(self.n, offset, q.rotation.0);
self.query(q.column_type, q.column_index, offset)
}
VarMid::Challenge(_c) => unimplemented!(),
},
&|ne| -ne,
&|a, b| a + b,
&|a, b| a * b,
)
}

// Evaluate multiple expressions and return the result as concatenated bytes from the field
// element representation.
fn eval_to_buf(&self, f_len: usize, exprs: &[ExpressionMid<F>], offset: usize) -> Vec<u8> {
let mut eval_buf = Vec::with_capacity(exprs.len() * f_len);
for eval in exprs.iter().map(|e| self.eval(e, offset)) {
eval_buf.extend_from_slice(eval.to_repr().as_ref())
}
eval_buf
}
}

/// Check that the wintess passes all the constraints defined by the circuit. Panics if any
/// constraint is not satisfied.
pub fn check_witness<F: PrimeField>(
circuit: &CompiledCircuit<F>,
k: u32,
blinding_rows: usize,
witness: &[Vec<F>],
public: &[Vec<F>],
) {
let n = 2usize.pow(k);
let usable_rows = n - blinding_rows;
let cs = &circuit.cs;

// Calculate blinding values
let mut rng = ChaCha20Rng::seed_from_u64(0xdeadbeef);
let mut blinders = vec![vec![F::ZERO; blinding_rows]; cs.num_advice_columns];
for column_blinders in blinders.iter_mut() {
for v in column_blinders.iter_mut() {
*v = F::random(&mut rng);
}
}

let mut blinded = vec![true; cs.num_advice_columns];
for advice_column_index in &cs.unblinded_advice_columns {
blinded[*advice_column_index] = false;
}

let assignments = Assignments {
public,
witness,
fixed: &circuit.preprocessing.fixed,
blinders: &blinders,
blinded: &blinded,
usable_rows,
n,
};

// Verify all gates
for (i, gate) in cs.gates.iter().enumerate() {
for offset in 0..n {
let res = assignments.eval(&gate.poly, offset);
if !res.is_zero_vartime() {
panic!(
"Unsatisfied gate {} \"{}\" at offset {}",
i, gate.name, offset
);
}
}
}

// Verify all copy constraints
for (lhs, rhs) in &circuit.preprocessing.permutation.copies {
let value_lhs = assignments.query(lhs.column.column_type, lhs.column.index, lhs.row);
let value_rhs = assignments.query(rhs.column.column_type, rhs.column.index, rhs.row);
if value_lhs != value_rhs {
panic!(
"Unsatisfied copy constraint ({:?},{:?}): {} != {}",
lhs,
rhs,
FDisp(&value_lhs),
FDisp(&value_rhs)
)
}
}

// Verify all lookups
let f_len = F::Repr::default().as_ref().len();
for (i, lookup) in cs.lookups.iter().enumerate() {
let mut virtual_table = HashSet::new();
for offset in 0..usable_rows {
let table_eval_buf = assignments.eval_to_buf(f_len, &lookup.table_expressions, offset);
virtual_table.insert(table_eval_buf);
}
for offset in 0..usable_rows {
let input_eval_buf = assignments.eval_to_buf(f_len, &lookup.input_expressions, offset);
if !virtual_table.contains(&input_eval_buf) {
panic!(
"Unsatisfied lookup {} \"{}\" at offset {}",
i, lookup.name, offset
);
}
}
}

// Verify all shuffles
for (i, shuffle) in cs.shuffles.iter().enumerate() {
let mut virtual_shuffle = Vec::with_capacity(usable_rows);
for offset in 0..usable_rows {
let shuffle_eval_buf =
assignments.eval_to_buf(f_len, &shuffle.shuffle_expressions, offset);
virtual_shuffle.push(shuffle_eval_buf);
}
let mut virtual_input = Vec::with_capacity(usable_rows);
for offset in 0..usable_rows {
let input_eval_buf = assignments.eval_to_buf(f_len, &shuffle.input_expressions, offset);
virtual_input.push(input_eval_buf);
}

virtual_shuffle.sort_unstable();
virtual_input.sort_unstable();

if virtual_input != virtual_shuffle {
panic!("Unsatisfied shuffle {} \"{}\"", i, shuffle.name);
}
}
}
7 changes: 5 additions & 2 deletions halo2_debug/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
mod check_witness;
pub mod display;

pub use check_witness::check_witness;

use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
use tiny_keccak::Hasher;
Expand Down Expand Up @@ -34,5 +39,3 @@ pub fn test_result<F: FnOnce() -> Vec<u8> + Send>(test: F, _expected: &str) -> V

result
}

pub mod display;
1 change: 1 addition & 0 deletions p3_frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ p3-keccak-air = { git = "https://github.com/Plonky3/Plonky3", rev = "7b5b8a6" }
p3-keccak = { git = "https://github.com/Plonky3/Plonky3", rev = "7b5b8a6" }
p3-util = { git = "https://github.com/Plonky3/Plonky3", rev = "7b5b8a6" }
rand = "0.8.5"
halo2_debug = { path = "../halo2_debug" }
63 changes: 6 additions & 57 deletions p3_frontend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
extern crate alloc;

use halo2_middleware::circuit::{
Any, Cell, ColumnMid, CompiledCircuit, ConstraintSystemMid, ExpressionMid, GateMid,
Preprocessing, QueryMid, VarMid,
Any, Cell, ColumnMid, ConstraintSystemMid, ExpressionMid, GateMid, Preprocessing, QueryMid,
VarMid,
};
use halo2_middleware::ff::{Field, PrimeField};
use halo2_middleware::permutation;
Expand Down Expand Up @@ -184,7 +184,7 @@ fn extract_copy_public<F: PrimeField + Hash>(
pub fn get_public_inputs<F: Field>(
preprocessing_info: &PreprocessingInfo,
size: usize,
witness: &[Option<Vec<F>>],
witness: &[Vec<F>],
) -> Vec<Vec<F>> {
if preprocessing_info.num_public_values == 0 {
return Vec::new();
Expand All @@ -196,7 +196,7 @@ pub fn get_public_inputs<F: Field>(
Location::LastRow => size - 1,
Location::Transition => unreachable!(),
};
public_inputs[*public_index] = witness[cell.0].as_ref().unwrap()[offset]
public_inputs[*public_index] = witness[cell.0][offset]
}
vec![public_inputs]
}
Expand Down Expand Up @@ -293,7 +293,7 @@ where
(cs, preprocessing_info)
}

pub fn trace_to_wit<F: Field>(k: u32, trace: RowMajorMatrix<FWrap<F>>) -> Vec<Option<Vec<F>>> {
pub fn trace_to_wit<F: Field>(k: u32, trace: RowMajorMatrix<FWrap<F>>) -> Vec<Vec<F>> {
let n = 2usize.pow(k);
let num_columns = trace.width;
let mut witness = vec![vec![F::ZERO; n]; num_columns];
Expand All @@ -302,56 +302,5 @@ pub fn trace_to_wit<F: Field>(k: u32, trace: RowMajorMatrix<FWrap<F>>) -> Vec<Op
witness[column_index][row_offset] = row[column_index].0;
}
}
witness.into_iter().map(Some).collect()
}

// TODO: Move to middleware
pub fn check_witness<F: Field>(
circuit: &CompiledCircuit<F>,
k: u32,
witness: &[Option<Vec<F>>],
public: &[Vec<F>],
) {
let n = 2usize.pow(k);
let cs = &circuit.cs;
let preprocessing = &circuit.preprocessing;
// TODO: Simulate blinding rows
// Verify all gates
for (i, gate) in cs.gates.iter().enumerate() {
for offset in 0..n {
let res = gate.poly.evaluate(
&|s| s,
&|v| match v {
VarMid::Query(q) => {
let offset = offset as i32 + q.rotation.0;
// TODO: Try to do mod n with a rust function
let offset = if offset < 0 {
(offset + n as i32) as usize
} else if offset >= n as i32 {
(offset - n as i32) as usize
} else {
offset as usize
};
match q.column_type {
Any::Instance => public[q.column_index][offset],
Any::Advice => witness[q.column_index].as_ref().unwrap()[offset],
Any::Fixed => preprocessing.fixed[q.column_index][offset],
}
}
VarMid::Challenge(_c) => unimplemented!(),
},
&|ne| -ne,
&|a, b| a + b,
&|a, b| a * b,
);
if !res.is_zero_vartime() {
println!(
"Unsatisfied gate {} \"{}\" at offset {}",
i, gate.name, offset
);
panic!("KO");
}
}
}
println!("Check witness: OK");
witness
}
13 changes: 9 additions & 4 deletions p3_frontend/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ use halo2_backend::{
Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer,
},
};
use halo2_debug::check_witness;
use halo2_debug::test_rng;
use halo2_middleware::circuit::CompiledCircuit;
use halo2_middleware::zal::impls::H2cEngine;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
use p3_air::Air;
use p3_frontend::{
check_witness, compile_circuit_cs, compile_preprocessing, get_public_inputs, trace_to_wit,
CompileParams, FWrap, SymbolicAirBuilder,
compile_circuit_cs, compile_preprocessing, get_public_inputs, trace_to_wit, CompileParams,
FWrap, SymbolicAirBuilder,
};
use p3_matrix::dense::RowMajorMatrix;
use std::time::Instant;
Expand Down Expand Up @@ -50,8 +51,12 @@ where
let witness = trace_to_wit(k, trace);
let pis = get_public_inputs(&preprocessing_info, size, &witness);

check_witness(&compiled_circuit, k, &witness, &pis);
(compiled_circuit, witness, pis)
check_witness(&compiled_circuit, k, 5, &witness, &pis);
(
compiled_circuit,
witness.into_iter().map(Some).collect(),
pis,
)
}

pub(crate) fn setup_prove_verify(
Expand Down

0 comments on commit c84fbb8

Please sign in to comment.