From 05716ab77cd358f3b51d6e9357ae18b7293ac1b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Nicolas?= Date: Tue, 17 Sep 2024 13:43:38 +0200 Subject: [PATCH 1/2] u8-table --- ceno_zkvm/src/chip_handler/general.rs | 6 +- ceno_zkvm/src/scheme/mock_prover.rs | 13 ++++ ceno_zkvm/src/structs.rs | 1 + ceno_zkvm/src/tables/range.rs | 23 ++++++- ceno_zkvm/src/tables/range/u8_table.rs | 83 ++++++++++++++++++++++++++ ceno_zkvm/src/witness.rs | 5 +- 6 files changed, 124 insertions(+), 7 deletions(-) create mode 100644 ceno_zkvm/src/tables/range/u8_table.rs diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 29c86e463..30265af78 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -216,7 +216,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { - self.assert_u16(name_fn, expr * Expression::from(1 << 8)) + let items: Vec> = vec![(ROMType::U8 as usize).into(), expr]; + let rlc_record = self.rlc_chip_record(items); + self.lk_record(name_fn, rlc_record)?; + Ok(()) } pub(crate) fn assert_bit( @@ -228,6 +231,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { + // TODO: Replace with `x * (1 - x)` or a multi-bit lookup similar to assert_u8_pair. self.assert_u16(name_fn, expr * Expression::from(1 << 15)) } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 8dea82b70..826342c88 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -239,6 +239,18 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> } } + fn load_u8_table( + t_vec: &mut Vec>, + cb: &CircuitBuilder, + challenge: [E; 2], + ) { + for i in 0..=u8::MAX as usize { + let rlc_record = cb.rlc_chip_record(vec![(ROMType::U8 as usize).into(), i.into()]); + let rlc_record = eval_by_expr(&[], &challenge, &rlc_record); + t_vec.push(rlc_record.to_repr().as_ref().to_vec()); + } + } + fn load_u16_table( t_vec: &mut Vec>, cb: &CircuitBuilder, @@ -338,6 +350,7 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> let mut table_vec = vec![]; // TODO load more tables here load_u5_table(&mut table_vec, cb, challenge); + load_u8_table(&mut table_vec, cb, challenge); load_u16_table(&mut table_vec, cb, challenge); load_lt_table(&mut table_vec, cb, challenge); load_and_table(&mut table_vec, cb, challenge); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1a522df88..cdf16c0cf 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -42,6 +42,7 @@ pub type ChallengeId = u16; #[derive(Debug)] pub enum ROMType { U5 = 0, // 2^5 = 32 + U8, // 2^8 = 256 U16, // 2^16 = 65,536 And, // a ^ b where a, b are bytes Ltu, // a <(usign) b where a, b are bytes diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index 2b195ea63..c0b1dd9a9 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -1,3 +1,6 @@ +mod u8_table; +use u8_table::{U8TableCircuit, U8TableConfig}; + use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; use crate::{ @@ -18,6 +21,8 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterato pub struct RangeTableConfig { u16_tbl: Fixed, u16_mlt: WitIn, + + u8_config: U8TableConfig, } pub struct RangeTableCircuit(PhantomData); @@ -42,7 +47,16 @@ impl TableCircuit for RangeTableCircuit { cb.lk_table_record(|| "u16 table", u16_table_values, u16_mlt.expr())?; - Ok(RangeTableConfig { u16_tbl, u16_mlt }) + let u8_config = cb.namespace( + || "u8_table", + |cb| U8TableCircuit::::construct_circuit(cb), + )?; + + Ok(RangeTableConfig { + u16_tbl, + u16_mlt, + u8_config, + }) } fn generate_fixed_traces( @@ -60,6 +74,8 @@ impl TableCircuit for RangeTableCircuit { set_fixed_val!(row, config.u16_tbl, E::BaseField::from(i as u64)); }); + U8TableCircuit::::generate_fixed_traces(&config.u8_config, &mut fixed); + fixed } @@ -69,9 +85,8 @@ impl TableCircuit for RangeTableCircuit { multiplicity: &[HashMap], _input: &(), ) -> Result, ZKVMError> { - let multiplicity = &multiplicity[ROMType::U16 as usize]; let mut u16_mlt = vec![0; 1 << RANGE_CHIP_BIT_WIDTH]; - for (limb, mlt) in multiplicity { + for (limb, mlt) in &multiplicity[ROMType::U16 as usize] { u16_mlt[*limb as usize] = *mlt; } @@ -84,6 +99,8 @@ impl TableCircuit for RangeTableCircuit { set_val!(row, config.u16_mlt, E::BaseField::from(mlt as u64)); }); + U8TableCircuit::::assign_instances(&config.u8_config, multiplicity, &mut witness); + Ok(witness) } } diff --git a/ceno_zkvm/src/tables/range/u8_table.rs b/ceno_zkvm/src/tables/range/u8_table.rs new file mode 100644 index 000000000..25d08adb5 --- /dev/null +++ b/ceno_zkvm/src/tables/range/u8_table.rs @@ -0,0 +1,83 @@ +use ff::Field; +use ff_ext::ExtensionField; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, Fixed, ToExpr, WitIn}, + scheme::constants::MIN_PAR_SIZE, + set_fixed_val, set_val, + structs::ROMType, + witness::RowMajorMatrix, +}; + +const NUM_U8: usize = 1 << 8; + +#[derive(Clone, Debug)] +pub struct U8TableConfig { + u8_fixed: Fixed, + mlt: WitIn, +} + +pub struct U8TableCircuit(PhantomData); + +impl U8TableCircuit { + pub fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let u8_fixed = cb.create_fixed(|| "u8_fixed")?; + let mlt = cb.create_witin(|| "mlt")?; + + let rlc_record = cb.rlc_chip_record(vec![ + Expression::Constant(E::BaseField::from(ROMType::U8 as u64)), + Expression::Fixed(u8_fixed.clone()), + ]); + + cb.lk_table_record(|| "u8_record", rlc_record, mlt.expr())?; + + Ok(U8TableConfig { u8_fixed, mlt }) + } + + pub fn generate_fixed_traces(config: &U8TableConfig, fixed: &mut RowMajorMatrix) { + assert!(fixed.num_instances() >= NUM_U8); + + fixed + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip((0..NUM_U8).into_par_iter()) + .for_each(|(row, i)| { + set_fixed_val!(row, config.u8_fixed, E::BaseField::from(i as u64)); + }); + + // Fill the rest with zeros, if any. + fixed.par_iter_mut().skip(NUM_U8).for_each(|row| { + set_fixed_val!(row, config.u8_fixed, E::BaseField::ZERO); + }); + } + + pub fn assign_instances( + config: &U8TableConfig, + multiplicity: &[HashMap], + witness: &mut RowMajorMatrix, + ) { + assert!(witness.num_instances() >= NUM_U8); + + let mut mlts = vec![0; NUM_U8]; + for (idx, mlt) in &multiplicity[ROMType::U8 as usize] { + mlts[*idx as usize] = *mlt; + } + + witness + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip(mlts.into_par_iter()) + .for_each(|(row, mlt)| { + set_val!(row, config.mlt, E::BaseField::from(mlt as u64)); + }); + + // Fill the rest with zeros, if any. + witness.par_iter_mut().skip(NUM_U8).for_each(|row| { + set_val!(row, config.mlt, E::BaseField::ZERO); + }); + } +} diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 263b6645c..b49694874 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -126,11 +126,10 @@ impl LkMultiplicity { } fn assert_byte(&mut self, v: u64) { - let v = v * (1 << u8::BITS); let multiplicity = self .multiplicity .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); - (*multiplicity.borrow_mut()[ROMType::U16 as usize] + (*multiplicity.borrow_mut()[ROMType::U8 as usize] .entry(v) .or_default()) += 1; } @@ -189,6 +188,6 @@ mod tests { } let res = lkm.into_finalize_result(); // check multiplicity counts of assert_byte - assert_eq!(res[ROMType::U16 as usize][&(8 << 8)], thread_count); + assert_eq!(res[ROMType::U8 as usize][&8], thread_count); } } From a83f37dcee660bf70c9ecbaf368808b8807a041b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Nicolas?= Date: Tue, 17 Sep 2024 14:30:13 +0200 Subject: [PATCH 2/2] u8-table: simplify types --- ceno_zkvm/src/tables/range.rs | 13 ++++++----- ceno_zkvm/src/tables/range/u8_table.rs | 30 +++++++++++++------------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index c0b1dd9a9..a3b706177 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -1,5 +1,5 @@ mod u8_table; -use u8_table::{U8TableCircuit, U8TableConfig}; +use u8_table::U8TableConfig; use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; @@ -47,10 +47,7 @@ impl TableCircuit for RangeTableCircuit { cb.lk_table_record(|| "u16 table", u16_table_values, u16_mlt.expr())?; - let u8_config = cb.namespace( - || "u8_table", - |cb| U8TableCircuit::::construct_circuit(cb), - )?; + let u8_config = cb.namespace(|| "u8_table", |cb| U8TableConfig::construct_circuit(cb))?; Ok(RangeTableConfig { u16_tbl, @@ -74,7 +71,7 @@ impl TableCircuit for RangeTableCircuit { set_fixed_val!(row, config.u16_tbl, E::BaseField::from(i as u64)); }); - U8TableCircuit::::generate_fixed_traces(&config.u8_config, &mut fixed); + config.u8_config.generate_fixed_traces(&mut fixed); fixed } @@ -99,7 +96,9 @@ impl TableCircuit for RangeTableCircuit { set_val!(row, config.u16_mlt, E::BaseField::from(mlt as u64)); }); - U8TableCircuit::::assign_instances(&config.u8_config, multiplicity, &mut witness); + config + .u8_config + .assign_instances(multiplicity, &mut witness); Ok(witness) } diff --git a/ceno_zkvm/src/tables/range/u8_table.rs b/ceno_zkvm/src/tables/range/u8_table.rs index 25d08adb5..e1e026db9 100644 --- a/ceno_zkvm/src/tables/range/u8_table.rs +++ b/ceno_zkvm/src/tables/range/u8_table.rs @@ -1,7 +1,7 @@ -use ff::Field; use ff_ext::ExtensionField; +use goldilocks::SmallField; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; -use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; +use std::{collections::HashMap, mem::MaybeUninit}; use crate::{ circuit_builder::CircuitBuilder, @@ -21,10 +21,10 @@ pub struct U8TableConfig { mlt: WitIn, } -pub struct U8TableCircuit(PhantomData); - -impl U8TableCircuit { - pub fn construct_circuit(cb: &mut CircuitBuilder) -> Result { +impl U8TableConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + ) -> Result { let u8_fixed = cb.create_fixed(|| "u8_fixed")?; let mlt = cb.create_witin(|| "mlt")?; @@ -35,10 +35,10 @@ impl U8TableCircuit { cb.lk_table_record(|| "u8_record", rlc_record, mlt.expr())?; - Ok(U8TableConfig { u8_fixed, mlt }) + Ok(Self { u8_fixed, mlt }) } - pub fn generate_fixed_traces(config: &U8TableConfig, fixed: &mut RowMajorMatrix) { + pub fn generate_fixed_traces(&self, fixed: &mut RowMajorMatrix) { assert!(fixed.num_instances() >= NUM_U8); fixed @@ -46,19 +46,19 @@ impl U8TableCircuit { .with_min_len(MIN_PAR_SIZE) .zip((0..NUM_U8).into_par_iter()) .for_each(|(row, i)| { - set_fixed_val!(row, config.u8_fixed, E::BaseField::from(i as u64)); + set_fixed_val!(row, self.u8_fixed, F::from(i as u64)); }); // Fill the rest with zeros, if any. fixed.par_iter_mut().skip(NUM_U8).for_each(|row| { - set_fixed_val!(row, config.u8_fixed, E::BaseField::ZERO); + set_fixed_val!(row, self.u8_fixed, F::ZERO); }); } - pub fn assign_instances( - config: &U8TableConfig, + pub fn assign_instances( + &self, multiplicity: &[HashMap], - witness: &mut RowMajorMatrix, + witness: &mut RowMajorMatrix, ) { assert!(witness.num_instances() >= NUM_U8); @@ -72,12 +72,12 @@ impl U8TableCircuit { .with_min_len(MIN_PAR_SIZE) .zip(mlts.into_par_iter()) .for_each(|(row, mlt)| { - set_val!(row, config.mlt, E::BaseField::from(mlt as u64)); + set_val!(row, self.mlt, F::from(mlt as u64)); }); // Fill the rest with zeros, if any. witness.par_iter_mut().skip(NUM_U8).for_each(|row| { - set_val!(row, config.mlt, E::BaseField::ZERO); + set_val!(row, self.mlt, F::ZERO); }); } }