Skip to content

Commit

Permalink
chore(server): Use same non tls logic at server io stream
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Nov 30, 2024
1 parent 9367490 commit d136425
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions tonic/src/transport/server/io_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,15 @@ where
tasks: JoinSet::new(),
}
}
}

impl<S, IO, IE> Stream for ServerIoStream<S, IO, IE>
where
S: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
{
type Item = Result<ServerIo<IO>, crate::BoxError>;

#[cfg(not(feature = "_tls-any"))]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn poll_next_without_tls(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<ServerIo<IO>, crate::BoxError>>>
where
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
{
match ready!(self.as_mut().project().inner.as_mut().poll_next(cx)) {
Some(Ok(io)) => Poll::Ready(Some(Ok(ServerIo::new_io(io)))),
Some(Err(e)) => match handle_tcp_accept_error(e) {
Expand All @@ -69,29 +66,42 @@ where
None => Poll::Ready(None),
}
}
}

impl<S, IO, IE> Stream for ServerIoStream<S, IO, IE>
where
S: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::BoxError>,
{
type Item = Result<ServerIo<IO>, crate::BoxError>;

#[cfg(not(feature = "_tls-any"))]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_next_without_tls(cx)
}

#[cfg(feature = "_tls-any")]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut projected = self.as_mut().project();

let tls = projected.tls;
let Some(tls) = projected.tls else {
return self.poll_next_without_tls(cx);
};

let tasks = projected.tasks;

let select_output = ready!(pin!(select(&mut projected.inner, tasks)).poll(cx));

match select_output {
SelectOutput::Incoming(stream) => {
if let Some(tls) = tls {
let tls = tls.clone();
tasks.spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new_tls_io(io))
});
cx.waker().wake_by_ref();
Poll::Pending
} else {
Poll::Ready(Some(Ok(ServerIo::new_io(stream))))
}
let tls = tls.clone();
tasks.spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new_tls_io(io))
});
cx.waker().wake_by_ref();
Poll::Pending
}

SelectOutput::Io(io) => Poll::Ready(Some(Ok(io))),
Expand Down

0 comments on commit d136425

Please sign in to comment.