Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/transcript fork #224

Merged
merged 2 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ceno_zkvm/examples/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@

let max_threads = {
if !is_power_of_2(RAYON_NUM_THREADS) {
#[cfg(not(feature = "non_pow2_rayon_thread"))]

Check warning on line 58 in ceno_zkvm/examples/riscv_add.rs

View workflow job for this annotation

GitHub Actions / Various lints (x86_64-unknown-linux-gnu)

unexpected `cfg` condition value: `non_pow2_rayon_thread`
{
panic!(
"add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool"
);
}

#[cfg(feature = "non_pow2_rayon_thread")]

Check warning on line 65 in ceno_zkvm/examples/riscv_add.rs

View workflow job for this annotation

GitHub Actions / Various lints (x86_64-unknown-linux-gnu)

unexpected `cfg` condition value: `non_pow2_rayon_thread`
{
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS);
Expand Down Expand Up @@ -141,12 +141,12 @@

let timer = Instant::now();

let mut transcript = Transcript::new(b"riscv");
let transcript = Transcript::new(b"riscv");
let mut rng = test_rng();
let real_challenges = [E::random(&mut rng), E::random(&mut rng)];

let zkvm_proof = prover
.create_proof(zkvm_witness, max_threads, &mut transcript, &real_challenges)
.create_proof(zkvm_witness, max_threads, transcript, &real_challenges)
.expect("create_proof failed");

println!(
Expand All @@ -155,10 +155,10 @@
timer.elapsed().as_secs_f64()
);

let mut transcript = Transcript::new(b"riscv");
let transcript = Transcript::new(b"riscv");
assert!(
verifier
.verify_proof(zkvm_proof, &mut transcript, &real_challenges)
.verify_proof(zkvm_proof, transcript, &real_challenges)
.expect("verify proof return with error"),
);
}
Expand Down
13 changes: 11 additions & 2 deletions ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@ pub struct ZKVMTableProof<E: ExtensionField> {
pub wits_in_evals: Vec<E>,
}

/// Map circuit names to
/// - an opcode or table proof,
/// - an index unique across both types.
#[derive(Default, Clone)]
pub struct ZKVMProof<E: ExtensionField> {
opcode_proofs: HashMap<String, ZKVMOpcodeProof<E>>,
table_proofs: HashMap<String, ZKVMTableProof<E>>,
opcode_proofs: HashMap<String, (usize, ZKVMOpcodeProof<E>)>,
table_proofs: HashMap<String, (usize, ZKVMTableProof<E>)>,
}

impl<E: ExtensionField> ZKVMProof<E> {
pub fn num_circuits(&self) -> usize {
self.opcode_proofs.len() + self.table_proofs.len()
}
}
15 changes: 11 additions & 4 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,18 @@ impl<E: ExtensionField> ZKVMProver<E> {
&self,
mut witnesses: ZKVMWitnesses<E>,
max_threads: usize,
transcript: &mut Transcript<E>,
transcript: Transcript<E>,
challenges: &[E; 2],
) -> Result<ZKVMProof<E>, ZKVMError> {
let mut vm_proof = ZKVMProof::default();
for (circuit_name, pk) in self.pk.circuit_pks.iter() {
let mut transcripts = transcript.fork(self.pk.circuit_pks.len());

for ((circuit_name, pk), (i, transcript)) in self
.pk
.circuit_pks
.iter() // Sorted by key.
.zip_eq(transcripts.iter_mut().enumerate())
{
let witness = witnesses
.witnesses
.remove(circuit_name)
Expand Down Expand Up @@ -94,7 +101,7 @@ impl<E: ExtensionField> ZKVMProver<E> {
);
vm_proof
.opcode_proofs
.insert(circuit_name.clone(), opcode_proof);
.insert(circuit_name.clone(), (i, opcode_proof));
} else {
let table_proof = self.create_table_proof(
pk,
Expand All @@ -116,7 +123,7 @@ impl<E: ExtensionField> ZKVMProver<E> {
);
vm_proof
.table_proofs
.insert(circuit_name.clone(), table_proof);
.insert(circuit_name.clone(), (i, table_proof));
}
}

Expand Down
12 changes: 9 additions & 3 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
pub fn verify_proof(
&self,
vm_proof: ZKVMProof<E>,
transcript: &mut Transcript<E>,
transcript: Transcript<E>,
challenges: &[E; 2],
) -> Result<bool, ZKVMError> {
let mut prod_r = E::ONE;
Expand All @@ -48,7 +48,11 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
let dummy_table_item = challenges[0];
let point_eval = PointAndEval::default();
let mut dummy_table_item_multiplicity = 0;
for (name, opcode_proof) in vm_proof.opcode_proofs {
let mut transcripts = transcript.fork(vm_proof.num_circuits());

for (name, (i, opcode_proof)) in vm_proof.opcode_proofs {
let transcript = &mut transcripts[i];

let circuit_vk = self
.vk
.circuit_vks
Expand Down Expand Up @@ -82,7 +86,9 @@ impl<E: ExtensionField> ZKVMVerifier<E> {
opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap();
}

for (name, table_proof) in vm_proof.table_proofs {
for (name, (i, table_proof)) in vm_proof.table_proofs {
let transcript = &mut transcripts[i];

let circuit_vk = self
.vk
.circuit_vks
Expand Down
11 changes: 11 additions & 0 deletions transcript/src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ impl<E: ExtensionField> Transcript<E> {
}

impl<E: ExtensionField> Transcript<E> {
/// Fork this transcript into n different threads.
pub fn fork(self, n: usize) -> Vec<Self> {
let mut forks = Vec::with_capacity(n);
for i in 0..n {
let mut fork = self.clone();
fork.append_field_element(&(i as u64).into());
forks.push(fork);
}
forks
}

// Append the message to the transcript.
pub fn append_message(&mut self, msg: &[u8]) {
let msg_f = E::BaseField::bytes_to_field_elements(msg);
Expand Down
Loading