diff --git a/Cargo.lock b/Cargo.lock index 3d36b3144..6cc7ba700 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,9 +108,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f37166d7d48a0284b99dd824694c26119c700b53bf0d1540cdb147dbdaaf13" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "ark-std" @@ -1526,9 +1526,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" dependencies = [ "bitflags 2.6.0", "errno", @@ -1743,9 +1743,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" dependencies = [ "cfg-if", "fastrand", diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index e60493cb8..d8ec28ab0 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -11,7 +11,7 @@ mod vm_state; pub use vm_state::VMState; mod rv32im; -pub use rv32im::{DecodedInstruction, EmuContext, InsnCodes, InsnFormat, InsnKind}; +pub use rv32im::{DecodedInstruction, EmuContext, InsnCategory, InsnCodes, InsnFormat, InsnKind}; mod elf; pub use elf::Program; diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 511981ba0..12ccaea97 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -11,6 +11,8 @@ pub struct Platform { pub rom_end: Addr, pub ram_start: Addr, pub ram_end: Addr, + /// If true, ecall instructions are no-op instead of trap. Testing only. + pub unsafe_ecall_nop: bool, } pub const CENO_PLATFORM: Platform = Platform { @@ -18,6 +20,7 @@ pub const CENO_PLATFORM: Platform = Platform { rom_end: 0x3000_0000 - 1, ram_start: 0x8000_0000, ram_end: 0xFFFF_FFFF, + unsafe_ecall_nop: false, }; impl Platform { diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index cd3f24a74..bea4bd5c9 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -111,7 +111,7 @@ pub struct DecodedInstruction { } #[derive(Clone, Copy, Debug)] -enum InsnCategory { +pub enum InsnCategory { Compute, Branch, Load, @@ -196,7 +196,7 @@ impl InsnKind { pub struct InsnCodes { pub format: InsnFormat, pub kind: InsnKind, - category: InsnCategory, + pub category: InsnCategory, pub(crate) opcode: u32, pub(crate) func3: u32, pub(crate) func7: u32, diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index d17e410a5..7ed36369e 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -1,8 +1,9 @@ use std::{collections::HashMap, fmt, mem}; use crate::{ - CENO_PLATFORM, PC_STEP_SIZE, + CENO_PLATFORM, InsnKind, PC_STEP_SIZE, addr::{ByteAddr, Cycle, RegIdx, Word, WordAddr}, + encode_rv32, rv32im::DecodedInstruction, }; @@ -187,6 +188,28 @@ impl StepRecord { ) } + /// Create a test record for an ECALL instruction that can do anything. + pub fn new_ecall_any(cycle: Cycle, pc: ByteAddr) -> StepRecord { + let value = 1234; + Self::new_insn( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + encode_rv32(InsnKind::EANY, 0, 0, 0, 0), + Some(value), + Some(value), + Some(Change::new(value, value)), + Some(WriteOp { + addr: CENO_PLATFORM.ram_start().into(), + value: Change { + before: value, + after: value, + }, + previous_cycle: 0, + }), + 0, + ) + } + #[allow(clippy::too_many_arguments)] fn new_insn( cycle: Cycle, diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 5ae155fc8..bb7a64ed0 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use super::rv32im::EmuContext; use crate::{ - Program, + PC_STEP_SIZE, Program, addr::{ByteAddr, RegIdx, Word, WordAddr}, platform::Platform, rv32im::{DecodedInstruction, Emulator, TrapCause}, @@ -117,12 +117,21 @@ impl EmuContext for VMState { // Expect an ecall to terminate the program: function HALT with argument exit_code. fn ecall(&mut self) -> Result { let function = self.load_register(self.platform.reg_ecall())?; + let arg0 = self.load_register(self.platform.reg_arg0())?; if function == self.platform.ecall_halt() { - let exit_code = self.load_register(self.platform.reg_arg0())?; - tracing::debug!("halt with exit_code={}", exit_code); + tracing::debug!("halt with exit_code={}", arg0); self.halt(); Ok(true) + } else if self.platform.unsafe_ecall_nop { + // Treat unknown ecalls as all powerful instructions: + // Read two registers, write one register, write one memory word, and branch. + tracing::warn!("ecall ignored: syscall_id={}", function); + self.store_register(DecodedInstruction::RD_NULL as RegIdx, 0)?; + let addr = self.platform.ram_start().into(); + self.store_memory(addr, self.peek_memory(addr))?; + self.set_pc(ByteAddr(self.pc) + PC_STEP_SIZE); + Ok(true) } else { self.trap(TrapCause::EcallError) } diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 40f841be8..6e50e1d32 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -32,7 +32,7 @@ tracing-subscriber.workspace = true clap = { version = "4.5", features = ["derive"] } generic_static = "0.2" rand.workspace = true -tempfile = "3.13" +tempfile = "3.14" thread_local = "1.1" [dev-dependencies] diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 16d5cfe67..8b948a48a 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -45,7 +45,7 @@ fn bench_add(c: &mut Criterion) { zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); let param = Pcs::setup(1 << MAX_NUM_VARIABLES).unwrap(); - let (pp, vp) = Pcs::trim(¶m, 1 << MAX_NUM_VARIABLES).unwrap(); + let (pp, vp) = Pcs::trim(param, 1 << MAX_NUM_VARIABLES).unwrap(); let pk = zkvm_cs .clone() diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index fbb0e0a83..14ec5b348 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -1,4 +1,4 @@ -use std::{panic, time::Instant}; +use std::{collections::BTreeMap, panic, time::Instant}; use ceno_zkvm::{ declare_program, @@ -19,16 +19,17 @@ use ceno_emul::{ }; use ceno_zkvm::{ scheme::{PublicValues, constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier}, + stats::{StaticReport, TraceReport}, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, }; use ff_ext::ff::Field; use goldilocks::GoldilocksExt2; use itertools::Itertools; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; +use sumcheck::{entered_span, exit_span}; use tracing_flame::FlameLayer; -use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt}; +use tracing_subscriber::{EnvFilter, Registry, fmt, fmt::format::FmtSpan, layer::SubscriberExt}; use transcript::Transcript; - const PROGRAM_SIZE: usize = 16; // For now, we assume registers // - x0 is not touched, @@ -92,23 +93,36 @@ fn main() { .collect(), ); let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); + let mut fmt_layer = fmt::layer() + .compact() + .with_span_events(FmtSpan::CLOSE) + .with_thread_ids(false) + .with_thread_names(false); + fmt_layer.set_ansi(false); + + // Take filtering directives from RUST_LOG env_var + // Directive syntax: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives + // Example: RUST_LOG="info" cargo run.. to get spans/events at info level; profiling spans are info + // Example: RUST_LOG="[sumcheck]" cargo run.. to get only events under the "sumcheck" span + let filter = EnvFilter::from_default_env(); + let subscriber = Registry::default() - .with( - fmt::layer() - .compact() - .with_thread_ids(false) - .with_thread_names(false), - ) - .with(EnvFilter::from_default_env()) + .with(fmt_layer) + .with(filter) .with(flame_layer.with_threads_collapsed(true)); tracing::subscriber::set_global_default(subscriber).unwrap(); + let top_level = entered_span!("TOPLEVEL"); + + let keygen = entered_span!("KEYGEN"); + // keygen let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); - let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); + let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); + let prog_config = zkvm_cs.register_table_circuit::>(); zkvm_cs.register_global_state::(); @@ -120,6 +134,8 @@ fn main() { &program, ); + let static_report = StaticReport::new(&zkvm_cs); + let reg_init = initial_registers(); // Define program constant here let program_data: &[u32] = &[]; @@ -138,6 +154,7 @@ fn main() { .expect("keygen failed"); let vk = pk.get_vk(); + exit_span!(keygen); // proving let prover = ZKVMProver::new(pk); let verifier = ZKVMVerifier::new(vk); @@ -274,6 +291,15 @@ fn main() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); + // get instance counts from witness matrices + let trace_report = TraceReport::new_via_witnesses( + &static_report, + &zkvm_witness, + "EXAMPLE_PROGRAM in riscv_opcodes.rs", + ); + + trace_report.save_json("report.json"); + MockProver::assert_satisfied_full( zkvm_cs.clone(), zkvm_fixed_traces.clone(), @@ -284,6 +310,7 @@ fn main() { let timer = Instant::now(); let transcript = Transcript::new(b"riscv"); + let mut zkvm_proof = prover .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); @@ -291,7 +318,7 @@ fn main() { println!( "riscv_opcodes::create_proof, instance_num_vars = {}, time = {}", instance_num_vars, - timer.elapsed().as_secs_f64() + timer.elapsed().as_secs() ); let transcript = Transcript::new(b"riscv"); @@ -336,4 +363,5 @@ fn main() { } }; } + exit_span!(top_level); } diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index ec3b440db..6ae1380db 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -1,6 +1,6 @@ use ceno_emul::Addr; -use itertools::Itertools; -use std::{collections::HashMap, marker::PhantomData}; +use itertools::{Itertools, chain}; +use std::{collections::HashMap, iter::once, marker::PhantomData}; use ff_ext::ExtensionField; use mpcs::PolynomialCommitmentScheme; @@ -10,12 +10,12 @@ use crate::{ chip_handler::utils::rlc_chip_record, error::ZKVMError, expression::{Expression, Fixed, Instance, WitIn}, - structs::{ProvingKey, RAMType, VerifyingKey, WitnessId}, + structs::{ProvingKey, RAMType, VerifyingKey, WitnessId, ZKVMConstraintSystem}, witness::RowMajorMatrix, }; /// namespace used for annotation, preserve meta info during circuit construction -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize)] pub struct NameSpace { namespace: Vec, } @@ -49,7 +49,7 @@ impl NameSpace { let mut name = String::new(); let mut needs_separation = false; - for ns in ns.iter().chain(Some(&this).into_iter()) { + for ns in chain!(ns, once(&this)) { if needs_separation { name += "/"; } diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 81f8af33c..93b05e743 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -9,6 +9,7 @@ pub mod branch; pub mod config; pub mod constants; pub mod divu; +pub mod dummy; pub mod ecall; pub mod jump; pub mod logic; diff --git a/ceno_zkvm/src/instructions/riscv/divu.rs b/ceno_zkvm/src/instructions/riscv/divu.rs index c353a6caf..8d25f583f 100644 --- a/ceno_zkvm/src/instructions/riscv/divu.rs +++ b/ceno_zkvm/src/instructions/riscv/divu.rs @@ -4,6 +4,7 @@ use ff_ext::ExtensionField; use super::{ RIVInstruction, constants::{UINT_LIMBS, UInt}, + dummy::DummyInstruction, r_insn::RInstructionConfig, }; use crate::{ @@ -33,12 +34,30 @@ pub struct ArithConfig { pub struct ArithInstruction(PhantomData<(E, I)>); +pub struct DivOp; +impl RIVInstruction for DivOp { + const INST_KIND: InsnKind = InsnKind::DIV; +} +pub type DivDummy = DummyInstruction; // TODO: implement DivInstruction. + pub struct DivUOp; impl RIVInstruction for DivUOp { const INST_KIND: InsnKind = InsnKind::DIVU; } pub type DivUInstruction = ArithInstruction; +pub struct RemOp; +impl RIVInstruction for RemOp { + const INST_KIND: InsnKind = InsnKind::REM; +} +pub type RemDummy = DummyInstruction; // TODO: implement RemInstruction. + +pub struct RemuOp; +impl RIVInstruction for RemuOp { + const INST_KIND: InsnKind = InsnKind::REMU; +} +pub type RemuDummy = DummyInstruction; // TODO: implement RemuInstruction. + impl Instruction for ArithInstruction { type InstructionConfig = ArithConfig; diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs new file mode 100644 index 000000000..1fa7dc4c2 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -0,0 +1,266 @@ +use std::marker::PhantomData; + +use ceno_emul::{InsnCategory, InsnFormat, InsnKind, StepRecord}; +use ff_ext::ExtensionField; + +use super::super::{ + RIVInstruction, + constants::UInt, + insn_base::{ReadMEM, ReadRS1, ReadRS2, StateInOut, WriteMEM, WriteRD}, +}; +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{ToExpr, WitIn}, + instructions::Instruction, + set_val, + tables::InsnRecord, + uint::Value, + utils::i64_to_base, + witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; + +/// DummyInstruction can handle any instruction and produce its side-effects. +pub struct DummyInstruction(PhantomData<(E, I)>); + +impl Instruction for DummyInstruction { + type InstructionConfig = DummyConfig; + + fn name() -> String { + format!("{:?}_DUMMY", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result { + let codes = I::INST_KIND.codes(); + + // ECALL can do everything. + let is_ecall = matches!(codes.kind, InsnKind::EANY); + + // Regular instructions do what is implied by their format. + let (with_rs1, with_rs2, with_rd) = match codes.format { + _ if is_ecall => (true, true, true), + InsnFormat::R => (true, true, true), + InsnFormat::I => (true, false, true), + InsnFormat::S => (true, true, false), + InsnFormat::B => (true, true, false), + InsnFormat::U => (false, false, true), + InsnFormat::J => (false, false, true), + }; + let with_mem_write = matches!(codes.category, InsnCategory::Store) || is_ecall; + let with_mem_read = matches!(codes.category, InsnCategory::Load); + let branching = matches!(codes.category, InsnCategory::Branch) + || matches!(codes.kind, InsnKind::JAL | InsnKind::JALR) + || is_ecall; + + DummyConfig::construct_circuit( + circuit_builder, + I::INST_KIND, + with_rs1, + with_rs2, + with_rd, + with_mem_write, + with_mem_read, + branching, + ) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit<::BaseField>], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config.assign_instance(instance, lk_multiplicity, step) + } +} + +#[derive(Debug)] +pub struct DummyConfig { + vm_state: StateInOut, + + rs1: Option<(ReadRS1, UInt)>, + rs2: Option<(ReadRS2, UInt)>, + rd: Option<(WriteRD, UInt)>, + + mem_addr_val: Option<[WitIn; 3]>, + mem_read: Option>, + mem_write: Option, + + imm: WitIn, +} + +impl DummyConfig { + #[allow(clippy::too_many_arguments)] + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + kind: InsnKind, + with_rs1: bool, + with_rs2: bool, + with_rd: bool, + with_mem_write: bool, + with_mem_read: bool, + branching: bool, + ) -> Result { + // State in and out + let vm_state = StateInOut::construct_circuit(circuit_builder, branching)?; + + // Registers + let rs1 = if with_rs1 { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rs1_op = + ReadRS1::construct_circuit(circuit_builder, rs1_read.register_expr(), vm_state.ts)?; + Some((rs1_op, rs1_read)) + } else { + None + }; + + let rs2 = if with_rs2 { + let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; + let rs2_op = + ReadRS2::construct_circuit(circuit_builder, rs2_read.register_expr(), vm_state.ts)?; + Some((rs2_op, rs2_read)) + } else { + None + }; + + let rd = if with_rd { + let rd_written = UInt::new_unchecked(|| "rd_written", circuit_builder)?; + let rd_op = WriteRD::construct_circuit( + circuit_builder, + rd_written.register_expr(), + vm_state.ts, + )?; + Some((rd_op, rd_written)) + } else { + None + }; + + // Memory + let mem_addr_val = if with_mem_read || with_mem_write { + Some([ + circuit_builder.create_witin(|| "mem_addr"), + circuit_builder.create_witin(|| "mem_before"), + circuit_builder.create_witin(|| "mem_after"), + ]) + } else { + None + }; + + let mem_read = if with_mem_read { + Some(ReadMEM::construct_circuit( + circuit_builder, + mem_addr_val.as_ref().unwrap()[0].expr(), + mem_addr_val.as_ref().unwrap()[1].expr(), + vm_state.ts, + )?) + } else { + None + }; + + let mem_write = if with_mem_write { + Some(WriteMEM::construct_circuit( + circuit_builder, + mem_addr_val.as_ref().unwrap()[0].expr(), + mem_addr_val.as_ref().unwrap()[1].expr(), + mem_addr_val.as_ref().unwrap()[2].expr(), + vm_state.ts, + )?) + } else { + None + }; + + // Fetch instruction + + // The register IDs of ECALL is fixed, not encoded. + let is_ecall = matches!(kind, InsnKind::EANY); + let rs1_id = match &rs1 { + Some((r, _)) if !is_ecall => r.id.expr(), + _ => 0.into(), + }; + let rs2_id = match &rs2 { + Some((r, _)) if !is_ecall => r.id.expr(), + _ => 0.into(), + }; + let rd_id = match &rd { + Some((r, _)) if !is_ecall => Some(r.id.expr()), + _ => None, + }; + + let imm = circuit_builder.create_witin(|| "imm"); + + circuit_builder.lk_fetch(&InsnRecord::new( + vm_state.pc.expr(), + kind.into(), + rd_id, + rs1_id, + rs2_id, + imm.expr(), + ))?; + + Ok(DummyConfig { + vm_state, + rs1, + rs2, + rd, + mem_addr_val, + mem_read, + mem_write, + imm, + }) + } + + fn assign_instance( + &self, + instance: &mut [MaybeUninit<::BaseField>], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + // State in and out + self.vm_state.assign_instance(instance, step)?; + + // Fetch instruction + lk_multiplicity.fetch(step.pc().before.0); + + // Registers + if let Some((rs1_op, rs1_read)) = &self.rs1 { + rs1_op.assign_instance(instance, lk_multiplicity, step)?; + + let rs1_val = Value::new_unchecked(step.rs1().expect("rs1 value").value); + rs1_read.assign_value(instance, rs1_val); + } + if let Some((rs2_op, rs2_read)) = &self.rs2 { + rs2_op.assign_instance(instance, lk_multiplicity, step)?; + + let rs2_val = Value::new_unchecked(step.rs2().expect("rs2 value").value); + rs2_read.assign_value(instance, rs2_val); + } + if let Some((rd_op, rd_written)) = &self.rd { + rd_op.assign_instance(instance, lk_multiplicity, step)?; + + let rd_val = Value::new_unchecked(step.rd().expect("rd value").value.after); + rd_written.assign_value(instance, rd_val); + } + + // Memory + if let Some([mem_addr, mem_before, mem_after]) = &self.mem_addr_val { + let mem_op = step.memory_op().expect("memory operation"); + set_val!(instance, mem_addr, u64::from(mem_op.addr)); + set_val!(instance, mem_before, mem_op.value.before as u64); + set_val!(instance, mem_after, mem_op.value.after as u64); + } + if let Some(mem_read) = &self.mem_read { + mem_read.assign_instance(instance, lk_multiplicity, step)?; + } + if let Some(mem_write) = &self.mem_write { + mem_write.assign_instance::(instance, lk_multiplicity, step)?; + } + + let imm = i64_to_base::(InsnRecord::imm_internal(&step.insn())); + set_val!(instance, self.imm, imm); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/dummy/mod.rs b/ceno_zkvm/src/instructions/riscv/dummy/mod.rs new file mode 100644 index 000000000..9c912f422 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/dummy/mod.rs @@ -0,0 +1,16 @@ +//! Dummy instruction circuits for testing. +//! Support instructions that don’t have a complete implementation yet. +//! It connects all the state together (register writes, etc), but does not verify the values. +//! +//! Usage: +//! Specify an instruction with `trait RIVInstruction` and define a `DummyInstruction` like so: +//! +//! use ceno_zkvm::instructions::riscv::{arith::AddOp, dummy::DummyInstruction}; +//! +//! type AddDummy = DummyInstruction; + +mod dummy_circuit; +pub use dummy_circuit::DummyInstruction; + +#[cfg(test)] +mod test; diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs new file mode 100644 index 000000000..df1eb0572 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -0,0 +1,101 @@ +use ceno_emul::{Change, InsnKind, StepRecord, encode_rv32}; +use goldilocks::GoldilocksExt2; + +use super::*; +use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::{arith::AddOp, branch::BeqOp, ecall::EcallDummy}, + }, + scheme::mock_prover::{MOCK_PC_START, MockProver}, +}; + +type AddDummy = DummyInstruction; +type BeqDummy = DummyInstruction; + +#[test] +fn test_dummy_ecall() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "ecall_dummy", + |cb| { + let config = EcallDummy::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let step = StepRecord::new_ecall_any(4, MOCK_PC_START); + let insn_code = step.insn_code(); + let (raw_witin, lkm) = + EcallDummy::assign_instances(&config, cb.cs.num_witin as usize, vec![step]).unwrap(); + + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); +} + +#[test] +fn test_dummy_r() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "add_dummy", + |cb| { + let config = AddDummy::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); + let (raw_witin, lkm) = AddDummy::assign_instances(&config, cb.cs.num_witin as usize, vec![ + StepRecord::new_r_instruction( + 3, + MOCK_PC_START, + insn_code, + 11, + 0xfffffffe, + Change::new(0, 11_u32.wrapping_add(0xfffffffe)), + 0, + ), + ]) + .unwrap(); + + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); +} + +#[test] +fn test_dummy_b() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "beq_dummy", + |cb| { + let config = BeqDummy::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); + let (raw_witin, lkm) = BeqDummy::assign_instances(&config, cb.cs.num_witin as usize, vec![ + StepRecord::new_b_instruction( + 3, + Change::new(MOCK_PC_START, MOCK_PC_START + 8_usize), + insn_code, + 0xbead1010, + 0xbead1010, + 0, + ), + ]) + .unwrap(); + + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); +} diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index 76c1c04e6..0d7a3315a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -1,3 +1,13 @@ mod halt; +use ceno_emul::InsnKind; pub use halt::HaltInstruction; + +use super::{RIVInstruction, dummy::DummyInstruction}; + +pub struct EcallOp; +impl RIVInstruction for EcallOp { + const INST_KIND: InsnKind = InsnKind::EANY; +} +/// Unsafe. A dummy ecall circuit that ignores unimplemented functions. +pub type EcallDummy = DummyInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 79f07b947..9588cb34f 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -111,16 +111,15 @@ impl ReadRS1 { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - set_val!(instance, self.id, step.insn().rs1() as u64); - - // Register state - set_val!(instance, self.prev_ts, step.rs1().unwrap().previous_cycle); + let op = step.rs1().expect("rs1 op"); + set_val!(instance, self.id, op.register_index() as u64); + set_val!(instance, self.prev_ts, op.previous_cycle); // Register read self.lt_cfg.assign_instance( instance, lk_multiplicity, - step.rs1().unwrap().previous_cycle, + op.previous_cycle, step.cycle() + Tracer::SUBCYCLE_RS1, )?; @@ -166,16 +165,15 @@ impl ReadRS2 { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - set_val!(instance, self.id, step.insn().rs2() as u64); - - // Register state - set_val!(instance, self.prev_ts, step.rs2().unwrap().previous_cycle); + let op = step.rs2().expect("rs2 op"); + set_val!(instance, self.id, op.register_index() as u64); + set_val!(instance, self.prev_ts, op.previous_cycle); // Register read self.lt_cfg.assign_instance( instance, lk_multiplicity, - step.rs2().unwrap().previous_cycle, + op.previous_cycle, step.cycle() + Tracer::SUBCYCLE_RS2, )?; @@ -223,20 +221,21 @@ impl WriteRD { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - set_val!(instance, self.id, step.insn().rd_internal() as u64); - set_val!(instance, self.prev_ts, step.rd().unwrap().previous_cycle); + let op = step.rd().expect("rd op"); + set_val!(instance, self.id, op.register_index() as u64); + set_val!(instance, self.prev_ts, op.previous_cycle); // Register state self.prev_value.assign_limbs( instance, - Value::new_unchecked(step.rd().unwrap().value.before).as_u16_limbs(), + Value::new_unchecked(op.value.before).as_u16_limbs(), ); // Register write self.lt_cfg.assign_instance( instance, lk_multiplicity, - step.rd().unwrap().previous_cycle, + op.previous_cycle, step.cycle() + Tracer::SUBCYCLE_RD, )?; diff --git a/ceno_zkvm/src/instructions/riscv/test.rs b/ceno_zkvm/src/instructions/riscv/test.rs index 6513cd1f7..f4b8f8824 100644 --- a/ceno_zkvm/src/instructions/riscv/test.rs +++ b/ceno_zkvm/src/instructions/riscv/test.rs @@ -23,6 +23,6 @@ fn test_multiple_opcode() { |cs| SubInstruction::construct_circuit(&mut CircuitBuilder::::new(cs)), ); let param = Pcs::setup(1 << 10).unwrap(); - let (pp, _) = Pcs::trim(¶m, 1 << 10).unwrap(); + let (pp, _) = Pcs::trim(param, 1 << 10).unwrap(); cs.key_gen::(&pp, None); } diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index ba853a943..f97a1886d 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -14,7 +14,7 @@ impl ZKVMConstraintSystem { ) -> Result, ZKVMError> { let mut vm_pk = ZKVMProvingKey::new(pp, vp); - for (c_name, cs) in self.circuit_css.into_iter() { + for (c_name, cs) in self.circuit_css { // fixed_traces is optional // verifier will check it existent if cs.num_fixed > 0 let fixed_traces = if cs.num_fixed > 0 { diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 75b4b377d..a3c2ff02f 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -14,6 +14,7 @@ pub mod expression; pub mod gadgets; mod keygen; pub mod state; +pub mod stats; pub mod structs; mod uint; mod utils; diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 8b7a36d26..506f065f1 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -22,7 +22,7 @@ use ff::Field; use ff_ext::ExtensionField; use generic_static::StaticTypeMap; use goldilocks::SmallField; -use itertools::{Itertools, izip}; +use itertools::{Itertools, enumerate, izip}; use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension}; use rand::thread_rng; use std::{ @@ -504,7 +504,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); let expr_evaluated = expr_evaluated.get_base_field_vec(); - for (inst_id, element) in expr_evaluated.iter().enumerate() { + for (inst_id, element) in enumerate(expr_evaluated) { if *element != E::BaseField::ZERO { errors.push(MockProverError::AssertZeroError { expression: expr.clone(), @@ -528,7 +528,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { let expr_evaluated = expr_evaluated.get_ext_field_vec(); // Check each lookup expr exists in t vec - for (inst_id, element) in expr_evaluated.iter().enumerate() { + for (inst_id, element) in enumerate(expr_evaluated) { if !table.contains(&element.to_canonical_u64_vec()) { errors.push(MockProverError::LookupError { expression: expr.clone(), @@ -883,12 +883,12 @@ Hints: num_instances.insert(circuit_name.clone(), num_rows); } - for (rom_type, inputs) in rom_inputs.into_iter() { + for (rom_type, inputs) in rom_inputs { let table = rom_tables.get_mut(&rom_type).unwrap(); for (lk_input_values, circuit_name, lk_input_annotation, input_value_exprs) in inputs { // counting multiplicity in rom_input let mut lk_input_values_multiplicity = HashMap::new(); - for (row, input_value) in lk_input_values.iter().enumerate() { + for (row, input_value) in enumerate(&lk_input_values) { // we only keep first row to restore debug information lk_input_values_multiplicity .entry(input_value) @@ -1009,7 +1009,7 @@ Hints: assert!(gs.insert(circuit_name.clone(), w).is_none()); }; let mut records = vec![]; - for (row, record_rlc) in write_rlc_records.into_iter().enumerate() { + for (row, record_rlc) in enumerate(write_rlc_records) { // TODO: report error assert_eq!(writes.insert(record_rlc), true); records.push((record_rlc, row)); @@ -1045,7 +1045,7 @@ Hints: .get_ext_field_vec()[..*num_rows] .to_vec(); let mut records = vec![]; - for (row, record) in read_records.into_iter().enumerate() { + for (row, record) in enumerate(read_records) { assert_eq!(reads.insert(record), true); records.push((record, row)); } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 93d7a77f2..671b52eda 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -52,6 +52,7 @@ impl> ZKVMProver { } /// create proof for zkvm execution + #[tracing::instrument(skip_all, name = "ZKVM_create_proof")] pub fn create_proof( &self, witnesses: ZKVMWitnesses, @@ -87,10 +88,11 @@ impl> ZKVMProver { let mut commitments = BTreeMap::new(); let mut wits = BTreeMap::new(); + let commit_to_traces_span = entered_span!("commit_to_traces"); // commit to opcode circuits first and then commit to table circuits, sorted by name for (circuit_name, witness) in witnesses.into_iter_sorted() { - let commit_dur = std::time::Instant::now(); let num_instances = witness.num_instances(); + let span = entered_span!("commit to iteration", circuit_name = circuit_name); let witness = match num_instances { 0 => vec![], _ => { @@ -100,16 +102,13 @@ impl> ZKVMProver { PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript) .map_err(ZKVMError::PCSError)?, ); - tracing::info!( - "commit to {} traces took {:?}", - circuit_name, - commit_dur.elapsed() - ); witness } }; + exit_span!(span); wits.insert(circuit_name, (witness, num_instances)); } + exit_span!(commit_to_traces_span); // squeeze two challenges from transcript let challenges = [ @@ -118,6 +117,7 @@ impl> ZKVMProver { ]; tracing::debug!("challenges in prover: {:?}", challenges); + let main_proofs_span = entered_span!("main_proofs"); let mut transcripts = transcript.fork(self.pk.circuit_pks.len()); for ((circuit_name, pk), (i, transcript)) in self .pk @@ -193,6 +193,7 @@ impl> ZKVMProver { } } } + exit_span!(main_proofs_span); Ok(vm_proof) } @@ -201,6 +202,7 @@ impl> ZKVMProver { /// 1: witness layer inferring from input -> output /// 2: proof (sumcheck reduce) from output to input #[allow(clippy::too_many_arguments)] + #[tracing::instrument(skip_all, name = "create_opcode_proof", fields(circuit_name=name))] pub fn create_opcode_proof( &self, name: &str, @@ -226,8 +228,9 @@ impl> ZKVMProver { .all(|v| { v.evaluations().len() == next_pow2_instances }) ); + let wit_inference_span = entered_span!("wit_inference"); // main constraint: read/write record witness inference - let span = entered_span!("wit_inference::record"); + let record_span = entered_span!("record"); let records_wit: Vec> = cs .r_expressions .par_iter() @@ -240,7 +243,7 @@ impl> ZKVMProver { .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); let (w_records_wit, lk_records_wit) = w_lk_records_wit.split_at(cs.w_expressions.len()); - exit_span!(span); + exit_span!(record_span); // product constraint: tower witness inference let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( @@ -255,7 +258,7 @@ impl> ZKVMProver { ); // process last layer by interleaving all the read/write record respectively // as last layer is the output of sel stage - let span = entered_span!("wit_inference::tower_witness_r_last_layer"); + let span = entered_span!("tower_witness_r_last_layer"); // TODO optimize last layer to avoid alloc new vector to save memory let r_records_last_layer = interleaving_mles_to_mles(r_records_wit, num_instances, NUM_FANIN, E::ONE); @@ -263,7 +266,7 @@ impl> ZKVMProver { exit_span!(span); // infer all tower witness after last layer - let span = entered_span!("wit_inference::tower_witness_r_layers"); + let span = entered_span!("tower_witness_r_layers"); let r_wit_layers = infer_tower_product_witness( log2_num_instances + log2_r_count, r_records_last_layer, @@ -271,14 +274,14 @@ impl> ZKVMProver { ); exit_span!(span); - let span = entered_span!("wit_inference::tower_witness_w_last_layer"); + let span = entered_span!("tower_witness_w_last_layer"); // TODO optimize last layer to avoid alloc new vector to save memory let w_records_last_layer = interleaving_mles_to_mles(w_records_wit, num_instances, NUM_FANIN, E::ONE); assert_eq!(w_records_last_layer.len(), NUM_FANIN); exit_span!(span); - let span = entered_span!("wit_inference::tower_witness_w_layers"); + let span = entered_span!("tower_witness_w_layers"); let w_wit_layers = infer_tower_product_witness( log2_num_instances + log2_w_count, w_records_last_layer, @@ -286,16 +289,17 @@ impl> ZKVMProver { ); exit_span!(span); - let span = entered_span!("wit_inference::tower_witness_lk_last_layer"); + let span = entered_span!("tower_witness_lk_last_layer"); // TODO optimize last layer to avoid alloc new vector to save memory let lk_records_last_layer = interleaving_mles_to_mles(lk_records_wit, num_instances, NUM_FANIN, chip_record_alpha); assert_eq!(lk_records_last_layer.len(), 2); exit_span!(span); - let span = entered_span!("wit_inference::tower_witness_lk_layers"); + let span = entered_span!("tower_witness_lk_layers"); let lk_wit_layers = infer_tower_logup_witness(None, lk_records_last_layer); exit_span!(span); + exit_span!(wit_inference_span); if cfg!(test) { // sanity check @@ -326,8 +330,9 @@ impl> ZKVMProver { })); } + let sumcheck_span = entered_span!("SUMCHECK"); // product constraint tower sumcheck - let span = entered_span!("sumcheck::tower"); + let tower_span = entered_span!("tower"); // final evals for verifier let record_r_out_evals: Vec = r_wit_layers[0] .iter() @@ -365,10 +370,10 @@ impl> ZKVMProver { .max() .unwrap() ); - exit_span!(span); + exit_span!(tower_span); // batch sumcheck: selector + main degree > 1 constraints - let span = entered_span!("sumcheck::main_sel"); + let main_sel_span = entered_span!("main_sel"); let (rt_r, rt_w, rt_lk, rt_non_lc_sumcheck): (Vec, Vec, Vec, Vec) = ( tower_proof.prod_specs_points[0] .last() @@ -581,7 +586,8 @@ impl> ZKVMProver { ); let input_open_point = main_sel_sumcheck_proofs.point.clone(); assert!(input_open_point.len() == log2_num_instances); - exit_span!(span); + exit_span!(main_sel_span); + exit_span!(sumcheck_span); let span = entered_span!("witin::evals"); let wits_in_evals: Vec = witnesses @@ -590,7 +596,7 @@ impl> ZKVMProver { .collect(); exit_span!(span); - let span = entered_span!("pcs_open"); + let pcs_open_span = entered_span!("pcs_open"); let opening_dur = std::time::Instant::now(); tracing::debug!( "[opcode {}]: build opening proof for {} polys at {:?}", @@ -612,7 +618,7 @@ impl> ZKVMProver { name, opening_dur.elapsed(), ); - exit_span!(span); + exit_span!(pcs_open_span); let wits_commit = PCS::get_pure_commitment(&wits_commit); Ok(ZKVMOpcodeProof { @@ -638,6 +644,7 @@ impl> ZKVMProver { /// support batch prove for logup + product arguments each with different num_vars() /// side effect: concurrency will be determine based on min(thread, num_vars()), /// so suggest dont batch too small table (size < threads) with large table together + #[tracing::instrument(skip_all, name = "create_table_proof", fields(table_name=name))] pub fn create_table_proof( &self, name: &str, @@ -681,8 +688,9 @@ impl> ZKVMProver { .all(|(r, w)| r.table_spec.len == w.table_spec.len) ); + let wit_inference_span = entered_span!("wit_inference"); // main constraint: lookup denominator and numerator record witness inference - let span = entered_span!("wit_inference::record"); + let record_span = entered_span!("record"); let mut records_wit: Vec> = cs .r_table_expressions .par_iter() @@ -707,10 +715,10 @@ impl> ZKVMProver { let (lk_d_wit, _empty) = remains.split_at_mut(cs.lk_table_expressions.len()); assert!(_empty.is_empty()); - exit_span!(span); + exit_span!(record_span); // infer all tower witness after last layer - let span = entered_span!("wit_inference::tower_witness_lk_last_layer"); + let span = entered_span!("tower_witness_lk_last_layer"); let mut r_set_last_layer = r_set_wit .iter() .chain(w_set_wit.iter()) @@ -758,7 +766,7 @@ impl> ZKVMProver { .collect::>(); exit_span!(span); - let span = entered_span!("wit_inference::tower_witness_lk_layers"); + let span = entered_span!("tower_witness_lk_layers"); let r_wit_layers = r_set_last_layer .into_iter() .zip(r_set_wit.iter()) @@ -779,6 +787,7 @@ impl> ZKVMProver { .map(|(lk_n, lk_d)| infer_tower_logup_witness(Some(lk_n), lk_d)) .collect_vec(); exit_span!(span); + exit_span!(wit_inference_span); if cfg!(test) { // sanity check @@ -831,8 +840,9 @@ impl> ZKVMProver { })); } + let sumcheck_span = entered_span!("sumcheck"); // product constraint tower sumcheck - let span = entered_span!("sumcheck::tower"); + let tower_span = entered_span!("tower"); // final evals for verifier let r_out_evals = r_wit_layers .iter() @@ -889,7 +899,7 @@ impl> ZKVMProver { rt_tower.len(), // num var length should equal to max_num_instance max_log2_num_instance ); - exit_span!(span); + exit_span!(tower_span); // same point sumcheck is optional when all witin + fixed are in same num_vars let is_skip_same_point_sumcheck = witnesses @@ -904,7 +914,7 @@ impl> ZKVMProver { } else { // one sumcheck to make them opening on same point r (with different prefix) // If all table length are the same, we can skip this sumcheck - let span = entered_span!("sumcheck::opening_same_point"); + let span = entered_span!("opening_same_point"); // NOTE: max concurrency will be dominated by smallest table since it will blo let num_threads = optimal_sumcheck_threads(min_log2_num_instance); let alpha_pow = get_challenge_pows( @@ -993,6 +1003,7 @@ impl> ZKVMProver { ) }; + exit_span!(sumcheck_span); let span = entered_span!("fixed::evals + witin::evals"); let mut evals = witnesses .par_iter() @@ -1025,7 +1036,7 @@ impl> ZKVMProver { .collect_vec(); // TODO implement mechanism to skip commitment - let span = entered_span!("pcs_opening"); + let pcs_opening = entered_span!("pcs_opening"); let (fixed_opening_proof, fixed_commit) = if !fixed.is_empty() { ( Some( @@ -1064,7 +1075,7 @@ impl> ZKVMProver { transcript, ) .map_err(ZKVMError::PCSError)?; - exit_span!(span); + exit_span!(pcs_opening); let wits_commit = PCS::get_pure_commitment(&wits_commit); tracing::debug!( "[table {}] build opening proof for {} polys at {:?}: values = {:?}, commit = {:?}", @@ -1132,6 +1143,7 @@ impl TowerProofs { /// Tower Prover impl TowerProver { + #[tracing::instrument(skip_all, name = "tower_prover_create_proof")] pub fn create_proof<'a, E: ExtensionField>( prod_specs: Vec>, logup_specs: Vec>, @@ -1226,11 +1238,17 @@ impl TowerProver { } } + let wrap_batch_span = entered_span!("wrap_batch"); + // NOTE: at the time of adding this span, visualizing it with the flamegraph layer + // shows it to be (inexplicably) much more time-consuming than the call to `prove_batch_polys` + // This is likely a bug in the tracing-flame crate. let (sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( num_threads, virtual_polys.get_batched_polys(), transcript, ); + exit_span!(wrap_batch_span); + proofs.push_sumcheck_proofs(sumcheck_proofs.proofs); // rt' = r_merge || rt diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 07ae5cb99..04edee440 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -89,7 +89,7 @@ fn test_rw_lk_expression_combination() { // pcs setup let param = Pcs::setup(1 << 13).unwrap(); - let (pp, vp) = Pcs::trim(¶m, 1 << 13).unwrap(); + let (pp, vp) = Pcs::trim(param, 1 << 13).unwrap(); // configure let name = TestCircuit::::name(); @@ -223,7 +223,7 @@ fn test_single_add_instance_e2e() { ); let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); - let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); + let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); // opcode circuits let add_config = zkvm_cs.register_opcode_circuit::>(); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 415a26872..75e91b576 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -42,7 +42,7 @@ impl> ZKVMVerifier pub fn new(vk: ZKVMVerifyingKey) -> Self { ZKVMVerifier { vk } } - + #[tracing::instrument(skip_all, name = "verify_proof")] pub fn verify_proof( &self, vm_proof: ZKVMProof, diff --git a/ceno_zkvm/src/stats.rs b/ceno_zkvm/src/stats.rs new file mode 100644 index 000000000..d88feb93b --- /dev/null +++ b/ceno_zkvm/src/stats.rs @@ -0,0 +1,172 @@ +use crate::{ + circuit_builder::{ConstraintSystem, NameSpace}, + expression::Expression, + structs::{ZKVMConstraintSystem, ZKVMWitnesses}, +}; +use ff_ext::ExtensionField; +use itertools::Itertools; +use serde_json::json; +use std::{collections::BTreeMap, fs::File, io::Write}; + +#[derive(Clone, Debug, serde::Serialize)] +pub struct OpCodeStats { + namespace: NameSpace, + witnesses: usize, + reads: usize, + writes: usize, + lookups: usize, + assert_zero_expr_degrees: Vec, + assert_zero_sumcheck_expr_degrees: Vec, +} + +#[derive(Clone, Debug, serde::Serialize)] +pub struct TableStats { + table_len: usize, +} + +#[derive(Clone, Debug, serde::Serialize)] +pub enum CircuitStats { + OpCode(OpCodeStats), + Table(TableStats), +} + +impl CircuitStats { + pub fn new(system: &ConstraintSystem) -> Self { + let just_degrees = + |exprs: &Vec>| exprs.iter().map(|e| e.degree()).collect_vec(); + let is_opcode = system.lk_table_expressions.is_empty() + && system.r_table_expressions.is_empty() + && system.w_table_expressions.is_empty(); + // distinguishing opcodes from tables as done in ZKVMProver::create_proof + if is_opcode { + CircuitStats::OpCode(OpCodeStats { + namespace: system.ns.clone(), + witnesses: system.num_witin as usize, + reads: system.r_expressions.len(), + writes: system.w_expressions.len(), + lookups: system.lk_expressions.len(), + assert_zero_expr_degrees: just_degrees(&system.assert_zero_expressions), + assert_zero_sumcheck_expr_degrees: just_degrees( + &system.assert_zero_sumcheck_expressions, + ), + }) + } else { + let table_len = if system.lk_table_expressions.len() > 0 { + system.lk_table_expressions[0].table_len + } else { + 0 + }; + CircuitStats::Table(TableStats { table_len }) + } + } +} + +pub struct Report { + metadata: BTreeMap, + circuits: Vec<(String, INFO)>, +} + +impl Report +where + INFO: serde::Serialize, +{ + pub fn get(&self, circuit_name: &str) -> Option<&INFO> { + self.circuits.iter().find_map(|(name, info)| { + if name == circuit_name { + Some(info) + } else { + None + } + }) + } + + pub fn save_json(&self, filename: &str) { + let json_data = json!({ + "metadata": self.metadata, + "circuits": self.circuits, + }); + + let mut file = File::create(filename).expect("Unable to create file"); + file.write_all(serde_json::to_string_pretty(&json_data).unwrap().as_bytes()) + .expect("Unable to write data"); + } +} +pub type StaticReport = Report; + +impl Report { + pub fn new(zkvm_system: &ZKVMConstraintSystem) -> Self { + Report { + metadata: BTreeMap::default(), + circuits: zkvm_system + .get_css() + .iter() + .map(|(k, v)| (k.clone(), CircuitStats::new(v))) + .collect_vec(), + } + } +} + +#[derive(Clone, Debug, serde::Serialize)] +pub struct CircuitStatsTrace { + static_stats: CircuitStats, + num_instances: usize, +} + +impl CircuitStatsTrace { + pub fn new(static_stats: CircuitStats, num_instances: usize) -> Self { + return CircuitStatsTrace { + static_stats, + num_instances, + }; + } +} + +pub type TraceReport = Report; + +impl Report { + pub fn new( + static_report: &Report, + num_instances: BTreeMap, + program_name: &str, + ) -> Self { + let mut metadata = static_report.metadata.clone(); + // Note where the num_instances are extracted from + metadata.insert("PROGRAM_NAME".to_owned(), program_name.to_owned()); + + // Ensure we recognize all circuits from the num_instances map + num_instances.keys().for_each(|key| { + assert!( + matches!(static_report.get(key), Some(_)), + r"unrecognized key {key}." + ); + }); + + // Stitch num instances to corresponding entries. Sort by num instances + let circuits = static_report + .circuits + .iter() + .map(|(key, value)| { + ( + key.to_owned(), + CircuitStatsTrace::new(value.clone(), *num_instances.get(key).unwrap_or(&0)), + ) + }) + .sorted_by(|lhs, rhs| rhs.1.num_instances.cmp(&lhs.1.num_instances)) + .collect_vec(); + Report { metadata, circuits } + } + + // Extract num_instances from witness data + pub fn new_via_witnesses( + static_report: &Report, + zkvm_witnesses: &ZKVMWitnesses, + program_name: &str, + ) -> Self { + let num_instances = zkvm_witnesses + .clone() + .into_iter_sorted() + .map(|(key, value)| (key, value.num_instances())) + .collect::>(); + Self::new::(static_report, num_instances, program_name) + } +} diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 5a504a2c3..96d7f787e 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -170,6 +170,10 @@ impl ZKVMConstraintSystem { SC::finalize_global_state(&mut circuit_builder).expect("global_state_out failed"); } + pub fn get_css(&self) -> &BTreeMap> { + &self.circuit_css + } + pub fn get_cs(&self, name: &String) -> Option<&ConstraintSystem> { self.circuit_css.get(name) } diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 452019930..afb231d8a 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -14,7 +14,6 @@ use crate::{ witness::LkMultiplicity, }; use ark_std::iterable::Iterable; -use constants::BYTE_BIT_WIDTH; use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; @@ -422,71 +421,6 @@ impl UIntLimbs { } } - /// Builds a `UIntLimbs` instance from a set of cells that represent `RANGE_VALUES` - /// assumes range_values are represented in little endian form - pub fn from_range_wits_in( - _circuit_builder: &mut CircuitBuilder, - _range_values: &[WitIn], - ) -> Result { - // Self::from_different_sized_cell_values( - // circuit_builder, - // range_values, - // RANGE_CHIP_BIT_WIDTH, - // true, - // ) - todo!() - } - - /// Builds a `UIntLimbs` instance from a set of cells that represent big-endian `BYTE_VALUES` - pub fn from_bytes_big_endian( - circuit_builder: &mut CircuitBuilder, - bytes: &[WitIn], - ) -> Result { - Self::from_bytes(circuit_builder, bytes, false) - } - - /// Builds a `UIntLimbs` instance from a set of cells that represent little-endian `BYTE_VALUES` - pub fn from_bytes_little_endian( - circuit_builder: &mut CircuitBuilder, - bytes: &[WitIn], - ) -> Result { - Self::from_bytes(circuit_builder, bytes, true) - } - - /// Builds a `UIntLimbs` instance from a set of cells that represent `BYTE_VALUES` - pub fn from_bytes( - circuit_builder: &mut CircuitBuilder, - bytes: &[WitIn], - is_little_endian: bool, - ) -> Result { - Self::from_different_sized_cell_values( - circuit_builder, - bytes, - BYTE_BIT_WIDTH, - is_little_endian, - ) - } - - /// Builds a `UIntLimbs` instance from a set of cell values of a certain `CELL_WIDTH` - fn from_different_sized_cell_values( - _circuit_builder: &mut CircuitBuilder, - _wits_in: &[WitIn], - _cell_width: usize, - _is_little_endian: bool, - ) -> Result { - todo!() - // let mut values = convert_decomp( - // circuit_builder, - // wits_in, - // cell_width, - // Self::MAX_CELL_BIT_WIDTH, - // is_little_endian, - // )?; - // debug_assert!(values.len() <= Self::NUM_CELLS); - // pad_cells(circuit_builder, &mut values, Self::NUM_CELLS); - // values.try_into() - } - /// Generate ((0)_{2^C}, (1)_{2^C}, ..., (size - 1)_{2^C}) pub fn counter_vector(size: usize) -> Vec> { let num_vars = ceil_log2(size); diff --git a/ceno_zkvm/src/uint/constants.rs b/ceno_zkvm/src/uint/constants.rs index c418213f0..991cf22e2 100644 --- a/ceno_zkvm/src/uint/constants.rs +++ b/ceno_zkvm/src/uint/constants.rs @@ -2,8 +2,6 @@ use crate::utils::const_min; use super::{UIntLimbs, util::max_carry_word_for_multiplication}; -pub const BYTE_BIT_WIDTH: usize = 8; - use ff_ext::ExtensionField; impl diff --git a/mpcs/benches/basecode.rs b/mpcs/benches/basecode.rs index 193d15f88..9ef1896f6 100644 --- a/mpcs/benches/basecode.rs +++ b/mpcs/benches/basecode.rs @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) { let (pp, _) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let polys = (0..batch_size) .map(|_| { diff --git a/mpcs/benches/commit_open_verify_basecode.rs b/mpcs/benches/commit_open_verify_basecode.rs index 1a22e9171..91baa5f73 100644 --- a/mpcs/benches/commit_open_verify_basecode.rs +++ b/mpcs/benches/commit_open_verify_basecode.rs @@ -42,7 +42,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Pcs::setup(poly_size).unwrap(); }) }); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); @@ -118,7 +118,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -258,7 +258,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); let polys = (0..batch_size) diff --git a/mpcs/benches/commit_open_verify_rs.rs b/mpcs/benches/commit_open_verify_rs.rs index 686253218..1401f5127 100644 --- a/mpcs/benches/commit_open_verify_rs.rs +++ b/mpcs/benches/commit_open_verify_rs.rs @@ -46,7 +46,7 @@ fn bench_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { Pcs::setup(poly_size).unwrap(); }) }); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); @@ -125,7 +125,7 @@ fn bench_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: bool) { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -266,7 +266,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let mut transcript = T::new(b"BaseFold"); let polys = (0..batch_size) diff --git a/mpcs/benches/rscode.rs b/mpcs/benches/rscode.rs index ac9870d84..2d284d177 100644 --- a/mpcs/benches/rscode.rs +++ b/mpcs/benches/rscode.rs @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) { let (pp, _) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let polys = (0..batch_size) .map(|_| { diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 3f85022cc..5c225c75a 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -287,10 +287,10 @@ where /// Derive the proving key and verification key from the public parameter. /// This step simultaneously trims the parameter for the particular size. fn trim( - pp: &Self::Param, + pp: Self::Param, poly_size: usize, ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { - >::trim(&pp.params, log2_strict(poly_size)).map( + >::trim(pp.params, log2_strict(poly_size)).map( |(pp, vp)| { ( BasefoldProverParams { diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs index 706ec9f9f..410d35970 100644 --- a/mpcs/src/basefold/encoding.rs +++ b/mpcs/src/basefold/encoding.rs @@ -35,7 +35,7 @@ pub trait EncodingScheme: std::fmt::Debug + Clone { fn setup(max_msg_size_log: usize) -> Self::PublicParameters; fn trim( - pp: &Self::PublicParameters, + pp: Self::PublicParameters, max_msg_size_log: usize, ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error>; @@ -177,7 +177,7 @@ pub(crate) mod test_util { let mut poly = FieldType::Ext(poly); let pp: Code::PublicParameters = Code::setup(num_vars); - let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); reverse_index_bits_in_place_field_type(&mut codeword); if Code::message_is_left_and_right_folding() { diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs index 44b91ba76..9fbee84f1 100644 --- a/mpcs/src/basefold/encoding/basecode.rs +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -117,7 +117,7 @@ where } fn trim( - pp: &Self::PublicParameters, + mut pp: Self::PublicParameters, max_msg_size_log: usize, ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { if pp.table.len() < Spec::get_rate_log() + max_msg_size_log { @@ -127,6 +127,9 @@ where max_msg_size_log, ))); } + pp.table_w_weights + .truncate(Spec::get_rate_log() + max_msg_size_log); + pp.table.truncate(Spec::get_rate_log() + max_msg_size_log); let mut key: [u8; 16] = [0u8; 16]; let mut iv: [u8; 16] = [0u8; 16]; let mut rng = ChaCha8Rng::from_seed(pp.rng_seed); @@ -135,8 +138,8 @@ where rng.fill_bytes(&mut iv); Ok(( Self::ProverParameters { - table_w_weights: pp.table_w_weights.clone(), - table: pp.table.clone(), + table_w_weights: pp.table_w_weights, + table: pp.table, rng_seed: pp.rng_seed, _phantom: PhantomData, }, @@ -430,7 +433,7 @@ mod tests { fn prover_verifier_consistency() { type Code = Basecode; let pp: BasecodeParameters = Code::setup(10); - let (pp, vp) = Code::trim(&pp, 10).unwrap(); + let (pp, vp) = Code::trim(pp, 10).unwrap(); for level in 0..(10 + >::get_rate_log()) { for index in 0..(1 << level) { assert_eq!( diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs index 8535ce23e..2bcac0826 100644 --- a/mpcs/src/basefold/encoding/rs.rs +++ b/mpcs/src/basefold/encoding/rs.rs @@ -280,7 +280,7 @@ where } fn trim( - pp: &Self::PublicParameters, + mut pp: Self::PublicParameters, max_message_size_log: usize, ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error> { if pp.fft_root_table.len() < max_message_size_log + Spec::get_rate_log() { @@ -308,7 +308,6 @@ where }, )); } - let mut gamma_powers = Vec::with_capacity(max_message_size_log); let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log); gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR); @@ -319,26 +318,27 @@ where } let inv_of_two = E::BaseField::from(2).invert().unwrap(); gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two); + pp.fft_root_table + .truncate(max_message_size_log + Spec::get_rate_log()); + let verifier_fft_root_table = pp.fft_root_table + [..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()] + .iter() + .cloned() + .chain( + pp.fft_root_table[Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..] + .iter() + .map(|v| vec![v[1]]), + ) + .collect(); Ok(( Self::ProverParameters { - fft_root_table: pp.fft_root_table[..max_message_size_log + Spec::get_rate_log()] - .to_vec(), + fft_root_table: pp.fft_root_table, gamma_powers: gamma_powers.clone(), gamma_powers_inv_div_two: gamma_powers_inv.clone(), full_message_size_log: max_message_size_log, }, Self::VerifierParameters { - fft_root_table: pp.fft_root_table - [..Spec::get_basecode_msg_size_log() + Spec::get_rate_log()] - .iter() - .cloned() - .chain( - pp.fft_root_table - [Spec::get_basecode_msg_size_log() + Spec::get_rate_log()..] - .iter() - .map(|v| vec![v[1]]), - ) - .collect(), + fft_root_table: verifier_fft_root_table, full_message_size_log: max_message_size_log, gamma_powers, gamma_powers_inv_div_two: gamma_powers_inv, @@ -653,7 +653,7 @@ mod tests { fn prover_verifier_consistency() { type Code = RSCode; let pp: RSCodeParameters = Code::setup(10); - let (pp, vp) = Code::trim(&pp, 10).unwrap(); + let (pp, vp) = Code::trim(pp, 10).unwrap(); for level in 0..(10 + >::get_rate_log()) { for index in 0..(1 << level) { let (naive_x0, naive_x1, naive_w) = @@ -690,7 +690,7 @@ mod tests { let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); - let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); reverse_index_bits_in_place_field_type(&mut codeword); let challenge = E::from(2); @@ -728,7 +728,7 @@ mod tests { let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); - let (pp, _) = Code::trim(&pp, num_vars).unwrap(); + let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); check_low_degree(&codeword, "low degree check for original codeword"); let c0 = field_type_index_ext(&codeword, 0); diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 46fbea0ff..19b3d16b6 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -24,7 +24,7 @@ pub fn pcs_setup>( } pub fn pcs_trim>( - param: &Pcs::Param, + param: Pcs::Param, poly_size: usize, ) -> Result<(Pcs::ProverParam, Pcs::VerifierParam), Error> { Pcs::trim(param, poly_size) @@ -119,7 +119,7 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { fn setup(poly_size: usize) -> Result; fn trim( - param: &Self::Param, + param: Self::Param, poly_size: usize, ) -> Result<(Self::ProverParam, Self::VerifierParam), Error>; @@ -380,7 +380,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Commit and open let (comm, eval, proof, challenge) = { @@ -442,7 +442,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; // Batch commit and open let evals = chain![ @@ -556,7 +556,7 @@ pub mod test_util { let (pp, vp) = { let poly_size = 1 << num_vars; let param = Pcs::setup(poly_size).unwrap(); - Pcs::trim(¶m, poly_size).unwrap() + Pcs::trim(param, poly_size).unwrap() }; let (comm, evals, proof, challenge) = { diff --git a/mpcs/src/util/transcript.rs b/mpcs/src/util/transcript.rs deleted file mode 100644 index 2edf082cf..000000000 --- a/mpcs/src/util/transcript.rs +++ /dev/null @@ -1,325 +0,0 @@ -use crate::{util::Itertools, Error}; -use ff::Field; -use ff_ext::ExtensionField; -use multilinear_extensions::mle::FieldType; - -use std::fmt::Debug; - -use super::hash::{new_hasher, Digest, Hasher, DIGEST_WIDTH}; - -pub const OUTPUT_WIDTH: usize = 4; // Must be at least the degree of F - -pub trait FieldTranscript { - fn squeeze_challenge(&mut self) -> E; - - fn squeeze_challenges(&mut self, n: usize) -> Vec { - (0..n).map(|_| self.squeeze_challenge()).collect() - } - - fn common_field_element_base(&mut self, fe: &E::BaseField) -> Result<(), Error>; - - fn common_field_element_ext(&mut self, fe: &E) -> Result<(), Error>; - - fn common_field_elements(&mut self, fes: FieldType) -> Result<(), Error> { - match fes { - FieldType::Base(fes) => fes - .iter() - .try_for_each(|fe| self.common_field_element_base(fe))?, - FieldType::Ext(fes) => fes - .iter() - .try_for_each(|fe| self.common_field_element_ext(fe))?, - FieldType::Unreachable => unreachable!(), - }; - Ok(()) - } -} - -pub trait FieldTranscriptRead: FieldTranscript { - fn read_field_element_base(&mut self) -> Result; - - fn read_field_element_ext(&mut self) -> Result; - - fn read_field_elements_base(&mut self, n: usize) -> Result, Error> { - (0..n).map(|_| self.read_field_element_base()).collect() - } - - fn read_field_elements_ext(&mut self, n: usize) -> Result, Error> { - (0..n).map(|_| self.read_field_element_ext()).collect() - } -} - -pub trait FieldTranscriptWrite: FieldTranscript { - fn write_field_element_base(&mut self, fe: &E::BaseField) -> Result<(), Error>; - - fn write_field_element_ext(&mut self, fe: &E) -> Result<(), Error>; - - fn write_field_elements_base<'a>( - &mut self, - fes: impl IntoIterator, - ) -> Result<(), Error> - where - E::BaseField: 'a, - { - for fe in fes.into_iter() { - self.write_field_element_base(fe)?; - } - Ok(()) - } - - fn write_field_elements_ext<'a>( - &mut self, - fes: impl IntoIterator, - ) -> Result<(), Error> - where - E::BaseField: 'a, - { - for fe in fes.into_iter() { - self.write_field_element_ext(fe)?; - } - Ok(()) - } -} - -pub trait Transcript: FieldTranscript { - fn common_commitment(&mut self, comm: &C) -> Result<(), Error>; - - fn common_commitments(&mut self, comms: &[C]) -> Result<(), Error> { - comms - .iter() - .map(|comm| self.common_commitment(comm)) - .try_collect() - } -} - -pub trait TranscriptRead: Transcript + FieldTranscriptRead { - fn read_commitment(&mut self) -> Result; - - fn read_commitments(&mut self, n: usize) -> Result, Error> { - (0..n).map(|_| self.read_commitment()).collect() - } -} - -pub trait TranscriptWrite: - Transcript + FieldTranscriptWrite -{ - fn write_commitment(&mut self, comm: &C) -> Result<(), Error>; - - fn write_commitments<'a>(&mut self, comms: impl IntoIterator) -> Result<(), Error> - where - C: 'a, - { - for comm in comms.into_iter() { - self.write_commitment(comm)?; - } - Ok(()) - } -} - -pub trait InMemoryTranscript { - fn new() -> Self; - - fn into_proof(self) -> Vec; - - fn from_proof(proof: &[E::BaseField]) -> Self; -} - -#[derive(Debug, Clone, PartialEq, Eq, Default)] -struct Stream { - inner: Vec, - pointer: usize, -} - -impl Stream { - pub fn new(content: Vec) -> Self { - Self { - inner: content, - pointer: 0, - } - } - - pub fn into_inner(self) -> Vec { - self.inner - } - - fn left(&self) -> usize { - self.inner.len() - self.pointer - } - - pub fn read_exact(&mut self, output: &mut [T]) -> Result<(), Error> { - let left = self.left(); - if left < output.len() { - return Err(Error::Transcript( - "Insufficient data in transcript".to_string(), - )); - } - let len = output.len(); - output.copy_from_slice(&self.inner[self.pointer..(self.pointer + len)]); - self.pointer += output.len(); - Ok(()) - } - - pub fn write_all(&mut self, input: &[T]) -> Result<(), Error> { - self.inner.extend_from_slice(input); - Ok(()) - } -} - -#[derive(Debug, Clone)] -pub struct PoseidonTranscript { - state: Hasher, - stream: Stream, -} - -impl Default for PoseidonTranscript { - fn default() -> Self { - Self { - state: new_hasher::(), - stream: Stream::default(), - } - } -} - -impl InMemoryTranscript for PoseidonTranscript { - fn new() -> Self { - Self::default() - } - - fn into_proof(self) -> Vec { - self.stream.into_inner() - } - - fn from_proof(proof: &[E::BaseField]) -> Self { - Self { - state: new_hasher::(), - stream: Stream::new(proof.to_vec()), - } - } -} - -impl FieldTranscript for PoseidonTranscript { - fn squeeze_challenge(&mut self) -> E { - let hash: [E::BaseField; OUTPUT_WIDTH] = self.state.squeeze_vec()[0..OUTPUT_WIDTH] - .try_into() - .unwrap(); - E::from_limbs(&hash[..E::DEGREE]) - } - - fn common_field_element_base(&mut self, fe: &E::BaseField) -> Result<(), Error> { - self.state.update(&[*fe]); - Ok(()) - } - - fn common_field_element_ext(&mut self, fe: &E) -> Result<(), Error> { - self.state.update(fe.as_bases()); - Ok(()) - } -} - -impl FieldTranscriptRead for PoseidonTranscript { - fn read_field_element_ext(&mut self) -> Result { - let mut repr = vec![E::BaseField::ZERO; E::DEGREE]; - - self.stream.read_exact(&mut repr)?; - - let fe = E::from_limbs(&repr); - self.common_field_element_ext(&fe)?; - Ok(fe) - } - - fn read_field_element_base(&mut self) -> Result { - let mut repr = vec![E::BaseField::ZERO; 1]; - self.stream.read_exact(&mut repr)?; - self.common_field_element_base(&repr[0])?; - Ok(repr[0]) - } -} - -impl FieldTranscriptWrite for PoseidonTranscript { - fn write_field_element_ext(&mut self, fe: &E) -> Result<(), Error> { - self.common_field_element_ext(fe)?; - self.stream.write_all(fe.as_bases()) - } - - fn write_field_element_base(&mut self, fe: &E::BaseField) -> Result<(), Error> { - self.common_field_element_base(fe)?; - self.stream.write_all(&[*fe]) - } -} - -impl Transcript, E> for PoseidonTranscript { - fn common_commitment(&mut self, comm: &Digest) -> Result<(), Error> { - self.state.update(&comm.0); - Ok(()) - } - - fn common_commitments(&mut self, comms: &[Digest]) -> Result<(), Error> { - comms - .iter() - .map(|comm| self.common_commitment(comm)) - .try_collect() - } -} - -impl TranscriptRead, E> for PoseidonTranscript { - fn read_commitment(&mut self) -> Result, Error> { - let mut repr = vec![E::BaseField::ZERO; DIGEST_WIDTH]; - self.stream.read_exact(&mut repr)?; - let comm = Digest(repr.as_slice().try_into().unwrap()); - self.common_commitment(&comm)?; - Ok(comm) - } -} - -impl TranscriptWrite, E> for PoseidonTranscript { - fn write_commitment(&mut self, comm: &Digest) -> Result<(), Error> { - self.common_commitment(comm)?; - self.stream.write_all(&comm.0) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use goldilocks::{Goldilocks as F, GoldilocksExt2 as EF}; - - #[test] - fn test_transcript() { - let mut transcript = PoseidonTranscript::::new(); - transcript.write_field_element_base(&F::from(1)).unwrap(); - let a = transcript.squeeze_challenge(); - transcript.write_field_element_base(&F::from(2)).unwrap(); - transcript - .write_commitment(&Digest([F::from(3); DIGEST_WIDTH])) - .unwrap(); - let b = transcript.squeeze_challenge(); - let proof = transcript.into_proof(); - let mut transcript = PoseidonTranscript::::from_proof(&proof); - assert_eq!(transcript.read_field_element_base().unwrap(), F::from(1)); - assert_eq!(transcript.squeeze_challenge(), a); - assert_eq!(transcript.read_field_element_base().unwrap(), F::from(2)); - assert_eq!( - transcript.read_commitment().unwrap(), - Digest([F::from(3); DIGEST_WIDTH]) - ); - assert_eq!(transcript.squeeze_challenge(), b); - - let mut transcript = PoseidonTranscript::::new(); - transcript.write_field_element_ext(&EF::from(1)).unwrap(); - let a = transcript.squeeze_challenge(); - transcript.write_field_element_ext(&EF::from(2)).unwrap(); - transcript - .write_commitment(&Digest([F::from(3); DIGEST_WIDTH])) - .unwrap(); - let b = transcript.squeeze_challenge(); - let proof = transcript.into_proof(); - let mut transcript = PoseidonTranscript::::from_proof(&proof); - assert_eq!(transcript.read_field_element_ext().unwrap(), EF::from(1)); - assert_eq!(transcript.squeeze_challenge(), a); - assert_eq!(transcript.read_field_element_ext().unwrap(), EF::from(2)); - assert_eq!( - transcript.read_commitment().unwrap(), - Digest([F::from(3); DIGEST_WIDTH]) - ); - assert_eq!(transcript.squeeze_challenge(), b); - } -} diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index 7d889111e..72c5b2649 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -66,12 +66,6 @@ pub struct VPAuxInfo { pub phantom: PhantomData, } -impl AsRef<[u8]> for VPAuxInfo { - fn as_ref(&self) -> &[u8] { - todo!() - } -} - impl VirtualPolynomial { /// Creates an empty virtual polynomial with `num_variables`. pub fn new(num_variables: usize) -> Self { diff --git a/multilinear_extensions/src/virtual_poly_v2.rs b/multilinear_extensions/src/virtual_poly_v2.rs index dcf588baf..5d64d88bc 100644 --- a/multilinear_extensions/src/virtual_poly_v2.rs +++ b/multilinear_extensions/src/virtual_poly_v2.rs @@ -63,12 +63,6 @@ pub struct VPAuxInfo { pub phantom: PhantomData, } -impl AsRef<[u8]> for VPAuxInfo { - fn as_ref(&self) -> &[u8] { - todo!() - } -} - impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> { /// Creates an empty virtual polynomial with `max_num_variables`. pub fn new(max_num_variables: usize) -> Self { diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index fd33b9d09..3a101fe53 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -103,7 +103,7 @@ fn prepare_input<'a, E: ExtensionField>( fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; - for nv in NV.into_iter() { + for nv in NV { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("sumcheck_nv_{}", nv)); group.sample_size(NUM_SAMPLES); @@ -148,7 +148,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; let threads = max_usable_threads(); - for nv in NV.into_iter() { + for nv in NV { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv)); group.sample_size(NUM_SAMPLES); diff --git a/sumcheck/src/macros.rs b/sumcheck/src/macros.rs index 470afdf18..a8c63206a 100644 --- a/sumcheck/src/macros.rs +++ b/sumcheck/src/macros.rs @@ -1,17 +1,21 @@ #[macro_export] macro_rules! entered_span { + ($first:expr, $($fields:tt)*) => { + $crate::tracing_span!($first, $($fields)*).entered() + }; ($first:expr $(,)*) => { $crate::tracing_span!($first).entered() }; } - #[macro_export] macro_rules! tracing_span { + ($first:expr, $($fields:tt)*) => { + tracing::span!(tracing::Level::INFO, $first, $($fields)*) + }; ($first:expr $(,)*) => { - tracing::span!(tracing::Level::DEBUG, $first) + tracing::span!(tracing::Level::INFO, $first) }; } - #[macro_export] macro_rules! exit_span { ($first:expr $(,)*) => { diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index f2ede8680..72a237e9c 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -90,43 +90,48 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ); let tx_prover_state = tx_prover_state.clone(); let mut thread_based_transcript = thread_based_transcript.clone(); - + let current_span = tracing::Span::current(); + // NOTE: Apply the span.in_scope(||) pattern to record work of spawned thread inside + // span of parent thread. s.spawn(move |_| { - let mut challenge = None; - let span = entered_span!("prove_rounds"); - for _ in 0..num_variables { - let prover_msg = IOPProverStateV2::prove_round_and_update_state( - &mut prover_state, - &challenge, - ); - thread_based_transcript.append_field_element_exts(&prover_msg.evaluations); + current_span.in_scope(|| { + let mut challenge = None; + let span = entered_span!("prove_rounds"); + for _ in 0..num_variables { + let prover_msg = IOPProverStateV2::prove_round_and_update_state( + &mut prover_state, + &challenge, + ); + thread_based_transcript + .append_field_element_exts(&prover_msg.evaluations); - challenge = Some( - thread_based_transcript.get_and_append_challenge(b"Internal round"), - ); - thread_based_transcript.commit_rolling(); - } - exit_span!(span); - // pushing the last challenge point to the state - if let Some(p) = challenge { - prover_state.challenges.push(p); - // fix last challenge to collect final evaluation - prover_state - .poly - .flattened_ml_extensions - .iter_mut() - .for_each(|mle| { - let mle = Arc::get_mut(mle).unwrap(); - if mle.num_vars() > 0 { - mle.fix_variables_in_place(&[p.elements]); - } - }); - tx_prover_state - .send(Some((thread_id, prover_state))) - .unwrap(); - } else { - tx_prover_state.send(None).unwrap(); - } + challenge = Some( + thread_based_transcript.get_and_append_challenge(b"Internal round"), + ); + thread_based_transcript.commit_rolling(); + } + exit_span!(span); + // pushing the last challenge point to the state + if let Some(p) = challenge { + prover_state.challenges.push(p); + // fix last challenge to collect final evaluation + prover_state + .poly + .flattened_ml_extensions + .iter_mut() + .for_each(|mle| { + let mle = Arc::get_mut(mle).unwrap(); + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]); + } + }); + tx_prover_state + .send(Some((thread_id, prover_state))) + .unwrap(); + } else { + tx_prover_state.send(None).unwrap(); + } + }) }); } @@ -139,7 +144,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let tx_prover_state = tx_prover_state.clone(); let mut thread_based_transcript = thread_based_transcript.clone(); - let span = entered_span!("main_thread_prove_rounds"); + let main_thread_span = entered_span!("main_thread_prove_rounds"); // main thread also be one worker thread // NOTE inline main thread flow with worker thread to improve efficiency // refactor to shared closure cause to 5% throuput drop @@ -158,7 +163,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { evaluations += AdditiveVec(round_poly_coeffs); } - let span = entered_span!("main_thread_get_challenge"); + let get_challenge_span = entered_span!("main_thread_get_challenge"); transcript.append_field_element_exts(&evaluations.0); let next_challenge = transcript.get_and_append_challenge(b"Internal round"); @@ -166,7 +171,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { thread_based_transcript.send_challenge(next_challenge.elements); }); - exit_span!(span); + exit_span!(get_challenge_span); prover_msgs.push(IOPProverMessage { evaluations: evaluations.0, @@ -176,7 +181,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { Some(thread_based_transcript.get_and_append_challenge(b"Internal round")); thread_based_transcript.commit_rolling(); } - exit_span!(span); + exit_span!(main_thread_span); // pushing the last challenge point to the state if let Some(p) = challenge { prover_state.challenges.push(p);