diff --git a/client/network/src/protocol/notifications/handler.rs b/client/network/src/protocol/notifications/handler.rs index ca87941cb96df..57561c7b9879d 100644 --- a/client/network/src/protocol/notifications/handler.rs +++ b/client/network/src/protocol/notifications/handler.rs @@ -782,9 +782,6 @@ impl ConnectionHandler for NotifsHandler { // performed before the code paths that can produce `Ready` (with some rare exceptions). // Importantly, however, the flush is performed *after* notifications are queued with // `Sink::start_send`. - // Note that we must call `poll_flush` on all substreams and not only on those we - // have called `Sink::start_send` on, because `NotificationsOutSubstream::poll_flush` - // also reports the substream termination (even if no data was written into it). for protocol_index in 0..self.protocols.len() { match &mut self.protocols[protocol_index].state { State::Open { out_substream: out_substream @ Some(_), .. } => { @@ -827,7 +824,7 @@ impl ConnectionHandler for NotifsHandler { State::OpenDesiredByRemote { in_substream, pending_opening } => match NotificationsInSubstream::poll_process(Pin::new(in_substream), cx) { Poll::Pending => {}, - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(void)) => match void {}, Poll::Ready(Err(_)) => { self.protocols[protocol_index].state = State::Closed { pending_opening: *pending_opening }; @@ -843,7 +840,7 @@ impl ConnectionHandler for NotifsHandler { cx, ) { Poll::Pending => {}, - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(void)) => match void {}, Poll::Ready(Err(_)) => *in_substream = None, }, } diff --git a/client/network/src/protocol/notifications/upgrade/notifications.rs b/client/network/src/protocol/notifications/upgrade/notifications.rs index 3621c63497d95..71afc3c90e37f 100644 --- a/client/network/src/protocol/notifications/upgrade/notifications.rs +++ b/client/network/src/protocol/notifications/upgrade/notifications.rs @@ -41,6 +41,7 @@ use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use log::{error, warn}; use sc_network_common::protocol::ProtocolName; use std::{ + convert::Infallible, io, mem, pin::Pin, task::{Context, Poll}, @@ -220,7 +221,10 @@ where /// Equivalent to `Stream::poll_next`, except that it only drives the handshake and is /// guaranteed to not generate any notification. - pub fn poll_process(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + pub fn poll_process( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { let mut this = self.project(); loop { @@ -242,10 +246,8 @@ where }, NotificationsInSubstreamHandshake::Flush => { match Sink::poll_flush(this.socket.as_mut(), cx)? { - Poll::Ready(()) => { - *this.handshake = NotificationsInSubstreamHandshake::Sent; - return Poll::Ready(Ok(())) - }, + Poll::Ready(()) => + *this.handshake = NotificationsInSubstreamHandshake::Sent, Poll::Pending => { *this.handshake = NotificationsInSubstreamHandshake::Flush; return Poll::Pending @@ -258,7 +260,7 @@ where st @ NotificationsInSubstreamHandshake::ClosingInResponseToRemote | st @ NotificationsInSubstreamHandshake::BothSidesClosed => { *this.handshake = st; - return Poll::Ready(Ok(())) + return Poll::Pending }, } } @@ -441,21 +443,6 @@ where fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let mut this = self.project(); - - // `Sink::poll_flush` does not expose stream closed error until we write something into - // the stream, so the code below makes sure we detect that the substream was closed - // even if we don't write anything into it. - match Stream::poll_next(this.socket.as_mut(), cx) { - Poll::Pending => {}, - Poll::Ready(Some(_)) => { - error!( - target: "sub-libp2p", - "Unexpected incoming data in `NotificationsOutSubstream`", - ); - }, - Poll::Ready(None) => return Poll::Ready(Err(NotificationsOutError::Terminated)), - } - Sink::poll_flush(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io) } @@ -505,21 +492,13 @@ pub enum NotificationsOutError { /// I/O error on the substream. #[error(transparent)] Io(#[from] io::Error), - - /// End of incoming data detected on out substream. - #[error("substream was closed/reset")] - Terminated, } #[cfg(test)] mod tests { - use super::{ - NotificationsIn, NotificationsInOpen, NotificationsOut, NotificationsOutError, - NotificationsOutOpen, - }; - use futures::{channel::oneshot, future, prelude::*}; + use super::{NotificationsIn, NotificationsInOpen, NotificationsOut, NotificationsOutOpen}; + use futures::{channel::oneshot, prelude::*}; use libp2p::core::upgrade; - use std::{pin::Pin, task::Poll}; use tokio::net::{TcpListener, TcpStream}; use tokio_util::compat::TokioAsyncReadCompatExt; @@ -712,95 +691,4 @@ mod tests { client.await.unwrap(); } - - #[tokio::test] - async fn send_handshake_without_polling_for_incoming_data() { - const PROTO_NAME: &str = "/test/proto/1"; - let (listener_addr_tx, listener_addr_rx) = oneshot::channel(); - - let client = tokio::spawn(async move { - let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap(); - let NotificationsOutOpen { handshake, .. } = upgrade::apply_outbound( - socket.compat(), - NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024), - upgrade::Version::V1, - ) - .await - .unwrap(); - - assert_eq!(handshake, b"hello world"); - }); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - listener_addr_tx.send(listener.local_addr().unwrap()).unwrap(); - - let (socket, _) = listener.accept().await.unwrap(); - let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound( - socket.compat(), - NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024), - ) - .await - .unwrap(); - - assert_eq!(handshake, b"initial message"); - substream.send_handshake(&b"hello world"[..]); - - // Actually send the handshake. - future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap(); - - client.await.unwrap(); - } - - #[tokio::test] - async fn can_detect_dropped_out_substream_without_writing_data() { - const PROTO_NAME: &str = "/test/proto/1"; - let (listener_addr_tx, listener_addr_rx) = oneshot::channel(); - - let client = tokio::spawn(async move { - let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap(); - let NotificationsOutOpen { handshake, mut substream, .. } = upgrade::apply_outbound( - socket.compat(), - NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024), - upgrade::Version::V1, - ) - .await - .unwrap(); - - assert_eq!(handshake, b"hello world"); - - future::poll_fn(|cx| match Pin::new(&mut substream).poll_flush(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(())) => { - cx.waker().wake_by_ref(); - Poll::Pending - }, - Poll::Ready(Err(e)) => { - assert!(matches!(e, NotificationsOutError::Terminated)); - Poll::Ready(()) - }, - }) - .await; - }); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - listener_addr_tx.send(listener.local_addr().unwrap()).unwrap(); - - let (socket, _) = listener.accept().await.unwrap(); - let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound( - socket.compat(), - NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024), - ) - .await - .unwrap(); - - assert_eq!(handshake, b"initial message"); - - // Send the handhsake. - substream.send_handshake(&b"hello world"[..]); - future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap(); - - drop(substream); - - client.await.unwrap(); - } }