Skip to content

Commit

Permalink
chore(server): Refactor TcpIncoming
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Nov 24, 2024
1 parent 2f2afb0 commit 3abbfe5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 48 deletions.
8 changes: 4 additions & 4 deletions tests/integration_tests/tests/http2_max_header_list_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ async fn test_http_max_header_list_size_and_long_errors() {
let addr = format!("http://{}", listener.local_addr().unwrap());

let jh = tokio::spawn(async move {
let (nodelay, keepalive) = (true, Some(Duration::from_secs(1)));
let listener =
tonic::transport::server::TcpIncoming::from_listener(listener, nodelay, keepalive)
.unwrap();
let (nodelay, keepalive) = (Some(true), Some(Duration::from_secs(1)));
let listener = tonic::transport::server::TcpIncoming::from(listener)
.with_nodelay(nodelay)
.with_keepalive(keepalive);
Server::builder()
.http2_max_pending_accept_reset_streams(Some(0))
.add_service(svc)
Expand Down
82 changes: 42 additions & 40 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
time::Duration,
};

use socket2::TcpKeepalive;
use tokio::net::{TcpListener, TcpStream};
use tokio_stream::{wrappers::TcpListenerStream, Stream};
use tracing::warn;
Expand All @@ -16,13 +17,13 @@ use tracing::warn;
#[derive(Debug)]
pub struct TcpIncoming {
inner: TcpListenerStream,
nodelay: bool,
keepalive: Option<Duration>,
nodelay: Option<bool>,
keepalive: Option<TcpKeepalive>,
}

impl TcpIncoming {
/// Creates an instance by binding (opening) the specified socket address
/// to which the specified TCP 'nodelay' and 'keepalive' parameters are applied.
/// Creates an instance by binding (opening) the specified socket address.
///
/// Returns a TcpIncoming if the socket address was successfully bound.
///
/// # Examples
Expand All @@ -42,7 +43,7 @@ impl TcpIncoming {
/// let mut port = 1322;
/// let tinc = loop {
/// let addr = format!("127.0.0.1:{}", port).parse().unwrap();
/// match TcpIncoming::new(addr, true, None) {
/// match TcpIncoming::new(addr) {
/// Ok(t) => break t,
/// Err(_) => port += 1
/// }
Expand All @@ -52,43 +53,42 @@ impl TcpIncoming {
/// .serve_with_incoming(tinc);
/// # Ok(())
/// # }
pub fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::BoxError> {
pub fn new(addr: SocketAddr) -> std::io::Result<Self> {
let std_listener = StdTcpListener::bind(addr)?;
std_listener.set_nonblocking(true)?;

let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?);
Ok(Self {
inner,
nodelay,
keepalive,
})
Ok(TcpListener::from_std(std_listener)?.into())
}

/// Sets the `TCP_NODELAY` option on the accepted connection.
pub fn with_nodelay(self, nodelay: Option<bool>) -> Self {
Self { nodelay, ..self }
}

/// Sets the `TCP_KEEPALIVE` option on the accepted connection.
pub fn with_keepalive(self, keepalive: Option<Duration>) -> Self {
let keepalive = keepalive.map(|t| TcpKeepalive::new().with_time(t));
Self { keepalive, ..self }
}
}

/// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`.
pub fn from_listener(
listener: TcpListener,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::BoxError> {
Ok(Self {
impl From<TcpListener> for TcpIncoming {
fn from(listener: TcpListener) -> Self {
Self {
inner: TcpListenerStream::new(listener),
nodelay,
keepalive,
})
nodelay: None,
keepalive: None,
}
}
}

impl Stream for TcpIncoming {
type Item = Result<TcpStream, std::io::Error>;
type Item = std::io::Result<TcpStream>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Ok(stream)) => {
set_accepted_socket_options(&stream, self.nodelay, self.keepalive);
set_accepted_socket_options(&stream, self.nodelay, &self.keepalive);
Some(Ok(stream)).into()
}
other => Poll::Ready(other),
Expand All @@ -97,19 +97,21 @@ impl Stream for TcpIncoming {
}

// Consistent with hyper-0.14, this function does not return an error.
fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option<Duration>) {
if nodelay {
if let Err(e) = stream.set_nodelay(true) {
warn!("error trying to set TCP nodelay: {}", e);
fn set_accepted_socket_options(
stream: &TcpStream,
nodelay: Option<bool>,
keepalive: &Option<TcpKeepalive>,
) {
if let Some(nodelay) = nodelay {
if let Err(e) = stream.set_nodelay(nodelay) {
warn!("error trying to set TCP_NODELAY: {e}");
}
}

if let Some(timeout) = keepalive {
if let Some(keepalive) = keepalive {
let sock_ref = socket2::SockRef::from(&stream);
let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);

if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
warn!("error trying to set TCP keepalive: {}", e);
if let Err(e) = sock_ref.set_tcp_keepalive(keepalive) {
warn!("error trying to set TCP_KEEPALIVE: {e}");
}
}
}
Expand All @@ -121,9 +123,9 @@ mod tests {
async fn one_tcpincoming_at_a_time() {
let addr = "127.0.0.1:1322".parse().unwrap();
{
let _t1 = TcpIncoming::new(addr, true, None).unwrap();
let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
let _t1 = TcpIncoming::new(addr).unwrap();
let _t2 = TcpIncoming::new(addr).unwrap_err();
}
let _t3 = TcpIncoming::new(addr, true, None).unwrap();
let _t3 = TcpIncoming::new(addr).unwrap();
}
}
12 changes: 8 additions & 4 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,8 +779,10 @@ impl<L> Router<L> {
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
.map_err(super::Error::from_source)?;
let incoming = TcpIncoming::new(addr)
.map_err(super::Error::from_source)?
.with_nodelay(Some(self.server.tcp_nodelay))
.with_keepalive(self.server.tcp_keepalive);
self.server
.serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>(
self.routes.prepare(),
Expand Down Expand Up @@ -810,8 +812,10 @@ impl<L> Router<L> {
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
.map_err(super::Error::from_source)?;
let incoming = TcpIncoming::new(addr)
.map_err(super::Error::from_source)?
.with_nodelay(Some(self.server.tcp_nodelay))
.with_keepalive(self.server.tcp_keepalive);
self.server
.serve_with_shutdown(self.routes.prepare(), incoming, Some(signal))
.await
Expand Down

0 comments on commit 3abbfe5

Please sign in to comment.