Skip to content

Commit

Permalink
Ensure multiple shutdown attempts behave correctly (#25)
Browse files Browse the repository at this point in the history
* Ensure multiple shutdown attempts behave correctly

---------

Authored-by: Caleb Leinz <[email protected]>
  • Loading branch information
cmleinz authored and fredclausen committed Dec 30, 2023
1 parent 64f22da commit d87d511
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/tokio/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,24 @@ enum Status<T, C> {
FailedAndExhausted, // the way one feels after programming in dynamically typed languages
}

#[inline]
fn poll_err<T>(
kind: ErrorKind,
reason: impl Into<Box<dyn std::error::Error + Send + Sync>>,
) -> Poll<io::Result<T>> {
let io_err = io::Error::new(kind, reason);
Poll::Ready(Err(io_err))
}

fn exhausted_err<T>() -> Poll<io::Result<T>> {
let io_err = io::Error::new(
poll_err(
ErrorKind::NotConnected,
"Disconnected. Connection attempts have been exhausted.",
);
Poll::Ready(Err(io_err))
)
}

fn disconnected_err<T>() -> Poll<io::Result<T>> {
poll_err(ErrorKind::NotConnected, "Underlying I/O is disconnected.")
}

impl<T, C> Deref for StubbornIo<T, C> {
Expand Down Expand Up @@ -446,7 +458,7 @@ where

poll
}
Status::Disconnected(_) => Poll::Pending,
Status::Disconnected(_) => disconnected_err(),
Status::FailedAndExhausted => exhausted_err(),
}
}
Expand Down
28 changes: 28 additions & 0 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use std::time::Duration;

use stubborn_io::StubbornTcpStream;
use tokio::{io::AsyncWriteExt, sync::oneshot};

#[tokio::test]
async fn back_to_back_shutdown_attempts() {
let (port_tx, port_rx) = oneshot::channel();
tokio::spawn(async move {
let mut streams = Vec::new();
let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr = listener.local_addr().unwrap();
port_tx.send(addr).unwrap();
loop {
let (stream, _addr) = listener.accept().await.unwrap();
streams.push(stream);
}
});
let addr = port_rx.await.unwrap();
let mut connection = StubbornTcpStream::connect(addr).await.unwrap();

connection.shutdown().await.unwrap();
let elapsed = tokio::time::timeout(Duration::from_secs(5), connection.shutdown()).await;

let result = elapsed.unwrap();
let error = result.unwrap_err();
assert_eq!(error.kind(), std::io::ErrorKind::NotConnected);
}

0 comments on commit d87d511

Please sign in to comment.