From cba2ce85a6c1710a900ee5c767d409457105cc85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Nicolas?= Date: Sat, 14 Sep 2024 00:08:33 +0200 Subject: [PATCH 1/4] byte-pair: Table for assert_byte and pairs of bytes --- ceno_zkvm/src/chip_handler/general.rs | 27 +++++++- ceno_zkvm/src/scheme/mock_prover.rs | 16 +++++ ceno_zkvm/src/structs.rs | 1 + ceno_zkvm/src/tables/mod.rs | 2 + ceno_zkvm/src/tables/u8_pair.rs | 91 +++++++++++++++++++++++++++ ceno_zkvm/src/uint.rs | 2 +- ceno_zkvm/src/witness.rs | 20 +++--- 7 files changed, 147 insertions(+), 12 deletions(-) create mode 100644 ceno_zkvm/src/tables/u8_pair.rs diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index c1ec9b2a9..23bdb2137 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -145,7 +145,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { { match C { 16 => self.assert_u16(name_fn, expr), - 8 => self.assert_byte(name_fn, expr), + 8 => self.assert_byte_slow(name_fn, expr), 5 => self.assert_u5(name_fn, expr), _ => panic!("Unsupported bit range"), } @@ -169,6 +169,24 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { ) } + /// Ensure that a and b are both bytes. + /// This is more efficient than two separate assert_byte. + pub(crate) fn assert_u8_pair( + &mut self, + name_fn: N, + a: Expression, + b: Expression, + ) -> Result<(), ZKVMError> + where + NR: Into, + N: FnOnce() -> NR, + { + let fields: Vec> = vec![(ROMType::U8Pair as usize).into(), a, b]; + let rlc_record = self.rlc_chip_record(fields); + self.lk_record(name_fn, rlc_record)?; + Ok(()) + } + fn assert_u16(&mut self, name_fn: N, expr: Expression) -> Result<(), ZKVMError> where NR: Into, @@ -195,7 +213,9 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { }) } - pub(crate) fn assert_byte( + /// Ensure that expr is a byte. + /// This is "slow" because `assert_u8_pair` is more efficient if you have multiple bytes. + pub(crate) fn assert_byte_slow( &mut self, name_fn: N, expr: Expression, @@ -204,7 +224,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { - self.assert_u16(name_fn, expr * Expression::from(1 << 8)) + self.assert_u8_pair(name_fn, expr, Expression::ZERO) } pub(crate) fn assert_bit( @@ -216,6 +236,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { + // TODO: Replace with `x * (1 - x)` or a multi-bit lookup similar to assert_u8_pair. self.assert_u16(name_fn, expr * Expression::from(1 << 15)) } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 3b9a8afcc..912ae206b 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -224,6 +224,21 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> } } + fn load_u8_pair_table( + t_vec: &mut Vec>, + cb: &CircuitBuilder, + challenge: [E; 2], + ) { + for i in 0..=u16::MAX as usize { + let a = i & 0xff; + let b = (i >> 8) & 0xff; + let rlc_record = + cb.rlc_chip_record(vec![(ROMType::U8Pair as usize).into(), a.into(), b.into()]); + let rlc_record = eval_by_expr(&[], &challenge, &rlc_record); + t_vec.push(rlc_record.to_repr().as_ref().to_vec()); + } + } + fn load_u16_table( t_vec: &mut Vec>, cb: &CircuitBuilder, @@ -300,6 +315,7 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> let mut table_vec = vec![]; // TODO load more tables here load_u5_table(&mut table_vec, cb, challenge); + load_u8_pair_table(&mut table_vec, cb, challenge); load_u16_table(&mut table_vec, cb, challenge); load_lt_table(&mut table_vec, cb, challenge); load_and_table(&mut table_vec, cb, challenge); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index c7c136d66..6e8ab6e6e 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -42,6 +42,7 @@ pub type ChallengeId = u16; #[derive(Debug)] pub enum ROMType { U5 = 0, // 2^5 = 32 + U8Pair, // 2^8 * 2^8 = 65,536 U16, // 2^16 = 65,536 And, // a ^ b where a, b are bytes Ltu, // a <(usign) b where a, b are bytes diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index a66094952..d7697ad9e 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -5,6 +5,8 @@ use std::collections::HashMap; mod range; pub use range::RangeTableCircuit; +mod u8_pair; + pub trait TableCircuit { type TableConfig: Send + Sync; type Input: Send + Sync; diff --git a/ceno_zkvm/src/tables/u8_pair.rs b/ceno_zkvm/src/tables/u8_pair.rs new file mode 100644 index 000000000..67d223537 --- /dev/null +++ b/ceno_zkvm/src/tables/u8_pair.rs @@ -0,0 +1,91 @@ +use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, Fixed, ToExpr, WitIn}, + scheme::constants::MIN_PAR_SIZE, + set_fixed_val, set_val, + structs::ROMType, + tables::TableCircuit, + witness::RowMajorMatrix, +}; +use ff_ext::ExtensionField; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; + +#[derive(Clone, Debug)] +pub struct U8PairTableConfig { + tbl_a: Fixed, + tbl_b: Fixed, + mlt: WitIn, +} + +pub struct U8PairTableCircuit(PhantomData); + +impl TableCircuit for U8PairTableCircuit { + type TableConfig = U8PairTableConfig; + type Input = u64; + + fn name() -> String { + "U8_PAIR".into() + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let tbl_a = cb.create_fixed(|| "tbl_a")?; + let tbl_b = cb.create_fixed(|| "tbl_b")?; + let mlt = cb.create_witin(|| "mlt")?; + + let rlc_record = cb.rlc_chip_record(vec![ + Expression::Constant(E::BaseField::from(ROMType::U8Pair as u64)), + Expression::Fixed(tbl_a.clone()), + Expression::Fixed(tbl_b.clone()), + ]); + + cb.lk_table_record(|| "u8_pair_table", rlc_record, mlt.expr())?; + + Ok(U8PairTableConfig { tbl_a, tbl_b, mlt }) + } + + fn generate_fixed_traces( + config: &U8PairTableConfig, + num_fixed: usize, + ) -> RowMajorMatrix { + let num_u16s = 1 << 16; + let mut fixed = RowMajorMatrix::::new(num_u16s, num_fixed); + fixed + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip((0..num_u16s).into_par_iter()) + .for_each(|(row, i)| { + let a = i & 0xff; + let b = (i >> 8) & 0xff; + set_fixed_val!(row, config.tbl_a.0, E::BaseField::from(a as u64)); + set_fixed_val!(row, config.tbl_b.0, E::BaseField::from(b as u64)); + }); + + fixed + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + multiplicity: &[HashMap], + ) -> Result, ZKVMError> { + let multiplicity = &multiplicity[ROMType::U8Pair as usize]; + let mut mlts = vec![0; 1 << 16]; + for (idx, mlt) in multiplicity { + mlts[*idx as usize] = *mlt; + } + + let mut witness = RowMajorMatrix::::new(mlts.len(), num_witin); + witness + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip(mlts.into_par_iter()) + .for_each(|(row, mlt)| { + set_val!(row, config.mlt, E::BaseField::from(mlt as u64)); + }); + + Ok(witness) + } +} diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index a256cac00..98120f4ae 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -242,7 +242,7 @@ impl UInt { let limbs = (0..k) .map(|_| { let w = circuit_builder.create_witin(|| "").unwrap(); - circuit_builder.assert_byte(|| "", w.expr()).unwrap(); + circuit_builder.assert_byte_slow(|| "", w.expr()).unwrap(); w.expr() }) .collect_vec(); diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index a77d9c4c3..d4761382c 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -97,7 +97,7 @@ impl LkMultiplicity { pub fn assert_ux(&mut self, v: u64) { match C { 16 => self.assert_u16(v), - 8 => self.assert_byte(v), + 8 => self.assert_byte(v as u8), 5 => self.assert_u5(v), _ => panic!("Unsupported bit range"), } @@ -112,17 +112,17 @@ impl LkMultiplicity { .or_default()) += 1; } - fn assert_u16(&mut self, v: u64) { + pub fn assert_u8_pair(&mut self, a: u8, b: u8) { + let key = a as u64 | (b as u64) << 8; let multiplicity = self .multiplicity .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); - (*multiplicity.borrow_mut()[ROMType::U16 as usize] - .entry(v) + (*multiplicity.borrow_mut()[ROMType::U8Pair as usize] + .entry(key) .or_default()) += 1; } - fn assert_byte(&mut self, v: u64) { - let v = v * (1 << u8::BITS); + fn assert_u16(&mut self, v: u64) { let multiplicity = self .multiplicity .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); @@ -131,6 +131,10 @@ impl LkMultiplicity { .or_default()) += 1; } + fn assert_byte(&mut self, v: u8) { + self.assert_u8_pair(v, 0) + } + /// lookup a < b as unsigned byte pub fn lookup_ltu_limb8(&mut self, a: u64, b: u64) { let key = a.wrapping_mul(256) + b; @@ -172,10 +176,10 @@ mod tests { // each thread calling assert_byte once for _ in 0..thread_count { let mut lkm = lkm.clone(); - thread::spawn(move || lkm.assert_byte(8u64)).join().unwrap(); + thread::spawn(move || lkm.assert_byte(8)).join().unwrap(); } let res = lkm.into_finalize_result(); // check multiplicity counts of assert_byte - assert_eq!(res[ROMType::U16 as usize][&(8 << 8)], thread_count); + assert_eq!(res[ROMType::U8Pair as usize][&8], thread_count); } } From c0dca74960279fdadde27ccc1cd74ddd72711f32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Nicolas?= Date: Sun, 15 Sep 2024 20:30:15 +0200 Subject: [PATCH 2/4] byte-pair: fix after merge --- ceno_zkvm/src/tables/u8_pair.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/tables/u8_pair.rs b/ceno_zkvm/src/tables/u8_pair.rs index 67d223537..638f9c189 100644 --- a/ceno_zkvm/src/tables/u8_pair.rs +++ b/ceno_zkvm/src/tables/u8_pair.rs @@ -24,7 +24,8 @@ pub struct U8PairTableCircuit(PhantomData); impl TableCircuit for U8PairTableCircuit { type TableConfig = U8PairTableConfig; - type Input = u64; + type FixedInput = (); + type WitnessInput = (); fn name() -> String { "U8_PAIR".into() @@ -49,6 +50,7 @@ impl TableCircuit for U8PairTableCircuit { fn generate_fixed_traces( config: &U8PairTableConfig, num_fixed: usize, + _input: &(), ) -> RowMajorMatrix { let num_u16s = 1 << 16; let mut fixed = RowMajorMatrix::::new(num_u16s, num_fixed); @@ -59,8 +61,8 @@ impl TableCircuit for U8PairTableCircuit { .for_each(|(row, i)| { let a = i & 0xff; let b = (i >> 8) & 0xff; - set_fixed_val!(row, config.tbl_a.0, E::BaseField::from(a as u64)); - set_fixed_val!(row, config.tbl_b.0, E::BaseField::from(b as u64)); + set_fixed_val!(row, config.tbl_a, E::BaseField::from(a as u64)); + set_fixed_val!(row, config.tbl_b, E::BaseField::from(b as u64)); }); fixed @@ -70,6 +72,7 @@ impl TableCircuit for U8PairTableCircuit { config: &Self::TableConfig, num_witin: usize, multiplicity: &[HashMap], + _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[ROMType::U8Pair as usize]; let mut mlts = vec![0; 1 << 16]; From 33de71c0cb6e5065adb69d50ce9213b1242ee39a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Nicolas?= Date: Sun, 15 Sep 2024 20:49:25 +0200 Subject: [PATCH 3/4] byte-pair: merge U8Pair with RangeTable for easy usage --- ceno_zkvm/src/tables/mod.rs | 2 - ceno_zkvm/src/tables/range.rs | 23 ++++++++-- ceno_zkvm/src/tables/{ => range}/u8_pair.rs | 49 ++++++++------------- 3 files changed, 38 insertions(+), 36 deletions(-) rename ceno_zkvm/src/tables/{ => range}/u8_pair.rs (63%) diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 6137834ec..7c9a29ed4 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -5,8 +5,6 @@ use std::collections::HashMap; mod range; pub use range::RangeTableCircuit; -mod u8_pair; - mod program; pub use program::{InsnRecord, ProgramTableCircuit}; diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index 2b195ea63..7da80a39d 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -1,3 +1,6 @@ +mod u8_pair; +use u8_pair::{U8PairTableCircuit, U8PairTableConfig}; + use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; use crate::{ @@ -18,6 +21,8 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterato pub struct RangeTableConfig { u16_tbl: Fixed, u16_mlt: WitIn, + + u8_pair: U8PairTableConfig, } pub struct RangeTableCircuit(PhantomData); @@ -42,7 +47,16 @@ impl TableCircuit for RangeTableCircuit { cb.lk_table_record(|| "u16 table", u16_table_values, u16_mlt.expr())?; - Ok(RangeTableConfig { u16_tbl, u16_mlt }) + let u8_pair = cb.namespace( + || "u8_pair", + |cb| U8PairTableCircuit::::construct_circuit(cb), + )?; + + Ok(RangeTableConfig { + u16_tbl, + u16_mlt, + u8_pair, + }) } fn generate_fixed_traces( @@ -60,6 +74,8 @@ impl TableCircuit for RangeTableCircuit { set_fixed_val!(row, config.u16_tbl, E::BaseField::from(i as u64)); }); + U8PairTableCircuit::::generate_fixed_traces(&config.u8_pair, &mut fixed); + fixed } @@ -69,9 +85,8 @@ impl TableCircuit for RangeTableCircuit { multiplicity: &[HashMap], _input: &(), ) -> Result, ZKVMError> { - let multiplicity = &multiplicity[ROMType::U16 as usize]; let mut u16_mlt = vec![0; 1 << RANGE_CHIP_BIT_WIDTH]; - for (limb, mlt) in multiplicity { + for (limb, mlt) in &multiplicity[ROMType::U16 as usize] { u16_mlt[*limb as usize] = *mlt; } @@ -84,6 +99,8 @@ impl TableCircuit for RangeTableCircuit { set_val!(row, config.u16_mlt, E::BaseField::from(mlt as u64)); }); + U8PairTableCircuit::::assign_instances(&config.u8_pair, multiplicity, &mut witness); + Ok(witness) } } diff --git a/ceno_zkvm/src/tables/u8_pair.rs b/ceno_zkvm/src/tables/range/u8_pair.rs similarity index 63% rename from ceno_zkvm/src/tables/u8_pair.rs rename to ceno_zkvm/src/tables/range/u8_pair.rs index 638f9c189..b16dc7116 100644 --- a/ceno_zkvm/src/tables/u8_pair.rs +++ b/ceno_zkvm/src/tables/range/u8_pair.rs @@ -7,12 +7,13 @@ use crate::{ scheme::constants::MIN_PAR_SIZE, set_fixed_val, set_val, structs::ROMType, - tables::TableCircuit, witness::RowMajorMatrix, }; use ff_ext::ExtensionField; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +const NUM_U8_PAIRS: usize = 1 << 16; + #[derive(Clone, Debug)] pub struct U8PairTableConfig { tbl_a: Fixed, @@ -22,16 +23,8 @@ pub struct U8PairTableConfig { pub struct U8PairTableCircuit(PhantomData); -impl TableCircuit for U8PairTableCircuit { - type TableConfig = U8PairTableConfig; - type FixedInput = (); - type WitnessInput = (); - - fn name() -> String { - "U8_PAIR".into() - } - - fn construct_circuit(cb: &mut CircuitBuilder) -> Result { +impl U8PairTableCircuit { + pub fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let tbl_a = cb.create_fixed(|| "tbl_a")?; let tbl_b = cb.create_fixed(|| "tbl_b")?; let mlt = cb.create_witin(|| "mlt")?; @@ -47,40 +40,36 @@ impl TableCircuit for U8PairTableCircuit { Ok(U8PairTableConfig { tbl_a, tbl_b, mlt }) } - fn generate_fixed_traces( + pub fn generate_fixed_traces( config: &U8PairTableConfig, - num_fixed: usize, - _input: &(), - ) -> RowMajorMatrix { - let num_u16s = 1 << 16; - let mut fixed = RowMajorMatrix::::new(num_u16s, num_fixed); + fixed: &mut RowMajorMatrix, + ) { + assert!(fixed.num_instances() >= NUM_U8_PAIRS); + fixed .par_iter_mut() .with_min_len(MIN_PAR_SIZE) - .zip((0..num_u16s).into_par_iter()) + .zip((0..NUM_U8_PAIRS).into_par_iter()) .for_each(|(row, i)| { let a = i & 0xff; let b = (i >> 8) & 0xff; set_fixed_val!(row, config.tbl_a, E::BaseField::from(a as u64)); set_fixed_val!(row, config.tbl_b, E::BaseField::from(b as u64)); }); - - fixed } - fn assign_instances( - config: &Self::TableConfig, - num_witin: usize, + pub fn assign_instances( + config: &U8PairTableConfig, multiplicity: &[HashMap], - _input: &(), - ) -> Result, ZKVMError> { - let multiplicity = &multiplicity[ROMType::U8Pair as usize]; - let mut mlts = vec![0; 1 << 16]; - for (idx, mlt) in multiplicity { + witness: &mut RowMajorMatrix, + ) { + assert!(witness.num_instances() >= NUM_U8_PAIRS); + + let mut mlts = vec![0; NUM_U8_PAIRS]; + for (idx, mlt) in &multiplicity[ROMType::U8Pair as usize] { mlts[*idx as usize] = *mlt; } - let mut witness = RowMajorMatrix::::new(mlts.len(), num_witin); witness .par_iter_mut() .with_min_len(MIN_PAR_SIZE) @@ -88,7 +77,5 @@ impl TableCircuit for U8PairTableCircuit { .for_each(|(row, mlt)| { set_val!(row, config.mlt, E::BaseField::from(mlt as u64)); }); - - Ok(witness) } } From 7cd9cf1c96fbc7b2aea3ad529a323b5f61d571b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Nicolas?= Date: Sun, 15 Sep 2024 21:03:24 +0200 Subject: [PATCH 4/4] byte-pair: zero padding --- ceno_zkvm/src/tables/range/u8_pair.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/tables/range/u8_pair.rs b/ceno_zkvm/src/tables/range/u8_pair.rs index b16dc7116..f3fd34eb4 100644 --- a/ceno_zkvm/src/tables/range/u8_pair.rs +++ b/ceno_zkvm/src/tables/range/u8_pair.rs @@ -1,3 +1,6 @@ +use ff::Field; +use ff_ext::ExtensionField; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit}; use crate::{ @@ -9,8 +12,6 @@ use crate::{ structs::ROMType, witness::RowMajorMatrix, }; -use ff_ext::ExtensionField; -use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; const NUM_U8_PAIRS: usize = 1 << 16; @@ -56,6 +57,12 @@ impl U8PairTableCircuit { set_fixed_val!(row, config.tbl_a, E::BaseField::from(a as u64)); set_fixed_val!(row, config.tbl_b, E::BaseField::from(b as u64)); }); + + // Fill the rest with zeros. + fixed.par_iter_mut().skip(NUM_U8_PAIRS).for_each(|row| { + set_fixed_val!(row, config.tbl_a, E::BaseField::ZERO); + set_fixed_val!(row, config.tbl_b, E::BaseField::ZERO); + }); } pub fn assign_instances( @@ -77,5 +84,10 @@ impl U8PairTableCircuit { .for_each(|(row, mlt)| { set_val!(row, config.mlt, E::BaseField::from(mlt as u64)); }); + + // Fill the rest with zeros. + witness.par_iter_mut().skip(NUM_U8_PAIRS).for_each(|row| { + set_val!(row, config.mlt, E::BaseField::ZERO); + }); } }