Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: combinator receiving #1280

Merged
merged 2 commits into from
Oct 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 136 additions & 137 deletions crates/core/src/client_events/combinator.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,33 @@
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::{collections::HashMap, task::Poll};
use std::collections::HashMap;

use freenet_stdlib::client_api::{ErrorKind, HostResponse};
use futures::future::BoxFuture;
use futures::task::AtomicWaker;
use futures::FutureExt;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use tokio::sync::mpsc::{channel, Receiver, Sender};

use super::{BoxedClient, ClientError, ClientId, HostResult, OpenRequest};

type HostIncomingMsg = Result<OpenRequest<'static>, ClientError>;

type ClientEventsFut =
BoxFuture<'static, (usize, Receiver<HostIncomingMsg>, Option<HostIncomingMsg>)>;

/// This type allows combining different sources of events into one and interoperation between them.
pub struct ClientEventsCombinator<const N: usize> {
pending_futs: FuturesUnordered<ClientEventsFut>,
/// receiving end of the different client applications from the node
clients: [Sender<(ClientId, HostResult)>; N],
/// receiving end of the host node from the different client applications
hosts_rx: [Receiver<HostIncomingMsg>; N],
/// a map of the individual protocols, external, sending client events ids to an internal list of ids
external_clients: [HashMap<ClientId, ClientId>; N],
/// a map of the external id to which protocol it belongs (represented by the index in the array)
/// and the original id (reverse of indexes)
internal_clients: HashMap<ClientId, (usize, ClientId)>,
#[allow(clippy::type_complexity)]
pend_futs:
[Option<Pin<Box<dyn Future<Output = Option<HostIncomingMsg>> + Sync + Send + 'static>>>; N],
}

impl<const N: usize> ClientEventsCombinator<N> {
pub fn new(clients: [BoxedClient; N]) -> Self {
let pending_futs = FuturesUnordered::new();
let channels = clients.map(|client| {
let (tx, rx) = channel(1);
let (tx_host, rx_host) = channel(1);
Expand All @@ -43,49 +40,37 @@ impl<const N: usize> ClientEventsCombinator<N> {
clients[i] = Some(tx);
hosts_rx[i] = Some(rx_host);
}
let hosts_rx = hosts_rx.map(|h| h.unwrap());
let external_clients = [(); N].map(|_| HashMap::new());

for (i, rx) in hosts_rx.iter_mut().enumerate() {
let Some(mut rx) = rx.take() else {
continue;
};
pending_futs.push(
async move {
let res = rx.recv().await;
(i, rx, res)
}
.boxed(),
);
}

Self {
clients: clients.map(|c| c.unwrap()),
hosts_rx,
external_clients,
internal_clients: HashMap::new(),
pend_futs: [(); N].map(|_| None),
pending_futs,
}
}
}

impl<const N: usize> super::ClientEventsProxy for ClientEventsCombinator<N> {
fn recv<'a>(&'_ mut self) -> BoxFuture<'_, Result<OpenRequest<'static>, ClientError>> {
Box::pin(async {
let mut futs_opt = [(); N].map(|_| None);
let pend_futs = &mut self.pend_futs;
for (i, pend) in pend_futs.iter_mut().enumerate() {
let fut = &mut futs_opt[i];
if let Some(pend_fut) = pend.take() {
*fut = Some(pend_fut);
} else {
// this receiver ain't awaiting, queue a new one
// SAFETY: is safe here to extend the lifetime since clients are required to be 'static
// and we take ownership, so they will be alive for the duration of the program
let f = Box::pin(self.hosts_rx[i].recv())
as Pin<Box<dyn Future<Output = _> + Send + Sync + '_>>;
fn recv(&mut self) -> BoxFuture<'_, Result<OpenRequest<'static>, ClientError>> {
async {
let Some((idx, mut rx, res)) = self.pending_futs.next().await else {
unreachable!();
};

type ExtendedLife<'a, 'b> = Pin<
Box<
dyn Future<Output = Option<Result<OpenRequest<'a>, ClientError>>>
+ Send
+ Sync
+ 'b,
>,
>;
let new_pend = unsafe {
std::mem::transmute::<ExtendedLife<'_, '_>, ExtendedLife<'_, '_>>(f)
};
*fut = Some(new_pend);
}
}
let (res, idx, mut others) = select_all(futs_opt.map(|f| f.unwrap())).await;
let res = res
.map(|res| {
match res {
Expand All @@ -95,19 +80,15 @@ impl<const N: usize> super::ClientEventsProxy for ClientEventsCombinator<N> {
notification_channel,
token,
}) => {
tracing::debug!(
"received request; internal_id={external}; req={request}"
);
let id =
*self.external_clients[idx]
.entry(external)
.or_insert_with(|| {
// add a new mapped external client id
let internal = ClientId::next();
self.internal_clients.insert(internal, (idx, external));
internal
});

let id = *self.external_clients[idx]
.entry(external)
.or_insert_with(|| {
// add a new mapped external client id
let internal = ClientId::next();
self.internal_clients.insert(internal, (idx, external));
internal
});
tracing::debug!("received request for proxy #{idx}; internal_id={id}; external_id={external}; req={request}");
Ok(OpenRequest {
client_id: id,
request,
Expand All @@ -119,23 +100,26 @@ impl<const N: usize> super::ClientEventsProxy for ClientEventsCombinator<N> {
}
})
.unwrap_or_else(|| Err(ErrorKind::TransportProtocolDisconnect.into()));
// place back futs
debug_assert!(pend_futs.iter().all(|f| f.is_none()));
debug_assert_eq!(
others.iter().filter(|a| a.is_some()).count(),
pend_futs.len() - 1

self.pending_futs.push(
async move {
let res = rx.recv().await;
(idx, rx, res)
}
.boxed(),
);
std::mem::swap(pend_futs, &mut others);

res
})
}
.boxed()
}

fn send<'a>(
&mut self,
internal: ClientId,
response: Result<HostResponse, ClientError>,
) -> BoxFuture<'_, Result<(), ClientError>> {
Box::pin(async move {
async move {
let (idx, external) = self
.internal_clients
.get(&internal)
Expand All @@ -145,7 +129,8 @@ impl<const N: usize> super::ClientEventsProxy for ClientEventsCombinator<N> {
.await
.map_err(|_| ErrorKind::TransportProtocolDisconnect)?;
Ok(())
})
}
.boxed()
}
}

Expand All @@ -159,6 +144,7 @@ async fn client_fn(
host_msg = rx.recv() => {
if let Some((client_id, response)) = host_msg {
if client.send(client_id, response).await.is_err() {
eprintln!("disconnected host");
break;
}
} else {
Expand Down Expand Up @@ -189,74 +175,23 @@ async fn client_fn(
tracing::error!("Client shut down");
}

/// An optimized for the use case version of `futures::select_all` which keeps ordering.
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct SelectAll<Fut, const N: usize> {
waker: AtomicWaker,
inner: [Option<Fut>; N],
}

impl<Fut: Unpin, const N: usize> Unpin for SelectAll<Fut, N> {}

impl<Fut: Future + Unpin, const N: usize> Future for SelectAll<Fut, N> {
type Output = (Fut::Output, usize, [Option<Fut>; N]);

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
macro_rules! recv {
() => {
let item = self
.inner
.iter_mut()
.enumerate()
.find_map(|(i, f)| {
f.as_mut().map(|f| match f.poll_unpin(cx) {
Poll::Pending => None,
Poll::Ready(e) => Some((i, e)),
})
})
.flatten();
match item {
Some((idx, res)) => {
self.inner[idx] = None;
let rest = std::mem::replace(&mut self.inner, [(); N].map(|_| None));
return Poll::Ready((res, idx, rest));
}
None => {}
}
};
}
recv!();
self.waker.register(cx.waker());
recv!();
Poll::Pending
}
}

fn select_all<F, const N: usize>(iter: [F; N]) -> SelectAll<F, N>
where
F: Future + Unpin,
{
SelectAll {
waker: AtomicWaker::new(),
inner: iter.map(|f| Some(f)),
}
}

#[cfg(test)]
mod test {
use freenet_stdlib::client_api::ClientRequest;
use futures::try_join;

use super::*;
use crate::client_events::ClientEventsProxy;

struct SampleProxy {
id: usize,
rx: Receiver<usize>,
tx: Sender<usize>,
}

impl SampleProxy {
fn new(id: usize, rx: Receiver<usize>) -> Self {
Self { id, rx }
fn new(id: usize, rx: Receiver<usize>, tx: Sender<usize>) -> Self {
Self { id, rx, tx }
}
}

Expand All @@ -279,41 +214,105 @@ mod test {

fn send(
&mut self,
_id: ClientId,
id: ClientId,
_response: Result<HostResponse, ClientError>,
) -> BoxFuture<'_, Result<(), ClientError>> {
todo!()
assert_eq!(id.0, self.id);
async {
self.tx
.send(self.id)
.await
.map_err(|_| ErrorKind::ChannelClosed.into())
}
.boxed()
}
}

#[ignore]
#[tokio::test]
async fn combinator_recv() {
fn setup_proxies() -> ([BoxedClient; 3], Vec<Sender<usize>>, Vec<Receiver<usize>>) {
let mut cnt = 0;
let mut senders = vec![];
let proxies = [None::<()>; 3].map(|_| {
let (tx, rx) = channel(1);
senders.push(tx);
let r = Box::new(SampleProxy::new(cnt, rx)) as _;
let mut receivers = vec![];
let clients = [None::<()>; 3].map(|_| {
let (tx1, rx1) = channel(1);
let (tx2, rx2) = channel(1);
let r = Box::new(SampleProxy::new(cnt, rx1, tx2)) as _;
senders.push(tx1);
receivers.push(rx2);
cnt += 1;
r
});
(clients, senders, receivers)
}

#[tokio::test]
async fn combinator_recv() {
let (proxies, mut senders, _) = setup_proxies();
let mut combinator = ClientEventsCombinator::new(proxies);

let _senders = tokio::task::spawn(async move {
for (id, tx) in senders.iter_mut().enumerate() {
tx.send(id).await.unwrap();
eprintln!("sent msg {id}");
let _senders: Vec<Sender<usize>> = tokio::task::spawn(async move {
for _ in 1..4 {
for (id, tx) in senders.iter_mut().enumerate() {
tx.send(id).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
eprintln!("sent msg {id}");
}
}
senders
})
.await
.unwrap();

for i in 0..3 {
let OpenRequest { client_id: id, .. } = combinator.recv().await.unwrap();
eprintln!("received: {id:?}");
assert_eq!(ClientId::new(i), id);
for _ in 0..3 {
for i in 1..4 {
let OpenRequest { client_id: id, .. } = combinator.recv().await.unwrap();
eprintln!("received {i}: {id:?}");
assert_eq!(ClientId::new(i), id);
}
}
}

#[ignore]
#[tokio::test]
async fn test_send() {
let (proxies, _, mut receivers) = setup_proxies();
let mut combinator = ClientEventsCombinator::new(proxies);

for idx in 0..3 {
let client_id = ClientId::new(idx);
// Insert each client ID mapping into the combinator's internal clients.
combinator
.internal_clients
.insert(client_id, (idx, client_id));
}

let receivers = async move {
// Test sending a response through the combinator for each proxy.
for (idx, receiver) in receivers.iter_mut().enumerate() {
// Assert that the receiver received the expected message.
let received_id = receiver
.recv()
.await
.ok_or(format!("missing {idx} sender"))?;
assert_eq!(received_id, idx);
println!(
"Receiver {} confirmed send for client ID: {}",
idx, received_id
);
}
Ok::<_, Box<dyn std::error::Error>>(())
};

let senders = async {
for idx in 0..3 {
// Send a sample response through the combinator.
combinator
.send(ClientId::new(idx), Ok(HostResponse::Ok))
.await
.map_err(|err| format!("Send failed for client {idx}: {err}",))?;
}
Ok::<_, Box<dyn std::error::Error>>(())
};

try_join!(senders, receivers).unwrap();
}
}