diff --git a/Cargo.lock b/Cargo.lock index 07fd7249d..7abb786a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1955,6 +1955,7 @@ name = "citrea-risc0-bonsai-adapter" version = "0.4.0-rc.3" dependencies = [ "anyhow", + "backoff", "bincode", "bonsai-sdk", "borsh", diff --git a/crates/risc0-bonsai/Cargo.toml b/crates/risc0-bonsai/Cargo.toml index c73f319f4..a3d1cfe77 100644 --- a/crates/risc0-bonsai/Cargo.toml +++ b/crates/risc0-bonsai/Cargo.toml @@ -13,6 +13,7 @@ description = "An adapter allowing Citrea to connect with Bonsai" [dependencies] anyhow = { workspace = true } +backoff = { workspace = true } bincode = { workspace = true } bonsai-sdk = { workspace = true } borsh = { workspace = true } diff --git a/crates/risc0-bonsai/src/host.rs b/crates/risc0-bonsai/src/host.rs index 19cbefdf9..16f1e9fa5 100644 --- a/crates/risc0-bonsai/src/host.rs +++ b/crates/risc0-bonsai/src/host.rs @@ -1,10 +1,10 @@ //! This module implements the [`ZkvmHost`] trait for the RISC0 VM. - -use std::sync::mpsc::{self, Sender}; -use std::sync::Arc; use std::time::Duration; use anyhow::anyhow; +use backoff::exponential::ExponentialBackoffBuilder; +use backoff::{retry as retry_backoff, SystemClock}; +use bonsai_sdk::blocking::Client; use borsh::{BorshDeserialize, BorshSerialize}; use risc0_zkvm::sha::Digest; use risc0_zkvm::{ @@ -13,278 +13,48 @@ use risc0_zkvm::{ }; use sov_risc0_adapter::guest::Risc0Guest; use sov_rollup_interface::zk::{Proof, Zkvm, ZkvmHost}; -use tracing::{debug, error, info, instrument, trace, warn}; - -/// Requests to bonsai client. Each variant represents its own method. -#[derive(Clone)] -enum BonsaiRequest { - UploadImg { - image_id: String, - buf: Vec, - notify: Sender, - }, - UploadInput { - buf: Vec, - notify: Sender, - }, - Download { - url: String, - notify: Sender>, - }, - CreateSession { - img_id: String, - input_id: String, - assumptions: Vec, - notify: Sender, - }, - CreateSnark { - session: bonsai_sdk::blocking::SessionId, - notify: Sender, - }, - Status { - session: bonsai_sdk::blocking::SessionId, - notify: Sender, - }, - SnarkStatus { - session: bonsai_sdk::blocking::SnarkId, - notify: Sender, - }, -} - -/// A wrapper around Bonsai SDK to handle tokio runtime inside another tokio runtime. -/// See https://stackoverflow.com/a/62536772. -#[derive(Clone)] -struct BonsaiClient { - queue: std::sync::mpsc::Sender, - _join_handle: Arc>, -} - -impl BonsaiClient { - fn from_parts(api_url: String, api_key: String, risc0_version: &str) -> Self { - macro_rules! unwrap_bonsai_response { - ($response:expr, $client_loop:lifetime, $queue_loop:lifetime) => ( - match $response { - Ok(r) => r, +use tracing::{error, info, warn}; + +macro_rules! retry_backoff_bonsai { + ($bonsai_call:expr) => { + retry_backoff( + ExponentialBackoffBuilder::::new() + .with_initial_interval(Duration::from_secs(5)) + .with_max_elapsed_time(Some(Duration::from_secs(15 * 60))) + .build(), + || { + let response = $bonsai_call; + match response { + Ok(r) => Ok(r), Err(e) => { use ::bonsai_sdk::SdkErr::*; match e { InternalServerErr(s) => { - warn!(%s, "Got HHTP 500 from Bonsai"); - std::thread::sleep(Duration::from_secs(10)); - continue $queue_loop + let err = format!("Got HHTP 500 from Bonsai: {}", s); + warn!(err); + Err(backoff::Error::transient(err)) } HttpErr(e) => { - error!(?e, "Reconnecting to Bonsai"); - std::thread::sleep(Duration::from_secs(5)); - continue $client_loop + let err = format!("Reconnecting to Bonsai: {}", e); + error!(err); + Err(backoff::Error::transient(err)) } HttpHeaderErr(e) => { - error!(?e, "Reconnecting to Bonsai"); - std::thread::sleep(Duration::from_secs(5)); - continue $client_loop + let err = format!("Reconnecting to Bonsai: {}", e); + error!(err); + Err(backoff::Error::transient(err)) } e => { - error!(?e, "Got unrecoverable error from Bonsai"); - panic!("Bonsai API error: {}", e); + let err = format!("Got unrecoverable error from Bonsai: {}", e); + error!(err); + Err(backoff::Error::permanent(err)) } } } } - ); - } - let risc0_version = risc0_version.to_string(); - let (queue, rx) = std::sync::mpsc::channel(); - let join_handle = std::thread::spawn(move || { - let mut last_request: Option = None; - 'client: loop { - debug!("Connecting to Bonsai"); - let client = match bonsai_sdk::blocking::Client::from_parts( - api_url.clone(), - api_key.clone(), - &risc0_version, - ) { - Ok(client) => client, - Err(e) => { - error!(?e, "Failed to connect to Bonsai"); - std::thread::sleep(Duration::from_secs(5)); - continue 'client; - } - }; - 'queue: loop { - let request = if let Some(last_request) = last_request.clone() { - debug!("Retrying last request after reconnection"); - last_request - } else { - trace!("Waiting for a new request"); - let req: BonsaiRequest = rx.recv().expect("bonsai client sender is dead"); - // Save request for retries - last_request = Some(req.clone()); - req - }; - match request { - BonsaiRequest::UploadImg { - image_id, - buf, - notify, - } => { - debug!(%image_id, "Bonsai:upload_img"); - let res = client.upload_img(&image_id, buf); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::UploadInput { buf, notify } => { - debug!("Bonsai:upload_input"); - let res = client.upload_input(buf); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::Download { url, notify } => { - debug!(%url, "Bonsai:download"); - let res = client.download(&url); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::CreateSession { - img_id, - input_id, - assumptions, - notify, - } => { - debug!(%img_id, %input_id, "Bonsai:create_session"); - // TODO: think about whether we should have a case where we use Bonsai with only execute mode - let res = client.create_session(img_id, input_id, assumptions, false); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::Status { session, notify } => { - debug!(?session, "Bonsai:session_status"); - let res = session.status(&client); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::CreateSnark { session, notify } => { - debug!(?session, "Bonsai:create_snark"); - let res = client.create_snark(session.uuid); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - BonsaiRequest::SnarkStatus { session, notify } => { - debug!(?session, "Bonsai:snark_status"); - let res = session.status(&client); - let res = unwrap_bonsai_response!(res, 'client, 'queue); - let _ = notify.send(res); - } - }; - // We arrive here only on a successful response - last_request = None; - } - } - }); - let _join_handle = Arc::new(join_handle); - Self { - queue, - _join_handle, - } - } - - #[instrument(level = "trace", skip(self, buf), ret)] - fn upload_img(&self, image_id: String, buf: Vec) -> bool { - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::UploadImg { - image_id, - buf, - notify, - }) - .expect("Bonsai processing queue is dead"); - rx.recv().unwrap() - } - - #[instrument(level = "trace", skip_all, ret)] - fn upload_input(&self, buf: Vec) -> String { - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::UploadInput { buf, notify }) - .expect("Bonsai processing queue is dead"); - rx.recv().unwrap() - } - - #[instrument(level = "trace", skip(self))] - fn download(&self, url: String) -> Vec { - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::Download { url, notify }) - .expect("Bonsai processing queue is dead"); - rx.recv().unwrap() - } - - #[instrument(level = "trace", skip(self, assumptions), ret)] - fn create_session( - &self, - img_id: String, - input_id: String, - assumptions: Vec, - ) -> bonsai_sdk::blocking::SessionId { - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::CreateSession { - img_id, - input_id, - assumptions, - notify, - }) - .expect("Bonsai processing queue is dead"); - rx.recv().unwrap() - } - - #[instrument(level = "trace", skip(self))] - fn status( - &self, - session: &bonsai_sdk::blocking::SessionId, - ) -> bonsai_sdk::responses::SessionStatusRes { - let session = session.clone(); - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::Status { session, notify }) - .expect("Bonsai processing queue is dead"); - let status = rx.recv().unwrap(); - debug!( - status.status, - status.receipt_url, status.error_msg, status.state, status.elapsed_time - ); - status - } - - #[instrument(level = "trace", skip(self), ret)] - fn create_snark( - &self, - session: &bonsai_sdk::blocking::SessionId, - ) -> bonsai_sdk::blocking::SnarkId { - let session = session.clone(); - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::CreateSnark { session, notify }) - .expect("Bonsai processing queue is dead"); - rx.recv().unwrap() - } - - #[instrument(level = "trace", skip(self))] - fn snark_status( - &self, - snark_session: &bonsai_sdk::blocking::SnarkId, - ) -> bonsai_sdk::responses::SnarkStatusRes { - let snark_session = snark_session.clone(); - let (notify, rx) = mpsc::channel(); - self.queue - .send(BonsaiRequest::SnarkStatus { - session: snark_session, - notify, - }) - .expect("Bonsai processing queue is dead"); - let status = rx.recv().unwrap(); - debug!(status.status, ?status.output, status.error_msg); - status - } + }, + ) + }; } /// A [`Risc0BonsaiHost`] stores a binary to execute in the Risc0 VM and prove in the Risc0 Bonsai API. @@ -293,7 +63,7 @@ pub struct Risc0BonsaiHost<'a> { elf: &'a [u8], env: Vec, image_id: Digest, - client: Option, + client: Option, last_input_id: Option, } @@ -308,12 +78,14 @@ impl<'a> Risc0BonsaiHost<'a> { // handle error let client = if !api_url.is_empty() && !api_key.is_empty() { - let client = BonsaiClient::from_parts(api_url, api_key, risc0_zkvm::VERSION); - tracing::debug!("Uploading image with id: {}", image_id); - // handle error - client.upload_img(hex::encode(image_id), elf.to_vec()); + let client = Client::from_parts(api_url, api_key, risc0_zkvm::VERSION) + .expect("Failed to create Bonsai client; qed"); + + client + .upload_img(hex::encode(image_id).as_str(), elf.to_vec()) + .expect("Failed to upload image; qed"); Some(client) } else { @@ -331,11 +103,12 @@ impl<'a> Risc0BonsaiHost<'a> { fn upload_to_bonsai(&mut self, buf: Vec) { // handle error - let input_id = self + let input_id = retry_backoff_bonsai!(self .client .as_ref() .expect("Bonsai client is not initialized") - .upload_input(buf); + .upload_input(buf.clone())) + .expect("Failed to upload input; qed"); tracing::info!("Uploaded input with id: {}", input_id); self.last_input_id = Some(input_id); } @@ -398,11 +171,19 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { }; // Start a session running the prover - let session = client.create_session(hex::encode(self.image_id), input_id, vec![]); + // execute only is set to false because we run bonsai only when proving + let session = retry_backoff_bonsai!(client.create_session( + hex::encode(self.image_id), + input_id.clone(), + vec![], + false + )) + .expect("Failed to create session; qed"); tracing::info!("Session created: {}", session.uuid); let receipt = loop { // handle error - let res = client.status(&session); + let res = retry_backoff_bonsai!(session.status(client)) + .expect("Failed to fetch status; qed"); if res.status == "RUNNING" { tracing::info!( @@ -429,7 +210,8 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { ); } - let receipt_buf = client.download(receipt_url); + let receipt_buf = retry_backoff_bonsai!(client.download(receipt_url.as_str())) + .expect("Failed to download receipt; qed"); let receipt: Receipt = bincode::deserialize(&receipt_buf).unwrap(); @@ -445,12 +227,14 @@ impl<'a> ZkvmHost for Risc0BonsaiHost<'a> { tracing::info!("Creating the SNARK"); - let snark_session = client.create_snark(&session); + let snark_session = retry_backoff_bonsai!(client.create_snark(session.uuid.clone())) + .expect("Failed to create snark session; qed"); tracing::info!("SNARK session created: {}", snark_session.uuid); loop { - let res = client.snark_status(&snark_session); + let res = retry_backoff_bonsai!(snark_session.status(client)) + .expect("Failed to fetch status; qed"); match res.status.as_str() { "RUNNING" => { tracing::info!("Current status: {} - continue polling...", res.status,);