diff --git a/Cargo.toml b/Cargo.toml index 5a9b42d4..e117c967 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,9 @@ sp-runtime = { git = "https://github.com/paritytech/substrate", branch = "polkad sp-api = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } mpc-net = { git = "https://github.com/webb-tools/zk-SaaS/" } +dist-primitives = { git = "https://github.com/webb-tools/zk-SaaS/" } +secret-sharing = { git = "https://github.com/webb-tools/zk-SaaS/" } + tokio-rustls = "0.24.1" tokio = "1.32.0" bincode2 = "2" diff --git a/test-gadget/Cargo.toml b/test-gadget/Cargo.toml index e44de903..3bb9224d 100644 --- a/test-gadget/Cargo.toml +++ b/test-gadget/Cargo.toml @@ -17,4 +17,13 @@ async-trait = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } -tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } \ No newline at end of file +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } + +# zk deps +mpc-net = { workspace = true } +dist-primitives = { workspace = true } +secret-sharing = { workspace = true } +bytes = "1.5.0" +ark-bls12-377 = { version = "0.4.0", default-features = false, features = ["curve"] } +ark-ff = {version = "0.4.0", default-features = false} +ark-poly = {version = "0.4.0", default-features = false} diff --git a/test-gadget/src/gadget.rs b/test-gadget/src/gadget.rs index cde943c9..0fa6eb30 100644 --- a/test-gadget/src/gadget.rs +++ b/test-gadget/src/gadget.rs @@ -13,21 +13,23 @@ use std::sync::Arc; use tokio::sync::Mutex; /// An AbstractGadget endowed with a WorkerManager, a fake blockchain that delivers FinalityNotifications to the gadgets, and a TestProtocolMessage stream -pub struct TestGadget { +pub struct TestGadget { job_manager: ProtocolWorkManager, blockchain_connection: Mutex>, network_connection: Mutex>, // Specifies at which blocks we should start a job on run_test_at: Arc>, clock: Arc>, - async_protocol_generator: Box, + test_bundle: B, + async_protocol_generator: Box>, } -impl TestGadget { - pub fn new( +impl TestGadget { + pub fn new + 'static>( blockchain_connection: tokio::sync::broadcast::Receiver, network_connection: tokio::sync::mpsc::UnboundedReceiver, run_test_at: Arc>, + test_bundle: B, async_protocol_generator: T, ) -> Self { let clock = Arc::new(RwLock::new(0)); @@ -43,6 +45,7 @@ impl TestGadget { blockchain_connection: Mutex::new(blockchain_connection), network_connection: Mutex::new(network_connection), run_test_at, + test_bundle, async_protocol_generator: Box::new(async_protocol_generator), clock, } @@ -56,7 +59,7 @@ pub struct TestFinalityNotification { } #[async_trait] -impl AbstractGadget for TestGadget { +impl AbstractGadget for TestGadget { type FinalityNotification = TestFinalityNotification; type BlockImportNotification = (); type ProtocolMessage = TestProtocolMessage; @@ -92,6 +95,7 @@ impl AbstractGadget for TestGadget { now, ssid, task_hash, + self.test_bundle.clone(), &*self.async_protocol_generator, ); self.job_manager @@ -128,12 +132,13 @@ impl AbstractGadget for TestGadget { } } -fn create_test_async_protocol( +fn create_test_async_protocol( session_id: ::SessionID, now: ::Clock, ssid: ::SSID, task_id: ::TaskID, - proto_gen: &dyn AsyncProtocolGenerator, + test_bundle: B, + proto_gen: &dyn AsyncProtocolGenerator, ) -> (TestProtocolRemote, Pin>>) { let is_done = Arc::new(AtomicBool::new(false)); let (to_async_protocol, protocol_message_rx) = tokio::sync::mpsc::unbounded_channel(); @@ -149,6 +154,7 @@ fn create_test_async_protocol( associated_ssid: ssid, associated_session_id: session_id, associated_task_id: task_id, + test_bundle, }; let remote = TestProtocolRemote { diff --git a/test-gadget/src/lib.rs b/test-gadget/src/lib.rs index 8deb56a6..2487e94d 100644 --- a/test-gadget/src/lib.rs +++ b/test-gadget/src/lib.rs @@ -1,6 +1,101 @@ +use crate::error::TestError; +use crate::gadget::TestGadget; +use crate::message::UserID; +use crate::test_network::InMemoryNetwork; +use crate::work_manager::AsyncProtocolGenerator; +use futures::stream::FuturesUnordered; +use futures::TryStreamExt; +use gadget_core::gadget::manager::GadgetManager; +use std::error::Error; +use std::sync::Arc; +use std::time::Duration; + pub mod blockchain; pub mod error; pub mod gadget; pub mod message; pub mod test_network; pub mod work_manager; + +pub async fn simulate_test + 'static + Clone>( + n_peers: usize, + n_blocks_per_session: u64, + block_duration: Duration, + run_tests_at: Vec, + async_proto_gen: T, +) -> Result<(), Box> { + let (network, recv_handles) = InMemoryNetwork::new(n_peers); + let (bc_tx, _bc_rx) = tokio::sync::broadcast::channel(5000); + let gadget_futures = FuturesUnordered::new(); + let run_tests_at = Arc::new(run_tests_at); + let (count_finished_tx, mut count_finished_rx) = tokio::sync::mpsc::unbounded_channel(); + + for (party_id, network_handle) in recv_handles.into_iter().enumerate() { + // Create a TestGadget + let count_finished_tx = count_finished_tx.clone(); + let network = network.clone(); + let test_bundle = TestBundle { + count_finished_tx, + network, + n_peers, + party_id: party_id as UserID, + }; + + let gadget = TestGadget::new( + bc_tx.subscribe(), + network_handle, + run_tests_at.clone(), + test_bundle, + async_proto_gen.clone(), + ); + + gadget_futures.push(tokio::task::spawn(GadgetManager::new(gadget))); + } + + let blockchain_future = blockchain::blockchain(block_duration, n_blocks_per_session, bc_tx); + + let finished_future = async move { + let mut count_received = 0; + let expected_count = n_peers * run_tests_at.len(); + while count_finished_rx.recv().await.is_some() { + count_received += 1; + log::info!("Received {} finished signals", count_received); + if count_received == expected_count { + return Ok::<(), TestError>(()); + } + } + + Err(TestError { + reason: "Didn't receive all finished signals".to_string(), + }) + }; + + let gadgets_future = gadget_futures.try_collect::>(); + + // The gadgets will run indefinitely if properly behaved, and the blockchain future will as well. + // The finished future will end when all gadgets have finished their async protocols properly + // Thus, select all three futures + + tokio::select! { + res0 = blockchain_future => { + res0?; + }, + res1 = gadgets_future => { + res1?.map_err(|err| TestError { reason: format!("{err:?}") }) + .map(|_| ())?; + }, + res2 = finished_future => { + res2?; + } + } + + Ok(()) +} + +#[derive(Clone)] +pub struct TestBundle { + pub count_finished_tx: tokio::sync::mpsc::UnboundedSender<()>, + pub network: InMemoryNetwork, + pub n_peers: usize, + pub party_id: UserID, +} diff --git a/test-gadget/src/work_manager.rs b/test-gadget/src/work_manager.rs index 194b2e49..c5405c47 100644 --- a/test-gadget/src/work_manager.rs +++ b/test-gadget/src/work_manager.rs @@ -115,7 +115,7 @@ impl ProtocolRemote for TestProtocolRemote { } } -pub struct TestAsyncProtocolParameters { +pub struct TestAsyncProtocolParameters { pub is_done: Arc, pub protocol_message_rx: tokio::sync::mpsc::UnboundedReceiver< ::ProtocolMessage, @@ -126,13 +126,16 @@ pub struct TestAsyncProtocolParameters { pub associated_ssid: ::SSID, pub associated_session_id: ::SessionID, pub associated_task_id: ::TaskID, + pub test_bundle: B, } -pub trait AsyncProtocolGenerator: - Send + Sync + Fn(TestAsyncProtocolParameters) -> Pin>> +pub trait AsyncProtocolGenerator: + Send + Sync + Fn(TestAsyncProtocolParameters) -> Pin>> { } -impl Pin>>> - AsyncProtocolGenerator for T +impl< + B: Send + Sync, + T: Send + Sync + Fn(TestAsyncProtocolParameters) -> Pin>>, + > AsyncProtocolGenerator for T { } diff --git a/test-gadget/tests/basic.rs b/test-gadget/tests/basic.rs index 287a819f..40bcbc2e 100644 --- a/test-gadget/tests/basic.rs +++ b/test-gadget/tests/basic.rs @@ -1,18 +1,12 @@ #[cfg(test)] mod tests { - use futures::stream::FuturesUnordered; - use futures::TryStreamExt; - use gadget_core::gadget::manager::GadgetManager; use gadget_core::job_manager::SendFuture; use std::error::Error; use std::pin::Pin; - use std::sync::Arc; use std::time::Duration; - use test_gadget::error::TestError; - use test_gadget::gadget::TestGadget; - use test_gadget::message::{TestProtocolMessage, UserID}; - use test_gadget::test_network::InMemoryNetwork; + use test_gadget::message::TestProtocolMessage; use test_gadget::work_manager::TestAsyncProtocolParameters; + use test_gadget::TestBundle; use tracing_subscriber::fmt::SubscriberBuilder; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::EnvFilter; @@ -29,88 +23,21 @@ mod tests { #[tokio::test] async fn test_basic_async_protocol() -> Result<(), Box> { setup_log(); - const N_PEERS: usize = 5; - const N_BLOCKS_PER_SESSION: u64 = 10; - const BLOCK_DURATION: Duration = Duration::from_millis(1000); - - let (network, recv_handles) = InMemoryNetwork::new(N_PEERS); - let (bc_tx, _bc_rx) = tokio::sync::broadcast::channel(1000); - let run_tests_at = Arc::new(vec![0, 2, 4]); - let gadget_futures = FuturesUnordered::new(); - let (count_finished_tx, mut count_finished_rx) = tokio::sync::mpsc::unbounded_channel(); - - for (party_id, network_handle) in recv_handles.into_iter().enumerate() { - // Create a TestGadget - let count_finished_tx = count_finished_tx.clone(); - let network = network.clone(); - let async_protocol_gen = move |params: TestAsyncProtocolParameters| { - async_proto_generator( - params, - count_finished_tx.clone(), - network.clone(), - N_PEERS, - party_id as UserID, - ) - }; - - let gadget = TestGadget::new( - bc_tx.subscribe(), - network_handle, - run_tests_at.clone(), - async_protocol_gen, - ); - - gadget_futures.push(GadgetManager::new(gadget)); - } - - let blockchain_future = - test_gadget::blockchain::blockchain(BLOCK_DURATION, N_BLOCKS_PER_SESSION, bc_tx); - - let finished_future = async move { - let mut count_received = 0; - let expected_count = N_PEERS * run_tests_at.len(); - while count_finished_rx.recv().await.is_some() { - count_received += 1; - log::info!("Received {} finished signals", count_received); - if count_received == expected_count { - return Ok::<(), TestError>(()); - } - } - - Err(TestError { - reason: "Didn't receive all finished signals".to_string(), - }) - }; - - let gadgets_future = gadget_futures.try_collect::>(); - - // The gadgets will run indefinitely if properly behaved, and the blockchain future will as well. - // The finished future will end when all gadgets have finished their async protocols properly - // Thus, select all three futures - - tokio::select! { - res0 = blockchain_future => { - res0?; - }, - res1 = gadgets_future => { - res1.map_err(|err| TestError { reason: format!("{err:?}") }) - .map(|_| ())?; - }, - res2 = finished_future => { - res2?; - } - } - - Ok(()) + test_gadget::simulate_test( + 5, + 10, + Duration::from_millis(1000), + vec![0, 2, 4], + async_proto_generator, + ) + .await } fn async_proto_generator( - mut params: TestAsyncProtocolParameters, - on_finish: tokio::sync::mpsc::UnboundedSender<()>, - network: InMemoryNetwork, - n_peers: usize, - party_id: UserID, + mut params: TestAsyncProtocolParameters, ) -> Pin>> { + let n_peers = params.test_bundle.n_peers; + let party_id = params.test_bundle.party_id; Box::pin(async move { params .start_rx @@ -119,7 +46,9 @@ mod tests { .await .expect("Failed to start"); // Broadcast a message to each peer - network + params + .test_bundle + .network .broadcast( party_id, TestProtocolMessage { @@ -145,7 +74,11 @@ mod tests { params .is_done .store(true, std::sync::atomic::Ordering::Relaxed); - on_finish.send(()).expect("Didn't send on_finish signal"); + params + .test_bundle + .count_finished_tx + .send(()) + .expect("Didn't send on_finish signal"); }) } } diff --git a/test-gadget/tests/dfft.rs b/test-gadget/tests/dfft.rs new file mode 100644 index 00000000..3f91defe --- /dev/null +++ b/test-gadget/tests/dfft.rs @@ -0,0 +1,158 @@ +#[cfg(test)] +mod tests { + use gadget_core::job_manager::SendFuture; + use std::error::Error; + use std::future::Future; + use std::pin::Pin; + use test_gadget::work_manager::TestAsyncProtocolParameters; + use test_gadget::TestBundle; + use tracing_subscriber::fmt::SubscriberBuilder; + use tracing_subscriber::util::SubscriberInitExt; + use tracing_subscriber::EnvFilter; + + use ark_bls12_377::Fr; + use ark_ff::{FftField, PrimeField}; + use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; + use async_trait::async_trait; + use dist_primitives::{ + channel::MpcSerNet, + dfft::{d_fft, fft_in_place_rearrange}, + utils::pack::transpose, + }; + use mpc_net::{MpcNet, MpcNetError, MultiplexedStreamID}; + use secret_sharing::pss::PackedSharingParams; + use test_gadget::test_network::InMemoryNetwork; + + pub async fn d_fft_test( + pp: &PackedSharingParams, + dom: &Radix2EvaluationDomain, + net: &Net, + ) { + let mbyl: usize = dom.size() / pp.l; + // We apply FFT on this vector + // let mut x = vec![F::ONE; cd.m]; + let mut x: Vec = Vec::new(); + for i in 0..dom.size() { + x.push(F::from(i as u64)); + } + + // Output to test against + let should_be_output = dom.fft(&x); + + fft_in_place_rearrange(&mut x); + let mut pcoeff: Vec> = Vec::new(); + for i in 0..mbyl { + pcoeff.push(x.iter().skip(i).step_by(mbyl).cloned().collect::>()); + pp.pack_from_public_in_place(&mut pcoeff[i]); + } + + let pcoeff_share = pcoeff + .iter() + .map(|x| x[net.party_id() as usize]) + .collect::>(); + + // Rearranging x + + let peval_share = d_fft( + pcoeff_share, + false, + 1, + false, + dom, + pp, + net, + MultiplexedStreamID::One, + ) + .await + .unwrap(); + + // Send to king who reconstructs and checks the answer + net.send_to_king(&peval_share, MultiplexedStreamID::One) + .await + .unwrap() + .map(|peval_shares| { + let peval_shares = transpose(peval_shares); + + let pevals: Vec = peval_shares + .into_iter() + .flat_map(|x| pp.unpack(x)) + .rev() + .collect(); + + if net.is_king() { + assert_eq!(should_be_output, pevals); + } + }); + } + + pub fn setup_log() { + let _ = SubscriberBuilder::default() + .with_env_filter(EnvFilter::from_default_env()) + .finish() + .try_init(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_dfft() -> Result<(), Box> { + setup_log(); + test_gadget::simulate_test( + 5, + 10, + std::time::Duration::from_millis(1000), + vec![0], + async_proto_generator, + ) + .await + } + + fn async_proto_generator( + mut params: TestAsyncProtocolParameters, + ) -> Pin>> { + Box::pin(async move { + let pp = PackedSharingParams::::new(2); + let dom = Radix2EvaluationDomain::::new(1024).unwrap(); + d_fft_test::(&pp, &dom, ¶ms.test_bundle.network).await; + }) + } + + struct TestNetwork { + net: InMemoryNetwork, + n_peers: usize, + party_id: u32, + } + + #[async_trait] + impl MpcNet for TestNetwork { + fn n_parties(&self) -> usize { + self.n_peers + } + + fn party_id(&self) -> u32 { + self.party_id + } + + fn is_init(&self) -> bool { + true + } + + async fn client_send_or_king_receive( + &self, + bytes: &[u8], + sid: MultiplexedStreamID, + ) -> Pin>, MpcNetError>> + Send>> + { + if self.is_king() { + } else { + self.net.send_to() + } + } + + async fn client_receive_or_king_send( + &self, + bytes: Option>, + sid: MultiplexedStreamID, + ) -> Pin> + Send>> { + todo!() + } + } +}