Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

u8 pair table #225

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
{
match C {
16 => self.assert_u16(name_fn, expr),
8 => self.assert_byte(name_fn, expr),
8 => self.assert_byte_slow(name_fn, expr),
5 => self.assert_u5(name_fn, expr),
_ => panic!("Unsupported bit range"),
}
Expand All @@ -181,6 +181,24 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
)
}

/// Ensure that a and b are both bytes.
/// This is more efficient than two separate assert_byte.
pub(crate) fn assert_u8_pair<NR, N>(
&mut self,
name_fn: N,
a: Expression<E>,
b: Expression<E>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
let fields: Vec<Expression<E>> = vec![(ROMType::U8Pair as usize).into(), a, b];
let rlc_record = self.rlc_chip_record(fields);
self.lk_record(name_fn, rlc_record)?;
Ok(())
}

fn assert_u16<NR, N>(&mut self, name_fn: N, expr: Expression<E>) -> Result<(), ZKVMError>
where
NR: Into<String>,
Expand All @@ -207,7 +225,9 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
})
}

pub(crate) fn assert_byte<NR, N>(
/// Ensure that expr is a byte.
/// This is "slow" because `assert_u8_pair` is more efficient if you have multiple bytes.
pub(crate) fn assert_byte_slow<NR, N>(
&mut self,
name_fn: N,
expr: Expression<E>,
Expand All @@ -216,7 +236,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR,
{
self.assert_u16(name_fn, expr * Expression::from(1 << 8))
self.assert_u8_pair(name_fn, expr, Expression::ZERO)
}

pub(crate) fn assert_bit<NR, N>(
Expand All @@ -228,6 +248,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR,
{
// TODO: Replace with `x * (1 - x)` or a multi-bit lookup similar to assert_u8_pair.
self.assert_u16(name_fn, expr * Expression::from(1 << 15))
}

Expand Down
16 changes: 16 additions & 0 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,21 @@ fn load_tables<E: ExtensionField>(cb: &CircuitBuilder<E>, challenge: [E; 2]) ->
}
}

fn load_u8_pair_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
for i in 0..=u16::MAX as usize {
let a = i & 0xff;
let b = (i >> 8) & 0xff;
let rlc_record =
cb.rlc_chip_record(vec![(ROMType::U8Pair as usize).into(), a.into(), b.into()]);
let rlc_record = eval_by_expr(&[], &challenge, &rlc_record);
t_vec.push(rlc_record.to_repr().as_ref().to_vec());
}
}

fn load_u16_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
Expand Down Expand Up @@ -338,6 +353,7 @@ fn load_tables<E: ExtensionField>(cb: &CircuitBuilder<E>, challenge: [E; 2]) ->
let mut table_vec = vec![];
// TODO load more tables here
load_u5_table(&mut table_vec, cb, challenge);
load_u8_pair_table(&mut table_vec, cb, challenge);
load_u16_table(&mut table_vec, cb, challenge);
load_lt_table(&mut table_vec, cb, challenge);
load_and_table(&mut table_vec, cb, challenge);
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub type ChallengeId = u16;
#[derive(Debug)]
pub enum ROMType {
U5 = 0, // 2^5 = 32
U8Pair, // 2^8 * 2^8 = 65,536
U16, // 2^16 = 65,536
And, // a ^ b where a, b are bytes
Ltu, // a <(usign) b where a, b are bytes
Expand Down
23 changes: 20 additions & 3 deletions ceno_zkvm/src/tables/range.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
mod u8_pair;
use u8_pair::{U8PairTableCircuit, U8PairTableConfig};

use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit};

use crate::{
Expand All @@ -18,6 +21,8 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterato
pub struct RangeTableConfig {
u16_tbl: Fixed,
u16_mlt: WitIn,

u8_pair: U8PairTableConfig,
}

pub struct RangeTableCircuit<E>(PhantomData<E>);
Expand All @@ -42,7 +47,16 @@ impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {

cb.lk_table_record(|| "u16 table", u16_table_values, u16_mlt.expr())?;

Ok(RangeTableConfig { u16_tbl, u16_mlt })
let u8_pair = cb.namespace(
|| "u8_pair",
|cb| U8PairTableCircuit::<E>::construct_circuit(cb),
)?;

Ok(RangeTableConfig {
u16_tbl,
u16_mlt,
u8_pair,
})
}

fn generate_fixed_traces(
Expand All @@ -60,6 +74,8 @@ impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {
set_fixed_val!(row, config.u16_tbl, E::BaseField::from(i as u64));
});

U8PairTableCircuit::<E>::generate_fixed_traces(&config.u8_pair, &mut fixed);

fixed
}

Expand All @@ -69,9 +85,8 @@ impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {
multiplicity: &[HashMap<u64, usize>],
_input: &(),
) -> Result<RowMajorMatrix<E::BaseField>, ZKVMError> {
let multiplicity = &multiplicity[ROMType::U16 as usize];
let mut u16_mlt = vec![0; 1 << RANGE_CHIP_BIT_WIDTH];
for (limb, mlt) in multiplicity {
for (limb, mlt) in &multiplicity[ROMType::U16 as usize] {
u16_mlt[*limb as usize] = *mlt;
}

Expand All @@ -84,6 +99,8 @@ impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {
set_val!(row, config.u16_mlt, E::BaseField::from(mlt as u64));
});

U8PairTableCircuit::<E>::assign_instances(&config.u8_pair, multiplicity, &mut witness);

Ok(witness)
}
}
93 changes: 93 additions & 0 deletions ceno_zkvm/src/tables/range/u8_pair.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use ff::Field;
use ff_ext::ExtensionField;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit};

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
scheme::constants::MIN_PAR_SIZE,
set_fixed_val, set_val,
structs::ROMType,
witness::RowMajorMatrix,
};

