From d87d511eb61bba010f7fecdf3c207583c3e6a95a Mon Sep 17 00:00:00 2001 From: "Caleb Leinz (he/him)" <103841857+cmleinz@users.noreply.github.com> Date: Wed, 13 Dec 2023 08:17:17 -0800 Subject: [PATCH] Ensure multiple shutdown attempts behave correctly (#25) * Ensure multiple shutdown attempts behave correctly --------- Authored-by: Caleb Leinz --- src/tokio/io.rs | 20 ++++++++++++++++---- tests/integration_tests.rs | 28 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 tests/integration_tests.rs diff --git a/src/tokio/io.rs b/src/tokio/io.rs index 4667ba3..70c0a21 100644 --- a/src/tokio/io.rs +++ b/src/tokio/io.rs @@ -95,12 +95,24 @@ enum Status { FailedAndExhausted, // the way one feels after programming in dynamically typed languages } +#[inline] +fn poll_err( + kind: ErrorKind, + reason: impl Into>, +) -> Poll> { + let io_err = io::Error::new(kind, reason); + Poll::Ready(Err(io_err)) +} + fn exhausted_err() -> Poll> { - let io_err = io::Error::new( + poll_err( ErrorKind::NotConnected, "Disconnected. Connection attempts have been exhausted.", - ); - Poll::Ready(Err(io_err)) + ) +} + +fn disconnected_err() -> Poll> { + poll_err(ErrorKind::NotConnected, "Underlying I/O is disconnected.") } impl Deref for StubbornIo { @@ -446,7 +458,7 @@ where poll } - Status::Disconnected(_) => Poll::Pending, + Status::Disconnected(_) => disconnected_err(), Status::FailedAndExhausted => exhausted_err(), } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs new file mode 100644 index 0000000..7c31127 --- /dev/null +++ b/tests/integration_tests.rs @@ -0,0 +1,28 @@ +use std::time::Duration; + +use stubborn_io::StubbornTcpStream; +use tokio::{io::AsyncWriteExt, sync::oneshot}; + +#[tokio::test] +async fn back_to_back_shutdown_attempts() { + let (port_tx, port_rx) = oneshot::channel(); + tokio::spawn(async move { + let mut streams = Vec::new(); + let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + port_tx.send(addr).unwrap(); + loop { + let (stream, _addr) = listener.accept().await.unwrap(); + streams.push(stream); + } + }); + let addr = port_rx.await.unwrap(); + let mut connection = StubbornTcpStream::connect(addr).await.unwrap(); + + connection.shutdown().await.unwrap(); + let elapsed = tokio::time::timeout(Duration::from_secs(5), connection.shutdown()).await; + + let result = elapsed.unwrap(); + let error = result.unwrap_err(); + assert_eq!(error.kind(), std::io::ErrorKind::NotConnected); +}