From df13009fe7ce7f5537208262664a91de6e79f51c Mon Sep 17 00:00:00 2001 From: tottoto Date: Sat, 23 Nov 2024 22:08:13 +0900 Subject: [PATCH] chore(server): Use same non tls logic at server io stream --- tonic/src/transport/server/io_stream.rs | 66 ++++++++++++++----------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/tonic/src/transport/server/io_stream.rs b/tonic/src/transport/server/io_stream.rs index e873e4def..222970935 100644 --- a/tonic/src/transport/server/io_stream.rs +++ b/tonic/src/transport/server/io_stream.rs @@ -19,6 +19,9 @@ use super::service::ServerIo; #[cfg(feature = "_tls-any")] use super::service::TlsAcceptor; +#[cfg(feature = "_tls-any")] +struct State(TlsAcceptor, JoinSet, crate::BoxError>>); + #[pin_project] pub(crate) struct ServerIoStream where @@ -27,9 +30,7 @@ where #[pin] inner: S, #[cfg(feature = "_tls-any")] - tls: Option, - #[cfg(feature = "_tls-any")] - tasks: JoinSet, crate::BoxError>>, + state: Option>, } impl ServerIoStream @@ -40,23 +41,17 @@ where Self { inner: incoming, #[cfg(feature = "_tls-any")] - tls, - #[cfg(feature = "_tls-any")] - tasks: JoinSet::new(), + state: tls.map(|tls| State(tls, JoinSet::new())), } } -} - -impl Stream for ServerIoStream -where - S: Stream>, - IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, - IE: Into, -{ - type Item = Result, crate::BoxError>; - #[cfg(not(feature = "_tls-any"))] - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next_without_tls( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, crate::BoxError>>> + where + IE: Into, + { match ready!(self.as_mut().project().inner.as_mut().poll_next(cx)) { Some(Ok(io)) => Poll::Ready(Some(Ok(ServerIo::new_io(io)))), Some(Err(e)) => match handle_tcp_accept_error(e) { @@ -69,29 +64,40 @@ where None => Poll::Ready(None), } } +} + +impl Stream for ServerIoStream +where + S: Stream>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into, +{ + type Item = Result, crate::BoxError>; + + #[cfg(not(feature = "_tls-any"))] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_next_without_tls(cx) + } #[cfg(feature = "_tls-any")] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut projected = self.as_mut().project(); - let tls = projected.tls; - let tasks = projected.tasks; + let Some(State(tls, tasks)) = projected.state else { + return self.poll_next_without_tls(cx); + }; let select_output = ready!(pin!(select(&mut projected.inner, tasks)).poll(cx)); match select_output { SelectOutput::Incoming(stream) => { - if let Some(tls) = tls { - let tls = tls.clone(); - tasks.spawn(async move { - let io = tls.accept(stream).await?; - Ok(ServerIo::new_tls_io(io)) - }); - cx.waker().wake_by_ref(); - Poll::Pending - } else { - Poll::Ready(Some(Ok(ServerIo::new_io(stream)))) - } + let tls = tls.clone(); + tasks.spawn(async move { + let io = tls.accept(stream).await?; + Ok(ServerIo::new_tls_io(io)) + }); + cx.waker().wake_by_ref(); + Poll::Pending } SelectOutput::Io(io) => Poll::Ready(Some(Ok(io))),