Skip to content

Commit

Permalink
Feat: Private input integration (#622)
Browse files Browse the repository at this point in the history
_Issue #614_.
_Pending a solution in #625_

- Load the private input from a file.
- Configuration of the address range.
- Integrate the circuit  (#617) in MMU.

---------

Co-authored-by: Aurélien Nicolas <[email protected]>
  • Loading branch information
naure and Aurélien Nicolas authored Nov 25, 2024
1 parent 3a72862 commit a76d586
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 20 deletions.
6 changes: 3 additions & 3 deletions ceno_emul/src/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,17 @@ impl ops::AddAssign<u32> for ByteAddr {
}

pub trait IterAddresses {
fn iter_addresses(&self) -> impl Iterator<Item = Addr>;
fn iter_addresses(&self) -> impl ExactSizeIterator<Item = Addr>;
}

impl IterAddresses for Range<Addr> {
fn iter_addresses(&self) -> impl Iterator<Item = Addr> {
fn iter_addresses(&self) -> impl ExactSizeIterator<Item = Addr> {
self.clone().step_by(WORD_SIZE)
}
}

impl<'a, T: GetAddr> IterAddresses for &'a [T] {
fn iter_addresses(&self) -> impl Iterator<Item = Addr> {
fn iter_addresses(&self) -> impl ExactSizeIterator<Item = Addr> {
self.iter().map(T::get_addr)
}
}
Expand Down
8 changes: 7 additions & 1 deletion ceno_emul/src/platform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct Platform {
pub rom: Range<Addr>,
pub ram: Range<Addr>,
pub public_io: Range<Addr>,
pub private_io: Range<Addr>,
pub stack_top: Addr,
/// If true, ecall instructions are no-op instead of trap. Testing only.
pub unsafe_ecall_nop: bool,
Expand All @@ -21,6 +22,7 @@ pub const CENO_PLATFORM: Platform = Platform {
rom: 0x2000_0000..0x3000_0000,
ram: 0x8000_0000..0xFFFF_0000,
public_io: 0x3000_1000..0x3000_2000,
private_io: 0x4000_0000..0x5000_0000,
stack_top: 0xC0000000,
unsafe_ecall_nop: false,
};
Expand All @@ -40,6 +42,10 @@ impl Platform {
self.public_io.contains(&addr)
}

pub fn is_priv_io(&self, addr: Addr) -> bool {
self.private_io.contains(&addr)
}

/// Virtual address of a register.
pub const fn register_vma(index: RegIdx) -> Addr {
// Register VMAs are aligned, cannot be confused with indices, and readable in hex.
Expand All @@ -60,7 +66,7 @@ impl Platform {
// Permissions.

pub fn can_read(&self, addr: Addr) -> bool {
self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr)
self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) || self.is_priv_io(addr)
}

pub fn can_write(&self, addr: Addr) -> bool {
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/examples/riscv_opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ fn main() {
&reg_final,
&mem_final,
&public_io_final,
&[],
)
.unwrap();

Expand Down
45 changes: 42 additions & 3 deletions ceno_zkvm/src/bin/e2e.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ceno_emul::{
ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, Platform, StepRecord, Tracer, VMState,
WORD_SIZE, WordAddr,
ByteAddr, CENO_PLATFORM, EmuContext, InsnKind::EANY, IterAddresses, Platform, StepRecord,
Tracer, VMState, WORD_SIZE, Word, WordAddr,
};
use ceno_zkvm::{
instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig},
Expand All @@ -19,7 +19,9 @@ use itertools::{Itertools, MinMaxResult, chain, enumerate};
use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme};
use std::{
collections::{HashMap, HashSet},
fs, panic,
fs,
iter::zip,
panic,
time::Instant,
};
use tracing::level_filters::LevelFilter;
Expand All @@ -41,6 +43,11 @@ struct Args {
/// The preset configuration to use.
#[arg(short, long, value_enum, default_value_t = Preset::Ceno)]
platform: Preset,

/// The private input or hints. This is a raw file mounted as a memory segment.
/// Zero-padded to the next power-of-two size.
#[arg(long)]
private_input: Option<String>,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
Expand Down Expand Up @@ -94,6 +101,17 @@ fn main() {
let elf_bytes = fs::read(&args.elf).expect("read elf file");
let mut vm = VMState::new_from_elf(platform.clone(), &elf_bytes).unwrap();

tracing::info!("Loading private input file: {:?}", args.private_input);
let priv_io = memory_from_file(&args.private_input);
assert!(
priv_io.len() <= platform.private_io.iter_addresses().len(),
"private input must fit in {} bytes",
platform.private_io.len()
);
for (addr, value) in zip(platform.private_io.iter_addresses(), &priv_io) {
vm.init_memory(addr.into(), *value);
}

// keygen
let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup");
let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim");
Expand Down Expand Up @@ -249,6 +267,14 @@ fn main() {
.map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0))
.collect_vec();

let priv_io_final = zip(platform.private_io.iter_addresses(), &priv_io)
.map(|(addr, &value)| MemFinalRecord {
addr,
value,
cycle: *final_access.get(&addr.into()).unwrap_or(&0),
})
.collect_vec();

// assign table circuits
config
.assign_table_circuit(&zkvm_cs, &mut zkvm_witness)
Expand All @@ -260,6 +286,7 @@ fn main() {
&reg_final,
&mem_final,
&io_final,
&priv_io_final,
)
.unwrap();
// assign program circuit
Expand Down Expand Up @@ -332,6 +359,18 @@ fn main() {
};
}

fn memory_from_file(path: &Option<String>) -> Vec<u32> {
path.as_ref()
.map(|path| {
let mut buf = fs::read(path).expect("could not read file");
buf.resize(buf.len().next_multiple_of(WORD_SIZE), 0);
buf.chunks_exact(WORD_SIZE)
.map(|word| Word::from_le_bytes(word.try_into().unwrap()))
.collect_vec()
})
.unwrap_or_default()
}

fn debug_memory_ranges(vm: &VMState, mem_final: &[MemFinalRecord]) {
let accessed_addrs = vm
.tracer()
Expand Down
24 changes: 21 additions & 3 deletions ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use crate::{
error::ZKVMError,
structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
tables::{
MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RegTable,
RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit,
MemFinalRecord, MemInitRecord, NonVolatileTable, PrivateIOCircuit, PubIOCircuit,
PubIOTable, RegTable, RegTableCircuit, StaticMemCircuit, StaticMemTable, TableCircuit,
},
};

Expand All @@ -20,6 +20,8 @@ pub struct MmuConfig<E: ExtensionField> {
pub static_mem_config: <StaticMemCircuit<E> as TableCircuit<E>>::TableConfig,
/// Initialization of public IO.
pub public_io_config: <PubIOCircuit<E> as TableCircuit<E>>::TableConfig,
/// Initialization of private IO.
pub private_io_config: <PrivateIOCircuit<E> as TableCircuit<E>>::TableConfig,
pub params: ProgramParams,
}

Expand All @@ -30,11 +32,13 @@ impl<E: ExtensionField> MmuConfig<E> {
let static_mem_config = cs.register_table_circuit::<StaticMemCircuit<E>>();

let public_io_config = cs.register_table_circuit::<PubIOCircuit<E>>();
let private_io_config = cs.register_table_circuit::<PrivateIOCircuit<E>>();

Self {
reg_config,
static_mem_config,
public_io_config,
private_io_config,
params: cs.params.clone(),
}
}
Expand All @@ -48,7 +52,13 @@ impl<E: ExtensionField> MmuConfig<E> {
io_addrs: &[Addr],
) {
assert!(
chain!(static_mem_init.iter_addresses(), io_addrs.iter_addresses()).all_unique(),
chain!(
static_mem_init.iter_addresses(),
io_addrs.iter_addresses(),
// TODO: optimize with min_max and Range.
self.params.platform.private_io.iter_addresses(),
)
.all_unique(),
"memory addresses must be unique"
);

Expand All @@ -61,6 +71,7 @@ impl<E: ExtensionField> MmuConfig<E> {
);

fixed.register_table_circuit::<PubIOCircuit<E>>(cs, &self.public_io_config, io_addrs);
fixed.register_table_circuit::<PrivateIOCircuit<E>>(cs, &self.private_io_config, &());
}

pub fn assign_table_circuit(
Expand All @@ -70,6 +81,7 @@ impl<E: ExtensionField> MmuConfig<E> {
reg_final: &[MemFinalRecord],
static_mem_final: &[MemFinalRecord],
io_cycles: &[Cycle],
private_io_final: &[MemFinalRecord],
) -> Result<(), ZKVMError> {
witness.assign_table_circuit::<RegTableCircuit<E>>(cs, &self.reg_config, reg_final)?;

Expand All @@ -81,6 +93,12 @@ impl<E: ExtensionField> MmuConfig<E> {

witness.assign_table_circuit::<PubIOCircuit<E>>(cs, &self.public_io_config, io_cycles)?;

witness.assign_table_circuit::<PrivateIOCircuit<E>>(
cs,
&self.private_io_config,
private_io_final,
)?;

Ok(())
}

Expand Down
12 changes: 6 additions & 6 deletions ceno_zkvm/src/tables/ram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,25 @@ impl DynVolatileRamTable for DynMemTable {
pub type DynMemCircuit<E> = DynVolatileRamCircuit<E, DynMemTable>;

#[derive(Clone)]
pub struct PrivateMemTable;
impl DynVolatileRamTable for PrivateMemTable {
pub struct PrivateIOTable;
impl DynVolatileRamTable for PrivateIOTable {
const RAM_TYPE: RAMType = RAMType::Memory;
const V_LIMBS: usize = 1; // See `MemoryExpr`.
const ZERO_INIT: bool = false;

fn offset_addr(params: &ProgramParams) -> Addr {
params.platform.ram.start
params.platform.private_io.start
}

fn end_addr(params: &ProgramParams) -> Addr {
params.platform.ram.end
params.platform.private_io.end
}

fn name() -> &'static str {
"PrivateMemTable"
"PrivateIOTable"
}
}
pub type PrivateMemCircuit<E> = DynVolatileRamCircuit<E, PrivateMemTable>;
pub type PrivateIOCircuit<E> = DynVolatileRamCircuit<E, PrivateIOTable>;

/// RegTable, fix size without offset
#[derive(Clone)]
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/tables/ram/ram_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ impl<E: ExtensionField, DVRAM: DynVolatileRamTable + Send + Sync + Clone> TableC
type WitnessInput = [MemFinalRecord];

fn name() -> String {
format!("RAM_{:?}", DVRAM::RAM_TYPE)
format!("RAM_{:?}_{}", DVRAM::RAM_TYPE, DVRAM::name())
}

fn construct_circuit(cb: &mut CircuitBuilder<E>) -> Result<Self::TableConfig, ZKVMError> {
Expand Down
8 changes: 5 additions & 3 deletions ceno_zkvm/src/tables/ram/ram_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,15 +369,17 @@ impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfig
) -> Result<RowMajorMatrix<F>, ZKVMError> {
assert!(final_mem.len() <= DVRAM::max_len(&self.params));
assert!(DVRAM::max_len(&self.params).is_power_of_two());
let mut final_table =
RowMajorMatrix::<F>::new(final_mem.len().next_power_of_two(), num_witness);
let mut final_table = RowMajorMatrix::<F>::new(final_mem.len(), num_witness);

final_table
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(final_mem.into_par_iter())
.for_each(|(row, rec)| {
.enumerate()
.for_each(|(i, (row, rec))| {
assert_eq!(rec.addr, DVRAM::addr(&self.params, i));
set_val!(row, self.addr, rec.addr as u64);

if self.final_v.len() == 1 {
// Assign value directly.
set_val!(row, self.final_v[0], rec.value as u64);
Expand Down

0 comments on commit a76d586

Please sign in to comment.