From d136425c0cd65699aa21c6bfa757f5ae8c114cdb Mon Sep 17 00:00:00 2001 From: tottoto Date: Sat, 23 Nov 2024 22:08:13 +0900 Subject: [PATCH] chore(server): Use same non tls logic at server io stream --- tonic/src/transport/server/io_stream.rs | 56 +++++++++++++++---------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/tonic/src/transport/server/io_stream.rs b/tonic/src/transport/server/io_stream.rs index e873e4def..57e03fe8c 100644 --- a/tonic/src/transport/server/io_stream.rs +++ b/tonic/src/transport/server/io_stream.rs @@ -45,18 +45,15 @@ where tasks: JoinSet::new(), } } -} -impl Stream for ServerIoStream -where - S: Stream>, - IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, - IE: Into, -{ - type Item = Result, crate::BoxError>; - - #[cfg(not(feature = "_tls-any"))] - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next_without_tls( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, crate::BoxError>>> + where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into, + { match ready!(self.as_mut().project().inner.as_mut().poll_next(cx)) { Some(Ok(io)) => Poll::Ready(Some(Ok(ServerIo::new_io(io)))), Some(Err(e)) => match handle_tcp_accept_error(e) { @@ -69,29 +66,42 @@ where None => Poll::Ready(None), } } +} + +impl Stream for ServerIoStream +where + S: Stream>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into, +{ + type Item = Result, crate::BoxError>; + + #[cfg(not(feature = "_tls-any"))] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_next_without_tls(cx) + } #[cfg(feature = "_tls-any")] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut projected = self.as_mut().project(); - let tls = projected.tls; + let Some(tls) = projected.tls else { + return self.poll_next_without_tls(cx); + }; + let tasks = projected.tasks; let select_output = ready!(pin!(select(&mut projected.inner, tasks)).poll(cx)); match select_output { SelectOutput::Incoming(stream) => { - if let Some(tls) = tls { - let tls = tls.clone(); - tasks.spawn(async move { - let io = tls.accept(stream).await?; - Ok(ServerIo::new_tls_io(io)) - }); - cx.waker().wake_by_ref(); - Poll::Pending - } else { - Poll::Ready(Some(Ok(ServerIo::new_io(stream)))) - } + let tls = tls.clone(); + tasks.spawn(async move { + let io = tls.accept(stream).await?; + Ok(ServerIo::new_tls_io(io)) + }); + cx.waker().wake_by_ref(); + Poll::Pending } SelectOutput::Io(io) => Poll::Ready(Some(Ok(io))),