diff --git a/crates/core/src/client_events/combinator.rs b/crates/core/src/client_events/combinator.rs index 145d4b858..f58a403dc 100644 --- a/crates/core/src/client_events/combinator.rs +++ b/crates/core/src/client_events/combinator.rs @@ -1,36 +1,33 @@ -use std::future::Future; -use std::pin::Pin; -use std::task::Context; -use std::{collections::HashMap, task::Poll}; +use std::collections::HashMap; use freenet_stdlib::client_api::{ErrorKind, HostResponse}; use futures::future::BoxFuture; -use futures::task::AtomicWaker; -use futures::FutureExt; +use futures::stream::FuturesUnordered; +use futures::{FutureExt, StreamExt}; use tokio::sync::mpsc::{channel, Receiver, Sender}; use super::{BoxedClient, ClientError, ClientId, HostResult, OpenRequest}; type HostIncomingMsg = Result, ClientError>; +type ClientEventsFut = + BoxFuture<'static, (usize, Receiver, Option)>; + /// This type allows combining different sources of events into one and interoperation between them. pub struct ClientEventsCombinator { + pending_futs: FuturesUnordered, /// receiving end of the different client applications from the node clients: [Sender<(ClientId, HostResult)>; N], - /// receiving end of the host node from the different client applications - hosts_rx: [Receiver; N], /// a map of the individual protocols, external, sending client events ids to an internal list of ids external_clients: [HashMap; N], /// a map of the external id to which protocol it belongs (represented by the index in the array) /// and the original id (reverse of indexes) internal_clients: HashMap, - #[allow(clippy::type_complexity)] - pend_futs: - [Option> + Sync + Send + 'static>>>; N], } impl ClientEventsCombinator { pub fn new(clients: [BoxedClient; N]) -> Self { + let pending_futs = FuturesUnordered::new(); let channels = clients.map(|client| { let (tx, rx) = channel(1); let (tx_host, rx_host) = channel(1); @@ -43,49 +40,37 @@ impl ClientEventsCombinator { clients[i] = Some(tx); hosts_rx[i] = Some(rx_host); } - let hosts_rx = hosts_rx.map(|h| h.unwrap()); let external_clients = [(); N].map(|_| HashMap::new()); + + for (i, rx) in hosts_rx.iter_mut().enumerate() { + let Some(mut rx) = rx.take() else { + continue; + }; + pending_futs.push( + async move { + let res = rx.recv().await; + (i, rx, res) + } + .boxed(), + ); + } + Self { clients: clients.map(|c| c.unwrap()), - hosts_rx, external_clients, internal_clients: HashMap::new(), - pend_futs: [(); N].map(|_| None), + pending_futs, } } } impl super::ClientEventsProxy for ClientEventsCombinator { - fn recv<'a>(&'_ mut self) -> BoxFuture<'_, Result, ClientError>> { - Box::pin(async { - let mut futs_opt = [(); N].map(|_| None); - let pend_futs = &mut self.pend_futs; - for (i, pend) in pend_futs.iter_mut().enumerate() { - let fut = &mut futs_opt[i]; - if let Some(pend_fut) = pend.take() { - *fut = Some(pend_fut); - } else { - // this receiver ain't awaiting, queue a new one - // SAFETY: is safe here to extend the lifetime since clients are required to be 'static - // and we take ownership, so they will be alive for the duration of the program - let f = Box::pin(self.hosts_rx[i].recv()) - as Pin + Send + Sync + '_>>; + fn recv(&mut self) -> BoxFuture<'_, Result, ClientError>> { + async { + let Some((idx, mut rx, res)) = self.pending_futs.next().await else { + unreachable!(); + }; - type ExtendedLife<'a, 'b> = Pin< - Box< - dyn Future, ClientError>>> - + Send - + Sync - + 'b, - >, - >; - let new_pend = unsafe { - std::mem::transmute::, ExtendedLife<'_, '_>>(f) - }; - *fut = Some(new_pend); - } - } - let (res, idx, mut others) = select_all(futs_opt.map(|f| f.unwrap())).await; let res = res .map(|res| { match res { @@ -95,19 +80,15 @@ impl super::ClientEventsProxy for ClientEventsCombinator { notification_channel, token, }) => { - tracing::debug!( - "received request; internal_id={external}; req={request}" - ); - let id = - *self.external_clients[idx] - .entry(external) - .or_insert_with(|| { - // add a new mapped external client id - let internal = ClientId::next(); - self.internal_clients.insert(internal, (idx, external)); - internal - }); - + let id = *self.external_clients[idx] + .entry(external) + .or_insert_with(|| { + // add a new mapped external client id + let internal = ClientId::next(); + self.internal_clients.insert(internal, (idx, external)); + internal + }); + tracing::debug!("received request for proxy #{idx}; internal_id={id}; external_id={external}; req={request}"); Ok(OpenRequest { client_id: id, request, @@ -119,15 +100,18 @@ impl super::ClientEventsProxy for ClientEventsCombinator { } }) .unwrap_or_else(|| Err(ErrorKind::TransportProtocolDisconnect.into())); - // place back futs - debug_assert!(pend_futs.iter().all(|f| f.is_none())); - debug_assert_eq!( - others.iter().filter(|a| a.is_some()).count(), - pend_futs.len() - 1 + + self.pending_futs.push( + async move { + let res = rx.recv().await; + (idx, rx, res) + } + .boxed(), ); - std::mem::swap(pend_futs, &mut others); + res - }) + } + .boxed() } fn send<'a>( @@ -135,7 +119,7 @@ impl super::ClientEventsProxy for ClientEventsCombinator { internal: ClientId, response: Result, ) -> BoxFuture<'_, Result<(), ClientError>> { - Box::pin(async move { + async move { let (idx, external) = self .internal_clients .get(&internal) @@ -145,7 +129,8 @@ impl super::ClientEventsProxy for ClientEventsCombinator { .await .map_err(|_| ErrorKind::TransportProtocolDisconnect)?; Ok(()) - }) + } + .boxed() } } @@ -159,6 +144,7 @@ async fn client_fn( host_msg = rx.recv() => { if let Some((client_id, response)) = host_msg { if client.send(client_id, response).await.is_err() { + eprintln!("disconnected host"); break; } } else { @@ -189,62 +175,10 @@ async fn client_fn( tracing::error!("Client shut down"); } -/// An optimized for the use case version of `futures::select_all` which keeps ordering. -#[must_use = "futures do nothing unless you `.await` or poll them"] -struct SelectAll { - waker: AtomicWaker, - inner: [Option; N], -} - -impl Unpin for SelectAll {} - -impl Future for SelectAll { - type Output = (Fut::Output, usize, [Option; N]); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - macro_rules! recv { - () => { - let item = self - .inner - .iter_mut() - .enumerate() - .find_map(|(i, f)| { - f.as_mut().map(|f| match f.poll_unpin(cx) { - Poll::Pending => None, - Poll::Ready(e) => Some((i, e)), - }) - }) - .flatten(); - match item { - Some((idx, res)) => { - self.inner[idx] = None; - let rest = std::mem::replace(&mut self.inner, [(); N].map(|_| None)); - return Poll::Ready((res, idx, rest)); - } - None => {} - } - }; - } - recv!(); - self.waker.register(cx.waker()); - recv!(); - Poll::Pending - } -} - -fn select_all(iter: [F; N]) -> SelectAll -where - F: Future + Unpin, -{ - SelectAll { - waker: AtomicWaker::new(), - inner: iter.map(|f| Some(f)), - } -} - #[cfg(test)] mod test { use freenet_stdlib::client_api::ClientRequest; + use futures::try_join; use super::*; use crate::client_events::ClientEventsProxy; @@ -252,11 +186,12 @@ mod test { struct SampleProxy { id: usize, rx: Receiver, + tx: Sender, } impl SampleProxy { - fn new(id: usize, rx: Receiver) -> Self { - Self { id, rx } + fn new(id: usize, rx: Receiver, tx: Sender) -> Self { + Self { id, rx, tx } } } @@ -279,41 +214,105 @@ mod test { fn send( &mut self, - _id: ClientId, + id: ClientId, _response: Result, ) -> BoxFuture<'_, Result<(), ClientError>> { - todo!() + assert_eq!(id.0, self.id); + async { + self.tx + .send(self.id) + .await + .map_err(|_| ErrorKind::ChannelClosed.into()) + } + .boxed() } } - #[ignore] - #[tokio::test] - async fn combinator_recv() { + fn setup_proxies() -> ([BoxedClient; 3], Vec>, Vec>) { let mut cnt = 0; let mut senders = vec![]; - let proxies = [None::<()>; 3].map(|_| { - let (tx, rx) = channel(1); - senders.push(tx); - let r = Box::new(SampleProxy::new(cnt, rx)) as _; + let mut receivers = vec![]; + let clients = [None::<()>; 3].map(|_| { + let (tx1, rx1) = channel(1); + let (tx2, rx2) = channel(1); + let r = Box::new(SampleProxy::new(cnt, rx1, tx2)) as _; + senders.push(tx1); + receivers.push(rx2); cnt += 1; r }); + (clients, senders, receivers) + } + + #[tokio::test] + async fn combinator_recv() { + let (proxies, mut senders, _) = setup_proxies(); let mut combinator = ClientEventsCombinator::new(proxies); - let _senders = tokio::task::spawn(async move { - for (id, tx) in senders.iter_mut().enumerate() { - tx.send(id).await.unwrap(); - eprintln!("sent msg {id}"); + let _senders: Vec> = tokio::task::spawn(async move { + for _ in 1..4 { + for (id, tx) in senders.iter_mut().enumerate() { + tx.send(id).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + eprintln!("sent msg {id}"); + } } senders }) .await .unwrap(); - for i in 0..3 { - let OpenRequest { client_id: id, .. } = combinator.recv().await.unwrap(); - eprintln!("received: {id:?}"); - assert_eq!(ClientId::new(i), id); + for _ in 0..3 { + for i in 1..4 { + let OpenRequest { client_id: id, .. } = combinator.recv().await.unwrap(); + eprintln!("received {i}: {id:?}"); + assert_eq!(ClientId::new(i), id); + } + } + } + + #[ignore] + #[tokio::test] + async fn test_send() { + let (proxies, _, mut receivers) = setup_proxies(); + let mut combinator = ClientEventsCombinator::new(proxies); + + for idx in 0..3 { + let client_id = ClientId::new(idx); + // Insert each client ID mapping into the combinator's internal clients. + combinator + .internal_clients + .insert(client_id, (idx, client_id)); } + + let receivers = async move { + // Test sending a response through the combinator for each proxy. + for (idx, receiver) in receivers.iter_mut().enumerate() { + // Assert that the receiver received the expected message. + let received_id = receiver + .recv() + .await + .ok_or(format!("missing {idx} sender"))?; + assert_eq!(received_id, idx); + println!( + "Receiver {} confirmed send for client ID: {}", + idx, received_id + ); + } + Ok::<_, Box>(()) + }; + + let senders = async { + for idx in 0..3 { + // Send a sample response through the combinator. + combinator + .send(ClientId::new(idx), Ok(HostResponse::Ok)) + .await + .map_err(|err| format!("Send failed for client {idx}: {err}",))?; + } + Ok::<_, Box>(()) + }; + + try_join!(senders, receivers).unwrap(); } }