From 3441722fc5f368645c3e7f568b2261cb08c41a02 Mon Sep 17 00:00:00 2001 From: Esad Yusuf Atik Date: Wed, 14 Aug 2024 11:51:39 +0300 Subject: [PATCH 1/5] Use bonsai sdk blocking client (#980) * remove bonsai sdk client wrapper queue * Update crates/risc0-bonsai/src/host.rs Co-authored-by: Roman * use backoff in risc0-bonsai --------- Co-authored-by: Roman --- Cargo.lock | 1 + crates/risc0-bonsai/Cargo.toml | 1 + crates/risc0-bonsai/src/host.rs | 330 ++++++-------------------------- 3 files changed, 59 insertions(+), 273 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 07fd7249d9..7abb786a6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1955,6 +1955,7 @@ name = "citrea-risc0-bonsai-adapter" version = "0.4.0-rc.3" dependencies = [ "anyhow", + "backoff", "bincode", "bonsai-sdk", "borsh", diff --git a/crates/risc0-bonsai/Cargo.toml b/crates/risc0-bonsai/Cargo.toml index c73f319f44..a3d1cfe776 100644 --- a/crates/risc0-bonsai/Cargo.toml +++ b/crates/risc0-bonsai/Cargo.toml @@ -13,6 +13,7 @@ description = "An adapter allowing Citrea to connect with Bonsai" [dependencies] anyhow = { workspace = true } +backoff = { workspace = true } bincode = { workspace = true } bonsai-sdk = { workspace = true } borsh = { workspace = true } diff --git a/crates/risc0-bonsai/src/host.rs b/crates/risc0-bonsai/src/host.rs index 19cbefdf90..16f1e9fa5d 100644 --- a/crates/risc0-bonsai/src/host.rs +++ b/crates/risc0-bonsai/src/host.rs @@ -1,10 +1,10 @@ //! This module implements the [`ZkvmHost`] trait for the RISC0 VM. - -use std::sync::mpsc::{self, Sender}; -use std::sync::Arc; use std::time::Duration; use anyhow::anyhow; +use backoff::exponential::ExponentialBackoffBuilder; +use backoff::{retry as retry_backoff, SystemClock}; +use bonsai_sdk::blocking::Client; use borsh::{BorshDeserialize, BorshSerialize}; use risc0_zkvm::sha::Digest; use risc0_zkvm::{ @@ -13,278 +13,48 @@ use risc0_zkvm::{ }; use sov_risc0_adapter::guest::Risc0Guest; use sov_rollup_interface::zk::{Proof, Zkvm, ZkvmHost}; -use tracing::{debug, error, info, instrument, trace, warn}; - -/// Requests to bonsai client. Each variant represents its own method. -#[derive(Clone)] -enum BonsaiRequest { - UploadImg { - image_id: String, - buf: Vec, - notify: Sender, - }, - UploadInput { - buf: Vec, - notify: Sender, - }, - Download { - url: String, - notify: Sender>, - }, - CreateSession { - img_id: String, - input_id: String, - assumptions: Vec, - notify: Sender, - }, - CreateSnark { - session: bonsai_sdk::blocking::SessionId, - notify: Sender, - }, - Status { - session: bonsai_sdk::blocking::SessionId, - notify: Sender, - }, - SnarkStatus { - session: bonsai_sdk::blocking::SnarkId, - notify: Sender, - }, -} - -/// A wrapper around Bonsai SDK to handle tokio runtime inside another tokio runtime. -/// See https://stackoverflow.com/a/62536772. -#[derive(Clone)] -struct BonsaiClient { - queue: std::sync::mpsc::Sender, - _join_handle: Arc>, -} - -impl BonsaiClient { - fn from_parts(api_url: String, api_key: String, risc0_version: &str) -> Self { - macro_rules! unwrap_bonsai_response { - ($response:expr, $client_loop:lifetime, $queue_loop:lifetime) => ( - match $response { - Ok(r) => r, +use tracing::{error, info, warn}; + +macro_rules! retry_backoff_bonsai { + ($bonsai_call:expr) => { + retry_backoff( + ExponentialBackoffBuilder::::new() + .with_initial_interval(Duration::from_secs(5)) + .with_max_elapsed_time(Some(Duration::from_secs(15 * 60))) + .build(), + || { + let response = $bonsai_call; + match response { + Ok(r) => Ok(r), Err(e) => { use ::bonsai_sdk::SdkErr::*; match e { InternalServerErr(s) => { - warn!(%s, "Got HHTP 500 from Bonsai"); - std::thread::sleep(Duration::from_secs(10)); - continue $queue_loop + let err = format!("Got HHTP 500 from Bonsai: {}", s); + warn!(err); + Err(backoff::Error::transient(err)) } HttpErr(e) => { - error!(?e, "Reconnecting to Bonsai"); - std::thread::sleep(Duration::from_secs(5)); - continue $client_loop + let err = format!("Reconnecting to Bonsai: {}", e); + error!(err); + Err(backoff::Error::transient(err)) } HttpHeaderErr(e) => { - error!(?e, "Reconnecting to Bonsai"); - std::thread::sleep(Duration::from_secs(5)); - continue $client_loop + let err = format!("Reconnecting to Bonsai: {}", e); + error!(err); + Err(backoff::Error::transient(err)) } e => { - error!(?e, "Got unrecoverable error from Bonsai"); - panic!("Bonsai API error: {}", e); + let err = format!("Got unrecoverable error from Bonsai: {}", e); + error!(err); + Err(backoff::Error::permanent(err)) } } } } - ); - } - let risc0_version = risc0_version.to_string(); - let (queue, rx) = std::sync::mpsc::channel(); - let join_handle = std::thread::spawn(move || { - let mut last_request: Option = None; - 'client: loop { - debug!("Connecting to Bonsai"); - let client = match bonsai_sdk::blocking::Client::from_parts( - api_url.clone(), - api_key.clone(), - &risc0_version, - ) { - Ok(client) => client, - Err(e) => { - error!(?e, "Failed to connect to Bonsai"); - std::thread::sleep(Duration::from_secs(5)); - continue 'client; - } - }; - 'queue: loop { - let request = if let Some(last_request) = last_request.clone() { - debug!("Retrying last request after reconnection"); - last_request - } else { - trace!("Waiting for a new request"); - let req: BonsaiRequest = rx.recv().expect("bonsai client sender is dead"); - // Save request for retries - last_request = Some(req.clone()); - req - }; - match request { - BonsaiRequest::UploadImg { - image_id, - buf, - notify, - } => { - debug!(%image_id, "Bonsai:upload_img"); - let res = client.upload_img(&image_id, buf); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::UploadInput { buf, notify } => { - debug!("Bonsai:upload_input"); - let res = client.upload_input(buf); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::Download { url, notify } => { - debug!(%url, "Bonsai:download"); - let res = client.download(&url); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::CreateSession { - img_id, - input_id, - assumptions, - notify, - } => { - debug!(%img_id, %input_id, "Bonsai:create_session"); - // TODO: think about whether we should have a case where we use Bonsai with only execute mode - let res = client.create_session(img_id, input_id, assumptions, false); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::Status { session, notify } => { - debug!(?session, "Bonsai:session_status"); - let res = session.status(&client); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::CreateSnark { session, notify } => { - debug!(?session, "Bonsai:create_snark"); - let res = client.create_snark(session.uuid); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::SnarkStatus { session, notify } => { - debug!(?session, "Bonsai:snark_status"); - let res = session.status(&client); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - }; - // We arrive here only on a successful response - last_request = None; - } - } - }); - let _join_handle = Arc::new(join_handle); - Self { - queue, - _join_handle, - } - } - - #[instrument(level = "trace", skip(self, buf), ret)] - fn upload_img(&self, image_id: String, buf: Vec) -> bool { - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::UploadImg { - image_id, - buf, - notify, - }) - .expect("Bonsai processing queue is dead"); - rx.recv().unwrap() - } - - #[instrument(level = "trace", skip_all, ret)] - fn upload_input(&self, buf: Vec) -> String { - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::UploadInput { buf, notify }) - .expect("Bonsai processing queue is dead"); - rx.recv().unwrap() - } - - #[instrument(level = "trace", skip(self))] - fn download(&self, url: String) -> Vec { - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::Download { url, notify }) - .expect("Bonsai processing queue is dead"); - rx.recv().unwrap() - } - - #[instrument(level = "trace", skip(self, assumptions), ret)] - fn create_session( - &self, - img_id: String, - input_id: String, - assumptions: Vec, - ) -> bonsai_sdk::blocking::SessionId { - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::CreateSession { - img_id, - input_id, - assumptions, - notify, - }) - .expect("Bonsai processing queue is dead"); - rx.recv().unwrap() - } - - #[instrument(level = "trace", skip(self))] - fn status( - &self, - session: &bonsai_sdk::blocking::SessionId, - ) -> bonsai_sdk::responses::SessionStatusRes { - let session = session.clone(); - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::Status { session, notify }) - .expect("Bonsai processing queue is dead"); - let status = rx.recv().unwrap(); - debug!( - status.status, - status.receipt_url, status.error_msg, status.state, status.elapsed_time - ); - status - } - - #[instrument(level = "trace", skip(self), ret)] - fn create_snark( - &self, - session: &bonsai_sdk::blocking::SessionId, - ) -> bonsai_sdk::blocking::SnarkId { - let session = session.clone(); - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::CreateSnark { session, notify }) - .expect("Bonsai processing queue is dead"); - rx.recv().unwrap() - } - - #[instrument(level = "trace", skip(self))] - fn snark_status( - &self, - snark_session: &bonsai_sdk::blocking::SnarkId, - ) -> bonsai_sdk::responses::SnarkStatusRes { - let snark_session = snark_session.clone(); - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::SnarkStatus { - session: snark_session, - notify, - }) - .expect("Bonsai processing queue is dead"); - let status = rx.recv().unwrap(); - debug!(status.status, ?status.output, status.error_msg); - status - } + }, + ) + }; } /// A [`Risc0BonsaiHost`] stores a binary to execute in the Risc0 VM and prove in the Risc0 Bonsai API. @@ -293,7 +63,7 @@ pub struct Risc0BonsaiHost<'a> { elf: &'a [u8], env: Vec, image_id: Digest, - client: Option, + client: Option, last_input_id: Option, } @@ -308,12 +78,14 @@ impl<'a> Risc0BonsaiHost<'a> { // handle error let client = if !api_url.is_empty() && !api_key.is_empty() { - let client = BonsaiClient::from_parts(api_url, api_key, risc0_zkvm::VERSION); - tracing::debug!("Uploading image with id: {}", image_id); - // handle error - client.upload_img(hex::encode(image_id), elf.to_vec()); + let client = Client::from_parts(api_url, api_key, risc0_zkvm::VERSION) + .expect("Failed to create Bonsai client; qed"); + + client + .upload_img(hex::encode(image_id).as_str(), elf.to_vec()) + .expect("Failed to upload image; qed"); Some(client) } else { @@ -331,11 +103,12 @@ impl<'a> Risc0BonsaiHost<'a> { fn upload_to_bonsai(&mut self, buf: Vec) { // handle error - let input_id = self + let input_id = retry_backoff_bonsai!(self .client .as_ref() .expect("Bonsai client is not initialized") - .upload_input(buf); + .upload_input(buf.clone())) + .expect("Failed to upload input; qed"); tracing::info!("Uploaded input with id: {}", input_id); self.last_input_id = Some(input_id); } @@ -398,11 +171,19 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { }; // Start a session running the prover - let session = client.create_session(hex::encode(self.image_id), input_id, vec![]); + // execute only is set to false because we run bonsai only when proving + let session = retry_backoff_bonsai!(client.create_session( + hex::encode(self.image_id), + input_id.clone(), + vec![], + false + )) + .expect("Failed to create session; qed"); tracing::info!("Session created: {}", session.uuid); let receipt = loop { // handle error - let res = client.status(&session); + let res = retry_backoff_bonsai!(session.status(client)) + .expect("Failed to fetch status; qed"); if res.status == "RUNNING" { tracing::info!( @@ -429,7 +210,8 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { ); } - let receipt_buf = client.download(receipt_url); + let receipt_buf = retry_backoff_bonsai!(client.download(receipt_url.as_str())) + .expect("Failed to download receipt; qed"); let receipt: Receipt = bincode::deserialize(&receipt_buf).unwrap(); @@ -445,12 +227,14 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { tracing::info!("Creating the SNARK"); - let snark_session = client.create_snark(&session); + let snark_session = retry_backoff_bonsai!(client.create_snark(session.uuid.clone())) + .expect("Failed to create snark session; qed"); tracing::info!("SNARK session created: {}", snark_session.uuid); loop { - let res = client.snark_status(&snark_session); + let res = retry_backoff_bonsai!(snark_session.status(client)) + .expect("Failed to fetch status; qed"); match res.status.as_str() { "RUNNING" => { tracing::info!("Current status: {} - continue polling...", res.status,); From 66a1a1911dc4f09b5c736dcb2b5e560007bd45b1 Mon Sep 17 00:00:00 2001 From: Rakan Al-Huneiti Date: Wed, 14 Aug 2024 13:02:43 +0300 Subject: [PATCH 2/5] Prover commitments grouping (#983) * Remove _prover_ in prover method names in ledger * Add prover state diff ledger tables * Move constants to primitives * Implement generic merge_diffs * Be able to clone da_data * Split commitments into groups and prove separately * Fix bug in grouping commitments * Pass flag to control MockDA This flag is to decide whether a single transaction should create a DA block in MockDA or multiple. * Group up to n-1 if we exceed * Replace bincode with borsh for diff serialization --- Cargo.lock | 1 - bin/citrea/tests/e2e/reopen.rs | 2 +- crates/primitives/src/constants.rs | 3 + crates/primitives/src/lib.rs | 2 + crates/primitives/src/utils.rs | 10 + crates/prover/src/runner.rs | 245 +++++++++++------- crates/sequencer/Cargo.toml | 1 - crates/sequencer/src/sequencer.rs | 22 +- .../full-node/db/sov-db/src/ledger_db/mod.rs | 27 +- .../full-node/db/sov-db/src/ledger_db/rpc.rs | 2 +- .../db/sov-db/src/ledger_db/traits.rs | 10 +- .../full-node/db/sov-db/src/schema/tables.rs | 6 + .../sov-modules-rollup-blueprint/src/lib.rs | 3 + .../rollup-interface/src/state_machine/da.rs | 2 +- 14 files changed, 221 insertions(+), 115 deletions(-) create mode 100644 crates/primitives/src/utils.rs diff --git a/Cargo.lock b/Cargo.lock index 7abb786a6b..0bf3bf786d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1980,7 +1980,6 @@ dependencies = [ "alloy-rlp", "alloy-sol-types", "anyhow", - "bincode", "borsh", "chrono", "citrea-evm", diff --git a/bin/citrea/tests/e2e/reopen.rs b/bin/citrea/tests/e2e/reopen.rs index 5949473b11..2de0e4f184 100644 --- a/bin/citrea/tests/e2e/reopen.rs +++ b/bin/citrea/tests/e2e/reopen.rs @@ -299,7 +299,7 @@ async fn test_reopen_sequencer() -> Result<(), anyhow::Error> { #[tokio::test(flavor = "multi_thread")] async fn test_reopen_prover() -> Result<(), anyhow::Error> { - // citrea::initialize_logging(tracing::Level::DEBUG); + citrea::initialize_logging(tracing::Level::DEBUG); let storage_dir = tempdir_with_children(&["DA", "sequencer", "prover"]); let da_db_dir = storage_dir.path().join("DA").to_path_buf(); diff --git a/crates/primitives/src/constants.rs b/crates/primitives/src/constants.rs index f17bb704d0..6c812ff8d2 100644 --- a/crates/primitives/src/constants.rs +++ b/crates/primitives/src/constants.rs @@ -5,3 +5,6 @@ pub const DA_TX_ID_LEADING_ZEROS: &[u8] = [0, 0].as_slice(); pub const TEST_PRIVATE_KEY: &str = "1212121212121212121212121212121212121212121212121212121212121212"; + +pub const MAX_STATEDIFF_SIZE_COMMITMENT_THRESHOLD: u64 = 300 * 1024; +pub const MAX_STATEDIFF_SIZE_PROOF_THRESHOLD: u64 = 400 * 1024; diff --git a/crates/primitives/src/lib.rs b/crates/primitives/src/lib.rs index b91c008e15..d68dc33be3 100644 --- a/crates/primitives/src/lib.rs +++ b/crates/primitives/src/lib.rs @@ -8,6 +8,8 @@ mod error; pub mod fork; pub mod forks; pub mod types; +#[cfg(feature = "native")] +pub mod utils; #[cfg(feature = "native")] pub use cache::*; diff --git a/crates/primitives/src/utils.rs b/crates/primitives/src/utils.rs new file mode 100644 index 0000000000..efe298cb5e --- /dev/null +++ b/crates/primitives/src/utils.rs @@ -0,0 +1,10 @@ +use std::collections::HashMap; + +use sov_rollup_interface::stf::StateDiff; + +pub fn merge_state_diffs(old_diff: StateDiff, new_diff: StateDiff) -> StateDiff { + let mut new_diff_map = HashMap::, Option>>::from_iter(old_diff); + + new_diff_map.extend(new_diff); + new_diff_map.into_iter().collect() +} diff --git a/crates/prover/src/runner.rs b/crates/prover/src/runner.rs index 50ff6ef953..01dc76b6c9 100644 --- a/crates/prover/src/runner.rs +++ b/crates/prover/src/runner.rs @@ -11,7 +11,8 @@ use backoff::future::retry as retry_backoff; use borsh::de::BorshDeserialize; use citrea_primitives::fork::{Fork, ForkManager}; use citrea_primitives::types::SoftConfirmationHash; -use citrea_primitives::{get_da_block_at_height, L1BlockCache}; +use citrea_primitives::utils::merge_state_diffs; +use citrea_primitives::{get_da_block_at_height, L1BlockCache, MAX_STATEDIFF_SIZE_PROOF_THRESHOLD}; use jsonrpsee::core::client::Error as JsonrpseeError; use jsonrpsee::server::{BatchRequestConfig, ServerBuilder}; use jsonrpsee::RpcModule; @@ -20,7 +21,7 @@ use sequencer_client::{GetSoftConfirmationResponse, SequencerClient}; use sov_db::ledger_db::{ProverLedgerOps, SlotCommit}; use sov_db::schema::types::{BatchNumber, SlotNumber, StoredStateTransition}; use sov_modules_api::storage::HierarchicalStorageManager; -use sov_modules_api::{BlobReaderTrait, Context, SignedSoftConfirmationBatch, SlotData}; +use sov_modules_api::{BlobReaderTrait, Context, SignedSoftConfirmationBatch, SlotData, StateDiff}; use sov_modules_stf_blueprint::StfBlueprintTrait; use sov_rollup_interface::da::{BlockHeaderTrait, DaData, DaSpec, SequencerCommitment}; use sov_rollup_interface::rpc::SoftConfirmationStatus; @@ -233,7 +234,7 @@ where // Check da block get and sync up to the latest block in the latest commitment let last_scanned_l1_height = self .ledger_db - .get_prover_last_scanned_l1_height() + .get_last_scanned_l1_height() .unwrap_or_else(|_| panic!("Failed to get last scanned l1 height from the ledger db")); let start_l1_height = match last_scanned_l1_height { @@ -345,6 +346,9 @@ where bail!("Post state root mismatch at height: {}", l2_height) } + // Save state diff to ledger DB + self.ledger_db + .set_l2_state_diff(BatchNumber(l2_height), slot_result.state_diff)?; // Save witness data to ledger db self.ledger_db .set_l2_witness(l2_height, &slot_result.witness)?; @@ -432,13 +436,15 @@ where da_data.iter_mut().for_each(|blob| { blob.full_data(); }); - let sequencer_commitments: Vec = + let mut sequencer_commitments: Vec = self.extract_sequencer_commitments(l1_block.header().hash().into(), &mut da_data); + // Make sure all sequencer commitments are stored in ascending order. + sequencer_commitments.sort_by_key(|commitment| commitment.l2_start_block_number); if sequencer_commitments.is_empty() { info!("No sequencer commitment found at height {}", l1_height,); self.ledger_db - .set_prover_last_scanned_l1_height(SlotNumber(l1_height)) + .set_last_scanned_l1_height(SlotNumber(l1_height)) .unwrap_or_else(|_| { panic!( "Failed to put prover last scanned l1 height in the ledger db {}", @@ -456,104 +462,112 @@ where l1_block.header().height(), ); - let first_l2_height_of_l1 = sequencer_commitments[0].l2_start_block_number; - let last_l2_height_of_l1 = - sequencer_commitments[sequencer_commitments.len() - 1].l2_end_block_number; - // If the L2 range does not exist, we break off the local loop getting back to // the outer loop / select to make room for other tasks to run. // We retry the L1 block there as well. - if !self.check_l2_range_exists(first_l2_height_of_l1, last_l2_height_of_l1) { + if !self.check_l2_range_exists( + sequencer_commitments[0].l2_start_block_number, + sequencer_commitments[sequencer_commitments.len() - 1].l2_end_block_number, + ) { break; } - let ( - state_transition_witnesses, - soft_confirmations, - da_block_headers_of_soft_confirmations, - ) = self - .get_state_transition_data_from_commitments( - &sequencer_commitments, - &self.da_service, - ) - .await?; - - let da_block_header_of_commitments = l1_block.header().clone(); - - let hash = da_block_header_of_commitments.hash(); - let initial_state_root = self - .ledger_db - .get_l2_state_root::(first_l2_height_of_l1 - 1)? - .expect("There should be a state root"); - let initial_batch_hash = self - .ledger_db - .get_soft_confirmation_by_number(&BatchNumber(first_l2_height_of_l1))? - .ok_or(anyhow!( - "Could not find soft confirmation at height {}", - first_l2_height_of_l1 - ))? - .prev_hash; - - let final_state_root = self - .ledger_db - .get_l2_state_root::(last_l2_height_of_l1)? - .expect("There should be a state root"); + let sequencer_commitments_groups = + self.break_sequencer_commitments_into_groups(sequencer_commitments)?; - let (inclusion_proof, completeness_proof) = self - .da_service - .get_extraction_proof(l1_block, &da_data) - .await; + for sequencer_commitments in sequencer_commitments_groups { + let first_l2_height_of_l1 = sequencer_commitments[0].l2_start_block_number; + let last_l2_height_of_l1 = + sequencer_commitments[sequencer_commitments.len() - 1].l2_end_block_number; - let transition_data: StateTransitionData = - StateTransitionData { - initial_state_root, - final_state_root, - initial_batch_hash, - da_data, - da_block_header_of_commitments, - inclusion_proof, - completeness_proof, - soft_confirmations, + let ( state_transition_witnesses, + soft_confirmations, da_block_headers_of_soft_confirmations, - sequencer_commitments_range: ( - 0, - (sequencer_commitments.len() - 1) - .try_into() - .expect("cant be more than 4 billion commitments in a da block; qed"), - ), // for now process all commitments - sequencer_public_key: self.sequencer_pub_key.clone(), - sequencer_da_public_key: self.sequencer_da_pub_key.clone(), + ) = self + .get_state_transition_data_from_commitments( + &sequencer_commitments, + &self.da_service, + ) + .await?; + + let da_block_header_of_commitments = l1_block.header().clone(); + + let hash = da_block_header_of_commitments.hash(); + let initial_state_root = self + .ledger_db + .get_l2_state_root::(first_l2_height_of_l1 - 1)? + .expect("There should be a state root"); + let initial_batch_hash = self + .ledger_db + .get_soft_confirmation_by_number(&BatchNumber(first_l2_height_of_l1))? + .ok_or(anyhow!( + "Could not find soft confirmation at height {}", + first_l2_height_of_l1 + ))? + .prev_hash; + + let final_state_root = self + .ledger_db + .get_l2_state_root::(last_l2_height_of_l1)? + .expect("There should be a state root"); + + let (inclusion_proof, completeness_proof) = self + .da_service + .get_extraction_proof(l1_block, &da_data.clone()) + .await; + + let transition_data: StateTransitionData = + StateTransitionData { + initial_state_root, + final_state_root, + initial_batch_hash, + da_data: da_data.clone(), + da_block_header_of_commitments, + inclusion_proof, + completeness_proof, + soft_confirmations, + state_transition_witnesses, + da_block_headers_of_soft_confirmations, + sequencer_commitments_range: ( + 0, + (sequencer_commitments.len() - 1).try_into().expect( + "cant be more than 4 billion commitments in a da block; qed", + ), + ), // for now process all commitments + sequencer_public_key: self.sequencer_pub_key.clone(), + sequencer_da_public_key: self.sequencer_da_pub_key.clone(), + }; + + let should_prove: bool = { + let mut rng = rand::thread_rng(); + // if proof_sampling_number is 0, then we always prove and submit + // otherwise we submit and prove with a probability of 1/proof_sampling_number + if prover_config.proof_sampling_number == 0 { + true + } else { + rng.gen_range(0..prover_config.proof_sampling_number) == 0 + } }; - let should_prove: bool = { - let mut rng = rand::thread_rng(); - // if proof_sampling_number is 0, then we always prove and submit - // otherwise we submit and prove with a probability of 1/proof_sampling_number - if prover_config.proof_sampling_number == 0 { - true + // Skip submission until l1 height + if l1_height >= skip_submission_until_l1 && should_prove { + self.generate_and_submit_proof(transition_data, l1_height, hash) + .await?; } else { - rng.gen_range(0..prover_config.proof_sampling_number) == 0 + info!("Skipping proving for l1 height {}", l1_height); } - }; - - // Skip submission until l1 height - if l1_height >= skip_submission_until_l1 && should_prove { - self.generate_and_submit_proof(transition_data, l1_height, hash) - .await?; - } else { - info!("Skipping proving for l1 height {}", l1_height); - } - self.save_commitments(sequencer_commitments, l1_height); + self.save_commitments(sequencer_commitments, l1_height); - if let Err(e) = self - .ledger_db - .set_prover_last_scanned_l1_height(SlotNumber(l1_height)) - { - panic!( - "Failed to put prover last scanned l1 height in the ledger db: {}", - e - ); + if let Err(e) = self + .ledger_db + .set_last_scanned_l1_height(SlotNumber(l1_height)) + { + panic!( + "Failed to put prover last scanned l1 height in the ledger db: {}", + e + ); + } } pending_l1_blocks.pop_front(); @@ -786,6 +800,61 @@ where pub fn get_state_root(&self) -> &Stf::StateRoot { &self.state_root } + + fn break_sequencer_commitments_into_groups( + &self, + sequencer_commitments: Vec, + ) -> anyhow::Result>> { + let mut result = vec![]; + + let mut group = vec![]; + let mut cumulative_state_diff = StateDiff::new(); + for sequencer_commitment in sequencer_commitments { + let mut sequencer_commitment_state_diff = StateDiff::new(); + for l2_height in sequencer_commitment.l2_start_block_number + ..=sequencer_commitment.l2_end_block_number + { + let state_diff = self + .ledger_db + .get_l2_state_diff(BatchNumber(l2_height))? + .ok_or(anyhow!( + "Could not find state diff for L2 range {}-{}", + sequencer_commitment.l2_start_block_number, + sequencer_commitment.l2_end_block_number + ))?; + sequencer_commitment_state_diff = + merge_state_diffs(sequencer_commitment_state_diff, state_diff); + } + cumulative_state_diff = merge_state_diffs( + cumulative_state_diff, + sequencer_commitment_state_diff.clone(), + ); + + let serialized_state_diff = borsh::to_vec(&cumulative_state_diff)?; + + let state_diff_threshold_reached = + serialized_state_diff.len() as u64 > MAX_STATEDIFF_SIZE_PROOF_THRESHOLD; + + if state_diff_threshold_reached && !group.is_empty() { + // We've exceeded the limit with the current commitments + // so we have to stop at the previous one. + result.push(group); + // Reset the cumulative state diff to be equal to the current commitment state diff + cumulative_state_diff = sequencer_commitment_state_diff; + group = vec![sequencer_commitment.clone()]; + } else { + group.push(sequencer_commitment.clone()); + } + } + + // If the last group hasn't been reset because it has not reached the threshold, + // Add it anyway + if !group.is_empty() { + result.push(group); + } + + Ok(result) + } } async fn l1_sync( diff --git a/crates/sequencer/Cargo.toml b/crates/sequencer/Cargo.toml index 699c246ef6..c0146665d8 100644 --- a/crates/sequencer/Cargo.toml +++ b/crates/sequencer/Cargo.toml @@ -16,7 +16,6 @@ resolver = "2" alloy-rlp = { workspace = true } alloy-sol-types = { workspace = true } anyhow = { workspace = true } -bincode = { workspace = true } borsh = { workspace = true } chrono = { workspace = true } digest = { workspace = true } diff --git a/crates/sequencer/src/sequencer.rs b/crates/sequencer/src/sequencer.rs index 97fa58fce5..41aacf04ae 100644 --- a/crates/sequencer/src/sequencer.rs +++ b/crates/sequencer/src/sequencer.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::marker::PhantomData; use std::net::SocketAddr; use std::ops::RangeInclusive; @@ -11,6 +11,8 @@ use borsh::BorshDeserialize; use citrea_evm::{CallMessage, Evm, RlpEvmTransaction, MIN_TRANSACTION_GAS}; use citrea_primitives::fork::{Fork, ForkManager}; use citrea_primitives::types::SoftConfirmationHash; +use citrea_primitives::utils::merge_state_diffs; +use citrea_primitives::MAX_STATEDIFF_SIZE_COMMITMENT_THRESHOLD; use citrea_stf::runtime::Runtime; use digest::Digest; use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; @@ -56,8 +58,6 @@ use crate::mempool::CitreaMempool; use crate::rpc::{create_rpc_module, RpcContext}; use crate::utils::recover_raw_transaction; -const MAX_STATEDIFF_SIZE_COMMITMENT_THRESHOLD: u64 = 300 * 1024; - type StateRoot = >::StateRoot; /// Represents information about the current DA state. /// @@ -572,12 +572,11 @@ where self.mempool.update_accounts(account_updates); - let merged_state_diff = self.merge_state_diffs( - self.last_state_diff.clone(), - slot_result.state_diff.clone(), - ); + let merged_state_diff = + merge_state_diffs(self.last_state_diff.clone(), slot_result.state_diff.clone()); + // Serialize the state diff to check size later. - let serialized_state_diff = bincode::serialize(&merged_state_diff)?; + let serialized_state_diff = borsh::to_vec(&merged_state_diff)?; let state_diff_threshold_reached = serialized_state_diff.len() as u64 > MAX_STATEDIFF_SIZE_COMMITMENT_THRESHOLD; if state_diff_threshold_reached { @@ -1183,13 +1182,6 @@ where Ok(updates) } - - fn merge_state_diffs(&self, old_diff: StateDiff, new_diff: StateDiff) -> StateDiff { - let mut new_diff_map = HashMap::, Option>>::from_iter(old_diff); - - new_diff_map.extend(new_diff); - new_diff_map.into_iter().collect() - } } fn get_l1_fee_rate_range( diff --git a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs index 6f1dfd93fd..cde955a722 100644 --- a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs +++ b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs @@ -16,9 +16,9 @@ use crate::rocks_db_config::gen_rocksdb_options; use crate::schema::tables::{ ActiveFork, BatchByNumber, CommitmentsByNumber, EventByKey, EventByNumber, L2GenesisStateRoot, L2RangeByL1Height, L2Witness, LastSequencerCommitmentSent, LastStateDiff, MempoolTxs, - PendingSequencerCommitmentL2Range, ProofBySlotNumber, ProverLastScannedSlot, SlotByHash, - SlotByNumber, SoftConfirmationByHash, SoftConfirmationByNumber, SoftConfirmationStatus, - TxByHash, TxByNumber, VerifiedProofsBySlotNumber, LEDGER_TABLES, + PendingSequencerCommitmentL2Range, ProofBySlotNumber, ProverLastScannedSlot, ProverStateDiffs, + SlotByHash, SlotByNumber, SoftConfirmationByHash, SoftConfirmationByNumber, + SoftConfirmationStatus, TxByHash, TxByNumber, VerifiedProofsBySlotNumber, LEDGER_TABLES, }; use crate::schema::types::{ split_tx_for_storage, BatchNumber, EventNumber, L2HeightRange, SlotNumber, StoredProof, @@ -486,14 +486,14 @@ impl ProverLedgerOps for LedgerDB { /// Get the last scanned slot by the prover #[instrument(level = "trace", skip(self), err, ret)] - fn get_prover_last_scanned_l1_height(&self) -> anyhow::Result> { + fn get_last_scanned_l1_height(&self) -> anyhow::Result> { self.db.get::(&()) } /// Set the last scanned slot by the prover /// Called by the prover. #[instrument(level = "trace", skip(self), err, ret)] - fn set_prover_last_scanned_l1_height(&self, l1_height: SlotNumber) -> anyhow::Result<()> { + fn set_last_scanned_l1_height(&self, l1_height: SlotNumber) -> anyhow::Result<()> { let mut schema_batch = SchemaBatch::new(); schema_batch.put::(&(), &l1_height)?; @@ -550,6 +550,23 @@ impl ProverLedgerOps for LedgerDB { Ok(()) } + + fn set_l2_state_diff( + &self, + l2_height: BatchNumber, + state_diff: StateDiff, + ) -> anyhow::Result<()> { + let mut schema_batch = SchemaBatch::new(); + schema_batch.put::(&l2_height, &state_diff)?; + + self.db.write_schemas(schema_batch)?; + + Ok(()) + } + + fn get_l2_state_diff(&self, l2_height: BatchNumber) -> anyhow::Result> { + self.db.get::(&l2_height) + } } impl SequencerLedgerOps for LedgerDB { diff --git a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/rpc.rs b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/rpc.rs index 4962b11640..68a727f2f4 100644 --- a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/rpc.rs +++ b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/rpc.rs @@ -135,7 +135,7 @@ impl LedgerRpcProvider for LedgerDB { } fn get_prover_last_scanned_l1_height(&self) -> Result { - match ProverLedgerOps::get_prover_last_scanned_l1_height(self)? { + match ProverLedgerOps::get_last_scanned_l1_height(self)? { Some(height) => Ok(height.0), None => Ok(0), } diff --git a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs index 26a8707a50..f064eaaed4 100644 --- a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs +++ b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs @@ -155,11 +155,11 @@ pub trait ProverLedgerOps: SharedLedgerOps { ) -> anyhow::Result>; /// Get the last scanned slot by the prover - fn get_prover_last_scanned_l1_height(&self) -> Result>; + fn get_last_scanned_l1_height(&self) -> Result>; /// Set the last scanned slot by the prover /// Called by the prover. - fn set_prover_last_scanned_l1_height(&self, l1_height: SlotNumber) -> Result<()>; + fn set_last_scanned_l1_height(&self, l1_height: SlotNumber) -> Result<()>; /// Get the witness by L2 height fn get_l2_witness(&self, l2_height: u64) -> Result>; @@ -175,6 +175,12 @@ pub trait ProverLedgerOps: SharedLedgerOps { /// Set the witness by L2 height fn set_l2_witness(&self, l2_height: u64, witness: &Witness) -> Result<()>; + + /// Save a specific L2 range state diff + fn set_l2_state_diff(&self, l2_height: BatchNumber, state_diff: StateDiff) -> Result<()>; + + /// Returns an L2 state diff + fn get_l2_state_diff(&self, l2_height: BatchNumber) -> Result>; } /// Sequencer ledger operations diff --git a/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs b/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs index 41d4aab0cc..f089a33d6e 100644 --- a/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs +++ b/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs @@ -74,6 +74,7 @@ pub const LEDGER_TABLES: &[&str] = &[ ProofBySlotNumber::table_name(), VerifiedProofsBySlotNumber::table_name(), MempoolTxs::table_name(), + ProverStateDiffs::table_name(), ]; /// A list of all tables used by the NativeDB. These tables store @@ -337,6 +338,11 @@ define_table_with_default_codec!( (MempoolTxs) Vec => Vec ); +define_table_with_default_codec!( + /// L2 height to state diff for prover + (ProverStateDiffs) BatchNumber => StateDiff +); + impl KeyEncoder for NodeKey { fn encode_key(&self) -> sov_schema_db::schema::Result> { // 8 bytes for version, 4 each for the num_nibbles and bytes.len() fields, plus 1 byte per byte of nibllepath diff --git a/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs b/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs index 1491caf6b7..614f7ddad1 100644 --- a/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs +++ b/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs @@ -24,8 +24,10 @@ pub use wallet::*; pub trait RollupBlueprint: Sized + Send + Sync { /// Data Availability service. type DaService: DaService + Send + Sync; + /// A specification for the types used by a DA layer. type DaSpec: DaSpec + Send + Sync; + /// Data Availability config. type DaConfig: Send + Sync; @@ -34,6 +36,7 @@ pub trait RollupBlueprint: Sized + Send + Sync { /// Context for Zero Knowledge environment. type ZkContext: Context; + /// Context for Native environment. type NativeContext: Context; diff --git a/crates/sovereign-sdk/rollup-interface/src/state_machine/da.rs b/crates/sovereign-sdk/rollup-interface/src/state_machine/da.rs index ec9c0bd247..241a5f32c9 100644 --- a/crates/sovereign-sdk/rollup-interface/src/state_machine/da.rs +++ b/crates/sovereign-sdk/rollup-interface/src/state_machine/da.rs @@ -42,7 +42,7 @@ pub trait DaSpec: type BlockHeader: BlockHeaderTrait + Send + Sync; /// The transaction type used by the DA layer. - type BlobTransaction: BlobReaderTrait
+ Send + Sync; + type BlobTransaction: BlobReaderTrait
+ Send + Sync + Clone; /// The type used to represent addresses on the DA layer. type Address: BasicAddress + Send + Sync; From a6807835d0f687e88af65287ee7a9c38e596dca0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erce=20Can=20Bekt=C3=BCre?= <47954181+ercecan@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:10:52 +0300 Subject: [PATCH 3/5] Crash recovery (#938) * WIP crash recovery * Implement bonsai host crash recovery * Store session ids by l1 height & get rid of latest proof l1 hash * Lint * Nits * Remove no-std breaking function from zkvmhost trait * Fmt * Update session storage logic * Lint * Clear pending sessions in prover service * Use l1 height data from proof data for recovered proofs * Naming * Lint * Fix merge bug --- Cargo.lock | 1 + bin/citrea/src/rollup/bitcoin.rs | 3 + bin/citrea/src/rollup/mock.rs | 3 + bin/citrea/src/rollup/mod.rs | 10 +- .../prover/src/prover_service/parallel/mod.rs | 30 ++ crates/prover/src/runner.rs | 257 ++++++++++------ crates/prover/tests/prover_tests.rs | 8 +- crates/risc0-bonsai/Cargo.toml | 1 + crates/risc0-bonsai/src/host.rs | 282 ++++++++++++------ .../adapters/mock-zkvm/src/lib.rs | 6 +- .../sovereign-sdk/adapters/risc0/src/host.rs | 4 + .../full-node/db/sov-db/src/ledger_db/mod.rs | 59 +++- .../db/sov-db/src/ledger_db/traits.rs | 23 +- .../full-node/db/sov-db/src/schema/tables.rs | 7 + .../sov-stf-runner/src/prover_service/mod.rs | 6 + .../sov-modules-rollup-blueprint/src/lib.rs | 1 + .../src/state_machine/zk/mod.rs | 3 + 17 files changed, 502 insertions(+), 202 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0bf3bf786d..7c8951e52f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1968,6 +1968,7 @@ dependencies = [ "risc0-zkvm", "risc0-zkvm-platform", "serde", + "sov-db", "sov-risc0-adapter", "sov-rollup-interface", "tracing", diff --git a/bin/citrea/src/rollup/bitcoin.rs b/bin/citrea/src/rollup/bitcoin.rs index bd4d2fb5e1..a2b362a3c6 100644 --- a/bin/citrea/src/rollup/bitcoin.rs +++ b/bin/citrea/src/rollup/bitcoin.rs @@ -135,11 +135,13 @@ impl RollupBlueprint for BitcoinRollup { prover_config: ProverConfig, _rollup_config: &FullNodeConfig, _da_service: &Arc, + ledger_db: LedgerDB, ) -> Self::ProverService { let vm = Risc0BonsaiHost::new( citrea_risc0::BITCOIN_DA_ELF, std::env::var("BONSAI_API_URL").unwrap_or("".to_string()), std::env::var("BONSAI_API_KEY").unwrap_or("".to_string()), + ledger_db.clone(), ); let zk_stf = StfBlueprint::new(); let zk_storage = ZkStorage::new(); @@ -155,6 +157,7 @@ impl RollupBlueprint for BitcoinRollup { da_verifier, prover_config, zk_storage, + ledger_db, ) .expect("Should be able to instantiate prover service") } diff --git a/bin/citrea/src/rollup/mock.rs b/bin/citrea/src/rollup/mock.rs index 2ae87456fb..b2da311337 100644 --- a/bin/citrea/src/rollup/mock.rs +++ b/bin/citrea/src/rollup/mock.rs @@ -100,11 +100,13 @@ impl RollupBlueprint for MockDemoRollup { prover_config: ProverConfig, _rollup_config: &FullNodeConfig, _da_service: &Arc, + ledger_db: LedgerDB, ) -> Self::ProverService { let vm = Risc0BonsaiHost::new( citrea_risc0::MOCK_DA_ELF, std::env::var("BONSAI_API_URL").unwrap_or("".to_string()), std::env::var("BONSAI_API_KEY").unwrap_or("".to_string()), + ledger_db.clone(), ); let zk_stf = StfBlueprint::new(); let zk_storage = ZkStorage::new(); @@ -116,6 +118,7 @@ impl RollupBlueprint for MockDemoRollup { da_verifier, prover_config, zk_storage, + ledger_db, ) .expect("Should be able to instantiate prover service") } diff --git a/bin/citrea/src/rollup/mod.rs b/bin/citrea/src/rollup/mod.rs index dd8e9e2e39..b28973697f 100644 --- a/bin/citrea/src/rollup/mod.rs +++ b/bin/citrea/src/rollup/mod.rs @@ -230,15 +230,21 @@ pub trait CitreaRollupBlueprint: RollupBlueprint { { let da_service = self.create_da_service(&rollup_config).await?; + let ledger_db = self.create_ledger_db(&rollup_config); + let prover_service = self - .create_prover_service(prover_config.clone(), &rollup_config, &da_service) + .create_prover_service( + prover_config.clone(), + &rollup_config, + &da_service, + ledger_db.clone(), + ) .await; // TODO: Double check what kind of storage needed here. // Maybe whole "prev_root" can be initialized inside runner // Getting block here, so prover_service doesn't have to be `Send` - let ledger_db = self.create_ledger_db(&rollup_config); let genesis_config = self.create_genesis_config(runtime_genesis_paths, &rollup_config)?; let mut storage_manager = self.create_storage_manager(&rollup_config)?; diff --git a/crates/prover/src/prover_service/parallel/mod.rs b/crates/prover/src/prover_service/parallel/mod.rs index 0c8cd0e6b9..de4f8704d6 100644 --- a/crates/prover/src/prover_service/parallel/mod.rs +++ b/crates/prover/src/prover_service/parallel/mod.rs @@ -7,6 +7,7 @@ use citrea_stf::verifier::StateTransitionVerifier; use prover::Prover; use serde::de::DeserializeOwned; use serde::Serialize; +use sov_db::ledger_db::{LedgerDB, ProvingServiceLedgerOps}; use sov_rollup_interface::da::{DaData, DaSpec}; use sov_rollup_interface::services::da::DaService; use sov_rollup_interface::stf::StateTransitionFunction; @@ -34,6 +35,7 @@ where zk_storage: V::PreState, prover_state: Prover, + ledger_db: LedgerDB, } impl ParallelProverService @@ -62,6 +64,7 @@ where config: ProverGuestRunConfig, zk_storage: V::PreState, num_threads: usize, + ledger_db: LedgerDB, ) -> anyhow::Result { let stf_verifier = StateTransitionVerifier::::new(zk_stf, da_verifier); @@ -96,6 +99,7 @@ where prover_config, prover_state: Prover::new(num_threads)?, zk_storage, + ledger_db, }) } @@ -106,6 +110,7 @@ where da_verifier: Da::Verifier, prover_config: ProverConfig, zk_storage: V::PreState, + ledger_db: LedgerDB, ) -> anyhow::Result { let num_cpus = num_cpus::get(); assert!(num_cpus > 1, "Unable to create parallel prover service"); @@ -117,6 +122,7 @@ where prover_config.proving_mode, zk_storage, num_cpus - 1, + ledger_db, ) } } @@ -192,6 +198,7 @@ where .send_transaction(da_data) .await .map_err(|e| anyhow::anyhow!(e))?; + self.ledger_db.clear_pending_proving_sessions()?; break Ok((tx_id, proof)); } ProverStatus::ProvingInProgress => { @@ -203,4 +210,27 @@ where } } } + + async fn recover_proving_sessions_and_send_to_da( + &self, + da_service: &Arc, + ) -> Result::TransactionId, Proof)>, anyhow::Error> { + tracing::info!("Checking if ongoing bonsai session exists"); + + let vm = self.vm.clone(); + let proofs = vm.recover_proving_sessions()?; + + let mut results = Vec::new(); + + for proof in proofs.into_iter() { + let da_data = DaData::ZKProof(proof.clone()); + let tx_id = da_service + .send_transaction(da_data) + .await + .map_err(|e| anyhow::anyhow!(e))?; + results.push((tx_id, proof)); + } + self.ledger_db.clear_pending_proving_sessions()?; + Ok(results) + } } diff --git a/crates/prover/src/runner.rs b/crates/prover/src/runner.rs index 01dc76b6c9..25c80a4d5b 100644 --- a/crates/prover/src/runner.rs +++ b/crates/prover/src/runner.rs @@ -224,6 +224,24 @@ where }); } + async fn check_and_recover_ongoing_proving_sessions(&self) -> Result { + let prover_service = self + .prover_service + .as_ref() + .expect("Prover service should be present"); + let results = prover_service + .recover_proving_sessions_and_send_to_da(&self.da_service) + .await?; + if results.is_empty() { + Ok(false) + } else { + for (tx_id, proof) in results { + self.extract_and_store_proof(tx_id, proof).await?; + } + Ok(true) + } + } + /// Runs the rollup. #[instrument(level = "trace", skip_all, err)] pub async fn run(&mut self) -> Result<(), anyhow::Error> { @@ -416,6 +434,7 @@ where skip_submission_until_l1: u64, prover_config: &ProverConfig, ) -> Result<(), anyhow::Error> { + let mut proving_session_exists = self.check_and_recover_ongoing_proving_sessions().await?; while !pending_l1_blocks.is_empty() { let l1_block = pending_l1_blocks .front() @@ -431,7 +450,8 @@ where ) .unwrap(); - let mut da_data = self.da_service.extract_relevant_blobs(l1_block); + let mut da_data: Vec<<::Spec as DaSpec>::BlobTransaction> = + self.da_service.extract_relevant_blobs(l1_block); // if we don't do this, the zk circuit can't read the sequencer commitments da_data.iter_mut().for_each(|blob| { blob.full_data(); @@ -472,102 +492,51 @@ where break; } - let sequencer_commitments_groups = - self.break_sequencer_commitments_into_groups(sequencer_commitments)?; - - for sequencer_commitments in sequencer_commitments_groups { - let first_l2_height_of_l1 = sequencer_commitments[0].l2_start_block_number; - let last_l2_height_of_l1 = - sequencer_commitments[sequencer_commitments.len() - 1].l2_end_block_number; - - let ( - state_transition_witnesses, - soft_confirmations, - da_block_headers_of_soft_confirmations, - ) = self - .get_state_transition_data_from_commitments( - &sequencer_commitments, - &self.da_service, + let da_block_header_of_commitments: <::Spec as DaSpec>::BlockHeader = + l1_block.header().clone(); + + let hash = da_block_header_of_commitments.hash(); + + if !proving_session_exists { + let sequencer_commitments_groups = + self.break_sequencer_commitments_into_groups(sequencer_commitments)?; + for sequencer_commitments in sequencer_commitments_groups { + // There is no ongoing bonsai session to recover + let transition_data: StateTransitionData< + Stf::StateRoot, + Stf::Witness, + Da::Spec, + > = self + .create_state_transition_data( + &sequencer_commitments, + da_block_header_of_commitments.clone(), + da_data.clone(), + l1_block, + ) + .await?; + + self.prove_state_transition( + transition_data, + prover_config, + skip_submission_until_l1, + l1_height, + hash.clone(), ) .await?; + proving_session_exists = false; - let da_block_header_of_commitments = l1_block.header().clone(); - - let hash = da_block_header_of_commitments.hash(); - let initial_state_root = self - .ledger_db - .get_l2_state_root::(first_l2_height_of_l1 - 1)? - .expect("There should be a state root"); - let initial_batch_hash = self - .ledger_db - .get_soft_confirmation_by_number(&BatchNumber(first_l2_height_of_l1))? - .ok_or(anyhow!( - "Could not find soft confirmation at height {}", - first_l2_height_of_l1 - ))? - .prev_hash; - - let final_state_root = self - .ledger_db - .get_l2_state_root::(last_l2_height_of_l1)? - .expect("There should be a state root"); - - let (inclusion_proof, completeness_proof) = self - .da_service - .get_extraction_proof(l1_block, &da_data.clone()) - .await; - - let transition_data: StateTransitionData = - StateTransitionData { - initial_state_root, - final_state_root, - initial_batch_hash, - da_data: da_data.clone(), - da_block_header_of_commitments, - inclusion_proof, - completeness_proof, - soft_confirmations, - state_transition_witnesses, - da_block_headers_of_soft_confirmations, - sequencer_commitments_range: ( - 0, - (sequencer_commitments.len() - 1).try_into().expect( - "cant be more than 4 billion commitments in a da block; qed", - ), - ), // for now process all commitments - sequencer_public_key: self.sequencer_pub_key.clone(), - sequencer_da_public_key: self.sequencer_da_pub_key.clone(), - }; - - let should_prove: bool = { - let mut rng = rand::thread_rng(); - // if proof_sampling_number is 0, then we always prove and submit - // otherwise we submit and prove with a probability of 1/proof_sampling_number - if prover_config.proof_sampling_number == 0 { - true - } else { - rng.gen_range(0..prover_config.proof_sampling_number) == 0 - } - }; - - // Skip submission until l1 height - if l1_height >= skip_submission_until_l1 && should_prove { - self.generate_and_submit_proof(transition_data, l1_height, hash) - .await?; - } else { - info!("Skipping proving for l1 height {}", l1_height); + self.save_commitments(sequencer_commitments, l1_height); } - self.save_commitments(sequencer_commitments, l1_height); + } - if let Err(e) = self - .ledger_db - .set_last_scanned_l1_height(SlotNumber(l1_height)) - { - panic!( - "Failed to put prover last scanned l1 height in the ledger db: {}", - e - ); - } + if let Err(e) = self + .ledger_db + .set_last_scanned_l1_height(SlotNumber(l1_height)) + { + panic!( + "Failed to put prover last scanned l1 height in the ledger db: {}", + e + ); } pending_l1_blocks.pop_front(); @@ -575,6 +544,93 @@ where Ok(()) } + async fn prove_state_transition( + &self, + transition_data: StateTransitionData, + prover_config: &ProverConfig, + skip_submission_until_l1: u64, + l1_height: u64, + hash: <::Spec as DaSpec>::SlotHash, + ) -> Result<(), anyhow::Error> { + // if proof_sampling_number is 0, then we always prove and submit + // otherwise we submit and prove with a probability of 1/proof_sampling_number + let should_prove = prover_config.proof_sampling_number == 0 + || rand::thread_rng().gen_range(0..prover_config.proof_sampling_number) == 0; + + // Skip submission until l1 height + if l1_height >= skip_submission_until_l1 && should_prove { + self.generate_and_submit_proof(transition_data, hash) + .await?; + } else { + info!("Skipping proving for l1 height {}", l1_height); + } + Ok(()) + } + + async fn create_state_transition_data( + &self, + sequencer_commitments: &[SequencerCommitment], + da_block_header_of_commitments: <::Spec as DaSpec>::BlockHeader, + da_data: Vec<<::Spec as DaSpec>::BlobTransaction>, + l1_block: &Da::FilteredBlock, + ) -> Result, anyhow::Error> { + let first_l2_height_of_l1 = sequencer_commitments[0].l2_start_block_number; + let last_l2_height_of_l1 = + sequencer_commitments[sequencer_commitments.len() - 1].l2_end_block_number; + let ( + state_transition_witnesses, + soft_confirmations, + da_block_headers_of_soft_confirmations, + ) = self + .get_state_transition_data_from_commitments(sequencer_commitments, &self.da_service) + .await?; + let initial_state_root = self + .ledger_db + .get_l2_state_root::(first_l2_height_of_l1 - 1)? + .expect("There should be a state root"); + let initial_batch_hash = self + .ledger_db + .get_soft_confirmation_by_number(&BatchNumber(first_l2_height_of_l1))? + .ok_or(anyhow!( + "Could not find soft batch at height {}", + first_l2_height_of_l1 + ))? + .prev_hash; + + let final_state_root = self + .ledger_db + .get_l2_state_root::(last_l2_height_of_l1)? + .expect("There should be a state root"); + + let (inclusion_proof, completeness_proof) = self + .da_service + .get_extraction_proof(l1_block, &da_data) + .await; + + let transition_data: StateTransitionData = + StateTransitionData { + initial_state_root, + final_state_root, + initial_batch_hash, + da_data, + da_block_header_of_commitments, + inclusion_proof, + completeness_proof, + soft_confirmations, + state_transition_witnesses, + da_block_headers_of_soft_confirmations, + sequencer_commitments_range: ( + 0, + (sequencer_commitments.len() - 1) + .try_into() + .expect("cant be more than 4 billion commitments in a da block; qed"), + ), // for now process all commitments + sequencer_public_key: self.sequencer_pub_key.clone(), + sequencer_da_public_key: self.sequencer_da_pub_key.clone(), + }; + Ok(transition_data) + } + fn extract_sequencer_commitments( &self, l1_block_hash: [u8; 32], @@ -701,7 +757,6 @@ where async fn generate_and_submit_proof( &self, transition_data: StateTransitionData, - l1_height: u64, hash: <::Spec as DaSpec>::SlotHash, ) -> Result<(), anyhow::Error> { let prover_service = self @@ -723,6 +778,14 @@ where } }; + self.extract_and_store_proof(tx_id, proof).await + } + + async fn extract_and_store_proof( + &self, + tx_id: ::TransactionId, + proof: Proof, + ) -> Result<(), anyhow::Error> { let tx_id_u8 = tx_id.into(); // l1_height => (tx_id, proof, transition_data) @@ -754,16 +817,22 @@ where info!("transition data: {:?}", transition_data); + let slot_hash = transition_data.da_slot_hash.into(); + let stored_state_transition = StoredStateTransition { initial_state_root: transition_data.initial_state_root.as_ref().to_vec(), final_state_root: transition_data.final_state_root.as_ref().to_vec(), state_diff: transition_data.state_diff, - da_slot_hash: transition_data.da_slot_hash.into(), + da_slot_hash: slot_hash, sequencer_commitments_range: transition_data.sequencer_commitments_range, sequencer_public_key: transition_data.sequencer_public_key, sequencer_da_public_key: transition_data.sequencer_da_public_key, validity_condition: borsh::to_vec(&transition_data.validity_condition).unwrap(), }; + let l1_height = self + .ledger_db + .get_l1_height_of_l1_hash(slot_hash)? + .expect("l1 height should exist"); if let Err(e) = self.ledger_db diff --git a/crates/prover/tests/prover_tests.rs b/crates/prover/tests/prover_tests.rs index c869f4cb57..df0102735d 100644 --- a/crates/prover/tests/prover_tests.rs +++ b/crates/prover/tests/prover_tests.rs @@ -2,6 +2,7 @@ use std::collections::VecDeque; use std::sync::Arc; use citrea_prover::prover_service::ParallelProverService; +use sov_db::ledger_db::LedgerDB; use sov_mock_da::{ MockAddress, MockBlockHeader, MockDaService, MockDaSpec, MockDaVerifier, MockHash, MockValidityCond, @@ -15,7 +16,7 @@ use sov_stf_runner::{ WitnessSubmissionStatus, }; -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn test_successful_prover_execution() -> Result<(), ProverServiceError> { let temp = tempfile::tempdir().unwrap(); @@ -49,7 +50,7 @@ async fn test_successful_prover_execution() -> Result<(), ProverServiceError> { Ok(()) } -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn test_prover_status_busy() -> Result<(), anyhow::Error> { let temp = tempfile::tempdir().unwrap(); let da_service = Arc::new(MockDaService::new(MockAddress::from([0; 32]), temp.path())); @@ -190,6 +191,8 @@ fn make_new_prover() -> TestProver { let prover_config = ProverGuestRunConfig::Execute; let zk_stf = MockStf::::default(); let da_verifier = MockDaVerifier::default(); + let tmpdir = tempfile::tempdir().unwrap(); + let ledger_db = LedgerDB::with_path(tmpdir.path()).unwrap(); TestProver { prover_service: ParallelProverService::new( vm.clone(), @@ -198,6 +201,7 @@ fn make_new_prover() -> TestProver { prover_config, (), num_threads, + ledger_db, ) .expect("Should be able to instantiate Prover service"), vm, diff --git a/crates/risc0-bonsai/Cargo.toml b/crates/risc0-bonsai/Cargo.toml index a3d1cfe776..ad944f7c90 100644 --- a/crates/risc0-bonsai/Cargo.toml +++ b/crates/risc0-bonsai/Cargo.toml @@ -26,6 +26,7 @@ risc0-zkp = { workspace = true, optional = true } risc0-zkvm = { workspace = true, default-features = false, features = ["std"] } risc0-zkvm-platform = { workspace = true } serde = { workspace = true } +sov-db = { path = "../sovereign-sdk/full-node/db/sov-db" } sov-rollup-interface = { path = "../sovereign-sdk/rollup-interface" } tracing = { workspace = true } diff --git a/crates/risc0-bonsai/src/host.rs b/crates/risc0-bonsai/src/host.rs index 16f1e9fa5d..b829e938da 100644 --- a/crates/risc0-bonsai/src/host.rs +++ b/crates/risc0-bonsai/src/host.rs @@ -11,10 +11,32 @@ use risc0_zkvm::{ compute_image_id, ExecutorEnvBuilder, ExecutorImpl, Groth16Receipt, InnerReceipt, Journal, Receipt, }; +use sov_db::ledger_db::{LedgerDB, ProvingServiceLedgerOps}; use sov_risc0_adapter::guest::Risc0Guest; use sov_rollup_interface::zk::{Proof, Zkvm, ZkvmHost}; use tracing::{error, info, warn}; +type StarkSessionId = String; +type SnarkSessionId = String; + +/// Bonsai sessions to be recovered in case of a crash. +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize)] +pub enum BonsaiSession { + /// Stark session id if the prover crashed during stark proof generation. + StarkSession(StarkSessionId), + /// Both Stark and Snark session id if the prover crashed during stark to snarkconversion. + SnarkSession(StarkSessionId, SnarkSessionId), +} + +/// Recovered bonsai session. +#[derive(Debug, Clone, BorshSerialize, BorshDeserialize)] +pub struct RecoveredBonsaiSession { + /// Used for sending proofs in order + pub id: u8, + /// Recovered session + pub session: BonsaiSession, +} + macro_rules! retry_backoff_bonsai { ($bonsai_call:expr) => { retry_backoff( @@ -65,11 +87,12 @@ pub struct Risc0BonsaiHost<'a> { image_id: Digest, client: Option, last_input_id: Option, + ledger_db: LedgerDB, } impl<'a> Risc0BonsaiHost<'a> { /// Create a new Risc0Host to prove the given binary. - pub fn new(elf: &'a [u8], api_url: String, api_key: String) -> Self { + pub fn new(elf: &'a [u8], api_url: String, api_key: String, ledger_db: LedgerDB) -> Self { // Compute the image_id, then upload the ELF with the image_id as its key. // handle error let image_id = compute_image_id(elf).unwrap(); @@ -98,6 +121,7 @@ impl<'a> Risc0BonsaiHost<'a> { image_id, client, last_input_id: None, + ledger_db, } } @@ -112,6 +136,126 @@ impl<'a> Risc0BonsaiHost<'a> { tracing::info!("Uploaded input with id: {}", input_id); self.last_input_id = Some(input_id); } + + fn receipt_loop(&self, session: &str, client: &Client) -> Result, anyhow::Error> { + let session = bonsai_sdk::blocking::SessionId::new(session.to_owned()); + loop { + // handle error + let res = + retry_backoff_bonsai!(session.status(client)).expect("Failed to fetch status; qed"); + + if res.status == "RUNNING" { + tracing::info!( + "Current status: {} - state: {} - continue polling...", + res.status, + res.state.unwrap_or_default() + ); + std::thread::sleep(Duration::from_secs(15)); + continue; + } + if res.status == "SUCCEEDED" { + // Download the receipt, containing the output + let receipt_url = res + .receipt_url + .expect("API error, missing receipt on completed session"); + + tracing::info!("Receipt URL: {}", receipt_url); + + let receipt_buf = retry_backoff_bonsai!(client.download(receipt_url.as_str())) + .expect("Failed to download receipt; qed"); + break Ok(receipt_buf); + } else { + return Err(anyhow!( + "Workflow exited: {} with error message: {}", + res.status, + res.error_msg.unwrap_or_default() + )); + } + } + } + + fn wait_for_receipt(&self, session: &str) -> Result, anyhow::Error> { + let session = bonsai_sdk::blocking::SessionId::new(session.to_string()); + let client = self.client.as_ref().unwrap(); + self.receipt_loop(&session.uuid, client) + } + + fn wait_for_stark_to_snark_conversion( + &self, + snark_session: Option<&str>, + stark_session: &str, + receipt_buf: Vec, + ) -> Result { + // If snark session exists use it else create one from stark + let snark_session = match snark_session { + Some(snark_session) => bonsai_sdk::blocking::SnarkId::new(snark_session.to_string()), + None => { + let client = self.client.as_ref().unwrap(); + let session = bonsai_sdk::blocking::SessionId::new(stark_session.to_string()); + retry_backoff_bonsai!(client.create_snark(session.uuid.clone())) + .expect("Failed to create snark session; qed") + } + }; + + let recovered_serialized_snark_session = borsh::to_vec(&RecoveredBonsaiSession { + id: 0, + session: BonsaiSession::SnarkSession( + stark_session.to_string(), + snark_session.uuid.clone(), + ), + })?; + self.ledger_db + .add_pending_proving_session(recovered_serialized_snark_session.clone())?; + + let client = self.client.as_ref().unwrap(); + let receipt: Receipt = bincode::deserialize(&receipt_buf).unwrap(); + loop { + let res = retry_backoff_bonsai!(snark_session.status(client)) + .expect("Failed to fetch status; qed"); + match res.status.as_str() { + "RUNNING" => { + tracing::info!("Current status: {} - continue polling...", res.status,); + std::thread::sleep(Duration::from_secs(15)); + continue; + } + "SUCCEEDED" => { + let snark_receipt = match res.output { + Some(output) => output, + None => { + return Err(anyhow!( + "SNARK session succeeded but no output was provided" + )) + } + }; + tracing::info!("Snark proof!: {snark_receipt:?}"); + + // now we convert the snark_receipt to a full receipt + + use risc0_zkvm::sha::Digestible; + let inner = InnerReceipt::Groth16(Groth16Receipt::new( + snark_receipt.snark.to_vec(), + receipt.claim().expect("stark_2_snark error, receipt claim"), + risc0_zkvm::Groth16ReceiptVerifierParameters::default().digest(), + )); + + let full_snark_receipt = Receipt::new(inner, snark_receipt.journal); + + tracing::info!("Full snark proof!: {full_snark_receipt:?}"); + + let full_snark_receipt = bincode::serialize(&full_snark_receipt)?; + + return Ok(Proof::Full(full_snark_receipt)); + } + _ => { + return Err(anyhow!( + "Workflow exited: {} with error message: {}", + res.status, + res.error_msg.unwrap_or_default() + )); + } + } + } + } } impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { @@ -179,105 +323,36 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { false )) .expect("Failed to create session; qed"); - tracing::info!("Session created: {}", session.uuid); - let receipt = loop { - // handle error - let res = retry_backoff_bonsai!(session.status(client)) - .expect("Failed to fetch status; qed"); - - if res.status == "RUNNING" { - tracing::info!( - "Current status: {} - state: {} - continue polling...", - res.status, - res.state.unwrap_or_default() - ); - std::thread::sleep(Duration::from_secs(15)); - continue; - } - if res.status == "SUCCEEDED" { - // Download the receipt, containing the output - let receipt_url = res - .receipt_url - .expect("API error, missing receipt on completed session"); - - tracing::info!("Receipt URL: {}", receipt_url); - if let Some(stats) = res.stats { - tracing::info!( - "User cycles: {} - Total cycles: {} - Segments: {}", - stats.cycles, - stats.total_cycles, - stats.segments, - ); - } - - let receipt_buf = retry_backoff_bonsai!(client.download(receipt_url.as_str())) - .expect("Failed to download receipt; qed"); + let stark_session = RecoveredBonsaiSession { + id: 0, + session: BonsaiSession::StarkSession(session.uuid.clone()), + }; + let serialized_stark_session = borsh::to_vec(&stark_session) + .expect("Bonsai host should be able to serialize bonsai sessions"); + self.ledger_db + .add_pending_proving_session(serialized_stark_session.clone())?; - let receipt: Receipt = bincode::deserialize(&receipt_buf).unwrap(); + tracing::info!("Session created: {}", session.uuid); - break receipt; - } else { - return Err(anyhow!( - "Workflow exited: {} with error message: {}", - res.status, - res.error_msg.unwrap_or_default() - )); - } - }; + let receipt = self.wait_for_receipt(&session.uuid)?; tracing::info!("Creating the SNARK"); let snark_session = retry_backoff_bonsai!(client.create_snark(session.uuid.clone())) .expect("Failed to create snark session; qed"); - tracing::info!("SNARK session created: {}", snark_session.uuid); - - loop { - let res = retry_backoff_bonsai!(snark_session.status(client)) - .expect("Failed to fetch status; qed"); - match res.status.as_str() { - "RUNNING" => { - tracing::info!("Current status: {} - continue polling...", res.status,); - std::thread::sleep(Duration::from_secs(15)); - continue; - } - "SUCCEEDED" => { - let snark_receipt = match res.output { - Some(output) => output, - None => { - return Err(anyhow!( - "SNARK session succeeded but no output was provided" - )) - } - }; - tracing::info!("Snark proof!: {snark_receipt:?}"); - - // now we convert the snark_receipt to a full receipt - - use risc0_zkvm::sha::Digestible; - let inner = InnerReceipt::Groth16(Groth16Receipt::new( - snark_receipt.snark.to_vec(), - receipt.claim().expect("stark_2_snark error, receipt claim"), - risc0_zkvm::Groth16ReceiptVerifierParameters::default().digest(), - )); + // Remove the stark session as it is finished + self.ledger_db + .remove_pending_proving_session(serialized_stark_session.clone())?; - let full_snark_receipt = Receipt::new(inner, snark_receipt.journal); - - tracing::info!("Full snark proof!: {full_snark_receipt:?}"); - - let full_snark_receipt = bincode::serialize(&full_snark_receipt)?; + tracing::info!("SNARK session created: {}", snark_session.uuid); - return Ok(Proof::Full(full_snark_receipt)); - } - _ => { - return Err(anyhow!( - "Workflow exited: {} with error message: {}", - res.status, - res.error_msg.unwrap_or_default() - )); - } - } - } + // Snark session is saved in the function + self.wait_for_stark_to_snark_conversion( + Some(&snark_session.uuid), + &session.uuid, + receipt, + ) } } @@ -296,6 +371,33 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { }; Ok(BorshDeserialize::try_from_slice(&journal.bytes)?) } + + fn recover_proving_sessions(&self) -> Result, anyhow::Error> { + let sessions = self.ledger_db.get_pending_proving_sessions()?; + let mut proofs = Vec::new(); + for session in sessions { + let bonsai_session: RecoveredBonsaiSession = BorshDeserialize::try_from_slice(&session) + .expect("Bonsai host should be able to recover bonsai sessions"); + match bonsai_session.session { + BonsaiSession::StarkSession(stark_session) => { + let receipt = self.wait_for_receipt(&stark_session)?; + let proof = + self.wait_for_stark_to_snark_conversion(None, &stark_session, receipt)?; + proofs.push(proof); + } + BonsaiSession::SnarkSession(stark_session, snark_session) => { + let receipt = self.wait_for_receipt(&stark_session)?; + let proof = self.wait_for_stark_to_snark_conversion( + Some(&snark_session), + &stark_session, + receipt, + )?; + proofs.push(proof) + } + } + } + Ok(proofs) + } } impl<'host> Zkvm for Risc0BonsaiHost<'host> { diff --git a/crates/sovereign-sdk/adapters/mock-zkvm/src/lib.rs b/crates/sovereign-sdk/adapters/mock-zkvm/src/lib.rs index 847275814b..620a20b574 100644 --- a/crates/sovereign-sdk/adapters/mock-zkvm/src/lib.rs +++ b/crates/sovereign-sdk/adapters/mock-zkvm/src/lib.rs @@ -9,7 +9,7 @@ use anyhow::ensure; use borsh::{BorshDeserialize, BorshSerialize}; use serde::{Deserialize, Serialize}; use sov_rollup_interface::da::BlockHeaderTrait; -use sov_rollup_interface::zk::{Matches, StateTransitionData, ValidityCondition}; +use sov_rollup_interface::zk::{Matches, Proof, StateTransitionData, ValidityCondition}; /// A mock commitment to a particular zkVM program. #[derive(Debug, Clone, PartialEq, Eq, BorshDeserialize, BorshSerialize, Serialize, Deserialize)] @@ -196,6 +196,10 @@ impl sov_rollup_interface::zk::ZkvmHost } } } + + fn recover_proving_sessions(&self) -> Result, anyhow::Error> { + unimplemented!() + } } /// A mock implementing the Guest. diff --git a/crates/sovereign-sdk/adapters/risc0/src/host.rs b/crates/sovereign-sdk/adapters/risc0/src/host.rs index 4c2e3e9224..f1bcf2ce46 100644 --- a/crates/sovereign-sdk/adapters/risc0/src/host.rs +++ b/crates/sovereign-sdk/adapters/risc0/src/host.rs @@ -105,6 +105,10 @@ impl<'a> ZkvmHost for Risc0Host<'a> { }; Ok(BorshDeserialize::deserialize(&mut journal.bytes.as_ref())?) } + + fn recover_proving_sessions(&self) -> Result, anyhow::Error> { + unimplemented!() + } } impl<'host> Zkvm for Risc0Host<'host> { diff --git a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs index cde955a722..cdd85b82a7 100644 --- a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs +++ b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/mod.rs @@ -16,9 +16,10 @@ use crate::rocks_db_config::gen_rocksdb_options; use crate::schema::tables::{ ActiveFork, BatchByNumber, CommitmentsByNumber, EventByKey, EventByNumber, L2GenesisStateRoot, L2RangeByL1Height, L2Witness, LastSequencerCommitmentSent, LastStateDiff, MempoolTxs, - PendingSequencerCommitmentL2Range, ProofBySlotNumber, ProverLastScannedSlot, ProverStateDiffs, - SlotByHash, SlotByNumber, SoftConfirmationByHash, SoftConfirmationByNumber, - SoftConfirmationStatus, TxByHash, TxByNumber, VerifiedProofsBySlotNumber, LEDGER_TABLES, + PendingProvingSessions, PendingSequencerCommitmentL2Range, ProofBySlotNumber, + ProverLastScannedSlot, ProverStateDiffs, SlotByHash, SlotByNumber, SoftConfirmationByHash, + SoftConfirmationByNumber, SoftConfirmationStatus, TxByHash, TxByNumber, + VerifiedProofsBySlotNumber, LEDGER_TABLES, }; use crate::schema::types::{ split_tx_for_storage, BatchNumber, EventNumber, L2HeightRange, SlotNumber, StoredProof, @@ -342,6 +343,12 @@ impl SharedLedgerOps for LedgerDB { self.db.put::(&hash, &SlotNumber(height)) } + /// Gets l1 height of l1 hash + #[instrument(level = "trace", skip(self), err, ret)] + fn get_l1_height_of_l1_hash(&self, hash: [u8; 32]) -> Result, anyhow::Error> { + self.db.get::(&hash).map(|v| v.map(|a| a.0)) + } + /// Saves a soft confirmation status for a given L1 height #[instrument(level = "trace", skip(self), err, ret)] fn put_soft_confirmation_status( @@ -569,6 +576,46 @@ impl ProverLedgerOps for LedgerDB { } } +impl ProvingServiceLedgerOps for LedgerDB { + /// Gets all pending sessions and step numbers + #[instrument(level = "trace", skip(self), err)] + fn get_pending_proving_sessions(&self) -> anyhow::Result>> { + let mut iter = self.db.iter::()?; + iter.seek_to_first(); + + let sessions = iter + .map(|item| item.map(|item| (item.key))) + .collect::, _>>()?; + Ok(sessions) + } + + #[instrument(level = "trace", skip(self), err)] + fn add_pending_proving_session(&self, session: Vec) -> anyhow::Result<()> { + self.db.put::(&session, &()) + } + + #[instrument(level = "trace", skip(self), err)] + fn remove_pending_proving_session(&self, session: Vec) -> anyhow::Result<()> { + self.db.delete::(&session) + } + + #[instrument(level = "trace", skip(self), err)] + fn clear_pending_proving_sessions(&self) -> anyhow::Result<()> { + let mut schema_batch = SchemaBatch::new(); + let mut iter = self.db.iter::()?; + iter.seek_to_first(); + + for item in iter { + let item = item?; + schema_batch.delete::(&item.key)?; + } + + self.db.write_schemas(schema_batch)?; + + Ok(()) + } +} + impl SequencerLedgerOps for LedgerDB { /// Put slots #[instrument(level = "trace", skip(self, schema_batch), err, ret)] @@ -748,12 +795,6 @@ impl NodeLedgerOps for LedgerDB { ) -> anyhow::Result>> { self.db.get::(&SlotNumber(height)) } - - /// Gets l1 height of l1 hash - #[instrument(level = "trace", skip(self), err, ret)] - fn get_l1_height_of_l1_hash(&self, hash: [u8; 32]) -> Result, anyhow::Error> { - self.db.get::(&hash).map(|v| v.map(|a| a.0)) - } } impl ForkMigration for LedgerDB { diff --git a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs index f064eaaed4..05eafdbab6 100644 --- a/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs +++ b/crates/sovereign-sdk/full-node/db/sov-db/src/ledger_db/traits.rs @@ -69,6 +69,9 @@ pub trait SharedLedgerOps { /// Sets l1 height of l1 hash fn set_l1_height_of_l1_hash(&self, hash: [u8; 32], height: u64) -> Result<()>; + /// Gets l1 height of l1 hash + fn get_l1_height_of_l1_hash(&self, hash: [u8; 32]) -> Result>; + /// Saves a soft confirmation status for a given L1 height fn put_soft_confirmation_status( &self, @@ -141,13 +144,10 @@ pub trait NodeLedgerOps: SharedLedgerOps { /// Gets the commitments in the da slot with given height if any fn get_commitments_on_da_slot(&self, height: u64) -> Result>>; - - /// Gets l1 height of l1 hash - fn get_l1_height_of_l1_hash(&self, hash: [u8; 32]) -> Result>; } /// Prover ledger operations -pub trait ProverLedgerOps: SharedLedgerOps { +pub trait ProverLedgerOps: SharedLedgerOps + Send + Sync { /// Get the state root by L2 height fn get_l2_state_root( &self, @@ -183,6 +183,21 @@ pub trait ProverLedgerOps: SharedLedgerOps { fn get_l2_state_diff(&self, l2_height: BatchNumber) -> Result>; } +/// Ledger operations for the prover service +pub trait ProvingServiceLedgerOps: ProverLedgerOps + SharedLedgerOps + Send + Sync { + /// Gets all pending sessions and step numbers + fn get_pending_proving_sessions(&self) -> Result>>; + + /// Adds a pending proving session + fn add_pending_proving_session(&self, session: Vec) -> Result<()>; + + /// Removes a pending proving session + fn remove_pending_proving_session(&self, session: Vec) -> Result<()>; + + /// Clears all pending proving sessions + fn clear_pending_proving_sessions(&self) -> Result<()>; +} + /// Sequencer ledger operations pub trait SequencerLedgerOps: SharedLedgerOps { /// Put slots diff --git a/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs b/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs index f089a33d6e..9bbcb1037b 100644 --- a/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs +++ b/crates/sovereign-sdk/full-node/db/sov-db/src/schema/tables.rs @@ -74,6 +74,7 @@ pub const LEDGER_TABLES: &[&str] = &[ ProofBySlotNumber::table_name(), VerifiedProofsBySlotNumber::table_name(), MempoolTxs::table_name(), + PendingProvingSessions::table_name(), ProverStateDiffs::table_name(), ]; @@ -333,6 +334,12 @@ define_table_with_default_codec!( (VerifiedProofsBySlotNumber) SlotNumber => Vec ); +define_table_with_seek_key_codec!( + /// Proving service uses this table to store pending proving sessions + /// If a session id is completed, remove it + (PendingProvingSessions) Vec => () +); + define_table_with_default_codec!( /// Transactions in mempool (TxHash, TxData) (MempoolTxs) Vec => Vec diff --git a/crates/sovereign-sdk/full-node/sov-stf-runner/src/prover_service/mod.rs b/crates/sovereign-sdk/full-node/sov-stf-runner/src/prover_service/mod.rs index 82900a5619..2293bf659b 100644 --- a/crates/sovereign-sdk/full-node/sov-stf-runner/src/prover_service/mod.rs +++ b/crates/sovereign-sdk/full-node/sov-stf-runner/src/prover_service/mod.rs @@ -113,4 +113,10 @@ pub trait ProverService { block_header_hash: <::Spec as DaSpec>::SlotHash, da_service: &Arc, ) -> Result<(::TransactionId, Proof), anyhow::Error>; + + /// Recovers pending proving sessions and sends proofs to the DA. + async fn recover_proving_sessions_and_send_to_da( + &self, + da_service: &Arc, + ) -> Result::TransactionId, Proof)>, anyhow::Error>; } diff --git a/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs b/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs index 614f7ddad1..5fd0b8b749 100644 --- a/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs +++ b/crates/sovereign-sdk/module-system/sov-modules-rollup-blueprint/src/lib.rs @@ -112,6 +112,7 @@ pub trait RollupBlueprint: Sized + Send + Sync { prover_config: ProverConfig, rollup_config: &FullNodeConfig, da_service: &Arc, + ledger_db: LedgerDB, ) -> Self::ProverService; /// Creates instance of [`Self::StorageManager`]. diff --git a/crates/sovereign-sdk/rollup-interface/src/state_machine/zk/mod.rs b/crates/sovereign-sdk/rollup-interface/src/state_machine/zk/mod.rs index 3e085a5a55..5d378b1157 100644 --- a/crates/sovereign-sdk/rollup-interface/src/state_machine/zk/mod.rs +++ b/crates/sovereign-sdk/rollup-interface/src/state_machine/zk/mod.rs @@ -54,6 +54,9 @@ pub trait ZkvmHost: Zkvm + Clone { fn extract_output( proof: &Proof, ) -> Result, Self::Error>; + + /// Host recovers pending proving sessions and returns proving results + fn recover_proving_sessions(&self) -> Result, anyhow::Error>; } /// A Zk proof system capable of proving and verifying arbitrary Rust code From 12bb40fa958449a8d7c9c49533fdc5fb724b535d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20Yaz=C4=B1c=C4=B1?= <75089142+yaziciahmet@users.noreply.github.com> Date: Thu, 15 Aug 2024 10:57:22 +0300 Subject: [PATCH 4/5] Sort commitments in zk (#987) * Sort commitments in zk * Impl Ord for commitment * Replace std with core --------- Co-authored-by: yaziciahmet --- .../sov-modules-stf-blueprint/src/lib.rs | 3 +++ .../rollup-interface/src/state_machine/da.rs | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/crates/sovereign-sdk/module-system/sov-modules-stf-blueprint/src/lib.rs b/crates/sovereign-sdk/module-system/sov-modules-stf-blueprint/src/lib.rs index 51330a93cb..312fb0dd70 100644 --- a/crates/sovereign-sdk/module-system/sov-modules-stf-blueprint/src/lib.rs +++ b/crates/sovereign-sdk/module-system/sov-modules-stf-blueprint/src/lib.rs @@ -522,6 +522,9 @@ where } } + // Sort commitments just in case + sequencer_commitments.sort_unstable(); + // Then verify these soft confirmations. let mut current_state_root = initial_state_root.clone(); diff --git a/crates/sovereign-sdk/rollup-interface/src/state_machine/da.rs b/crates/sovereign-sdk/rollup-interface/src/state_machine/da.rs index 241a5f32c9..a19a70aaea 100644 --- a/crates/sovereign-sdk/rollup-interface/src/state_machine/da.rs +++ b/crates/sovereign-sdk/rollup-interface/src/state_machine/da.rs @@ -21,6 +21,18 @@ pub struct SequencerCommitment { pub l2_end_block_number: u64, } +impl core::cmp::PartialOrd for SequencerCommitment { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl core::cmp::Ord for SequencerCommitment { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.l2_start_block_number.cmp(&other.l2_start_block_number) + } +} + /// Data written to DA can only be one of these two types /// Data written to DA and read from DA is must be borsh serialization of this enum #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, BorshDeserialize, BorshSerialize)] From fe4cadde434eb343801cd48ddabf95b7a5ff8524 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20Yaz=C4=B1c=C4=B1?= <75089142+yaziciahmet@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:06:02 +0300 Subject: [PATCH 5/5] Fix empty l2 block having one tx (#992) * Skip apply tx if 0 txs found * Lint * Assign vars in place * Fix test --------- Co-authored-by: yaziciahmet --- bin/citrea/tests/e2e/sequencer_behaviour.rs | 2 +- crates/sequencer/src/sequencer.rs | 29 +++++++++++++-------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/bin/citrea/tests/e2e/sequencer_behaviour.rs b/bin/citrea/tests/e2e/sequencer_behaviour.rs index 9a74689bee..6f4359ec86 100644 --- a/bin/citrea/tests/e2e/sequencer_behaviour.rs +++ b/bin/citrea/tests/e2e/sequencer_behaviour.rs @@ -306,7 +306,7 @@ async fn transaction_failing_on_l1_is_removed_from_mempool() -> Result<(), anyho assert_eq!(block.transactions.len(), 0); assert!(tx_from_mempool.is_none()); - assert_eq!(soft_confirmation.txs.unwrap().len(), 1); // TODO: if we can also remove the tx from soft confirmation, that'd be very efficient + assert_eq!(soft_confirmation.txs.unwrap().len(), 0); wait_for_l2_block(&full_node_test_client, block.header.number.unwrap(), None).await; diff --git a/crates/sequencer/src/sequencer.rs b/crates/sequencer/src/sequencer.rs index 41aacf04ae..51afb283fd 100644 --- a/crates/sequencer/src/sequencer.rs +++ b/crates/sequencer/src/sequencer.rs @@ -443,18 +443,25 @@ where &mut signed_batch, ) { (Ok(()), mut batch_workspace) => { - let evm_txs_count = txs_to_run.len(); - let call_txs = CallMessage { txs: txs_to_run }; - let raw_message = - as EncodeCall>>::encode_call(call_txs); - let signed_blob = self.make_blob(raw_message, &mut batch_workspace)?; - let txs = vec![signed_blob.clone()]; + let mut txs = vec![]; + let mut tx_receipts = vec![]; - let (batch_workspace, tx_receipts) = self.stf.apply_soft_confirmation_txs( - self.fork_manager.active_fork(), - txs.clone(), - batch_workspace, - ); + let evm_txs_count = txs_to_run.len(); + if evm_txs_count > 0 { + let call_txs = CallMessage { txs: txs_to_run }; + let raw_message = + as EncodeCall>>::encode_call( + call_txs, + ); + let signed_blob = self.make_blob(raw_message, &mut batch_workspace)?; + txs.push(signed_blob); + + (batch_workspace, tx_receipts) = self.stf.apply_soft_confirmation_txs( + self.fork_manager.active_fork(), + txs.clone(), + batch_workspace, + ); + } // create the unsigned batch with the txs then sign th sc let unsigned_batch = UnsignedSoftConfirmationBatch::new(