Skip to content

Commit

Permalink
batch assignment per thread (#203)
Browse files Browse the repository at this point in the history
To address comment
#198 (comment) For
further reduce the Arc clone overhead.

I think in the future we might redesign `assign_instances` to support
streaming assign in the future, such that StepRecord will emit as
streaming from emulator and multi-threads as worker to collect batch
data and do assignment.
We can explore further usage when tuning performance
  • Loading branch information
hero78119 authored Sep 10, 2024
1 parent d988e54 commit e524c47
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
29 changes: 22 additions & 7 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use std::mem::MaybeUninit;

use ceno_emul::StepRecord;
use ff_ext::ExtensionField;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
use rayon::{
iter::{IndexedParallelIterator, ParallelIterator},
slice::ParallelSlice,
};

use crate::{
circuit_builder::CircuitBuilder,
Expand All @@ -23,26 +26,38 @@ pub trait Instruction<E: ExtensionField> {
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: StepRecord,
step: &StepRecord,
) -> Result<(), ZKVMError>;

fn assign_instances(
config: &Self::InstructionConfig,
num_witin: usize,
steps: Vec<StepRecord>,
) -> Result<(RowMajorMatrix<E::BaseField>, LkMultiplicity), ZKVMError> {
let nthreads =
std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::<usize>().unwrap_or(8));
let num_instance_per_batch = if steps.len() > 256 {
steps.len() / nthreads
} else {
steps.len()
};
let lk_multiplicity = LkMultiplicity::default();
let mut raw_witin = RowMajorMatrix::<E::BaseField>::new(steps.len(), num_witin);
let raw_witin_iter = raw_witin.par_iter_mut();
let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch);

raw_witin_iter
.zip_eq(steps.into_par_iter())
.map(|(instance, step)| {
.zip_eq(steps.par_chunks(num_instance_per_batch))
.flat_map(|(instances, steps)| {
let mut lk_multiplicity = lk_multiplicity.clone();
Self::assign_instance(config, instance, &mut lk_multiplicity, step)
instances
.chunks_mut(num_witin)
.zip(steps)
.map(|(instance, step)| {
Self::assign_instance(config, instance, &mut lk_multiplicity, step)
})
.collect::<Vec<_>>()
})
.collect::<Result<(), ZKVMError>>()?;

Ok((raw_witin, lk_multiplicity))
}
}
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: StepRecord,
step: &StepRecord,
) -> Result<(), ZKVMError> {
// TODO use fields from step
set_val!(instance, config.pc, 1);
Expand Down Expand Up @@ -202,7 +202,7 @@ impl<E: ExtensionField> Instruction<E> for SubInstruction {
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
_lk_multiplicity: &mut LkMultiplicity,
_step: StepRecord,
_step: &StepRecord,
) -> Result<(), ZKVMError> {
// TODO use field from step
set_val!(instance, config.pc, _step.pc().before.0 as u64);
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/blt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ impl<E: ExtensionField> Instruction<E> for BltInstruction {
config: &Self::InstructionConfig,
instance: &mut [std::mem::MaybeUninit<E::BaseField>],
_lk_multiplicity: &mut LkMultiplicity,
_step: ceno_emul::StepRecord,
_step: &ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
// take input from _step
let input = BltInput::random();
Expand Down
7 changes: 7 additions & 0 deletions ceno_zkvm/src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ impl<T: Sized + Sync + Clone + Send> RowMajorMatrix<T> {
self.values.par_chunks_mut(self.num_col)
}

pub fn par_batch_iter_mut(
&mut self,
num_rows: usize,
) -> rayon::slice::ChunksMut<MaybeUninit<T>> {
self.values.par_chunks_mut(num_rows * self.num_col)
}

pub fn de_interleaving(mut self) -> Vec<Vec<T>> {
(0..self.num_col)
.map(|i| {
Expand Down

0 comments on commit e524c47

Please sign in to comment.