diff --git a/Cargo.lock b/Cargo.lock index 0bf3bf786..7c8951e52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1968,6 +1968,7 @@ dependencies = [ "risc0-zkvm", "risc0-zkvm-platform", "serde", + "sov-db", "sov-risc0-adapter", "sov-rollup-interface", "tracing", diff --git a/bin/citrea/src/rollup/bitcoin.rs b/bin/citrea/src/rollup/bitcoin.rs index bd4d2fb5e..a2b362a3c 100644 --- a/bin/citrea/src/rollup/bitcoin.rs +++ b/bin/citrea/src/rollup/bitcoin.rs @@ -135,11 +135,13 @@ impl RollupBlueprint for BitcoinRollup { prover_config: ProverConfig, _rollup_config: &FullNodeConfig, _da_service: &Arc, + ledger_db: LedgerDB, ) -> Self::ProverService { let vm = Risc0BonsaiHost::new( citrea_risc0::BITCOIN_DA_ELF, std::env::var("BONSAI_API_URL").unwrap_or("".to_string()), std::env::var("BONSAI_API_KEY").unwrap_or("".to_string()), + ledger_db.clone(), ); let zk_stf = StfBlueprint::new(); let zk_storage = ZkStorage::new(); @@ -155,6 +157,7 @@ impl RollupBlueprint for BitcoinRollup { da_verifier, prover_config, zk_storage, + ledger_db, ) .expect("Should be able to instantiate prover service") } diff --git a/bin/citrea/src/rollup/mock.rs b/bin/citrea/src/rollup/mock.rs index 2ae87456f..b2da31133 100644 --- a/bin/citrea/src/rollup/mock.rs +++ b/bin/citrea/src/rollup/mock.rs @@ -100,11 +100,13 @@ impl RollupBlueprint for MockDemoRollup { prover_config: ProverConfig, _rollup_config: &FullNodeConfig, _da_service: &Arc, + ledger_db: LedgerDB, ) -> Self::ProverService { let vm = Risc0BonsaiHost::new( citrea_risc0::MOCK_DA_ELF, std::env::var("BONSAI_API_URL").unwrap_or("".to_string()), std::env::var("BONSAI_API_KEY").unwrap_or("".to_string()), + ledger_db.clone(), ); let zk_stf = StfBlueprint::new(); let zk_storage = ZkStorage::new(); @@ -116,6 +118,7 @@ impl RollupBlueprint for MockDemoRollup { da_verifier, prover_config, zk_storage, + ledger_db, ) .expect("Should be able to instantiate prover service") } diff --git a/bin/citrea/src/rollup/mod.rs b/bin/citrea/src/rollup/mod.rs index dd8e9e2e3..b28973697 100644 --- a/bin/citrea/src/rollup/mod.rs +++ b/bin/citrea/src/rollup/mod.rs @@ -230,15 +230,21 @@ pub trait CitreaRollupBlueprint: RollupBlueprint { { let da_service = self.create_da_service(&rollup_config).await?; + let ledger_db = self.create_ledger_db(&rollup_config); + let prover_service = self - .create_prover_service(prover_config.clone(), &rollup_config, &da_service) + .create_prover_service( + prover_config.clone(), + &rollup_config, + &da_service, + ledger_db.clone(), + ) .await; // TODO: Double check what kind of storage needed here. // Maybe whole "prev_root" can be initialized inside runner // Getting block here, so prover_service doesn't have to be `Send` - let ledger_db = self.create_ledger_db(&rollup_config); let genesis_config = self.create_genesis_config(runtime_genesis_paths, &rollup_config)?; let mut storage_manager = self.create_storage_manager(&rollup_config)?; diff --git a/crates/prover/src/prover_service/parallel/mod.rs b/crates/prover/src/prover_service/parallel/mod.rs index 0c8cd0e6b..de4f8704d 100644 --- a/crates/prover/src/prover_service/parallel/mod.rs +++ b/crates/prover/src/prover_service/parallel/mod.rs @@ -7,6 +7,7 @@ use citrea_stf::verifier::StateTransitionVerifier; use prover::Prover; use serde::de::DeserializeOwned; use serde::Serialize; +use sov_db::ledger_db::{LedgerDB, ProvingServiceLedgerOps}; use sov_rollup_interface::da::{DaData, DaSpec}; use sov_rollup_interface::services::da::DaService; use sov_rollup_interface::stf::StateTransitionFunction; @@ -34,6 +35,7 @@ where zk_storage: V::PreState, prover_state: Prover, + ledger_db: LedgerDB, } impl ParallelProverService @@ -62,6 +64,7 @@ where config: ProverGuestRunConfig, zk_storage: V::PreState, num_threads: usize, + ledger_db: LedgerDB, ) -> anyhow::Result { let stf_verifier = StateTransitionVerifier::::new(zk_stf, da_verifier); @@ -96,6 +99,7 @@ where prover_config, prover_state: Prover::new(num_threads)?, zk_storage, + ledger_db, }) } @@ -106,6 +110,7 @@ where da_verifier: Da::Verifier, prover_config: ProverConfig, zk_storage: V::PreState, + ledger_db: LedgerDB, ) -> anyhow::Result { let num_cpus = num_cpus::get(); assert!(num_cpus > 1, "Unable to create parallel prover service"); @@ -117,6 +122,7 @@ where prover_config.proving_mode, zk_storage, num_cpus - 1, + ledger_db, ) } } @@ -192,6 +198,7 @@ where .send_transaction(da_data) .await .map_err(|e| anyhow::anyhow!(e))?; + self.ledger_db.clear_pending_proving_sessions()?; break Ok((tx_id, proof)); } ProverStatus::ProvingInProgress => { @@ -203,4 +210,27 @@ where } } } + + async fn recover_proving_sessions_and_send_to_da( + &self, + da_service: &Arc, + ) -> Result::TransactionId, Proof)>, anyhow::Error> { + tracing::info!("Checking if ongoing bonsai session exists"); + + let vm = self.vm.clone(); + let proofs = vm.recover_proving_sessions()?; + + let mut results = Vec::new(); + + for proof in proofs.into_iter() { + let da_data = DaData::ZKProof(proof.clone()); + let tx_id = da_service + .send_transaction(da_data) + .await + .map_err(|e| anyhow::anyhow!(e))?; + results.push((tx_id, proof)); + } + self.ledger_db.clear_pending_proving_sessions()?; + Ok(results) + } } diff --git a/crates/prover/src/runner.rs b/crates/prover/src/runner.rs index 01dc76b6c..25c80a4d5 100644 --- a/crates/prover/src/runner.rs +++ b/crates/prover/src/runner.rs @@ -224,6 +224,24 @@ where }); } + async fn check_and_recover_ongoing_proving_sessions(&self) -> Result { + let prover_service = self + .prover_service + .as_ref() + .expect("Prover service should be present"); + let results = prover_service + .recover_proving_sessions_and_send_to_da(&self.da_service) + .await?; + if results.is_empty() { + Ok(false) + } else { + for (tx_id, proof) in results { + self.extract_and_store_proof(tx_id, proof).await?; + } + Ok(true) + } + } + /// Runs the rollup. #[instrument(level = "trace", skip_all, err)] pub async fn run(&mut self) -> Result<(), anyhow::Error> { @@ -416,6 +434,7 @@ where skip_submission_until_l1: u64, prover_config: &ProverConfig, ) -> Result<(), anyhow::Error> { + let mut proving_session_exists = self.check_and_recover_ongoing_proving_sessions().await?; while !pending_l1_blocks.is_empty() { let l1_block = pending_l1_blocks .front() @@ -431,7 +450,8 @@ where ) .unwrap(); - let mut da_data = self.da_service.extract_relevant_blobs(l1_block); + let mut da_data: Vec<<::Spec as DaSpec>::BlobTransaction> = + self.da_service.extract_relevant_blobs(l1_block); // if we don't do this, the zk circuit can't read the sequencer commitments da_data.iter_mut().for_each(|blob| { blob.full_data(); @@ -472,102 +492,51 @@ where break; } - let sequencer_commitments_groups = - self.break_sequencer_commitments_into_groups(sequencer_commitments)?; - - for sequencer_commitments in sequencer_commitments_groups { - let first_l2_height_of_l1 = sequencer_commitments[0].l2_start_block_number; - let last_l2_height_of_l1 = - sequencer_commitments[sequencer_commitments.len() - 1].l2_end_block_number; - - let ( - state_transition_witnesses, - soft_confirmations, - da_block_headers_of_soft_confirmations, - ) = self - .get_state_transition_data_from_commitments( - &sequencer_commitments, - &self.da_service, + let da_block_header_of_commitments: <::Spec as DaSpec>::BlockHeader = + l1_block.header().clone(); + + let hash = da_block_header_of_commitments.hash(); + + if !proving_session_exists { + let sequencer_commitments_groups = + self.break_sequencer_commitments_into_groups(sequencer_commitments)?; + for sequencer_commitments in sequencer_commitments_groups { + // There is no ongoing bonsai session to recover + let transition_data: StateTransitionData< + Stf::StateRoot, + Stf::Witness, + Da::Spec, + > = self + .create_state_transition_data( + &sequencer_commitments, + da_block_header_of_commitments.clone(), + da_data.clone(), + l1_block, + ) + .await?; + + self.prove_state_transition( + transition_data, + prover_config, + skip_submission_until_l1, + l1_height, + hash.clone(), ) .await?; + proving_session_exists = false; - let da_block_header_of_commitments = l1_block.header().clone(); - - let hash = da_block_header_of_commitments.hash(); - let initial_state_root = self - .ledger_db - .get_l2_state_root::(first_l2_height_of_l1 - 1)? - .expect("There should be a state root"); - let initial_batch_hash = self - .ledger_db - .get_soft_confirmation_by_number(&BatchNumber(first_l2_height_of_l1))? - .ok_or(anyhow!( - "Could not find soft confirmation at height {}", - first_l2_height_of_l1 - ))? - .prev_hash; - - let final_state_root = self - .ledger_db - .get_l2_state_root::(last_l2_height_of_l1)? - .expect("There should be a state root"); - - let (inclusion_proof, completeness_proof) = self - .da_service - .get_extraction_proof(l1_block, &da_data.clone()) - .await; - - let transition_data: StateTransitionData = - StateTransitionData { - initial_state_root, - final_state_root, - initial_batch_hash, - da_data: da_data.clone(), - da_block_header_of_commitments, - inclusion_proof, - completeness_proof, - soft_confirmations, - state_transition_witnesses, - da_block_headers_of_soft_confirmations, - sequencer_commitments_range: ( - 0, - (sequencer_commitments.len() - 1).try_into().expect( - "cant be more than 4 billion commitments in a da block; qed", - ), - ), // for now process all commitments - sequencer_public_key: self.sequencer_pub_key.clone(), - sequencer_da_public_key: self.sequencer_da_pub_key.clone(), - }; - - let should_prove: bool = { - let mut rng = rand::thread_rng(); - // if proof_sampling_number is 0, then we always prove and submit - // otherwise we submit and prove with a probability of 1/proof_sampling_number - if prover_config.proof_sampling_number == 0 { - true - } else { - rng.gen_range(0..prover_config.proof_sampling_number) == 0 - } - }; - - // Skip submission until l1 height - if l1_height >= skip_submission_until_l1 && should_prove { - self.generate_and_submit_proof(transition_data, l1_height, hash) - .await?; - } else { - info!("Skipping proving for l1 height {}", l1_height); + self.save_commitments(sequencer_commitments, l1_height); } - self.save_commitments(sequencer_commitments, l1_height); + } - if let Err(e) = self - .ledger_db - .set_last_scanned_l1_height(SlotNumber(l1_height)) - { - panic!( - "Failed to put prover last scanned l1 height in the ledger db: {}", - e - ); - } + if let Err(e) = self + .ledger_db + .set_last_scanned_l1_height(SlotNumber(l1_height)) + { + panic!( + "Failed to put prover last scanned l1 height in the ledger db: {}", + e + ); } pending_l1_blocks.pop_front(); @@ -575,6 +544,93 @@ where Ok(()) } + async fn prove_state_transition( + &self, + transition_data: StateTransitionData, + prover_config: &ProverConfig, + skip_submission_until_l1: u64, + l1_height: u64, + hash: <::Spec as DaSpec>::SlotHash, + ) -> Result<(), anyhow::Error> { + // if proof_sampling_number is 0, then we always prove and submit + // otherwise we submit and prove with a probability of 1/proof_sampling_number + let should_prove = prover_config.proof_sampling_number == 0 + || rand::thread_rng().gen_range(0..prover_config.proof_sampling_number) == 0; + + // Skip submission until l1 height + if l1_height >= skip_submission_until_l1 && should_prove { + self.generate_and_submit_proof(transition_data, hash) + .await?; + } else { + info!("Skipping proving for l1 height {}", l1_height); + } + Ok(()) + } + + async fn create_state_transition_data( + &self, + sequencer_commitments: &[SequencerCommitment], + da_block_header_of_commitments: <::Spec as DaSpec>::BlockHeader, + da_data: Vec<<::Spec as DaSpec>::BlobTransaction>, + l1_block: &Da::FilteredBlock, + ) -> Result, anyhow::Error> { + let first_l2_height_of_l1 = sequencer_commitments[0].l2_start_block_number; + let last_l2_height_of_l1 = + sequencer_commitments[sequencer_commitments.len() - 1].l2_end_block_number; + let ( + state_transition_witnesses, + soft_confirmations, + da_block_headers_of_soft_confirmations, + ) = self + .get_state_transition_data_from_commitments(sequencer_commitments, &self.da_service) + .await?; + let initial_state_root = self + .ledger_db + .get_l2_state_root::(first_l2_height_of_l1 - 1)? + .expect("There should be a state root"); + let initial_batch_hash = self + .ledger_db + .get_soft_confirmation_by_number(&BatchNumber(first_l2_height_of_l1))? + .ok_or(anyhow!( + "Could not find soft batch at height {}", + first_l2_height_of_l1 + ))? + .prev_hash; + + let final_state_root = self + .ledger_db + .get_l2_state_root::(last_l2_height_of_l1)? + .expect("There should be a state root"); + + let (inclusion_proof, completeness_proof) = self + .da_service + .get_extraction_proof(l1_block, &da_data) + .await; + + let transition_data: StateTransitionData = + StateTransitionData { + initial_state_root, + final_state_root, + initial_batch_hash, + da_data, + da_block_header_of_commitments, + inclusion_proof, + completeness_proof, + soft_confirmations, + state_transition_witnesses, + da_block_headers_of_soft_confirmations, + sequencer_commitments_range: ( + 0, + (sequencer_commitments.len() - 1) + .try_into() + .expect("cant be more than 4 billion commitments in a da block; qed"), + ), // for now process all commitments + sequencer_public_key: self.sequencer_pub_key.clone(), + sequencer_da_public_key: self.sequencer_da_pub_key.clone(), + }; + Ok(transition_data) + } + fn extract_sequencer_commitments( &self, l1_block_hash: [u8; 32], @@ -701,7 +757,6 @@ where async fn generate_and_submit_proof( &self, transition_data: StateTransitionData, - l1_height: u64, hash: <::Spec as DaSpec>::SlotHash, ) -> Result<(), anyhow::Error> { let prover_service = self @@ -723,6 +778,14 @@ where } }; + self.extract_and_store_proof(tx_id, proof).await + } + + async fn extract_and_store_proof( + &self, + tx_id: ::TransactionId, + proof: Proof, + ) -> Result<(), anyhow::Error> { let tx_id_u8 = tx_id.into(); // l1_height => (tx_id, proof, transition_data) @@ -754,16 +817,22 @@ where info!("transition data: {:?}", transition_data); + let slot_hash = transition_data.da_slot_hash.into(); + let stored_state_transition = StoredStateTransition { initial_state_root: transition_data.initial_state_root.as_ref().to_vec(), final_state_root: transition_data.final_state_root.as_ref().to_vec(), state_diff: transition_data.state_diff, - da_slot_hash: transition_data.da_slot_hash.into(), + da_slot_hash: slot_hash, sequencer_commitments_range: transition_data.sequencer_commitments_range, sequencer_public_key: transition_data.sequencer_public_key, sequencer_da_public_key: transition_data.sequencer_da_public_key, validity_condition: borsh::to_vec(&transition_data.validity_condition).unwrap(), }; + let l1_height = self + .ledger_db + .get_l1_height_of_l1_hash(slot_hash)? + .expect("l1 height should exist"); if let Err(e) = self.ledger_db diff --git a/crates/prover/tests/prover_tests.rs b/crates/prover/tests/prover_tests.rs index c869f4cb5..df0102735 100644 --- a/crates/prover/tests/prover_tests.rs +++ b/crates/prover/tests/prover_tests.rs @@ -2,6 +2,7 @@ use std::collections::VecDeque; use std::sync::Arc; use citrea_prover::prover_service::ParallelProverService; +use sov_db::ledger_db::LedgerDB; use sov_mock_da::{ MockAddress, MockBlockHeader, MockDaService, MockDaSpec, MockDaVerifier, MockHash, MockValidityCond, @@ -15,7 +16,7 @@ use sov_stf_runner::{ WitnessSubmissionStatus, }; -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn test_successful_prover_execution() -> Result<(), ProverServiceError> { let temp = tempfile::tempdir().unwrap(); @@ -49,7 +50,7 @@ async fn test_successful_prover_execution() -> Result<(), ProverServiceError> { Ok(()) } -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn test_prover_status_busy() -> Result<(), anyhow::Error> { let temp = tempfile::tempdir().unwrap(); let da_service = Arc::new(MockDaService::new(MockAddress::from([0; 32]), temp.path())); @@ -190,6 +191,8 @@ fn make_new_prover() -> TestProver { let prover_config = ProverGuestRunConfig::Execute; let zk_stf = MockStf::::default(); let da_verifier = MockDaVerifier::default(); + let tmpdir = tempfile::tempdir().unwrap(); + let ledger_db = LedgerDB::with_path(tmpdir.path()).unwrap(); TestProver { prover_service: ParallelProverService::new( vm.clone(), @@ -198,6 +201,7 @@ fn make_new_prover() -> TestProver { prover_config, (), num_threads, + ledger_db, ) .expect("Should be able to instantiate Prover service"), vm, diff --git a/crates/risc0-bonsai/Cargo.toml b/crates/risc0-bonsai/Cargo.toml index a3d1cfe77..ad944f7c9 100644 --- a/crates/risc0-bonsai/Cargo.toml +++ b/crates/risc0-bonsai/Cargo.toml @@ -26,6 +26,7 @@ risc0-zkp = { workspace = true, optional = true } risc0-zkvm = { workspace = true, default-features = false, features = ["std"] } risc0-zkvm-platform = { workspace = true } serde = { workspace = true } +sov-db = { path = "../sovereign-sdk/full-node/db/sov-db" } sov-rollup-interface = { path = "../sovereign-sdk/rollup-interface" } tracing = { workspace = true } diff --git a/crates/risc0-bonsai/src/host.rs b/crates/risc0-bonsai/src/host.rs index 16f1e9fa5..b829e938d 100644 --- a/crates/risc0-bonsai/src/host.rs +++ b/crates/risc0-bonsai/src/host.rs @@ -11,10 +11,32 @@ use risc0_zkvm::{ compute_image_id, ExecutorEnvBuilder, ExecutorImpl, Groth16Receipt, InnerReceipt, Journal, Receipt, }; +use sov_db::ledger_db::{LedgerDB, ProvingServiceLedgerOps}; use sov_risc0_adapter::guest::Risc0Guest; use sov_rollup_interface::zk::{Proof, Zkvm, ZkvmHost}; use tracing::{error, info, warn}; +type StarkSessionId = String; +type SnarkSessionId = String; + +/// Bonsai sessions to be recovered in case of a crash. +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize)] +pub enum BonsaiSession { + /// Stark session id if the prover crashed during stark proof generation. + StarkSession(StarkSessionId), + /// Both Stark and Snark session id if the prover crashed during stark to snarkconversion. + SnarkSession(StarkSessionId, SnarkSessionId), +} + +/// Recovered bonsai session. +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize)] +pub struct RecoveredBonsaiSession { + /// Used for sending proofs in order + pub id: u8, + /// Recovered session + pub session: BonsaiSession, +} + macro_rules! retry_backoff_bonsai { ($bonsai_call:expr) => { retry_backoff( @@ -65,11 +87,12 @@ pub struct Risc0BonsaiHost<'a> { image_id: Digest, client: Option, last_input_id: Option, + ledger_db: LedgerDB, } impl<'a> Risc0BonsaiHost<'a> { /// Create a new Risc0Host to prove the given binary. - pub fn new(elf: &'a [u8], api_url: String, api_key: String) -> Self { + pub fn new(elf: &'a [u8], api_url: String, api_key: String, ledger_db: LedgerDB) -> Self { // Compute the image_id, then upload the ELF with the image_id as its key. // handle error let image_id = compute_image_id(elf).unwrap(); @@ -98,6 +121,7 @@ impl<'a> Risc0BonsaiHost<'a> { image_id, client, last_input_id: None, + ledger_db, } } @@ -112,6 +136,126 @@ impl<'a> Risc0BonsaiHost<'a> { tracing::info!("Uploaded input with id: {}", input_id); self.last_input_id = Some(input_id); } + + fn receipt_loop(&self, session: &str, client: &Client) -> Result, anyhow::Error> { + let session = bonsai_sdk::blocking::SessionId::new(session.to_owned()); + loop { + // handle error + let res = + retry_backoff_bonsai!(session.status(client)).expect("Failed to fetch status; qed"); + + if res.status == "RUNNING" { + tracing::info!( + "Current status: {} - state: {} - continue polling...", + res.status, + res.state.unwrap_or_default() + ); + std::thread::sleep(Duration::from_secs(15)); + continue; + } + if res.status == "SUCCEEDED" { + // Download the receipt, containing the output + let receipt_url = res + .receipt_url + .expect("API error, missing receipt on completed session"); + + tracing::info!("Receipt URL: {}", receipt_url); + + let receipt_buf = retry_backoff_bonsai!(client.download(receipt_url.as_str())) + .expect("Failed to download receipt; qed"); + break Ok(receipt_buf); + } else { + return Err(anyhow!( + "Workflow exited: {} with error message: {}", + res.status, + res.error_msg.unwrap_or_default() + )); + } + } + } + + fn wait_for_receipt(&self, session: &str) -> Result, anyhow::Error> { + let session = bonsai_sdk::blocking::SessionId::new(session.to_string()); + let client = self.client.as_ref().unwrap(); + self.receipt_loop(&session.uuid, client) + } + + fn wait_for_stark_to_snark_conversion( + &self, + snark_session: Option<&str>, + stark_session: &str, + receipt_buf: Vec, + ) -> Result { + // If snark session exists use it else create one from stark + let snark_session = match snark_session { + Some(snark_session) => bonsai_sdk::blocking::SnarkId::new(snark_session.to_string()), + None => { + let client = self.client.as_ref().unwrap(); + let session = bonsai_sdk::blocking::SessionId::new(stark_session.to_string()); + retry_backoff_bonsai!(client.create_snark(session.uuid.clone())) + .expect("Failed to create snark session; qed") + } + }; + + let recovered_serialized_snark_session = borsh::to_vec(&RecoveredBonsaiSession { + id: 0, + session: BonsaiSession::SnarkSession( + stark_session.to_string(), + snark_session.uuid.clone(), + ), + })?; + self.ledger_db + .add_pending_proving_session(recovered_serialized_snark_session.clone())?; + + let client = self.client.as_ref().unwrap(); + let receipt: Receipt = bincode::deserialize(&receipt_buf).unwrap(); + loop { + let res = retry_backoff_bonsai!(snark_session.status(client)) + .expect("Failed to fetch status; qed"); + match res.status.as_str() { + "RUNNING" => { + tracing::info!("Current status: {} - continue polling...", res.status,); + std::thread::sleep(Duration::from_secs(15)); + continue; + } + "SUCCEEDED" => { + let snark_receipt = match res.output { + Some(output) => output, + None => { + return Err(anyhow!( + "SNARK session succeeded but no output was provided" + )) + } + }; + tracing::info!("Snark proof!: {snark_receipt:?}"); + + // now we convert the snark_receipt to a full receipt + + use risc0_zkvm::sha::Digestible; + let inner = InnerReceipt::Groth16(Groth16Receipt::new( + snark_receipt.snark.to_vec(), + receipt.claim().expect("stark_2_snark error, receipt claim"), + risc0_zkvm::Groth16ReceiptVerifierParameters::default().digest(), + )); + + let full_snark_receipt = Receipt::new(inner, snark_receipt.journal); + + tracing::info!("Full snark proof!: {full_snark_receipt:?}"); + + let full_snark_receipt = bincode::serialize(&full_snark_receipt)?; + + return Ok(Proof::Full(full_snark_receipt)); + } + _ => { + return Err(anyhow!( + "Workflow exited: {} with error message: {}", + res.status, + res.error_msg.unwrap_or_default() + )); + } + } + } + } } impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { @@ -179,105 +323,36 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { false )) .expect("Failed to create session; qed"); - tracing::info!("Session created: {}", session.uuid); - let receipt = loop { - // handle error - let res = retry_backoff_bonsai!(session.status(client)) - .expect("Failed to fetch status; qed"); - - if res.status == "RUNNING" { - tracing::info!( - "Current status: {} - state: {} - continue polling...", - res.status, - res.state.unwrap_or_default() - ); - std::thread::sleep(Duration::from_secs(15)); - continue; - } - if res.status == "SUCCEEDED" { - // Download the receipt, containing the output - let receipt_url = res - .receipt_url - .expect("API error, missing receipt on completed session"); - - tracing::info!("Receipt URL: {}", receipt_url); - if let Some(stats) = res.stats { - tracing::info!( - "User cycles: {} - Total cycles: {} - Segments: {}", - stats.cycles, - stats.total_cycles, - stats.segments, - ); - } - - let receipt_buf = retry_backoff_bonsai!(client.download(receipt_url.as_str())) - .expect("Failed to download receipt; qed"); + let stark_session = RecoveredBonsaiSession { + id: 0, + session: BonsaiSession::StarkSession(session.uuid.clone()), + }; + let serialized_stark_session = borsh::to_vec(&stark_session) + .expect("Bonsai host should be able to serialize bonsai sessions"); + self.ledger_db + .add_pending_proving_session(serialized_stark_session.clone())?; - let receipt: Receipt = bincode::deserialize(&receipt_buf).unwrap(); + tracing::info!("Session created: {}", session.uuid); - break receipt; - } else { - return Err(anyhow!( - "Workflow exited: {} with error message: {}", - res.status, - res.error_msg.unwrap_or_default() - )); - } - }; + let receipt = self.wait_for_receipt(&session.uuid)?; tracing::info!("Creating the SNARK"); let snark_session = retry_backoff_bonsai!(client.create_snark(session.uuid.clone())) .expect("Failed to create snark session; qed"); - tracing::info!("SNARK session created: {}", snark_session.uuid); - - loop { - let res = retry_backoff_bonsai!(snark_session.status(client)) - .expect("Failed to fetch status; qed"); - match res.status.as_str() { - "RUNNING" => { - tracing::info!("Current status: {} - continue polling...", res.status,); - std::thread::sleep(Duration::from_secs(15)); - continue; - } - "SUCCEEDED" => { - let snark_receipt = match res.output { - Some(output) => output, - None => { - return Err(anyhow!( - "SNARK session succeeded but no output was provided" - )) - } - }; - tracing::info!("Snark proof!: {snark_receipt:?}"); - - // now we convert the snark_receipt to a full receipt - - use risc0_zkvm::sha::Digestible; - let inner = InnerReceipt::Groth16(Groth16Receipt::new( - snark_receipt.snark.to_vec(), - receipt.claim().expect("stark_2_snark error, receipt claim"), - risc0_zkvm::Groth16ReceiptVerifierParameters::default().digest(), - )); + // Remove the stark session as it is finished + self.ledger_db + .remove_pending_proving_session(serialized_stark_session.clone())?; - let full_snark_receipt = Receipt::new(inner, snark_receipt.journal); - - tracing::info!("Full snark proof!: {full_snark_receipt:?}"); - - let full_snark_receipt = bincode::serialize(&full_snark_receipt)?; + tracing::info!("SNARK session created: {}", snark_session.uuid); - return Ok(Proof::Full(full_snark_receipt)); - } - _ => { - return Err(anyhow!( - "Workflow exited: {} with error message: {}", - res.status, - res.error_msg.unwrap_or_default() - )); - } - } - } + // Snark session is saved in the function + self.wait_for_stark_to_snark_conversion( + Some(&snark_session.uuid), + &session.uuid, + receipt, + ) } } @@ -296,6 +371,33 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { }; Ok(BorshDeserialize::try_from_slice(&journal.bytes)?) } + + fn recover_proving_sessions(&self) -> Result, anyhow::Error> { + let sessions = self.ledger_db.get_pending_proving_sessions()?; + let mut proofs = Vec::new(); + for session in sessions { + let bonsai_session: RecoveredBonsaiSession = BorshDeserialize::try_from_slice(&session) + .expect("Bonsai host should be able to recover bonsai sessions"); + match bonsai_session.session { + BonsaiSession::StarkSession(stark_session) => { + let receipt = self.wait_for_receipt(&stark_session)?; + let proof = + self.wait_for_stark_to_snark_conversion(None, &stark_session, receipt)?; + proofs.push(proof); + } + BonsaiSession::SnarkSession(stark_session, snark_session) => { + let receipt = self.wait_for_receipt(&stark_session)?; + let proof = self.wait_for_stark_to_snark_conversion( + Some(&snark_session), + &stark_session, + receipt, + )?; + proofs.push(proof) + } + } + } + Ok(proofs) + } } impl<'host> Zkvm for Risc0BonsaiHost<'host> { diff --git a/crates/sovereign-sdk/adapters/mock-zkvm/src/lib.rs b/crates/sovereign-sdk/adapters/mock-zkvm/src/lib.rs index 847275814..620a20b57 100644 --- a/crates/sovereign-sdk/adapters/mock-zkvm/src/lib.rs +++ b/crates/sovereign-sdk/adapters/mock-zkvm/src/lib.rs @@ -9,7 +9,7 @@ use anyhow::ensure; use borsh::{BorshDeserialize, BorshSerialize}; use serde::{Deserialize, Serialize}; use sov_rollup_interface::da::BlockHeaderTrait; -use sov_rollup_interface::zk::{Matches, StateTransitionData, ValidityCondition}; +use sov_rollup_interface::zk::{Matches, Proof, StateTransitionData, ValidityCondition}; /// A mock commitment to a particular zkVM program. #[derive(Debug, Clone, PartialEq, Eq, BorshDeserialize, BorshSerialize, Serialize, Deserialize)] @@ -196,6 +196,10 @@ impl sov_rollup_interface::zk::ZkvmHost } } } + + fn recover_proving_sessions(&self) -> Result, anyhow::Error> { + unimplemented!() + } } /// A mock implementing the Guest. diff --git a/crates/sovereign-sdk/adapters/risc0/src/host.rs b/crates/sovereign-sdk/adapters/risc0/src/host.rs index 4c2e3e922..f1bcf2ce4 100644 --- a/crates/sovereign-sdk/adapters/risc0/src/host.rs +++ b/crates/sovereign-sdk/adapters/risc0/src/host.rs @@ -105,6 +105,10 @@ impl<'a> ZkvmHost for Risc0Host<'a> { }; Ok(BorshDeserialize::deserialize(&mut journal.bytes.as_ref())?) } + + fn recover_proving_sessions(&self) -> Result, anyhow::Error> { + unimplemented!() + } } impl<'host> Zkvm for Risc0Host<'host> { diff --git a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs index cde955a72..cdd85b82a 100644 --- a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs +++ b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs @@ -16,9 +16,10 @@ use crate::rocks_db_config::gen_rocksdb_options; use crate::schema::tables::{ ActiveFork, BatchByNumber, CommitmentsByNumber, EventByKey, EventByNumber, L2GenesisStateRoot, L2RangeByL1Height, L2Witness, LastSequencerCommitmentSent, LastStateDiff, MempoolTxs, - PendingSequencerCommitmentL2Range, ProofBySlotNumber, ProverLastScannedSlot, ProverStateDiffs, - SlotByHash, SlotByNumber, SoftConfirmationByHash, SoftConfirmationByNumber, - SoftConfirmationStatus, TxByHash, TxByNumber, VerifiedProofsBySlotNumber, LEDGER_TABLES, + PendingProvingSessions, PendingSequencerCommitmentL2Range, ProofBySlotNumber, + ProverLastScannedSlot, ProverStateDiffs, SlotByHash, SlotByNumber, SoftConfirmationByHash, + SoftConfirmationByNumber, SoftConfirmationStatus, TxByHash, TxByNumber, + VerifiedProofsBySlotNumber, LEDGER_TABLES, }; use crate::schema::types::{ split_tx_for_storage, BatchNumber, EventNumber, L2HeightRange, SlotNumber, StoredProof, @@ -342,6 +343,12 @@ impl SharedLedgerOps for LedgerDB { self.db.put::(&hash, &SlotNumber(height)) } + /// Gets l1 height of l1 hash + #[instrument(level = "trace", skip(self), err, ret)] + fn get_l1_height_of_l1_hash(&self, hash: [u8; 32]) -> Result, anyhow::Error> { + self.db.get::(&hash).map(|v| v.map(|a| a.0)) + } + /// Saves a soft confirmation status for a given L1 height #[instrument(level = "trace", skip(self), err, ret)] fn put_soft_confirmation_status( @@ -569,6 +576,46 @@ impl ProverLedgerOps for LedgerDB { } } +impl ProvingServiceLedgerOps for LedgerDB { + /// Gets all pending sessions and step numbers + #[instrument(level = "trace", skip(self), err)] + fn get_pending_proving_sessions(&self) -> anyhow::Result>> { + let mut iter = self.db.iter::()?; + iter.seek_to_first(); + + let sessions = iter + .map(|item| item.map(|item| (item.key))) + .collect::, _>>()?; + Ok(sessions) + } + + #[instrument(level = "trace", skip(self), err)] + fn add_pending_proving_session(&self, session: Vec) -> anyhow::Result<()> { + self.db.put::(&session, &()) + } + + #[instrument(level = "trace", skip(self), err)] + fn remove_pending_proving_session(&self, session: Vec) -> anyhow::Result<()> { + self.db.delete::(&session) + } + + #[instrument(level = "trace", skip(self), err)] + fn clear_pending_proving_sessions(&self) -> anyhow::Result<()> { + let mut schema_batch = SchemaBatch::new(); + let mut iter = self.db.iter::()?; + iter.seek_to_first(); + + for item in iter { + let item = item?; + schema_batch.delete::(&item.key)?; + } + + self.db.write_schemas(schema_batch)?; + + Ok(()) + } +} + impl SequencerLedgerOps for LedgerDB { /// Put slots #[instrument(level = "trace", skip(self, schema_batch), err, ret)] @@ -748,12 +795,6 @@ impl NodeLedgerOps for LedgerDB { ) -> anyhow::Result>> { self.db.get::(&SlotNumber(height)) } - - /// Gets l1 height of l1 hash - #[instrument(level = "trace", skip(self), err, ret)] - fn get_l1_height_of_l1_hash(&self, hash: [u8; 32]) -> Result, anyhow::Error> { - self.db.get::(&hash).map(|v| v.map(|a| a.0)) - } } impl ForkMigration for LedgerDB { diff --git a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs index f064eaaed..05eafdbab 100644 --- a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs +++ b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs @@ -69,6 +69,9 @@ pub trait SharedLedgerOps { /// Sets l1 height of l1 hash fn set_l1_height_of_l1_hash(&self, hash: [u8; 32], height: u64) -> Result<()>; + /// Gets l1 height of l1 hash + fn get_l1_height_of_l1_hash(&self, hash: [u8; 32]) -> Result>; + /// Saves a soft confirmation status for a given L1 height fn put_soft_confirmation_status( &self, @@ -141,13 +144,10 @@ pub trait NodeLedgerOps: SharedLedgerOps { /// Gets the commitments in the da slot with given height if any fn get_commitments_on_da_slot(&self, height: u64) -> Result>>; - - /// Gets l1 height of l1 hash - fn get_l1_height_of_l1_hash(&self, hash: [u8; 32]) -> Result>; } /// Prover ledger operations -pub trait ProverLedgerOps: SharedLedgerOps { +pub trait ProverLedgerOps: SharedLedgerOps + Send + Sync { /// Get the state root by L2 height fn get_l2_state_root( &self, @@ -183,6 +183,21 @@ pub trait ProverLedgerOps: SharedLedgerOps { fn get_l2_state_diff(&self, l2_height: BatchNumber) -> Result>; } +/// Ledger operations for the prover service +pub trait ProvingServiceLedgerOps: ProverLedgerOps + SharedLedgerOps + Send + Sync { + /// Gets all pending sessions and step numbers + fn get_pending_proving_sessions(&self) -> Result>>; + + /// Adds a pending proving session + fn add_pending_proving_session(&self, session: Vec) -> Result<()>; + + /// Removes a pending proving session + fn remove_pending_proving_session(&self, session: Vec) -> Result<()>; + + /// Clears all pending proving sessions + fn clear_pending_proving_sessions(&self) -> Result<()>; +} + /// Sequencer ledger operations pub trait SequencerLedgerOps: SharedLedgerOps { /// Put slots diff --git a/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs b/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs index f089a33d6..9bbcb1037 100644 --- a/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs +++ b/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs @@ -74,6 +74,7 @@ pub const LEDGER_TABLES: &[&str] = &[ ProofBySlotNumber::table_name(), VerifiedProofsBySlotNumber::table_name(), MempoolTxs::table_name(), + PendingProvingSessions::table_name(), ProverStateDiffs::table_name(), ]; @@ -333,6 +334,12 @@ define_table_with_default_codec!( (VerifiedProofsBySlotNumber) SlotNumber => Vec ); +define_table_with_seek_key_codec!( + /// Proving service uses this table to store pending proving sessions + /// If a session id is completed, remove it + (PendingProvingSessions) Vec => () +); + define_table_with_default_codec!( /// Transactions in mempool (TxHash, TxData) (MempoolTxs) Vec => Vec diff --git a/crates/sovereign-sdk/full-node/sov-stf-runner/src/prover_service/mod.rs b/crates/sovereign-sdk/full-node/sov-stf-runner/src/prover_service/mod.rs index 82900a561..2293bf659 100644 --- a/crates/sovereign-sdk/full-node/sov-stf-runner/src/prover_service/mod.rs +++ b/crates/sovereign-sdk/full-node/sov-stf-runner/src/prover_service/mod.rs @@ -113,4 +113,10 @@ pub trait ProverService { block_header_hash: <::Spec as DaSpec>::SlotHash, da_service: &Arc, ) -> Result<(::TransactionId, Proof), anyhow::Error>; + + /// Recovers pending proving sessions and sends proofs to the DA. + async fn recover_proving_sessions_and_send_to_da( + &self, + da_service: &Arc, + ) -> Result::TransactionId, Proof)>, anyhow::Error>; } diff --git a/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs b/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs index 614f7ddad..5fd0b8b74 100644 --- a/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs +++ b/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs @@ -112,6 +112,7 @@ pub trait RollupBlueprint: Sized + Send + Sync { prover_config: ProverConfig, rollup_config: &FullNodeConfig, da_service: &Arc, + ledger_db: LedgerDB, ) -> Self::ProverService; /// Creates instance of [`Self::StorageManager`]. diff --git a/crates/sovereign-sdk/rollup-interface/src/state_machine/zk/mod.rs b/crates/sovereign-sdk/rollup-interface/src/state_machine/zk/mod.rs index 3e085a5a5..5d378b115 100644 --- a/crates/sovereign-sdk/rollup-interface/src/state_machine/zk/mod.rs +++ b/crates/sovereign-sdk/rollup-interface/src/state_machine/zk/mod.rs @@ -54,6 +54,9 @@ pub trait ZkvmHost: Zkvm + Clone { fn extract_output( proof: &Proof, ) -> Result, Self::Error>; + + /// Host recovers pending proving sessions and returns proving results + fn recover_proving_sessions(&self) -> Result, anyhow::Error>; } /// A Zk proof system capable of proving and verifying arbitrary Rust code