Skip to content

Commit

Permalink
chore(server): Add ServerIoStream type
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Dec 3, 2024
1 parent 9606c5b commit d2f0d97
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 58 deletions.
2 changes: 0 additions & 2 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ tls-webpki-roots = ["_tls-any","channel", "dep:webpki-roots"]
router = ["dep:axum", "dep:tower", "tower?/util"]
server = [
"router",
"dep:async-stream",
"dep:h2",
"dep:hyper", "hyper?/server",
"dep:hyper-util", "hyper-util?/service", "hyper-util?/server-auto",
Expand Down Expand Up @@ -79,7 +78,6 @@ prost = {version = "0.13", default-features = false, features = ["std"], optiona
async-trait = {version = "0.1.13", optional = true}

# transport
async-stream = {version = "0.3", optional = true}
h2 = {version = "0.4", optional = true}
hyper = {version = "1", features = ["http1", "http2"], optional = true}
hyper-util = { version = "0.1.4", features = ["tokio"], optional = true }
Expand Down
145 changes: 90 additions & 55 deletions tonic/src/transport/server/io_stream.rs
Original file line number Diff line number Diff line change
@@ -1,81 +1,116 @@
use std::{io, ops::ControlFlow, pin::pin};

#[cfg(feature = "_tls-any")]
use std::future::Future;
use std::{
io,
ops::ControlFlow,
pin::{pin, Pin},
task::{ready, Context, Poll},
};

use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_stream::{Stream, StreamExt as _};
#[cfg(feature = "_tls-any")]
use tokio::task::JoinSet;
use tokio_stream::Stream;
#[cfg(feature = "_tls-any")]
use tokio_stream::StreamExt as _;

use super::service::ServerIo;
#[cfg(feature = "_tls-any")]
use super::service::TlsAcceptor;

#[cfg(not(feature = "_tls-any"))]
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
) -> impl Stream<Item = Result<ServerIo<IO>, crate::BoxError>>
#[pin_project]
pub(crate) struct ServerIoStream<S, IO, IE>
where
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
S: Stream<Item = Result<IO, IE>>,
{
async_stream::try_stream! {
let mut incoming = pin!(incoming);

while let Some(item) = incoming.next().await {
yield match item {
Ok(_) => item.map(ServerIo::new_io)?,
Err(e) => match handle_tcp_accept_error(e) {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(e) => Err(e)?,
}
}
#[pin]
inner: S,
#[cfg(feature = "_tls-any")]
tls: Option<TlsAcceptor>,
#[cfg(feature = "_tls-any")]
tasks: JoinSet<Result<ServerIo<IO>, crate::BoxError>>,
}

impl<S, IO, IE> ServerIoStream<S, IO, IE>
where
S: Stream<Item = Result<IO, IE>>,
{
pub(crate) fn new(incoming: S, #[cfg(feature = "_tls-any")] tls: Option<TlsAcceptor>) -> Self {
Self {
inner: incoming,
#[cfg(feature = "_tls-any")]
tls,
#[cfg(feature = "_tls-any")]
tasks: JoinSet::new(),
}
}
}

#[cfg(feature = "_tls-any")]
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
tls: Option<TlsAcceptor>,
) -> impl Stream<Item = Result<ServerIo<IO>, crate::BoxError>>
impl<S, IO, IE> Stream for ServerIoStream<S, IO, IE>
where
S: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
{
async_stream::try_stream! {
let mut incoming = pin!(incoming);

let mut tasks = tokio::task::JoinSet::new();

loop {
match select(&mut incoming, &mut tasks).await {
SelectOutput::Incoming(stream) => {
if let Some(tls) = &tls {
let tls = tls.clone();
tasks.spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new_tls_io(io))
});
} else {
yield ServerIo::new_io(stream);
}
type Item = Result<ServerIo<IO>, crate::BoxError>;

#[cfg(not(feature = "_tls-any"))]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().project().inner.as_mut().poll_next(cx)) {
Some(Ok(io)) => Poll::Ready(Some(Ok(ServerIo::new_io(io)))),
Some(Err(e)) => match handle_tcp_accept_error(e) {
ControlFlow::Continue(()) => {
cx.waker().wake_by_ref();
Poll::Pending
}
ControlFlow::Break(e) => Poll::Ready(Some(Err(e))),
},
None => Poll::Ready(None),
}
}

SelectOutput::Io(io) => {
yield io;
#[cfg(feature = "_tls-any")]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut projected = self.as_mut().project();

let tls = projected.tls;
let tasks = projected.tasks;

let select_output = ready!(pin!(select(&mut projected.inner, tasks)).poll(cx));

match select_output {
SelectOutput::Incoming(stream) => {
if let Some(tls) = tls {
let tls = tls.clone();
tasks.spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new_tls_io(io))
});
cx.waker().wake_by_ref();
Poll::Pending
} else {
Poll::Ready(Some(Ok(ServerIo::new_io(stream))))
}
}

SelectOutput::TcpErr(e) => match handle_tcp_accept_error(e) {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(e) => Err(e)?,
}
SelectOutput::Io(io) => Poll::Ready(Some(Ok(io))),

SelectOutput::TlsErr(e) => {
tracing::debug!(error = %e, "tls accept error");
continue;
SelectOutput::TcpErr(e) => match handle_tcp_accept_error(e) {
ControlFlow::Continue(()) => {
cx.waker().wake_by_ref();
Poll::Pending
}
ControlFlow::Break(e) => Poll::Ready(Some(Err(e))),
},

SelectOutput::Done => {
break;
}
SelectOutput::TlsErr(e) => {
tracing::debug!(error = %e, "tls accept error");
cx.waker().wake_by_ref();
Poll::Pending
}

SelectOutput::Done => Poll::Ready(None),
}
}
}
Expand Down Expand Up @@ -103,7 +138,7 @@ fn handle_tcp_accept_error(e: impl Into<crate::BoxError>) -> ControlFlow<crate::
#[cfg(feature = "_tls-any")]
async fn select<IO: 'static, IE>(
incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::BoxError>>,
tasks: &mut JoinSet<Result<ServerIo<IO>, crate::BoxError>>,
) -> SelectOutput<IO>
where
IE: Into<crate::BoxError>,
Expand Down
2 changes: 1 addition & 1 deletion tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ impl<L> Server<L> {

let svc = self.service_builder.service(svc);

let incoming = io_stream::tcp_incoming(
let incoming = io_stream::ServerIoStream::new(
incoming,
#[cfg(feature = "_tls-any")]
self.tls,
Expand Down

0 comments on commit d2f0d97

Please sign in to comment.