diff --git a/tests/integration_tests/tests/client_layer.rs b/tests/integration_tests/tests/client_layer.rs index f27a5031f..6bf60fc50 100644 --- a/tests/integration_tests/tests/client_layer.rs +++ b/tests/integration_tests/tests/client_layer.rs @@ -28,7 +28,7 @@ async fn connect_supports_standard_tower_layers() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); // Start the server now, second call should succeed let jh = tokio::spawn(async move { diff --git a/tests/integration_tests/tests/connect_info.rs b/tests/integration_tests/tests/connect_info.rs index 08d5dd7ef..0671732c4 100644 --- a/tests/integration_tests/tests/connect_info.rs +++ b/tests/integration_tests/tests/connect_info.rs @@ -30,7 +30,7 @@ async fn getting_connect_info() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() diff --git a/tests/integration_tests/tests/connection.rs b/tests/integration_tests/tests/connection.rs index feefbe4a7..1bc2b1740 100644 --- a/tests/integration_tests/tests/connection.rs +++ b/tests/integration_tests/tests/connection.rs @@ -42,7 +42,7 @@ async fn connect_returns_err_via_call_after_connected() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() @@ -85,7 +85,7 @@ async fn connect_lazy_reconnects_after_first_failure() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); // Start the server now, second call should succeed let jh = tokio::spawn(async move { diff --git a/tests/integration_tests/tests/extensions.rs b/tests/integration_tests/tests/extensions.rs index f7f4d6b78..22ad609d4 100644 --- a/tests/integration_tests/tests/extensions.rs +++ b/tests/integration_tests/tests/extensions.rs @@ -41,7 +41,7 @@ async fn setting_extension_from_interceptor() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() @@ -90,7 +90,7 @@ async fn setting_extension_from_tower() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() diff --git a/tests/integration_tests/tests/http2_keep_alive.rs b/tests/integration_tests/tests/http2_keep_alive.rs index 6bf128a26..13b71b6f7 100644 --- a/tests/integration_tests/tests/http2_keep_alive.rs +++ b/tests/integration_tests/tests/http2_keep_alive.rs @@ -22,7 +22,7 @@ async fn http2_keepalive_does_not_cause_panics() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() @@ -52,7 +52,7 @@ async fn http2_keepalive_does_not_cause_panics_on_client_side() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() diff --git a/tests/integration_tests/tests/http2_max_header_list_size.rs b/tests/integration_tests/tests/http2_max_header_list_size.rs index 2909bf048..61e31fd50 100644 --- a/tests/integration_tests/tests/http2_max_header_list_size.rs +++ b/tests/integration_tests/tests/http2_max_header_list_size.rs @@ -34,10 +34,10 @@ async fn test_http_max_header_list_size_and_long_errors() { let addr = format!("http://{}", listener.local_addr().unwrap()); let jh = tokio::spawn(async move { - let (nodelay, keepalive) = (true, Some(Duration::from_secs(1))); - let listener = - tonic::transport::server::TcpIncoming::from_listener(listener, nodelay, keepalive) - .unwrap(); + let (nodelay, keepalive) = (Some(true), Some(Duration::from_secs(1))); + let listener = tonic::transport::server::TcpIncoming::from(listener) + .with_nodelay(nodelay) + .with_keepalive(keepalive); Server::builder() .http2_max_pending_accept_reset_streams(Some(0)) .add_service(svc) diff --git a/tests/integration_tests/tests/interceptor.rs b/tests/integration_tests/tests/interceptor.rs index f22258e0e..aec9c35fe 100644 --- a/tests/integration_tests/tests/interceptor.rs +++ b/tests/integration_tests/tests/interceptor.rs @@ -25,7 +25,7 @@ async fn interceptor_retrieves_grpc_method() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); // Start the server now, second call should succeed let jh = tokio::spawn(async move { diff --git a/tests/integration_tests/tests/origin.rs b/tests/integration_tests/tests/origin.rs index 95e7bee0b..291ed47d7 100644 --- a/tests/integration_tests/tests/origin.rs +++ b/tests/integration_tests/tests/origin.rs @@ -33,7 +33,7 @@ async fn writes_origin_header() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() diff --git a/tests/integration_tests/tests/routes_builder.rs b/tests/integration_tests/tests/routes_builder.rs index 1dd065720..aeb975a50 100644 --- a/tests/integration_tests/tests/routes_builder.rs +++ b/tests/integration_tests/tests/routes_builder.rs @@ -59,7 +59,7 @@ async fn multiple_service_using_routes_builder() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() diff --git a/tests/integration_tests/tests/status.rs b/tests/integration_tests/tests/status.rs index 5ea3cc20d..b388611f9 100644 --- a/tests/integration_tests/tests/status.rs +++ b/tests/integration_tests/tests/status.rs @@ -36,7 +36,7 @@ async fn status_with_details() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() @@ -94,7 +94,7 @@ async fn status_with_metadata() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() @@ -165,7 +165,7 @@ async fn status_from_server_stream() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); tokio::spawn(async move { Server::builder() @@ -235,7 +235,7 @@ async fn message_and_then_status_from_server_stream() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); tokio::spawn(async move { Server::builder() diff --git a/tests/integration_tests/tests/user_agent.rs b/tests/integration_tests/tests/user_agent.rs index 9ab8306a9..5849fd4ff 100644 --- a/tests/integration_tests/tests/user_agent.rs +++ b/tests/integration_tests/tests/user_agent.rs @@ -26,7 +26,7 @@ async fn writes_user_agent_header() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); let jh = tokio::spawn(async move { Server::builder() diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index f02c7b4f8..a3ff0da07 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -1,10 +1,11 @@ use std::{ net::{SocketAddr, TcpListener as StdTcpListener}, pin::Pin, - task::{ready, Context, Poll}, + task::{Context, Poll}, time::Duration, }; +use socket2::TcpKeepalive; use tokio::net::{TcpListener, TcpStream}; use tokio_stream::{wrappers::TcpListenerStream, Stream}; use tracing::warn; @@ -16,13 +17,13 @@ use tracing::warn; #[derive(Debug)] pub struct TcpIncoming { inner: TcpListenerStream, - nodelay: bool, - keepalive: Option, + nodelay: Option, + keepalive: Option, } impl TcpIncoming { - /// Creates an instance by binding (opening) the specified socket address - /// to which the specified TCP 'nodelay' and 'keepalive' parameters are applied. + /// Creates an instance by binding (opening) the specified socket address. + /// /// Returns a TcpIncoming if the socket address was successfully bound. /// /// # Examples @@ -42,7 +43,7 @@ impl TcpIncoming { /// let mut port = 1322; /// let tinc = loop { /// let addr = format!("127.0.0.1:{}", port).parse().unwrap(); - /// match TcpIncoming::new(addr, true, None) { + /// match TcpIncoming::bind(addr) { /// Ok(t) => break t, /// Err(_) => port += 1 /// } @@ -52,64 +53,65 @@ impl TcpIncoming { /// .serve_with_incoming(tinc); /// # Ok(()) /// # } - pub fn new( - addr: SocketAddr, - nodelay: bool, - keepalive: Option, - ) -> Result { + pub fn bind(addr: SocketAddr) -> std::io::Result { let std_listener = StdTcpListener::bind(addr)?; std_listener.set_nonblocking(true)?; - let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?); - Ok(Self { - inner, - nodelay, - keepalive, - }) + Ok(TcpListener::from_std(std_listener)?.into()) + } + + /// Sets the `TCP_NODELAY` option on the accepted connection. + pub fn with_nodelay(self, nodelay: Option) -> Self { + Self { nodelay, ..self } + } + + /// Sets the `TCP_KEEPALIVE` option on the accepted connection. + pub fn with_keepalive(self, keepalive: Option) -> Self { + let keepalive = keepalive.map(|t| TcpKeepalive::new().with_time(t)); + Self { keepalive, ..self } } +} - /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`. - pub fn from_listener( - listener: TcpListener, - nodelay: bool, - keepalive: Option, - ) -> Result { - Ok(Self { +impl From for TcpIncoming { + fn from(listener: TcpListener) -> Self { + Self { inner: TcpListenerStream::new(listener), - nodelay, - keepalive, - }) + nodelay: None, + keepalive: None, + } } } impl Stream for TcpIncoming { - type Item = Result; + type Item = std::io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(Pin::new(&mut self.inner).poll_next(cx)) { - Some(Ok(stream)) => { - set_accepted_socket_options(&stream, self.nodelay, self.keepalive); - Some(Ok(stream)).into() - } - other => Poll::Ready(other), + let polled = Pin::new(&mut self.inner).poll_next(cx); + + if let Poll::Ready(Some(Ok(stream))) = &polled { + set_accepted_socket_options(stream, self.nodelay, &self.keepalive); } + + polled } } // Consistent with hyper-0.14, this function does not return an error. -fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option) { - if nodelay { - if let Err(e) = stream.set_nodelay(true) { - warn!("error trying to set TCP nodelay: {}", e); +fn set_accepted_socket_options( + stream: &TcpStream, + nodelay: Option, + keepalive: &Option, +) { + if let Some(nodelay) = nodelay { + if let Err(e) = stream.set_nodelay(nodelay) { + warn!("error trying to set TCP_NODELAY: {e}"); } } - if let Some(timeout) = keepalive { + if let Some(keepalive) = keepalive { let sock_ref = socket2::SockRef::from(&stream); - let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout); - - if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) { - warn!("error trying to set TCP keepalive: {}", e); + if let Err(e) = sock_ref.set_tcp_keepalive(keepalive) { + warn!("error trying to set TCP_KEEPALIVE: {e}"); } } } @@ -121,9 +123,9 @@ mod tests { async fn one_tcpincoming_at_a_time() { let addr = "127.0.0.1:1322".parse().unwrap(); { - let _t1 = TcpIncoming::new(addr, true, None).unwrap(); - let _t2 = TcpIncoming::new(addr, true, None).unwrap_err(); + let _t1 = TcpIncoming::bind(addr).unwrap(); + let _t2 = TcpIncoming::bind(addr).unwrap_err(); } - let _t3 = TcpIncoming::new(addr, true, None).unwrap(); + let _t3 = TcpIncoming::bind(addr).unwrap(); } } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 114c5f452..77bf00710 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -778,8 +778,10 @@ impl Router { ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { - let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) - .map_err(super::Error::from_source)?; + let incoming = TcpIncoming::bind(addr) + .map_err(super::Error::from_source)? + .with_nodelay(Some(self.server.tcp_nodelay)) + .with_keepalive(self.server.tcp_keepalive); self.server .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>( self.routes.prepare(), @@ -809,8 +811,10 @@ impl Router { ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { - let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) - .map_err(super::Error::from_source)?; + let incoming = TcpIncoming::bind(addr) + .map_err(super::Error::from_source)? + .with_nodelay(Some(self.server.tcp_nodelay)) + .with_keepalive(self.server.tcp_keepalive); self.server .serve_with_shutdown(self.routes.prepare(), incoming, Some(signal)) .await