Skip to content

Commit

Permalink
fix: combinator receiving (#1280)
Browse files Browse the repository at this point in the history
  • Loading branch information
iduartgomez authored Oct 27, 2024
1 parent e3ccd92 commit 2092a52
Showing 1 changed file with 136 additions and 137 deletions.
273 changes: 136 additions & 137 deletions crates/core/src/client_events/combinator.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,33 @@
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::{collections::HashMap, task::Poll};
use std::collections::HashMap;

use freenet_stdlib::client_api::{ErrorKind, HostResponse};
use futures::future::BoxFuture;
use futures::task::AtomicWaker;
use futures::FutureExt;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use tokio::sync::mpsc::{channel, Receiver, Sender};

use super::{BoxedClient, ClientError, ClientId, HostResult, OpenRequest};

type HostIncomingMsg = Result<OpenRequest<'static>, ClientError>;

type ClientEventsFut =
BoxFuture<'static, (usize, Receiver<HostIncomingMsg>, Option<HostIncomingMsg>)>;

/// This type allows combining different sources of events into one and interoperation between them.
pub struct ClientEventsCombinator<const N: usize> {
pending_futs: FuturesUnordered<ClientEventsFut>,
/// receiving end of the different client applications from the node
clients: [Sender<(ClientId, HostResult)>; N],
/// receiving end of the host node from the different client applications
hosts_rx: [Receiver<HostIncomingMsg>; N],
/// a map of the individual protocols, external, sending client events ids to an internal list of ids
external_clients: [HashMap<ClientId, ClientId>; N],
/// a map of the external id to which protocol it belongs (represented by the index in the array)
/// and the original id (reverse of indexes)
internal_clients: HashMap<ClientId, (usize, ClientId)>,
#[allow(clippy::type_complexity)]
pend_futs:
[Option<Pin<Box<dyn Future<Output = Option<HostIncomingMsg>> + Sync + Send + 'static>>>; N],
}

impl<const N: usize> ClientEventsCombinator<N> {
pub fn new(clients: [BoxedClient; N]) -> Self {
let pending_futs = FuturesUnordered::new();
let channels = clients.map(|client| {
let (tx, rx) = channel(1);
let (tx_host, rx_host) = channel(1);
Expand All @@ -43,49 +40,37 @@ impl<const N: usize> ClientEventsCombinator<N> {
clients[i] = Some(tx);
hosts_rx[i] = Some(rx_host);
}
let hosts_rx = hosts_rx.map(|h| h.unwrap());
let external_clients = [(); N].map(|_| HashMap::new());

for (i, rx) in hosts_rx.iter_mut().enumerate() {
let Some(mut rx) = rx.take() else {
continue;
};
pending_futs.push(
async move {
let res = rx.recv().await;
(i, rx, res)
}
.boxed(),
);
}

Self {
clients: clients.map(|c| c.unwrap()),
hosts_rx,
external_clients,
internal_clients: HashMap::new(),
pend_futs: [(); N].map(|_| None),
pending_futs,
}
}
}

impl<const N: usize> super::ClientEventsProxy for ClientEventsCombinator<N> {
fn recv<'a>(&'_ mut self) -> BoxFuture<'_, Result<OpenRequest<'static>, ClientError>> {
Box::pin(async {
let mut futs_opt = [(); N].map(|_| None);
let pend_futs = &mut self.pend_futs;
for (i, pend) in pend_futs.iter_mut().enumerate() {
let fut = &mut futs_opt[i];
if let Some(pend_fut) = pend.take() {
*fut = Some(pend_fut);
} else {
// this receiver ain't awaiting, queue a new one
// SAFETY: is safe here to extend the lifetime since clients are required to be 'static
// and we take ownership, so they will be alive for the duration of the program
let f = Box::pin(self.hosts_rx[i].recv())
as Pin<Box<dyn Future<Output = _> + Send + Sync + '_>>;
fn recv(&mut self) -> BoxFuture<'_, Result<OpenRequest<'static>, ClientError>> {
async {
let Some((idx, mut rx, res)) = self.pending_futs.next().await else {
unreachable!();
};

type ExtendedLife<'a, 'b> = Pin<
Box<
dyn Future<Output = Option<Result<OpenRequest<'a>, ClientError>>>
+ Send
+ Sync
+ 'b,
>,
>;
let new_pend = unsafe {
std::mem::transmute::<ExtendedLife<'_, '_>, ExtendedLife<'_, '_>>(f)
};
*fut = Some(new_pend);
}
}
let (res, idx, mut others) = select_all(futs_opt.map(|f| f.unwrap())).await;
let res = res
.map(|res| {
match res {
Expand All @@ -95,19 +80,15 @@ impl<const N: usize> super::ClientEventsProxy for ClientEventsCombinator<N> {
notification_channel,
token,
}) => {
tracing::debug!(
"received request; internal_id={external}; req={request}"
);
let id =
*self.external_clients[idx]
.entry(external)
.or_insert_with(|| {
// add a new mapped external client id
let internal = ClientId::next();
self.internal_clients.insert(internal, (idx, external));
internal
});

let id = *self.external_clients[idx]
.entry(external)
.or_insert_with(|| {
// add a new mapped external client id
let internal = ClientId::next();
self.internal_clients.insert(internal, (idx, external));
internal
});
tracing::debug!("received request for proxy #{idx}; internal_id={id}; external_id={external}; req={request}");
Ok(OpenRequest {
client_id: id,
request,
Expand All @@ -119,23 +100,26 @@ impl<const N: usize> super::ClientEventsProxy for ClientEventsCombinator<N> {
}
})
.unwrap_or_else(|| Err(ErrorKind::TransportProtocolDisconnect.into()));
// place back futs
debug_assert!(pend_futs.iter().all(|f| f.is_none()));
debug_assert_eq!(
others.iter().filter(|a| a.is_some()).count(),
pend_futs.len() - 1

self.pending_futs.push(
async move {
let res = rx.recv().await;
(idx, rx, res)
}
.boxed(),
);
std::mem::swap(pend_futs, &mut others);

res
})
}
.boxed()
}

fn send<'a>(
&mut self,
internal: ClientId,
response: Result<HostResponse, ClientError>,
) -> BoxFuture<'_, Result<(), ClientError>> {
Box::pin(async move {
async move {
let (idx, external) = self
.internal_clients
.get(&internal)
Expand All @@ -145,7 +129,8 @@ impl<const N: usize> super::ClientEventsProxy for ClientEventsCombinator<N> {
.await
.map_err(|_| ErrorKind::TransportProtocolDisconnect)?;
Ok(())
})
}
.boxed()
}
}

Expand All @@ -159,6 +144,7 @@ async fn client_fn(
host_msg = rx.recv() => {
if let Some((client_id, response)) = host_msg {
if client.send(client_id, response).await.is_err() {
eprintln!("disconnected host");
break;
}
} else {
Expand Down Expand Up @@ -189,74 +175,23 @@ async fn client_fn(
tracing::error!("Client shut down");
}

/// An optimized for the use case version of `futures::select_all` which keeps ordering.
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct SelectAll<Fut, const N: usize> {
waker: AtomicWaker,
inner: [Option<Fut>; N],
}

impl<Fut: Unpin, const N: usize> Unpin for SelectAll<Fut, N> {}

impl<Fut: Future + Unpin, const N: usize> Future for SelectAll<Fut, N> {
type Output = (Fut::Output, usize, [Option<Fut>; N]);

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
macro_rules! recv {
() => {
let item = self
.inner
.iter_mut()
.enumerate()
.find_map(|(i, f)| {
f.as_mut().map(|f| match f.poll_unpin(cx) {
Poll::Pending => None,
Poll::Ready(e) => Some((i, e)),
})
})
.flatten();
match item {
Some((idx, res)) => {
self.inner[idx] = None;
let rest = std::mem::replace(&mut self.inner, [(); N].map(|_| None));
return Poll::Ready((res, idx, rest));
}
None => {}
}
};
}
recv!();
self.waker.register(cx.waker());
recv!();
Poll::Pending
}
}

fn select_all<F, const N: usize>(iter: [F; N]) -> SelectAll<F, N>
where
F: Future + Unpin,
{
SelectAll {
waker: AtomicWaker::new(),
inner: iter.map(|f| Some(f)),
}
}

#[cfg(test)]
mod test {
use freenet_stdlib::client_api::ClientRequest;
use futures::try_join;

use super::*;
use crate::client_events::ClientEventsProxy;

struct SampleProxy {
id: usize,
rx: Receiver<usize>,
tx: Sender<usize>,
}

impl SampleProxy {
fn new(id: usize, rx: Receiver<usize>) -> Self {
Self { id, rx }
fn new(id: usize, rx: Receiver<usize>, tx: Sender<usize>) -> Self {
Self { id, rx, tx }
}
}

Expand All @@ -279,41 +214,105 @@ mod test {

fn send(
&mut self,
_id: ClientId,
id: ClientId,
_response: Result<HostResponse, ClientError>,
) -> BoxFuture<'_, Result<(), ClientError>> {
todo!()
assert_eq!(id.0, self.id);
async {
self.tx
.send(self.id)
.await
.map_err(|_| ErrorKind::ChannelClosed.into())
}
.boxed()
}
}

#[ignore]
#[tokio::test]
async fn combinator_recv() {
fn setup_proxies() -> ([BoxedClient; 3], Vec<Sender<usize>>, Vec<Receiver<usize>>) {
let mut cnt = 0;
let mut senders = vec![];
let proxies = [None::<()>; 3].map(|_| {
let (tx, rx) = channel(1);
senders.push(tx);
let r = Box::new(SampleProxy::new(cnt, rx)) as _;
let mut receivers = vec![];
let clients = [None::<()>; 3].map(|_| {
let (tx1, rx1) = channel(1);
let (tx2, rx2) = channel(1);
let r = Box::new(SampleProxy::new(cnt, rx1, tx2)) as _;
senders.push(tx1);
receivers.push(rx2);
cnt += 1;
r
});
(clients, senders, receivers)
}

#[tokio::test]
async fn combinator_recv() {
let (proxies, mut senders, _) = setup_proxies();
let mut combinator = ClientEventsCombinator::new(proxies);

let _senders = tokio::task::spawn(async move {
for (id, tx) in senders.iter_mut().enumerate() {
tx.send(id).await.unwrap();
eprintln!("sent msg {id}");
let _senders: Vec<Sender<usize>> = tokio::task::spawn(async move {
for _ in 1..4 {
for (id, tx) in senders.iter_mut().enumerate() {
tx.send(id).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
eprintln!("sent msg {id}");
}
}
senders
})
.await
.unwrap();

for i in 0..3 {
let OpenRequest { client_id: id, .. } = combinator.recv().await.unwrap();
eprintln!("received: {id:?}");
assert_eq!(ClientId::new(i), id);
for _ in 0..3 {
for i in 1..4 {
let OpenRequest { client_id: id, .. } = combinator.recv().await.unwrap();
eprintln!("received {i}: {id:?}");
assert_eq!(ClientId::new(i), id);
}
}
}

#[ignore]
#[tokio::test]
async fn test_send() {
let (proxies, _, mut receivers) = setup_proxies();
let mut combinator = ClientEventsCombinator::new(proxies);

for idx in 0..3 {
let client_id = ClientId::new(idx);
// Insert each client ID mapping into the combinator's internal clients.
combinator
.internal_clients
.insert(client_id, (idx, client_id));
}

let receivers = async move {
// Test sending a response through the combinator for each proxy.
for (idx, receiver) in receivers.iter_mut().enumerate() {
// Assert that the receiver received the expected message.
let received_id = receiver
.recv()
.await
.ok_or(format!("missing {idx} sender"))?;
assert_eq!(received_id, idx);
println!(
"Receiver {} confirmed send for client ID: {}",
idx, received_id
);
}
Ok::<_, Box<dyn std::error::Error>>(())
};

let senders = async {
for idx in 0..3 {
// Send a sample response through the combinator.
combinator
.send(ClientId::new(idx), Ok(HostResponse::Ok))
.await
.map_err(|err| format!("Send failed for client {idx}: {err}",))?;
}
Ok::<_, Box<dyn std::error::Error>>(())
};

try_join!(senders, receivers).unwrap();
}
}

0 comments on commit 2092a52

Please sign in to comment.