From 9dd45dbf88fb11406b9c459b33cd9f441074cc8e Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 30 Oct 2024 01:14:54 +0800 Subject: [PATCH] include all opcodes in rv32im config --- ceno_zkvm/examples/fibonacci_elf.rs | 4 +- ceno_zkvm/src/instructions/riscv/branch.rs | 8 +- .../src/instructions/riscv/branch/blt.rs | 4 +- .../src/instructions/riscv/branch/bltu.rs | 4 +- ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 4 + .../src/instructions/riscv/memory/store.rs | 3 - ceno_zkvm/src/instructions/riscv/rv32im.rs | 251 ++++++++++++++---- ceno_zkvm/src/instructions/riscv/shift.rs | 10 +- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 3 + ceno_zkvm/src/scheme/prover.rs | 31 ++- ceno_zkvm/src/structs.rs | 14 +- 11 files changed, 245 insertions(+), 91 deletions(-) diff --git a/ceno_zkvm/examples/fibonacci_elf.rs b/ceno_zkvm/examples/fibonacci_elf.rs index 041032f89..9f483a5e1 100644 --- a/ceno_zkvm/examples/fibonacci_elf.rs +++ b/ceno_zkvm/examples/fibonacci_elf.rs @@ -11,6 +11,7 @@ use ceno_zkvm::{ structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{MemFinalRecord, ProgramTableCircuit, initial_memory, initial_registers}, }; +use ff_ext::ff::Field; use goldilocks::GoldilocksExt2; use itertools::Itertools; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; @@ -19,7 +20,6 @@ use std::{panic, time::Instant}; use tracing_flame::FlameLayer; use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt}; use transcript::Transcript; -use ff_ext::ff::Field; fn main() { type E = GoldilocksExt2; @@ -153,7 +153,7 @@ fn main() { .assign_table_circuit::>( &zkvm_cs, &prog_config, - &vm.program(), + vm.program(), ) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs index f9cb2c311..6fe929980 100644 --- a/ceno_zkvm/src/instructions/riscv/branch.rs +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -25,22 +25,22 @@ pub struct BltuOp; impl RIVInstruction for BltuOp { const INST_KIND: InsnKind = InsnKind::BLTU; } -pub type BltuInstruction = bltu::BltuCircuit; +pub type BltuInstruction = bltu::BltuCircuit; pub struct BgeuOp; impl RIVInstruction for BgeuOp { const INST_KIND: InsnKind = InsnKind::BGEU; } -pub type BgeuInstruction = bltu::BltuCircuit; +pub type BgeuInstruction = bltu::BltuCircuit; pub struct BltOp; impl RIVInstruction for BltOp { const INST_KIND: InsnKind = InsnKind::BLT; } -pub type BltInstruction = blt::BltCircuit; +pub type BltInstruction = blt::BltCircuit; pub struct BgeOp; impl RIVInstruction for BgeOp { const INST_KIND: InsnKind = InsnKind::BGE; } -pub type BgeInstruction = blt::BltCircuit; +pub type BgeInstruction = blt::BltCircuit; diff --git a/ceno_zkvm/src/instructions/riscv/branch/blt.rs b/ceno_zkvm/src/instructions/riscv/branch/blt.rs index 43f07bbfb..47caafc9e 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/blt.rs @@ -16,7 +16,7 @@ use crate::{ }; use ceno_emul::{InsnKind, SWord}; -pub struct BltCircuit(PhantomData); +pub struct BltCircuit(PhantomData<(E, I)>); pub struct InstructionConfig { pub b_insn: BInstructionConfig, @@ -25,7 +25,7 @@ pub struct InstructionConfig { pub signed_lt: SignedLtConfig, } -impl Instruction for BltCircuit { +impl Instruction for BltCircuit { fn name() -> String { format!("{:?}", I::INST_KIND) } diff --git a/ceno_zkvm/src/instructions/riscv/branch/bltu.rs b/ceno_zkvm/src/instructions/riscv/branch/bltu.rs index 3043e7563..896bf19da 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/bltu.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/bltu.rs @@ -20,7 +20,7 @@ use crate::{ }; use ceno_emul::InsnKind; -pub struct BltuCircuit(PhantomData); +pub struct BltuCircuit(PhantomData<(E, I)>); pub struct InstructionConfig { pub b_insn: BInstructionConfig, @@ -29,7 +29,7 @@ pub struct InstructionConfig { pub is_lt: IsLtConfig, } -impl Instruction for BltuCircuit { +impl Instruction for BltuCircuit { fn name() -> String { format!("{:?}", I::INST_KIND) } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index a80e11c4a..9795f324b 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -109,6 +109,10 @@ impl Instruction for JalrInstruction { let rs1 = step.rs1().unwrap().value; let imm: i32 = insn.imm_or_funct7() as i32; + if step.rd().is_none() { + tracing::info!("step: {:?}", step.insn().kind()); + tracing::info!("step: {:?}", step); + } let rd = step.rd().unwrap().value.after; let (sum, overflowing) = rs1.overflowing_add_signed(imm); diff --git a/ceno_zkvm/src/instructions/riscv/memory/store.rs b/ceno_zkvm/src/instructions/riscv/memory/store.rs index 878777de6..8d1fcfc13 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store.rs @@ -38,7 +38,6 @@ impl RIVInstruction for SWOp { const INST_KIND: InsnKind = InsnKind::SW; } -#[cfg(test)] pub type SwInstruction = StoreInstruction; pub struct SHOp; @@ -47,7 +46,6 @@ impl RIVInstruction for SHOp { const INST_KIND: InsnKind = InsnKind::SH; } -#[cfg(test)] pub type ShInstruction = StoreInstruction; pub struct SBOp; @@ -56,7 +54,6 @@ impl RIVInstruction for SBOp { const INST_KIND: InsnKind = InsnKind::SB; } -#[cfg(test)] pub type SbInstruction = StoreInstruction; impl Instruction diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 14a65ec35..ffdfddc61 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -1,6 +1,23 @@ use crate::{ error::ZKVMError, - instructions::{Instruction, riscv::*}, + instructions::{ + Instruction, + riscv::{ + arith_imm::AddiInstruction, + branch::{ + BeqInstruction, BgeInstruction, BgeuInstruction, BltInstruction, BneInstruction, + }, + divu::DivUInstruction, + logic::{AndInstruction, OrInstruction, XorInstruction}, + logic_imm::{AndiInstruction, OriInstruction, XoriInstruction}, + mulh::MulhuInstruction, + shift::{SllInstruction, SrlInstruction}, + shift_imm::{SlliInstruction, SraiInstruction, SrliInstruction}, + slti::SltiInstruction, + sltu::SltuInstruction, + *, + }, + }, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ AndTableCircuit, LtuTableCircuit, MemFinalRecord, MemInitRecord, MemTableCircuit, @@ -11,6 +28,7 @@ use ceno_emul::{CENO_PLATFORM, InsnKind, InsnKind::*, StepRecord}; use ff_ext::ExtensionField; use itertools::Itertools; use num_traits::cast::ToPrimitive; +use std::collections::BTreeMap; use strum::IntoEnumIterator; use super::{ @@ -25,17 +43,40 @@ pub struct Rv32imConfig { // ALU Opcodes. pub add_config: as Instruction>::InstructionConfig, pub sub_config: as Instruction>::InstructionConfig, + pub and_config: as Instruction>::InstructionConfig, + pub or_config: as Instruction>::InstructionConfig, + pub xor_config: as Instruction>::InstructionConfig, + pub sll_config: as Instruction>::InstructionConfig, + pub srl_config: as Instruction>::InstructionConfig, + // TODO: sra / slt + pub sltu_config: as Instruction>::InstructionConfig, + pub mul_config: as Instruction>::InstructionConfig, + pub mulhu_config: as Instruction>::InstructionConfig, + pub divu_config: as Instruction>::InstructionConfig, - // Branching Opcodes - pub bltu_config: >::InstructionConfig, + // ALU with imm + pub addi_config: as Instruction>::InstructionConfig, + pub andi_config: as Instruction>::InstructionConfig, + pub ori_config: as Instruction>::InstructionConfig, + pub xori_config: as Instruction>::InstructionConfig, + pub slli_config: as Instruction>::InstructionConfig, + pub srli_config: as Instruction>::InstructionConfig, + pub srai_config: as Instruction>::InstructionConfig, + pub slti_config: as Instruction>::InstructionConfig, - // Imm - pub lui_config: as Instruction>::InstructionConfig, + // Branching Opcodes + pub beq_config: as Instruction>::InstructionConfig, + pub bne_config: as Instruction>::InstructionConfig, + pub blt_config: as Instruction>::InstructionConfig, + pub bltu_config: as Instruction>::InstructionConfig, + pub bge_config: as Instruction>::InstructionConfig, + pub bgeu_config: as Instruction>::InstructionConfig, // Jump Opcodes pub jal_config: as Instruction>::InstructionConfig, pub jalr_config: as Instruction>::InstructionConfig, pub auipc_config: as Instruction>::InstructionConfig, + pub lui_config: as Instruction>::InstructionConfig, // Memory Opcodes pub lw_config: as Instruction>::InstructionConfig, @@ -52,7 +93,7 @@ pub struct Rv32imConfig { // Tables. pub u16_range_config: as TableCircuit>::TableConfig, pub u14_range_config: as TableCircuit>::TableConfig, - pub and_config: as TableCircuit>::TableConfig, + pub and_table_config: as TableCircuit>::TableConfig, pub ltu_config: as TableCircuit>::TableConfig, // RW tables. @@ -66,9 +107,33 @@ impl Rv32imConfig { // alu opcodes let add_config = cs.register_opcode_circuit::>(); let sub_config = cs.register_opcode_circuit::>(); + let and_config = cs.register_opcode_circuit::>(); + let or_config = cs.register_opcode_circuit::>(); + let xor_config = cs.register_opcode_circuit::>(); + let sll_config = cs.register_opcode_circuit::>(); + let srl_config = cs.register_opcode_circuit::>(); + let sltu_config = cs.register_opcode_circuit::>(); + let mul_config = cs.register_opcode_circuit::>(); + let mulhu_config = cs.register_opcode_circuit::>(); + let divu_config = cs.register_opcode_circuit::>(); + + // alu with imm opcodes + let addi_config = cs.register_opcode_circuit::>(); + let andi_config = cs.register_opcode_circuit::>(); + let ori_config = cs.register_opcode_circuit::>(); + let xori_config = cs.register_opcode_circuit::>(); + let slli_config = cs.register_opcode_circuit::>(); + let srli_config = cs.register_opcode_circuit::>(); + let srai_config = cs.register_opcode_circuit::>(); + let slti_config = cs.register_opcode_circuit::>(); // branching opcodes - let bltu_config = cs.register_opcode_circuit::(); + let beq_config = cs.register_opcode_circuit::>(); + let bne_config = cs.register_opcode_circuit::>(); + let blt_config = cs.register_opcode_circuit::>(); + let bltu_config = cs.register_opcode_circuit::>(); + let bge_config = cs.register_opcode_circuit::>(); + let bgeu_config = cs.register_opcode_circuit::>(); // jump opcodes let lui_config = cs.register_opcode_circuit::>(); @@ -91,7 +156,7 @@ impl Rv32imConfig { // tables let u16_range_config = cs.register_table_circuit::>(); let u14_range_config = cs.register_table_circuit::>(); - let and_config = cs.register_table_circuit::>(); + let and_table_config = cs.register_table_circuit::>(); let ltu_config = cs.register_table_circuit::>(); // RW tables @@ -102,8 +167,31 @@ impl Rv32imConfig { // alu opcodes add_config, sub_config, + and_config, + or_config, + xor_config, + sll_config, + srl_config, + sltu_config, + mul_config, + mulhu_config, + divu_config, + // alu with imm + addi_config, + andi_config, + ori_config, + xori_config, + slli_config, + srli_config, + srai_config, + slti_config, // branching opcodes + beq_config, + bne_config, + blt_config, bltu_config, + bge_config, + bgeu_config, // jump opcodes lui_config, jal_config, @@ -123,7 +211,7 @@ impl Rv32imConfig { // tables u16_range_config, u14_range_config, - and_config, + and_table_config, ltu_config, reg_config, @@ -138,16 +226,41 @@ impl Rv32imConfig { reg_init: &[MemInitRecord], mem_init: &[MemInitRecord], ) { + // alu fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); - - fixed.register_opcode_circuit::(cs); - + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + // TODO: add sra / slt + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + // alu with imm + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + // branching + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + // jump fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); - + // memory fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); @@ -161,20 +274,23 @@ impl Rv32imConfig { fixed.register_table_circuit::>(cs, self.u16_range_config.clone(), &()); fixed.register_table_circuit::>(cs, self.u14_range_config.clone(), &()); - fixed.register_table_circuit::>(cs, self.and_config.clone(), &()); + fixed.register_table_circuit::>(cs, self.and_table_config.clone(), &()); fixed.register_table_circuit::>(cs, self.ltu_config.clone(), &()); fixed.register_table_circuit::>(cs, self.reg_config.clone(), reg_init); fixed.register_table_circuit::>(cs, self.mem_config.clone(), mem_init); } + pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, witness: &mut ZKVMWitnesses, steps: Vec, ) -> Result<(), ZKVMError> { - let mut all_records = vec![Vec::new(); InsnKind::iter().count()]; + let mut all_records: BTreeMap> = InsnKind::iter() + .map(|insn_kind| (insn_kind.to_usize().unwrap(), Vec::new())) + .collect(); let mut halt_records = Vec::new(); steps.into_iter().for_each(|record| { let insn_kind = record.insn().codes().kind; @@ -183,56 +299,81 @@ impl Rv32imConfig { EANY if record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() => { halt_records.push(record); } - _ => all_records[insn_kind.to_usize().unwrap()].push(record), + _ => { + let insn_kind = insn_kind.to_usize().unwrap(); + // it's safe to unwrap as all_records are initialized with Vec::new() + all_records.get_mut(&insn_kind).unwrap().push(record); + } } }); - for (insn_kind, records) in InsnKind::iter() + for (insn_kind, (_, records)) in InsnKind::iter() .zip(all_records.iter()) - .sorted_by(|a, b| Ord::cmp(&a.1.len(), &b.1.len())) + .sorted_by(|a, b| Ord::cmp(&a.1.1.len(), &b.1.1.len())) .rev() { - if records.len() != 0 { + if !records.is_empty() { tracing::info!("tracer generated {:?} {} records", insn_kind, records.len()); } } assert_eq!(halt_records.len(), 1); - witness.assign_opcode_circuit::>( - cs, - &self.add_config, - all_records[ADD.to_usize().unwrap()].as_slice(), - )?; - witness.assign_opcode_circuit::( - cs, - &self.bltu_config, - all_records[BLTU.to_usize().unwrap()].as_slice(), - )?; - witness.assign_opcode_circuit::>( - cs, - &self.jal_config, - all_records[JAL.to_usize().unwrap()].as_slice(), - )?; - witness.assign_opcode_circuit::>( - cs, - &self.lui_config, - all_records[LUI.to_usize().unwrap()].as_slice(), - )?; - witness.assign_opcode_circuit::>( - cs, - &self.lw_config, - all_records[LW.to_usize().unwrap()].as_slice(), - )?; - witness.assign_opcode_circuit::>( - cs, - &self.sw_config, - all_records[SW.to_usize().unwrap()].as_slice(), - )?; - witness.assign_opcode_circuit::>( - cs, - &self.halt_config, - &halt_records, - )?; + macro_rules! assign_opcode { + ($insn_kind:ident,$instruction:ty,$config:ident) => { + witness.assign_opcode_circuit::<$instruction>( + cs, + &self.$config, + all_records.remove(&$insn_kind.to_usize().unwrap()).unwrap(), + )?; + } + } + // alu + assign_opcode!(ADD, AddInstruction, add_config); + assign_opcode!(SUB, SubInstruction, sub_config); + assign_opcode!(AND, AndInstruction, and_config); + assign_opcode!(OR, OrInstruction, or_config); + assign_opcode!(XOR, XorInstruction, xor_config); + assign_opcode!(SLL, SllInstruction, sll_config); + assign_opcode!(SRL, SrlInstruction, srl_config); + assign_opcode!(SLTU, SltuInstruction, sltu_config); + assign_opcode!(MUL, MulInstruction, mul_config); + assign_opcode!(MULHU, MulhuInstruction, mulhu_config); + assign_opcode!(DIVU, DivUInstruction, divu_config); + // alu with imm + assign_opcode!(ADDI, AddiInstruction, addi_config); + assign_opcode!(ANDI, AndiInstruction, andi_config); + assign_opcode!(ORI, OriInstruction, ori_config); + assign_opcode!(XORI, XoriInstruction, xori_config); + assign_opcode!(SLLI, SlliInstruction, slli_config); + assign_opcode!(SRLI, SrliInstruction, srli_config); + assign_opcode!(SRAI, SraiInstruction, srai_config); + assign_opcode!(SLTI, SltiInstruction, slti_config); + // branching + assign_opcode!(BEQ, BeqInstruction, beq_config); + assign_opcode!(BNE, BneInstruction, bne_config); + assign_opcode!(BLT, BltInstruction, blt_config); + assign_opcode!(BLTU, BltuInstruction, bltu_config); + assign_opcode!(BGE, BgeInstruction, bge_config); + assign_opcode!(BGEU, BgeuInstruction, bgeu_config); + // jump + assign_opcode!(JAL, JalInstruction, jal_config); + assign_opcode!(JALR, JalrInstruction, jalr_config); + assign_opcode!(AUIPC, AuipcInstruction, auipc_config); + assign_opcode!(LUI, LuiInstruction, lui_config); + // memory + assign_opcode!(LW, LwInstruction, lw_config); + assign_opcode!(LB, LbInstruction, lb_config); + assign_opcode!(LBU, LbuInstruction, lbu_config); + assign_opcode!(LH, LhInstruction, lh_config); + assign_opcode!(LHU, LhuInstruction, lhu_config); + assign_opcode!(SW, SwInstruction, sw_config); + assign_opcode!(SH, ShInstruction, sh_config); + assign_opcode!(SB, SbInstruction, sb_config); + + // ecall / halt + witness.assign_opcode_circuit::>(cs, &self.halt_config, halt_records)?; + + assert!(all_records.is_empty()); Ok(()) } @@ -245,7 +386,7 @@ impl Rv32imConfig { ) -> Result<(), ZKVMError> { witness.assign_table_circuit::>(cs, &self.u16_range_config, &())?; witness.assign_table_circuit::>(cs, &self.u14_range_config, &())?; - witness.assign_table_circuit::>(cs, &self.and_config, &())?; + witness.assign_table_circuit::>(cs, &self.and_table_config, &())?; witness.assign_table_circuit::>(cs, &self.ltu_config, &())?; // assign register finalization. diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index f5d0f8e8d..a529b9e3c 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -31,19 +31,17 @@ pub struct ShiftConfig { pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); -#[cfg(test)] -struct SllOp; -#[cfg(test)] +pub struct SllOp; impl RIVInstruction for SllOp { const INST_KIND: InsnKind = InsnKind::SLL; } +pub type SllInstruction = ShiftLogicalInstruction; -#[cfg(test)] -struct SrlOp; -#[cfg(test)] +pub struct SrlOp; impl RIVInstruction for SrlOp { const INST_KIND: InsnKind = InsnKind::SRL; } +pub type SrlInstruction = ShiftLogicalInstruction; impl Instruction for ShiftLogicalInstruction { type InstructionConfig = ShiftConfig; diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 1f43da85c..04e141cec 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -35,16 +35,19 @@ pub struct SlliOp; impl RIVInstruction for SlliOp { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLLI; } +pub type SlliInstruction = ShiftImmInstruction; pub struct SraiOp; impl RIVInstruction for SraiOp { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRAI; } +pub type SraiInstruction = ShiftImmInstruction; pub struct SrliOp; impl RIVInstruction for SrliOp { const INST_KIND: ceno_emul::InsnKind = InsnKind::SRLI; } +pub type SrliInstruction = ShiftImmInstruction; impl Instruction for ShiftImmInstruction { type InstructionConfig = ShiftImmConfig; diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index c87283c01..b83ca4689 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -76,17 +76,23 @@ impl> ZKVMProver { for (circuit_name, witness) in witnesses.witnesses { let commit_dur = std::time::Instant::now(); let num_instances = witness.num_instances(); - let witness = witness.into_mles(); - commitments.insert( - circuit_name.clone(), - PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript) - .map_err(ZKVMError::PCSError)?, - ); - tracing::info!( - "commit to {} traces took {:?}", - circuit_name, - commit_dur.elapsed() - ); + let witness = match num_instances { + 0 => vec![], + _ => { + let witness = witness.into_mles(); + commitments.insert( + circuit_name.clone(), + PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript) + .map_err(ZKVMError::PCSError)?, + ); + tracing::info!( + "commit to {} traces took {:?}", + circuit_name, + commit_dur.elapsed() + ); + witness + } + }; wits.insert(circuit_name, (witness, num_instances)); } @@ -107,6 +113,9 @@ impl> ZKVMProver { let (witness, num_instances) = wits .remove(circuit_name) .ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; + if witness.is_empty() { + continue; + } let wits_commit = commitments.remove(circuit_name).unwrap(); // TODO: add an enum for circuit type either in constraint_system or vk let cs = pk.get_cs(); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 5ef37de55..e06c1d6dd 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -215,16 +215,18 @@ impl ZKVMWitnesses { &mut self, cs: &ZKVMConstraintSystem, config: &OC::InstructionConfig, - records: &[StepRecord], + records: Vec, ) -> Result<(), ZKVMError> { assert!(self.combined_lk_mlt.is_none()); - if records.len() == 0 { - return Ok(()); - } let cs = cs.get_cs(&OC::name()).unwrap(); - let (witness, logup_multiplicity) = - OC::assign_instances(config, cs.num_witin as usize, records.to_vec())?; + let (witness, logup_multiplicity) = match records.len() { + 0 => ( + RowMajorMatrix::new(0, cs.num_witin as usize), + LkMultiplicity::default(), + ), + _ => OC::assign_instances(config, cs.num_witin as usize, records)?, + }; assert!(self.witnesses.insert(OC::name(), witness).is_none()); assert!( self.lk_mlts