diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index dfd9056b7..78b01563e 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -197,17 +197,17 @@ impl ops::AddAssign for ByteAddr { } pub trait IterAddresses { - fn iter_addresses(&self) -> impl Iterator; + fn iter_addresses(&self) -> impl ExactSizeIterator; } impl IterAddresses for Range { - fn iter_addresses(&self) -> impl Iterator { + fn iter_addresses(&self) -> impl ExactSizeIterator { self.clone().step_by(WORD_SIZE) } } impl<'a, T: GetAddr> IterAddresses for &'a [T] { - fn iter_addresses(&self) -> impl Iterator { + fn iter_addresses(&self) -> impl ExactSizeIterator { self.iter().map(T::get_addr) } } diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index fb11d166d..b0fa0afe7 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -12,6 +12,7 @@ pub struct Platform { pub rom: Range, pub ram: Range, pub public_io: Range, + pub private_io: Range, pub stack_top: Addr, /// If true, ecall instructions are no-op instead of trap. Testing only. pub unsafe_ecall_nop: bool, @@ -21,6 +22,7 @@ pub const CENO_PLATFORM: Platform = Platform { rom: 0x2000_0000..0x3000_0000, ram: 0x8000_0000..0xFFFF_0000, public_io: 0x3000_1000..0x3000_2000, + private_io: 0x4000_0000..0x5000_0000, stack_top: 0xC0000000, unsafe_ecall_nop: false, }; @@ -40,6 +42,10 @@ impl Platform { self.public_io.contains(&addr) } + pub fn is_priv_io(&self, addr: Addr) -> bool { + self.private_io.contains(&addr) + } + /// Virtual address of a register. pub const fn register_vma(index: RegIdx) -> Addr { // Register VMAs are aligned, cannot be confused with indices, and readable in hex. @@ -60,7 +66,7 @@ impl Platform { // Permissions. pub fn can_read(&self, addr: Addr) -> bool { - self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) + self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) || self.is_priv_io(addr) } pub fn can_write(&self, addr: Addr) -> bool { diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 0711bd58f..67a314683 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -275,6 +275,7 @@ fn main() { ®_final, &mem_final, &public_io_final, + &[], ) .unwrap(); diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index a4a2b6087..9da848b7e 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -1,6 +1,6 @@ use ceno_emul::{ - ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, Platform, StepRecord, Tracer, VMState, - WORD_SIZE, WordAddr, + ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, IterAddresses, Platform, StepRecord, + Tracer, VMState, WORD_SIZE, Word, WordAddr, }; use ceno_zkvm::{ instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, @@ -19,7 +19,9 @@ use itertools::{Itertools, MinMaxResult, chain, enumerate}; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; use std::{ collections::{HashMap, HashSet}, - fs, panic, + fs, + iter::zip, + panic, time::Instant, }; use tracing::level_filters::LevelFilter; @@ -41,6 +43,11 @@ struct Args { /// The preset configuration to use. #[arg(short, long, value_enum, default_value_t = Preset::Ceno)] platform: Preset, + + /// The private input or hints. This is a raw file mounted as a memory segment. + /// Zero-padded to the next power-of-two size. + #[arg(long)] + private_input: Option, } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] @@ -94,6 +101,17 @@ fn main() { let elf_bytes = fs::read(&args.elf).expect("read elf file"); let mut vm = VMState::new_from_elf(platform.clone(), &elf_bytes).unwrap(); + tracing::info!("Loading private input file: {:?}", args.private_input); + let priv_io = memory_from_file(&args.private_input); + assert!( + priv_io.len() <= platform.private_io.iter_addresses().len(), + "private input must fit in {} bytes", + platform.private_io.len() + ); + for (addr, value) in zip(platform.private_io.iter_addresses(), &priv_io) { + vm.init_memory(addr.into(), *value); + } + // keygen let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); @@ -249,6 +267,14 @@ fn main() { .map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0)) .collect_vec(); + let priv_io_final = zip(platform.private_io.iter_addresses(), &priv_io) + .map(|(addr, &value)| MemFinalRecord { + addr, + value, + cycle: *final_access.get(&addr.into()).unwrap_or(&0), + }) + .collect_vec(); + // assign table circuits config .assign_table_circuit(&zkvm_cs, &mut zkvm_witness) @@ -260,6 +286,7 @@ fn main() { ®_final, &mem_final, &io_final, + &priv_io_final, ) .unwrap(); // assign program circuit @@ -332,6 +359,18 @@ fn main() { }; } +fn memory_from_file(path: &Option) -> Vec { + path.as_ref() + .map(|path| { + let mut buf = fs::read(path).expect("could not read file"); + buf.resize(buf.len().next_multiple_of(WORD_SIZE), 0); + buf.chunks_exact(WORD_SIZE) + .map(|word| Word::from_le_bytes(word.try_into().unwrap())) + .collect_vec() + }) + .unwrap_or_default() +} + fn debug_memory_ranges(vm: &VMState, mem_final: &[MemFinalRecord]) { let accessed_addrs = vm .tracer() diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 66eafd891..fe1722728 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -8,8 +8,8 @@ use crate::{ error::ZKVMError, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ - MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RegTable, - RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit, + MemFinalRecord, MemInitRecord, NonVolatileTable, PrivateIOCircuit, PubIOCircuit, + PubIOTable, RegTable, RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit, }, }; @@ -20,6 +20,8 @@ pub struct MmuConfig { pub static_mem_config: as TableCircuit>::TableConfig, /// Initialization of public IO. pub public_io_config: as TableCircuit>::TableConfig, + /// Initialization of private IO. + pub private_io_config: as TableCircuit>::TableConfig, pub params: ProgramParams, } @@ -30,11 +32,13 @@ impl MmuConfig { let static_mem_config = cs.register_table_circuit::>(); let public_io_config = cs.register_table_circuit::>(); + let private_io_config = cs.register_table_circuit::>(); Self { reg_config, static_mem_config, public_io_config, + private_io_config, params: cs.params.clone(), } } @@ -48,7 +52,13 @@ impl MmuConfig { io_addrs: &[Addr], ) { assert!( - chain!(static_mem_init.iter_addresses(), io_addrs.iter_addresses()).all_unique(), + chain!( + static_mem_init.iter_addresses(), + io_addrs.iter_addresses(), + // TODO: optimize with min_max and Range. + self.params.platform.private_io.iter_addresses(), + ) + .all_unique(), "memory addresses must be unique" ); @@ -61,6 +71,7 @@ impl MmuConfig { ); fixed.register_table_circuit::>(cs, &self.public_io_config, io_addrs); + fixed.register_table_circuit::>(cs, &self.private_io_config, &()); } pub fn assign_table_circuit( @@ -70,6 +81,7 @@ impl MmuConfig { reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], io_cycles: &[Cycle], + private_io_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { witness.assign_table_circuit::>(cs, &self.reg_config, reg_final)?; @@ -81,6 +93,12 @@ impl MmuConfig { witness.assign_table_circuit::>(cs, &self.public_io_config, io_cycles)?; + witness.assign_table_circuit::>( + cs, + &self.private_io_config, + private_io_final, + )?; + Ok(()) } diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index 88e968e69..f6fcb0282 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -34,25 +34,25 @@ impl DynVolatileRamTable for DynMemTable { pub type DynMemCircuit = DynVolatileRamCircuit; #[derive(Clone)] -pub struct PrivateMemTable; -impl DynVolatileRamTable for PrivateMemTable { +pub struct PrivateIOTable; +impl DynVolatileRamTable for PrivateIOTable { const RAM_TYPE: RAMType = RAMType::Memory; const V_LIMBS: usize = 1; // See `MemoryExpr`. const ZERO_INIT: bool = false; fn offset_addr(params: &ProgramParams) -> Addr { - params.platform.ram.start + params.platform.private_io.start } fn end_addr(params: &ProgramParams) -> Addr { - params.platform.ram.end + params.platform.private_io.end } fn name() -> &'static str { - "PrivateMemTable" + "PrivateIOTable" } } -pub type PrivateMemCircuit = DynVolatileRamCircuit; +pub type PrivateIOCircuit = DynVolatileRamCircuit; /// RegTable, fix size without offset #[derive(Clone)] diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 584489781..234ff7799 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -189,7 +189,7 @@ impl TableC type WitnessInput = [MemFinalRecord]; fn name() -> String { - format!("RAM_{:?}", DVRAM::RAM_TYPE) + format!("RAM_{:?}_{}", DVRAM::RAM_TYPE, DVRAM::name()) } fn construct_circuit(cb: &mut CircuitBuilder) -> Result { diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index d293e5d83..2f16c6bff 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -369,15 +369,17 @@ impl DynVolatileRamTableConfig ) -> Result, ZKVMError> { assert!(final_mem.len() <= DVRAM::max_len(&self.params)); assert!(DVRAM::max_len(&self.params).is_power_of_two()); - let mut final_table = - RowMajorMatrix::::new(final_mem.len().next_power_of_two(), num_witness); + let mut final_table = RowMajorMatrix::::new(final_mem.len(), num_witness); final_table .par_iter_mut() .with_min_len(MIN_PAR_SIZE) .zip(final_mem.into_par_iter()) - .for_each(|(row, rec)| { + .enumerate() + .for_each(|(i, (row, rec))| { + assert_eq!(rec.addr, DVRAM::addr(&self.params, i)); set_val!(row, self.addr, rec.addr as u64); + if self.final_v.len() == 1 { // Assign value directly. set_val!(row, self.final_v[0], rec.value as u64);