Skip to content

Commit

Permalink
Clean up and improve abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Braun committed Nov 17, 2023
1 parent 52de2e4 commit e640a0b
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 26 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ zk-gadget = { path = "./zk-gadget" }
sc-client-api = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false }
sp-core = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false }
sp-runtime = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false }
sp-api = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false }
sc-utils = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false }
parity-scale-codec = "3.6.5"

Expand Down
2 changes: 0 additions & 2 deletions gadget-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ edition = "2021"
substrate = [
"sp-runtime",
"sc-client-api",
"sp-api",
"futures"
]

Expand All @@ -22,7 +21,6 @@ async-trait = "0.1.73"

sp-runtime = { optional = true, workspace = true, default-features = false }
sc-client-api = { optional = true, workspace = true, default-features = false }
sp-api = { optional = true, workspace = true, default-features = false }
futures = { optional = true, workspace = true }

[dev-dependencies]
4 changes: 2 additions & 2 deletions webb-gadget/src/gadget/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ pub struct GadgetProtocolMessage {
pub associated_block_id: <WebbWorkManager as WorkManagerInterface>::Clock,
pub associated_session_id: <WebbWorkManager as WorkManagerInterface>::SessionID,
pub associated_ssid: <WebbWorkManager as WorkManagerInterface>::SSID,
// A unique marker for the associated task this message belongs to
pub task_hash: <WebbWorkManager as WorkManagerInterface>::TaskID,
pub from: UserID,
// If None, this is a broadcasted message
pub to: Option<UserID>,
// A unique marker for the associated task this message belongs to
pub task_hash: <WebbWorkManager as WorkManagerInterface>::TaskID,
pub payload: Vec<u8>,
}

