Skip to content

Commit

Permalink
Ensure disconnected sockets do not block the caller
Browse files Browse the repository at this point in the history
  • Loading branch information
fredclausen committed Dec 30, 2023
1 parent 5f5d9b0 commit 64f22da
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions src/tokio/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ where
T: UnderlyingIo<C> + AsyncWrite,
C: Clone + Send + Unpin + 'static,
{
/// Method for writing to the underlying IO item.
/// If the write results in a disconnect, the write is skipped and the
/// underlying IO item will attempt to be reconnected.
/// No error is returned to the caller.
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand All @@ -389,15 +393,23 @@ where
let poll = AsyncWrite::poll_write(Pin::new(&mut self.underlying_io), cx, buf);

if self.is_write_disconnect_detected(&poll) {
error!(
"{}Write disconnect detected. Skipping message",
&self.get_connection_name()
);
self.on_disconnect(cx);
Poll::Pending
Poll::Ready(Ok(buf.len()))
} else {
poll
}
}
Status::Disconnected(_) => {
error!(
"{}Write disconnect detected. Skipping Message",
&self.get_connection_name()
);
self.poll_disconnect(cx);
Poll::Pending
Poll::Ready(Ok(buf.len()))
}
Status::FailedAndExhausted => exhausted_err(),
}
Expand Down Expand Up @@ -439,6 +451,10 @@ where
}
}

/// Method for writing to the underlying IO item.
/// If the write results in a disconnect, the write is skipped and the
/// underlying IO item will attempt to be reconnected.
/// No error is returned to the caller.
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand All @@ -450,15 +466,23 @@ where
AsyncWrite::poll_write_vectored(Pin::new(&mut self.underlying_io), cx, bufs);

if self.is_write_disconnect_detected(&poll) {
error!(
"{}Write disconnect detected. Skipping message",
&self.get_connection_name()
);
self.on_disconnect(cx);
Poll::Pending
Poll::Ready(Ok(bufs.iter().map(|buf| buf.len()).sum()))
} else {
poll
}
}
Status::Disconnected(_) => {
error!(
"{}Write disconnect detected. Skipping Message",
&self.get_connection_name()
);
self.poll_disconnect(cx);
Poll::Pending
Poll::Ready(Ok(bufs.iter().map(|buf| buf.len()).sum()))
}
Status::FailedAndExhausted => exhausted_err(),
}
Expand Down

0 comments on commit 64f22da

Please sign in to comment.