diff --git a/.github/workflows/validate_pr.yml b/.github/workflows/validate_pr.yml new file mode 100644 index 00000000..f1a20543 --- /dev/null +++ b/.github/workflows/validate_pr.yml @@ -0,0 +1,57 @@ +name: Validate PR + +on: + pull_request: + branches: [main] + +concurrency: + group: rust-validation-${{ github.head_ref }} + cancel-in-progress: true + +jobs: + formatting: + name: Rustfmt + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v2 + + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + components: rustfmt + + - name: Check Formatting + run: cargo fmt -- --check + + linting: + name: Clippy + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v2 + + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + components: clippy + + - name: Run Clippy + run: cargo clippy --tests -- -D warnings + + testing: + name: Cargo Test + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v2 + + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + + - name: Run Tests + run: cargo test --features=substrate \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6985cf1b..53798a75 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,9 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + + +# Added by cargo + +/target +.idea \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..422734c6 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,24 @@ +[workspace] +members = [ + "gadget-core", + "webb-gadget", + "zk-gadget" +] + +[workspace.dependencies] +gadget-core = { path = "./gadget-core" } +webb-gadget = { path = "./webb-gadget" } +zk-gadget = { path = "./zk-gadget" } + +sc-client-api = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } +sp-core = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } +sp-runtime = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } +sp-api = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } + +mpc-net = { git = "https://github.com/webb-tools/zk-SaaS/" } +tokio-rustls = "0.24.1" +tokio = "1.32.0" +bincode2 = "2" +futures-util = "0.3.28" +serde = "1.0.188" +async-trait = "0.1.73" \ No newline at end of file diff --git a/README.md b/README.md index 8be34a54..208e57ff 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,12 @@ -# gadget -A common repository for gadgets +# Gadget + +## Design + + +![](./resources/gadget.png) + +The core library is `gadget-core`. The core library allows gadgets to hold standardization of use across different blockchains. The core library is the base of all gadgets, and expects to receive `FinalityNotifications` and `BlockImportNotifications`. + +Once such blockchain is a substrate blockchain. This is where `webb-gadget` comes into play. The `webb-gadget` is the `core-gadget` endowed with a connection to a substrate blockchain, a networking layer to communicate with other gadgets, and a `WebbGadgetModule` that has application-specific logic. + +Since `webb-gadget` allows varying connections to a substrate blockchain and differing network layers, we can thus design above it the `zk-gadget` and `tss-gadget`. These gadgets are endowed with the same functionalities as the `webb-gadget` but with a (potentially) different blockchain connection, networking layer, and application-specific logic. diff --git a/gadget-core/Cargo.toml b/gadget-core/Cargo.toml new file mode 100644 index 00000000..c26b4b67 --- /dev/null +++ b/gadget-core/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "gadget-core" +version = "0.0.1" +authors = ["Thomas P Braun"] +license = "GPL-3.0-or-later WITH Classpath-exception-2.0" +edition = "2021" + +[features] +substrate = [ + "sp-runtime", + "sc-client-api", + "sp-api", + "futures" +] + +[dependencies] +sync_wrapper = "0.1.2" +parking_lot = "0.12.1" +tokio = { workspace = true, features = ["sync", "time", "macros", "rt"] } +hex = "0.4.3" +async-trait = "0.1.73" + +sp-runtime = { optional = true, workspace = true, default-features = false } +sc-client-api = { optional = true, workspace = true, default-features = false } +sp-api = { optional = true, workspace = true, default-features = false } +futures = { optional = true, version = "0.3.28" } + +[dev-dependencies] \ No newline at end of file diff --git a/gadget-core/src/gadget/manager.rs b/gadget-core/src/gadget/manager.rs new file mode 100644 index 00000000..3c8e5e74 --- /dev/null +++ b/gadget-core/src/gadget/manager.rs @@ -0,0 +1,114 @@ +use async_trait::async_trait; +use std::error::Error; +use std::future::Future; +use std::pin::Pin; + +pub struct GadgetManager<'a> { + gadget: Pin> + 'a>>, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum GadgetError { + FinalityNotificationStreamEnded, + BlockImportNotificationStreamEnded, + ProtocolMessageStreamEnded, +} + +#[async_trait] +pub trait AbstractGadget: Send { + type FinalityNotification: Send; + type BlockImportNotification: Send; + type ProtocolMessage: Send; + type Error: Error; + + async fn get_next_finality_notification(&self) -> Option; + async fn get_next_block_import_notification(&self) -> Option; + async fn get_next_protocol_message(&self) -> Option; + + async fn process_finality_notification( + &self, + notification: Self::FinalityNotification, + ) -> Result<(), Self::Error>; + async fn process_block_import_notification( + &self, + notification: Self::BlockImportNotification, + ) -> Result<(), Self::Error>; + async fn process_protocol_message( + &self, + message: Self::ProtocolMessage, + ) -> Result<(), Self::Error>; + + async fn process_error(&self, error: Self::Error); +} + +impl<'a> GadgetManager<'a> { + pub fn new(gadget: T) -> Self { + let gadget_task = async move { + let gadget = &gadget; + + let finality_notification_task = async move { + loop { + if let Some(notification) = gadget.get_next_finality_notification().await { + if let Err(err) = gadget.process_finality_notification(notification).await { + gadget.process_error(err).await; + } + } else { + return Err(GadgetError::FinalityNotificationStreamEnded); + } + } + }; + + let block_import_notification_task = async move { + loop { + if let Some(notification) = gadget.get_next_block_import_notification().await { + if let Err(err) = + gadget.process_block_import_notification(notification).await + { + gadget.process_error(err).await; + } + } else { + return Err(GadgetError::BlockImportNotificationStreamEnded); + } + } + }; + + let protocol_message_task = async move { + loop { + if let Some(message) = gadget.get_next_protocol_message().await { + if let Err(err) = gadget.process_protocol_message(message).await { + gadget.process_error(err).await; + } + } else { + return Err(GadgetError::ProtocolMessageStreamEnded); + } + } + }; + + tokio::select! { + res0 = finality_notification_task => res0, + res1 = block_import_notification_task => res1, + res2 = protocol_message_task => res2 + } + }; + + Self { + gadget: Box::pin(gadget_task), + } + } +} + +impl Future for GadgetManager<'_> { + type Output = Result<(), GadgetError>; + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + self.gadget.as_mut().poll(cx) + } +} + +impl<'a, T: AbstractGadget + 'a> From for GadgetManager<'a> { + fn from(gadget: T) -> Self { + Self::new(gadget) + } +} diff --git a/gadget-core/src/gadget/mod.rs b/gadget-core/src/gadget/mod.rs new file mode 100644 index 00000000..383e27d7 --- /dev/null +++ b/gadget-core/src/gadget/mod.rs @@ -0,0 +1,3 @@ +pub mod manager; +#[cfg(feature = "substrate")] +pub mod substrate; diff --git a/gadget-core/src/gadget/substrate/mod.rs b/gadget-core/src/gadget/substrate/mod.rs new file mode 100644 index 00000000..36576b03 --- /dev/null +++ b/gadget-core/src/gadget/substrate/mod.rs @@ -0,0 +1,143 @@ +use crate::gadget::manager::AbstractGadget; +use async_trait::async_trait; +use futures::stream::StreamExt; +use sc_client_api::{ + Backend, BlockImportNotification, BlockchainEvents, FinalityNotification, + FinalityNotifications, HeaderBackend, ImportNotifications, +}; +use sp_api::ProvideRuntimeApi; +use sp_runtime::traits::Block; +use std::error::Error; +use std::fmt::{Debug, Display, Formatter}; +use std::sync::Arc; +use tokio::sync::Mutex; + +pub struct SubstrateGadget { + module: Module, + finality_notification_stream: Mutex>, + block_import_notification_stream: Mutex>, + client: Arc, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct SubstrateGadgetError {} + +/// Designed to plug-in to the substrate gadget +#[async_trait] +pub trait SubstrateGadgetModule: Send + Sync { + type Error: Error + Send; + type ProtocolMessage: Send; + type Block: Block; + type Backend: Backend; + type Client: Client; + + async fn get_next_protocol_message(&self) -> Option; + async fn process_finality_notification( + &self, + notification: FinalityNotification, + ) -> Result<(), Self::Error>; + async fn process_block_import_notification( + &self, + notification: BlockImportNotification, + ) -> Result<(), Self::Error>; + async fn process_protocol_message( + &self, + message: Self::ProtocolMessage, + ) -> Result<(), Self::Error>; + async fn process_error(&self, error: Self::Error); +} + +impl Display for SubstrateGadgetError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Debug::fmt(self, f) + } +} + +impl Error for SubstrateGadgetError {} + +pub trait Client: + BlockchainEvents + HeaderBackend + ProvideRuntimeApi + Send +where + B: Block, + BE: Backend, +{ +} + +impl, T: Client> Client for T {} + +impl SubstrateGadget +where + Module: SubstrateGadgetModule, +{ + pub fn new(client: Module::Client, module: Module) -> Self { + let finality_notification_stream = client.finality_notification_stream(); + let block_import_notification_stream = client.import_notification_stream(); + + Self { + module, + finality_notification_stream: Mutex::new(finality_notification_stream), + block_import_notification_stream: Mutex::new(block_import_notification_stream), + client: Arc::new(client), + } + } + + pub fn client(&self) -> &Arc { + &self.client + } +} + +#[async_trait] +impl AbstractGadget for SubstrateGadget +where + Module: SubstrateGadgetModule, +{ + type FinalityNotification = FinalityNotification; + type BlockImportNotification = BlockImportNotification; + type ProtocolMessage = Module::ProtocolMessage; + type Error = Module::Error; + + async fn get_next_finality_notification(&self) -> Option { + self.finality_notification_stream.lock().await.next().await + } + + async fn get_next_block_import_notification(&self) -> Option { + self.block_import_notification_stream + .lock() + .await + .next() + .await + } + + async fn get_next_protocol_message(&self) -> Option { + self.module.get_next_protocol_message().await + } + + async fn process_finality_notification( + &self, + notification: Self::FinalityNotification, + ) -> Result<(), Self::Error> { + self.module + .process_finality_notification(notification) + .await + } + + async fn process_block_import_notification( + &self, + notification: Self::BlockImportNotification, + ) -> Result<(), Self::Error> { + self.module + .process_block_import_notification(notification) + .await + } + + async fn process_protocol_message( + &self, + message: Self::ProtocolMessage, + ) -> Result<(), Self::Error> { + self.module.process_protocol_message(message).await + } + + async fn process_error(&self, error: Self::Error) { + self.module.process_error(error).await + } +} diff --git a/gadget-core/src/job_manager.rs b/gadget-core/src/job_manager.rs new file mode 100644 index 00000000..86ced27c --- /dev/null +++ b/gadget-core/src/job_manager.rs @@ -0,0 +1,1259 @@ +use parking_lot::RwLock; +use std::fmt::{Debug, Display}; +use std::future::Future; +use std::ops::{Add, Sub}; +use std::{ + collections::{HashMap, HashSet, VecDeque}, + hash::{Hash, Hasher}, + pin::Pin, + sync::Arc, +}; +use sync_wrapper::SyncWrapper; + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum PollMethod { + Interval { millis: u64 }, + Manual, +} + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub enum DeliveryType { + EnqueuedProtocol, + ActiveProtocol, + // Protocol for the message is not yet available + EnqueuedMessage, +} +pub struct ProtocolWorkManager { + inner: Arc>>, + utility: Arc, + // for now, use a hard-coded value for the number of tasks + max_tasks: Arc, + max_enqueued_tasks: Arc, + poll_method: Arc, +} + +impl Clone for ProtocolWorkManager { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + utility: self.utility.clone(), + max_tasks: self.max_tasks.clone(), + max_enqueued_tasks: self.max_enqueued_tasks.clone(), + poll_method: self.poll_method.clone(), + } + } +} + +pub struct WorkManagerInner { + pub active_tasks: HashSet>, + pub enqueued_tasks: VecDeque>, + // task hash => SSID => enqueued messages + pub enqueued_messages: EnqueuedMessage, +} + +pub type EnqueuedMessage = HashMap>>; + +pub trait WorkManagerInterface: Send + Sync + 'static + Sized { + type SSID: Copy + Hash + Eq + PartialEq + Send + Sync + 'static; + type Clock: Copy + + Debug + + Default + + Eq + + Ord + + PartialOrd + + PartialEq + + Send + + Sync + + Sub + + Add + + 'static; + type ProtocolMessage: ProtocolMessageMetadata + Send + Sync + 'static; + type Error: Debug + Send + Sync + 'static; + type SessionID: Copy + Hash + Eq + PartialEq + Display + Debug + Send + Sync + 'static; + type TaskID: Copy + Hash + Eq + PartialEq + Debug + Send + Sync + AsRef<[u8]> + 'static; + + fn debug(&self, input: String); + fn error(&self, input: String); + fn warn(&self, input: String); + fn clock(&self) -> Self::Clock; + fn acceptable_block_tolerance() -> Self::Clock; + fn associated_block_id_acceptable(expected: Self::Clock, received: Self::Clock) -> bool { + // Favor explicit logic for readability + let tolerance = Self::acceptable_block_tolerance(); + let is_acceptable_above = received >= expected && received <= expected + tolerance; + let is_acceptable_below = + received < expected && received >= saturating_sub(expected, tolerance); + let is_equal = expected == received; + + is_acceptable_above || is_acceptable_below || is_equal + } +} + +fn saturating_sub + Ord + Default>(a: T, b: T) -> T { + if a < b { + T::default() + } else { + a - b + } +} + +pub trait ProtocolMessageMetadata { + fn associated_block_id(&self) -> WM::Clock; + fn associated_session_id(&self) -> WM::SessionID; + fn associated_ssid(&self) -> WM::SSID; + fn associated_task(&self) -> WM::TaskID; +} + +/// The [`ProtocolRemote`] is the interface between the [`ProtocolWorkManager`] and the async protocol. +/// It *must* be unique between each async protocol. +pub trait ProtocolRemote: Send + Sync + 'static { + fn start(&self) -> Result<(), WM::Error>; + fn session_id(&self) -> WM::SessionID; + fn set_as_primary(&self); + fn has_stalled(&self, now: WM::Clock) -> bool; + fn started_at(&self) -> WM::Clock; + fn shutdown(&self, reason: ShutdownReason) -> Result<(), WM::Error>; + fn is_done(&self) -> bool; + fn deliver_message(&self, message: WM::ProtocolMessage) -> Result<(), WM::Error>; + fn has_started(&self) -> bool; + fn is_active(&self) -> bool; + fn ssid(&self) -> WM::SSID; +} + +#[derive(Debug, Eq, PartialEq)] +pub struct JobMetadata { + pub session_id: WM::SessionID, + pub is_stalled: bool, + pub is_finished: bool, + pub has_started: bool, + pub is_active: bool, +} + +impl ProtocolWorkManager { + pub fn new( + utility: WM, + max_tasks: usize, + max_enqueued_tasks: usize, + poll_method: PollMethod, + ) -> Self { + let this = Self { + inner: Arc::new(RwLock::new(WorkManagerInner { + active_tasks: HashSet::new(), + enqueued_tasks: VecDeque::new(), + enqueued_messages: HashMap::new(), + })), + utility: Arc::new(utility), + max_tasks: Arc::new(max_tasks), + max_enqueued_tasks: Arc::new(max_enqueued_tasks), + poll_method: Arc::new(poll_method), + }; + + if let PollMethod::Interval { millis } = poll_method { + let this_worker = this.clone(); + let handler = async move { + let job_receiver_worker = this_worker.clone(); + let logger = job_receiver_worker.utility.clone(); + + let periodic_poller = async move { + let mut interval = + tokio::time::interval(std::time::Duration::from_millis(millis)); + loop { + interval.tick().await; + this_worker.poll(); + } + }; + + periodic_poller.await; + logger.error("[worker] periodic_poller exited".to_string()); + }; + + tokio::task::spawn(handler); + } + + this + } + + pub fn clear_enqueued_tasks(&self) { + let mut lock = self.inner.write(); + lock.enqueued_tasks.clear(); + } + + /// Pushes the task, but does not necessarily start it + pub fn push_task( + &self, + task_hash: WM::TaskID, + force_start: bool, + handle: Arc>, + task: Pin>>, + ) -> Result<(), WorkManagerError> { + let mut lock = self.inner.write(); + // set as primary, that way on drop, the async protocol ends + handle.set_as_primary(); + let job = Job { + task: Arc::new(RwLock::new(Some(task.into()))), + handle, + task_hash, + utility: self.utility.clone(), + }; + + if force_start { + // This job has priority over the max_tasks limit + self.utility.debug(format!( + "[FORCE START] Force starting task {}", + hex::encode(task_hash) + )); + return self.start_job_unconditional(job, &mut *lock); + } + + if lock.enqueued_tasks.len() + lock.active_tasks.len() + >= *self.max_enqueued_tasks + *self.max_tasks + { + return Err(WorkManagerError::PushTaskFailed { + reason: "Too many active and enqueued tasks".to_string(), + }); + } + + lock.enqueued_tasks.push_back(job); + + drop(lock); + + if *self.poll_method != PollMethod::Manual { + self.poll(); + } + + Ok(()) + } + + pub fn can_submit_more_tasks(&self) -> bool { + let lock = self.inner.read(); + lock.enqueued_tasks.len() + lock.active_tasks.len() + < *self.max_enqueued_tasks + *self.max_tasks + } + + // Only relevant for keygen + pub fn get_active_sessions_metadata(&self, now: WM::Clock) -> Vec> { + self.inner + .read() + .active_tasks + .iter() + .map(|r| r.metadata(now)) + .collect() + } + + // This will shutdown and drop all tasks and enqueued messages + pub fn force_shutdown_all(&self) { + let mut lock = self.inner.write(); + lock.active_tasks.clear(); + lock.enqueued_tasks.clear(); + lock.enqueued_messages.clear(); + } + + pub fn poll(&self) { + // Go through each task and see if it's done + let now = self.utility.clock(); + let mut lock = self.inner.write(); + let cur_count = lock.active_tasks.len(); + lock.active_tasks.retain(|job| { + let is_stalled = job.handle.has_stalled(now); + if is_stalled { + // If stalled, lets log the start and now blocks for logging purposes + self.utility.debug(format!( + "[worker] Job {:?} | Started at {:?} | Now {:?} | is stalled, shutting down", + hex::encode(job.task_hash), + job.handle.started_at(), + now + )); + + // The task is stalled, lets be pedantic and shutdown + let _ = job.handle.shutdown(ShutdownReason::Stalled); + // Return false so that the proposals are released from the currently signing + // proposals + return false; + } + + let is_done = job.handle.is_done(); + + !is_done + }); + + let new_count = lock.active_tasks.len(); + if cur_count != new_count { + self.utility + .debug(format!("[worker] {} jobs dropped", cur_count - new_count)); + } + + // Now, check to see if there is room to start a new task + let tasks_to_start = self.max_tasks.saturating_sub(lock.active_tasks.len()); + for _ in 0..tasks_to_start { + if let Some(job) = lock.enqueued_tasks.pop_front() { + let task_hash = job.task_hash; + if let Err(err) = self.start_job_unconditional(job, &mut *lock) { + self.utility.error(format!( + "[worker] Failed to start job {:?}: {err:?}", + hex::encode(task_hash) + )); + } + } else { + break; + } + } + + // Next, remove any outdated enqueued messages to prevent RAM bloat + let mut to_remove = vec![]; + for (hash, queue) in lock.enqueued_messages.iter_mut() { + for (ssid, queue) in queue.iter_mut() { + let before = queue.len(); + // Only keep the messages that are not outdated + queue.retain(|msg| { + WM::associated_block_id_acceptable(now, msg.associated_block_id()) + }); + let after = queue.len(); + + if before != after { + self.utility.debug(format!( + "[worker] Removed {} outdated enqueued messages from the queue for {:?}", + before - after, + hex::encode(*hash) + )); + } + + if queue.is_empty() { + to_remove.push((*hash, *ssid)); + } + } + } + + // Next, to prevent the existence of piling-up empty *inner* queues, remove them + for (hash, ssid) in to_remove { + lock.enqueued_messages + .get_mut(&hash) + .expect("Should be available") + .remove(&ssid); + } + + // Finally, remove any empty outer maps + lock.enqueued_messages.retain(|_, v| !v.is_empty()); + } + + fn start_job_unconditional( + &self, + job: Job, + lock: &mut WorkManagerInner, + ) -> Result<(), WorkManagerError> { + self.utility.debug(format!( + "[worker] Starting job {:?}", + hex::encode(job.task_hash) + )); + if let Err(err) = job.handle.start() { + return Err(WorkManagerError::PushTaskFailed { + reason: format!( + "Failed to start job {:?}: {err:?}", + hex::encode(job.task_hash) + ), + }); + } else { + // deliver all the enqueued messages to the protocol now + if let Some(mut enqueued_messages_map) = lock.enqueued_messages.remove(&job.task_hash) { + let job_ssid = job.handle.ssid(); + if let Some(mut enqueued_messages) = enqueued_messages_map.remove(&job_ssid) { + self.utility.debug(format!( + "Will now deliver {} enqueued message(s) to the async protocol for {:?}", + enqueued_messages.len(), + hex::encode(job.task_hash) + )); + + while let Some(message) = enqueued_messages.pop_front() { + if should_deliver(&job, &message, job.task_hash) { + if let Err(err) = job.handle.deliver_message(message) { + self.utility.error(format!( + "Unable to deliver message for job {:?}: {err:?}", + hex::encode(job.task_hash) + )); + } + } else { + self.utility.warn("Will not deliver enqueued message to async protocol since the message is no longer acceptable".to_string()) + } + } + } + + // If there are any other messages for other SSIDs, put them back in the map + if !enqueued_messages_map.is_empty() { + lock.enqueued_messages + .insert(job.task_hash, enqueued_messages_map); + } + } + } + let task = job.task.clone(); + // Put the job inside here, that way the drop code does not get called right away, + // killing the process + lock.active_tasks.insert(job); + // run the task + let task = async move { + let task = task.write().take().expect("Should not happen"); + task.into_inner().await + }; + + // Spawn the task. When it finishes, it will clean itself up + tokio::task::spawn(task); + Ok(()) + } + + pub fn job_exists(&self, job: &WM::TaskID) -> bool { + let lock = self.inner.read(); + lock.active_tasks.iter().any(|r| &r.task_hash == job) + || lock.enqueued_tasks.iter().any(|j| &j.task_hash == job) + } + + pub fn deliver_message( + &self, + msg: WM::ProtocolMessage, + ) -> Result { + self.utility.debug(format!( + "Delivered message is intended for session_id = {}", + msg.associated_session_id() + )); + let message_task_hash = msg.associated_task(); + let mut lock = self.inner.write(); + + // check the enqueued + for task in lock.enqueued_tasks.iter() { + if should_deliver(task, &msg, message_task_hash) { + self.utility.debug(format!( + "Message is for this ENQUEUED signing execution in session: {}", + task.handle.session_id() + )); + if let Err(err) = task.handle.deliver_message(msg) { + return Err(WorkManagerError::DeliverMessageFailed { + reason: format!("{err:?}"), + }); + } + + return Ok(DeliveryType::EnqueuedProtocol); + } + } + + // check the currently signing + for task in lock.active_tasks.iter() { + if should_deliver(task, &msg, message_task_hash) { + self.utility.debug(format!( + "Message is for this signing CURRENT execution in session: {}", + task.handle.session_id() + )); + if let Err(err) = task.handle.deliver_message(msg) { + return Err(WorkManagerError::DeliverMessageFailed { + reason: format!("{err:?}"), + }); + } + + return Ok(DeliveryType::ActiveProtocol); + } + } + + // if the protocol is neither started nor enqueued, then, this message may be for a future + // async protocol. Store the message + let current_running_session_ids: Vec = lock + .active_tasks + .iter() + .map(|job| job.handle.session_id()) + .collect(); + let enqueued_session_ids: Vec = lock + .enqueued_tasks + .iter() + .map(|job| job.handle.session_id()) + .collect(); + self.utility + .debug(format!("Enqueuing message for {:?} | current_running_session_ids: {current_running_session_ids:?} | enqueued_session_ids: {enqueued_session_ids:?}", hex::encode(message_task_hash))); + lock.enqueued_messages + .entry(message_task_hash) + .or_default() + .entry(msg.associated_ssid()) + .or_default() + .push_back(msg); + + Ok(DeliveryType::EnqueuedMessage) + } +} + +pub struct Job { + task_hash: WM::TaskID, + utility: Arc, + handle: Arc>, + task: Arc>>>, +} + +impl Job { + fn metadata(&self, now: WM::Clock) -> JobMetadata { + JobMetadata:: { + session_id: self.handle.session_id(), + is_stalled: self.handle.has_stalled(now), + is_finished: self.handle.is_done(), + has_started: self.handle.has_started(), + is_active: self.handle.is_active(), + } + } +} + +#[derive(Debug, Clone)] +pub enum WorkManagerError { + PushTaskFailed { reason: String }, + DeliverMessageFailed { reason: String }, +} + +pub enum ShutdownReason { + Stalled, + DropCode, +} + +pub trait SendFuture<'a, T>: Send + Future + 'a {} +impl<'a, F: Send + Future + 'a, T> SendFuture<'a, T> for F {} + +pub type SyncFuture = SyncWrapper>>>; + +impl PartialEq for Job { + fn eq(&self, other: &Self) -> bool { + self.task_hash == other.task_hash + } +} + +impl Eq for Job {} + +impl Hash for Job { + fn hash(&self, state: &mut H) { + self.task_hash.hash(state); + } +} + +impl Drop for Job { + fn drop(&mut self) { + self.utility.debug(format!( + "Will remove job {:?} from currently_signing_proposals", + hex::encode(self.task_hash) + )); + let _ = self.handle.shutdown(ShutdownReason::DropCode); + } +} + +fn should_deliver( + task: &Job, + msg: &WM::ProtocolMessage, + message_task_hash: WM::TaskID, +) -> bool { + task.handle.session_id() == msg.associated_session_id() + && task.task_hash == message_task_hash + && task.handle.ssid() == msg.associated_ssid() + && WM::associated_block_id_acceptable( + task.handle.started_at(), // use to be associated_block_id + msg.associated_block_id(), + ) +} + +#[cfg(test)] +mod tests { + + use super::*; + use parking_lot::Mutex; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::time::Duration; + use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; + + #[derive(Debug, Eq, PartialEq)] + struct TestWorkManager; + #[derive(Clone, Eq, PartialEq, Debug)] + pub struct TestMessage { + message: String, + associated_block_id: u64, + associated_session_id: u32, + associated_ssid: u32, + associated_task: [u8; 32], + } + + impl ProtocolMessageMetadata for TestMessage { + fn associated_block_id(&self) -> u64 { + self.associated_block_id + } + fn associated_session_id(&self) -> u32 { + self.associated_session_id + } + fn associated_ssid(&self) -> u32 { + self.associated_ssid + } + fn associated_task(&self) -> [u8; 32] { + self.associated_task + } + } + + impl WorkManagerInterface for TestWorkManager { + type SSID = u32; + type Clock = u64; + type ProtocolMessage = TestMessage; + type Error = (); + type SessionID = u32; + type TaskID = [u8; 32]; + + fn debug(&self, _input: String) {} + fn error(&self, input: String) { + println!("ERROR: {input}") + } + fn warn(&self, input: String) { + println!("WARN: {input}") + } + fn clock(&self) -> Self::Clock { + 0 + } + + fn acceptable_block_tolerance() -> Self::Clock { + 0 + } + } + + #[derive(Clone)] + pub struct TestProtocolRemote { + session_id: u32, + ssid: u32, + started_at: u64, + delivered_messages: UnboundedSender, + start_tx: Arc>>>, + is_done: Arc, + has_started: Arc, + } + + impl TestProtocolRemote { + pub fn new( + session_id: u32, + ssid: u32, + started_at: u64, + task_tx: UnboundedSender, + start_tx: tokio::sync::oneshot::Sender<()>, + is_done: Arc, + ) -> Arc { + Arc::new(Self { + session_id, + ssid, + started_at, + delivered_messages: task_tx, + start_tx: Arc::new(Mutex::new(Some(start_tx))), + is_done, + has_started: Arc::new(AtomicBool::new(false)), + }) + } + } + + #[allow(clippy::type_complexity)] + pub fn generate_async_protocol( + session_id: u32, + ssid: u32, + started_at: u64, + ) -> ( + Arc, + Pin>>, + UnboundedReceiver, + ) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (start_tx, start_rx) = tokio::sync::oneshot::channel(); + let is_done = Arc::new(AtomicBool::new(false)); + let remote = + TestProtocolRemote::new(session_id, ssid, started_at, tx, start_tx, is_done.clone()); + let task = async move { + start_rx.await.unwrap(); + is_done.store(true, Ordering::SeqCst); + }; + let task = Box::pin(task); + (remote, task, rx) + } + + impl ProtocolRemote for TestProtocolRemote { + fn start(&self) -> Result<(), ()> { + self.start_tx.lock().take().unwrap().send(())?; + self.has_started.store(true, Ordering::SeqCst); + Ok(()) + } + fn session_id(&self) -> u32 { + self.session_id + } + fn set_as_primary(&self) {} + fn has_stalled(&self, now: u64) -> bool { + now > self.started_at + } + fn started_at(&self) -> u64 { + self.started_at + } + fn shutdown(&self, _reason: ShutdownReason) -> Result<(), ()> { + Ok(()) + } + fn is_done(&self) -> bool { + self.is_done.load(Ordering::SeqCst) + } + fn deliver_message(&self, message: TestMessage) -> Result<(), ()> { + self.delivered_messages.send(message).map_err(|_| ()) + } + fn has_started(&self) -> bool { + self.has_started.load(Ordering::SeqCst) + } + fn is_active(&self) -> bool { + self.has_started() && !self.is_done() + } + fn ssid(&self) -> u32 { + self.ssid + } + } + + #[tokio::test] + async fn test_push_task() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote, task, _rx) = generate_async_protocol(1, 1, 0); + let result = work_manager.push_task([0; 32], false, remote, task); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_deliver_message() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote, task, mut rx) = generate_async_protocol(0, 0, 0); + + work_manager + .push_task(Default::default(), true, remote.clone(), task) + .unwrap(); + + let message = TestMessage { + message: "test".to_string(), + associated_block_id: 0, + associated_session_id: 0, + associated_ssid: 0, + associated_task: [0; 32], + }; + assert_ne!( + DeliveryType::EnqueuedMessage, + work_manager.deliver_message(message).unwrap() + ); + let _ = rx.recv().await.unwrap(); + } + + #[tokio::test] + async fn test_job_exists() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote, task, _rx) = generate_async_protocol(1, 1, 0); + + let result = work_manager.job_exists(&[0; 32]); + assert!(!result); + + work_manager.push_task([0; 32], true, remote, task).unwrap(); + + let result = work_manager.job_exists(&[0; 32]); + assert!(result); + } + + #[tokio::test] + async fn test_add_multiple_tasks() { + let work_manager = ProtocolWorkManager::new( + TestWorkManager, + 2, // max 2 tasks + 0, + PollMethod::Manual, + ); + + let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); + let (remote2, task2, _rx) = generate_async_protocol(2, 2, 0); + let (remote3, task3, _rx) = generate_async_protocol(3, 3, 0); + + // Add 2 tasks, should succeed + assert!(work_manager + .push_task([1; 32], false, remote1, task1) + .is_ok()); + assert!(work_manager + .push_task([2; 32], false, remote2, task2) + .is_ok()); + + // Try to add a third, should fail + assert!(work_manager + .push_task([3; 32], false, remote3, task3) + .is_err()); + } + + #[tokio::test] + async fn test_deliver_to_queued_task() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote, task, mut rx) = generate_async_protocol(1, 1, 0); + + // Add a queued task + work_manager + .push_task([0; 32], false, remote.clone(), task) + .unwrap(); + + // Deliver message, should succeed + let msg = TestMessage { + message: "test".to_string(), + associated_block_id: 0, + associated_session_id: 1, + associated_ssid: 1, + associated_task: [0; 32], + }; + assert_ne!( + DeliveryType::EnqueuedMessage, + work_manager.deliver_message(msg.clone()).unwrap() + ); + let next_message = rx.recv().await.unwrap(); + assert_eq!(next_message, msg); + } + + #[tokio::test] + async fn test_get_task_metadata() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); + let (remote2, task2, _rx) = generate_async_protocol(2, 2, 0); + + work_manager + .push_task([1; 32], true, remote1, task1) + .unwrap(); + work_manager + .push_task([2; 32], true, remote2, task2) + .unwrap(); + + let now = 0; + let metadata = work_manager.get_active_sessions_metadata(now); + + assert_eq!(metadata.len(), 2); + let expected1 = JobMetadata { + session_id: 1, + is_stalled: false, + is_finished: false, + has_started: true, + is_active: true, + }; + + let expected2 = JobMetadata { + session_id: 2, + is_stalled: false, + is_finished: false, + has_started: true, + is_active: true, + }; + + assert!(metadata.contains(&expected1)); + assert!(metadata.contains(&expected2)); + + // Now, start the tasks + work_manager.poll(); + // Wait some time for the tasks to finish + tokio::time::sleep(Duration::from_millis(100)).await; + // Poll again to cleanup + work_manager.poll(); + // Re-check the statuses + let metadata = work_manager.get_active_sessions_metadata(now); + + assert!(metadata.is_empty()); + } + + #[tokio::test] + async fn test_get_task_metadata_no_force_start() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); + let (remote2, task2, _rx) = generate_async_protocol(2, 2, 0); + + let now = 0; + + work_manager + .push_task([1; 32], false, remote1, task1) + .unwrap(); + work_manager + .push_task([2; 32], false, remote2, task2) + .unwrap(); + + let metadata = work_manager.get_active_sessions_metadata(now); + assert!(metadata.is_empty()); + + // Now, poll to start the tasks + work_manager.poll(); + + let metadata = work_manager.get_active_sessions_metadata(now); + + assert_eq!(metadata.len(), 2); + let expected1 = JobMetadata { + session_id: 1, + is_stalled: false, + is_finished: false, + has_started: true, + is_active: true, + }; + + let expected2 = JobMetadata { + session_id: 2, + is_stalled: false, + is_finished: false, + has_started: true, + is_active: true, + }; + + assert!(metadata.contains(&expected1)); + assert!(metadata.contains(&expected2)); + + // Wait some time for the tasks to finish + tokio::time::sleep(Duration::from_millis(100)).await; + // Poll again to cleanup + work_manager.poll(); + // Re-check the statuses + let metadata = work_manager.get_active_sessions_metadata(now); + + assert!(metadata.is_empty()); + } + + #[tokio::test] + async fn test_force_shutdown_all() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); + let (remote2, task2, _rx) = generate_async_protocol(2, 2, 0); + work_manager + .push_task([1; 32], true, remote1, task1) + .unwrap(); + work_manager + .push_task([2; 32], true, remote2, task2) + .unwrap(); + + // Verify that the tasks were added + assert!(work_manager.job_exists(&[1; 32])); + assert!(work_manager.job_exists(&[2; 32])); + + // Force shutdown all tasks + work_manager.force_shutdown_all(); + + // Verify that the tasks were removed + assert!(!work_manager.job_exists(&[1; 32])); + assert!(!work_manager.job_exists(&[2; 32])); + } + + #[tokio::test] + async fn test_clear_enqueued_tasks() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); + let (remote2, task2, _rx) = generate_async_protocol(2, 2, 0); + work_manager + .push_task([1; 32], false, remote1, task1) + .unwrap(); + work_manager + .push_task([2; 32], false, remote2, task2) + .unwrap(); + + // Verify that the tasks were added + assert!(work_manager.job_exists(&[1; 32])); + assert!(work_manager.job_exists(&[2; 32])); + + // Clear enqueued tasks + work_manager.clear_enqueued_tasks(); + + // Verify that the tasks were removed + assert!(!work_manager.job_exists(&[1; 32])); + assert!(!work_manager.job_exists(&[2; 32])); + } + + #[tokio::test] + async fn test_max_tasks_limit() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 2, 2, PollMethod::Manual); + + let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); + let (remote2, task2, _rx) = generate_async_protocol(2, 2, 0); + let (remote3, task3, _rx) = generate_async_protocol(3, 3, 0); + let (remote4, task4, _rx) = generate_async_protocol(4, 4, 0); + let (remote5, task5, _rx) = generate_async_protocol(5, 5, 0); + + assert!(work_manager + .push_task([1; 32], true, remote1, task1) + .is_ok()); + assert!(work_manager + .push_task([2; 32], true, remote2, task2) + .is_ok()); + assert!(work_manager + .push_task([3; 32], false, remote3, task3) + .is_ok()); + assert!(work_manager + .push_task([4; 32], false, remote4, task4) + .is_ok()); + + // Try to add a fifth task, should fail + assert!(work_manager + .push_task([5; 32], false, remote5, task5) + .is_err()); + } + + #[tokio::test] + async fn test_task_completion() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); + + work_manager + .push_task([1; 32], true, remote1, task1) + .unwrap(); + + // Wait some time for the tasks to finish + tokio::time::sleep(Duration::from_millis(100)).await; + // Poll again to cleanup + work_manager.poll(); + + // Check that the task has completed + assert!(!work_manager.job_exists(&[1; 32])); + } + + #[tokio::test] + async fn test_job_removal_on_drop() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); + + work_manager + .push_task([1; 32], true, remote1, task1) + .unwrap(); + + // Manual drop of all jobs + work_manager.force_shutdown_all(); + + // Check that the job has been removed + assert!(!work_manager.job_exists(&[1; 32])); + } + + #[tokio::test] + async fn test_message_delivery_to_non_existent_job() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + + let message = TestMessage { + message: "test".to_string(), + associated_block_id: 0, + associated_session_id: 1, + associated_ssid: 1, + associated_task: [0; 32], + }; + + // Deliver a message to a non-existent job + let delivery_type = work_manager.deliver_message(message).unwrap(); + + // The message should be enqueued for future use + assert_eq!(delivery_type, DeliveryType::EnqueuedMessage); + } + + #[tokio::test] + async fn test_message_delivery_to_job_with_outdated_block_id() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); + + work_manager + .push_task([0; 32], true, remote1, task1) + .unwrap(); + + let message = TestMessage { + message: "test".to_string(), + associated_block_id: 10, // Outdated block ID + associated_session_id: 1, + associated_ssid: 1, + associated_task: [0; 32], + }; + + // Try to deliver a message with an outdated block ID + let delivery_type = work_manager.deliver_message(message).unwrap(); + + // The message should be enqueued for future use + assert_eq!(delivery_type, DeliveryType::EnqueuedMessage); + } + + #[tokio::test] + async fn test_can_submit_more_tasks() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 1, 0, PollMethod::Manual); + let (remote, task, _rx) = generate_async_protocol(1, 1, 0); + + assert!(work_manager.can_submit_more_tasks()); + work_manager + .push_task([1; 32], false, remote, task) + .unwrap(); + assert!(!work_manager.can_submit_more_tasks()); + } + + #[tokio::test] + async fn test_multiple_messages_single_task() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote, task, mut rx) = generate_async_protocol(1, 1, 0); + + work_manager + .push_task([0; 32], true, remote.clone(), task) + .unwrap(); + + let message = TestMessage { + message: "test".to_string(), + associated_block_id: 0, + associated_session_id: 1, + associated_ssid: 1, + associated_task: [0; 32], + }; + + for _ in 0..10 { + assert_ne!( + DeliveryType::EnqueuedMessage, + work_manager.deliver_message(message.clone()).unwrap() + ); + } + + for _ in 0..10 { + let next_message = rx.recv().await.unwrap(); + assert_eq!(next_message, message); + } + } + + #[tokio::test] + async fn test_task_not_started() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 1, 0, PollMethod::Manual); + let (remote, task, _rx) = generate_async_protocol(1, 1, 0); + + work_manager + .push_task([1; 32], false, remote, task) + .unwrap(); // task is enqueued but not started + assert!(!work_manager + .get_active_sessions_metadata(0) + .contains(&JobMetadata { + session_id: 1, + is_stalled: false, + is_finished: false, + has_started: false, + is_active: false, + })); + } + + #[tokio::test] + async fn test_can_submit_more_tasks_with_enqueued_tasks() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 1, 1, PollMethod::Manual); + let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); + let (remote2, task2, _rx) = generate_async_protocol(2, 2, 0); + + assert!(work_manager.can_submit_more_tasks()); + work_manager + .push_task([1; 32], true, remote1, task1) + .unwrap(); + assert!(work_manager.can_submit_more_tasks()); // one more task can be enqueued + work_manager + .push_task([2; 32], false, remote2, task2) + .unwrap(); + assert!(!work_manager.can_submit_more_tasks()); // no more tasks can be added + } + + #[tokio::test] + async fn test_message_delivery_to_stalled_task() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 1, 0, PollMethod::Manual); + let (remote, task, _rx) = generate_async_protocol(1, 1, 1); // Task will stall immediately + let msg = TestMessage { + message: "test".to_string(), + associated_block_id: 0, + associated_session_id: 1, + associated_ssid: 1, + associated_task: [0; 32], + }; + + work_manager.push_task([0; 32], true, remote, task).unwrap(); + work_manager.poll(); // should identify and remove the stalled task + + let delivery_type = work_manager.deliver_message(msg).unwrap(); + assert_eq!(delivery_type, DeliveryType::EnqueuedMessage); // message should be enqueued because task is stalled and removed + } + + #[tokio::test] + async fn test_message_delivery_with_incorrect_task_hash() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 1, 0, PollMethod::Manual); + let (remote, task, _rx) = generate_async_protocol(1, 1, 0); + let msg = TestMessage { + message: "test".to_string(), + associated_block_id: 0, + associated_session_id: 1, + associated_ssid: 1, + associated_task: [1; 32], // incorrect task hash, not 0;32 as below + }; + + work_manager.push_task([0; 32], true, remote, task).unwrap(); + + let delivery_type = work_manager.deliver_message(msg).unwrap(); // incorrect task hash + assert_eq!(delivery_type, DeliveryType::EnqueuedMessage); // message should be enqueued because the task hash is incorrect + } + + struct DummyRangeChecker; + struct DummyProtocolMessage; + + impl ProtocolMessageMetadata> for DummyProtocolMessage { + fn associated_block_id(&self) -> u64 { + 0u64 + } + fn associated_session_id(&self) -> u64 { + 0u64 + } + fn associated_ssid(&self) -> u64 { + 0u64 + } + fn associated_task(&self) -> [u8; 32] { + [0; 32] + } + } + + impl WorkManagerInterface for DummyRangeChecker { + type SSID = u64; + type Clock = u64; + type ProtocolMessage = DummyProtocolMessage; + type Error = (); + type SessionID = u64; + type TaskID = [u8; 32]; + + fn debug(&self, _input: String) { + todo!() + } + + fn error(&self, _input: String) { + todo!() + } + + fn warn(&self, _input: String) { + todo!() + } + + fn clock(&self) -> Self::Clock { + todo!() + } + + fn acceptable_block_tolerance() -> Self::Clock { + N + } + } + + #[test] + fn test_range_above() { + let current_block: u64 = 10; + const TOL: u64 = 5; + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block + )); + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block + 1 + )); + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block + TOL + )); + assert!(!DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block + TOL + 1 + )); + } + + #[test] + fn test_range_below() { + let current_block: u64 = 10; + const TOL: u64 = 5; + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block + )); + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block - 1 + )); + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block - TOL + )); + assert!(!DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block - TOL - 1 + )); + } +} diff --git a/gadget-core/src/lib.rs b/gadget-core/src/lib.rs new file mode 100644 index 00000000..f0906143 --- /dev/null +++ b/gadget-core/src/lib.rs @@ -0,0 +1,5 @@ +pub use sc_client_api::Backend; +pub use sp_runtime::traits::Block; + +pub mod gadget; +pub mod job_manager; diff --git a/resources/gadget.png b/resources/gadget.png new file mode 100644 index 00000000..918705b5 Binary files /dev/null and b/resources/gadget.png differ diff --git a/webb-gadget/Cargo.toml b/webb-gadget/Cargo.toml new file mode 100644 index 00000000..4a57bd94 --- /dev/null +++ b/webb-gadget/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "webb-gadget" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] + +[dependencies] +mpc-net = { workspace = true } +tokio-rustls = { workspace = true } +gadget-core = { workspace = true, features = ["substrate"] } +tokio = { workspace = true } +serde = { workspace = true, features = ["derive"] } +tokio-util = { version = "0.7.9" } +async-trait = { workspace = true } +log = "0.4.20" +parking_lot = "0.12.1" +auto_impl = "1.1.0" +sc-client-api = { workspace = true } +sp-core = { workspace = true } +sp-runtime = { workspace = true } \ No newline at end of file diff --git a/webb-gadget/src/gadget/message.rs b/webb-gadget/src/gadget/message.rs new file mode 100644 index 00000000..d2ee3196 --- /dev/null +++ b/webb-gadget/src/gadget/message.rs @@ -0,0 +1,36 @@ +use crate::gadget::work_manager::WebbWorkManager; +use gadget_core::job_manager::{ProtocolMessageMetadata, WorkManagerInterface}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub struct GadgetProtocolMessage { + pub associated_block_id: ::Clock, + pub associated_session_id: ::SessionID, + pub associated_ssid: ::SSID, + pub from: UserID, + // If None, this is a broadcasted message + pub to: Option, + // A unique marker for the associated task this message belongs to + pub task_hash: ::TaskID, + pub payload: Vec, +} + +pub type UserID = u32; + +impl ProtocolMessageMetadata for GadgetProtocolMessage { + fn associated_block_id(&self) -> ::Clock { + self.associated_block_id + } + + fn associated_session_id(&self) -> ::SessionID { + self.associated_session_id + } + + fn associated_ssid(&self) -> ::SSID { + self.associated_ssid + } + + fn associated_task(&self) -> ::TaskID { + self.task_hash + } +} diff --git a/webb-gadget/src/gadget/mod.rs b/webb-gadget/src/gadget/mod.rs new file mode 100644 index 00000000..0947bf0c --- /dev/null +++ b/webb-gadget/src/gadget/mod.rs @@ -0,0 +1,125 @@ +use crate::gadget::message::GadgetProtocolMessage; +use crate::gadget::network::Network; +use crate::gadget::work_manager::WebbWorkManager; +use crate::Error; +use async_trait::async_trait; +use gadget_core::gadget::substrate::{Client, SubstrateGadgetModule}; +use gadget_core::job_manager::{PollMethod, ProtocolWorkManager}; +use parking_lot::RwLock; +use sc_client_api::{Backend, BlockImportNotification, FinalityNotification}; +use sp_runtime::traits::{Block, Header}; +use sp_runtime::SaturatedConversion; +use std::marker::PhantomData; +use std::sync::Arc; +use tokio::sync::Mutex; + +pub mod message; +pub mod network; +pub mod work_manager; + +/// Used as a module to place inside the SubstrateGadget +pub struct WebbGadget { + #[allow(dead_code)] + network: N, + module: M, + job_manager: ProtocolWorkManager, + from_network: Mutex>, + clock: Arc>>, + _pd: PhantomData<(B, C, BE)>, +} + +const MAX_ACTIVE_TASKS: usize = 4; +const MAX_PENDING_TASKS: usize = 4; + +impl, B: Block, BE: Backend, N: Network, M: WebbGadgetModule> + WebbGadget +{ + pub fn new(mut network: N, module: M, now: Option) -> Self { + let clock = Arc::new(RwLock::new(now)); + let clock_clone = clock.clone(); + let from_registry = network.take_message_receiver().expect("Should exist"); + + let job_manager_zk = WebbWorkManager::new(move || *clock_clone.read()); + + let job_manager = ProtocolWorkManager::new( + job_manager_zk, + MAX_ACTIVE_TASKS, + MAX_PENDING_TASKS, + PollMethod::Interval { millis: 200 }, + ); + + WebbGadget { + module, + network, + job_manager, + clock, + from_network: Mutex::new(from_registry), + _pd: Default::default(), + } + } +} + +#[async_trait] +impl, B: Block, BE: Backend, N: Network, M: WebbGadgetModule> + SubstrateGadgetModule for WebbGadget +{ + type Error = Error; + type ProtocolMessage = GadgetProtocolMessage; + type Block = B; + type Backend = BE; + type Client = C; + + async fn get_next_protocol_message(&self) -> Option { + self.from_network.lock().await.recv().await + } + + async fn process_finality_notification( + &self, + notification: FinalityNotification, + ) -> Result<(), Self::Error> { + let now: u64 = (*notification.header.number()).saturated_into(); + *self.clock.write() = Some(now); + self.module + .process_finality_notification(notification, now, &self.job_manager) + .await + } + + async fn process_block_import_notification( + &self, + notification: BlockImportNotification, + ) -> Result<(), Self::Error> { + self.module + .process_block_import_notification(notification, &self.job_manager) + .await + } + + async fn process_protocol_message( + &self, + message: Self::ProtocolMessage, + ) -> Result<(), Self::Error> { + self.job_manager + .deliver_message(message) + .map(|_| ()) + .map_err(|err| Error::WorkManagerError { err }) + } + + async fn process_error(&self, error: Self::Error) { + self.module.process_error(error, &self.job_manager).await + } +} + +#[async_trait] +pub trait WebbGadgetModule: Send + Sync { + async fn process_finality_notification( + &self, + notification: FinalityNotification, + now: u64, + job_manager: &ProtocolWorkManager, + ) -> Result<(), Error>; + async fn process_block_import_notification( + &self, + notification: BlockImportNotification, + job_manager: &ProtocolWorkManager, + ) -> Result<(), Error>; + async fn process_error(&self, error: Error, job_manager: &ProtocolWorkManager); +} diff --git a/webb-gadget/src/gadget/network.rs b/webb-gadget/src/gadget/network.rs new file mode 100644 index 00000000..7867377e --- /dev/null +++ b/webb-gadget/src/gadget/network.rs @@ -0,0 +1,6 @@ +use crate::gadget::message::GadgetProtocolMessage; +use tokio::sync::mpsc::UnboundedReceiver; + +pub trait Network: Send + Sync { + fn take_message_receiver(&mut self) -> Option>; +} diff --git a/webb-gadget/src/gadget/work_manager.rs b/webb-gadget/src/gadget/work_manager.rs new file mode 100644 index 00000000..a9fc188a --- /dev/null +++ b/webb-gadget/src/gadget/work_manager.rs @@ -0,0 +1,56 @@ +use crate::gadget::message::GadgetProtocolMessage; +use gadget_core::job_manager::WorkManagerInterface; +use std::sync::Arc; + +pub struct WebbWorkManager { + pub(crate) clock: Arc< + dyn Fn() -> Option<::Clock> + + Send + + Sync + + 'static, + >, +} + +impl WebbWorkManager { + pub fn new( + clock: impl Fn() -> Option<::Clock> + + Send + + Sync + + 'static, + ) -> Self { + Self { + clock: Arc::new(clock), + } + } +} + +const ACCEPTABLE_BLOCK_TOLERANCE: u64 = 5; + +impl WorkManagerInterface for WebbWorkManager { + type SSID = u16; + type Clock = u64; + type ProtocolMessage = GadgetProtocolMessage; + type Error = crate::Error; + type SessionID = u64; + type TaskID = [u8; 32]; + + fn debug(&self, input: String) { + log::debug!(target: "gadget", "{input}") + } + + fn error(&self, input: String) { + log::error!(target: "gadget", "{input}") + } + + fn warn(&self, input: String) { + log::warn!(target: "gadget", "{input}") + } + + fn clock(&self) -> Self::Clock { + (self.clock)().expect("No finality notification received") + } + + fn acceptable_block_tolerance() -> Self::Clock { + ACCEPTABLE_BLOCK_TOLERANCE + } +} diff --git a/webb-gadget/src/lib.rs b/webb-gadget/src/lib.rs new file mode 100644 index 00000000..0105e022 --- /dev/null +++ b/webb-gadget/src/lib.rs @@ -0,0 +1,46 @@ +use crate::gadget::network::Network; +use crate::gadget::{WebbGadget, WebbGadgetModule}; +use gadget_core::gadget::manager::{GadgetError, GadgetManager}; +use gadget_core::gadget::substrate::{Client, SubstrateGadget}; +use gadget_core::job_manager::WorkManagerError; +pub use sc_client_api::BlockImportNotification; +pub use sc_client_api::{Backend, FinalityNotification}; +use sp_runtime::traits::Block; +use std::fmt::{Debug, Display, Formatter}; + +pub mod gadget; + +#[derive(Debug)] +pub enum Error { + RegistryCreateError { err: String }, + RegistrySendError { err: String }, + RegistryRecvError { err: String }, + RegistryListenError { err: String }, + GadgetManagerError { err: GadgetError }, + InitError { err: String }, + WorkManagerError { err: WorkManagerError }, +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Debug::fmt(self, f) + } +} + +impl std::error::Error for Error {} + +pub async fn run, B: Block, BE: Backend, N: Network, M: WebbGadgetModule>( + network: N, + module: M, + client: C, +) -> Result<(), Error> { + let now = None; + let webb_gadget = WebbGadget::new(network, module, now); + // Plug the module into the substrate gadget to interface the WebbGadget with Substrate + let substrate_gadget = SubstrateGadget::new(client, webb_gadget); + + // Run the GadgetManager to execute the substrate gadget + GadgetManager::new(substrate_gadget) + .await + .map_err(|err| Error::GadgetManagerError { err }) +} diff --git a/zk-gadget/Cargo.toml b/zk-gadget/Cargo.toml new file mode 100644 index 00000000..077aa763 --- /dev/null +++ b/zk-gadget/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "zk-gadget" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio-rustls = { workspace = true } +mpc-net = { workspace = true } +webb-gadget = { workspace = true } +gadget-core = { workspace = true } +bincode2 = { workspace = true } +tokio = { workspace = true } +futures-util = { workspace = true } +serde = { workspace = true, features = ["derive"] } +async-trait = { workspace = true } \ No newline at end of file diff --git a/zk-gadget/src/lib.rs b/zk-gadget/src/lib.rs new file mode 100644 index 00000000..c128b9bd --- /dev/null +++ b/zk-gadget/src/lib.rs @@ -0,0 +1,58 @@ +use crate::module::ZkModule; +use crate::network::RegistantId; +use gadget_core::gadget::substrate::Client; +use gadget_core::{Backend, Block}; +use mpc_net::prod::RustlsCertificate; +use std::net::SocketAddr; +use tokio_rustls::rustls::{Certificate, PrivateKey, RootCertStore}; +use webb_gadget::Error; + +pub mod module; +pub mod network; + +pub struct ZkGadgetConfig { + king_bind_addr: Option, + client_only_king_addr: Option, + id: RegistantId, + public_identity_der: Vec, + private_identity_der: Vec, + client_only_king_public_identity_der: Option>, +} + +pub async fn run, B: Block, BE: Backend>( + config: ZkGadgetConfig, + client: C, +) -> Result<(), Error> { + // Create the zk gadget module + let our_identity = RustlsCertificate { + cert: Certificate(config.public_identity_der), + private_key: PrivateKey(config.private_identity_der), + }; + + let network = if let Some(addr) = &config.king_bind_addr { + network::ZkNetworkService::new_king(*addr, our_identity).await? + } else { + let king_addr = config + .client_only_king_addr + .expect("King address must be specified if king bind address is not specified"); + + let mut king_certs = RootCertStore::empty(); + king_certs + .add(&Certificate( + config + .client_only_king_public_identity_der + .expect("The client must specify the identity of the king"), + )) + .map_err(|err| Error::InitError { + err: err.to_string(), + })?; + + network::ZkNetworkService::new_client(king_addr, config.id, our_identity, king_certs) + .await? + }; + + let zk_module = ZkModule {}; // TODO: proper implementation + + // Plug the module into the webb gadget + webb_gadget::run(network, zk_module, client).await +} diff --git a/zk-gadget/src/module/mod.rs b/zk-gadget/src/module/mod.rs new file mode 100644 index 00000000..2c326d3f --- /dev/null +++ b/zk-gadget/src/module/mod.rs @@ -0,0 +1,36 @@ +use async_trait::async_trait; +use gadget_core::job_manager::ProtocolWorkManager; +use gadget_core::Block; +use webb_gadget::gadget::work_manager::WebbWorkManager; +use webb_gadget::gadget::WebbGadgetModule; +use webb_gadget::{BlockImportNotification, Error, FinalityNotification}; + +pub struct ZkModule {} + +#[async_trait] +impl WebbGadgetModule for ZkModule { + async fn process_finality_notification( + &self, + _notification: FinalityNotification, + _now: u64, + _job_manager: &ProtocolWorkManager, + ) -> Result<(), Error> { + todo!() + } + + async fn process_block_import_notification( + &self, + _notification: BlockImportNotification, + _job_manager: &ProtocolWorkManager, + ) -> Result<(), Error> { + todo!() + } + + async fn process_error( + &self, + _error: Error, + _job_manager: &ProtocolWorkManager, + ) { + todo!() + } +} diff --git a/zk-gadget/src/network/mod.rs b/zk-gadget/src/network/mod.rs new file mode 100644 index 00000000..7292659e --- /dev/null +++ b/zk-gadget/src/network/mod.rs @@ -0,0 +1,367 @@ +use futures_util::sink::SinkExt; +use futures_util::StreamExt; +use mpc_net::multi::WrappedStream; +use mpc_net::prod::{CertToDer, RustlsCertificate}; +use mpc_net::MpcNetError; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; +use tokio::sync::mpsc::UnboundedReceiver; +use tokio::sync::Mutex; +use tokio_rustls::rustls::server::NoClientAuth; +use tokio_rustls::rustls::{RootCertStore, ServerConfig}; +use tokio_rustls::{rustls, TlsAcceptor, TlsStream}; + +/// Type should correspond to the on-chain identifier of the registrant +pub type RegistantId = u64; + +pub enum ZkNetworkService { + King { + listener: Option, + registrants: Arc>>, + to_gadget: tokio::sync::mpsc::UnboundedSender, + from_registry: Option>, + identity: RustlsCertificate, + }, + Client { + king_registry_addr: SocketAddr, + registrant_id: RegistantId, + connection: Option>, + cert_der: Vec, + to_gadget: tokio::sync::mpsc::UnboundedSender, + from_registry: Option>, + }, +} + +#[allow(dead_code)] +pub struct Registrant { + id: RegistantId, + cert_der: Vec, +} + +use crate::Error; +use webb_gadget::gadget::message::GadgetProtocolMessage; +use webb_gadget::gadget::network::Network; + +pub fn create_server_tls_acceptor( + server_certificate: T, +) -> Result { + let client_auth = NoClientAuth::boxed(); + let server_config = ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(client_auth) + .with_single_cert( + vec![rustls::Certificate( + server_certificate.serialize_certificate_to_der()?, + )], + rustls::PrivateKey(server_certificate.serialize_private_key_to_der()?), + ) + .unwrap(); + Ok(TlsAcceptor::from(Arc::new(server_config))) +} + +impl ZkNetworkService { + pub async fn new_king( + bind_addr: T, + identity: RustlsCertificate, + ) -> Result { + let bind_addr: SocketAddr = bind_addr + .to_socket_addrs() + .map_err(|err| Error::RegistryCreateError { + err: err.to_string(), + })? + .next() + .ok_or(Error::RegistryCreateError { + err: "No address found".to_string(), + })?; + + let listener = tokio::net::TcpListener::bind(bind_addr) + .await + .map_err(|err| Error::RegistryCreateError { + err: err.to_string(), + })?; + let registrants = Arc::new(Mutex::new(HashMap::new())); + let (to_gadget, from_registry) = tokio::sync::mpsc::unbounded_channel(); + Ok(ZkNetworkService::King { + listener: Some(listener), + registrants, + to_gadget, + identity, + from_registry: Some(from_registry), + }) + } + + pub async fn new_client( + king_registry_addr: T, + registrant_id: RegistantId, + client_identity: RustlsCertificate, + king_certs: RootCertStore, + ) -> Result { + let king_registry_addr: SocketAddr = king_registry_addr + .to_socket_addrs() + .map_err(|err| Error::RegistryCreateError { + err: err.to_string(), + })? + .next() + .ok_or(Error::RegistryCreateError { + err: "No address found".to_string(), + })?; + + let cert_der = client_identity.cert.0.clone(); + + let connection = TcpStream::connect(king_registry_addr) + .await + .map_err(|err| Error::RegistryCreateError { + err: err.to_string(), + })?; + + // Upgrade to TLS + let tls = mpc_net::prod::create_client_mutual_tls_connector(king_certs, client_identity) + .map_err(|err| Error::RegistryCreateError { + err: format!("{err:?}"), + })?; + + let connection = tls + .connect( + tokio_rustls::rustls::ServerName::IpAddress(king_registry_addr.ip()), + connection, + ) + .await + .map_err(|err| Error::RegistryCreateError { + err: err.to_string(), + })?; + + let (to_gadget, from_registry) = tokio::sync::mpsc::unbounded_channel(); + + let mut this = ZkNetworkService::Client { + king_registry_addr, + registrant_id, + cert_der, + connection: Some(TlsStream::Client(connection)), + to_gadget, + from_registry: Some(from_registry), + }; + + this.client_register().await?; + + Ok(this) + } + + pub async fn run(self) -> Result<(), Error> { + match self { + Self::King { + listener, + registrants, + to_gadget, + identity, + .. + } => { + let listener = listener.expect("Should exist"); + let tls_acceptor = create_server_tls_acceptor(identity).map_err(|err| { + Error::RegistryCreateError { + err: format!("{err:?}"), + } + })?; + + while let Ok((stream, peer_addr)) = listener.accept().await { + println!("[Registry] Accepted connection from {peer_addr}, upgrading to TLS"); + let stream = tls_acceptor.accept(stream).await.map_err(|err| { + Error::RegistryCreateError { + err: format!("{err:?}"), + } + })?; + + handle_stream_as_king( + TlsStream::Server(stream), + peer_addr, + registrants.clone(), + to_gadget.clone(), + ); + } + + Err(Error::RegistryCreateError { + err: "Listener closed".to_string(), + }) + } + Self::Client { + connection, + to_gadget, + .. + } => { + let stream = connection.expect("Should exist"); + let mut wrapped_stream = mpc_net::multi::wrap_stream(stream); + while let Some(Ok(message)) = wrapped_stream.next().await { + match bincode2::deserialize::(&message) { + Ok(packet) => match packet { + RegistryPacket::SubstrateGadgetMessage { payload } => { + if let Err(err) = to_gadget.send(payload) { + eprintln!( + "[Registry] Failed to send message to gadget: {err:?}" + ); + } + } + _ => { + println!("[Registry] Received invalid packet"); + } + }, + Err(err) => { + println!("[Registry] Received invalid packet: {err}"); + } + } + } + + Err(Error::RegistryListenError { + err: "Connection closed".to_string(), + }) + } + } + } + + async fn client_register(&mut self) -> Result<(), Error> { + match self { + Self::King { .. } => Err(Error::RegistryCreateError { + err: "Cannot register as king".to_string(), + }), + Self::Client { + king_registry_addr: _, + registrant_id, + connection, + cert_der, + .. + } => { + let conn = connection.as_mut().expect("Should exist"); + let mut wrapped_stream = mpc_net::multi::wrap_stream(conn); + + send_stream( + &mut wrapped_stream, + RegistryPacket::Register { + id: *registrant_id, + cert_der: cert_der.clone(), + }, + ) + .await?; + + let response = recv_stream::(&mut wrapped_stream).await?; + + if !matches!( + &response, + &RegistryPacket::RegisterResponse { success: true, .. } + ) { + return Err(Error::RegistryCreateError { + err: "Unexpected response".to_string(), + }); + } + + Ok(()) + } + } + } +} + +#[derive(Serialize, Deserialize)] +enum RegistryPacket { + Register { id: RegistantId, cert_der: Vec }, + RegisterResponse { id: RegistantId, success: bool }, + // A message for the substrate gadget + SubstrateGadgetMessage { payload: GadgetProtocolMessage }, +} + +fn handle_stream_as_king( + stream: TlsStream, + peer_addr: SocketAddr, + registrants: Arc>>, + to_gadget: tokio::sync::mpsc::UnboundedSender, +) { + tokio::task::spawn(async move { + let mut wrapped_stream = mpc_net::multi::wrap_stream(stream); + let mut peer_id = None; + while let Some(Ok(message)) = wrapped_stream.next().await { + match bincode2::deserialize::(&message) { + Ok(packet) => match packet { + RegistryPacket::Register { id, cert_der } => { + println!("[Registry] Received registration for id {id}"); + peer_id = Some(id); + let mut registrants = registrants.lock().await; + registrants.insert(id, Registrant { id, cert_der }); + if let Err(err) = send_stream( + &mut wrapped_stream, + RegistryPacket::RegisterResponse { id, success: true }, + ) + .await + { + eprintln!("[Registry] Failed to send registration response: {err:?}"); + } + } + RegistryPacket::SubstrateGadgetMessage { payload } => { + if let Err(err) = to_gadget.send(payload) { + eprintln!("[Registry] Failed to send message to gadget: {err:?}"); + } + } + _ => { + println!("[Registry] Received invalid packet"); + } + }, + Err(err) => { + println!("[Registry] Received invalid packet: {err}"); + } + } + } + + // Deregister peer + if let Some(id) = peer_id { + let mut registrants = registrants.lock().await; + registrants.remove(&id); + } + + eprintln!("[Registry] Connection closed to peer {peer_addr}") + }); +} + +async fn send_stream( + stream: &mut WrappedStream, + payload: T, +) -> Result<(), Error> { + let serialized = bincode2::serialize(&payload).map_err(|err| Error::RegistrySendError { + err: err.to_string(), + })?; + + stream + .send(serialized.into()) + .await + .map_err(|err| Error::RegistrySendError { + err: err.to_string(), + }) +} + +async fn recv_stream( + stream: &mut WrappedStream, +) -> Result { + let message = stream + .next() + .await + .ok_or(Error::RegistryRecvError { + err: "Stream closed".to_string(), + })? + .map_err(|err| Error::RegistryRecvError { + err: err.to_string(), + })?; + + let deserialized = bincode2::deserialize(&message).map_err(|err| Error::RegistryRecvError { + err: err.to_string(), + })?; + + Ok(deserialized) +} + +impl Network for ZkNetworkService { + fn take_message_receiver(&mut self) -> Option> { + match self { + Self::King { from_registry, .. } => from_registry.take(), + Self::Client { from_registry, .. } => from_registry.take(), + } + } +}