Expand Down
2 changes: 1 addition & 1 deletion zk-gadget/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub async fn run<
C: ClientWithApi<B>,
B: Block,
T: AdditionalProtocolParams,
Gen: AsyncProtocolGenerator<T, Error, ZkNetworkService>,
Gen: AsyncProtocolGenerator<T, Error, ZkNetworkService, C, B>,
>(
config: ZkGadgetConfig,
client: C,
Expand Down
7 changes: 4 additions & 3 deletions zk-gadget/src/module/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ use webb_gadget::{BlockImportNotification, Error, FinalityNotification};

pub mod proto_gen;

pub struct ZkModule<T, C> {
pub struct ZkModule<T, C, B> {
pub party_id: RegistantId,
pub additional_protocol_params: T,
pub client: C,
pub network: ZkNetworkService,
pub async_protocol_generator:
Box<dyn proto_gen::AsyncProtocolGenerator<T, Error, ZkNetworkService>>,
Box<dyn proto_gen::AsyncProtocolGenerator<T, Error, ZkNetworkService, C, B>>,
}

pub trait AdditionalProtocolParams: Send + Sync + Clone + 'static {}
impl<T: Send + Sync + Clone + 'static> AdditionalProtocolParams for T {}

#[async_trait]
impl<B: Block, T: AdditionalProtocolParams, C: ClientWithApi<B>> WebbGadgetModule<B>
for ZkModule<T, C>
for ZkModule<T, C, B>
{
async fn process_finality_notification(
&self,
Expand All @@ -49,6 +49,7 @@ impl<B: Block, T: AdditionalProtocolParams, C: ClientWithApi<B>> WebbGadgetModul
n_parties,
self.additional_protocol_params.clone(),
self.network.clone(),
self.client.clone(),
&*self.async_protocol_generator,
);

Expand Down
42 changes: 31 additions & 11 deletions zk-gadget/src/module/proto_gen.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
use crate::client_ext::ClientWithApi;
use crate::network::RegistantId;
use async_trait::async_trait;
use bytes::Bytes;
use gadget_core::job_manager::{ProtocolRemote, SendFuture, ShutdownReason, WorkManagerInterface};
use mpc_net::{MpcNet, MpcNetError, MultiplexedStreamID};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use sp_runtime::traits::Block;
use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use webb_gadget::gadget::message::GadgetProtocolMessage;
use webb_gadget::gadget::network::Network;
use webb_gadget::gadget::work_manager::WebbWorkManager;

pub struct ZkAsyncProtocolParameters<B, N> {
pub struct ZkAsyncProtocolParameters<B, N, C, Bl> {
pub associated_block_id: <WebbWorkManager as WorkManagerInterface>::Clock,
pub associated_ssid: <WebbWorkManager as WorkManagerInterface>::SSID,
pub associated_session_id: <WebbWorkManager as WorkManagerInterface>::SessionID,
Expand All @@ -23,7 +26,9 @@ pub struct ZkAsyncProtocolParameters<B, N> {
pub party_id: RegistantId,
pub n_parties: usize,
pub network: N,
pub client: C,
pub extra_parameters: B,
_pd: PhantomData<Bl>,
}

#[derive(Serialize, Deserialize)]
Expand All @@ -45,27 +50,37 @@ pub struct ZkProtocolRemote {
pub is_done: Arc<AtomicBool>,
}

pub trait AsyncProtocolGenerator<B, E, N>:
pub trait AsyncProtocolGenerator<B, E, N, C, Bl>:
Send
+ Sync
+ 'static
+ Fn(ZkAsyncProtocolParameters<B, N>) -> Pin<Box<dyn SendFuture<'static, Result<(), E>>>>
+ Fn(ZkAsyncProtocolParameters<B, N, C, Bl>) -> Pin<Box<dyn SendFuture<'static, Result<(), E>>>>
{
}
impl<
B: Send + Sync,
E: Debug,
N: Network,
Bl: Block,
C: ClientWithApi<Bl>,
T: Send
+ Sync
+ 'static
+ Fn(ZkAsyncProtocolParameters<B, N>) -> Pin<Box<dyn SendFuture<'static, Result<(), E>>>>,
> AsyncProtocolGenerator<B, E, N> for T
+ Fn(
ZkAsyncProtocolParameters<B, N, C, Bl>,
) -> Pin<Box<dyn SendFuture<'static, Result<(), E>>>>,
> AsyncProtocolGenerator<B, E, N, C, Bl> for T
{
}

#[allow(clippy::too_many_arguments)]
pub fn create_zk_async_protocol<B: Send + Sync + 'static, E: Debug + 'static, N: Network>(
pub fn create_zk_async_protocol<
B: Send + Sync + 'static,
E: Debug + 'static,
N: Network,
C: ClientWithApi<Bl>,
Bl: Block,
>(
session_id: <WebbWorkManager as WorkManagerInterface>::SessionID,
now: <WebbWorkManager as WorkManagerInterface>::Clock,
ssid: <WebbWorkManager as WorkManagerInterface>::SSID,
Expand All @@ -74,7 +89,8 @@ pub fn create_zk_async_protocol<B: Send + Sync + 'static, E: Debug + 'static, N:
n_parties: usize,
extra_parameters: B,
network: N,
proto_gen: &dyn AsyncProtocolGenerator<B, E, N>,
client: C,
proto_gen: &dyn AsyncProtocolGenerator<B, E, N, C, Bl>,
) -> (ZkProtocolRemote, Pin<Box<dyn SendFuture<'static, ()>>>) {
let is_done = Arc::new(AtomicBool::new(false));
let (to_async_protocol, mut protocol_message_rx) = tokio::sync::mpsc::unbounded_channel();
Expand Down Expand Up @@ -120,6 +136,8 @@ pub fn create_zk_async_protocol<B: Send + Sync + 'static, E: Debug + 'static, N:
associated_task_id: task_id,
extra_parameters,
network,
client,
_pd: PhantomData,
};

let remote = ZkProtocolRemote {
Expand All @@ -138,7 +156,7 @@ pub fn create_zk_async_protocol<B: Send + Sync + 'static, E: Debug + 'static, N:
// job manager
let wrapped_future = Box::pin(async move {
if let Err(err) = start_rx.await {
log::error!("Failed to start protocol {proto_hash_hex}: {err:?}");
log::error!("Protocol {proto_hash_hex} failed to receive start signal: {err:?}");
} else {
tokio::select! {
res0 = async_protocol => {
Expand All @@ -152,10 +170,10 @@ pub fn create_zk_async_protocol<B: Send + Sync + 'static, E: Debug + 'static, N:
res1 = shutdown_rx => {
match res1 {
Ok(reason) => {
log::info!("Protocol shutdown: {reason:?}");
log::info!("Protocol {proto_hash_hex} shutdown: {reason:?}");
},
Err(err) => {
log::error!("Protocol shutdown failed: {err:?}");
log::error!("Protocol {proto_hash_hex} shutdown failed: {err:?}");
},
}
}
Expand Down Expand Up @@ -234,7 +252,9 @@ impl ProtocolRemote<WebbWorkManager> for ZkProtocolRemote {
}

#[async_trait]
impl<B: Send + Sync, N: Network> MpcNet for ZkAsyncProtocolParameters<B, N> {
impl<B: Send + Sync, N: Network, C: ClientWithApi<Bl>, Bl: Block> MpcNet
for ZkAsyncProtocolParameters<B, N, C, Bl>
{
fn n_parties(&self) -> usize {
self.n_parties
}
Expand Down
14 changes: 8 additions & 6 deletions zk-gadget/tests/jobs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[cfg(test)]
mod tests {
use crate::tests::client::BlockchainClient;
use crate::tests::client::{BlockchainClient, TestBlock};
use futures_util::stream::FuturesUnordered;
use futures_util::TryStreamExt;
use gadget_core::job_manager::SendFuture;
Expand Down Expand Up @@ -340,7 +340,6 @@ mod tests {

let additional_parameters = AdditionalParams {
stop_tx: done_tx.clone(),
client: client.clone(),
};

let zk_gadget_future = zk_gadget::run(
Expand Down Expand Up @@ -385,12 +384,15 @@ mod tests {
#[derive(Clone)]
struct AdditionalParams {
pub stop_tx: UnboundedSender<()>,
#[allow(dead_code)]
pub client: BlockchainClient,
}

fn async_protocol_generator(
params: ZkAsyncProtocolParameters<AdditionalParams, ZkNetworkService>,
params: ZkAsyncProtocolParameters<
AdditionalParams,
ZkNetworkService,
BlockchainClient,
TestBlock,
>,
) -> Pin<Box<dyn SendFuture<'static, Result<(), webb_gadget::Error>>>> {
Box::pin(async move {
if params.party_id == 0 {
Expand Down Expand Up @@ -425,7 +427,7 @@ mod tests {
.expect("Should receive protocol message");
}

// TODO: use the params.extra_parameters.client to get job metadata, **AFTER** the server is given some data
// TODO: use the params.client to get job metadata, **AFTER** the server is given some data
// to store inside its hashmap. By default, there is none. See previous TODO

params
Expand Down

0 comments on commit e640a0b

Please sign in to comment.