-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TestWorld for sharded environments (#982)
* TestWorld for sharded environments This change introduces the ability to run very simple circuits on multiple shards in parallel. It is not possible to communicate between shards yet, but it is possible to use the same test infrastructure to create multiple shards and use PRSS inside them as well as provide the input for each shard and consume their output. The next and hopefully final change will bring the ability to communicate across shards. * Address feedback * Replace generic with AT in `Transport` * Get rid of `IdentityHandlerExt` replace it with `ListenerSetup` trait that is hopefully less confusing to use * Document `RequestHandler` trait * s/W/S * Rename Sharded to WithShards * Final touches * ok().unwrap() instead of unwrap_or_else()
- Loading branch information
Showing
20 changed files
with
1,061 additions
and
359 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
use std::{collections::HashSet, future::Future}; | ||
|
||
use crate::{ | ||
helpers::{ | ||
query::{PrepareQuery, QueryConfig}, | ||
transport::in_memory::{routing::Addr, transport::Error, InMemoryTransport}, | ||
HelperIdentity, RouteId, Transport, TransportCallbacks, TransportIdentity, | ||
}, | ||
protocol::QueryId, | ||
sharding::ShardIndex, | ||
}; | ||
|
||
/// Trait for in-memory request handlers. MPC handlers need to be able to process query requests, | ||
/// while shard traffic does not need to and therefore does not make use of it. | ||
/// | ||
/// See [`HelperRequestHandler`]. | ||
pub trait RequestHandler<I: TransportIdentity>: Send { | ||
fn handle( | ||
&mut self, | ||
transport: InMemoryTransport<I>, | ||
addr: Addr<I>, | ||
) -> impl Future<Output = Result<(), Error<I>>> + Send; | ||
} | ||
|
||
impl RequestHandler<ShardIndex> for () { | ||
async fn handle( | ||
&mut self, | ||
_transport: InMemoryTransport<ShardIndex>, | ||
addr: Addr<ShardIndex>, | ||
) -> Result<(), Error<ShardIndex>> { | ||
panic!( | ||
"Shards can only process {:?} requests, got {:?}", | ||
RouteId::Records, | ||
addr.route | ||
) | ||
} | ||
} | ||
|
||
/// Handler that keeps track of running queries and | ||
/// routes [`RouteId::PrepareQuery`] and [`RouteId::ReceiveQuery`] requests to the stored | ||
/// callback instance. This handler works for MPC networks, for sharding network see | ||
/// [`RequestHandler<ShardIndex>`] | ||
pub struct HelperRequestHandler { | ||
active_queries: HashSet<QueryId>, | ||
callbacks: TransportCallbacks<InMemoryTransport<HelperIdentity>>, | ||
} | ||
|
||
impl From<TransportCallbacks<InMemoryTransport<HelperIdentity>>> for HelperRequestHandler { | ||
fn from(callbacks: TransportCallbacks<InMemoryTransport<HelperIdentity>>) -> Self { | ||
Self { | ||
active_queries: HashSet::default(), | ||
callbacks, | ||
} | ||
} | ||
} | ||
|
||
impl RequestHandler<HelperIdentity> for HelperRequestHandler { | ||
async fn handle( | ||
&mut self, | ||
transport: InMemoryTransport<HelperIdentity>, | ||
addr: Addr<HelperIdentity>, | ||
) -> Result<(), Error<HelperIdentity>> { | ||
let dest = transport.identity(); | ||
match addr.route { | ||
RouteId::ReceiveQuery => { | ||
let qc = addr.into::<QueryConfig>(); | ||
(self.callbacks.receive_query)(Transport::clone_ref(&transport), qc) | ||
.await | ||
.map(|query_id| { | ||
assert!( | ||
self.active_queries.insert(query_id), | ||
"the same query id {query_id:?} is generated twice" | ||
); | ||
}) | ||
.map_err(|e| Error::Rejected { | ||
dest, | ||
inner: Box::new(e), | ||
}) | ||
} | ||
RouteId::PrepareQuery => { | ||
let input = addr.into::<PrepareQuery>(); | ||
(self.callbacks.prepare_query)(Transport::clone_ref(&transport), input) | ||
.await | ||
.map_err(|e| Error::Rejected { | ||
dest, | ||
inner: Box::new(e), | ||
}) | ||
} | ||
RouteId::Records => unreachable!(), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
use std::{borrow::Borrow, fmt::Debug}; | ||
|
||
use serde::de::DeserializeOwned; | ||
|
||
use crate::{ | ||
helpers::{QueryIdBinding, RouteId, RouteParams, StepBinding, TransportIdentity}, | ||
protocol::{step::Gate, QueryId}, | ||
}; | ||
|
||
/// The header/metadata of the incoming request. | ||
#[derive(Debug)] | ||
pub(super) struct Addr<I> { | ||
pub route: RouteId, | ||
pub origin: Option<I>, | ||
pub query_id: Option<QueryId>, | ||
pub gate: Option<Gate>, | ||
pub params: String, | ||
} | ||
|
||
impl<I: TransportIdentity> Addr<I> { | ||
#[allow(clippy::needless_pass_by_value)] // to avoid using double-reference at callsites | ||
pub fn from_route<Q: QueryIdBinding, S: StepBinding, R: RouteParams<RouteId, Q, S>>( | ||
origin: I, | ||
route: R, | ||
) -> Self | ||
where | ||
Option<QueryId>: From<Q>, | ||
Option<Gate>: From<S>, | ||
{ | ||
Self { | ||
route: route.resource_identifier(), | ||
origin: Some(origin), | ||
query_id: route.query_id().into(), | ||
gate: route.gate().into(), | ||
params: route.extra().borrow().to_string(), | ||
} | ||
} | ||
|
||
pub fn into<T: DeserializeOwned>(self) -> T { | ||
serde_json::from_str(&self.params).unwrap() | ||
} | ||
|
||
#[cfg(all(test, unit_test))] | ||
pub fn records(from: I, query_id: QueryId, gate: Gate) -> Self { | ||
Self { | ||
route: RouteId::Records, | ||
origin: Some(from), | ||
query_id: Some(query_id), | ||
gate: Some(gate), | ||
params: String::new(), | ||
} | ||
} | ||
} |
Oops, something went wrong.