From 2e746fe2c16be6284cd52b6bf551e1d88b9c97f9 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 20 Mar 2024 22:22:11 -0700 Subject: [PATCH 1/5] Delete `TransportCallbacks` and use `RequestHandler` trait instead See #987 for motivation. I had to decide whether I want to use dynamic dispatch vs clunky HTTP interfaces with another generic parameter propagated through the entire stack. I don't have a conslusive answer which way is better, both have significant downsides. Problems with DD approach that is proposed in this change: * Hard to keep `RequestHandler` trait object safe. No generics for `handle` method, use of `async_trait` etc. That removes the opportunity for some optimizations, namely using a trait to pass data down to the handler. It could be better if HTTP layer just passes the same structs it gets from HTTP layer without an extra conversion that must occur if dynamic dispatch is used. * Non zero-cost abstraction. To get data back from the handler, we have to use the same format, right now it is JSON but I doubt we can do better than binary serialization, which means more work to get the data out. * `Box }/routing.rs | 28 ++- .../src/helpers/transport/stream/axum_body.rs | 6 + .../src/helpers/transport/stream/box_body.rs | 12 +- ipa-core/src/net/client/mod.rs | 149 ++++++------- ipa-core/src/net/http_serde.rs | 73 ++++++- .../src/net/server/handlers/query/create.rs | 40 ++-- .../src/net/server/handlers/query/input.rs | 53 +++-- .../src/net/server/handlers/query/prepare.rs | 39 ++-- .../src/net/server/handlers/query/results.rs | 36 ++-- .../src/net/server/handlers/query/status.rs | 32 +-- ipa-core/src/net/test.rs | 17 +- ipa-core/src/net/transport.rs | 123 +++++++---- ipa-core/src/query/executor.rs | 10 +- ipa-core/src/query/processor.rs | 109 +++++----- ipa-core/src/test_fixture/app.rs | 26 ++- 28 files changed, 935 insertions(+), 675 deletions(-) delete mode 100644 ipa-core/src/helpers/transport/callbacks.rs create mode 100644 ipa-core/src/helpers/transport/handler.rs delete mode 100644 ipa-core/src/helpers/transport/in_memory/handlers.rs rename ipa-core/src/helpers/transport/{in_memory => }/routing.rs (64%) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index e571690a3..7f6586866 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -1,19 +1,23 @@ +use std::sync::Mutex; + +use async_trait::async_trait; + use crate::{ helpers::{ - query::{QueryConfig, QueryInput}, - Transport, TransportCallbacks, TransportImpl, + query::{PrepareQuery, QueryConfig, QueryInput}, + routing::{Addr, RouteId}, + ApiError, BodyStream, HelperIdentity, HelperResponse, RequestHandler, Transport, + TransportImpl, }, hpke::{KeyPair, KeyRegistry}, protocol::QueryId, - query::{ - NewQueryError, QueryCompletionError, QueryInputError, QueryProcessor, QueryStatus, - QueryStatusError, - }, + query::{NewQueryError, QueryProcessor, QueryStatus}, sync::Arc, }; pub struct Setup { query_processor: Arc, + handler_setup: RequestHandlerSetup, } /// The API layer to interact with a helper. @@ -23,65 +27,122 @@ pub struct HelperApp { transport: TransportImpl, } +/// This handles requests to initiate and control IPA queries. +struct QueryRequestHandler { + qp: Arc, + transport: Arc>>, +} + +#[async_trait] +impl RequestHandler for QueryRequestHandler { + type Identity = HelperIdentity; + + async fn handle( + &self, + req: Addr, + data: BodyStream, + ) -> Result { + fn ext_query_id(req: &Addr) -> Result { + req.query_id.ok_or_else(|| { + ApiError::BadRequest("Query input is missing query_id argument".into()) + }) + } + + let qp = Arc::clone(&self.qp); + + Ok(match req.route { + r @ RouteId::Records => { + return Err(ApiError::BadRequest( + format!("{r:?} request must not be handled by query processing flow").into(), + )) + } + RouteId::ReceiveQuery => { + let req = req.into::()?; + HelperResponse::from(qp.new_query(self.transport(), req).await?) + } + RouteId::PrepareQuery => { + let req = req.into::()?; + HelperResponse::from(qp.prepare(&self.transport(), req)?) + } + RouteId::QueryInput => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.receive_inputs( + self.transport(), + QueryInput { + query_id, + input_stream: data, + }, + )?) + } + RouteId::QueryStatus => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.query_status(query_id)?) + } + RouteId::CompleteQuery => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.complete(query_id).await?) + } + }) + } +} + +impl QueryRequestHandler { + fn transport(&self) -> TransportImpl { + Clone::clone(self.transport.lock().unwrap().as_ref().unwrap()) + } +} + +#[derive(Clone)] +pub struct RequestHandlerSetup { + qp: Arc, + transport_container: Arc>>, +} + +impl RequestHandlerSetup { + fn new(qp: Arc) -> Self { + Self { + qp, + transport_container: Arc::new(Mutex::new(None)), + } + } + + pub fn make_handler(&self) -> Box> { + Box::new(QueryRequestHandler { + qp: Arc::clone(&self.qp), + transport: Arc::clone(&self.transport_container), + }) as Box> + } + + fn finish(self, transport: TransportImpl) { + let mut guard = self.transport_container.lock().unwrap(); + *guard = Some(transport); + } +} + impl Setup { #[must_use] - pub fn new() -> (Self, TransportCallbacks) { + pub fn new() -> (Self, RequestHandlerSetup) { Self::with_key_registry(KeyRegistry::empty()) } #[must_use] - pub fn with_key_registry( - key_registry: KeyRegistry, - ) -> (Self, TransportCallbacks) { + pub fn with_key_registry(key_registry: KeyRegistry) -> (Self, RequestHandlerSetup) { let query_processor = Arc::new(QueryProcessor::new(key_registry)); + let handler_setup = RequestHandlerSetup::new(Arc::clone(&query_processor)); let this = Self { - query_processor: Arc::clone(&query_processor), + query_processor, + handler_setup: handler_setup.clone(), }; // TODO: weak reference to query processor to prevent mem leak - (this, Self::callbacks(&query_processor)) + (this, handler_setup) } /// Instantiate [`HelperApp`] by connecting it to the provided transport implementation pub fn connect(self, transport: TransportImpl) -> HelperApp { + self.handler_setup.finish(Clone::clone(&transport)); HelperApp::new(transport, self.query_processor) } - - /// Create callbacks that tie up query processor and transport. - fn callbacks(query_processor: &Arc) -> TransportCallbacks { - let rqp = Arc::clone(query_processor); - let pqp = Arc::clone(query_processor); - let iqp = Arc::clone(query_processor); - let sqp = Arc::clone(query_processor); - let cqp = Arc::clone(query_processor); - - TransportCallbacks { - receive_query: Box::new(move |transport: TransportImpl, receive_query| { - let processor = Arc::clone(&rqp); - Box::pin(async move { - let r = processor.new_query(transport, receive_query).await?; - - Ok(r.query_id) - }) - }), - prepare_query: Box::new(move |transport: TransportImpl, prepare_query| { - let processor = Arc::clone(&pqp); - Box::pin(async move { processor.prepare(&transport, prepare_query) }) - }), - query_input: Box::new(move |transport: TransportImpl, query_input| { - let processor = Arc::clone(&iqp); - Box::pin(async move { processor.receive_inputs(transport, query_input) }) - }), - query_status: Box::new(move |_transport: TransportImpl, query_id| { - let processor = Arc::clone(&sqp); - Box::pin(async move { processor.query_status(query_id) }) - }), - complete_query: Box::new(move |_transport: TransportImpl, query_id| { - let processor = Arc::clone(&cqp); - Box::pin(async move { processor.complete(query_id).await }) - }), - } - } } impl HelperApp { @@ -109,7 +170,7 @@ impl HelperApp { /// /// ## Errors /// Propagates errors from the helper. - pub fn execute_query(&self, input: QueryInput) -> Result<(), Error> { + pub fn execute_query(&self, input: QueryInput) -> Result<(), ApiError> { let transport = ::clone(&self.transport); self.query_processor.receive_inputs(transport, input)?; Ok(()) @@ -119,7 +180,7 @@ impl HelperApp { /// /// ## Errors /// Propagates errors from the helper. - pub fn query_status(&self, query_id: QueryId) -> Result { + pub fn query_status(&self, query_id: QueryId) -> Result { Ok(self.query_processor.query_status(query_id)?) } @@ -127,20 +188,7 @@ impl HelperApp { /// /// ## Errors /// Propagates errors from the helper. - pub async fn complete_query(&self, query_id: QueryId) -> Result, Error> { - Ok(self.query_processor.complete(query_id).await?.into_bytes()) + pub async fn complete_query(&self, query_id: QueryId) -> Result, ApiError> { + Ok(self.query_processor.complete(query_id).await?.as_bytes()) } } - -/// Union of error types returned by API operations. -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error(transparent)] - NewQuery(#[from] NewQueryError), - #[error(transparent)] - QueryInput(#[from] QueryInputError), - #[error(transparent)] - QueryCompletion(#[from] QueryCompletionError), - #[error(transparent)] - QueryStatus(#[from] QueryStatusError), -} diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index d0ef0b14d..32f49ae32 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -131,7 +131,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { }); let key_registry = hpke_registry(mk_encryption.as_ref()).await?; - let (setup, callbacks) = AppSetup::with_key_registry(key_registry); + let (setup, handler_setup) = AppSetup::with_key_registry(key_registry); let server_config = ServerConfig { port: args.port, @@ -155,7 +155,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { server_config, network_config, clients, - callbacks, + handler_setup.make_handler(), ); let _app = setup.connect(transport.clone()); diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 6c75de9a6..151a920c6 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -19,7 +19,8 @@ use crate::{ gateway::{ receive::GatewayReceivers, send::GatewaySenders, transport::RoleResolvingTransport, }, - HelperChannelId, Message, Role, RoleAssignment, RouteId, TotalRecords, Transport, + transport::routing::RouteId, + HelperChannelId, Message, Role, RoleAssignment, TotalRecords, Transport, }, protocol::QueryId, }; diff --git a/ipa-core/src/helpers/gateway/transport.rs b/ipa-core/src/helpers/gateway/transport.rs index 558e44e40..43840ce4a 100644 --- a/ipa-core/src/helpers/gateway/transport.rs +++ b/ipa-core/src/helpers/gateway/transport.rs @@ -3,8 +3,8 @@ use futures::Stream; use crate::{ helpers::{ - NoResourceIdentifier, QueryIdBinding, Role, RoleAssignment, RouteId, RouteParams, - StepBinding, Transport, TransportImpl, + transport::routing::RouteId, NoResourceIdentifier, QueryIdBinding, Role, RoleAssignment, + RouteParams, StepBinding, Transport, TransportImpl, }, protocol::{step::Gate, QueryId}, }; diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 1b4e287f8..a10427cfb 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -52,9 +52,10 @@ pub use prss_protocol::negotiate as negotiate_prss; #[cfg(feature = "web-app")] pub use transport::WrappedAxumBodyStream; pub use transport::{ - callbacks::*, query, BodyStream, BytesStream, Identity as TransportIdentity, - LengthDelimitedStream, LogErrors, NoResourceIdentifier, QueryIdBinding, ReceiveRecords, - RecordsStream, RouteId, RouteParams, StepBinding, StreamCollection, StreamKey, Transport, + make_boxed_handler, query, routing, ApiError, BodyStream, BytesStream, HelperResponse, + Identity as TransportIdentity, LengthDelimitedStream, LogErrors, NoQueryId, + NoResourceIdentifier, NoStep, PanickingHandler, QueryIdBinding, ReceiveRecords, RecordsStream, + RequestHandler, RouteParams, StepBinding, StreamCollection, StreamKey, Transport, WrappedBoxBodyStream, }; #[cfg(feature = "in-memory-infra")] diff --git a/ipa-core/src/helpers/transport/callbacks.rs b/ipa-core/src/helpers/transport/callbacks.rs deleted file mode 100644 index ea32ee005..000000000 --- a/ipa-core/src/helpers/transport/callbacks.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::{future::Future, pin::Pin}; - -use crate::{ - helpers::query::{PrepareQuery, QueryConfig, QueryInput}, - protocol::QueryId, - query::{ - NewQueryError, PrepareQueryError, ProtocolResult, QueryCompletionError, QueryInputError, - QueryStatus, QueryStatusError, - }, -}; - -/// Macro for defining transport callbacks. -/// -/// Each input definition specifies a callback name, a result type name, and -/// a function signature for the callback. The expansion looks like this: -/// -/// ```ignore -/// pub type ReceiveQueryResult = Pin> + Send>>; -/// -/// /// Called when helper receives a new query request from an external party. -/// pub trait ReceiveQueryCallback: -/// Fn(T, QueryConfig) -> ReceiveQueryResult + Send + Sync {} -/// -/// impl ReceiveQueryCallback for F where -/// F: Fn(T, QueryConfig) -> ReceiveQueryResult + Send + Sync {} -/// ``` -macro_rules! callbacks { - { - $( - $(#[$($attr:meta),+ ])? - ($cb_name:ident, $res_name:ident): async fn($($args:ident),*) -> $result:ty; - )* - } => { - $( - pub type $res_name = Pin + Send>>; - - $(#[$($attr),+ ])? - pub trait $cb_name: Fn($($args),*) -> $res_name + Send + Sync {} - - impl $cb_name for F where - F: Fn($($args),*) -> $res_name + Send + Sync {} - )* - } -} - -callbacks! { - /// Called by clients to initiate a new query. - (ReceiveQueryCallback, ReceiveQueryResult): - async fn(T, QueryConfig) -> Result; - - /// Called by the leader helper to set up followers for a new query. - (PrepareQueryCallback, PrepareQueryResult): - async fn(T, PrepareQuery) -> Result<(), PrepareQueryError>; - - /// Called by clients to deliver query input data. - (QueryInputCallback, QueryInputResult): - async fn(T, QueryInput) -> Result<(), QueryInputError>; - - /// Called by clients to retrieve query status. - (QueryStatusCallback, QueryStatusResult): - async fn(T, QueryId) -> Result; - - /// Called by clients to drive query to completion and retrieve results. - (CompleteQueryCallback, CompleteQueryResult): - async fn(T, QueryId) -> Result, QueryCompletionError>; -} - -pub struct TransportCallbacks { - pub receive_query: Box>, - pub prepare_query: Box>, - pub query_input: Box>, - pub query_status: Box>, - pub complete_query: Box>, -} - -#[cfg(any(test, feature = "in-memory-infra"))] -impl Default for TransportCallbacks { - fn default() -> Self { - // `TransportCallbacks::default()` is commonly used with struct update syntax - // (`..Default::default()`) to fill out the callbacks that aren't relevant to a particular - // test. In that scenario, a call that does occur is "unexpected" in the sense the term - // is used by mocks. - Self { - receive_query: Box::new(move |_, _| { - Box::pin(async { panic!("unexpected call to receive_query") }) - }), - prepare_query: Box::new(move |_, _| { - Box::pin(async { panic!("unexpected call to prepare_query") }) - }), - query_input: Box::new(move |_, _| { - Box::pin(async { panic!("unexpected call to query_input") }) - }), - query_status: Box::new(move |_, _| { - Box::pin(async { panic!("unexpected call to query_status") }) - }), - complete_query: Box::new(move |_, _| { - Box::pin(async { panic!("unexpected call to complete_query") }) - }), - } - } -} diff --git a/ipa-core/src/helpers/transport/handler.rs b/ipa-core/src/helpers/transport/handler.rs new file mode 100644 index 000000000..9823173fb --- /dev/null +++ b/ipa-core/src/helpers/transport/handler.rs @@ -0,0 +1,201 @@ +use std::{ + fmt::{Debug, Formatter}, + future::Future, + marker::PhantomData, +}; + +use async_trait::async_trait; +use serde::de::DeserializeOwned; +use serde_json::json; + +use crate::{ + error::BoxError, + helpers::{ + query::PrepareQuery, transport::routing::Addr, BodyStream, HelperIdentity, + TransportIdentity, + }, + query::{ + NewQueryError, PrepareQueryError, ProtocolResult, QueryCompletionError, QueryInputError, + QueryStatus, QueryStatusError, + }, +}; + +/// Represents some response sent from MPC helper acting on a given request. It is rudimental now +/// because we sent everything as HTTP body, but it could evolve. +/// +/// ## Performance +/// This implementation is far from being optimal. Between HTTP and transport layer, there exists +/// one round of serialization and deserialization to properly represent the types. It is not critical +/// to address, because MPC helpers have to handle a constant number of requests per query. Note +/// that all requests tagged with [`crate::helpers::transport::RouteId::Records`] are not routed +/// through [`RequestHandler`], so there is no penalty. +/// +pub struct HelperResponse { + body: Vec, +} + +impl Debug for HelperResponse { + fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result { + todo!() + } +} + +impl HelperResponse { + /// Returns an empty response that indicates that incoming request has been processed successfully + #[must_use] + pub fn ok() -> Self { + Self { body: Vec::new() } + } + + /// Consumes [`Self`] and returns the body of the response. + #[must_use] + pub fn into_body(self) -> Vec { + self.body + } + + /// Attempts to interpret [`Self`] body as JSON-serialized `T`. + /// ## Errors + /// if `T` cannot be deserialized from response body. + pub fn try_into_owned(self) -> Result { + serde_json::from_slice(&self.body) + } +} + +impl From for HelperResponse { + fn from(value: PrepareQuery) -> Self { + let v = serde_json::to_vec(&json!({"query_id": value.query_id})).unwrap(); + Self { body: v } + } +} + +impl From<()> for HelperResponse { + fn from(_value: ()) -> Self { + Self::ok() + } +} + +impl From for HelperResponse { + fn from(value: QueryStatus) -> Self { + let v = serde_json::to_vec(&json!({"status": value})).unwrap(); + Self { body: v } + } +} + +impl> From for HelperResponse { + fn from(value: R) -> Self { + let v = value.as_ref().as_bytes(); + Self { body: v } + } +} + +/// Union of error types returned by API operations. +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error(transparent)] + NewQuery(#[from] NewQueryError), + #[error(transparent)] + QueryInput(#[from] QueryInputError), + #[error(transparent)] + QueryPrepare(#[from] PrepareQueryError), + #[error(transparent)] + QueryCompletion(#[from] QueryCompletionError), + #[error(transparent)] + QueryStatus(#[from] QueryStatusError), + #[error(transparent)] + DeserializationFailure(#[from] serde_json::Error), + #[error("MalformedRequest: {0}")] + BadRequest(BoxError), +} + +/// Trait for custom-handling different request types made against MPC helper parties. +/// There is a limitation for RPITIT that traits can't be made object-safe, hence the use of async_trait +#[async_trait] +pub trait RequestHandler: Send + Sync { + type Identity: TransportIdentity; + /// Handle the incoming request with metadata/headers specified in [`Addr`] and body encoded as + /// [`BodyStream`]. + async fn handle( + &self, + req: Addr, + data: BodyStream, + ) -> Result; +} + +#[async_trait] +impl RequestHandler for F +where + F: Fn(Addr, BodyStream) -> Result + + Send + + Sync + + 'static, +{ + type Identity = HelperIdentity; + + async fn handle( + &self, + req: Addr, + data: BodyStream, + ) -> Result { + self(req, data) + } +} + +pub fn make_boxed_handler<'a, I, F, Fut>(handler: F) -> Box + 'a> +where + I: TransportIdentity, + F: Fn(Addr, BodyStream) -> Fut + Send + Sync + 'a, + Fut: Future> + Send + 'a, +{ + struct Handler { + inner: F, + phantom: PhantomData, + } + #[async_trait] + impl RequestHandler for Handler + where + I: TransportIdentity, + F: Fn(Addr, BodyStream) -> Fut + Send + Sync, + Fut: Future> + Send, + { + type Identity = I; + + async fn handle( + &self, + req: Addr, + data: BodyStream, + ) -> Result { + (self.inner)(req, data).await + } + } + + Box::new(Handler { + inner: handler, + phantom: PhantomData, + }) +} + +// This handler panics when [`Self::handle`] method is called. +pub struct PanickingHandler { + phantom: PhantomData, +} + +impl Default for PanickingHandler { + fn default() -> Self { + Self { + phantom: PhantomData, + } + } +} + +#[async_trait] +impl RequestHandler for PanickingHandler { + type Identity = I; + + async fn handle( + &self, + req: Addr, + _data: BodyStream, + ) -> Result { + panic!("unexpected call: {req:?}"); + } +} diff --git a/ipa-core/src/helpers/transport/in_memory/handlers.rs b/ipa-core/src/helpers/transport/in_memory/handlers.rs deleted file mode 100644 index 209b9d3e4..000000000 --- a/ipa-core/src/helpers/transport/in_memory/handlers.rs +++ /dev/null @@ -1,92 +0,0 @@ -use std::{collections::HashSet, future::Future}; - -use crate::{ - helpers::{ - query::{PrepareQuery, QueryConfig}, - transport::in_memory::{routing::Addr, transport::Error, InMemoryTransport}, - HelperIdentity, RouteId, Transport, TransportCallbacks, TransportIdentity, - }, - protocol::QueryId, - sharding::ShardIndex, -}; - -/// Trait for in-memory request handlers. MPC handlers need to be able to process query requests, -/// while shard traffic does not need to and therefore does not make use of it. -/// -/// See [`HelperRequestHandler`]. -pub trait RequestHandler: Send { - fn handle( - &mut self, - transport: InMemoryTransport, - addr: Addr, - ) -> impl Future>> + Send; -} - -impl RequestHandler for () { - async fn handle( - &mut self, - _transport: InMemoryTransport, - addr: Addr, - ) -> Result<(), Error> { - panic!( - "Shards can only process {:?} requests, got {:?}", - RouteId::Records, - addr.route - ) - } -} - -/// Handler that keeps track of running queries and -/// routes [`RouteId::PrepareQuery`] and [`RouteId::ReceiveQuery`] requests to the stored -/// callback instance. This handler works for MPC networks, for sharding network see -/// [`RequestHandler`] -pub struct HelperRequestHandler { - active_queries: HashSet, - callbacks: TransportCallbacks>, -} - -impl From>> for HelperRequestHandler { - fn from(callbacks: TransportCallbacks>) -> Self { - Self { - active_queries: HashSet::default(), - callbacks, - } - } -} - -impl RequestHandler for HelperRequestHandler { - async fn handle( - &mut self, - transport: InMemoryTransport, - addr: Addr, - ) -> Result<(), Error> { - let dest = transport.identity(); - match addr.route { - RouteId::ReceiveQuery => { - let qc = addr.into::(); - (self.callbacks.receive_query)(Transport::clone_ref(&transport), qc) - .await - .map(|query_id| { - assert!( - self.active_queries.insert(query_id), - "the same query id {query_id:?} is generated twice" - ); - }) - .map_err(|e| Error::Rejected { - dest, - inner: Box::new(e), - }) - } - RouteId::PrepareQuery => { - let input = addr.into::(); - (self.callbacks.prepare_query)(Transport::clone_ref(&transport), input) - .await - .map_err(|e| Error::Rejected { - dest, - inner: Box::new(e), - }) - } - RouteId::Records => unreachable!(), - } - } -} diff --git a/ipa-core/src/helpers/transport/in_memory/mod.rs b/ipa-core/src/helpers/transport/in_memory/mod.rs index 929deca2d..4f8735cf0 100644 --- a/ipa-core/src/helpers/transport/in_memory/mod.rs +++ b/ipa-core/src/helpers/transport/in_memory/mod.rs @@ -1,13 +1,13 @@ -mod handlers; -mod routing; mod sharding; mod transport; +use std::array; + pub use sharding::InMemoryShardNetwork; pub use transport::Setup; use crate::{ - helpers::{transport::in_memory::transport::ListenerSetup, HelperIdentity, TransportCallbacks}, + helpers::{HelperIdentity, RequestHandler}, sync::{Arc, Weak}, }; @@ -21,17 +21,13 @@ pub struct InMemoryMpcNetwork { impl Default for InMemoryMpcNetwork { fn default() -> Self { - Self::new([ - TransportCallbacks::default(), - TransportCallbacks::default(), - TransportCallbacks::default(), - ]) + Self::new(array::from_fn(|_| None)) } } impl InMemoryMpcNetwork { #[must_use] - pub fn new(callbacks: [TransportCallbacks>; 3]) -> Self { + pub fn new(handlers: [Option>>; 3]) -> Self { let [mut first, mut second, mut third]: [_; 3] = HelperIdentity::make_three().map(Setup::new); @@ -39,10 +35,10 @@ impl InMemoryMpcNetwork { second.connect(&mut third); third.connect(&mut first); - let [cb1, cb2, cb3] = callbacks; + let [h1, h2, h3] = handlers; Self { - transports: [first.start(cb1), second.start(cb2), third.start(cb3)], + transports: [first.start(h1), second.start(h2), third.start(h3)], } } diff --git a/ipa-core/src/helpers/transport/in_memory/sharding.rs b/ipa-core/src/helpers/transport/in_memory/sharding.rs index 597bbd2f7..0700793cc 100644 --- a/ipa-core/src/helpers/transport/in_memory/sharding.rs +++ b/ipa-core/src/helpers/transport/in_memory/sharding.rs @@ -1,6 +1,6 @@ use crate::{ helpers::{ - transport::in_memory::transport::{InMemoryTransport, ListenerSetup, Setup}, + transport::in_memory::transport::{InMemoryTransport, Setup}, HelperIdentity, }, sharding::ShardIndex, @@ -36,7 +36,7 @@ impl InMemoryShardNetwork { shard_connections .into_iter() - .map(|s| tracing::info_span!("", ?h).in_scope(|| s.start(()))) + .map(|s| tracing::info_span!("", ?h).in_scope(|| s.start(None))) .collect::>() .into() }); @@ -74,7 +74,10 @@ mod tests { use tokio_stream::wrappers::ReceiverStream; use crate::{ - helpers::{transport::in_memory::InMemoryShardNetwork, HelperIdentity, RouteId, Transport}, + helpers::{ + transport::{in_memory::InMemoryShardNetwork, routing::RouteId}, + HelperIdentity, Transport, + }, protocol::{step::Gate, QueryId}, sharding::ShardIndex, test_executor::run, diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 97fc45c6b..5444809c0 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -1,6 +1,5 @@ use std::{ collections::HashMap, - convert, fmt::{Debug, Formatter}, io, pin::Pin, @@ -21,22 +20,18 @@ use tracing::Instrument; use crate::{ error::BoxError, helpers::{ - transport::in_memory::{ - handlers::{HelperRequestHandler, RequestHandler}, - routing::Addr, - }, - HelperIdentity, NoResourceIdentifier, QueryIdBinding, ReceiveRecords, RouteId, RouteParams, - StepBinding, StreamCollection, Transport, TransportIdentity, + transport::routing::{Addr, RouteId}, + ApiError, BodyStream, HelperResponse, NoResourceIdentifier, QueryIdBinding, ReceiveRecords, + RequestHandler, RouteParams, StepBinding, StreamCollection, Transport, TransportIdentity, }, protocol::{step::Gate, QueryId}, - sharding::ShardIndex, sync::{Arc, Weak}, }; type Packet = ( Addr, InMemoryStream, - oneshot::Sender>>, + oneshot::Sender>, ); type ConnectionTx = Sender>; type ConnectionRx = Receiver>; @@ -55,6 +50,11 @@ pub enum Error { #[source] inner: BoxError, }, + #[error(transparent)] + DeserializationFailed { + #[from] + inner: serde_json::Error, + }, } /// In-memory implementation of [`Transport`] backed by Tokio mpsc channels. @@ -85,15 +85,14 @@ impl InMemoryTransport { /// out and processes it, the same way as query processor does. That will allow all tasks to be /// created in one place (driver). It does not affect the [`Transport`] interface, /// so I'll leave it as is for now. - fn listen>( + fn listen( self: &Arc, - mut callbacks: L::Handler, + handler: Option>>, mut rx: ConnectionRx, ) { tokio::spawn( { let streams = self.record_streams.clone(); - let this = Arc::downgrade(self); async move { while let Some((addr, stream, ack)) = rx.recv().await { tracing::trace!("received new message: {addr:?}"); @@ -104,10 +103,23 @@ impl InMemoryTransport { let gate = addr.gate.unwrap(); let from = addr.origin.unwrap(); streams.add_stream((query_id, from, gate), stream); - Ok(()) + Ok(HelperResponse::ok()) } - RouteId::ReceiveQuery | RouteId::PrepareQuery => { - callbacks.handle(Clone::clone(&this), addr).await + RouteId::ReceiveQuery + | RouteId::PrepareQuery + | RouteId::QueryInput + | RouteId::QueryStatus + | RouteId::CompleteQuery => { + handler + .as_ref() + .expect("Request handler is provided") + .handle( + addr, + BodyStream::from_infallible( + stream.map(Vec::into_boxed_slice), + ), + ) + .await } }; @@ -164,7 +176,7 @@ impl Transport for Weak> { { let this = self.upgrade().unwrap(); let channel = this.get_channel(dest); - let addr = Addr::from_route(this.identity, route); + let addr = Addr::from_route(Some(this.identity), route); let (ack_tx, ack_rx) = oneshot::channel(); channel @@ -179,8 +191,13 @@ impl Transport for Weak> { .map_err(|_recv_error| Error::Rejected { dest, inner: "channel closed".into(), - }) - .and_then(convert::identity) + })? + .map_err(|e| Error::Rejected { + dest, + inner: e.into(), + })?; + + Ok(()) } fn receive>( @@ -283,47 +300,21 @@ impl Setup { .is_none()); } - fn into_active_conn::Handler>>( + pub(crate) fn start( self, - callbacks: H, - ) -> (ConnectionTx, Arc>) - where - Self: ListenerSetup, - { - let transport = Arc::new(InMemoryTransport::new(self.identity, self.connections)); - transport.listen::(callbacks.into(), self.rx); - - (self.tx, transport) - } -} - -/// Trait to tie up different transports to the requests handlers they can use inside their -/// listen loop. -pub trait ListenerSetup { - type Identity: TransportIdentity; - type Handler: RequestHandler + 'static; - type Listener; - - fn start>(self, handler: I) -> Self::Listener; -} - -impl ListenerSetup for Setup { - type Identity = HelperIdentity; - type Handler = HelperRequestHandler; - type Listener = Arc>; - - fn start>(self, handler: I) -> Self::Listener { + handler: Option>>, + ) -> Arc> { self.into_active_conn(handler).1 } -} -impl ListenerSetup for Setup { - type Identity = ShardIndex; - type Handler = (); - type Listener = Arc>; + fn into_active_conn( + self, + handler: Option>>, + ) -> (ConnectionTx, Arc>) { + let transport = Arc::new(InMemoryTransport::new(self.identity, self.connections)); + transport.listen(handler, self.rx); - fn start>(self, handler: I) -> Self::Listener { - self.into_active_conn(handler).1 + (self.tx, transport) } } @@ -331,7 +322,7 @@ impl ListenerSetup for Setup { mod tests { use std::{ collections::HashMap, - convert, io, + io, io::ErrorKind, num::NonZeroUsize, panic::AssertUnwindSafe, @@ -345,14 +336,15 @@ mod tests { use crate::{ ff::{FieldType, Fp31}, helpers::{ - query::{QueryConfig, QueryType::TestMultiply}, - transport::in_memory::{ - transport::{ - Addr, ConnectionTx, Error, InMemoryStream, InMemoryTransport, ListenerSetup, + query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, + transport::{ + in_memory::{ + transport::{Addr, ConnectionTx, Error, InMemoryStream, InMemoryTransport}, + InMemoryMpcNetwork, Setup, }, - InMemoryMpcNetwork, Setup, + routing::RouteId, }, - HelperIdentity, OrderingSender, RouteId, Transport, TransportCallbacks, + HelperIdentity, HelperResponse, OrderingSender, Role, RoleAssignment, Transport, TransportIdentity, }, protocol::{step::Gate, QueryId}, @@ -368,41 +360,46 @@ mod tests { ) { let (tx, rx) = oneshot::channel(); sender.send((addr, data, tx)).await.unwrap(); - rx.await - .map_err(|_e| Error::Io { + let _ = rx + .await + .map_err(|_e| Error::::Io { inner: io::Error::new(ErrorKind::ConnectionRefused, "channel closed"), }) - .and_then(convert::identity) + .unwrap() .unwrap(); } #[tokio::test] - async fn callback_is_called() { + async fn handler_is_called() { let (signal_tx, signal_rx) = oneshot::channel(); let signal_tx = Arc::new(Mutex::new(Some(signal_tx))); - let (tx, _transport) = - Setup::new(HelperIdentity::ONE).into_active_conn(TransportCallbacks { - receive_query: Box::new(move |_transport, query_config| { - let signal_tx = Arc::clone(&signal_tx); - Box::pin(async move { - // this works because callback is only called once - signal_tx - .lock() - .unwrap() - .take() - .expect("query callback invoked more than once") - .send(query_config) - .unwrap(); - Ok(QueryId) - }) - }), - ..Default::default() - }); + let (tx, _) = Setup::new(HelperIdentity::ONE).into_active_conn(Some(Box::new( + move |addr: Addr, _| { + let RouteId::ReceiveQuery = addr.route else { + panic!("unexpected call: {addr:?}") + }; + let query_config = addr.into::().unwrap(); + + // this works because callback is only called once + signal_tx + .lock() + .unwrap() + .take() + .expect("query callback invoked more than once") + .send(query_config) + .unwrap(); + Ok(HelperResponse::from(PrepareQuery { + query_id: QueryId, + config: query_config, + roles: RoleAssignment::try_from([Role::H1, Role::H2, Role::H3]).unwrap(), + })) + }, + ))); let expected = QueryConfig::new(TestMultiply, FieldType::Fp32BitPrime, 1u32).unwrap(); send_and_ack( &tx, - Addr::from_route(HelperIdentity::TWO, &expected), + Addr::from_route(Some(HelperIdentity::TWO), expected), InMemoryStream::empty(), ) .await; @@ -412,8 +409,7 @@ mod tests { #[tokio::test] async fn receive_not_ready() { - let (tx, transport) = - Setup::new(HelperIdentity::ONE).into_active_conn(TransportCallbacks::default()); + let (tx, transport) = Setup::new(HelperIdentity::ONE).into_active_conn(None); let transport = Arc::downgrade(&transport); let expected = vec![vec![1], vec![2]]; @@ -436,8 +432,7 @@ mod tests { #[tokio::test] async fn receive_ready() { - let (tx, transport) = - Setup::new(HelperIdentity::ONE).into_active_conn(TransportCallbacks::default()); + let (tx, transport) = Setup::new(HelperIdentity::ONE).into_active_conn(None); let expected = vec![vec![1], vec![2]]; send_and_ack( @@ -500,8 +495,8 @@ mod tests { setup1.connect(&mut setup2); - let transport1 = setup1.start(TransportCallbacks::default()); - let transport2 = setup2.start(TransportCallbacks::default()); + let transport1 = setup1.start(None); + let transport2 = setup2.start(None); let transports = HashMap::from([ (HelperIdentity::ONE, Arc::downgrade(&transport1)), (HelperIdentity::TWO, Arc::downgrade(&transport2)), @@ -513,8 +508,7 @@ mod tests { #[tokio::test] async fn panic_if_stream_received_twice() { - let (tx, owned_transport) = - Setup::new(HelperIdentity::ONE).into_active_conn(TransportCallbacks::default()); + let (tx, owned_transport) = Setup::new(HelperIdentity::ONE).into_active_conn(None); let gate = Gate::from(STEP); let (stream_tx, stream_rx) = channel(1); let stream = InMemoryStream::from(stream_rx); diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index 458e2b5be..e6a70b199 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -8,13 +8,17 @@ use crate::{ protocol::{step::Gate, QueryId}, }; -pub mod callbacks; +mod handler; #[cfg(feature = "in-memory-infra")] mod in_memory; pub mod query; mod receive; +pub mod routing; mod stream; +pub use handler::{ + make_boxed_handler, Error as ApiError, HelperResponse, PanickingHandler, RequestHandler, +}; #[cfg(feature = "in-memory-infra")] pub use in_memory::{InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport}; pub use receive::{LogErrors, ReceiveRecords}; @@ -26,7 +30,7 @@ pub use stream::{ }; use crate::{ - helpers::{Role, TransportIdentity}, + helpers::{transport::routing::RouteId, Role, TransportIdentity}, sharding::ShardIndex, }; @@ -57,13 +61,6 @@ pub struct NoResourceIdentifier; pub struct NoQueryId; pub struct NoStep; -#[derive(Debug, Copy, Clone)] -pub enum RouteId { - Records, - ReceiveQuery, - PrepareQuery, -} - impl ResourceIdentifier for NoResourceIdentifier {} impl ResourceIdentifier for RouteId {} @@ -90,6 +87,9 @@ where Option: From, Option: From, { + // This is not great and definitely not a zero-cost abstraction. We serialize parameters + // here, only to deserialize them again inside the request handler. I am not too worried + // about it as long as the data we serialize is tiny, which is the case right now. type Params: Borrow; fn resource_identifier(&self) -> R; @@ -139,6 +139,26 @@ impl RouteParams for (RouteId, QueryId, Gate) { } } +impl RouteParams for (RouteId, QueryId) { + type Params = &'static str; + + fn resource_identifier(&self) -> RouteId { + self.0 + } + + fn query_id(&self) -> QueryId { + self.1 + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + "" + } +} + /// Transport that supports per-query,per-step channels #[async_trait] pub trait Transport: Clone + Send + Sync + 'static { diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index 5cbaa7aaf..e8e1f35ba 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -10,8 +10,8 @@ use serde::{Deserialize, Deserializer, Serialize}; use crate::{ ff::FieldType, helpers::{ - transport::{BodyStream, NoQueryId, NoStep}, - GatewayConfig, RoleAssignment, RouteId, RouteParams, + transport::{routing::RouteId, BodyStream, NoQueryId, NoStep}, + GatewayConfig, RoleAssignment, RouteParams, }, protocol::{step::Step, QueryId}, }; @@ -108,6 +108,26 @@ pub struct PrepareQuery { pub roles: RoleAssignment, } +impl RouteParams for PrepareQuery { + type Params = String; + + fn resource_identifier(&self) -> RouteId { + RouteId::PrepareQuery + } + + fn query_id(&self) -> QueryId { + self.query_id + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + serde_json::to_string(self).unwrap() + } +} + impl RouteParams for &QueryConfig { type Params = String; diff --git a/ipa-core/src/helpers/transport/in_memory/routing.rs b/ipa-core/src/helpers/transport/routing.rs similarity index 64% rename from ipa-core/src/helpers/transport/in_memory/routing.rs rename to ipa-core/src/helpers/transport/routing.rs index 68c5015c0..a8da5200c 100644 --- a/ipa-core/src/helpers/transport/in_memory/routing.rs +++ b/ipa-core/src/helpers/transport/routing.rs @@ -3,24 +3,36 @@ use std::{borrow::Borrow, fmt::Debug}; use serde::de::DeserializeOwned; use crate::{ - helpers::{QueryIdBinding, RouteId, RouteParams, StepBinding, TransportIdentity}, + helpers::{QueryIdBinding, RouteParams, StepBinding, TransportIdentity}, protocol::{step::Gate, QueryId}, }; +// The type of request made to an MPC helper. +#[derive(Debug, Copy, Clone)] +pub enum RouteId { + Records, + ReceiveQuery, + PrepareQuery, + QueryInput, + QueryStatus, + CompleteQuery, +} + /// The header/metadata of the incoming request. #[derive(Debug)] -pub(super) struct Addr { +pub struct Addr { pub route: RouteId, pub origin: Option, pub query_id: Option, pub gate: Option, + // String and not vec for readability pub params: String, } impl Addr { #[allow(clippy::needless_pass_by_value)] // to avoid using double-reference at callsites pub fn from_route>( - origin: I, + origin: Option, route: R, ) -> Self where @@ -29,15 +41,19 @@ impl Addr { { Self { route: route.resource_identifier(), - origin: Some(origin), + origin, query_id: route.query_id().into(), gate: route.gate().into(), params: route.extra().borrow().to_string(), } } - pub fn into(self) -> T { - serde_json::from_str(&self.params).unwrap() + /// Deserializes JSON-encoded request parameters into a client-supplied type `T`. + /// + /// ## Errors + /// If deseserialization fails + pub fn into(self) -> Result { + serde_json::from_str(&self.params) } #[cfg(all(test, unit_test))] diff --git a/ipa-core/src/helpers/transport/stream/axum_body.rs b/ipa-core/src/helpers/transport/stream/axum_body.rs index fec8103b9..f4007cc7f 100644 --- a/ipa-core/src/helpers/transport/stream/axum_body.rs +++ b/ipa-core/src/helpers/transport/stream/axum_body.rs @@ -4,6 +4,7 @@ use std::{ }; use axum::extract::{BodyStream, FromRequest, RequestParts}; +use bytes::Bytes; use futures::{Stream, TryStreamExt}; use hyper::Body; use pin_project::pin_project; @@ -28,6 +29,11 @@ impl WrappedAxumBodyStream { pub(super) fn new_internal(inner: BodyStream) -> Self { Self(inner.map_err(axum::Error::into_inner as fn(axum::Error) -> BoxError)) } + + #[must_use] + pub fn empty() -> Self { + Self::from_body(Bytes::new()) + } } impl Stream for WrappedAxumBodyStream { diff --git a/ipa-core/src/helpers/transport/stream/box_body.rs b/ipa-core/src/helpers/transport/stream/box_body.rs index e51d86f87..aa7a25583 100644 --- a/ipa-core/src/helpers/transport/stream/box_body.rs +++ b/ipa-core/src/helpers/transport/stream/box_body.rs @@ -3,7 +3,8 @@ use std::{ task::{Context, Poll}, }; -use futures::Stream; +use bytes::Bytes; +use futures::{stream::StreamExt, Stream}; use crate::helpers::transport::stream::BoxBytesStream; @@ -16,6 +17,15 @@ impl WrappedBoxBodyStream { pub fn new(inner: axum::extract::BodyStream) -> Self { Self(Box::pin(super::WrappedAxumBodyStream::new_internal(inner))) } + + pub fn from_infallible> + Send + 'static>(input: S) -> Self { + Self(Box::pin(input.map(Bytes::from).map(Ok))) + } + + #[must_use] + pub fn empty() -> Self { + WrappedBoxBodyStream(Box::pin(futures::stream::empty())) + } } impl Stream for WrappedBoxBodyStream { diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 67296ba63..296298860 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -431,44 +431,17 @@ pub(crate) mod tests { use crate::{ ff::{FieldType, Fp31}, helpers::{ - query::QueryType::TestMultiply, BytesStream, RoleAssignment, Transport, - TransportCallbacks, MESSAGE_PAYLOAD_SIZE_BYTES, + make_boxed_handler, query::QueryType::TestMultiply, BytesStream, HelperResponse, + PanickingHandler, RequestHandler, RoleAssignment, Transport, + MESSAGE_PAYLOAD_SIZE_BYTES, }, - net::{test::TestServer, HttpTransport}, + net::test::TestServer, protocol::step::StepNarrow, query::ProtocolResult, secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, sync::Arc, }; - // This is a kludgy way of working around `TransportCallbacks` not being `Clone`, so - // that tests can run against both HTTP and HTTPS servers with one set. - // - // If the use grows beyond that, it's probably worth doing something more elegant, on the - // TransportCallbacks type itself (references and lifetime parameters, dyn_clone, or make it a - // trait and implement it on an `Arc` type). - fn clone_callbacks( - cb: TransportCallbacks, - ) -> (TransportCallbacks, TransportCallbacks) { - fn wrap(inner: &Arc>) -> TransportCallbacks { - let ri = Arc::clone(inner); - let pi = Arc::clone(inner); - let qi = Arc::clone(inner); - let si = Arc::clone(inner); - let ci = Arc::clone(inner); - TransportCallbacks { - receive_query: Box::new(move |t, req| (ri.receive_query)(t, req)), - prepare_query: Box::new(move |t, req| (pi.prepare_query)(t, req)), - query_input: Box::new(move |t, req| (qi.query_input)(t, req)), - query_status: Box::new(move |t, req| (si.query_status)(t, req)), - complete_query: Box::new(move |t, req| (ci.complete_query)(t, req)), - } - } - - let arc_cb = Arc::new(cb); - (wrap(&arc_cb), wrap(&arc_cb)) - } - #[tokio::test] async fn untrusted_certificate() { const ECHO_DATA: &str = "asdf"; @@ -500,21 +473,18 @@ pub(crate) mod tests { /// Also tests that the same functionality works for both `http` and `https` and all supported /// HTTP versions (HTTP 1.1 and HTTP 2 at the moment) . In order to ensure /// this, the return type of `clientf` must be `Eq + Debug` so that the results can be compared. - async fn test_query_command( + async fn test_query_command( clientf: ClientF, - server_cb: TransportCallbacks>, + server_handler: HandlerF, ) -> ClientOut where ClientOut: Eq + Debug, ClientFut: Future, ClientF: Fn(MpcHelperClient) -> ClientFut, + HandlerF: Fn() -> Box>, { - let mut cb = server_cb; let mut results = Vec::with_capacity(4); for (use_https, use_http1) in zip([true, false], [true, false]) { - let (cur, next) = clone_callbacks(cb); - cb = next; - let mut test_server_builder = TestServer::builder(); if !use_https { test_server_builder = test_server_builder.disable_https(); @@ -527,7 +497,10 @@ pub(crate) mod tests { let TestServer { client: http_client, .. - } = test_server_builder.with_callbacks(cur).build().await; + } = test_server_builder + .with_request_handler(server_handler()) + .build() + .await; results.push(clientf(http_client).await); } @@ -543,7 +516,7 @@ pub(crate) mod tests { let output = test_query_command( |client| async move { client.echo(expected_output).await.unwrap() }, - TransportCallbacks::default(), + || Box::::default(), ) .await; assert_eq!(expected_output, &output); @@ -554,16 +527,21 @@ pub(crate) mod tests { let expected_query_id = QueryId; let expected_query_config = QueryConfig::new(TestMultiply, FieldType::Fp31, 1).unwrap(); - let cb = TransportCallbacks { - receive_query: Box::new(move |_transport, query_config| { + let handler = || { + make_boxed_handler(move |addr, _| async move { + let query_config = addr.into::().unwrap(); assert_eq!(query_config, expected_query_config); - Box::pin(ready(Ok(expected_query_id))) - }), - ..Default::default() + + Ok(HelperResponse::from(PrepareQuery { + query_id: expected_query_id, + config: query_config, + roles: RoleAssignment::new(HelperIdentity::make_three()), + })) + }) }; let query_id = test_query_command( |client| async move { client.create_query(expected_query_config).await.unwrap() }, - cb, + handler, ) .await; assert_eq!(query_id, expected_query_id); @@ -571,25 +549,31 @@ pub(crate) mod tests { #[tokio::test] async fn prepare() { - let input = PrepareQuery { - query_id: QueryId, - config: QueryConfig::new(TestMultiply, FieldType::Fp31, 1).unwrap(), - roles: RoleAssignment::new(HelperIdentity::make_three()), - }; - let expected_data = input.clone(); - let cb = TransportCallbacks { - prepare_query: Box::new(move |_transport, prepare_query| { - assert_eq!(prepare_query, expected_data); - Box::pin(ready(Ok(()))) - }), - ..Default::default() + let config = QueryConfig::new(TestMultiply, FieldType::Fp31, 1).unwrap(); + let handler = move || { + make_boxed_handler(move |addr, _| async move { + let input = PrepareQuery { + query_id: QueryId, + config, + roles: RoleAssignment::new(HelperIdentity::make_three()), + }; + let prepare_query = addr.into::().unwrap(); + assert_eq!(prepare_query, input); + + Ok(HelperResponse::ok()) + }) }; + test_query_command( |client| { - let req = input.clone(); + let req = PrepareQuery { + query_id: QueryId, + config, + roles: RoleAssignment::new(HelperIdentity::make_three()), + }; async move { client.prepare_query(req).await.unwrap() } }, - cb, + handler, ) .await; } @@ -598,15 +582,13 @@ pub(crate) mod tests { async fn input() { let expected_query_id = QueryId; let expected_input = &[8u8; 25]; - let cb = TransportCallbacks { - query_input: Box::new(move |_transport, query_input| { - Box::pin(async move { - assert_eq!(query_input.query_id, expected_query_id); - assert_eq!(&query_input.input_stream.to_vec().await, expected_input); - Ok(()) - }) - }), - ..Default::default() + let handler = move || { + make_boxed_handler(move |addr, data| async move { + assert_eq!(addr.query_id, Some(expected_query_id)); + assert_eq!(data.to_vec().await, expected_input); + + Ok(HelperResponse::ok()) + }) }; test_query_command( |client| async move { @@ -616,7 +598,7 @@ pub(crate) mod tests { }; client.query_input(data).await.unwrap(); }, - cb, + handler, ) .await; } @@ -653,25 +635,30 @@ pub(crate) mod tests { #[tokio::test] async fn results() { - let expected_results = Box::new(vec![Replicated::from(( + let expected_results = [ Fp31::try_from(1u128).unwrap(), Fp31::try_from(2u128).unwrap(), - ))]); + ]; let expected_query_id = QueryId; - let raw_results = expected_results.to_vec(); - let cb = TransportCallbacks { - complete_query: Box::new(move |_transport, query_id| { - let results: Box = Box::new(raw_results.clone()); - assert_eq!(query_id, expected_query_id); - Box::pin(ready(Ok(results))) - }), - ..Default::default() + let handler = move || { + make_boxed_handler(move |addr, _| async move { + let results: Box = Box::new( + [Replicated::from((expected_results[0], expected_results[1]))].to_vec(), + ); + assert_eq!(addr.query_id, Some(expected_query_id)); + Ok(HelperResponse::from(results)) + }) }; let results = test_query_command( |client| async move { client.query_results(expected_query_id).await.unwrap() }, - cb, + handler, ) .await; - assert_eq!(results.to_vec(), expected_results.into_bytes()); + assert_eq!( + results.to_vec(), + [Replicated::from((expected_results[0], expected_results[1]))] + .to_vec() + .as_bytes() + ); } } diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 72ad9ceb2..2937c948a 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -178,7 +178,7 @@ pub mod query { use hyper::http::uri; use crate::{ - helpers::query::QueryConfig, + helpers::{query::QueryConfig, HelperResponse}, net::{ http_serde::query::{QueryConfigQueryParams, BASE_AXUM_PATH}, Error, @@ -229,6 +229,14 @@ pub mod query { pub query_id: QueryId, } + impl TryFrom for ResponseBody { + type Error = serde_json::Error; + + fn try_from(value: HelperResponse) -> Result { + value.try_into_owned() + } + } + pub const AXUM_PATH: &str = "/"; } @@ -462,13 +470,38 @@ pub mod query { use axum::extract::{FromRequest, Path, RequestParts}; use serde::{Deserialize, Serialize}; - use crate::{net::Error, protocol::QueryId, query::QueryStatus}; + use crate::{ + helpers::{routing::RouteId, HelperResponse, NoStep, RouteParams}, + net::Error, + protocol::QueryId, + query::QueryStatus, + }; - #[derive(Debug, Clone)] + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Request { pub query_id: QueryId, } + impl RouteParams for Request { + type Params = String; + + fn resource_identifier(&self) -> RouteId { + RouteId::QueryStatus + } + + fn query_id(&self) -> QueryId { + self.query_id + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + serde_json::to_string(self).unwrap() + } + } + impl Request { #[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] // needed because client is blocking; remove when non-blocking pub fn new(query_id: QueryId) -> Self { @@ -509,6 +542,12 @@ pub mod query { pub status: QueryStatus, } + impl From for ResponseBody { + fn from(value: HelperResponse) -> Self { + serde_json::from_slice(value.into_body().as_slice()).unwrap() + } + } + pub const AXUM_PATH: &str = "/:query_id"; } @@ -516,13 +555,37 @@ pub mod query { use async_trait::async_trait; use axum::extract::{FromRequest, Path, RequestParts}; - use crate::{net::Error, protocol::QueryId}; + use crate::{ + helpers::{routing::RouteId, NoStep, RouteParams}, + net::Error, + protocol::QueryId, + }; - #[derive(Debug, Clone)] + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Request { pub query_id: QueryId, } + impl RouteParams for Request { + type Params = String; + + fn resource_identifier(&self) -> RouteId { + RouteId::CompleteQuery + } + + fn query_id(&self) -> QueryId { + self.query_id + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + serde_json::to_string(self).unwrap() + } + } + impl Request { #[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] // needed because client is blocking; remove when non-blocking pub fn new(query_id: QueryId) -> Self { diff --git a/ipa-core/src/net/server/handlers/query/create.rs b/ipa-core/src/net/server/handlers/query/create.rs index 903fbe10f..e69294293 100644 --- a/ipa-core/src/net/server/handlers/query/create.rs +++ b/ipa-core/src/net/server/handlers/query/create.rs @@ -2,7 +2,7 @@ use axum::{routing::post, Extension, Json, Router}; use hyper::StatusCode; use crate::{ - helpers::Transport, + helpers::{ApiError::NewQuery, BodyStream, Transport}, net::{http_serde, Error, HttpTransport}, query::NewQueryError, sync::Arc, @@ -15,9 +15,12 @@ async fn handler( req: http_serde::query::create::Request, ) -> Result, Error> { let transport = Transport::clone_ref(&*transport); - match transport.receive_query(req.query_config).await { - Ok(query_id) => Ok(Json(http_serde::query::create::ResponseBody { query_id })), - Err(err @ NewQueryError::State { .. }) => { + match transport + .dispatch(req.query_config, BodyStream::empty()) + .await + { + Ok(resp) => Ok(Json(resp.try_into()?)), + Err(err @ NewQuery(NewQueryError::State { .. })) => { Err(Error::application(StatusCode::CONFLICT, err)) } Err(err) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, err)), @@ -32,7 +35,7 @@ pub fn router(transport: Arc) -> Router { #[cfg(all(test, unit_test))] mod tests { - use std::{future::ready, num::NonZeroU32}; + use std::num::NonZeroU32; use axum::http::Request; use hyper::{ @@ -43,8 +46,9 @@ mod tests { use crate::{ ff::FieldType, helpers::{ - query::{IpaQueryConfig, QueryConfig, QueryType}, - TransportCallbacks, + query::{IpaQueryConfig, PrepareQuery, QueryConfig, QueryType}, + routing::{Addr, RouteId}, + HelperIdentity, HelperResponse, Role, RoleAssignment, }, net::{ http_serde, @@ -55,14 +59,22 @@ mod tests { }; async fn create_test(expected_query_config: QueryConfig) { - let cb = TransportCallbacks { - receive_query: Box::new(move |_transport, query_config| { + let TestServer { server, .. } = TestServer::builder() + .with_request_handler(Box::new(move |addr: Addr, _| { + let RouteId::ReceiveQuery = addr.route else { + panic!("unexpected call"); + }; + + let query_config = addr.into().unwrap(); assert_eq!(query_config, expected_query_config); - Box::pin(ready(Ok(QueryId))) - }), - ..Default::default() - }; - let TestServer { server, .. } = TestServer::builder().with_callbacks(cb).build().await; + Ok(HelperResponse::from(PrepareQuery { + query_id: QueryId, + config: query_config, + roles: RoleAssignment::try_from([Role::H1, Role::H2, Role::H3]).unwrap(), + })) + })) + .build() + .await; let req = http_serde::query::create::Request::new(expected_query_config); let req = req .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) diff --git a/ipa-core/src/net/server/handlers/query/input.rs b/ipa-core/src/net/server/handlers/query/input.rs index f926f4e33..842db5c77 100644 --- a/ipa-core/src/net/server/handlers/query/input.rs +++ b/ipa-core/src/net/server/handlers/query/input.rs @@ -2,7 +2,7 @@ use axum::{routing::post, Extension, Router}; use hyper::StatusCode; use crate::{ - helpers::Transport, + helpers::{routing::RouteId, Transport}, net::{http_serde, Error, HttpTransport}, sync::Arc, }; @@ -12,10 +12,15 @@ async fn handler( req: http_serde::query::input::Request, ) -> Result<(), Error> { let transport = Transport::clone_ref(&*transport); - transport - .query_input(req.query_input) + let _ = transport + .dispatch( + (RouteId::QueryInput, req.query_input.query_id), + req.query_input.input_stream, + ) .await - .map_err(|e| Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)) + .map_err(|e| Error::application(StatusCode::INTERNAL_SERVER_ERROR, e))?; + + Ok(()) } pub fn router(transport: Arc) -> Router { @@ -28,9 +33,14 @@ pub fn router(transport: Arc) -> Router { mod tests { use axum::{http::Request, Extension}; use hyper::{Body, StatusCode}; + use tokio::runtime::Handle; use crate::{ - helpers::{query::QueryInput, BytesStream, TransportCallbacks}, + helpers::{ + query::QueryInput, + routing::{Addr, RouteId}, + BodyStream, BytesStream, HelperIdentity, HelperResponse, + }, net::{ http_serde, server::handlers::query::{ @@ -42,21 +52,30 @@ mod tests { protocol::QueryId, }; - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn input_test() { let expected_query_id = QueryId; let expected_input = &[4u8; 4]; - let cb = TransportCallbacks { - query_input: Box::new(move |_transport, query_input| { - Box::pin(async move { - assert_eq!(query_input.query_id, expected_query_id); - assert_eq!(&query_input.input_stream.to_vec().await, expected_input); - Ok(()) - }) - }), - ..Default::default() - }; - let TestServer { transport, .. } = TestServer::builder().with_callbacks(cb).build().await; + let req_handler = Box::new(move |addr: Addr, data: BodyStream| { + let RouteId::QueryInput = addr.route else { + panic!("unexpected call"); + }; + + assert_eq!(addr.query_id, Some(expected_query_id)); + assert_eq!( + tokio::task::block_in_place(move || { + Handle::current().block_on(async move { data.to_vec().await }) + }), + expected_input + ); + + Ok(HelperResponse::ok()) + }); + + let TestServer { transport, .. } = TestServer::builder() + .with_request_handler(req_handler) + .build() + .await; let req = http_serde::query::input::Request::new(QueryInput { query_id: expected_query_id, input_stream: expected_input.to_vec().into(), diff --git a/ipa-core/src/net/server/handlers/query/prepare.rs b/ipa-core/src/net/server/handlers/query/prepare.rs index bdb746bed..9f55793b0 100644 --- a/ipa-core/src/net/server/handlers/query/prepare.rs +++ b/ipa-core/src/net/server/handlers/query/prepare.rs @@ -2,7 +2,8 @@ use axum::{response::IntoResponse, routing::post, Extension, Router}; use hyper::StatusCode; use crate::{ - net::{http_serde, server::ClientIdentity, HttpTransport}, + helpers::{BodyStream, Transport}, + net::{http_serde, server::ClientIdentity, Error, HttpTransport}, query::PrepareQueryError, sync::Arc, }; @@ -13,8 +14,14 @@ async fn handler( transport: Extension>, _: Extension, // require that client is an authenticated helper req: http_serde::query::prepare::Request, -) -> Result<(), PrepareQueryError> { - Arc::clone(&transport).prepare_query(req.data).await +) -> Result<(), Error> { + let transport = Transport::clone_ref(&*transport); + let _ = transport + .dispatch(req.data, BodyStream::empty()) + .await + .map_err(|e| Error::application(StatusCode::INTERNAL_SERVER_ERROR, e))?; + + Ok(()) } impl IntoResponse for PrepareQueryError { @@ -31,7 +38,6 @@ pub fn router(transport: Arc) -> Router { #[cfg(all(test, unit_test))] mod tests { - use std::future::ready; use axum::{http::Request, Extension}; use hyper::{Body, StatusCode}; @@ -40,7 +46,8 @@ mod tests { ff::FieldType, helpers::{ query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, - HelperIdentity, RoleAssignment, TransportCallbacks, + routing::{Addr, RouteId}, + BodyStream, HelperIdentity, HelperResponse, RoleAssignment, }, net::{ http_serde, @@ -65,15 +72,21 @@ mod tests { roles: RoleAssignment::new(HelperIdentity::make_three()), }); let expected_prepare_query = req.data.clone(); + let TestServer { transport, .. } = TestServer::builder() + .with_request_handler(Box::new( + move |addr: Addr, _data: BodyStream| { + let RouteId::PrepareQuery = addr.route else { + panic!("unexpected call"); + }; + + let query_config = addr.into::().unwrap(); + assert_eq!(query_config, expected_prepare_query); + Ok(HelperResponse::ok()) + }, + )) + .build() + .await; - let cb = TransportCallbacks { - prepare_query: Box::new(move |_transport, prepare_query| { - assert_eq!(prepare_query, expected_prepare_query); - Box::pin(ready(Ok(()))) - }), - ..Default::default() - }; - let TestServer { transport, .. } = TestServer::builder().with_callbacks(cb).build().await; handler( Extension(transport), Extension(ClientIdentity(HelperIdentity::TWO)), diff --git a/ipa-core/src/net/server/handlers/query/results.rs b/ipa-core/src/net/server/handlers/query/results.rs index cb56e9315..679283f7e 100644 --- a/ipa-core/src/net/server/handlers/query/results.rs +++ b/ipa-core/src/net/server/handlers/query/results.rs @@ -2,7 +2,7 @@ use axum::{routing::get, Extension, Router}; use hyper::StatusCode; use crate::{ - helpers::Transport, + helpers::{BodyStream, Transport}, net::{http_serde, server::Error, HttpTransport}, sync::Arc, }; @@ -14,8 +14,8 @@ async fn handler( ) -> Result, Error> { // TODO: we may be able to stream the response let transport = Transport::clone_ref(&*transport); - match transport.complete_query(req.query_id).await { - Ok(result) => Ok(result.into_bytes()), + match transport.dispatch(req, BodyStream::empty()).await { + Ok(resp) => Ok(resp.into_body()), Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), } } @@ -28,14 +28,16 @@ pub fn router(transport: Arc) -> Router { #[cfg(all(test, unit_test))] mod tests { - use std::future::ready; use axum::{http::Request, Extension}; use hyper::StatusCode; use crate::{ ff::Fp31, - helpers::TransportCallbacks, + helpers::{ + routing::{Addr, RouteId}, + BodyStream, HelperIdentity, HelperResponse, + }, net::{ http_serde, server::handlers::query::{ @@ -57,18 +59,22 @@ mod tests { ))]); let expected_query_id = QueryId; let raw_results = expected_results.to_vec(); - let cb = TransportCallbacks { - complete_query: Box::new(move |_transport, query_id| { - let results: Box = Box::new(raw_results.clone()); - assert_eq!(query_id, expected_query_id); - Box::pin(ready(Ok(results))) - }), - ..Default::default() - }; - let TestServer { transport, .. } = TestServer::builder().with_callbacks(cb).build().await; + let TestServer { transport, .. } = TestServer::builder() + .with_request_handler(Box::new( + move |addr: Addr, _data: BodyStream| { + let RouteId::CompleteQuery = addr.route else { + panic!("unexpected call"); + }; + let results = Box::new(raw_results.clone()) as Box; + assert_eq!(addr.query_id, Some(expected_query_id)); + Ok(HelperResponse::from(results)) + }, + )) + .build() + .await; let req = http_serde::query::results::Request::new(QueryId); let results = handler(Extension(transport), req.clone()).await.unwrap(); - assert_eq!(results, expected_results.into_bytes()); + assert_eq!(results, expected_results.as_bytes()); } struct OverrideReq { diff --git a/ipa-core/src/net/server/handlers/query/status.rs b/ipa-core/src/net/server/handlers/query/status.rs index f08a475e9..a06d95a28 100644 --- a/ipa-core/src/net/server/handlers/query/status.rs +++ b/ipa-core/src/net/server/handlers/query/status.rs @@ -2,7 +2,7 @@ use axum::{routing::get, Extension, Json, Router}; use hyper::StatusCode; use crate::{ - helpers::Transport, + helpers::{BodyStream, Transport}, net::{http_serde::query::status, server::Error, HttpTransport}, sync::Arc, }; @@ -12,8 +12,8 @@ async fn handler( req: status::Request, ) -> Result, Error> { let transport = Transport::clone_ref(&*transport); - match transport.query_status(req.query_id).await { - Ok(state) => Ok(Json(status::ResponseBody { status: state })), + match transport.dispatch(req, BodyStream::empty()).await { + Ok(state) => Ok(Json(status::ResponseBody::from(state))), Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), } } @@ -26,13 +26,15 @@ pub fn router(transport: Arc) -> Router { #[cfg(all(test, unit_test))] mod tests { - use std::future::ready; use axum::{http::Request, Extension, Json}; use hyper::StatusCode; use crate::{ - helpers::TransportCallbacks, + helpers::{ + routing::{Addr, RouteId}, + BodyStream, HelperIdentity, HelperResponse, + }, net::{ http_serde, server::handlers::query::{ @@ -49,14 +51,18 @@ mod tests { async fn status_test() { let expected_status = QueryStatus::Running; let expected_query_id = QueryId; - let cb = TransportCallbacks { - query_status: Box::new(move |_transport, query_id| { - assert_eq!(query_id, expected_query_id); - Box::pin(ready(Ok(expected_status))) - }), - ..Default::default() - }; - let TestServer { transport, .. } = TestServer::builder().with_callbacks(cb).build().await; + let TestServer { transport, .. } = TestServer::builder() + .with_request_handler(Box::new( + move |addr: Addr, _data: BodyStream| { + let RouteId::QueryStatus = addr.route else { + panic!("unexpected call"); + }; + assert_eq!(addr.query_id, Some(expected_query_id)); + Ok(HelperResponse::from(expected_status)) + }, + )) + .build() + .await; let req = http_serde::query::status::Request::new(QueryId); let response = handler(Extension(transport), req.clone()).await.unwrap(); diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index c66a8d72c..ed4312c15 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -8,7 +8,6 @@ //! `net::transport::tests`. #![allow(clippy::missing_panics_doc)] - use std::{ array, net::{SocketAddr, TcpListener}, @@ -23,7 +22,7 @@ use crate::{ ClientConfig, HpkeClientConfig, HpkeServerConfig, NetworkConfig, PeerConfig, ServerConfig, TlsConfig, }, - helpers::{HelperIdentity, TransportCallbacks}, + helpers::{HelperIdentity, PanickingHandler, RequestHandler}, hpke::{Deserializable as _, IpaPublicKey}, net::{ClientIdentity, HttpTransport, MpcHelperClient, MpcHelperServer}, sync::Arc, @@ -204,8 +203,6 @@ impl TestConfigBuilder { } } -type HttpTransportCallbacks = TransportCallbacks>; - pub struct TestServer { pub addr: SocketAddr, pub handle: JoinHandle<()>, @@ -232,7 +229,7 @@ impl TestServer { #[derive(Default)] pub struct TestServerBuilder { - callbacks: Option, + handler: Option>>, metrics: Option, disable_https: bool, use_http1: bool, @@ -241,8 +238,11 @@ pub struct TestServerBuilder { impl TestServerBuilder { #[must_use] - pub fn with_callbacks(mut self, callbacks: HttpTransportCallbacks) -> Self { - self.callbacks = Some(callbacks); + pub fn with_request_handler( + mut self, + handler: Box>, + ) -> Self { + self.handler = Some(handler); self } @@ -300,7 +300,8 @@ impl TestServerBuilder { server_config, network_config.clone(), clients, - self.callbacks.unwrap_or_default(), + self.handler + .unwrap_or_else(|| Box::::default()), ); let (addr, handle) = server.start_on(Some(server_socket), self.metrics).await; // Get the config for HelperIdentity::ONE diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index e26c567a1..36ae1b231 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -8,16 +8,17 @@ use std::{ use async_trait::async_trait; use bytes::Bytes; use futures::{Stream, TryFutureExt}; +use pin_project::{pin_project, pinned_drop}; use crate::{ config::{NetworkConfig, ServerConfig}, error::BoxError, helpers::{ - query::{PrepareQuery, QueryConfig, QueryInput}, - BodyStream, CompleteQueryResult, HelperIdentity, LogErrors, NoResourceIdentifier, - PrepareQueryResult, QueryIdBinding, QueryInputResult, QueryStatusResult, - ReceiveQueryResult, ReceiveRecords, RouteId, RouteParams, StepBinding, StreamCollection, - Transport, TransportCallbacks, + query::QueryConfig, + routing::{Addr, RouteId}, + ApiError, BodyStream, HelperIdentity, HelperResponse, LogErrors, NoQueryId, + NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, RequestHandler, RouteParams, + StepBinding, StreamCollection, Transport, }, net::{client::MpcHelperClient, error::Error, MpcHelperServer}, protocol::{step::Gate, QueryId}, @@ -29,11 +30,31 @@ type LogHttpErrors = LogErrors; /// HTTP transport for IPA helper service. pub struct HttpTransport { identity: HelperIdentity, - callbacks: TransportCallbacks>, clients: [MpcHelperClient; 3], // TODO(615): supporting multiple queries likely require a hashmap here. It will be ok if we // only allow one query at a time. record_streams: StreamCollection, + handler: Box>, +} + +impl RouteParams for QueryConfig { + type Params = String; + + fn resource_identifier(&self) -> RouteId { + RouteId::ReceiveQuery + } + + fn query_id(&self) -> NoQueryId { + NoQueryId + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + serde_json::to_string(self).unwrap() + } } impl HttpTransport { @@ -43,9 +64,9 @@ impl HttpTransport { server_config: ServerConfig, network_config: NetworkConfig, clients: [MpcHelperClient; 3], - callbacks: TransportCallbacks>, + handler: Box>, ) -> (Arc, MpcHelperServer) { - let transport = Self::new_internal(identity, clients, callbacks); + let transport = Self::new_internal(identity, clients, handler); let server = MpcHelperServer::new(Arc::clone(&transport), server_config, network_config); (transport, server) } @@ -53,58 +74,67 @@ impl HttpTransport { fn new_internal( identity: HelperIdentity, clients: [MpcHelperClient; 3], - callbacks: TransportCallbacks>, + handler: Box>, ) -> Arc { Arc::new(Self { identity, - callbacks, clients, + handler, record_streams: StreamCollection::default(), }) } - pub fn receive_query(self: Arc, req: QueryConfig) -> ReceiveQueryResult { - (Arc::clone(&self).callbacks.receive_query)(self, req) - } - - pub fn prepare_query(self: Arc, req: PrepareQuery) -> PrepareQueryResult { - (Arc::clone(&self).callbacks.prepare_query)(self, req) - } - - pub fn query_input(self: Arc, req: QueryInput) -> QueryInputResult { - (Arc::clone(&self).callbacks.query_input)(self, req) - } - - pub fn query_status(self: Arc, query_id: QueryId) -> QueryStatusResult { - (Arc::clone(&self).callbacks.query_status)(self, query_id) - } - - pub fn complete_query(self: Arc, query_id: QueryId) -> CompleteQueryResult { + /// Dispatches the given request to the [`RequestHandler`] connected to this transport. + /// + /// ## Errors + /// Returns an error, if handler rejects the request for any reason. + pub async fn dispatch>( + self: Arc, + req: R, + body: BodyStream, + ) -> Result + where + Option: From, + { /// Cleans up the `records_stream` collection after drop to ensure this transport /// can process the next query even in case of a panic. - struct ClearOnDrop { + /// + /// This implementation is a poor man's safety net and only works because we run + /// one query at a time and don't use query identifiers. + #[pin_project(PinnedDrop)] + struct ClearOnDrop { transport: Arc, - qr: CompleteQueryResult, + #[pin] + inner: F, } - impl Future for ClearOnDrop { - type Output = ::Output; + impl Future for ClearOnDrop { + type Output = F::Output; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.qr.as_mut().poll(cx) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().inner.poll(cx) } } - impl Drop for ClearOnDrop { - fn drop(&mut self) { + #[pinned_drop] + impl PinnedDrop for ClearOnDrop { + fn drop(self: Pin<&mut Self>) { self.transport.record_streams.clear(); } } - Box::pin(ClearOnDrop { - transport: Arc::clone(&self), - qr: Box::pin((Arc::clone(&self).callbacks.complete_query)(self, query_id)), - }) + let route_id = req.resource_identifier(); + let r = self.handler.handle(Addr::from_route(None, req), body); + + if let RouteId::CompleteQuery = route_id { + ClearOnDrop { + transport: Arc::clone(&self), + inner: r, + } + .await + } else { + r.await + } } /// Connect an inbound stream of MPC record data. @@ -168,8 +198,13 @@ impl Transport for Arc { let req = serde_json::from_str(route.extra().borrow()).unwrap(); self.clients[dest].prepare_query(req).await } - RouteId::ReceiveQuery => { - unimplemented!("attempting to send ReceiveQuery to another helper") + evt @ (RouteId::QueryInput + | RouteId::ReceiveQuery + | RouteId::QueryStatus + | RouteId::CompleteQuery) => { + unimplemented!( + "attempting to send client-specific request {evt:?} to another helper" + ) } } } @@ -202,7 +237,7 @@ mod tests { use crate::{ config::{NetworkConfig, ServerConfig}, ff::{FieldType, Fp31, Serializable}, - helpers::query::QueryType::TestMultiply, + helpers::query::{QueryInput, QueryType::TestMultiply}, net::{ client::ClientIdentity, test::{get_test_identity, TestConfig, TestConfigBuilder, TestServer}, @@ -272,14 +307,14 @@ mod tests { } else { get_test_identity(id) }; - let (setup, callbacks) = AppSetup::new(); + let (setup, handler_setup) = AppSetup::new(); let clients = MpcHelperClient::from_conf(network_config, identity); let (transport, server) = HttpTransport::new( id, server_config, network_config.clone(), clients, - callbacks, + handler_setup.make_handler(), ); server.start_on(Some(socket), ()).await; diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index b2085980e..409c6c8f5 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -36,7 +36,7 @@ use crate::{ }; pub trait Result: Send + Debug { - fn into_bytes(self: Box) -> Vec; + fn as_bytes(&self) -> Vec; } impl Result for Vec @@ -44,9 +44,9 @@ where T: Serializable, Vec: Debug + Send, { - fn into_bytes(self: Box) -> Vec { + fn as_bytes(&self) -> Vec { let mut r = vec![0u8; self.len() * T::Size::USIZE]; - for (i, row) in self.into_iter().enumerate() { + for (i, row) in self.iter().enumerate() { row.serialize(GenericArray::from_mut_slice( &mut r[(i * T::Size::USIZE)..((i + 1) * T::Size::USIZE)], )); @@ -156,10 +156,10 @@ mod tests { fn serialize_result() { let [input, ..] = (0u128..=3).map(Fp31::truncate_from).share(); let expected = input.clone(); - let bytes = Box::new(input).into_bytes(); + let bytes = &input.as_bytes(); assert_eq!( expected, - AdditiveShare::::from_byte_slice(&bytes) + AdditiveShare::::from_byte_slice(bytes) .collect::, _>>() .unwrap() ); diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index fb01aa8ac..81508ddcb 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -154,8 +154,8 @@ impl Processor { // Inform other parties about new query. If any of them rejects it, this join will fail try_join( - transport.send(left, &prepare_request, stream::empty()), - transport.send(right, &prepare_request, stream::empty()), + transport.send(left, prepare_request.clone(), stream::empty()), + transport.send(right, prepare_request.clone(), stream::empty()), ) .await .map_err(NewQueryError::Transport)?; @@ -204,7 +204,7 @@ impl Processor { /// if query is not registered on this helper. /// /// ## Panics - /// If failed to obtain an exclusive access to the query collection. + /// If failed to obtain exclusive access to the query collection. pub fn receive_inputs( &self, transport: TransportImpl, @@ -278,7 +278,7 @@ impl Processor { /// if query is not registered on this helper. /// /// ## Panics - /// If failed to obtain an exclusive access to the query collection. + /// If failed to obtain exclusive access to the query collection. pub async fn complete( &self, query_id: QueryId, @@ -321,9 +321,10 @@ mod tests { use crate::{ ff::FieldType, helpers::{ + make_boxed_handler, query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, - HelperIdentity, InMemoryMpcNetwork, PrepareQueryCallback, RoleAssignment, Transport, - TransportCallbacks, + ApiError, HelperIdentity, HelperResponse, InMemoryMpcNetwork, RequestHandler, + RoleAssignment, Transport, }, protocol::QueryId, query::{ @@ -331,12 +332,19 @@ mod tests { }, }; - fn prepare_query_callback(cb: F) -> Box> + fn prepare_query_handler(cb: F) -> Box> where - F: Fn(T, PrepareQuery) -> Fut + Send + Sync + 'static, - Fut: Future> + Send + 'static, + F: Fn(PrepareQuery) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, { - Box::new(move |transport, prepare_query| Box::pin(cb(transport, prepare_query))) + make_boxed_handler(move |req, _| { + let prepare_query = req.into().unwrap(); + cb(prepare_query) + }) + } + + fn respond_ok() -> Box> { + prepare_query_handler(move |_| async move { Ok(HelperResponse::ok()) }) } fn test_multiply_config() -> QueryConfig { @@ -346,29 +354,23 @@ mod tests { #[tokio::test] async fn new_query() { let barrier = Arc::new(Barrier::new(3)); - let cb2_barrier = Arc::clone(&barrier); - let cb3_barrier = Arc::clone(&barrier); - let cb2 = TransportCallbacks { - prepare_query: prepare_query_callback(move |_, _| { - let barrier = Arc::clone(&cb2_barrier); - async move { - barrier.wait().await; - Ok(()) - } - }), - ..Default::default() - }; - let cb3 = TransportCallbacks { - prepare_query: prepare_query_callback(move |_, _| { - let barrier = Arc::clone(&cb3_barrier); - async move { - barrier.wait().await; - Ok(()) - } - }), - ..Default::default() - }; - let network = InMemoryMpcNetwork::new([TransportCallbacks::default(), cb2, cb3]); + let h2_barrier = Arc::clone(&barrier); + let h3_barrier = Arc::clone(&barrier); + let h2 = prepare_query_handler(move |_| { + let barrier = Arc::clone(&h2_barrier); + async move { + barrier.wait().await; + Ok(HelperResponse::ok()) + } + }); + let h3 = prepare_query_handler(move |_| { + let barrier = Arc::clone(&h3_barrier); + async move { + barrier.wait().await; + Ok(HelperResponse::ok()) + } + }); + let network = InMemoryMpcNetwork::new([None, Some(h2), Some(h3)]); let [t0, _, _] = network.transports(); let p0 = Processor::default(); let request = test_multiply_config(); @@ -402,11 +404,12 @@ mod tests { #[tokio::test] async fn rejects_duplicate_query_id() { - let cb = array::from_fn(|_| TransportCallbacks { - prepare_query: prepare_query_callback(|_, _| async { Ok(()) }), - ..Default::default() + let handlers = array::from_fn(|_| { + Some(prepare_query_handler(|_| async { + Ok(HelperResponse::ok()) + })) }); - let network = InMemoryMpcNetwork::new(cb); + let network = InMemoryMpcNetwork::new(handlers); let [t0, _, _] = network.transports(); let p0 = Processor::default(); let request = test_multiply_config(); @@ -423,17 +426,11 @@ mod tests { #[tokio::test] async fn prepare_error() { - let cb2 = TransportCallbacks { - prepare_query: prepare_query_callback(|_, _| async { Ok(()) }), - ..Default::default() - }; - let cb3 = TransportCallbacks { - prepare_query: prepare_query_callback(|_, _| async { - Err(PrepareQueryError::WrongTarget) - }), - ..Default::default() - }; - let network = InMemoryMpcNetwork::new([TransportCallbacks::default(), cb2, cb3]); + let h2 = respond_ok().into(); + let h3 = Some(prepare_query_handler(|_| async move { + Err(ApiError::QueryPrepare(PrepareQueryError::WrongTarget)) + })); + let network = InMemoryMpcNetwork::new([None, h2, h3]); let [t0, _, _] = network.transports(); let p0 = Processor::default(); let request = test_multiply_config(); @@ -446,17 +443,11 @@ mod tests { #[tokio::test] async fn can_recover_from_prepare_error() { - let cb2 = TransportCallbacks { - prepare_query: prepare_query_callback(|_, _| async { Ok(()) }), - ..Default::default() - }; - let cb3 = TransportCallbacks { - prepare_query: prepare_query_callback(|_, _| async { - Err(PrepareQueryError::WrongTarget) - }), - ..Default::default() - }; - let network = InMemoryMpcNetwork::new([TransportCallbacks::default(), cb2, cb3]); + let h2 = respond_ok().into(); + let h3 = Some(prepare_query_handler(|_| async move { + Err(ApiError::QueryPrepare(PrepareQueryError::WrongTarget)) + })); + let network = InMemoryMpcNetwork::new([None, h2, h3]); let [t0, _, _] = network.transports(); let p0 = Processor::default(); let request = test_multiply_config(); diff --git a/ipa-core/src/test_fixture/app.rs b/ipa-core/src/test_fixture/app.rs index 6fda9c056..6ed453523 100644 --- a/ipa-core/src/test_fixture/app.rs +++ b/ipa-core/src/test_fixture/app.rs @@ -1,14 +1,14 @@ -use std::iter::zip; +use std::{array, iter::zip}; use generic_array::GenericArray; use typenum::Unsigned; use crate::{ - app::Error, + app::RequestHandlerSetup, ff::Serializable, helpers::{ query::{QueryConfig, QueryInput}, - InMemoryMpcNetwork, + ApiError, InMemoryMpcNetwork, }, protocol::QueryId, query::QueryStatus, @@ -60,10 +60,14 @@ fn unzip_tuple_array(input: [(T, U); 3]) -> ([T; 3], [U; 3]) { impl Default for TestApp { fn default() -> Self { - let (setup, callbacks) = - unzip_tuple_array([AppSetup::new(), AppSetup::new(), AppSetup::new()]); - - let network = InMemoryMpcNetwork::new(callbacks); + let (setup, handlers) = unzip_tuple_array(array::from_fn(|_| AppSetup::new())); + let handlers_ref = [&handlers[0], &handlers[1], &handlers[2]]; + + let network = InMemoryMpcNetwork::new( + handlers_ref + .map(RequestHandlerSetup::make_handler) + .map(Some), + ); let drivers = network .transports() .iter() @@ -88,7 +92,7 @@ impl TestApp { &self, input: I, query_config: QueryConfig, - ) -> Result + ) -> Result where I: IntoShares, A: IntoBuf, @@ -117,7 +121,7 @@ impl TestApp { /// Propagates errors retrieving the query status. /// ## Panics /// Never. - pub fn query_status(&self, query_id: QueryId) -> Result<[QueryStatus; 3], Error> { + pub fn query_status(&self, query_id: QueryId) -> Result<[QueryStatus; 3], ApiError> { Ok((0..3) .map(|i| self.drivers[i].query_status(query_id)) .collect::, _>>()? @@ -129,7 +133,7 @@ impl TestApp { /// Returns an error if one or more helpers can't finish the processing. /// ## Panics /// Never. - pub async fn complete_query(&self, query_id: QueryId) -> Result<[Vec; 3], Error> { + pub async fn complete_query(&self, query_id: QueryId) -> Result<[Vec; 3], ApiError> { let results = try_join3_array([0, 1, 2].map(|i| self.drivers[i].complete_query(query_id))).await; self.network.reset(); @@ -145,7 +149,7 @@ impl TestApp { &self, input: I, query_config: QueryConfig, - ) -> Result<[Vec; 3], Error> + ) -> Result<[Vec; 3], ApiError> where I: IntoShares, A: IntoBuf, From 04a49c498269eba213afe0d1edeaada3f7059cbb Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 25 Mar 2024 21:15:28 -0700 Subject: [PATCH 2/5] Fix the memory leak inside TestApp --- ipa-core/src/app.rs | 212 ++++++++---------- ipa-core/src/bin/helper.rs | 4 +- ipa-core/src/helpers/mod.rs | 9 +- ipa-core/src/helpers/transport/handler.rs | 106 ++++++--- .../src/helpers/transport/in_memory/mod.rs | 4 +- .../helpers/transport/in_memory/transport.rs | 36 ++- ipa-core/src/helpers/transport/mod.rs | 2 +- ipa-core/src/net/client/mod.rs | 28 +-- .../src/net/server/handlers/query/create.rs | 31 +-- .../src/net/server/handlers/query/input.rs | 13 +- .../src/net/server/handlers/query/prepare.rs | 26 ++- .../src/net/server/handlers/query/results.rs | 26 ++- .../src/net/server/handlers/query/status.rs | 11 +- ipa-core/src/net/test.rs | 12 +- ipa-core/src/net/transport.rs | 24 +- ipa-core/src/query/processor.rs | 52 +++-- ipa-core/src/test_fixture/app.rs | 9 +- 17 files changed, 325 insertions(+), 280 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index 7f6586866..3d7a322b7 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -1,4 +1,4 @@ -use std::sync::Mutex; +use std::sync::Weak; use async_trait::async_trait; @@ -6,8 +6,8 @@ use crate::{ helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, routing::{Addr, RouteId}, - ApiError, BodyStream, HelperIdentity, HelperResponse, RequestHandler, Transport, - TransportImpl, + ApiError, BodyStream, HandlerBox, HandlerRef, HelperIdentity, HelperResponse, + RequestHandler, Transport, TransportImpl, }, hpke::{KeyPair, KeyRegistry}, protocol::QueryId, @@ -16,143 +16,57 @@ use crate::{ }; pub struct Setup { - query_processor: Arc, - handler_setup: RequestHandlerSetup, + query_processor: QueryProcessor, + handler: HandlerRef, } /// The API layer to interact with a helper. #[must_use] pub struct HelperApp { - query_processor: Arc, - transport: TransportImpl, -} - -/// This handles requests to initiate and control IPA queries. -struct QueryRequestHandler { - qp: Arc, - transport: Arc>>, + inner: Arc, } -#[async_trait] -impl RequestHandler for QueryRequestHandler { - type Identity = HelperIdentity; - - async fn handle( - &self, - req: Addr, - data: BodyStream, - ) -> Result { - fn ext_query_id(req: &Addr) -> Result { - req.query_id.ok_or_else(|| { - ApiError::BadRequest("Query input is missing query_id argument".into()) - }) - } - - let qp = Arc::clone(&self.qp); - - Ok(match req.route { - r @ RouteId::Records => { - return Err(ApiError::BadRequest( - format!("{r:?} request must not be handled by query processing flow").into(), - )) - } - RouteId::ReceiveQuery => { - let req = req.into::()?; - HelperResponse::from(qp.new_query(self.transport(), req).await?) - } - RouteId::PrepareQuery => { - let req = req.into::()?; - HelperResponse::from(qp.prepare(&self.transport(), req)?) - } - RouteId::QueryInput => { - let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.receive_inputs( - self.transport(), - QueryInput { - query_id, - input_stream: data, - }, - )?) - } - RouteId::QueryStatus => { - let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.query_status(query_id)?) - } - RouteId::CompleteQuery => { - let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.complete(query_id).await?) - } - }) - } -} - -impl QueryRequestHandler { - fn transport(&self) -> TransportImpl { - Clone::clone(self.transport.lock().unwrap().as_ref().unwrap()) - } -} - -#[derive(Clone)] -pub struct RequestHandlerSetup { - qp: Arc, - transport_container: Arc>>, -} - -impl RequestHandlerSetup { - fn new(qp: Arc) -> Self { - Self { - qp, - transport_container: Arc::new(Mutex::new(None)), - } - } - - pub fn make_handler(&self) -> Box> { - Box::new(QueryRequestHandler { - qp: Arc::clone(&self.qp), - transport: Arc::clone(&self.transport_container), - }) as Box> - } - - fn finish(self, transport: TransportImpl) { - let mut guard = self.transport_container.lock().unwrap(); - *guard = Some(transport); - } +struct Inner { + query_processor: QueryProcessor, + transport: TransportImpl, } impl Setup { #[must_use] - pub fn new() -> (Self, RequestHandlerSetup) { + pub fn new() -> (Self, HandlerRef) { Self::with_key_registry(KeyRegistry::empty()) } #[must_use] - pub fn with_key_registry(key_registry: KeyRegistry) -> (Self, RequestHandlerSetup) { - let query_processor = Arc::new(QueryProcessor::new(key_registry)); - let handler_setup = RequestHandlerSetup::new(Arc::clone(&query_processor)); + pub fn with_key_registry(key_registry: KeyRegistry) -> (Self, HandlerRef) { + let query_processor = QueryProcessor::new(key_registry); + let handler = HandlerBox::empty(); let this = Self { query_processor, - handler_setup: handler_setup.clone(), + handler: handler.clone(), }; // TODO: weak reference to query processor to prevent mem leak - (this, handler_setup) + (this, handler) } /// Instantiate [`HelperApp`] by connecting it to the provided transport implementation pub fn connect(self, transport: TransportImpl) -> HelperApp { - self.handler_setup.finish(Clone::clone(&transport)); - HelperApp::new(transport, self.query_processor) + let app = Arc::new(Inner { + query_processor: self.query_processor, + transport, + }); + self.handler.set_handler( + Arc::downgrade(&app) as Weak> + ); + + // Handler must be kept inside the app instance. When app is dropped, handler, transport and + // query processor are destroyed. + HelperApp { inner: app } } } impl HelperApp { - pub fn new(transport: TransportImpl, query_processor: Arc) -> Self { - Self { - query_processor, - transport, - } - } - /// Initiates a new query on this helper. In case if query is accepted, the unique [`QueryId`] /// identifier is returned, otherwise an error indicating what went wrong is reported back. /// @@ -160,8 +74,9 @@ impl HelperApp { /// If query is rejected for any reason. pub async fn start_query(&self, query_config: QueryConfig) -> Result { Ok(self + .inner .query_processor - .new_query(Transport::clone_ref(&self.transport), query_config) + .new_query(Transport::clone_ref(&self.inner.transport), query_config) .await? .query_id) } @@ -171,8 +86,10 @@ impl HelperApp { /// ## Errors /// Propagates errors from the helper. pub fn execute_query(&self, input: QueryInput) -> Result<(), ApiError> { - let transport = ::clone(&self.transport); - self.query_processor.receive_inputs(transport, input)?; + let transport = ::clone(&self.inner.transport); + self.inner + .query_processor + .receive_inputs(transport, input)?; Ok(()) } @@ -181,7 +98,7 @@ impl HelperApp { /// ## Errors /// Propagates errors from the helper. pub fn query_status(&self, query_id: QueryId) -> Result { - Ok(self.query_processor.query_status(query_id)?) + Ok(self.inner.query_processor.query_status(query_id)?) } /// Waits for a query to complete and returns the result. @@ -189,6 +106,67 @@ impl HelperApp { /// ## Errors /// Propagates errors from the helper. pub async fn complete_query(&self, query_id: QueryId) -> Result, ApiError> { - Ok(self.query_processor.complete(query_id).await?.as_bytes()) + Ok(self + .inner + .query_processor + .complete(query_id) + .await? + .as_bytes()) + } +} + +#[async_trait] +impl RequestHandler for Inner { + type Identity = HelperIdentity; + + async fn handle( + &self, + req: Addr, + data: BodyStream, + ) -> Result { + fn ext_query_id(req: &Addr) -> Result { + req.query_id.ok_or_else(|| { + ApiError::BadRequest("Query input is missing query_id argument".into()) + }) + } + + let qp = &self.query_processor; + + Ok(match req.route { + r @ RouteId::Records => { + return Err(ApiError::BadRequest( + format!("{r:?} request must not be handled by query processing flow").into(), + )) + } + RouteId::ReceiveQuery => { + let req = req.into::()?; + HelperResponse::from( + qp.new_query(Transport::clone_ref(&self.transport), req) + .await?, + ) + } + RouteId::PrepareQuery => { + let req = req.into::()?; + HelperResponse::from(qp.prepare(&self.transport, req)?) + } + RouteId::QueryInput => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.receive_inputs( + Transport::clone_ref(&self.transport), + QueryInput { + query_id, + input_stream: data, + }, + )?) + } + RouteId::QueryStatus => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.query_status(query_id)?) + } + RouteId::CompleteQuery => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.complete(query_id).await?) + } + }) } } diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 32f49ae32..9ac13f670 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -131,7 +131,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { }); let key_registry = hpke_registry(mk_encryption.as_ref()).await?; - let (setup, handler_setup) = AppSetup::with_key_registry(key_registry); + let (setup, handler) = AppSetup::with_key_registry(key_registry); let server_config = ServerConfig { port: args.port, @@ -155,7 +155,7 @@ async fn server(args: ServerArgs) -> Result<(), BoxError> { server_config, network_config, clients, - handler_setup.make_handler(), + Some(handler), ); let _app = setup.connect(transport.clone()); diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index a10427cfb..c02a2a08b 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -52,11 +52,10 @@ pub use prss_protocol::negotiate as negotiate_prss; #[cfg(feature = "web-app")] pub use transport::WrappedAxumBodyStream; pub use transport::{ - make_boxed_handler, query, routing, ApiError, BodyStream, BytesStream, HelperResponse, - Identity as TransportIdentity, LengthDelimitedStream, LogErrors, NoQueryId, - NoResourceIdentifier, NoStep, PanickingHandler, QueryIdBinding, ReceiveRecords, RecordsStream, - RequestHandler, RouteParams, StepBinding, StreamCollection, StreamKey, Transport, - WrappedBoxBodyStream, + make_owned_handler, query, routing, ApiError, BodyStream, BytesStream, HandlerBox, HandlerRef, + HelperResponse, Identity as TransportIdentity, LengthDelimitedStream, LogErrors, NoQueryId, + NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, RecordsStream, RequestHandler, + RouteParams, StepBinding, StreamCollection, StreamKey, Transport, WrappedBoxBodyStream, }; #[cfg(feature = "in-memory-infra")] pub use transport::{InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport}; diff --git a/ipa-core/src/helpers/transport/handler.rs b/ipa-core/src/helpers/transport/handler.rs index 9823173fb..613c6bdd8 100644 --- a/ipa-core/src/helpers/transport/handler.rs +++ b/ipa-core/src/helpers/transport/handler.rs @@ -18,6 +18,7 @@ use crate::{ NewQueryError, PrepareQueryError, ProtocolResult, QueryCompletionError, QueryInputError, QueryStatus, QueryStatusError, }, + sync::{Arc, Mutex, Weak}, }; /// Represents some response sent from MPC helper acting on a given request. It is rudimental now @@ -34,6 +35,69 @@ pub struct HelperResponse { body: Vec, } +/// The lifecycle of request handlers is somewhat complicated. First, to initialize [`Transport`], +/// an instance of [`RequestHandler`] is required upfront. To function properly, each handler must +/// have a reference to transport. +/// +/// This lifecycle is managed through this struct. An empty [`Option`], protected by a mutex +/// is passed over to transport, and it is given a value later, after transport is fully initialized. +pub struct HandlerBox { + /// There is a cyclic dependency between handlers and transport. + /// Handlers use transports to create MPC infrastructure as response to query requests. + /// Transport uses handler to respond to requests. + /// + /// To break this cycle, transport holds a weak reference to the handler and handler + /// uses strong references to transport. + inner: Mutex>>>, +} + +impl Default for HandlerBox { + fn default() -> Self { + Self { + inner: Mutex::new(None), + } + } +} + +impl HandlerBox { + #[must_use] + pub fn empty() -> HandlerRef { + HandlerRef { + inner: Arc::new(Self::default()), + } + } + + pub fn owning_ref(handler: &Arc>) -> HandlerRef { + HandlerRef { + inner: Arc::new(Self { + inner: Mutex::new(Some(Arc::downgrade(handler))), + }), + } + } + + fn set_handler(&self, handler: Weak>) { + let mut guard = self.inner.lock().unwrap(); + assert!(guard.is_none(), "Handler can be set only once"); + *guard = Some(handler); + } + + fn handler(&self) -> Arc> { + self.inner + .lock() + .unwrap() + .as_ref() + .expect("Handler is set") + .upgrade() + .expect("Handler is not destroyed") + } +} + +/// This struct is passed over to [`Transport`] to initialize it. +#[derive(Clone)] +pub struct HandlerRef { + inner: Arc>, +} + impl Debug for HelperResponse { fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result { todo!() @@ -121,26 +185,7 @@ pub trait RequestHandler: Send + Sync { ) -> Result; } -#[async_trait] -impl RequestHandler for F -where - F: Fn(Addr, BodyStream) -> Result - + Send - + Sync - + 'static, -{ - type Identity = HelperIdentity; - - async fn handle( - &self, - req: Addr, - data: BodyStream, - ) -> Result { - self(req, data) - } -} - -pub fn make_boxed_handler<'a, I, F, Fut>(handler: F) -> Box + 'a> +pub fn make_owned_handler<'a, I, F, Fut>(handler: F) -> Arc + 'a> where I: TransportIdentity, F: Fn(Addr, BodyStream) -> Fut + Send + Sync + 'a, @@ -168,34 +213,27 @@ where } } - Box::new(Handler { + Arc::new(Handler { inner: handler, phantom: PhantomData, }) } -// This handler panics when [`Self::handle`] method is called. -pub struct PanickingHandler { - phantom: PhantomData, -} - -impl Default for PanickingHandler { - fn default() -> Self { - Self { - phantom: PhantomData, - } +impl HandlerRef { + pub fn set_handler(&self, handler: Weak>) { + self.inner.set_handler(handler); } } #[async_trait] -impl RequestHandler for PanickingHandler { +impl RequestHandler for HandlerRef { type Identity = I; async fn handle( &self, req: Addr, - _data: BodyStream, + data: BodyStream, ) -> Result { - panic!("unexpected call: {req:?}"); + self.inner.handler().handle(req, data).await } } diff --git a/ipa-core/src/helpers/transport/in_memory/mod.rs b/ipa-core/src/helpers/transport/in_memory/mod.rs index 4f8735cf0..f3622ed22 100644 --- a/ipa-core/src/helpers/transport/in_memory/mod.rs +++ b/ipa-core/src/helpers/transport/in_memory/mod.rs @@ -7,7 +7,7 @@ pub use sharding::InMemoryShardNetwork; pub use transport::Setup; use crate::{ - helpers::{HelperIdentity, RequestHandler}, + helpers::{HandlerRef, HelperIdentity}, sync::{Arc, Weak}, }; @@ -27,7 +27,7 @@ impl Default for InMemoryMpcNetwork { impl InMemoryMpcNetwork { #[must_use] - pub fn new(handlers: [Option>>; 3]) -> Self { + pub fn new(handlers: [Option; 3]) -> Self { let [mut first, mut second, mut third]: [_; 3] = HelperIdentity::make_three().map(Setup::new); diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 5444809c0..595154802 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -21,8 +21,9 @@ use crate::{ error::BoxError, helpers::{ transport::routing::{Addr, RouteId}, - ApiError, BodyStream, HelperResponse, NoResourceIdentifier, QueryIdBinding, ReceiveRecords, - RequestHandler, RouteParams, StepBinding, StreamCollection, Transport, TransportIdentity, + ApiError, BodyStream, HandlerRef, HelperResponse, NoResourceIdentifier, QueryIdBinding, + ReceiveRecords, RequestHandler, RouteParams, StepBinding, StreamCollection, Transport, + TransportIdentity, }, protocol::{step::Gate, QueryId}, sync::{Arc, Weak}, @@ -85,11 +86,7 @@ impl InMemoryTransport { /// out and processes it, the same way as query processor does. That will allow all tasks to be /// created in one place (driver). It does not affect the [`Transport`] interface, /// so I'll leave it as is for now. - fn listen( - self: &Arc, - handler: Option>>, - mut rx: ConnectionRx, - ) { + fn listen(self: &Arc, handler: Option>, mut rx: ConnectionRx) { tokio::spawn( { let streams = self.record_streams.clone(); @@ -112,7 +109,7 @@ impl InMemoryTransport { | RouteId::CompleteQuery => { handler .as_ref() - .expect("Request handler is provided") + .expect("Handler is set") .handle( addr, BodyStream::from_infallible( @@ -300,16 +297,13 @@ impl Setup { .is_none()); } - pub(crate) fn start( - self, - handler: Option>>, - ) -> Arc> { + pub(crate) fn start(self, handler: Option>) -> Arc> { self.into_active_conn(handler).1 } fn into_active_conn( self, - handler: Option>>, + handler: Option>, ) -> (ConnectionTx, Arc>) { let transport = Arc::new(InMemoryTransport::new(self.identity, self.connections)); transport.listen(handler, self.rx); @@ -336,6 +330,7 @@ mod tests { use crate::{ ff::{FieldType, Fp31}, helpers::{ + make_owned_handler, query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, transport::{ in_memory::{ @@ -344,8 +339,8 @@ mod tests { }, routing::RouteId, }, - HelperIdentity, HelperResponse, OrderingSender, Role, RoleAssignment, Transport, - TransportIdentity, + HandlerBox, HelperIdentity, HelperResponse, OrderingSender, Role, RoleAssignment, + Transport, TransportIdentity, }, protocol::{step::Gate, QueryId}, sync::Arc, @@ -373,8 +368,9 @@ mod tests { async fn handler_is_called() { let (signal_tx, signal_rx) = oneshot::channel(); let signal_tx = Arc::new(Mutex::new(Some(signal_tx))); - let (tx, _) = Setup::new(HelperIdentity::ONE).into_active_conn(Some(Box::new( - move |addr: Addr, _| { + let handler = make_owned_handler(move |addr: Addr, _| { + let signal_tx = Arc::clone(&signal_tx); + async move { let RouteId::ReceiveQuery = addr.route else { panic!("unexpected call: {addr:?}") }; @@ -393,8 +389,10 @@ mod tests { config: query_config, roles: RoleAssignment::try_from([Role::H1, Role::H2, Role::H3]).unwrap(), })) - }, - ))); + } + }); + let (tx, _) = Setup::new(HelperIdentity::ONE) + .into_active_conn(Some(HandlerBox::owning_ref(&handler))); let expected = QueryConfig::new(TestMultiply, FieldType::Fp32BitPrime, 1u32).unwrap(); send_and_ack( diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index e6a70b199..23c290388 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -17,7 +17,7 @@ pub mod routing; mod stream; pub use handler::{ - make_boxed_handler, Error as ApiError, HelperResponse, PanickingHandler, RequestHandler, + make_owned_handler, Error as ApiError, HandlerBox, HandlerRef, HelperResponse, RequestHandler, }; #[cfg(feature = "in-memory-infra")] pub use in_memory::{InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport}; diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 296298860..2d49d4e28 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -431,9 +431,8 @@ pub(crate) mod tests { use crate::{ ff::{FieldType, Fp31}, helpers::{ - make_boxed_handler, query::QueryType::TestMultiply, BytesStream, HelperResponse, - PanickingHandler, RequestHandler, RoleAssignment, Transport, - MESSAGE_PAYLOAD_SIZE_BYTES, + make_owned_handler, query::QueryType::TestMultiply, BytesStream, HelperResponse, + RequestHandler, RoleAssignment, Transport, MESSAGE_PAYLOAD_SIZE_BYTES, }, net::test::TestServer, protocol::step::StepNarrow, @@ -481,7 +480,7 @@ pub(crate) mod tests { ClientOut: Eq + Debug, ClientFut: Future, ClientF: Fn(MpcHelperClient) -> ClientFut, - HandlerF: Fn() -> Box>, + HandlerF: Fn() -> Arc>, { let mut results = Vec::with_capacity(4); for (use_https, use_http1) in zip([true, false], [true, false]) { @@ -494,15 +493,12 @@ pub(crate) mod tests { test_server_builder = test_server_builder.use_http1(); } - let TestServer { - client: http_client, - .. - } = test_server_builder + let test_server = test_server_builder .with_request_handler(server_handler()) .build() .await; - results.push(clientf(http_client).await); + results.push(clientf(test_server.client).await); } assert!(results.windows(2).all(|slice| slice[0] == slice[1])); @@ -516,7 +512,11 @@ pub(crate) mod tests { let output = test_query_command( |client| async move { client.echo(expected_output).await.unwrap() }, - || Box::::default(), + || { + make_owned_handler(move |addr, _| async move { + panic!("unexpected call: {addr:?}"); + }) + }, ) .await; assert_eq!(expected_output, &output); @@ -528,7 +528,7 @@ pub(crate) mod tests { let expected_query_config = QueryConfig::new(TestMultiply, FieldType::Fp31, 1).unwrap(); let handler = || { - make_boxed_handler(move |addr, _| async move { + make_owned_handler(move |addr, _| async move { let query_config = addr.into::().unwrap(); assert_eq!(query_config, expected_query_config); @@ -551,7 +551,7 @@ pub(crate) mod tests { async fn prepare() { let config = QueryConfig::new(TestMultiply, FieldType::Fp31, 1).unwrap(); let handler = move || { - make_boxed_handler(move |addr, _| async move { + make_owned_handler(move |addr, _| async move { let input = PrepareQuery { query_id: QueryId, config, @@ -583,7 +583,7 @@ pub(crate) mod tests { let expected_query_id = QueryId; let expected_input = &[8u8; 25]; let handler = move || { - make_boxed_handler(move |addr, data| async move { + make_owned_handler(move |addr, data| async move { assert_eq!(addr.query_id, Some(expected_query_id)); assert_eq!(data.to_vec().await, expected_input); @@ -641,7 +641,7 @@ pub(crate) mod tests { ]; let expected_query_id = QueryId; let handler = move || { - make_boxed_handler(move |addr, _| async move { + make_owned_handler(move |addr, _| async move { let results: Box = Box::new( [Replicated::from((expected_results[0], expected_results[1]))].to_vec(), ); diff --git a/ipa-core/src/net/server/handlers/query/create.rs b/ipa-core/src/net/server/handlers/query/create.rs index e69294293..2d46f043b 100644 --- a/ipa-core/src/net/server/handlers/query/create.rs +++ b/ipa-core/src/net/server/handlers/query/create.rs @@ -46,6 +46,7 @@ mod tests { use crate::{ ff::FieldType, helpers::{ + make_owned_handler, query::{IpaQueryConfig, PrepareQuery, QueryConfig, QueryType}, routing::{Addr, RouteId}, HelperIdentity, HelperResponse, Role, RoleAssignment, @@ -59,27 +60,29 @@ mod tests { }; async fn create_test(expected_query_config: QueryConfig) { - let TestServer { server, .. } = TestServer::builder() - .with_request_handler(Box::new(move |addr: Addr, _| { - let RouteId::ReceiveQuery = addr.route else { - panic!("unexpected call"); - }; + let test_server = TestServer::builder() + .with_request_handler(make_owned_handler( + move |addr: Addr, _| async move { + let RouteId::ReceiveQuery = addr.route else { + panic!("unexpected call"); + }; - let query_config = addr.into().unwrap(); - assert_eq!(query_config, expected_query_config); - Ok(HelperResponse::from(PrepareQuery { - query_id: QueryId, - config: query_config, - roles: RoleAssignment::try_from([Role::H1, Role::H2, Role::H3]).unwrap(), - })) - })) + let query_config = addr.into().unwrap(); + assert_eq!(query_config, expected_query_config); + Ok(HelperResponse::from(PrepareQuery { + query_id: QueryId, + config: query_config, + roles: RoleAssignment::try_from([Role::H1, Role::H2, Role::H3]).unwrap(), + })) + }, + )) .build() .await; let req = http_serde::query::create::Request::new(expected_query_config); let req = req .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) .unwrap(); - let resp = server.handle_req(req).await; + let resp = test_server.server.handle_req(req).await; let status = resp.status(); let body_bytes = hyper::body::to_bytes(resp.into_body()).await.unwrap(); diff --git a/ipa-core/src/net/server/handlers/query/input.rs b/ipa-core/src/net/server/handlers/query/input.rs index 842db5c77..f03418c1a 100644 --- a/ipa-core/src/net/server/handlers/query/input.rs +++ b/ipa-core/src/net/server/handlers/query/input.rs @@ -31,15 +31,14 @@ pub fn router(transport: Arc) -> Router { #[cfg(all(test, unit_test))] mod tests { + use axum::{http::Request, Extension}; use hyper::{Body, StatusCode}; use tokio::runtime::Handle; use crate::{ helpers::{ - query::QueryInput, - routing::{Addr, RouteId}, - BodyStream, BytesStream, HelperIdentity, HelperResponse, + make_owned_handler, query::QueryInput, routing::RouteId, BytesStream, HelperResponse, }, net::{ http_serde, @@ -56,7 +55,7 @@ mod tests { async fn input_test() { let expected_query_id = QueryId; let expected_input = &[4u8; 4]; - let req_handler = Box::new(move |addr: Addr, data: BodyStream| { + let req_handler = make_owned_handler(move |addr, data| async move { let RouteId::QueryInput = addr.route else { panic!("unexpected call"); }; @@ -72,7 +71,7 @@ mod tests { Ok(HelperResponse::ok()) }); - let TestServer { transport, .. } = TestServer::builder() + let test_server = TestServer::builder() .with_request_handler(req_handler) .build() .await; @@ -80,7 +79,9 @@ mod tests { query_id: expected_query_id, input_stream: expected_input.to_vec().into(), }); - handler(Extension(transport), req).await.unwrap(); + handler(Extension(test_server.transport), req) + .await + .unwrap(); } struct OverrideReq { diff --git a/ipa-core/src/net/server/handlers/query/prepare.rs b/ipa-core/src/net/server/handlers/query/prepare.rs index 9f55793b0..ceb2f0840 100644 --- a/ipa-core/src/net/server/handlers/query/prepare.rs +++ b/ipa-core/src/net/server/handlers/query/prepare.rs @@ -45,6 +45,7 @@ mod tests { use crate::{ ff::FieldType, helpers::{ + make_owned_handler, query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, routing::{Addr, RouteId}, BodyStream, HelperIdentity, HelperResponse, RoleAssignment, @@ -72,23 +73,26 @@ mod tests { roles: RoleAssignment::new(HelperIdentity::make_three()), }); let expected_prepare_query = req.data.clone(); - let TestServer { transport, .. } = TestServer::builder() - .with_request_handler(Box::new( - move |addr: Addr, _data: BodyStream| { - let RouteId::PrepareQuery = addr.route else { - panic!("unexpected call"); - }; - - let query_config = addr.into::().unwrap(); - assert_eq!(query_config, expected_prepare_query); - Ok(HelperResponse::ok()) + let test_server = TestServer::builder() + .with_request_handler(make_owned_handler( + move |addr: Addr, _: BodyStream| { + let expected_prepare_query = expected_prepare_query.clone(); + async move { + let RouteId::PrepareQuery = addr.route else { + panic!("unexpected call"); + }; + + let actual_prepare_query = addr.into::().unwrap(); + assert_eq!(actual_prepare_query, expected_prepare_query); + Ok(HelperResponse::ok()) + } }, )) .build() .await; handler( - Extension(transport), + Extension(test_server.transport), Extension(ClientIdentity(HelperIdentity::TWO)), req.clone(), ) diff --git a/ipa-core/src/net/server/handlers/query/results.rs b/ipa-core/src/net/server/handlers/query/results.rs index 679283f7e..020c157c1 100644 --- a/ipa-core/src/net/server/handlers/query/results.rs +++ b/ipa-core/src/net/server/handlers/query/results.rs @@ -35,6 +35,7 @@ mod tests { use crate::{ ff::Fp31, helpers::{ + make_owned_handler, routing::{Addr, RouteId}, BodyStream, HelperIdentity, HelperResponse, }, @@ -59,21 +60,26 @@ mod tests { ))]); let expected_query_id = QueryId; let raw_results = expected_results.to_vec(); - let TestServer { transport, .. } = TestServer::builder() - .with_request_handler(Box::new( - move |addr: Addr, _data: BodyStream| { - let RouteId::CompleteQuery = addr.route else { - panic!("unexpected call"); - }; - let results = Box::new(raw_results.clone()) as Box; - assert_eq!(addr.query_id, Some(expected_query_id)); - Ok(HelperResponse::from(results)) + let test_server = TestServer::builder() + .with_request_handler(make_owned_handler( + move |addr: Addr, _: BodyStream| { + let raw_results = raw_results.clone(); + async move { + let RouteId::CompleteQuery = addr.route else { + panic!("unexpected call"); + }; + let results = Box::new(raw_results.clone()) as Box; + assert_eq!(addr.query_id, Some(expected_query_id)); + Ok(HelperResponse::from(results)) + } }, )) .build() .await; let req = http_serde::query::results::Request::new(QueryId); - let results = handler(Extension(transport), req.clone()).await.unwrap(); + let results = handler(Extension(test_server.transport), req.clone()) + .await + .unwrap(); assert_eq!(results, expected_results.as_bytes()); } diff --git a/ipa-core/src/net/server/handlers/query/status.rs b/ipa-core/src/net/server/handlers/query/status.rs index a06d95a28..20eb94318 100644 --- a/ipa-core/src/net/server/handlers/query/status.rs +++ b/ipa-core/src/net/server/handlers/query/status.rs @@ -32,6 +32,7 @@ mod tests { use crate::{ helpers::{ + make_owned_handler, routing::{Addr, RouteId}, BodyStream, HelperIdentity, HelperResponse, }, @@ -51,9 +52,9 @@ mod tests { async fn status_test() { let expected_status = QueryStatus::Running; let expected_query_id = QueryId; - let TestServer { transport, .. } = TestServer::builder() - .with_request_handler(Box::new( - move |addr: Addr, _data: BodyStream| { + let test_server = TestServer::builder() + .with_request_handler(make_owned_handler( + move |addr: Addr, _data: BodyStream| async move { let RouteId::QueryStatus = addr.route else { panic!("unexpected call"); }; @@ -64,7 +65,9 @@ mod tests { .build() .await; let req = http_serde::query::status::Request::new(QueryId); - let response = handler(Extension(transport), req.clone()).await.unwrap(); + let response = handler(Extension(test_server.transport), req.clone()) + .await + .unwrap(); let Json(http_serde::query::status::ResponseBody { status }) = response; assert_eq!(status, expected_status); diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index ed4312c15..3cd0221b9 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -22,7 +22,7 @@ use crate::{ ClientConfig, HpkeClientConfig, HpkeServerConfig, NetworkConfig, PeerConfig, ServerConfig, TlsConfig, }, - helpers::{HelperIdentity, PanickingHandler, RequestHandler}, + helpers::{HandlerBox, HelperIdentity, RequestHandler}, hpke::{Deserializable as _, IpaPublicKey}, net::{ClientIdentity, HttpTransport, MpcHelperClient, MpcHelperServer}, sync::Arc, @@ -209,6 +209,7 @@ pub struct TestServer { pub transport: Arc, pub server: MpcHelperServer, pub client: MpcHelperClient, + pub request_handler: Option>>, } impl TestServer { @@ -229,7 +230,7 @@ impl TestServer { #[derive(Default)] pub struct TestServerBuilder { - handler: Option>>, + handler: Option>>, metrics: Option, disable_https: bool, use_http1: bool, @@ -240,7 +241,7 @@ impl TestServerBuilder { #[must_use] pub fn with_request_handler( mut self, - handler: Box>, + handler: Arc>, ) -> Self { self.handler = Some(handler); self @@ -295,13 +296,13 @@ impl TestServerBuilder { panic!("TestConfig should have allocated ports"); }; let clients = MpcHelperClient::from_conf(&network_config, identity.clone()); + let handler = self.handler.as_ref().map(HandlerBox::owning_ref); let (transport, server) = HttpTransport::new( HelperIdentity::ONE, server_config, network_config.clone(), clients, - self.handler - .unwrap_or_else(|| Box::::default()), + handler, ); let (addr, handle) = server.start_on(Some(server_socket), self.metrics).await; // Get the config for HelperIdentity::ONE @@ -316,6 +317,7 @@ impl TestServerBuilder { transport, server, client, + request_handler: self.handler, } } } diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 36ae1b231..83d8e76e7 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -16,7 +16,7 @@ use crate::{ helpers::{ query::QueryConfig, routing::{Addr, RouteId}, - ApiError, BodyStream, HelperIdentity, HelperResponse, LogErrors, NoQueryId, + ApiError, BodyStream, HandlerRef, HelperIdentity, HelperResponse, LogErrors, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, RequestHandler, RouteParams, StepBinding, StreamCollection, Transport, }, @@ -34,7 +34,7 @@ pub struct HttpTransport { // TODO(615): supporting multiple queries likely require a hashmap here. It will be ok if we // only allow one query at a time. record_streams: StreamCollection, - handler: Box>, + handler: Option, } impl RouteParams for QueryConfig { @@ -64,7 +64,7 @@ impl HttpTransport { server_config: ServerConfig, network_config: NetworkConfig, clients: [MpcHelperClient; 3], - handler: Box>, + handler: Option, ) -> (Arc, MpcHelperServer) { let transport = Self::new_internal(identity, clients, handler); let server = MpcHelperServer::new(Arc::clone(&transport), server_config, network_config); @@ -74,7 +74,7 @@ impl HttpTransport { fn new_internal( identity: HelperIdentity, clients: [MpcHelperClient; 3], - handler: Box>, + handler: Option, ) -> Arc { Arc::new(Self { identity, @@ -88,6 +88,9 @@ impl HttpTransport { /// /// ## Errors /// Returns an error, if handler rejects the request for any reason. + /// + /// ## Panics + /// This will panic if request handler hasn't been previously set for this transport. pub async fn dispatch>( self: Arc, req: R, @@ -124,7 +127,11 @@ impl HttpTransport { } let route_id = req.resource_identifier(); - let r = self.handler.handle(Addr::from_route(None, req), body); + let r = self + .handler + .as_ref() + .expect("Handler is set") + .handle(Addr::from_route(None, req), body); if let RouteId::CompleteQuery = route_id { ClearOnDrop { @@ -221,7 +228,8 @@ impl Transport for Arc { } } -#[cfg(all(test, web_test))] +// #[cfg(all(test, web_test))] //FIXME +#[cfg(all(test, feature = "real-world-infra"))] mod tests { use std::{iter::zip, net::TcpListener, task::Poll}; @@ -307,14 +315,14 @@ mod tests { } else { get_test_identity(id) }; - let (setup, handler_setup) = AppSetup::new(); + let (setup, handler) = AppSetup::new(); let clients = MpcHelperClient::from_conf(network_config, identity); let (transport, server) = HttpTransport::new( id, server_config, network_config.clone(), clients, - handler_setup.make_handler(), + Some(handler), ); server.start_on(Some(socket), ()).await; diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 81508ddcb..a003e95ac 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -321,10 +321,10 @@ mod tests { use crate::{ ff::FieldType, helpers::{ - make_boxed_handler, + make_owned_handler, query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, - ApiError, HelperIdentity, HelperResponse, InMemoryMpcNetwork, RequestHandler, - RoleAssignment, Transport, + ApiError, HandlerBox, HelperIdentity, HelperResponse, InMemoryMpcNetwork, + RequestHandler, RoleAssignment, Transport, }, protocol::QueryId, query::{ @@ -332,18 +332,18 @@ mod tests { }, }; - fn prepare_query_handler(cb: F) -> Box> + fn prepare_query_handler(cb: F) -> Arc> where F: Fn(PrepareQuery) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { - make_boxed_handler(move |req, _| { + make_owned_handler(move |req, _| { let prepare_query = req.into().unwrap(); cb(prepare_query) }) } - fn respond_ok() -> Box> { + fn respond_ok() -> Arc> { prepare_query_handler(move |_| async move { Ok(HelperResponse::ok()) }) } @@ -370,7 +370,11 @@ mod tests { Ok(HelperResponse::ok()) } }); - let network = InMemoryMpcNetwork::new([None, Some(h2), Some(h3)]); + let network = InMemoryMpcNetwork::new([ + None, + Some(HandlerBox::owning_ref(&h2)), + Some(HandlerBox::owning_ref(&h3)), + ]); let [t0, _, _] = network.transports(); let p0 = Processor::default(); let request = test_multiply_config(); @@ -404,12 +408,10 @@ mod tests { #[tokio::test] async fn rejects_duplicate_query_id() { - let handlers = array::from_fn(|_| { - Some(prepare_query_handler(|_| async { - Ok(HelperResponse::ok()) - })) - }); - let network = InMemoryMpcNetwork::new(handlers); + let handlers = + array::from_fn(|_| prepare_query_handler(|_| async { Ok(HelperResponse::ok()) })); + let network = + InMemoryMpcNetwork::new(handlers.each_ref().map(HandlerBox::owning_ref).map(Some)); let [t0, _, _] = network.transports(); let p0 = Processor::default(); let request = test_multiply_config(); @@ -426,11 +428,15 @@ mod tests { #[tokio::test] async fn prepare_error() { - let h2 = respond_ok().into(); - let h3 = Some(prepare_query_handler(|_| async move { + let h2 = respond_ok(); + let h3 = prepare_query_handler(|_| async move { Err(ApiError::QueryPrepare(PrepareQueryError::WrongTarget)) - })); - let network = InMemoryMpcNetwork::new([None, h2, h3]); + }); + let network = InMemoryMpcNetwork::new([ + None, + Some(HandlerBox::owning_ref(&h2)), + Some(HandlerBox::owning_ref(&h3)), + ]); let [t0, _, _] = network.transports(); let p0 = Processor::default(); let request = test_multiply_config(); @@ -443,11 +449,15 @@ mod tests { #[tokio::test] async fn can_recover_from_prepare_error() { - let h2 = respond_ok().into(); - let h3 = Some(prepare_query_handler(|_| async move { + let h2 = respond_ok(); + let h3 = prepare_query_handler(|_| async move { Err(ApiError::QueryPrepare(PrepareQueryError::WrongTarget)) - })); - let network = InMemoryMpcNetwork::new([None, h2, h3]); + }); + let network = InMemoryMpcNetwork::new([ + None, + Some(HandlerBox::owning_ref(&h2)), + Some(HandlerBox::owning_ref(&h3)), + ]); let [t0, _, _] = network.transports(); let p0 = Processor::default(); let request = test_multiply_config(); diff --git a/ipa-core/src/test_fixture/app.rs b/ipa-core/src/test_fixture/app.rs index 6ed453523..96d09fe59 100644 --- a/ipa-core/src/test_fixture/app.rs +++ b/ipa-core/src/test_fixture/app.rs @@ -4,7 +4,6 @@ use generic_array::GenericArray; use typenum::Unsigned; use crate::{ - app::RequestHandlerSetup, ff::Serializable, helpers::{ query::{QueryConfig, QueryInput}, @@ -61,13 +60,8 @@ fn unzip_tuple_array(input: [(T, U); 3]) -> ([T; 3], [U; 3]) { impl Default for TestApp { fn default() -> Self { let (setup, handlers) = unzip_tuple_array(array::from_fn(|_| AppSetup::new())); - let handlers_ref = [&handlers[0], &handlers[1], &handlers[2]]; - let network = InMemoryMpcNetwork::new( - handlers_ref - .map(RequestHandlerSetup::make_handler) - .map(Some), - ); + let network = InMemoryMpcNetwork::new(handlers.map(Some)); let drivers = network .transports() .iter() @@ -88,6 +82,7 @@ impl TestApp { /// ## Errors /// Returns an error if it can't start a query or send query input. #[allow(clippy::missing_panics_doc)] + pub async fn start_query( &self, input: I, From 49c244a3fd07cc142d16c726260616ae4689413d Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 26 Mar 2024 22:51:28 -0700 Subject: [PATCH 3/5] Fix one FIXME --- ipa-core/src/net/transport.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 83d8e76e7..8e9dbca8e 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -228,8 +228,7 @@ impl Transport for Arc { } } -// #[cfg(all(test, web_test))] //FIXME -#[cfg(all(test, feature = "real-world-infra"))] +#[cfg(all(test, web_test))] mod tests { use std::{iter::zip, net::TcpListener, task::Poll}; From facf7062cc76133ddf9039572884fc76c17884bd Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 26 Mar 2024 23:11:27 -0700 Subject: [PATCH 4/5] Clean up code --- ipa-core/src/app.rs | 4 ++++ ipa-core/src/helpers/transport/handler.rs | 12 +----------- .../src/helpers/transport/in_memory/transport.rs | 2 +- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index 3d7a322b7..9a90a8a49 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -28,6 +28,10 @@ pub struct HelperApp { struct Inner { query_processor: QueryProcessor, + /// For HTTP implementation this transport is also behind an [`Arc`] which causes double indirection + /// on top of atomics and all fun stuff associated with it. I don't see an easy way to avoid that + /// if we want to keep the implementation leak-free, but one may be aware if this shows up on + /// the flamegraph transport: TransportImpl, } diff --git a/ipa-core/src/helpers/transport/handler.rs b/ipa-core/src/helpers/transport/handler.rs index 613c6bdd8..f0e0f0bb7 100644 --- a/ipa-core/src/helpers/transport/handler.rs +++ b/ipa-core/src/helpers/transport/handler.rs @@ -1,8 +1,4 @@ -use std::{ - fmt::{Debug, Formatter}, - future::Future, - marker::PhantomData, -}; +use std::{fmt::Debug, future::Future, marker::PhantomData}; use async_trait::async_trait; use serde::de::DeserializeOwned; @@ -98,12 +94,6 @@ pub struct HandlerRef { inner: Arc>, } -impl Debug for HelperResponse { - fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result { - todo!() - } -} - impl HelperResponse { /// Returns an empty response that indicates that incoming request has been processed successfully #[must_use] diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 595154802..cb456a599 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -120,7 +120,7 @@ impl InMemoryTransport { } }; - ack.send(result).unwrap(); + ack.send(result).map_err(|_| "Channel closed").unwrap(); } } } From 02ee736b8f98484eddec15dc3e21c198f5d448aa Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 28 Mar 2024 09:38:17 -0700 Subject: [PATCH 5/5] Feedback --- ipa-core/src/app.rs | 2 +- ipa-core/src/helpers/transport/handler.rs | 2 +- ipa-core/src/net/client/mod.rs | 2 +- ipa-core/src/net/server/handlers/query/create.rs | 4 ++-- ipa-core/src/net/server/handlers/query/results.rs | 2 +- ipa-core/src/query/executor.rs | 6 +++--- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index 9a90a8a49..0ca99287d 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -115,7 +115,7 @@ impl HelperApp { .query_processor .complete(query_id) .await? - .as_bytes()) + .to_bytes()) } } diff --git a/ipa-core/src/helpers/transport/handler.rs b/ipa-core/src/helpers/transport/handler.rs index f0e0f0bb7..42981d097 100644 --- a/ipa-core/src/helpers/transport/handler.rs +++ b/ipa-core/src/helpers/transport/handler.rs @@ -137,7 +137,7 @@ impl From for HelperResponse { impl> From for HelperResponse { fn from(value: R) -> Self { - let v = value.as_ref().as_bytes(); + let v = value.as_ref().to_bytes(); Self { body: v } } } diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 2d49d4e28..795457718 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -658,7 +658,7 @@ pub(crate) mod tests { results.to_vec(), [Replicated::from((expected_results[0], expected_results[1]))] .to_vec() - .as_bytes() + .to_bytes() ); } } diff --git a/ipa-core/src/net/server/handlers/query/create.rs b/ipa-core/src/net/server/handlers/query/create.rs index 2d46f043b..3fc7bc641 100644 --- a/ipa-core/src/net/server/handlers/query/create.rs +++ b/ipa-core/src/net/server/handlers/query/create.rs @@ -2,7 +2,7 @@ use axum::{routing::post, Extension, Json, Router}; use hyper::StatusCode; use crate::{ - helpers::{ApiError::NewQuery, BodyStream, Transport}, + helpers::{ApiError, BodyStream, Transport}, net::{http_serde, Error, HttpTransport}, query::NewQueryError, sync::Arc, @@ -20,7 +20,7 @@ async fn handler( .await { Ok(resp) => Ok(Json(resp.try_into()?)), - Err(err @ NewQuery(NewQueryError::State { .. })) => { + Err(err @ ApiError::NewQuery(NewQueryError::State { .. })) => { Err(Error::application(StatusCode::CONFLICT, err)) } Err(err) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, err)), diff --git a/ipa-core/src/net/server/handlers/query/results.rs b/ipa-core/src/net/server/handlers/query/results.rs index 020c157c1..8e6cad2f6 100644 --- a/ipa-core/src/net/server/handlers/query/results.rs +++ b/ipa-core/src/net/server/handlers/query/results.rs @@ -80,7 +80,7 @@ mod tests { let results = handler(Extension(test_server.transport), req.clone()) .await .unwrap(); - assert_eq!(results, expected_results.as_bytes()); + assert_eq!(results, expected_results.to_bytes()); } struct OverrideReq { diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index 409c6c8f5..c5977e7df 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -36,7 +36,7 @@ use crate::{ }; pub trait Result: Send + Debug { - fn as_bytes(&self) -> Vec; + fn to_bytes(&self) -> Vec; } impl Result for Vec @@ -44,7 +44,7 @@ where T: Serializable, Vec: Debug + Send, { - fn as_bytes(&self) -> Vec { + fn to_bytes(&self) -> Vec { let mut r = vec![0u8; self.len() * T::Size::USIZE]; for (i, row) in self.iter().enumerate() { row.serialize(GenericArray::from_mut_slice( @@ -156,7 +156,7 @@ mod tests { fn serialize_result() { let [input, ..] = (0u128..=3).map(Fp31::truncate_from).share(); let expected = input.clone(); - let bytes = &input.as_bytes(); + let bytes = &input.to_bytes(); assert_eq!( expected, AdditiveShare::::from_byte_slice(bytes)