diff --git a/tests/integration_tests/tests/http2_max_header_list_size.rs b/tests/integration_tests/tests/http2_max_header_list_size.rs index 2909bf048..61e31fd50 100644 --- a/tests/integration_tests/tests/http2_max_header_list_size.rs +++ b/tests/integration_tests/tests/http2_max_header_list_size.rs @@ -34,10 +34,10 @@ async fn test_http_max_header_list_size_and_long_errors() { let addr = format!("http://{}", listener.local_addr().unwrap()); let jh = tokio::spawn(async move { - let (nodelay, keepalive) = (true, Some(Duration::from_secs(1))); - let listener = - tonic::transport::server::TcpIncoming::from_listener(listener, nodelay, keepalive) - .unwrap(); + let (nodelay, keepalive) = (Some(true), Some(Duration::from_secs(1))); + let listener = tonic::transport::server::TcpIncoming::from(listener) + .with_nodelay(nodelay) + .with_keepalive(keepalive); Server::builder() .http2_max_pending_accept_reset_streams(Some(0)) .add_service(svc) diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index f02c7b4f8..25b015c46 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -5,6 +5,7 @@ use std::{ time::Duration, }; +use socket2::TcpKeepalive; use tokio::net::{TcpListener, TcpStream}; use tokio_stream::{wrappers::TcpListenerStream, Stream}; use tracing::warn; @@ -16,13 +17,13 @@ use tracing::warn; #[derive(Debug)] pub struct TcpIncoming { inner: TcpListenerStream, - nodelay: bool, - keepalive: Option, + nodelay: Option, + keepalive: Option, } impl TcpIncoming { - /// Creates an instance by binding (opening) the specified socket address - /// to which the specified TCP 'nodelay' and 'keepalive' parameters are applied. + /// Creates an instance by binding (opening) the specified socket address. + /// /// Returns a TcpIncoming if the socket address was successfully bound. /// /// # Examples @@ -42,7 +43,7 @@ impl TcpIncoming { /// let mut port = 1322; /// let tinc = loop { /// let addr = format!("127.0.0.1:{}", port).parse().unwrap(); - /// match TcpIncoming::new(addr, true, None) { + /// match TcpIncoming::new(addr) { /// Ok(t) => break t, /// Err(_) => port += 1 /// } @@ -52,43 +53,42 @@ impl TcpIncoming { /// .serve_with_incoming(tinc); /// # Ok(()) /// # } - pub fn new( - addr: SocketAddr, - nodelay: bool, - keepalive: Option, - ) -> Result { + pub fn new(addr: SocketAddr) -> std::io::Result { let std_listener = StdTcpListener::bind(addr)?; std_listener.set_nonblocking(true)?; - let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?); - Ok(Self { - inner, - nodelay, - keepalive, - }) + Ok(TcpListener::from_std(std_listener)?.into()) + } + + /// Sets the `TCP_NODELAY` option on the accepted connection. + pub fn with_nodelay(self, nodelay: Option) -> Self { + Self { nodelay, ..self } + } + + /// Sets the `TCP_KEEPALIVE` option on the accepted connection. + pub fn with_keepalive(self, keepalive: Option) -> Self { + let keepalive = keepalive.map(|t| TcpKeepalive::new().with_time(t)); + Self { keepalive, ..self } } +} - /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`. - pub fn from_listener( - listener: TcpListener, - nodelay: bool, - keepalive: Option, - ) -> Result { - Ok(Self { +impl From for TcpIncoming { + fn from(listener: TcpListener) -> Self { + Self { inner: TcpListenerStream::new(listener), - nodelay, - keepalive, - }) + nodelay: None, + keepalive: None, + } } } impl Stream for TcpIncoming { - type Item = Result; + type Item = std::io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match ready!(Pin::new(&mut self.inner).poll_next(cx)) { Some(Ok(stream)) => { - set_accepted_socket_options(&stream, self.nodelay, self.keepalive); + set_accepted_socket_options(&stream, self.nodelay, &self.keepalive); Some(Ok(stream)).into() } other => Poll::Ready(other), @@ -97,19 +97,21 @@ impl Stream for TcpIncoming { } // Consistent with hyper-0.14, this function does not return an error. -fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option) { - if nodelay { - if let Err(e) = stream.set_nodelay(true) { - warn!("error trying to set TCP nodelay: {}", e); +fn set_accepted_socket_options( + stream: &TcpStream, + nodelay: Option, + keepalive: &Option, +) { + if let Some(nodelay) = nodelay { + if let Err(e) = stream.set_nodelay(nodelay) { + warn!("error trying to set TCP_NODELAY: {e}"); } } - if let Some(timeout) = keepalive { + if let Some(keepalive) = keepalive { let sock_ref = socket2::SockRef::from(&stream); - let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout); - - if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) { - warn!("error trying to set TCP keepalive: {}", e); + if let Err(e) = sock_ref.set_tcp_keepalive(keepalive) { + warn!("error trying to set TCP_KEEPALIVE: {e}"); } } } @@ -121,9 +123,9 @@ mod tests { async fn one_tcpincoming_at_a_time() { let addr = "127.0.0.1:1322".parse().unwrap(); { - let _t1 = TcpIncoming::new(addr, true, None).unwrap(); - let _t2 = TcpIncoming::new(addr, true, None).unwrap_err(); + let _t1 = TcpIncoming::new(addr).unwrap(); + let _t2 = TcpIncoming::new(addr).unwrap_err(); } - let _t3 = TcpIncoming::new(addr, true, None).unwrap(); + let _t3 = TcpIncoming::new(addr).unwrap(); } } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index dc344615e..fc6ef8dbb 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -779,8 +779,10 @@ impl Router { ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { - let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) - .map_err(super::Error::from_source)?; + let incoming = TcpIncoming::new(addr) + .map_err(super::Error::from_source)? + .with_nodelay(Some(self.server.tcp_nodelay)) + .with_keepalive(self.server.tcp_keepalive); self.server .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>( self.routes.prepare(), @@ -810,8 +812,10 @@ impl Router { ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { - let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) - .map_err(super::Error::from_source)?; + let incoming = TcpIncoming::new(addr) + .map_err(super::Error::from_source)? + .with_nodelay(Some(self.server.tcp_nodelay)) + .with_keepalive(self.server.tcp_keepalive); self.server .serve_with_shutdown(self.routes.prepare(), incoming, Some(signal)) .await