Skip to content

Commit

Permalink
Refactor tests. Add dfft test (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Braun committed Nov 6, 2023
1 parent f43877d commit 443d332
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 101 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ sp-runtime = { git = "https://github.com/paritytech/substrate", branch = "polkad
sp-api = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false }

mpc-net = { git = "https://github.com/webb-tools/zk-SaaS/" }
dist-primitives = { git = "https://github.com/webb-tools/zk-SaaS/" }
secret-sharing = { git = "https://github.com/webb-tools/zk-SaaS/" }

tokio-rustls = "0.24.1"
tokio = "1.32.0"
bincode2 = "2"
Expand Down
11 changes: 10 additions & 1 deletion test-gadget/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,13 @@ async-trait = { workspace = true }

[dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }

# zk deps
mpc-net = { workspace = true }
dist-primitives = { workspace = true }
secret-sharing = { workspace = true }
bytes = "1.5.0"
ark-bls12-377 = { version = "0.4.0", default-features = false, features = ["curve"] }
ark-ff = {version = "0.4.0", default-features = false}
ark-poly = {version = "0.4.0", default-features = false}
20 changes: 13 additions & 7 deletions test-gadget/src/gadget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@ use std::sync::Arc;
use tokio::sync::Mutex;

/// An AbstractGadget endowed with a WorkerManager, a fake blockchain that delivers FinalityNotifications to the gadgets, and a TestProtocolMessage stream
pub struct TestGadget {
pub struct TestGadget<B> {
job_manager: ProtocolWorkManager<TestWorkManager>,
blockchain_connection: Mutex<tokio::sync::broadcast::Receiver<TestFinalityNotification>>,
network_connection: Mutex<tokio::sync::mpsc::UnboundedReceiver<TestProtocolMessage>>,
// Specifies at which blocks we should start a job on
run_test_at: Arc<Vec<u64>>,
clock: Arc<RwLock<u64>>,
async_protocol_generator: Box<dyn AsyncProtocolGenerator>,
test_bundle: B,
async_protocol_generator: Box<dyn AsyncProtocolGenerator<B>>,
}

impl TestGadget {
pub fn new<T: AsyncProtocolGenerator + 'static>(
impl<B: Send + Sync + Clone + 'static> TestGadget<B> {
pub fn new<T: AsyncProtocolGenerator<B> + 'static>(
blockchain_connection: tokio::sync::broadcast::Receiver<TestFinalityNotification>,
network_connection: tokio::sync::mpsc::UnboundedReceiver<TestProtocolMessage>,
run_test_at: Arc<Vec<u64>>,
test_bundle: B,
async_protocol_generator: T,
) -> Self {
let clock = Arc::new(RwLock::new(0));
Expand All @@ -43,6 +45,7 @@ impl TestGadget {
blockchain_connection: Mutex::new(blockchain_connection),
network_connection: Mutex::new(network_connection),
run_test_at,
test_bundle,
async_protocol_generator: Box::new(async_protocol_generator),
clock,
}
Expand All @@ -56,7 +59,7 @@ pub struct TestFinalityNotification {
}

#[async_trait]
impl AbstractGadget for TestGadget {
impl<B: Send + Sync + Clone + 'static> AbstractGadget for TestGadget<B> {
type FinalityNotification = TestFinalityNotification;
type BlockImportNotification = ();
type ProtocolMessage = TestProtocolMessage;
Expand Down Expand Up @@ -92,6 +95,7 @@ impl AbstractGadget for TestGadget {
now,
ssid,
task_hash,
self.test_bundle.clone(),
&*self.async_protocol_generator,
);
self.job_manager
Expand Down Expand Up @@ -128,12 +132,13 @@ impl AbstractGadget for TestGadget {
}
}

fn create_test_async_protocol(
fn create_test_async_protocol<B: Send + Sync + 'static>(
session_id: <TestWorkManager as WorkManagerInterface>::SessionID,
now: <TestWorkManager as WorkManagerInterface>::Clock,
ssid: <TestWorkManager as WorkManagerInterface>::SSID,
task_id: <TestWorkManager as WorkManagerInterface>::TaskID,
proto_gen: &dyn AsyncProtocolGenerator,
test_bundle: B,
proto_gen: &dyn AsyncProtocolGenerator<B>,
) -> (TestProtocolRemote, Pin<Box<dyn SendFuture<'static, ()>>>) {
let is_done = Arc::new(AtomicBool::new(false));
let (to_async_protocol, protocol_message_rx) = tokio::sync::mpsc::unbounded_channel();
Expand All @@ -149,6 +154,7 @@ fn create_test_async_protocol(
associated_ssid: ssid,
associated_session_id: session_id,
associated_task_id: task_id,
test_bundle,
};

let remote = TestProtocolRemote {
Expand Down
95 changes: 95 additions & 0 deletions test-gadget/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,101 @@
use crate::error::TestError;
use crate::gadget::TestGadget;
use crate::message::UserID;
use crate::test_network::InMemoryNetwork;
use crate::work_manager::AsyncProtocolGenerator;
use futures::stream::FuturesUnordered;
use futures::TryStreamExt;
use gadget_core::gadget::manager::GadgetManager;
use std::error::Error;
use std::sync::Arc;
use std::time::Duration;

pub mod blockchain;
pub mod error;
pub mod gadget;
pub mod message;
pub mod test_network;
pub mod work_manager;

pub async fn simulate_test<T: AsyncProtocolGenerator<TestBundle> + 'static + Clone>(
n_peers: usize,
n_blocks_per_session: u64,
block_duration: Duration,
run_tests_at: Vec<u64>,
async_proto_gen: T,
) -> Result<(), Box<dyn Error>> {
let (network, recv_handles) = InMemoryNetwork::new(n_peers);
let (bc_tx, _bc_rx) = tokio::sync::broadcast::channel(5000);
let gadget_futures = FuturesUnordered::new();
let run_tests_at = Arc::new(run_tests_at);
let (count_finished_tx, mut count_finished_rx) = tokio::sync::mpsc::unbounded_channel();

for (party_id, network_handle) in recv_handles.into_iter().enumerate() {
// Create a TestGadget
let count_finished_tx = count_finished_tx.clone();
let network = network.clone();
let test_bundle = TestBundle {
count_finished_tx,
network,
n_peers,
party_id: party_id as UserID,
};

let gadget = TestGadget::new(
bc_tx.subscribe(),
network_handle,
run_tests_at.clone(),
test_bundle,
async_proto_gen.clone(),
);

gadget_futures.push(tokio::task::spawn(GadgetManager::new(gadget)));
}

let blockchain_future = blockchain::blockchain(block_duration, n_blocks_per_session, bc_tx);

let finished_future = async move {
let mut count_received = 0;
let expected_count = n_peers * run_tests_at.len();
while count_finished_rx.recv().await.is_some() {
count_received += 1;
log::info!("Received {} finished signals", count_received);
if count_received == expected_count {
return Ok::<(), TestError>(());
}
}

Err(TestError {
reason: "Didn't receive all finished signals".to_string(),
})
};

let gadgets_future = gadget_futures.try_collect::<Vec<_>>();

// The gadgets will run indefinitely if properly behaved, and the blockchain future will as well.
// The finished future will end when all gadgets have finished their async protocols properly
// Thus, select all three futures

tokio::select! {
res0 = blockchain_future => {
res0?;
},
res1 = gadgets_future => {
res1?.map_err(|err| TestError { reason: format!("{err:?}") })
.map(|_| ())?;
},
res2 = finished_future => {
res2?;
}
}

Ok(())
}

#[derive(Clone)]
pub struct TestBundle {
pub count_finished_tx: tokio::sync::mpsc::UnboundedSender<()>,
pub network: InMemoryNetwork,
pub n_peers: usize,
pub party_id: UserID,
}
13 changes: 8 additions & 5 deletions test-gadget/src/work_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl ProtocolRemote<TestWorkManager> for TestProtocolRemote {
}
}

pub struct TestAsyncProtocolParameters {
pub struct TestAsyncProtocolParameters<B> {
pub is_done: Arc<AtomicBool>,
pub protocol_message_rx: tokio::sync::mpsc::UnboundedReceiver<
<TestWorkManager as WorkManagerInterface>::ProtocolMessage,
Expand All @@ -126,13 +126,16 @@ pub struct TestAsyncProtocolParameters {
pub associated_ssid: <TestWorkManager as WorkManagerInterface>::SSID,
pub associated_session_id: <TestWorkManager as WorkManagerInterface>::SessionID,
pub associated_task_id: <TestWorkManager as WorkManagerInterface>::TaskID,
pub test_bundle: B,
}

pub trait AsyncProtocolGenerator:
Send + Sync + Fn(TestAsyncProtocolParameters) -> Pin<Box<dyn SendFuture<'static, ()>>>
pub trait AsyncProtocolGenerator<B>:
Send + Sync + Fn(TestAsyncProtocolParameters<B>) -> Pin<Box<dyn SendFuture<'static, ()>>>
{
}
impl<T: Send + Sync + Fn(TestAsyncProtocolParameters) -> Pin<Box<dyn SendFuture<'static, ()>>>>
AsyncProtocolGenerator for T
impl<
B: Send + Sync,
T: Send + Sync + Fn(TestAsyncProtocolParameters<B>) -> Pin<Box<dyn SendFuture<'static, ()>>>,
> AsyncProtocolGenerator<B> for T
{
}
109 changes: 21 additions & 88 deletions test-gadget/tests/basic.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
#[cfg(test)]
mod tests {
use futures::stream::FuturesUnordered;
use futures::TryStreamExt;
use gadget_core::gadget::manager::GadgetManager;
use gadget_core::job_manager::SendFuture;
use std::error::Error;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use test_gadget::error::TestError;
use test_gadget::gadget::TestGadget;
use test_gadget::message::{TestProtocolMessage, UserID};
use test_gadget::test_network::InMemoryNetwork;
use test_gadget::message::TestProtocolMessage;
use test_gadget::work_manager::TestAsyncProtocolParameters;
use test_gadget::TestBundle;
use tracing_subscriber::fmt::SubscriberBuilder;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
Expand All @@ -29,88 +23,21 @@ mod tests {
#[tokio::test]
async fn test_basic_async_protocol() -> Result<(), Box<dyn Error>> {
setup_log();
const N_PEERS: usize = 5;
const N_BLOCKS_PER_SESSION: u64 = 10;
const BLOCK_DURATION: Duration = Duration::from_millis(1000);

let (network, recv_handles) = InMemoryNetwork::new(N_PEERS);
let (bc_tx, _bc_rx) = tokio::sync::broadcast::channel(1000);
let run_tests_at = Arc::new(vec![0, 2, 4]);
let gadget_futures = FuturesUnordered::new();
let (count_finished_tx, mut count_finished_rx) = tokio::sync::mpsc::unbounded_channel();

for (party_id, network_handle) in recv_handles.into_iter().enumerate() {
// Create a TestGadget
let count_finished_tx = count_finished_tx.clone();
let network = network.clone();
let async_protocol_gen = move |params: TestAsyncProtocolParameters| {
async_proto_generator(
params,
count_finished_tx.clone(),
network.clone(),
N_PEERS,
party_id as UserID,
)
};

let gadget = TestGadget::new(
bc_tx.subscribe(),
network_handle,
run_tests_at.clone(),
async_protocol_gen,
);

gadget_futures.push(GadgetManager::new(gadget));
}

let blockchain_future =
test_gadget::blockchain::blockchain(BLOCK_DURATION, N_BLOCKS_PER_SESSION, bc_tx);

let finished_future = async move {
let mut count_received = 0;
let expected_count = N_PEERS * run_tests_at.len();
while count_finished_rx.recv().await.is_some() {
count_received += 1;
log::info!("Received {} finished signals", count_received);
if count_received == expected_count {
return Ok::<(), TestError>(());
}
}

Err(TestError {
reason: "Didn't receive all finished signals".to_string(),
})
};

let gadgets_future = gadget_futures.try_collect::<Vec<_>>();

// The gadgets will run indefinitely if properly behaved, and the blockchain future will as well.
// The finished future will end when all gadgets have finished their async protocols properly
// Thus, select all three futures

tokio::select! {
res0 = blockchain_future => {
res0?;
},
res1 = gadgets_future => {
res1.map_err(|err| TestError { reason: format!("{err:?}") })
.map(|_| ())?;
},
res2 = finished_future => {
res2?;
}
}

Ok(())
test_gadget::simulate_test(
5,
10,
Duration::from_millis(1000),
vec![0, 2, 4],
async_proto_generator,
)
.await
}

fn async_proto_generator(
mut params: TestAsyncProtocolParameters,
on_finish: tokio::sync::mpsc::UnboundedSender<()>,
network: InMemoryNetwork,
n_peers: usize,
party_id: UserID,
mut params: TestAsyncProtocolParameters<TestBundle>,
) -> Pin<Box<dyn SendFuture<'static, ()>>> {
let n_peers = params.test_bundle.n_peers;
let party_id = params.test_bundle.party_id;
Box::pin(async move {
params
.start_rx
Expand All @@ -119,7 +46,9 @@ mod tests {
.await
.expect("Failed to start");
// Broadcast a message to each peer
network
params
.test_bundle
.network
.broadcast(
party_id,
TestProtocolMessage {
Expand All @@ -145,7 +74,11 @@ mod tests {
params
.is_done
.store(true, std::sync::atomic::Ordering::Relaxed);
on_finish.send(()).expect("Didn't send on_finish signal");
params
.test_bundle
.count_finished_tx
.send(())
.expect("Didn't send on_finish signal");
})
}
}
Loading

0 comments on commit 443d332

Please sign in to comment.