Skip to content

Commit

Permalink
TestWorld for sharded environments (#982)
Browse files Browse the repository at this point in the history
* TestWorld for sharded environments

This change introduces the ability to run very simple circuits on multiple
shards in parallel. It is not possible to communicate between shards yet,
but it is possible to use the same test infrastructure to create multiple shards
and use PRSS inside them as well as provide the input for each shard and consume
their output.

The next and hopefully final change will bring the ability to communicate across
shards.

* Address feedback

* Replace generic with AT in `Transport`

* Get rid of `IdentityHandlerExt`

replace it with `ListenerSetup` trait that is hopefully less confusing to use

* Document `RequestHandler` trait

* s/W/S

* Rename Sharded to WithShards

* Final touches

* ok().unwrap() instead of unwrap_or_else()
  • Loading branch information
akoshelev authored Mar 20, 2024
1 parent 5bc5a2a commit 4949f06
Show file tree
Hide file tree
Showing 20 changed files with 1,061 additions and 359 deletions.
7 changes: 3 additions & 4 deletions ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ use crate::{
gateway::{
receive::GatewayReceivers, send::GatewaySenders, transport::RoleResolvingTransport,
},
HelperChannelId, HelperIdentity, Message, Role, RoleAssignment, RouteId, TotalRecords,
Transport,
HelperChannelId, Message, Role, RoleAssignment, RouteId, TotalRecords, Transport,
},
protocol::QueryId,
};
Expand All @@ -30,12 +29,12 @@ use crate::{
/// To avoid proliferation of type parameters, most code references this concrete type alias, rather
/// than a type parameter `T: Transport`.
#[cfg(feature = "in-memory-infra")]
pub type TransportImpl = super::transport::InMemoryTransport<HelperIdentity>;
pub type TransportImpl = super::transport::InMemoryTransport<crate::helpers::HelperIdentity>;

#[cfg(feature = "real-world-infra")]
pub type TransportImpl = crate::sync::Arc<crate::net::HttpTransport>;

pub type TransportError = <TransportImpl as Transport<HelperIdentity>>::Error;
pub type TransportError = <TransportImpl as Transport>::Error;

/// Gateway into IPA Network infrastructure. It allows helpers send and receive messages.
pub struct Gateway {
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/helpers/gateway/receive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ pub(super) struct GatewayReceivers {
}

pub(super) type UR = UnorderedReceiver<
<RoleResolvingTransport as Transport<Role>>::RecordsStream,
<<RoleResolvingTransport as Transport<Role>>::RecordsStream as Stream>::Item,
<RoleResolvingTransport as Transport>::RecordsStream,
<<RoleResolvingTransport as Transport>::RecordsStream as Stream>::Item,
>;

impl<M: Message> ReceivingEnd<M> {
Expand Down
30 changes: 7 additions & 23 deletions ipa-core/src/helpers/gateway/transport.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,17 @@
use std::{
pin::Pin,
task::{Context, Poll},
};

use async_trait::async_trait;
use futures::Stream;

use crate::{
helpers::{
HelperIdentity, NoResourceIdentifier, QueryIdBinding, Role, RoleAssignment, RouteId,
RouteParams, StepBinding, Transport, TransportImpl,
NoResourceIdentifier, QueryIdBinding, Role, RoleAssignment, RouteId, RouteParams,
StepBinding, Transport, TransportImpl,
},
protocol::{step::Gate, QueryId},
};

#[derive(Debug, thiserror::Error)]
#[error("Failed to send to {0:?}: {1:?}")]
pub struct SendToRoleError(Role, <TransportImpl as Transport<HelperIdentity>>::Error);

/// This struct exists to hide the generic type used to index streams internally.
#[pin_project::pin_project]
pub struct RoleRecordsStream(#[pin] <TransportImpl as Transport<HelperIdentity>>::RecordsStream);
pub struct SendToRoleError(Role, <TransportImpl as Transport>::Error);

/// Transport adapter that resolves [`Role`] -> [`HelperIdentity`] mapping. As gateways created
/// per query, it is not ambiguous.
Expand All @@ -32,17 +23,10 @@ pub struct RoleResolvingTransport {
pub(super) inner: TransportImpl,
}

impl Stream for RoleRecordsStream {
type Item = Vec<u8>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().0.poll_next(cx)
}
}

#[async_trait]
impl Transport<Role> for RoleResolvingTransport {
type RecordsStream = RoleRecordsStream;
impl Transport for RoleResolvingTransport {
type Identity = Role;
type RecordsStream = <TransportImpl as Transport>::RecordsStream;
type Error = SendToRoleError;

fn identity(&self) -> Role {
Expand Down Expand Up @@ -89,6 +73,6 @@ impl Transport<Role> for RoleResolvingTransport {
"can't receive message from itself"
);

RoleRecordsStream(self.inner.receive(origin_helper, route))
self.inner.receive(origin_helper, route)
}
}
4 changes: 2 additions & 2 deletions ipa-core/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub use transport::{
WrappedBoxBodyStream,
};
#[cfg(feature = "in-memory-infra")]
pub use transport::{InMemoryNetwork, InMemoryTransport};
pub use transport::{InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport};
use typenum::{Unsigned, U8};
use x25519_dalek::PublicKey;

Expand Down Expand Up @@ -352,7 +352,7 @@ impl<T> IndexMut<Role> for Vec<T> {

impl RoleAssignment {
#[must_use]
pub fn new(helper_roles: [HelperIdentity; 3]) -> Self {
pub const fn new(helper_roles: [HelperIdentity; 3]) -> Self {
Self { helper_roles }
}

Expand Down
92 changes: 92 additions & 0 deletions ipa-core/src/helpers/transport/in_memory/handlers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use std::{collections::HashSet, future::Future};

use crate::{
helpers::{
query::{PrepareQuery, QueryConfig},
transport::in_memory::{routing::Addr, transport::Error, InMemoryTransport},
HelperIdentity, RouteId, Transport, TransportCallbacks, TransportIdentity,
},
protocol::QueryId,
sharding::ShardIndex,
};

/// Trait for in-memory request handlers. MPC handlers need to be able to process query requests,
/// while shard traffic does not need to and therefore does not make use of it.
///
/// See [`HelperRequestHandler`].
pub trait RequestHandler<I: TransportIdentity>: Send {
fn handle(
&mut self,
transport: InMemoryTransport<I>,
addr: Addr<I>,
) -> impl Future<Output = Result<(), Error<I>>> + Send;
}

impl RequestHandler<ShardIndex> for () {
async fn handle(
&mut self,
_transport: InMemoryTransport<ShardIndex>,
addr: Addr<ShardIndex>,
) -> Result<(), Error<ShardIndex>> {
panic!(
"Shards can only process {:?} requests, got {:?}",
RouteId::Records,
addr.route
)
}
}

/// Handler that keeps track of running queries and
/// routes [`RouteId::PrepareQuery`] and [`RouteId::ReceiveQuery`] requests to the stored
/// callback instance. This handler works for MPC networks, for sharding network see
/// [`RequestHandler<ShardIndex>`]
pub struct HelperRequestHandler {
active_queries: HashSet<QueryId>,
callbacks: TransportCallbacks<InMemoryTransport<HelperIdentity>>,
}

impl From<TransportCallbacks<InMemoryTransport<HelperIdentity>>> for HelperRequestHandler {
fn from(callbacks: TransportCallbacks<InMemoryTransport<HelperIdentity>>) -> Self {
Self {
active_queries: HashSet::default(),
callbacks,
}
}
}

impl RequestHandler<HelperIdentity> for HelperRequestHandler {
async fn handle(
&mut self,
transport: InMemoryTransport<HelperIdentity>,
addr: Addr<HelperIdentity>,
) -> Result<(), Error<HelperIdentity>> {
let dest = transport.identity();
match addr.route {
RouteId::ReceiveQuery => {
let qc = addr.into::<QueryConfig>();
(self.callbacks.receive_query)(Transport::clone_ref(&transport), qc)
.await
.map(|query_id| {
assert!(
self.active_queries.insert(query_id),
"the same query id {query_id:?} is generated twice"
);
})
.map_err(|e| Error::Rejected {
dest,
inner: Box::new(e),
})
}
RouteId::PrepareQuery => {
let input = addr.into::<PrepareQuery>();
(self.callbacks.prepare_query)(Transport::clone_ref(&transport), input)
.await
.map_err(|e| Error::Rejected {
dest,
inner: Box::new(e),
})
}
RouteId::Records => unreachable!(),
}
}
}
60 changes: 25 additions & 35 deletions ipa-core/src/helpers/transport/in_memory/mod.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
mod handlers;
mod routing;
mod sharding;
mod transport;

pub use sharding::InMemoryShardNetwork;
pub use transport::Setup;

use crate::{
helpers::{HelperIdentity, TransportCallbacks, TransportIdentity},
helpers::{transport::in_memory::transport::ListenerSetup, HelperIdentity, TransportCallbacks},
sync::{Arc, Weak},
};

pub type InMemoryTransport<I> = Weak<transport::InMemoryTransport<I>>;

/// Container for all active transports
/// Container for all active MPC communication channels
#[derive(Clone)]
pub struct InMemoryNetwork<I> {
pub transports: [Arc<transport::InMemoryTransport<I>>; 3],
pub struct InMemoryMpcNetwork {
pub transports: [Arc<transport::InMemoryTransport<HelperIdentity>>; 3],
}

impl Default for InMemoryNetwork<HelperIdentity> {
impl Default for InMemoryMpcNetwork {
fn default() -> Self {
Self::new([
TransportCallbacks::default(),
Expand All @@ -25,25 +29,29 @@ impl Default for InMemoryNetwork<HelperIdentity> {
}
}

#[allow(dead_code)]
impl<I: TransportIdentity> InMemoryNetwork<I> {
impl InMemoryMpcNetwork {
#[must_use]
#[allow(clippy::missing_panics_doc)]
pub fn identities(&self) -> [I; 3] {
self.transports
.iter()
.map(|t| t.identity())
.collect::<Vec<_>>()
.try_into()
.unwrap()
pub fn new(callbacks: [TransportCallbacks<InMemoryTransport<HelperIdentity>>; 3]) -> Self {
let [mut first, mut second, mut third]: [_; 3] =
HelperIdentity::make_three().map(Setup::new);

first.connect(&mut second);
second.connect(&mut third);
third.connect(&mut first);

let [cb1, cb2, cb3] = callbacks;

Self {
transports: [first.start(cb1), second.start(cb2), third.start(cb3)],
}
}

/// Returns the transport to communicate with the given helper.
///
/// ## Panics
/// If [`HelperIdentity`] is somehow points to a non-existent helper, which shouldn't happen.
#[must_use]
pub fn transport(&self, id: I) -> InMemoryTransport<I> {
pub fn transport(&self, id: HelperIdentity) -> InMemoryTransport<HelperIdentity> {
self.transports
.iter()
.find(|t| t.identity() == id)
Expand All @@ -52,7 +60,7 @@ impl<I: TransportIdentity> InMemoryNetwork<I> {

#[allow(clippy::missing_panics_doc)]
#[must_use]
pub fn transports(&self) -> [InMemoryTransport<I>; 3] {
pub fn transports(&self) -> [InMemoryTransport<HelperIdentity>; 3] {
let transports: [InMemoryTransport<_>; 3] = self
.transports
.iter()
Expand All @@ -71,21 +79,3 @@ impl<I: TransportIdentity> InMemoryNetwork<I> {
}
}
}

impl InMemoryNetwork<HelperIdentity> {
#[must_use]
pub fn new(callbacks: [TransportCallbacks<InMemoryTransport<HelperIdentity>>; 3]) -> Self {
let [mut first, mut second, mut third]: [_; 3] =
HelperIdentity::make_three().map(Setup::new);

first.connect(&mut second);
second.connect(&mut third);
third.connect(&mut first);

let [cb1, cb2, cb3] = callbacks;

Self {
transports: [first.start(cb1), second.start(cb2), third.start(cb3)],
}
}
}
53 changes: 53 additions & 0 deletions ipa-core/src/helpers/transport/in_memory/routing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use std::{borrow::Borrow, fmt::Debug};

use serde::de::DeserializeOwned;

use crate::{
helpers::{QueryIdBinding, RouteId, RouteParams, StepBinding, TransportIdentity},
protocol::{step::Gate, QueryId},
};

/// The header/metadata of the incoming request.
#[derive(Debug)]
pub(super) struct Addr<I> {
pub route: RouteId,
pub origin: Option<I>,
pub query_id: Option<QueryId>,
pub gate: Option<Gate>,
pub params: String,
}

impl<I: TransportIdentity> Addr<I> {
#[allow(clippy::needless_pass_by_value)] // to avoid using double-reference at callsites
pub fn from_route<Q: QueryIdBinding, S: StepBinding, R: RouteParams<RouteId, Q, S>>(
origin: I,
route: R,
) -> Self
where
Option<QueryId>: From<Q>,
Option<Gate>: From<S>,
{
Self {
route: route.resource_identifier(),
origin: Some(origin),
query_id: route.query_id().into(),
gate: route.gate().into(),
params: route.extra().borrow().to_string(),
}
}

pub fn into<T: DeserializeOwned>(self) -> T {
serde_json::from_str(&self.params).unwrap()
}

#[cfg(all(test, unit_test))]
pub fn records(from: I, query_id: QueryId, gate: Gate) -> Self {
Self {
route: RouteId::Records,
origin: Some(from),
query_id: Some(query_id),
gate: Some(gate),
params: String::new(),
}
}
}
Loading

0 comments on commit 4949f06

Please sign in to comment.