const NUM_U8_PAIRS: usize = 1 << 16;

#[derive(Clone, Debug)]
pub struct U8PairTableConfig {
tbl_a: Fixed,
tbl_b: Fixed,
mlt: WitIn,
}

pub struct U8PairTableCircuit<E>(PhantomData<E>);

impl<E: ExtensionField> U8PairTableCircuit<E> {
pub fn construct_circuit(cb: &mut CircuitBuilder<E>) -> Result<U8PairTableConfig, ZKVMError> {
let tbl_a = cb.create_fixed(|| "tbl_a")?;
let tbl_b = cb.create_fixed(|| "tbl_b")?;
let mlt = cb.create_witin(|| "mlt")?;

let rlc_record = cb.rlc_chip_record(vec![
Expression::Constant(E::BaseField::from(ROMType::U8Pair as u64)),
Expression::Fixed(tbl_a.clone()),
Expression::Fixed(tbl_b.clone()),
]);

cb.lk_table_record(|| "u8_pair_table", rlc_record, mlt.expr())?;

Ok(U8PairTableConfig { tbl_a, tbl_b, mlt })
}

pub fn generate_fixed_traces(
config: &U8PairTableConfig,
fixed: &mut RowMajorMatrix<E::BaseField>,
) {
assert!(fixed.num_instances() >= NUM_U8_PAIRS);

fixed
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip((0..NUM_U8_PAIRS).into_par_iter())
.for_each(|(row, i)| {
let a = i & 0xff;
let b = (i >> 8) & 0xff;
set_fixed_val!(row, config.tbl_a, E::BaseField::from(a as u64));
set_fixed_val!(row, config.tbl_b, E::BaseField::from(b as u64));
});

// Fill the rest with zeros.
fixed.par_iter_mut().skip(NUM_U8_PAIRS).for_each(|row| {
set_fixed_val!(row, config.tbl_a, E::BaseField::ZERO);
set_fixed_val!(row, config.tbl_b, E::BaseField::ZERO);
});
}

pub fn assign_instances(
config: &U8PairTableConfig,
multiplicity: &[HashMap<u64, usize>],
witness: &mut RowMajorMatrix<E::BaseField>,
) {
assert!(witness.num_instances() >= NUM_U8_PAIRS);

let mut mlts = vec![0; NUM_U8_PAIRS];
for (idx, mlt) in &multiplicity[ROMType::U8Pair as usize] {
mlts[*idx as usize] = *mlt;
}

witness
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(mlts.into_par_iter())
.for_each(|(row, mlt)| {
set_val!(row, config.mlt, E::BaseField::from(mlt as u64));
});

// Fill the rest with zeros.
witness.par_iter_mut().skip(NUM_U8_PAIRS).for_each(|row| {
set_val!(row, config.mlt, E::BaseField::ZERO);
});
}
}
2 changes: 1 addition & 1 deletion ceno_zkvm/src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ impl<const M: usize, const C: usize, E: ExtensionField> UInt<M, C, E> {
let limbs = (0..k)
.map(|_| {
let w = circuit_builder.create_witin(|| "").unwrap();
circuit_builder.assert_byte(|| "", w.expr()).unwrap();
circuit_builder.assert_byte_slow(|| "", w.expr()).unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit concern is about this extra overhead on original "assert_byte" since it will occur a slow version.

To address this, an option is iterating limbs via chunk(2) therefore assert_u8_pair works.

There are other place to invoke assert_byte() directly

cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?;

Which I believe for potiential performance pitfall we should refactor it traverse in chunk as well

w.expr()
})
.collect_vec();
Expand Down
20 changes: 12 additions & 8 deletions ceno_zkvm/src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl LkMultiplicity {
pub fn assert_ux<const C: usize>(&mut self, v: u64) {
match C {
16 => self.assert_u16(v),
8 => self.assert_byte(v),
8 => self.assert_byte(v as u8),
5 => self.assert_u5(v),
_ => panic!("Unsupported bit range"),
}
Expand All @@ -116,17 +116,17 @@ impl LkMultiplicity {
.or_default()) += 1;
}

fn assert_u16(&mut self, v: u64) {
pub fn assert_u8_pair(&mut self, a: u8, b: u8) {
let key = a as u64 | (b as u64) << 8;
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U16 as usize]
.entry(v)
(*multiplicity.borrow_mut()[ROMType::U8Pair as usize]
.entry(key)
.or_default()) += 1;
}

fn assert_byte(&mut self, v: u64) {
let v = v * (1 << u8::BITS);
fn assert_u16(&mut self, v: u64) {
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
Expand All @@ -135,6 +135,10 @@ impl LkMultiplicity {
.or_default()) += 1;
}

fn assert_byte(&mut self, v: u8) {
self.assert_u8_pair(v, 0)
}

/// lookup a < b as unsigned byte
pub fn lookup_ltu_limb8(&mut self, a: u64, b: u64) {
let key = a.wrapping_mul(256) + b;
Expand Down Expand Up @@ -185,10 +189,10 @@ mod tests {
// each thread calling assert_byte once
for _ in 0..thread_count {
let mut lkm = lkm.clone();
thread::spawn(move || lkm.assert_byte(8u64)).join().unwrap();
thread::spawn(move || lkm.assert_byte(8)).join().unwrap();
}
let res = lkm.into_finalize_result();
// check multiplicity counts of assert_byte
assert_eq!(res[ROMType::U16 as usize][&(8 << 8)], thread_count);
assert_eq!(res[ROMType::U8Pair as usize][&8], thread_count);
}
}
Loading