diff --git a/Cargo.lock b/Cargo.lock index 8112055e..d9f83a1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1784,6 +1784,18 @@ dependencies = [ "slab", ] +[[package]] +name = "futures_ringbuf" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6628abb6eb1fc74beaeb20cd0670c43d158b0150f7689b38c3eaf663f99bdec7" +dependencies = [ + "futures", + "log", + "ringbuf", + "rustc_version", +] + [[package]] name = "fxhash" version = "0.2.1" @@ -2880,6 +2892,7 @@ dependencies = [ "futures", "futures-rustls", "futures-timer", + "futures_ringbuf", "hex-literal", "indexmap 2.1.0", "libc", @@ -4370,6 +4383,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ringbuf" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79abed428d1fd2a128201cec72c5f6938e2da607c6f3745f769fabea399d950a" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "rtcp" version = "0.7.2" diff --git a/Cargo.toml b/Cargo.toml index fc0b875f..68fc0eb5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,6 +83,7 @@ sc-network = "0.28.0" sc-utils = "8.0.0" serde_json = "1.0.108" tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +futures_ringbuf = "0.4.0" [features] custom_sc_network = [] diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index 2a8a025e..6b345cb6 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -86,10 +86,24 @@ pub struct DialerSelectFuture { } enum State { - SendHeader { io: MessageIO }, - SendProtocol { io: MessageIO, protocol: N }, - FlushProtocol { io: MessageIO, protocol: N }, - AwaitProtocol { io: MessageIO, protocol: N }, + SendHeader { + io: MessageIO, + }, + SendProtocol { + io: MessageIO, + protocol: N, + header_received: bool, + }, + FlushProtocol { + io: MessageIO, + protocol: N, + header_received: bool, + }, + AwaitProtocol { + io: MessageIO, + protocol: N, + header_received: bool, + }, Done, } @@ -127,14 +141,26 @@ where // The dialer always sends the header and the first protocol // proposal in one go for efficiency. - *this.state = State::SendProtocol { io, protocol }; + *this.state = State::SendProtocol { + io, + protocol, + header_received: false, + }; } - State::SendProtocol { mut io, protocol } => { + State::SendProtocol { + mut io, + protocol, + header_received, + } => { match Pin::new(&mut io).poll_ready(cx)? { Poll::Ready(()) => {} Poll::Pending => { - *this.state = State::SendProtocol { io, protocol }; + *this.state = State::SendProtocol { + io, + protocol, + header_received, + }; return Poll::Pending; } } @@ -146,10 +172,19 @@ where tracing::debug!(target: LOG_TARGET, "Dialer: Proposed protocol: {}", p); if this.protocols.peek().is_some() { - *this.state = State::FlushProtocol { io, protocol } + *this.state = State::FlushProtocol { + io, + protocol, + header_received, + } } else { match this.version { - Version::V1 => *this.state = State::FlushProtocol { io, protocol }, + Version::V1 => + *this.state = State::FlushProtocol { + io, + protocol, + header_received, + }, // This is the only effect that `V1Lazy` has compared to `V1`: // Optimistically settling on the only protocol that // the dialer supports for this negotiation. Notably, @@ -168,21 +203,40 @@ where } } - State::FlushProtocol { mut io, protocol } => { - match Pin::new(&mut io).poll_flush(cx)? { - Poll::Ready(()) => *this.state = State::AwaitProtocol { io, protocol }, - Poll::Pending => { - *this.state = State::FlushProtocol { io, protocol }; - return Poll::Pending; - } + State::FlushProtocol { + mut io, + protocol, + header_received, + } => match Pin::new(&mut io).poll_flush(cx)? { + Poll::Ready(()) => + *this.state = State::AwaitProtocol { + io, + protocol, + header_received, + }, + Poll::Pending => { + *this.state = State::FlushProtocol { + io, + protocol, + header_received, + }; + return Poll::Pending; } - } + }, - State::AwaitProtocol { mut io, protocol } => { + State::AwaitProtocol { + mut io, + protocol, + header_received, + } => { let msg = match Pin::new(&mut io).poll_next(cx)? { Poll::Ready(Some(msg)) => msg, Poll::Pending => { - *this.state = State::AwaitProtocol { io, protocol }; + *this.state = State::AwaitProtocol { + io, + protocol, + header_received, + }; return Poll::Pending; } // Treat EOF error as [`NegotiationError::Failed`], not as @@ -192,8 +246,14 @@ where }; match msg { - Message::Header(v) if v == HeaderLine::from(*this.version) => { - *this.state = State::AwaitProtocol { io, protocol }; + Message::Header(v) + if v == HeaderLine::from(*this.version) && !header_received => + { + *this.state = State::AwaitProtocol { + io, + protocol, + header_received: true, + }; } Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => { tracing::debug!( @@ -211,7 +271,11 @@ where String::from_utf8_lossy(protocol.as_ref()) ); let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; - *this.state = State::SendProtocol { io, protocol } + *this.state = State::SendProtocol { + io, + protocol, + header_received, + } } _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), } @@ -346,3 +410,354 @@ impl DialerState { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::multistream_select::listener_select_proto; + use std::time::Duration; + use tokio::net::{TcpListener, TcpStream}; + + #[tokio::test] + async fn select_proto_basic() { + async fn run(version: Version) { + let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100); + + let server: tokio::task::JoinHandle> = tokio::spawn(async move { + let protos = vec!["/proto1", "/proto2"]; + let (proto, mut io) = + listener_select_proto(server_connection, protos).await.unwrap(); + assert_eq!(proto, "/proto2"); + + let mut out = vec![0; 32]; + let n = io.read(&mut out).await.unwrap(); + out.truncate(n); + assert_eq!(out, b"ping"); + + io.write_all(b"pong").await.unwrap(); + io.flush().await.unwrap(); + + Ok(()) + }); + + let client: tokio::task::JoinHandle> = tokio::spawn(async move { + let protos = vec!["/proto3", "/proto2"]; + let (proto, mut io) = + dialer_select_proto(client_connection, protos, version).await.unwrap(); + assert_eq!(proto, "/proto2"); + + io.write_all(b"ping").await.unwrap(); + io.flush().await.unwrap(); + + let mut out = vec![0; 32]; + let n = io.read(&mut out).await.unwrap(); + out.truncate(n); + assert_eq!(out, b"pong"); + + Ok(()) + }); + + server.await.unwrap(); + client.await.unwrap(); + } + + run(Version::V1).await; + run(Version::V1Lazy).await; + } + + /// Tests the expected behaviour of failed negotiations. + #[tokio::test] + async fn negotiation_failed() { + async fn run( + version: Version, + dial_protos: Vec<&'static str>, + dial_payload: Vec, + listen_protos: Vec<&'static str>, + ) { + let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100); + + let server: tokio::task::JoinHandle> = tokio::spawn(async move { + let io = match tokio::time::timeout( + Duration::from_secs(2), + listener_select_proto(server_connection, listen_protos), + ) + .await + .unwrap() + { + Ok((_, io)) => io, + Err(NegotiationError::Failed) => return Ok(()), + Err(NegotiationError::ProtocolError(e)) => { + panic!("Unexpected protocol error {e}") + } + }; + match io.complete().await { + Err(NegotiationError::Failed) => {} + _ => panic!(), + } + + Ok(()) + }); + + let client: tokio::task::JoinHandle> = tokio::spawn(async move { + let mut io = match tokio::time::timeout( + Duration::from_secs(2), + dialer_select_proto(client_connection, dial_protos, version), + ) + .await + .unwrap() + { + Err(NegotiationError::Failed) => return Ok(()), + Ok((_, io)) => io, + Err(_) => panic!(), + }; + + // The dialer may write a payload that is even sent before it + // got confirmation of the last proposed protocol, when `V1Lazy` + // is used. + io.write_all(&dial_payload).await.unwrap(); + match io.complete().await { + Err(NegotiationError::Failed) => {} + _ => panic!(), + } + + Ok(()) + }); + + server.await.unwrap(); + client.await.unwrap(); + } + + // Incompatible protocols. + run(Version::V1, vec!["/proto1"], vec![1], vec!["/proto2"]).await; + run(Version::V1Lazy, vec!["/proto1"], vec![1], vec!["/proto2"]).await; + } + + #[tokio::test] + async fn v1_lazy_do_not_wait_for_negotiation_on_poll_close() { + let (client_connection, _server_connection) = + futures_ringbuf::Endpoint::pair(1024 * 1024, 1); + + let client = tokio::spawn(async move { + // Single protocol to allow for lazy (or optimistic) protocol negotiation. + let protos = vec!["/proto1"]; + let (proto, mut io) = + dialer_select_proto(client_connection, protos, Version::V1Lazy).await.unwrap(); + assert_eq!(proto, "/proto1"); + + // The lazy negotation of protocols can be closed at any time, + // even if the negotiation is not yet done. + io.close().await.unwrap(); + }); + + tokio::time::timeout(Duration::from_secs(10), client).await.unwrap(); + } + + #[tokio::test] + async fn low_level_negotiate() { + async fn run(version: Version) { + let (client_connection, mut server_connection) = + futures_ringbuf::Endpoint::pair(100, 100); + + let server = tokio::spawn(async move { + let protos = vec!["/proto2"]; + + let multistream = b"/multistream/1.0.0\n"; + let len = multistream.len(); + let proto = b"/proto2\n"; + let proto_len = proto.len(); + + // Check that our implementation writes optimally + // the multistream ++ protocol in a single message. + let mut expected_message = Vec::new(); + expected_message.push(len as u8); + expected_message.extend_from_slice(multistream); + expected_message.push(proto_len as u8); + expected_message.extend_from_slice(proto); + + if version == Version::V1Lazy { + expected_message.extend_from_slice(b"ping"); + } + + let mut out = vec![0; 64]; + let n = server_connection.read(&mut out).await.unwrap(); + out.truncate(n); + assert_eq!(out, expected_message); + + // We must send the back the multistream packet. + let mut send_message = Vec::new(); + send_message.push(len as u8); + send_message.extend_from_slice(multistream); + + server_connection.write_all(&mut send_message).await.unwrap(); + + let mut send_message = Vec::new(); + send_message.push(proto_len as u8); + send_message.extend_from_slice(proto); + server_connection.write_all(&mut send_message).await.unwrap(); + + // Handle handshake. + match version { + Version::V1 => { + let mut out = vec![0; 64]; + let n = server_connection.read(&mut out).await.unwrap(); + out.truncate(n); + assert_eq!(out, b"ping"); + + server_connection.write_all(b"pong").await.unwrap(); + } + Version::V1Lazy => { + // Ping (handshake) payload expected in the initial message. + server_connection.write_all(b"pong").await.unwrap(); + } + } + }); + + let client = tokio::spawn(async move { + let protos = vec!["/proto2"]; + let (proto, mut io) = + dialer_select_proto(client_connection, protos, version).await.unwrap(); + assert_eq!(proto, "/proto2"); + + io.write_all(b"ping").await.unwrap(); + io.flush().await.unwrap(); + + let mut out = vec![0; 32]; + let n = io.read(&mut out).await.unwrap(); + out.truncate(n); + assert_eq!(out, b"pong"); + }); + + server.await; + client.await; + } + + run(Version::V1).await; + run(Version::V1Lazy).await; + } + + #[tokio::test] + async fn v1_low_level_negotiate_multiple_headers() { + let (client_connection, mut server_connection) = futures_ringbuf::Endpoint::pair(100, 100); + + let server: tokio::task::JoinHandle> = tokio::spawn(async move { + let protos = vec!["/proto2"]; + + let multistream = b"/multistream/1.0.0\n"; + let len = multistream.len(); + let proto = b"/proto2\n"; + let proto_len = proto.len(); + + // Check that our implementation writes optimally + // the multistream ++ protocol in a single message. + let mut expected_message = Vec::new(); + expected_message.push(len as u8); + expected_message.extend_from_slice(multistream); + expected_message.push(proto_len as u8); + expected_message.extend_from_slice(proto); + + let mut out = vec![0; 64]; + let n = server_connection.read(&mut out).await.unwrap(); + out.truncate(n); + assert_eq!(out, expected_message); + + // We must send the back the multistream packet. + let mut send_message = Vec::new(); + send_message.push(len as u8); + send_message.extend_from_slice(multistream); + + server_connection.write_all(&mut send_message).await.unwrap(); + + // We must send the back the multistream packet again. + let mut send_message = Vec::new(); + send_message.push(len as u8); + send_message.extend_from_slice(multistream); + + server_connection.write_all(&mut send_message).await.unwrap(); + + Ok(()) + }); + + let client: tokio::task::JoinHandle> = tokio::spawn(async move { + let protos = vec!["/proto2"]; + + // Negotiation fails because the protocol receives the `/multistream/1.0.0` header + // multiple times. + let result = + dialer_select_proto(client_connection, protos, Version::V1).await.unwrap_err(); + match result { + NegotiationError::ProtocolError(ProtocolError::InvalidMessage) => {} + _ => panic!("unexpected error: {:?}", result), + }; + + Ok(()) + }); + + server.await.unwrap(); + client.await.unwrap(); + } + + #[tokio::test] + async fn v1_lazy_low_level_negotiate_multiple_headers() { + let (client_connection, mut server_connection) = futures_ringbuf::Endpoint::pair(100, 100); + + let server: tokio::task::JoinHandle> = tokio::spawn(async move { + let protos = vec!["/proto2"]; + + let multistream = b"/multistream/1.0.0\n"; + let len = multistream.len(); + let proto = b"/proto2\n"; + let proto_len = proto.len(); + + // Check that our implementation writes optimally + // the multistream ++ protocol in a single message. + let mut expected_message = Vec::new(); + expected_message.push(len as u8); + expected_message.extend_from_slice(multistream); + expected_message.push(proto_len as u8); + expected_message.extend_from_slice(proto); + + let mut out = vec![0; 64]; + let n = server_connection.read(&mut out).await.unwrap(); + out.truncate(n); + assert_eq!(out, expected_message); + + // We must send the back the multistream packet. + let mut send_message = Vec::new(); + send_message.push(len as u8); + send_message.extend_from_slice(multistream); + + server_connection.write_all(&mut send_message).await.unwrap(); + + // We must send the back the multistream packet again. + let mut send_message = Vec::new(); + send_message.push(len as u8); + send_message.extend_from_slice(multistream); + + server_connection.write_all(&mut send_message).await.unwrap(); + + Ok(()) + }); + + let client: tokio::task::JoinHandle> = tokio::spawn(async move { + let protos = vec!["/proto2"]; + + // Negotiation fails because the protocol receives the `/multistream/1.0.0` header + // multiple times. + let (proto, to_negociate) = + dialer_select_proto(client_connection, protos, Version::V1Lazy).await.unwrap(); + assert_eq!(proto, "/proto2"); + + let result = to_negociate.complete().await.unwrap_err(); + + match result { + NegotiationError::ProtocolError(ProtocolError::InvalidMessage) => {} + _ => panic!("unexpected error: {:?}", result), + }; + + Ok(()) + }); + + server.await.unwrap(); + client.await.unwrap(); + } +} diff --git a/src/multistream_select/negotiated.rs b/src/multistream_select/negotiated.rs index 450bdae3..1aae2285 100644 --- a/src/multistream_select/negotiated.rs +++ b/src/multistream_select/negotiated.rs @@ -176,6 +176,10 @@ impl Negotiated { header: None, }; continue; + } else { + // If we received a header message but it doesn't match the expected + // one, or we have already received the message return an error. + return Poll::Ready(Err(ProtocolError::InvalidMessage.into())); } } @@ -319,9 +323,8 @@ where } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Ensure all data has been flushed and expected negotiation messages + // Ensure all data has been flushed and potentially negotiation messages // have been received. - ready!(self.as_mut().poll(cx).map_err(Into::::into)?); ready!(self.as_mut().poll_flush(cx).map_err(Into::::into)?); // Continue with the shutdown of the underlying I/O stream.