Skip to content

Commit

Permalink
feat(sdk): refactor ProveArg to support generic
Browse files Browse the repository at this point in the history
  • Loading branch information
dongchangYoo committed Aug 16, 2024
1 parent 97a94de commit 6814964
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 62 deletions.
12 changes: 7 additions & 5 deletions sdk/src/multi_prover/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub mod types;

use crate::{ProverClient, SP1ProofKind};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sp1_core::{
runtime::{Program, Runtime, SP1Context, SP1ContextBuilder},
utils::SP1ProverOpts,
Expand All @@ -18,12 +18,12 @@ use sysinfo::System;
static LIMIT_RAM_GB: u64 = 120;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ProveArgs {
pub zkvm_input: Vec<u8>,
pub struct ProveArgs<T: Serialize> {
pub zkvm_input: T,
pub elf: Vec<u8>,
}

impl ProveArgs {
impl<T: Serialize + DeserializeOwned> ProveArgs<T> {
pub fn to_bytes(&self) -> Vec<u8> {
bincode::serialize(self).unwrap()
}
Expand All @@ -33,7 +33,9 @@ impl ProveArgs {
}
}

pub fn init_client(args: &ProveArgs) -> (ProverClient, SP1Stdin, SP1ProvingKey, SP1VerifyingKey) {
pub fn init_client<T: Serialize>(
args: &ProveArgs<T>,
) -> (ProverClient, SP1Stdin, SP1ProvingKey, SP1VerifyingKey) {
let client = ProverClient::new();
let (pk, vk) = client.setup(&args.elf);
let mut stdin = SP1Stdin::new();
Expand Down
34 changes: 22 additions & 12 deletions sdk/src/multi_prover/operator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use crate::multi_prover::common::{
types::{CommitmentType, PublicValueStreamType},
};
use p3_baby_bear::BabyBear;
use serde::de::DeserializeOwned;
use serde::Serialize;
use sp1_core::air::{PublicValues, Word};
use sp1_core::stark::{MachineProver, StarkGenericConfig};
use sp1_core::utils::BabyBearPoseidon2;
Expand All @@ -19,14 +21,14 @@ use steps::{
};
use utils::{read_bin_file_to_vec, ChallengerState};

pub fn operator_split_into_checkpoints(
pub fn operator_split_into_checkpoints<T: Serialize + DeserializeOwned>(
args: &[u8],
o_public_values_stream: &mut Vec<u8>,
o_public_values: &mut Vec<u8>,
o_checkpoints: &mut Vec<Vec<u8>>,
o_cycles: &mut u64,
) {
let args_obj = ProveArgs::from_slice(args);
let args_obj: ProveArgs<T> = ProveArgs::from_slice(args);
let (public_values_stream, public_values, checkpoints, cycles) =
operator_split_into_checkpoints_impl(&args_obj).unwrap();

Expand All @@ -39,13 +41,13 @@ pub fn operator_split_into_checkpoints(
*o_cycles = cycles;
}

pub fn operator_absorb_commits(
pub fn operator_absorb_commits<T: Serialize + DeserializeOwned>(
args: &[u8],
commitments_vec: &[Vec<Vec<u8>>],
records_vec: &[Vec<Vec<u8>>],
o_challenger_state: &mut Vec<u8>,
) {
let args_obj = ProveArgs::from_slice(args);
let args_obj: ProveArgs<T> = ProveArgs::from_slice(args);
let commitments_vec: Vec<Vec<CommitmentType>> = commitments_vec
.iter()
.map(|commitments| {
Expand Down Expand Up @@ -76,14 +78,14 @@ pub fn operator_absorb_commits(
*o_challenger_state = ChallengerState::from(&challenger).to_bytes();
}

pub fn operator_construct_sp1_core_proof(
pub fn operator_construct_sp1_core_proof<T: Serialize + DeserializeOwned>(
args: &Vec<u8>,
shard_proofs_vec: &[Vec<Vec<u8>>],
public_values_stream: &[u8],
cycles: u64,
o_proof: &mut Vec<u8>,
) {
let args_obj = ProveArgs::from_slice(args.as_slice());
let args_obj: ProveArgs<T> = ProveArgs::from_slice(args.as_slice());
let shard_proofs_vec_obj = shard_proofs_vec
.iter()
.map(|proofs| {
Expand All @@ -106,14 +108,14 @@ pub fn operator_construct_sp1_core_proof(
*o_proof = bincode::serialize(&proof).unwrap();
}

pub fn operator_prepare_compress_inputs(
pub fn operator_prepare_compress_inputs<T: Serialize + DeserializeOwned>(
args: &Vec<u8>,
core_proof: &[u8],
o_rec_layouts: &mut Vec<Vec<u8>>,
o_def_layouts: &mut Vec<Vec<u8>>,
o_last_proof_public_values: &mut Vec<u8>,
) {
let args_obj = ProveArgs::from_slice(args.as_slice());
let args_obj: ProveArgs<T> = ProveArgs::from_slice(args.as_slice());
let core_proof_obj: SP1CoreProof = bincode::deserialize(&core_proof).unwrap();

let (client, stdin, _, vk) = common::init_client(&args_obj);
Expand Down Expand Up @@ -173,8 +175,12 @@ pub fn operator_prepare_compress_input_chunks(
.collect();
}

pub fn operator_prove_shrink(args: &[u8], compressed_proof: &[u8], o_shrink_proof: &mut Vec<u8>) {
let args_obj = ProveArgs::from_slice(args);
pub fn operator_prove_shrink<T: Serialize + DeserializeOwned>(
args: &[u8],
compressed_proof: &[u8],
o_shrink_proof: &mut Vec<u8>,
) {
let args_obj: ProveArgs<T> = ProveArgs::from_slice(args);
let compressed_proof_obj: SP1ReduceProof<BabyBearPoseidon2> =
bincode::deserialize(compressed_proof).unwrap();

Expand All @@ -183,8 +189,12 @@ pub fn operator_prove_shrink(args: &[u8], compressed_proof: &[u8], o_shrink_proo
*o_shrink_proof = bincode::serialize(&shrink_proof).unwrap();
}

pub fn operator_prove_plonk(args: &Vec<u8>, shrink_proof: &[u8], o_plonk_proof: &mut Vec<u8>) {
let args_obj = ProveArgs::from_slice(args.as_slice());
pub fn operator_prove_plonk<T: Serialize + DeserializeOwned>(
args: &Vec<u8>,
shrink_proof: &[u8],
o_plonk_proof: &mut Vec<u8>,
) {
let args_obj: ProveArgs<T> = ProveArgs::from_slice(args.as_slice());
let shrink_proof_obj: SP1ReduceProof<BabyBearPoseidon2> =
bincode::deserialize(shrink_proof).unwrap();

Expand Down
21 changes: 11 additions & 10 deletions sdk/src/multi_prover/operator/steps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::multi_prover::common::{self, ProveArgs};
use anyhow::Result;
use p3_baby_bear::BabyBear;
use p3_challenger::CanObserve;
use serde::Serialize;
use sp1_core::stark::{MachineRecord, RiscvAir};
use sp1_core::{
runtime::Runtime,
Expand Down Expand Up @@ -58,8 +59,8 @@ fn operator_split_into_checkpoints(
Ok((public_values_stream, public_values, checkpoints))
}

pub fn operator_split_into_checkpoints_impl(
args: &ProveArgs,
pub fn operator_split_into_checkpoints_impl<T: Serialize>(
args: &ProveArgs<T>,
) -> Result<(
PublicValueStreamType,
PublicValuesType,
Expand All @@ -84,8 +85,8 @@ pub fn operator_split_into_checkpoints_impl(
))
}

pub fn operator_absorb_commits_impl(
args: &ProveArgs,
pub fn operator_absorb_commits_impl<T: Serialize>(
args: &ProveArgs<T>,
commitments_vec: Vec<Vec<CommitmentType>>,
records_vec: Vec<Vec<RecordType>>,
) -> Result<ChallengerType> {
Expand Down Expand Up @@ -128,8 +129,8 @@ pub fn operator_absorb_commits_impl(
Ok(challenger)
}

pub fn construct_sp1_core_proof_impl(
args: &ProveArgs,
pub fn construct_sp1_core_proof_impl<T: Serialize>(
args: &ProveArgs<T>,
shard_proofs_vec: Vec<Vec<ShardProof<BabyBearPoseidon2>>>,
public_values_stream: PublicValueStreamType,
cycles: u64,
Expand Down Expand Up @@ -224,8 +225,8 @@ pub fn operator_prepare_compress_input_chunks_impl(
Ok(result)
}

pub fn operator_prove_shrink_impl(
args: &ProveArgs,
pub fn operator_prove_shrink_impl<T: Serialize>(
args: &ProveArgs<T>,
compress_proof: SP1ReduceProof<BabyBearPoseidon2>,
) -> Result<SP1ReduceProof<BabyBearPoseidon2>> {
let (client, _, pk, _) = common::init_client(args);
Expand All @@ -237,8 +238,8 @@ pub fn operator_prove_shrink_impl(
.map_err(|e| anyhow::anyhow!(e))
}

pub fn operator_prove_plonk_impl(
args: &ProveArgs,
pub fn operator_prove_plonk_impl<T: Serialize>(
args: &ProveArgs<T>,
shrink_proof: SP1ReduceProof<BabyBearPoseidon2>,
) -> Result<PlonkBn254Proof> {
let (client, _, pk, _) = common::init_client(args);
Expand Down
28 changes: 21 additions & 7 deletions sdk/src/multi_prover/scenario/compress_prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,26 @@ use crate::multi_prover::{
};
use crate::{SP1Proof, SP1ProofWithPublicValues};
use anyhow::Result;
use serde::de::DeserializeOwned;
use serde::Serialize;
use sp1_core::{stark::ShardProof, utils::BabyBearPoseidon2};
use sp1_prover::SP1ReduceProof;
use tracing::info_span;

pub fn mpc_prove_compress(args: &ProveArgs) -> Result<(Vec<u8>, Vec<u8>)> {
pub fn mpc_prove_compress<T: Serialize + DeserializeOwned>(
args: &ProveArgs<T>,
) -> Result<(Vec<u8>, Vec<u8>)> {
let span = info_span!("kroma_core");
let _guard = span.entered();

let core_proof = mpc_prove_core(args.clone()).unwrap();
let core_proof = mpc_prove_core::<T>(args).unwrap();
let serialize_args = bincode::serialize(&args).unwrap();

let mut rec_layouts: Vec<Vec<u8>> = Vec::new();
let mut def_layouts: Vec<Vec<u8>> = Vec::new();
let mut last_proof_public_values = Vec::new();
info_span!("o_prepare_compress_inputs").in_scope(|| {
operator_prepare_compress_inputs(
operator_prepare_compress_inputs::<T>(
&serialize_args,
&core_proof,
&mut rec_layouts,
Expand All @@ -36,7 +40,7 @@ pub fn mpc_prove_compress(args: &ProveArgs) -> Result<(Vec<u8>, Vec<u8>)> {
info_span!("w_compress_proofs_leaf").in_scope(|| {
for layout in rec_layouts {
let mut compressed_proof = Vec::new();
worker_compress_proofs(
worker_compress_proofs::<T>(
&serialize_args,
&layout,
0,
Expand All @@ -47,7 +51,7 @@ pub fn mpc_prove_compress(args: &ProveArgs) -> Result<(Vec<u8>, Vec<u8>)> {
}
for layout in def_layouts {
let mut compressed_proof = Vec::new();
worker_compress_proofs(&serialize_args, &layout, 1, None, &mut compressed_proof);
worker_compress_proofs::<T>(&serialize_args, &layout, 1, None, &mut compressed_proof);
compressed_proofs.push(compressed_proof);
}
});
Expand All @@ -65,7 +69,13 @@ pub fn mpc_prove_compress(args: &ProveArgs) -> Result<(Vec<u8>, Vec<u8>)> {
info_span!("w_compress_proofs").in_scope(|| {
for (worker_idx, layout) in red_layout.iter().enumerate() {
let mut compressed_proof = Vec::new();
worker_compress_proofs(&serialize_args, &layout, 2, None, &mut compressed_proof);
worker_compress_proofs::<T>(
&serialize_args,
&layout,
2,
None,
&mut compressed_proof,
);
compress_layer_proofs.push(compressed_proof);
tracing::info!("{:?}/{:?} worker done", worker_idx + 1, red_layout.len());
}
Expand All @@ -84,7 +94,11 @@ pub fn mpc_prove_compress(args: &ProveArgs) -> Result<(Vec<u8>, Vec<u8>)> {
Ok((core_proof, proof))
}

pub fn scenario_end(args: &ProveArgs, core_proof: &Vec<u8>, compress_proof: &Vec<u8>) {
pub fn scenario_end<T: Serialize + DeserializeOwned>(
args: &ProveArgs<T>,
core_proof: &Vec<u8>,
compress_proof: &Vec<u8>,
) {
let compress_proof_obj: SP1ReduceProof<BabyBearPoseidon2> =
bincode::deserialize(compress_proof).unwrap();

Expand Down
19 changes: 12 additions & 7 deletions sdk/src/multi_prover/scenario/core_prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ use crate::multi_prover::{
};
use crate::{SP1Proof, SP1ProofWithPublicValues};
use anyhow::Result;
use serde::de::DeserializeOwned;
use serde::Serialize;
use sp1_prover::SP1CoreProof;
use tracing::info_span;

pub fn mpc_prove_core(args: ProveArgs) -> Result<Vec<u8>> {
pub fn mpc_prove_core<T: Serialize + DeserializeOwned>(args: &ProveArgs<T>) -> Result<Vec<u8>> {
let span = info_span!("kroma_core");
let _guard = span.entered();

Expand All @@ -21,7 +23,7 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result<Vec<u8>> {
let mut checkpoints = Vec::new();
let mut cycles = 0;
info_span!("o_split_checkpoints").in_scope(|| {
operator_split_into_checkpoints(
operator_split_into_checkpoints::<T>(
&serialize_args,
&mut public_values_stream,
&mut public_values,
Expand All @@ -37,7 +39,7 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result<Vec<u8>> {
for (worker_idx, checkpoint) in checkpoints.iter_mut().enumerate() {
let mut commitments = Vec::new();
let mut records = Vec::new();
worker_commit_checkpoint(
worker_commit_checkpoint::<T>(
&serialize_args,
worker_idx as u32,
checkpoint,
Expand All @@ -54,7 +56,7 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result<Vec<u8>> {

let mut challenger_state = Vec::new();
info_span!("o_absorb_commits").in_scope(|| {
operator_absorb_commits(
operator_absorb_commits::<T>(
&serialize_args,
&commitments_vec,
&records_vec,
Expand All @@ -67,7 +69,7 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result<Vec<u8>> {
let num_workers = records_vec.len();
for (worker_idx, records) in records_vec.into_iter().enumerate() {
let mut shard_proofs = Vec::new();
worker_prove_checkpoint(
worker_prove_checkpoint::<T>(
&serialize_args,
&challenger_state,
records.as_slice(),
Expand All @@ -80,7 +82,7 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result<Vec<u8>> {

let mut proof = Vec::new();
info_span!("o_construct_sp1_core_proof").in_scope(|| {
operator_construct_sp1_core_proof(
operator_construct_sp1_core_proof::<T>(
&serialize_args,
&shard_proofs_vec,
&public_values_stream,
Expand All @@ -93,7 +95,10 @@ pub fn mpc_prove_core(args: ProveArgs) -> Result<Vec<u8>> {
Ok(proof)
}

pub fn scenario_end(args: &ProveArgs, core_proof: &Vec<u8>) -> Result<SP1ProofWithPublicValues> {
pub fn scenario_end<T: Serialize + DeserializeOwned>(
args: &ProveArgs<T>,
core_proof: &Vec<u8>,
) -> Result<SP1ProofWithPublicValues> {
let core_proof_obj: SP1CoreProof = bincode::deserialize(core_proof).unwrap();

let (client, _, _, vk) = common::init_client(args);
Expand Down
18 changes: 13 additions & 5 deletions sdk/src/multi_prover/scenario/plonk_prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,36 @@ use crate::multi_prover::{
};
use crate::{PlonkBn254Proof, SP1Proof, SP1ProofWithPublicValues};
use anyhow::Result;
use serde::{de::DeserializeOwned, Serialize};
use sp1_prover::SP1CoreProof;
use tracing::info_span;

pub fn mpc_prove_plonk(args: &ProveArgs) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
pub fn mpc_prove_plonk<T: Serialize + DeserializeOwned>(
args: &ProveArgs<T>,
) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
let span = info_span!("kroma_core");
let _guard = span.entered();

let (core_proof, compress_proof) = mpc_prove_compress(args).unwrap();
let serialize_args = bincode::serialize(&args).unwrap();

let mut shrink_proof = Vec::new();
info_span!("o_shrink_proof")
.in_scope(|| operator_prove_shrink(&serialize_args, &compress_proof, &mut shrink_proof));
info_span!("o_shrink_proof").in_scope(|| {
operator_prove_shrink::<T>(&serialize_args, &compress_proof, &mut shrink_proof)
});

let mut plonk_proof = Vec::new();
info_span!("o_plonk_proof")
.in_scope(|| operator_prove_plonk(&serialize_args, &shrink_proof, &mut plonk_proof));
.in_scope(|| operator_prove_plonk::<T>(&serialize_args, &shrink_proof, &mut plonk_proof));

Ok((core_proof, compress_proof, plonk_proof))
}

pub fn scenario_end(args: &ProveArgs, core_proof: &Vec<u8>, plonk_proof: &Vec<u8>) {
pub fn scenario_end<T: Serialize + DeserializeOwned>(
args: &ProveArgs<T>,
core_proof: &Vec<u8>,
plonk_proof: &Vec<u8>,
) {
let plonk_proof: PlonkBn254Proof = bincode::deserialize(plonk_proof).unwrap();

let (client, _, _, vk) = common::init_client(args);
Expand Down
Loading

0 comments on commit 6814964

Please sign in to comment.