Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

u8 table #234

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR,
{
self.assert_u16(name_fn, expr * Expression::from(1 << 8))
let items: Vec<Expression<E>> = vec![(ROMType::U8 as usize).into(), expr];
let rlc_record = self.rlc_chip_record(items);
self.lk_record(name_fn, rlc_record)?;
Ok(())
}

pub(crate) fn assert_bit<NR, N>(
Expand All @@ -228,6 +231,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR,
{
// TODO: Replace with `x * (1 - x)` or a multi-bit lookup similar to assert_u8_pair.
self.assert_u16(name_fn, expr * Expression::from(1 << 15))
}

Expand Down
13 changes: 13 additions & 0 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,18 @@ fn load_tables<E: ExtensionField>(cb: &CircuitBuilder<E>, challenge: [E; 2]) ->
}
}

fn load_u8_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
for i in 0..=u8::MAX as usize {
let rlc_record = cb.rlc_chip_record(vec![(ROMType::U8 as usize).into(), i.into()]);
let rlc_record = eval_by_expr(&[], &challenge, &rlc_record);
t_vec.push(rlc_record.to_repr().as_ref().to_vec());
}
}

fn load_u16_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
Expand Down Expand Up @@ -338,6 +350,7 @@ fn load_tables<E: ExtensionField>(cb: &CircuitBuilder<E>, challenge: [E; 2]) ->
let mut table_vec = vec![];
// TODO load more tables here
load_u5_table(&mut table_vec, cb, challenge);
load_u8_table(&mut table_vec, cb, challenge);
load_u16_table(&mut table_vec, cb, challenge);
load_lt_table(&mut table_vec, cb, challenge);
load_and_table(&mut table_vec, cb, challenge);
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub type ChallengeId = u16;
#[derive(Debug)]
pub enum ROMType {
U5 = 0, // 2^5 = 32
U8, // 2^8 = 256
U16, // 2^16 = 65,536
And, // a ^ b where a, b are bytes
Ltu, // a <(usign) b where a, b are bytes
Expand Down
22 changes: 19 additions & 3 deletions ceno_zkvm/src/tables/range.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
mod u8_table;
use u8_table::U8TableConfig;

use std::{collections::HashMap, marker::PhantomData, mem::MaybeUninit};

use crate::{
Expand All @@ -18,6 +21,8 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterato
pub struct RangeTableConfig {
u16_tbl: Fixed,
u16_mlt: WitIn,

u8_config: U8TableConfig,
}

pub struct RangeTableCircuit<E>(PhantomData<E>);
Expand All @@ -42,7 +47,13 @@ impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {

cb.lk_table_record(|| "u16 table", u16_table_values, u16_mlt.expr())?;

Ok(RangeTableConfig { u16_tbl, u16_mlt })
let u8_config = cb.namespace(|| "u8_table", |cb| U8TableConfig::construct_circuit(cb))?;

Ok(RangeTableConfig {
u16_tbl,
u16_mlt,
u8_config,
})
}

fn generate_fixed_traces(
Expand All @@ -60,6 +71,8 @@ impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {
set_fixed_val!(row, config.u16_tbl, E::BaseField::from(i as u64));
});

config.u8_config.generate_fixed_traces(&mut fixed);

fixed
}

Expand All @@ -69,9 +82,8 @@ impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {
multiplicity: &[HashMap<u64, usize>],
_input: &(),
) -> Result<RowMajorMatrix<E::BaseField>, ZKVMError> {
let multiplicity = &multiplicity[ROMType::U16 as usize];
let mut u16_mlt = vec![0; 1 << RANGE_CHIP_BIT_WIDTH];
for (limb, mlt) in multiplicity {
for (limb, mlt) in &multiplicity[ROMType::U16 as usize] {
u16_mlt[*limb as usize] = *mlt;
}

Expand All @@ -84,6 +96,10 @@ impl<E: ExtensionField> TableCircuit<E> for RangeTableCircuit<E> {
set_val!(row, config.u16_mlt, E::BaseField::from(mlt as u64));
});

config
.u8_config
.assign_instances(multiplicity, &mut witness);

Ok(witness)
}
}
83 changes: 83 additions & 0 deletions ceno_zkvm/src/tables/range/u8_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use ff_ext::ExtensionField;
use goldilocks::SmallField;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use std::{collections::HashMap, mem::MaybeUninit};

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
scheme::constants::MIN_PAR_SIZE,
set_fixed_val, set_val,
structs::ROMType,
witness::RowMajorMatrix,
};

const NUM_U8: usize = 1 << 8;

#[derive(Clone, Debug)]
pub struct U8TableConfig {
u8_fixed: Fixed,
mlt: WitIn,
}

impl U8TableConfig {
pub fn construct_circuit<E: ExtensionField>(
cb: &mut CircuitBuilder<E>,
) -> Result<Self, ZKVMError> {
let u8_fixed = cb.create_fixed(|| "u8_fixed")?;
let mlt = cb.create_witin(|| "mlt")?;

let rlc_record = cb.rlc_chip_record(vec![
Expression::Constant(E::BaseField::from(ROMType::U8 as u64)),
Expression::Fixed(u8_fixed.clone()),
]);

cb.lk_table_record(|| "u8_record", rlc_record, mlt.expr())?;

Ok(Self { u8_fixed, mlt })
}

pub fn generate_fixed_traces<F: SmallField>(&self, fixed: &mut RowMajorMatrix<F>) {
assert!(fixed.num_instances() >= NUM_U8);

fixed
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip((0..NUM_U8).into_par_iter())
.for_each(|(row, i)| {
set_fixed_val!(row, self.u8_fixed, F::from(i as u64));
});

// Fill the rest with zeros, if any.
fixed.par_iter_mut().skip(NUM_U8).for_each(|row| {
set_fixed_val!(row, self.u8_fixed, F::ZERO);
});
}

pub fn assign_instances<F: SmallField>(
&self,
multiplicity: &[HashMap<u64, usize>],
witness: &mut RowMajorMatrix<F>,
) {
assert!(witness.num_instances() >= NUM_U8);

let mut mlts = vec![0; NUM_U8];
for (idx, mlt) in &multiplicity[ROMType::U8 as usize] {
mlts[*idx as usize] = *mlt;
}

witness
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(mlts.into_par_iter())
.for_each(|(row, mlt)| {
set_val!(row, self.mlt, F::from(mlt as u64));
});

// Fill the rest with zeros, if any.
witness.par_iter_mut().skip(NUM_U8).for_each(|row| {
set_val!(row, self.mlt, F::ZERO);
});
}
}
5 changes: 2 additions & 3 deletions ceno_zkvm/src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,10 @@ impl LkMultiplicity {
}

fn assert_byte(&mut self, v: u64) {
let v = v * (1 << u8::BITS);
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U16 as usize]
(*multiplicity.borrow_mut()[ROMType::U8 as usize]
.entry(v)
.or_default()) += 1;
}
Expand Down Expand Up @@ -189,6 +188,6 @@ mod tests {
}
let res = lkm.into_finalize_result();
// check multiplicity counts of assert_byte
assert_eq!(res[ROMType::U16 as usize][&(8 << 8)], thread_count);
assert_eq!(res[ROMType::U8 as usize][&8], thread_count);
}
}
Loading