From 0ea9248dc6bffbb3e7d6c8aaad7039567e8860e1 Mon Sep 17 00:00:00 2001 From: naure Date: Tue, 26 Nov 2024 12:59:48 +0100 Subject: [PATCH] Revert "fix/program-size2: refactor padding_zero (#615)" (#638) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit c548dc88fccac94fcb529105169ac44b32c66c0c. --------- Co-authored-by: Aurélien Nicolas --- ceno_zkvm/src/instructions.rs | 2 +- ceno_zkvm/src/tables/mod.rs | 49 ++++++-------- ceno_zkvm/src/tables/program.rs | 109 +++++++++++++++++++++++++------- ceno_zkvm/src/witness.rs | 9 ++- 4 files changed, 110 insertions(+), 59 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index e87675dd6..63314cbee 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -94,7 +94,7 @@ pub trait Instruction { num_padding_instances }; raw_witin - .par_batch_iter_padding_mut(None, num_padding_instance_per_batch) + .par_batch_iter_padding_mut(num_padding_instance_per_batch) .with_min_len(MIN_PAR_SIZE) .for_each(|row| { row.chunks_mut(num_witin) diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index f498b868e..2ef7e293a 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -2,8 +2,8 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, scheme::constants::MIN_PAR_SIZE, witness::RowMajorMatrix, }; +use ff::Field; use ff_ext::ExtensionField; -use goldilocks::SmallField; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use std::{collections::HashMap, mem::MaybeUninit}; mod range; @@ -46,34 +46,25 @@ pub trait TableCircuit { table: &mut RowMajorMatrix, num_witin: usize, ) -> Result<(), ZKVMError> { - padding_zero(table, num_witin, None); + // Fill the padding with zeros, if any. + let num_padding_instances = table.num_padding_instances(); + if num_padding_instances > 0 { + let nthreads = + std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let padding_instance = vec![MaybeUninit::new(E::BaseField::ZERO); num_witin]; + let num_padding_instance_per_batch = if num_padding_instances > 256 { + num_padding_instances.div_ceil(nthreads) + } else { + num_padding_instances + }; + table + .par_batch_iter_padding_mut(num_padding_instance_per_batch) + .with_min_len(MIN_PAR_SIZE) + .for_each(|row| { + row.chunks_mut(num_witin) + .for_each(|instance| instance.copy_from_slice(padding_instance.as_slice())); + }); + } Ok(()) } } - -/// Fill the padding with zeros. Start after the given `num_instances`, or detect it from the table. -pub fn padding_zero( - table: &mut RowMajorMatrix, - num_cols: usize, - num_instances: Option, -) { - // Fill the padding with zeros, if any. - let num_padding_instances = table.num_padding_instances(); - if num_padding_instances > 0 { - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); - let padding_instance = vec![MaybeUninit::new(F::ZERO); num_cols]; - let num_padding_instance_per_batch = if num_padding_instances > 256 { - num_padding_instances.div_ceil(nthreads) - } else { - num_padding_instances - }; - table - .par_batch_iter_padding_mut(num_instances, num_padding_instance_per_batch) - .with_min_len(MIN_PAR_SIZE) - .for_each(|row| { - row.chunks_mut(num_cols) - .for_each(|instance| instance.copy_from_slice(padding_instance.as_slice())); - }); - } -} diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index f1d38ada8..da063545e 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -7,7 +7,7 @@ use crate::{ scheme::constants::MIN_PAR_SIZE, set_fixed_val, set_val, structs::ROMType, - tables::{TableCircuit, padding_zero}, + tables::TableCircuit, utils::i64_to_base, witness::RowMajorMatrix, }; @@ -136,7 +136,7 @@ impl TableCircuit for ProgramTableCircuit { cb.lk_table_record( || "prog table", - cb.params.program_size, + cb.params.program_size.next_power_of_two(), ROMType::Instruction, record_exprs, mlt.expr(), @@ -176,7 +176,15 @@ impl TableCircuit for ProgramTableCircuit { }); assert_eq!(INVALID as u64, 0, "0 padding must be invalid instructions"); - padding_zero(&mut fixed, num_fixed, Some(num_instructions)); + fixed + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .skip(num_instructions) + .for_each(|row| { + for col in config.record.as_slice() { + set_fixed_val!(row, *col, 0_u64.into()); + } + }); fixed } @@ -204,32 +212,85 @@ impl TableCircuit for ProgramTableCircuit { set_val!(row, config.mlt, E::BaseField::from(mlt as u64)); }); - padding_zero(&mut witness, num_witin, Some(program.instructions.len())); + witness + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .skip(program.instructions.len()) + .for_each(|row| { + set_val!(row, config.mlt, 0_u64); + }); Ok(witness) } } #[cfg(test)] -#[test] -#[allow(clippy::identity_op)] -fn test_decode_imm() { - for (i, expected) in [ - // Example of I-type: ADDI. - // imm | rs1 | funct3 | rd | opcode - (89 << 20 | 1 << 15 | 0b000 << 12 | 1 << 7 | 0x13, 89), - // Shifts get a precomputed power of 2: SLLI, SRLI, SRAI. - (31 << 20 | 1 << 15 | 0b001 << 12 | 1 << 7 | 0x13, 1 << 31), - (31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, 1 << 31), - ( - 1 << 30 | 31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, - 1 << 31, - ), - // Example of R-type with funct7: SUB. - // funct7 | rs2 | rs1 | funct3 | rd | opcode - (0x20 << 25 | 1 << 20 | 1 << 15 | 0 << 12 | 1 << 7 | 0x33, 0), - ] { - let imm = InsnRecord::imm_internal(&DecodedInstruction::new(i)); - assert_eq!(imm, expected); +mod tests { + use super::*; + use crate::{circuit_builder::ConstraintSystem, witness::LkMultiplicity}; + use ceno_emul::encode_rv32; + use ff::Field; + use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; + + #[test] + #[allow(clippy::identity_op)] + fn test_decode_imm() { + for (i, expected) in [ + // Example of I-type: ADDI. + // imm | rs1 | funct3 | rd | opcode + (89 << 20 | 1 << 15 | 0b000 << 12 | 1 << 7 | 0x13, 89), + // Shifts get a precomputed power of 2: SLLI, SRLI, SRAI. + (31 << 20 | 1 << 15 | 0b001 << 12 | 1 << 7 | 0x13, 1 << 31), + (31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, 1 << 31), + ( + 1 << 30 | 31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, + 1 << 31, + ), + // Example of R-type with funct7: SUB. + // funct7 | rs2 | rs1 | funct3 | rd | opcode + (0x20 << 25 | 1 << 20 | 1 << 15 | 0 << 12 | 1 << 7 | 0x33, 0), + ] { + let imm = InsnRecord::imm_internal(&DecodedInstruction::new(i)); + assert_eq!(imm, expected); + } + } + + #[test] + fn test_program_padding() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + + let actual_len = 3; + let instructions = vec![encode_rv32(ADD, 1, 2, 3, 0); actual_len]; + let program = Program::new(0x2000_0000, 0x2000_0000, instructions, Default::default()); + + let config = ProgramTableCircuit::construct_circuit(&mut cb).unwrap(); + + let check = |matrix: &RowMajorMatrix| { + assert_eq!( + matrix.num_instances() + matrix.num_padding_instances(), + cb.params.program_size + ); + for row in matrix.iter_rows().skip(actual_len) { + for col in row.iter() { + assert_eq!(unsafe { col.assume_init() }, F::ZERO); + } + } + }; + + let fixed = + ProgramTableCircuit::::generate_fixed_traces(&config, cb.cs.num_fixed, &program); + check(&fixed); + + let lkm = LkMultiplicity::default().into_finalize_result(); + + let witness = ProgramTableCircuit::::assign_instances( + &config, + cb.cs.num_witin as usize, + &lkm, + &program, + ) + .unwrap(); + check(&witness); } } diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index e85360aee..7acc9ad50 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -87,13 +87,12 @@ impl RowMajorMatrix { pub fn par_batch_iter_padding_mut( &mut self, - num_instances: Option, - batch_size: usize, + num_rows: usize, ) -> rayon::slice::ChunksMut<'_, MaybeUninit> { - let num_instances = num_instances.unwrap_or(self.num_instances()); - self.values[num_instances * self.num_col..] + let valid_instance = self.num_instances(); + self.values[valid_instance * self.num_col..] .as_mut() - .par_chunks_mut(batch_size * self.num_col) + .par_chunks_mut(num_rows * self.num_col) } pub fn de_interleaving(mut self) -> Vec> {