Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(server): Refactor TcpIncoming #2052

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/client_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async fn connect_supports_standard_tower_layers() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

// Start the server now, second call should succeed
let jh = tokio::spawn(async move {
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/connect_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async fn getting_connect_info() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/tests/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async fn connect_returns_err_via_call_after_connected() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -85,7 +85,7 @@ async fn connect_lazy_reconnects_after_first_failure() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

// Start the server now, second call should succeed
let jh = tokio::spawn(async move {
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/tests/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async fn setting_extension_from_interceptor() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -90,7 +90,7 @@ async fn setting_extension_from_tower() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/tests/http2_keep_alive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async fn http2_keepalive_does_not_cause_panics() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -52,7 +52,7 @@ async fn http2_keepalive_does_not_cause_panics_on_client_side() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/tests/http2_max_header_list_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ async fn test_http_max_header_list_size_and_long_errors() {
let addr = format!("http://{}", listener.local_addr().unwrap());

let jh = tokio::spawn(async move {
let (nodelay, keepalive) = (true, Some(Duration::from_secs(1)));
let listener =
tonic::transport::server::TcpIncoming::from_listener(listener, nodelay, keepalive)
.unwrap();
let (nodelay, keepalive) = (Some(true), Some(Duration::from_secs(1)));
let listener = tonic::transport::server::TcpIncoming::from(listener)
.with_nodelay(nodelay)
.with_keepalive(keepalive);
Server::builder()
.http2_max_pending_accept_reset_streams(Some(0))
.add_service(svc)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/interceptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async fn interceptor_retrieves_grpc_method() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

// Start the server now, second call should succeed
let jh = tokio::spawn(async move {
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/origin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async fn writes_origin_header() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/routes_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async fn multiple_service_using_routes_builder() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/tests/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn status_with_details() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -94,7 +94,7 @@ async fn status_with_metadata() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -165,7 +165,7 @@ async fn status_from_server_stream() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -235,7 +235,7 @@ async fn message_and_then_status_from_server_stream() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

tokio::spawn(async move {
Server::builder()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/user_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async fn writes_user_agent_header() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
94 changes: 48 additions & 46 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::{
net::{SocketAddr, TcpListener as StdTcpListener},
pin::Pin,
task::{ready, Context, Poll},
task::{Context, Poll},
time::Duration,
};

use socket2::TcpKeepalive;
use tokio::net::{TcpListener, TcpStream};
use tokio_stream::{wrappers::TcpListenerStream, Stream};
use tracing::warn;
Expand All @@ -16,13 +17,13 @@ use tracing::warn;
#[derive(Debug)]
pub struct TcpIncoming {
inner: TcpListenerStream,
nodelay: bool,
keepalive: Option<Duration>,
nodelay: Option<bool>,
keepalive: Option<TcpKeepalive>,
}

impl TcpIncoming {
/// Creates an instance by binding (opening) the specified socket address
/// to which the specified TCP 'nodelay' and 'keepalive' parameters are applied.
/// Creates an instance by binding (opening) the specified socket address.
///
/// Returns a TcpIncoming if the socket address was successfully bound.
///
/// # Examples
Expand All @@ -42,7 +43,7 @@ impl TcpIncoming {
/// let mut port = 1322;
/// let tinc = loop {
/// let addr = format!("127.0.0.1:{}", port).parse().unwrap();
/// match TcpIncoming::new(addr, true, None) {
/// match TcpIncoming::bind(addr) {
/// Ok(t) => break t,
/// Err(_) => port += 1
/// }
Expand All @@ -52,64 +53,65 @@ impl TcpIncoming {
/// .serve_with_incoming(tinc);
/// # Ok(())
/// # }
pub fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::BoxError> {
pub fn bind(addr: SocketAddr) -> std::io::Result<Self> {
let std_listener = StdTcpListener::bind(addr)?;
std_listener.set_nonblocking(true)?;

let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?);
Ok(Self {
inner,
nodelay,
keepalive,
})
Ok(TcpListener::from_std(std_listener)?.into())
}

/// Sets the `TCP_NODELAY` option on the accepted connection.
pub fn with_nodelay(self, nodelay: Option<bool>) -> Self {
Self { nodelay, ..self }
}

/// Sets the `TCP_KEEPALIVE` option on the accepted connection.
pub fn with_keepalive(self, keepalive: Option<Duration>) -> Self {
let keepalive = keepalive.map(|t| TcpKeepalive::new().with_time(t));
Self { keepalive, ..self }
}
}

/// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`.
pub fn from_listener(
listener: TcpListener,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::BoxError> {
Ok(Self {
impl From<TcpListener> for TcpIncoming {
fn from(listener: TcpListener) -> Self {
Self {
inner: TcpListenerStream::new(listener),
nodelay,
keepalive,
})
nodelay: None,
keepalive: None,
}
}
}

impl Stream for TcpIncoming {
type Item = Result<TcpStream, std::io::Error>;
type Item = std::io::Result<TcpStream>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Ok(stream)) => {
set_accepted_socket_options(&stream, self.nodelay, self.keepalive);
Some(Ok(stream)).into()
}
other => Poll::Ready(other),
let polled = Pin::new(&mut self.inner).poll_next(cx);

if let Poll::Ready(Some(Ok(stream))) = &polled {
set_accepted_socket_options(stream, self.nodelay, &self.keepalive);
}

polled
}
}

// Consistent with hyper-0.14, this function does not return an error.
fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option<Duration>) {
if nodelay {
if let Err(e) = stream.set_nodelay(true) {
warn!("error trying to set TCP nodelay: {}", e);
fn set_accepted_socket_options(
stream: &TcpStream,
nodelay: Option<bool>,
keepalive: &Option<TcpKeepalive>,
) {
if let Some(nodelay) = nodelay {
if let Err(e) = stream.set_nodelay(nodelay) {
warn!("error trying to set TCP_NODELAY: {e}");
}
}

if let Some(timeout) = keepalive {
if let Some(keepalive) = keepalive {
let sock_ref = socket2::SockRef::from(&stream);
let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);

if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
warn!("error trying to set TCP keepalive: {}", e);
if let Err(e) = sock_ref.set_tcp_keepalive(keepalive) {
warn!("error trying to set TCP_KEEPALIVE: {e}");
}
}
}
Expand All @@ -121,9 +123,9 @@ mod tests {
async fn one_tcpincoming_at_a_time() {
let addr = "127.0.0.1:1322".parse().unwrap();
{
let _t1 = TcpIncoming::new(addr, true, None).unwrap();
let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
let _t1 = TcpIncoming::bind(addr).unwrap();
let _t2 = TcpIncoming::bind(addr).unwrap_err();
}
let _t3 = TcpIncoming::new(addr, true, None).unwrap();
let _t3 = TcpIncoming::bind(addr).unwrap();
}
}
12 changes: 8 additions & 4 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,10 @@ impl<L> Router<L> {
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
.map_err(super::Error::from_source)?;
let incoming = TcpIncoming::bind(addr)
.map_err(super::Error::from_source)?
.with_nodelay(Some(self.server.tcp_nodelay))
.with_keepalive(self.server.tcp_keepalive);
self.server
.serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>(
self.routes.prepare(),
Expand Down Expand Up @@ -809,8 +811,10 @@ impl<L> Router<L> {
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
.map_err(super::Error::from_source)?;
let incoming = TcpIncoming::bind(addr)
.map_err(super::Error::from_source)?
.with_nodelay(Some(self.server.tcp_nodelay))
.with_keepalive(self.server.tcp_keepalive);
self.server
.serve_with_shutdown(self.routes.prepare(), incoming, Some(signal))
.await
Expand Down