diff --git a/sdk/src/multi_prover/common/mod.rs b/sdk/src/multi_prover/common/mod.rs index d499d28937..ea28597359 100644 --- a/sdk/src/multi_prover/common/mod.rs +++ b/sdk/src/multi_prover/common/mod.rs @@ -3,7 +3,7 @@ pub mod types; use crate::{ProverClient, SP1ProofKind}; use anyhow::Result; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use sp1_core::{ runtime::{Program, Runtime, SP1Context, SP1ContextBuilder}, utils::SP1ProverOpts, @@ -18,12 +18,12 @@ use sysinfo::System; static LIMIT_RAM_GB: u64 = 120; #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ProveArgs { - pub zkvm_input: Vec, +pub struct ProveArgs { + pub zkvm_input: T, pub elf: Vec, } -impl ProveArgs { +impl ProveArgs { pub fn to_bytes(&self) -> Vec { bincode::serialize(self).unwrap() } @@ -33,7 +33,9 @@ impl ProveArgs { } } -pub fn init_client(args: &ProveArgs) -> (ProverClient, SP1Stdin, SP1ProvingKey, SP1VerifyingKey) { +pub fn init_client( + args: &ProveArgs, +) -> (ProverClient, SP1Stdin, SP1ProvingKey, SP1VerifyingKey) { let client = ProverClient::new(); let (pk, vk) = client.setup(&args.elf); let mut stdin = SP1Stdin::new(); diff --git a/sdk/src/multi_prover/operator/mod.rs b/sdk/src/multi_prover/operator/mod.rs index a1b1311a21..ec1865f76a 100644 --- a/sdk/src/multi_prover/operator/mod.rs +++ b/sdk/src/multi_prover/operator/mod.rs @@ -7,6 +7,8 @@ use crate::multi_prover::common::{ types::{CommitmentType, PublicValueStreamType}, }; use p3_baby_bear::BabyBear; +use serde::de::DeserializeOwned; +use serde::Serialize; use sp1_core::air::{PublicValues, Word}; use sp1_core::stark::{MachineProver, StarkGenericConfig}; use sp1_core::utils::BabyBearPoseidon2; @@ -19,14 +21,14 @@ use steps::{ }; use utils::{read_bin_file_to_vec, ChallengerState}; -pub fn operator_split_into_checkpoints( +pub fn operator_split_into_checkpoints( args: &[u8], o_public_values_stream: &mut Vec, o_public_values: &mut Vec, o_checkpoints: &mut Vec>, o_cycles: &mut u64, ) { - let args_obj = ProveArgs::from_slice(args); + let args_obj: ProveArgs = ProveArgs::from_slice(args); let (public_values_stream, public_values, checkpoints, cycles) = operator_split_into_checkpoints_impl(&args_obj).unwrap(); @@ -39,13 +41,13 @@ pub fn operator_split_into_checkpoints( *o_cycles = cycles; } -pub fn operator_absorb_commits( +pub fn operator_absorb_commits( args: &[u8], commitments_vec: &[Vec>], records_vec: &[Vec>], o_challenger_state: &mut Vec, ) { - let args_obj = ProveArgs::from_slice(args); + let args_obj: ProveArgs = ProveArgs::from_slice(args); let commitments_vec: Vec> = commitments_vec .iter() .map(|commitments| { @@ -76,14 +78,14 @@ pub fn operator_absorb_commits( *o_challenger_state = ChallengerState::from(&challenger).to_bytes(); } -pub fn operator_construct_sp1_core_proof( +pub fn operator_construct_sp1_core_proof( args: &Vec, shard_proofs_vec: &[Vec>], public_values_stream: &[u8], cycles: u64, o_proof: &mut Vec, ) { - let args_obj = ProveArgs::from_slice(args.as_slice()); + let args_obj: ProveArgs = ProveArgs::from_slice(args.as_slice()); let shard_proofs_vec_obj = shard_proofs_vec .iter() .map(|proofs| { @@ -106,14 +108,14 @@ pub fn operator_construct_sp1_core_proof( *o_proof = bincode::serialize(&proof).unwrap(); } -pub fn operator_prepare_compress_inputs( +pub fn operator_prepare_compress_inputs( args: &Vec, core_proof: &[u8], o_rec_layouts: &mut Vec>, o_def_layouts: &mut Vec>, o_last_proof_public_values: &mut Vec, ) { - let args_obj = ProveArgs::from_slice(args.as_slice()); + let args_obj: ProveArgs = ProveArgs::from_slice(args.as_slice()); let core_proof_obj: SP1CoreProof = bincode::deserialize(&core_proof).unwrap(); let (client, stdin, _, vk) = common::init_client(&args_obj); @@ -173,8 +175,12 @@ pub fn operator_prepare_compress_input_chunks( .collect(); } -pub fn operator_prove_shrink(args: &[u8], compressed_proof: &[u8], o_shrink_proof: &mut Vec) { - let args_obj = ProveArgs::from_slice(args); +pub fn operator_prove_shrink( + args: &[u8], + compressed_proof: &[u8], + o_shrink_proof: &mut Vec, +) { + let args_obj: ProveArgs = ProveArgs::from_slice(args); let compressed_proof_obj: SP1ReduceProof = bincode::deserialize(compressed_proof).unwrap(); @@ -183,8 +189,12 @@ pub fn operator_prove_shrink(args: &[u8], compressed_proof: &[u8], o_shrink_proo *o_shrink_proof = bincode::serialize(&shrink_proof).unwrap(); } -pub fn operator_prove_plonk(args: &Vec, shrink_proof: &[u8], o_plonk_proof: &mut Vec) { - let args_obj = ProveArgs::from_slice(args.as_slice()); +pub fn operator_prove_plonk( + args: &Vec, + shrink_proof: &[u8], + o_plonk_proof: &mut Vec, +) { + let args_obj: ProveArgs = ProveArgs::from_slice(args.as_slice()); let shrink_proof_obj: SP1ReduceProof = bincode::deserialize(shrink_proof).unwrap(); diff --git a/sdk/src/multi_prover/operator/steps.rs b/sdk/src/multi_prover/operator/steps.rs index 338004f62f..3eee979033 100644 --- a/sdk/src/multi_prover/operator/steps.rs +++ b/sdk/src/multi_prover/operator/steps.rs @@ -8,6 +8,7 @@ use crate::multi_prover::common::{self, ProveArgs}; use anyhow::Result; use p3_baby_bear::BabyBear; use p3_challenger::CanObserve; +use serde::Serialize; use sp1_core::stark::{MachineRecord, RiscvAir}; use sp1_core::{ runtime::Runtime, @@ -58,8 +59,8 @@ fn operator_split_into_checkpoints( Ok((public_values_stream, public_values, checkpoints)) } -pub fn operator_split_into_checkpoints_impl( - args: &ProveArgs, +pub fn operator_split_into_checkpoints_impl( + args: &ProveArgs, ) -> Result<( PublicValueStreamType, PublicValuesType, @@ -84,8 +85,8 @@ pub fn operator_split_into_checkpoints_impl( )) } -pub fn operator_absorb_commits_impl( - args: &ProveArgs, +pub fn operator_absorb_commits_impl( + args: &ProveArgs, commitments_vec: Vec>, records_vec: Vec>, ) -> Result { @@ -128,8 +129,8 @@ pub fn operator_absorb_commits_impl( Ok(challenger) } -pub fn construct_sp1_core_proof_impl( - args: &ProveArgs, +pub fn construct_sp1_core_proof_impl( + args: &ProveArgs, shard_proofs_vec: Vec>>, public_values_stream: PublicValueStreamType, cycles: u64, @@ -224,8 +225,8 @@ pub fn operator_prepare_compress_input_chunks_impl( Ok(result) } -pub fn operator_prove_shrink_impl( - args: &ProveArgs, +pub fn operator_prove_shrink_impl( + args: &ProveArgs, compress_proof: SP1ReduceProof, ) -> Result> { let (client, _, pk, _) = common::init_client(args); @@ -237,8 +238,8 @@ pub fn operator_prove_shrink_impl( .map_err(|e| anyhow::anyhow!(e)) } -pub fn operator_prove_plonk_impl( - args: &ProveArgs, +pub fn operator_prove_plonk_impl( + args: &ProveArgs, shrink_proof: SP1ReduceProof, ) -> Result { let (client, _, pk, _) = common::init_client(args); diff --git a/sdk/src/multi_prover/scenario/compress_prove.rs b/sdk/src/multi_prover/scenario/compress_prove.rs index 99e8b14b47..f06fbb3224 100644 --- a/sdk/src/multi_prover/scenario/compress_prove.rs +++ b/sdk/src/multi_prover/scenario/compress_prove.rs @@ -8,22 +8,26 @@ use crate::multi_prover::{ }; use crate::{SP1Proof, SP1ProofWithPublicValues}; use anyhow::Result; +use serde::de::DeserializeOwned; +use serde::Serialize; use sp1_core::{stark::ShardProof, utils::BabyBearPoseidon2}; use sp1_prover::SP1ReduceProof; use tracing::info_span; -pub fn mpc_prove_compress(args: &ProveArgs) -> Result<(Vec, Vec)> { +pub fn mpc_prove_compress( + args: &ProveArgs, +) -> Result<(Vec, Vec)> { let span = info_span!("kroma_core"); let _guard = span.entered(); - let core_proof = mpc_prove_core(args.clone()).unwrap(); + let core_proof = mpc_prove_core::(args).unwrap(); let serialize_args = bincode::serialize(&args).unwrap(); let mut rec_layouts: Vec> = Vec::new(); let mut def_layouts: Vec> = Vec::new(); let mut last_proof_public_values = Vec::new(); info_span!("o_prepare_compress_inputs").in_scope(|| { - operator_prepare_compress_inputs( + operator_prepare_compress_inputs::( &serialize_args, &core_proof, &mut rec_layouts, @@ -36,7 +40,7 @@ pub fn mpc_prove_compress(args: &ProveArgs) -> Result<(Vec, Vec)> { info_span!("w_compress_proofs_leaf").in_scope(|| { for layout in rec_layouts { let mut compressed_proof = Vec::new(); - worker_compress_proofs( + worker_compress_proofs::( &serialize_args, &layout, 0, @@ -47,7 +51,7 @@ pub fn mpc_prove_compress(args: &ProveArgs) -> Result<(Vec, Vec)> { } for layout in def_layouts { let mut compressed_proof = Vec::new(); - worker_compress_proofs(&serialize_args, &layout, 1, None, &mut compressed_proof); + worker_compress_proofs::(&serialize_args, &layout, 1, None, &mut compressed_proof); compressed_proofs.push(compressed_proof); } }); @@ -65,7 +69,13 @@ pub fn mpc_prove_compress(args: &ProveArgs) -> Result<(Vec, Vec)> { info_span!("w_compress_proofs").in_scope(|| { for (worker_idx, layout) in red_layout.iter().enumerate() { let mut compressed_proof = Vec::new(); - worker_compress_proofs(&serialize_args, &layout, 2, None, &mut compressed_proof); + worker_compress_proofs::( + &serialize_args, + &layout, + 2, + None, + &mut compressed_proof, + ); compress_layer_proofs.push(compressed_proof); tracing::info!("{:?}/{:?} worker done", worker_idx + 1, red_layout.len()); } @@ -84,7 +94,11 @@ pub fn mpc_prove_compress(args: &ProveArgs) -> Result<(Vec, Vec)> { Ok((core_proof, proof)) } -pub fn scenario_end(args: &ProveArgs, core_proof: &Vec, compress_proof: &Vec) { +pub fn scenario_end( + args: &ProveArgs, + core_proof: &Vec, + compress_proof: &Vec, +) { let compress_proof_obj: SP1ReduceProof = bincode::deserialize(compress_proof).unwrap(); diff --git a/sdk/src/multi_prover/scenario/core_prove.rs b/sdk/src/multi_prover/scenario/core_prove.rs index 855804a4c4..16fdafcb70 100644 --- a/sdk/src/multi_prover/scenario/core_prove.rs +++ b/sdk/src/multi_prover/scenario/core_prove.rs @@ -8,10 +8,12 @@ use crate::multi_prover::{ }; use crate::{SP1Proof, SP1ProofWithPublicValues}; use anyhow::Result; +use serde::de::DeserializeOwned; +use serde::Serialize; use sp1_prover::SP1CoreProof; use tracing::info_span; -pub fn mpc_prove_core(args: ProveArgs) -> Result> { +pub fn mpc_prove_core(args: &ProveArgs) -> Result> { let span = info_span!("kroma_core"); let _guard = span.entered(); @@ -21,7 +23,7 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result> { let mut checkpoints = Vec::new(); let mut cycles = 0; info_span!("o_split_checkpoints").in_scope(|| { - operator_split_into_checkpoints( + operator_split_into_checkpoints::( &serialize_args, &mut public_values_stream, &mut public_values, @@ -37,7 +39,7 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result> { for (worker_idx, checkpoint) in checkpoints.iter_mut().enumerate() { let mut commitments = Vec::new(); let mut records = Vec::new(); - worker_commit_checkpoint( + worker_commit_checkpoint::( &serialize_args, worker_idx as u32, checkpoint, @@ -54,7 +56,7 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result> { let mut challenger_state = Vec::new(); info_span!("o_absorb_commits").in_scope(|| { - operator_absorb_commits( + operator_absorb_commits::( &serialize_args, &commitments_vec, &records_vec, @@ -67,7 +69,7 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result> { let num_workers = records_vec.len(); for (worker_idx, records) in records_vec.into_iter().enumerate() { let mut shard_proofs = Vec::new(); - worker_prove_checkpoint( + worker_prove_checkpoint::( &serialize_args, &challenger_state, records.as_slice(), @@ -80,7 +82,7 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result> { let mut proof = Vec::new(); info_span!("o_construct_sp1_core_proof").in_scope(|| { - operator_construct_sp1_core_proof( + operator_construct_sp1_core_proof::( &serialize_args, &shard_proofs_vec, &public_values_stream, @@ -93,7 +95,10 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result> { Ok(proof) } -pub fn scenario_end(args: &ProveArgs, core_proof: &Vec) -> Result { +pub fn scenario_end( + args: &ProveArgs, + core_proof: &Vec, +) -> Result { let core_proof_obj: SP1CoreProof = bincode::deserialize(core_proof).unwrap(); let (client, _, _, vk) = common::init_client(args); diff --git a/sdk/src/multi_prover/scenario/plonk_prove.rs b/sdk/src/multi_prover/scenario/plonk_prove.rs index 151fa4982f..88f5d7fd8f 100644 --- a/sdk/src/multi_prover/scenario/plonk_prove.rs +++ b/sdk/src/multi_prover/scenario/plonk_prove.rs @@ -5,10 +5,13 @@ use crate::multi_prover::{ }; use crate::{PlonkBn254Proof, SP1Proof, SP1ProofWithPublicValues}; use anyhow::Result; +use serde::{de::DeserializeOwned, Serialize}; use sp1_prover::SP1CoreProof; use tracing::info_span; -pub fn mpc_prove_plonk(args: &ProveArgs) -> Result<(Vec, Vec, Vec)> { +pub fn mpc_prove_plonk( + args: &ProveArgs, +) -> Result<(Vec, Vec, Vec)> { let span = info_span!("kroma_core"); let _guard = span.entered(); @@ -16,17 +19,22 @@ pub fn mpc_prove_plonk(args: &ProveArgs) -> Result<(Vec, Vec, Vec)> let serialize_args = bincode::serialize(&args).unwrap(); let mut shrink_proof = Vec::new(); - info_span!("o_shrink_proof") - .in_scope(|| operator_prove_shrink(&serialize_args, &compress_proof, &mut shrink_proof)); + info_span!("o_shrink_proof").in_scope(|| { + operator_prove_shrink::(&serialize_args, &compress_proof, &mut shrink_proof) + }); let mut plonk_proof = Vec::new(); info_span!("o_plonk_proof") - .in_scope(|| operator_prove_plonk(&serialize_args, &shrink_proof, &mut plonk_proof)); + .in_scope(|| operator_prove_plonk::(&serialize_args, &shrink_proof, &mut plonk_proof)); Ok((core_proof, compress_proof, plonk_proof)) } -pub fn scenario_end(args: &ProveArgs, core_proof: &Vec, plonk_proof: &Vec) { +pub fn scenario_end( + args: &ProveArgs, + core_proof: &Vec, + plonk_proof: &Vec, +) { let plonk_proof: PlonkBn254Proof = bincode::deserialize(plonk_proof).unwrap(); let (client, _, _, vk) = common::init_client(args); diff --git a/sdk/src/multi_prover/worker/mod.rs b/sdk/src/multi_prover/worker/mod.rs index 1edd7dddd4..85f9e8d5b7 100644 --- a/sdk/src/multi_prover/worker/mod.rs +++ b/sdk/src/multi_prover/worker/mod.rs @@ -10,6 +10,8 @@ use crate::multi_prover::{ }, operator::utils::ChallengerState, }; +use serde::de::DeserializeOwned; +use serde::Serialize; use sp1_core::{runtime::ExecutionState, stark::MachineProver}; use steps::{ worker_commit_checkpoint_impl, worker_compress_proofs_for_deferred, @@ -17,7 +19,7 @@ use steps::{ worker_prove_checkpoint_impl, }; -pub fn worker_commit_checkpoint( +pub fn worker_commit_checkpoint( args: &Vec, idx: u32, checkpoint: &Vec, @@ -26,7 +28,7 @@ pub fn worker_commit_checkpoint( o_commitments: &mut Vec>, o_records: &mut Vec>, ) { - let args_obj = ProveArgs::from_slice(args.as_slice()); + let args_obj: ProveArgs = ProveArgs::from_slice(args.as_slice()); let execution_state: ExecutionState = bincode::deserialize(checkpoint.as_slice()).unwrap(); let mut checkpoint_file = tempfile::tempfile().unwrap(); execution_state.save(&mut checkpoint_file).unwrap(); @@ -52,13 +54,13 @@ pub fn worker_commit_checkpoint( .collect(); } -pub fn worker_prove_checkpoint( +pub fn worker_prove_checkpoint( args: &Vec, challenger_state: &Vec, records: &[Vec], o_shard_proofs: &mut Vec>, ) { - let args_obj = ProveArgs::from_slice(args.as_slice()); + let args_obj: ProveArgs = ProveArgs::from_slice(args.as_slice()); let (client, _, _, _) = common::init_client(&args_obj); let challenger = ChallengerState::from_bytes(challenger_state.as_slice()) .to_challenger(&client.prover.sp1_prover().core_prover.config().perm); @@ -76,14 +78,14 @@ pub fn worker_prove_checkpoint( .collect(); } -pub fn worker_compress_proofs( +pub fn worker_compress_proofs( args: &Vec, layout: &Vec, layout_type: usize, last_proof_public_values: Option<&Vec>, o_proof: &mut Vec, ) { - let args_obj = ProveArgs::from_slice(args.as_slice()); + let args_obj: ProveArgs = ProveArgs::from_slice(args.as_slice()); let compressed_shard_proof = match LayoutType::from_usize(layout_type) { LayoutType::Recursion => { let layout: SerializableRecursionLayout = bincode::deserialize(layout).unwrap(); diff --git a/sdk/src/multi_prover/worker/steps.rs b/sdk/src/multi_prover/worker/steps.rs index 9dc86a9c55..f7c5f08672 100644 --- a/sdk/src/multi_prover/worker/steps.rs +++ b/sdk/src/multi_prover/worker/steps.rs @@ -9,6 +9,8 @@ use crate::multi_prover::common::ProveArgs; use anyhow::Result; use p3_baby_bear::BabyBear; use p3_field::AbstractField; +use serde::de::DeserializeOwned; +use serde::Serialize; use sp1_core::air::Word; use sp1_core::runtime::ExecutionReport; use sp1_core::{ @@ -20,8 +22,8 @@ use sp1_core::{ use sp1_prover::ReduceProgramType; use std::fs::File; -pub fn worker_commit_checkpoint_impl( - args: &ProveArgs, +pub fn worker_commit_checkpoint_impl( + args: &ProveArgs, idx: u32, checkpoint: &mut File, is_last_checkpoint: bool, @@ -104,8 +106,8 @@ pub fn worker_commit_checkpoint_impl( Ok((commitments, records)) } -pub fn worker_prove_checkpoint_impl( - args: &ProveArgs, +pub fn worker_prove_checkpoint_impl( + args: &ProveArgs, challenger: ChallengerType, records: Vec, ) -> Result>> { @@ -134,8 +136,8 @@ pub fn worker_prove_checkpoint_impl( Ok(shard_proofs) } -pub fn worker_compress_proofs_for_recursion( - args: &ProveArgs, +pub fn worker_compress_proofs_for_recursion( + args: &ProveArgs, mut layout: SerializableRecursionLayout, ) -> Result<(ShardProof, ReduceProgramType)> { let (client, stdin, pk, _) = common::init_client(&args); @@ -175,8 +177,8 @@ pub fn worker_compress_proofs_for_recursion( .map_err(|e| anyhow::anyhow!("failed to compress machine proof: {:?}", e)) } -pub fn worker_compress_proofs_for_deferred( - args: &ProveArgs, +pub fn worker_compress_proofs_for_deferred( + args: &ProveArgs, mut layout: SerializableDeferredLayout, last_proof_pv: PublicValues, BabyBear>, ) -> Result<(ShardProof, ReduceProgramType)> { @@ -223,8 +225,8 @@ pub fn worker_compress_proofs_for_deferred( .map_err(|e| anyhow::anyhow!("failed to compress machine proof: {:?}", e)) } -pub fn worker_compress_proofs_for_reduce( - args: &ProveArgs, +pub fn worker_compress_proofs_for_reduce( + args: &ProveArgs, layout: SerializableReduceLayout, ) -> Result<(ShardProof, ReduceProgramType)> { let (client, _, pk, _) = common::init_client(&args);