Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix add bench build error & example #211

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use std::time::{Duration, Instant};

use ark_std::test_rng;
use ceno_zkvm::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
self,
instructions::{riscv::addsub::AddInstruction, Instruction},
scheme::prover::ZKVMProver,
structs::{ZKVMConstraintSystem, ZKVMFixedTraces},
};
use const_env::from_env;
use criterion::*;
Expand Down Expand Up @@ -62,11 +63,22 @@ fn bench_add(c: &mut Criterion) {
RAYON_NUM_THREADS
}
};
let mut cs = ConstraintSystem::new(|| "risv_add");
let mut circuit_builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);
let _ = AddInstruction::construct_circuit(&mut circuit_builder);
let pk = cs.key_gen(None);
let num_witin = pk.get_cs().num_witin;
let mut zkvm_cs = ZKVMConstraintSystem::default();
let _ = zkvm_cs.register_opcode_circuit::<AddInstruction<E>>();
let mut zkvm_fixed_traces = ZKVMFixedTraces::default();
zkvm_fixed_traces.register_opcode_circuit::<AddInstruction<E>>(&zkvm_cs);

let pk = zkvm_cs
.clone()
.key_gen(zkvm_fixed_traces)
.expect("keygen failed");

let circuit_pk = pk
.circuit_pks
.get(&AddInstruction::<E>::name())
.unwrap()
.clone();
let num_witin = circuit_pk.get_cs().num_witin;

let prover = ZKVMProver::new(pk);
let mut transcript = Transcript::new(b"riscv");
Expand Down Expand Up @@ -101,6 +113,7 @@ fn bench_add(c: &mut Criterion) {
let timer = Instant::now();
let _ = prover
.create_opcode_proof(
&circuit_pk,
wits_in,
num_instances,
max_threads,
Expand Down
17 changes: 9 additions & 8 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
// - x2 is initialized to -1,
// - x3 is initialized to loop bound.
// we use x4 to hold the acc_sum.
#[allow(clippy::unusual_byte_groupings)]
const PROGRAM_ADD_LOOP: [u32; 4] = [
// func7 rs2 rs1 f3 rd opcode
0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1
Expand Down Expand Up @@ -54,14 +55,14 @@

let max_threads = {
if !is_power_of_2(RAYON_NUM_THREADS) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]

Check warning on line 58 in ceno_zkvm/examples/riscv_add.rs

View workflow job for this annotation

GitHub Actions / Various lints (x86_64-unknown-linux-gnu)

unexpected `cfg` condition value: `non_pow2_rayon_thread`
{
panic!(
"add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool"
);
}

#[cfg(feature = "non_pow2_rayon_thread")]

Check warning on line 65 in ceno_zkvm/examples/riscv_add.rs

View workflow job for this annotation

GitHub Actions / Various lints (x86_64-unknown-linux-gnu)

unexpected `cfg` condition value: `non_pow2_rayon_thread`
{
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS);
Expand Down Expand Up @@ -106,15 +107,15 @@
let verifier = ZKVMVerifier::new(vk);

for instance_num_vars in args.start..args.end {
let num_instances = 1 << instance_num_vars;
let step_loop = 1 << (instance_num_vars - 1); // 1 step in loop contribute to 2 add instance
let mut vm = VMState::new(CENO_PLATFORM);
let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr();

// init vm.x1 = 1, vm.x2 = -1, vm.x3 = num_instances
// vm.x4 += vm.x1
vm.init_register_unsafe(1usize, 1);
vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement
vm.init_register_unsafe(3usize, num_instances as u32);
vm.init_register_unsafe(3usize, step_loop as u32);
for (i, inst) in PROGRAM_ADD_LOOP.iter().enumerate() {
vm.init_memory(pc_start + i, *inst);
}
Expand Down Expand Up @@ -148,17 +149,17 @@
.create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges)
.expect("create_proof failed");

println!(
"AddInstruction::create_proof, instance_num_vars = {}, time = {}",
instance_num_vars,
timer.elapsed().as_secs_f64()
);

let mut transcript = Transcript::new(b"riscv");
assert!(
verifier
.verify_proof(zkvm_proof, &mut transcript, &real_challenges)
.expect("verify proof return with error"),
);

println!(
"AddInstruction::create_proof, instance_num_vars = {}, time = {}",
instance_num_vars,
timer.elapsed().as_secs_f64()
);
}
}
2 changes: 1 addition & 1 deletion ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
#[derive(Default)]
pub struct ZKVMProvingKey<E: ExtensionField> {
// pk for opcode and table circuits
pub(crate) circuit_pks: BTreeMap<String, ProvingKey<E>>,
pub circuit_pks: BTreeMap<String, ProvingKey<E>>,
}

impl<E: ExtensionField> ZKVMProvingKey<E> {
Expand Down
Loading