From e0802dc5dd853914774cbbea42fe69c9ef902a50 Mon Sep 17 00:00:00 2001 From: Caleb Leinz Date: Tue, 12 Dec 2023 10:10:01 -0800 Subject: [PATCH 1/3] Ensure multiple shutdown attempts yield --- src/tokio/io.rs | 2 +- tests/dummy_tests.rs | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/tokio/io.rs b/src/tokio/io.rs index a3bbe9d..7c1f5ac 100644 --- a/src/tokio/io.rs +++ b/src/tokio/io.rs @@ -381,7 +381,7 @@ where poll } - Status::Disconnected(_) => Poll::Pending, + Status::Disconnected(_) => exhausted_err(), Status::FailedAndExhausted => exhausted_err(), } } diff --git a/tests/dummy_tests.rs b/tests/dummy_tests.rs index 260df34..5c19daf 100644 --- a/tests/dummy_tests.rs +++ b/tests/dummy_tests.rs @@ -174,6 +174,31 @@ mod already_connected { use tokio_util::codec::{Framed, LinesCodec}; + #[tokio::test] + async fn back_to_back_shutdown_attempts() { + use stubborn_io::StubbornTcpStream; + use tokio::io::AsyncWriteExt; + + const ADDR: &str = "127.0.0.1:3989"; + + tokio::spawn(async move { + let mut streams = Vec::new(); + let listener = tokio::net::TcpListener::bind(ADDR).await.unwrap(); + loop { + let (stream, _addr) = listener.accept().await.unwrap(); + streams.push(stream); + } + }); + tokio::time::sleep(Duration::from_secs(1)).await; + let mut connection = StubbornTcpStream::connect(ADDR).await.unwrap(); + + connection.shutdown().await.unwrap(); + let elapsed = tokio::time::timeout(Duration::from_secs(5), connection.shutdown()).await; + + let result = elapsed.unwrap(); + assert!(result.is_err()); + } + #[tokio::test] async fn should_ignore_non_fatal_errors_and_continue_as_connected() { let connect_outcomes = Arc::new(Mutex::new(vec![true])); From 1a35332809f43b94832abb78ea82d0cf50293cd4 Mon Sep 17 00:00:00 2001 From: Caleb Leinz Date: Tue, 12 Dec 2023 15:55:09 -0800 Subject: [PATCH 2/3] Move test and delegate port selection to OS --- tests/dummy_tests.rs | 25 ------------------------- tests/integration_tests.rs | 27 +++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 25 deletions(-) create mode 100644 tests/integration_tests.rs diff --git a/tests/dummy_tests.rs b/tests/dummy_tests.rs index 5c19daf..260df34 100644 --- a/tests/dummy_tests.rs +++ b/tests/dummy_tests.rs @@ -174,31 +174,6 @@ mod already_connected { use tokio_util::codec::{Framed, LinesCodec}; - #[tokio::test] - async fn back_to_back_shutdown_attempts() { - use stubborn_io::StubbornTcpStream; - use tokio::io::AsyncWriteExt; - - const ADDR: &str = "127.0.0.1:3989"; - - tokio::spawn(async move { - let mut streams = Vec::new(); - let listener = tokio::net::TcpListener::bind(ADDR).await.unwrap(); - loop { - let (stream, _addr) = listener.accept().await.unwrap(); - streams.push(stream); - } - }); - tokio::time::sleep(Duration::from_secs(1)).await; - let mut connection = StubbornTcpStream::connect(ADDR).await.unwrap(); - - connection.shutdown().await.unwrap(); - let elapsed = tokio::time::timeout(Duration::from_secs(5), connection.shutdown()).await; - - let result = elapsed.unwrap(); - assert!(result.is_err()); - } - #[tokio::test] async fn should_ignore_non_fatal_errors_and_continue_as_connected() { let connect_outcomes = Arc::new(Mutex::new(vec![true])); diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs new file mode 100644 index 0000000..cc81b7b --- /dev/null +++ b/tests/integration_tests.rs @@ -0,0 +1,27 @@ +use std::time::Duration; + +use stubborn_io::StubbornTcpStream; +use tokio::{io::AsyncWriteExt, sync::oneshot}; + +#[tokio::test] +async fn back_to_back_shutdown_attempts() { + let (port_tx, port_rx) = oneshot::channel(); + tokio::spawn(async move { + let mut streams = Vec::new(); + let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + port_tx.send(addr).unwrap(); + loop { + let (stream, _addr) = listener.accept().await.unwrap(); + streams.push(stream); + } + }); + let addr = port_rx.await.unwrap(); + let mut connection = StubbornTcpStream::connect(addr).await.unwrap(); + + connection.shutdown().await.unwrap(); + let elapsed = tokio::time::timeout(Duration::from_secs(5), connection.shutdown()).await; + + let result = elapsed.unwrap(); + assert!(result.is_err()); +} From 134cc4b4cd12a35afd138223c020aea825ed3593 Mon Sep 17 00:00:00 2001 From: Caleb Leinz Date: Wed, 13 Dec 2023 08:07:42 -0800 Subject: [PATCH 3/3] v0.3.3 - Adjust shutdown on disconnect error --- Cargo.toml | 2 +- src/tokio/io.rs | 20 ++++++++++++++++---- tests/integration_tests.rs | 3 ++- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f5b9ddf..e889e3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "stubborn-io" -version = "0.3.2" +version = "0.3.3" authors = ["David Raifaizen "] edition = "2021" description = "io traits/structs that automatically recover from potential disconnections/interruptions." diff --git a/src/tokio/io.rs b/src/tokio/io.rs index 7c1f5ac..d4d1473 100644 --- a/src/tokio/io.rs +++ b/src/tokio/io.rs @@ -95,12 +95,24 @@ enum Status { FailedAndExhausted, // the way one feels after programming in dynamically typed languages } +#[inline] +fn poll_err( + kind: ErrorKind, + reason: impl Into>, +) -> Poll> { + let io_err = io::Error::new(kind, reason); + Poll::Ready(Err(io_err)) +} + fn exhausted_err() -> Poll> { - let io_err = io::Error::new( + poll_err( ErrorKind::NotConnected, "Disconnected. Connection attempts have been exhausted.", - ); - Poll::Ready(Err(io_err)) + ) +} + +fn disconnected_err() -> Poll> { + poll_err(ErrorKind::NotConnected, "Underlying I/O is disconnected.") } impl Deref for StubbornIo { @@ -381,7 +393,7 @@ where poll } - Status::Disconnected(_) => exhausted_err(), + Status::Disconnected(_) => disconnected_err(), Status::FailedAndExhausted => exhausted_err(), } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index cc81b7b..7c31127 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -23,5 +23,6 @@ async fn back_to_back_shutdown_attempts() { let elapsed = tokio::time::timeout(Duration::from_secs(5), connection.shutdown()).await; let result = elapsed.unwrap(); - assert!(result.is_err()); + let error = result.unwrap_err(); + assert_eq!(error.kind(), std::io::ErrorKind::NotConnected); }