From 02fc0278f675ba33d0e749b7dbe0d0b5d289d88e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20D=C3=B3ka?= <60517552+jakubDoka@users.noreply.github.com> Date: Sat, 6 Jul 2024 01:30:18 +0200 Subject: [PATCH] fix(swarm): eliminating protocol cloning when nothing is happening MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Code keeps the API while eliminating repetitive protocol cloning when protocols did not change, If protocol changes occur, only then the protocols are cloned to a reused buffer from which they are borrowed for iteration. Following are benchmark results: |behaviour count|iterations|protocols|timings|change*| |-|-|-|-|-| |1|1000|10|27.798 µs 28.134 µs 28.493 µs|-15.771% -14.523% -13.269%| |1|1000|100|55.171 µs 55.578 µs 56.009 µs|-51.831% -50.162% -48.437%| |1|1000|1000|289.24 µs 290.99 µs 293.00 µs|-61.748% -60.895% -60.054%| |5|1000|2|34.000 µs 34.216 µs 34.457 µs|-18.538% -16.231% -14.011%| |5|1000|20|70.962 µs 71.428 µs 72.005 µs|-40.501% -38.944% -37.309%| |5|1000|200|426.17 µs 433.27 µs 442.60 µs|-44.824% -42.663% -40.262%| |10|1000|1|42.993 µs 44.382 µs 45.655 µs|-18.839% -16.292% -13.584%| |10|1000|10|94.022 µs 96.787 µs 99.321 µs|-25.469% -23.572% -21.562%| |10|1000|100|543.13 µs 554.91 µs 569.06 µs|-43.781% -42.189% -40.568%| |20|500|1|63.150 µs 64.846 µs 66.860 µs|-9.5693% -6.1722% -2.6400%| |20|500|10|212.21 µs 217.48 µs 222.64 µs|-16.525% -14.234% -11.925%| |20|500|100|1.6651 ms 1.7083 ms 1.7490 ms|-27.704% -25.683% -23.618%| change*: 3da7d918d0d3c443f20d1813772af1ac152b68c7 is the baseline Pull-Request: #5026. --- Cargo.lock | 5 +- Cargo.toml | 2 +- swarm/CHANGELOG.md | 7 + swarm/Cargo.toml | 7 +- swarm/benches/connection_handler.rs | 359 ++++++++++++++++++++ swarm/src/connection.rs | 70 ++-- swarm/src/connection/supported_protocols.rs | 20 +- swarm/src/handler.rs | 320 +++++++++++++---- swarm/src/lib.rs | 8 + 9 files changed, 707 insertions(+), 91 deletions(-) create mode 100644 swarm/benches/connection_handler.rs diff --git a/Cargo.lock b/Cargo.lock index 5d45d51ac4b..5ee39cb1786 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1093,6 +1093,7 @@ dependencies = [ "ciborium", "clap", "criterion-plot", + "futures", "is-terminal", "itertools", "num-traits", @@ -1105,6 +1106,7 @@ dependencies = [ "serde_derive", "serde_json", "tinytemplate", + "tokio", "walkdir", ] @@ -3291,9 +3293,10 @@ dependencies = [ [[package]] name = "libp2p-swarm" -version = "0.45.0" +version = "0.45.1" dependencies = [ "async-std", + "criterion", "either", "fnv", "futures", diff --git a/Cargo.toml b/Cargo.toml index ab660cc90e9..1444b469c31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -102,7 +102,7 @@ libp2p-rendezvous = { version = "0.14.1", path = "protocols/rendezvous" } libp2p-request-response = { version = "0.26.4", path = "protocols/request-response" } libp2p-server = { version = "0.12.7", path = "misc/server" } libp2p-stream = { version = "0.1.0-alpha.1", path = "protocols/stream" } -libp2p-swarm = { version = "0.45.0", path = "swarm" } +libp2p-swarm = { version = "0.45.1", path = "swarm" } libp2p-swarm-derive = { version = "=0.34.2", path = "swarm-derive" } # `libp2p-swarm-derive` may not be compatible with different `libp2p-swarm` non-breaking releases. E.g. `libp2p-swarm` might introduce a new enum variant `FromSwarm` (which is `#[non-exhaustive]`) in a non-breaking release. Older versions of `libp2p-swarm-derive` would not forward this enum variant within the `NetworkBehaviour` hierarchy. Thus the version pinning is required. libp2p-swarm-test = { version = "0.3.0", path = "swarm-test" } libp2p-tcp = { version = "0.42.0", path = "transports/tcp" } diff --git a/swarm/CHANGELOG.md b/swarm/CHANGELOG.md index 1139c40021f..8cf8852dd76 100644 --- a/swarm/CHANGELOG.md +++ b/swarm/CHANGELOG.md @@ -1,3 +1,10 @@ + +## 0.45.1 + +- Optimize internal connection `fn poll`. New implementation now scales much better with number of listen protocols active. + No changes to public API introduced. + See [PR 5026](https://github.com/libp2p/rust-libp2p/pull/5026) + ## 0.45.0 - Add peer_id to `FromSwarm::ListenFailure`. diff --git a/swarm/Cargo.toml b/swarm/Cargo.toml index 60cf58cb495..6ce99aa33c5 100644 --- a/swarm/Cargo.toml +++ b/swarm/Cargo.toml @@ -3,7 +3,7 @@ name = "libp2p-swarm" edition = "2021" rust-version = { workspace = true } description = "The libp2p swarm" -version = "0.45.0" +version = "0.45.1" authors = ["Parity Technologies "] license = "MIT" repository = "https://github.com/libp2p/rust-libp2p" @@ -52,6 +52,7 @@ libp2p-swarm-derive = { path = "../swarm-derive" } # Using `pat libp2p-swarm-test = { path = "../swarm-test" } # Using `path` here because this is a cyclic dev-dependency which otherwise breaks releasing. libp2p-yamux = { path = "../muxers/yamux" } # Using `path` here because this is a cyclic dev-dependency which otherwise breaks releasing. quickcheck = { workspace = true } +criterion = { version = "0.5", features = ["async_tokio"] } void = "1" once_cell = "1.19.0" trybuild = "1.0.95" @@ -69,5 +70,9 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] rustc-args = ["--cfg", "docsrs"] +[[bench]] +name = "connection_handler" +harness = false + [lints] workspace = true diff --git a/swarm/benches/connection_handler.rs b/swarm/benches/connection_handler.rs new file mode 100644 index 00000000000..b9986d9649f --- /dev/null +++ b/swarm/benches/connection_handler.rs @@ -0,0 +1,359 @@ +use async_std::stream::StreamExt; +use criterion::{criterion_group, criterion_main, Criterion}; +use libp2p_core::{ + transport::MemoryTransport, InboundUpgrade, Multiaddr, OutboundUpgrade, Transport, UpgradeInfo, +}; +use libp2p_identity::PeerId; +use libp2p_swarm::{ConnectionHandler, NetworkBehaviour, StreamProtocol}; +use std::{convert::Infallible, sync::atomic::AtomicUsize}; +use web_time::Duration; + +macro_rules! gen_behaviour { + ($($name:ident {$($field:ident),*};)*) => {$( + #[derive(libp2p_swarm::NetworkBehaviour, Default)] + #[behaviour(prelude = "libp2p_swarm::derive_prelude")] + struct $name { + $($field: SpinningBehaviour,)* + } + + impl BigBehaviour for $name { + fn behaviours(&mut self) -> &mut [SpinningBehaviour] { + unsafe { + std::slice::from_raw_parts_mut( + self as *mut Self as *mut SpinningBehaviour, + std::mem::size_of::() / std::mem::size_of::(), + ) + } + } + } + )*}; +} + +macro_rules! benchmarks { + ($( + $group:ident::[$( + $beh:ident::bench() + .name($name:ident) + .poll_count($count:expr) + .protocols_per_behaviour($protocols:expr), + )+]; + )*) => { + + $( + $( + fn $name(c: &mut Criterion) { + <$beh>::run_bench(c, $protocols, $count, true); + } + )+ + + criterion_group!($group, $($name),*); + )* + + criterion_main!($($group),*); + }; +} + +// fans go brrr +gen_behaviour! { + SpinningBehaviour5 { a, b, c, d, e }; + SpinningBehaviour10 { a, b, c, d, e, f, g, h, i, j }; + SpinningBehaviour20 { a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u }; +} + +benchmarks! { + singles::[ + SpinningBehaviour::bench().name(b).poll_count(1000).protocols_per_behaviour(10), + SpinningBehaviour::bench().name(c).poll_count(1000).protocols_per_behaviour(100), + SpinningBehaviour::bench().name(d).poll_count(1000).protocols_per_behaviour(1000), + ]; + big_5::[ + SpinningBehaviour5::bench().name(e).poll_count(1000).protocols_per_behaviour(2), + SpinningBehaviour5::bench().name(f).poll_count(1000).protocols_per_behaviour(20), + SpinningBehaviour5::bench().name(g).poll_count(1000).protocols_per_behaviour(200), + ]; + top_10::[ + SpinningBehaviour10::bench().name(h).poll_count(1000).protocols_per_behaviour(1), + SpinningBehaviour10::bench().name(i).poll_count(1000).protocols_per_behaviour(10), + SpinningBehaviour10::bench().name(j).poll_count(1000).protocols_per_behaviour(100), + ]; + lucky_20::[ + SpinningBehaviour20::bench().name(k).poll_count(500).protocols_per_behaviour(1), + SpinningBehaviour20::bench().name(l).poll_count(500).protocols_per_behaviour(10), + SpinningBehaviour20::bench().name(m).poll_count(500).protocols_per_behaviour(100), + ]; +} +//fn main() {} + +trait BigBehaviour: Sized { + fn behaviours(&mut self) -> &mut [SpinningBehaviour]; + + fn for_each_beh(&mut self, f: impl FnMut(&mut SpinningBehaviour)) { + self.behaviours().iter_mut().for_each(f); + } + + fn any_beh(&mut self, f: impl FnMut(&mut SpinningBehaviour) -> bool) -> bool { + self.behaviours().iter_mut().any(f) + } + + fn run_bench( + c: &mut Criterion, + protocols_per_behaviour: usize, + spam_count: usize, + static_protocols: bool, + ) where + Self: Default + NetworkBehaviour, + { + let name = format!( + "{}::bench().poll_count({}).protocols_per_behaviour({})", + std::any::type_name::(), + spam_count, + protocols_per_behaviour + ); + + let init = || { + let mut swarm_a = new_swarm(Self::default()); + let mut swarm_b = new_swarm(Self::default()); + + let behaviour_count = swarm_a.behaviours().len(); + let protocol_count = behaviour_count * protocols_per_behaviour; + let protocols = (0..protocol_count) + .map(|i| { + if static_protocols { + StreamProtocol::new(format!("/protocol/{i}").leak()) + } else { + StreamProtocol::try_from_owned(format!("/protocol/{i}")).unwrap() + } + }) + .collect::>() + .leak(); + + let mut protocol_chunks = protocols.chunks(protocols_per_behaviour); + swarm_a.for_each_beh(|b| b.protocols = protocol_chunks.next().unwrap()); + let mut protocol_chunks = protocols.chunks(protocols_per_behaviour); + swarm_b.for_each_beh(|b| b.protocols = protocol_chunks.next().unwrap()); + + swarm_a.for_each_beh(|b| b.iter_count = spam_count); + swarm_b.for_each_beh(|b| b.iter_count = 0); + + swarm_a.for_each_beh(|b| b.other_peer = Some(*swarm_b.local_peer_id())); + swarm_b.for_each_beh(|b| b.other_peer = Some(*swarm_a.local_peer_id())); + + static OFFSET: AtomicUsize = AtomicUsize::new(8000); + let offset = OFFSET.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + swarm_a + .listen_on(format!("/memory/{offset}").parse().unwrap()) + .unwrap(); + swarm_b + .dial(format!("/memory/{offset}").parse::().unwrap()) + .unwrap(); + + (swarm_a, swarm_b) + }; + + c.bench_function(&name, |b| { + b.to_async(tokio::runtime::Builder::new_multi_thread().build().unwrap()) + .iter_batched( + init, + |(mut swarm_a, mut swarm_b)| async move { + while swarm_a.any_beh(|b| !b.finished) || swarm_b.any_beh(|b| !b.finished) { + futures::future::select(swarm_b.next(), swarm_a.next()).await; + } + }, + criterion::BatchSize::LargeInput, + ); + }); + } +} + +impl BigBehaviour for libp2p_swarm::Swarm { + fn behaviours(&mut self) -> &mut [SpinningBehaviour] { + self.behaviour_mut().behaviours() + } +} + +fn new_swarm(beh: T) -> libp2p_swarm::Swarm { + let keypair = libp2p_identity::Keypair::generate_ed25519(); + libp2p_swarm::Swarm::new( + MemoryTransport::default() + .upgrade(multistream_select::Version::V1) + .authenticate(libp2p_plaintext::Config::new(&keypair)) + .multiplex(libp2p_yamux::Config::default()) + .boxed(), + beh, + keypair.public().to_peer_id(), + libp2p_swarm::Config::without_executor().with_idle_connection_timeout(Duration::MAX), + ) +} + +/// Whole purpose of the behaviour is to rapidly call `poll` on the handler +/// configured amount of times and then emit event when finished. +#[derive(Default)] +struct SpinningBehaviour { + iter_count: usize, + protocols: &'static [StreamProtocol], + finished: bool, + emitted: bool, + other_peer: Option, +} + +#[derive(Debug)] +struct FinishedSpinning; + +impl NetworkBehaviour for SpinningBehaviour { + type ConnectionHandler = SpinningHandler; + type ToSwarm = FinishedSpinning; + + fn handle_established_inbound_connection( + &mut self, + _connection_id: libp2p_swarm::ConnectionId, + _peer: libp2p_identity::PeerId, + _local_addr: &libp2p_core::Multiaddr, + _remote_addr: &libp2p_core::Multiaddr, + ) -> Result, libp2p_swarm::ConnectionDenied> { + Ok(SpinningHandler { + iter_count: 0, + protocols: self.protocols, + }) + } + + fn handle_established_outbound_connection( + &mut self, + _connection_id: libp2p_swarm::ConnectionId, + _peer: libp2p_identity::PeerId, + _addr: &libp2p_core::Multiaddr, + _role_override: libp2p_core::Endpoint, + ) -> Result, libp2p_swarm::ConnectionDenied> { + Ok(SpinningHandler { + iter_count: self.iter_count, + protocols: self.protocols, + }) + } + + fn on_swarm_event(&mut self, _: libp2p_swarm::FromSwarm) {} + + fn on_connection_handler_event( + &mut self, + _peer_id: libp2p_identity::PeerId, + _connection_id: libp2p_swarm::ConnectionId, + _event: libp2p_swarm::THandlerOutEvent, + ) { + self.finished = true; + } + + fn poll( + &mut self, + _: &mut std::task::Context<'_>, + ) -> std::task::Poll>> + { + if self.finished && !self.emitted { + self.emitted = true; + std::task::Poll::Ready(libp2p_swarm::ToSwarm::GenerateEvent(FinishedSpinning)) + } else { + std::task::Poll::Pending + } + } +} + +impl BigBehaviour for SpinningBehaviour { + fn behaviours(&mut self) -> &mut [SpinningBehaviour] { + std::slice::from_mut(self) + } +} + +struct SpinningHandler { + iter_count: usize, + protocols: &'static [StreamProtocol], +} + +impl ConnectionHandler for SpinningHandler { + type FromBehaviour = Infallible; + + type ToBehaviour = FinishedSpinning; + + type InboundProtocol = Upgrade; + + type OutboundProtocol = Upgrade; + + type InboundOpenInfo = (); + + type OutboundOpenInfo = (); + + fn listen_protocol( + &self, + ) -> libp2p_swarm::SubstreamProtocol { + libp2p_swarm::SubstreamProtocol::new(Upgrade(self.protocols), ()) + } + + fn poll( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll< + libp2p_swarm::ConnectionHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::ToBehaviour, + >, + > { + if self.iter_count == usize::MAX { + return std::task::Poll::Pending; + } + + if self.iter_count != 0 { + self.iter_count -= 1; + cx.waker().wake_by_ref(); + return std::task::Poll::Pending; + } + + self.iter_count = usize::MAX; + std::task::Poll::Ready(libp2p_swarm::ConnectionHandlerEvent::NotifyBehaviour( + FinishedSpinning, + )) + } + + fn on_behaviour_event(&mut self, event: Self::FromBehaviour) { + match event {} + } + + fn on_connection_event( + &mut self, + _event: libp2p_swarm::handler::ConnectionEvent< + Self::InboundProtocol, + Self::OutboundProtocol, + Self::InboundOpenInfo, + Self::OutboundOpenInfo, + >, + ) { + } +} + +pub struct Upgrade(&'static [StreamProtocol]); + +impl UpgradeInfo for Upgrade { + type Info = &'static StreamProtocol; + type InfoIter = std::slice::Iter<'static, StreamProtocol>; + + fn protocol_info(&self) -> Self::InfoIter { + self.0.iter() + } +} + +impl OutboundUpgrade for Upgrade { + type Output = libp2p_swarm::Stream; + type Error = Infallible; + type Future = futures::future::Ready>; + + fn upgrade_outbound(self, s: libp2p_swarm::Stream, _: Self::Info) -> Self::Future { + futures::future::ready(Ok(s)) + } +} + +impl InboundUpgrade for Upgrade { + type Output = libp2p_swarm::Stream; + type Error = Infallible; + type Future = futures::future::Ready>; + + fn upgrade_inbound(self, s: libp2p_swarm::Stream, _: Self::Info) -> Self::Future { + futures::future::ready(Ok(s)) + } +} diff --git a/swarm/src/connection.rs b/swarm/src/connection.rs index 69f95bca1d3..9c5e39830ed 100644 --- a/swarm/src/connection.rs +++ b/swarm/src/connection.rs @@ -31,8 +31,7 @@ pub use supported_protocols::SupportedProtocols; use crate::handler::{ AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError, FullyNegotiatedInbound, - FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport, ProtocolsAdded, ProtocolsChange, - UpgradeInfoSend, + FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport, ProtocolsChange, UpgradeInfoSend, }; use crate::stream::ActiveStreamCounter; use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; @@ -51,7 +50,7 @@ use libp2p_core::upgrade; use libp2p_core::upgrade::{NegotiationError, ProtocolError}; use libp2p_core::Endpoint; use libp2p_identity::PeerId; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt::{Display, Formatter}; use std::future::Future; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -153,8 +152,11 @@ where SubstreamRequested, >, - local_supported_protocols: HashSet, + local_supported_protocols: + HashMap::Info>, bool>, remote_supported_protocols: HashSet, + protocol_buffer: Vec, + idle_timeout: Duration, stream_counter: ActiveStreamCounter, } @@ -187,11 +189,17 @@ where idle_timeout: Duration, ) -> Self { let initial_protocols = gather_supported_protocols(&handler); + let mut buffer = Vec::new(); + if !initial_protocols.is_empty() { handler.on_connection_event(ConnectionEvent::LocalProtocolsChange( - ProtocolsChange::Added(ProtocolsAdded::from_set(&initial_protocols)), + ProtocolsChange::from_initial_protocols( + initial_protocols.keys().map(|e| &e.0), + &mut buffer, + ), )); } + Connection { muxing: muxer, handler, @@ -203,6 +211,7 @@ where requested_substreams: Default::default(), local_supported_protocols: initial_protocols, remote_supported_protocols: Default::default(), + protocol_buffer: buffer, idle_timeout, stream_counter: ActiveStreamCounter::default(), } @@ -250,6 +259,7 @@ where substream_upgrade_protocol_override, local_supported_protocols: supported_protocols, remote_supported_protocols, + protocol_buffer, idle_timeout, stream_counter, .. @@ -287,25 +297,24 @@ where ProtocolSupport::Added(protocols), )) => { if let Some(added) = - ProtocolsChange::add(remote_supported_protocols, &protocols) + ProtocolsChange::add(remote_supported_protocols, protocols, protocol_buffer) { handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added)); - remote_supported_protocols.extend(protocols); + remote_supported_protocols.extend(protocol_buffer.drain(..)); } - continue; } Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols( ProtocolSupport::Removed(protocols), )) => { - if let Some(removed) = - ProtocolsChange::remove(remote_supported_protocols, &protocols) - { + if let Some(removed) = ProtocolsChange::remove( + remote_supported_protocols, + protocols, + protocol_buffer, + ) { handler .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed)); - remote_supported_protocols.retain(|p| !protocols.contains(p)); } - continue; } } @@ -431,16 +440,16 @@ where } } - let new_protocols = gather_supported_protocols(handler); - let changes = ProtocolsChange::from_full_sets(supported_protocols, &new_protocols); + let changes = ProtocolsChange::from_full_sets( + supported_protocols, + handler.listen_protocol().upgrade().protocol_info(), + protocol_buffer, + ); if !changes.is_empty() { for change in changes { handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change)); } - - *supported_protocols = new_protocols; - continue; // Go back to the top, handler can potentially make progress again. } @@ -454,12 +463,14 @@ where } } -fn gather_supported_protocols(handler: &impl ConnectionHandler) -> HashSet { +fn gather_supported_protocols( + handler: &C, +) -> HashMap::Info>, bool> { handler .listen_protocol() .upgrade() .protocol_info() - .filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok()) + .map(|info| (AsStrHashEq(info), true)) .collect() } @@ -734,6 +745,25 @@ enum Shutdown { Later(Delay), } +// Structure used to avoid allocations when storing the protocols in the `HashMap. +// Instead of allocating a new `String` for the key, +// we use `T::as_ref()` in `Hash`, `Eq` and `PartialEq` requirements. +pub(crate) struct AsStrHashEq(pub(crate) T); + +impl> Eq for AsStrHashEq {} + +impl> PartialEq for AsStrHashEq { + fn eq(&self, other: &Self) -> bool { + self.0.as_ref() == other.0.as_ref() + } +} + +impl> std::hash::Hash for AsStrHashEq { + fn hash(&self, state: &mut H) { + self.0.as_ref().hash(state) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/swarm/src/connection/supported_protocols.rs b/swarm/src/connection/supported_protocols.rs index 0575046bb44..124ec93d669 100644 --- a/swarm/src/connection/supported_protocols.rs +++ b/swarm/src/connection/supported_protocols.rs @@ -40,7 +40,6 @@ impl SupportedProtocols { mod tests { use super::*; use crate::handler::{ProtocolsAdded, ProtocolsRemoved}; - use once_cell::sync::Lazy; #[test] fn protocols_change_added_returns_correct_changed_value() { @@ -70,19 +69,24 @@ mod tests { } fn add_foo() -> ProtocolsChange<'static> { - ProtocolsChange::Added(ProtocolsAdded::from_set(&FOO_PROTOCOLS)) + ProtocolsChange::Added(ProtocolsAdded { + protocols: FOO_PROTOCOLS.iter(), + }) } fn add_foo_bar() -> ProtocolsChange<'static> { - ProtocolsChange::Added(ProtocolsAdded::from_set(&FOO_BAR_PROTOCOLS)) + ProtocolsChange::Added(ProtocolsAdded { + protocols: FOO_BAR_PROTOCOLS.iter(), + }) } fn remove_foo() -> ProtocolsChange<'static> { - ProtocolsChange::Removed(ProtocolsRemoved::from_set(&FOO_PROTOCOLS)) + ProtocolsChange::Removed(ProtocolsRemoved { + protocols: FOO_PROTOCOLS.iter(), + }) } - static FOO_PROTOCOLS: Lazy> = - Lazy::new(|| HashSet::from([StreamProtocol::new("/foo")])); - static FOO_BAR_PROTOCOLS: Lazy> = - Lazy::new(|| HashSet::from([StreamProtocol::new("/foo"), StreamProtocol::new("/bar")])); + static FOO_PROTOCOLS: &[StreamProtocol] = &[StreamProtocol::new("/foo")]; + static FOO_BAR_PROTOCOLS: &[StreamProtocol] = + &[StreamProtocol::new("/foo"), StreamProtocol::new("/bar")]; } diff --git a/swarm/src/handler.rs b/swarm/src/handler.rs index 20980aff8bd..610b95b8cf1 100644 --- a/swarm/src/handler.rs +++ b/swarm/src/handler.rs @@ -46,22 +46,19 @@ mod one_shot; mod pending; mod select; +use crate::connection::AsStrHashEq; pub use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper, UpgradeInfoSend}; pub use map_in::MapInEvent; pub use map_out::MapOutEvent; pub use one_shot::{OneShotHandler, OneShotHandlerConfig}; pub use pending::PendingConnectionHandler; pub use select::ConnectionHandlerSelect; +use smallvec::SmallVec; use crate::StreamProtocol; -use ::either::Either; +use core::slice; use libp2p_core::Multiaddr; -use once_cell::sync::Lazy; -use smallvec::SmallVec; -use std::collections::hash_map::RandomState; -use std::collections::hash_set::{Difference, Intersection}; -use std::collections::HashSet; -use std::iter::Peekable; +use std::collections::{HashMap, HashSet}; use std::{error, fmt, io, task::Context, task::Poll, time::Duration}; /// A handler for a set of protocols used on a connection with a remote. @@ -335,64 +332,124 @@ pub enum ProtocolsChange<'a> { } impl<'a> ProtocolsChange<'a> { + /// Compute the protocol change for the initial set of protocols. + pub(crate) fn from_initial_protocols<'b, T: AsRef + 'b>( + new_protocols: impl IntoIterator, + buffer: &'a mut Vec, + ) -> Self { + buffer.clear(); + buffer.extend( + new_protocols + .into_iter() + .filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok()), + ); + + ProtocolsChange::Added(ProtocolsAdded { + protocols: buffer.iter(), + }) + } + /// Compute the [`ProtocolsChange`] that results from adding `to_add` to `existing_protocols`. /// /// Returns `None` if the change is a no-op, i.e. `to_add` is a subset of `existing_protocols`. pub(crate) fn add( - existing_protocols: &'a HashSet, - to_add: &'a HashSet, + existing_protocols: &HashSet, + to_add: HashSet, + buffer: &'a mut Vec, ) -> Option { - let mut actually_added_protocols = to_add.difference(existing_protocols).peekable(); - - actually_added_protocols.peek()?; + buffer.clear(); + buffer.extend( + to_add + .into_iter() + .filter(|i| !existing_protocols.contains(i)), + ); + + if buffer.is_empty() { + return None; + } - Some(ProtocolsChange::Added(ProtocolsAdded { - protocols: actually_added_protocols, + Some(Self::Added(ProtocolsAdded { + protocols: buffer.iter(), })) } - /// Compute the [`ProtocolsChange`] that results from removing `to_remove` from `existing_protocols`. + /// Compute the [`ProtocolsChange`] that results from removing `to_remove` from `existing_protocols`. Removes the protocols from `existing_protocols`. /// /// Returns `None` if the change is a no-op, i.e. none of the protocols in `to_remove` are in `existing_protocols`. pub(crate) fn remove( - existing_protocols: &'a HashSet, - to_remove: &'a HashSet, + existing_protocols: &mut HashSet, + to_remove: HashSet, + buffer: &'a mut Vec, ) -> Option { - let mut actually_removed_protocols = existing_protocols.intersection(to_remove).peekable(); - - actually_removed_protocols.peek()?; + buffer.clear(); + buffer.extend( + to_remove + .into_iter() + .filter_map(|i| existing_protocols.take(&i)), + ); + + if buffer.is_empty() { + return None; + } - Some(ProtocolsChange::Removed(ProtocolsRemoved { - protocols: Either::Right(actually_removed_protocols), + Some(Self::Removed(ProtocolsRemoved { + protocols: buffer.iter(), })) } /// Compute the [`ProtocolsChange`]s required to go from `existing_protocols` to `new_protocols`. - pub(crate) fn from_full_sets( - existing_protocols: &'a HashSet, - new_protocols: &'a HashSet, + pub(crate) fn from_full_sets>( + existing_protocols: &mut HashMap, bool>, + new_protocols: impl IntoIterator, + buffer: &'a mut Vec, ) -> SmallVec<[Self; 2]> { - if existing_protocols == new_protocols { + buffer.clear(); + + // Initially, set the boolean for all protocols to `false`, meaning "not visited". + for v in existing_protocols.values_mut() { + *v = false; + } + + let mut new_protocol_count = 0; // We can only iterate `new_protocols` once, so keep track of its length separately. + for new_protocol in new_protocols { + existing_protocols + .entry(AsStrHashEq(new_protocol)) + .and_modify(|v| *v = true) // Mark protocol as visited (i.e. we still support it) + .or_insert_with_key(|k| { + // Encountered a previously unsupported protocol, remember it in `buffer`. + buffer.extend(StreamProtocol::try_from_owned(k.0.as_ref().to_owned()).ok()); + true + }); + new_protocol_count += 1; + } + + if new_protocol_count == existing_protocols.len() && buffer.is_empty() { return SmallVec::new(); } - let mut changes = SmallVec::new(); + let num_new_protocols = buffer.len(); + // Drain all protocols that we haven't visited. + // For existing protocols that are not in `new_protocols`, the boolean will be false, meaning we need to remove it. + existing_protocols.retain(|p, &mut is_supported| { + if !is_supported { + buffer.extend(StreamProtocol::try_from_owned(p.0.as_ref().to_owned()).ok()); + } - let mut added_protocols = new_protocols.difference(existing_protocols).peekable(); - let mut removed_protocols = existing_protocols.difference(new_protocols).peekable(); + is_supported + }); - if added_protocols.peek().is_some() { + let (added, removed) = buffer.split_at(num_new_protocols); + let mut changes = SmallVec::new(); + if !added.is_empty() { changes.push(ProtocolsChange::Added(ProtocolsAdded { - protocols: added_protocols, + protocols: added.iter(), })); } - - if removed_protocols.peek().is_some() { + if !removed.is_empty() { changes.push(ProtocolsChange::Removed(ProtocolsRemoved { - protocols: Either::Left(removed_protocols), + protocols: removed.iter(), })); } - changes } } @@ -400,33 +457,13 @@ impl<'a> ProtocolsChange<'a> { /// An [`Iterator`] over all protocols that have been added. #[derive(Debug, Clone)] pub struct ProtocolsAdded<'a> { - protocols: Peekable>, -} - -impl<'a> ProtocolsAdded<'a> { - pub(crate) fn from_set(protocols: &'a HashSet) -> Self { - ProtocolsAdded { - protocols: protocols.difference(&EMPTY_HASHSET).peekable(), - } - } + pub(crate) protocols: slice::Iter<'a, StreamProtocol>, } /// An [`Iterator`] over all protocols that have been removed. #[derive(Debug, Clone)] pub struct ProtocolsRemoved<'a> { - protocols: Either< - Peekable>, - Peekable>, - >, -} - -impl<'a> ProtocolsRemoved<'a> { - #[cfg(test)] - pub(crate) fn from_set(protocols: &'a HashSet) -> Self { - ProtocolsRemoved { - protocols: Either::Left(protocols.difference(&EMPTY_HASHSET).peekable()), - } - } + pub(crate) protocols: slice::Iter<'a, StreamProtocol>, } impl<'a> Iterator for ProtocolsAdded<'a> { @@ -691,6 +728,169 @@ where } } -/// A statically declared, empty [`HashSet`] allows us to work around borrow-checker rules for -/// [`ProtocolsAdded::from_set`]. The lifetimes don't work unless we have a [`HashSet`] with a `'static' lifetime. -static EMPTY_HASHSET: Lazy> = Lazy::new(HashSet::new); +#[cfg(test)] +mod test { + use super::*; + + fn protocol_set_of(s: &'static str) -> HashSet { + s.split_whitespace() + .map(|p| StreamProtocol::try_from_owned(format!("/{p}")).unwrap()) + .collect() + } + + fn test_remove( + existing: &mut HashSet, + to_remove: HashSet, + ) -> HashSet { + ProtocolsChange::remove(existing, to_remove, &mut Vec::new()) + .into_iter() + .flat_map(|c| match c { + ProtocolsChange::Added(_) => panic!("unexpected added"), + ProtocolsChange::Removed(r) => r.cloned(), + }) + .collect::>() + } + + #[test] + fn test_protocol_remove_subset() { + let mut existing = protocol_set_of("a b c"); + let to_remove = protocol_set_of("a b"); + + let change = test_remove(&mut existing, to_remove); + + assert_eq!(existing, protocol_set_of("c")); + assert_eq!(change, protocol_set_of("a b")); + } + + #[test] + fn test_protocol_remove_all() { + let mut existing = protocol_set_of("a b c"); + let to_remove = protocol_set_of("a b c"); + + let change = test_remove(&mut existing, to_remove); + + assert_eq!(existing, protocol_set_of("")); + assert_eq!(change, protocol_set_of("a b c")); + } + + #[test] + fn test_protocol_remove_superset() { + let mut existing = protocol_set_of("a b c"); + let to_remove = protocol_set_of("a b c d"); + + let change = test_remove(&mut existing, to_remove); + + assert_eq!(existing, protocol_set_of("")); + assert_eq!(change, protocol_set_of("a b c")); + } + + #[test] + fn test_protocol_remove_none() { + let mut existing = protocol_set_of("a b c"); + let to_remove = protocol_set_of("d"); + + let change = test_remove(&mut existing, to_remove); + + assert_eq!(existing, protocol_set_of("a b c")); + assert_eq!(change, protocol_set_of("")); + } + + #[test] + fn test_protocol_remove_none_from_empty() { + let mut existing = protocol_set_of(""); + let to_remove = protocol_set_of("d"); + + let change = test_remove(&mut existing, to_remove); + + assert_eq!(existing, protocol_set_of("")); + assert_eq!(change, protocol_set_of("")); + } + + fn test_from_full_sets( + existing: HashSet, + new: HashSet, + ) -> [HashSet; 2] { + let mut buffer = Vec::new(); + let mut existing = existing + .iter() + .map(|p| (AsStrHashEq(p.as_ref()), true)) + .collect::>(); + + let changes = ProtocolsChange::from_full_sets( + &mut existing, + new.iter().map(AsRef::as_ref), + &mut buffer, + ); + + let mut added_changes = HashSet::new(); + let mut removed_changes = HashSet::new(); + + for change in changes { + match change { + ProtocolsChange::Added(a) => { + added_changes.extend(a.cloned()); + } + ProtocolsChange::Removed(r) => { + removed_changes.extend(r.cloned()); + } + } + } + + [removed_changes, added_changes] + } + + #[test] + fn test_from_full_stes_subset() { + let existing = protocol_set_of("a b c"); + let new = protocol_set_of("a b"); + + let [removed_changes, added_changes] = test_from_full_sets(existing, new); + + assert_eq!(added_changes, protocol_set_of("")); + assert_eq!(removed_changes, protocol_set_of("c")); + } + + #[test] + fn test_from_full_sets_superset() { + let existing = protocol_set_of("a b"); + let new = protocol_set_of("a b c"); + + let [removed_changes, added_changes] = test_from_full_sets(existing, new); + + assert_eq!(added_changes, protocol_set_of("c")); + assert_eq!(removed_changes, protocol_set_of("")); + } + + #[test] + fn test_from_full_sets_intersection() { + let existing = protocol_set_of("a b c"); + let new = protocol_set_of("b c d"); + + let [removed_changes, added_changes] = test_from_full_sets(existing, new); + + assert_eq!(added_changes, protocol_set_of("d")); + assert_eq!(removed_changes, protocol_set_of("a")); + } + + #[test] + fn test_from_full_sets_disjoint() { + let existing = protocol_set_of("a b c"); + let new = protocol_set_of("d e f"); + + let [removed_changes, added_changes] = test_from_full_sets(existing, new); + + assert_eq!(added_changes, protocol_set_of("d e f")); + assert_eq!(removed_changes, protocol_set_of("a b c")); + } + + #[test] + fn test_from_full_sets_empty() { + let existing = protocol_set_of(""); + let new = protocol_set_of(""); + + let [removed_changes, added_changes] = test_from_full_sets(existing, new); + + assert_eq!(added_changes, protocol_set_of("")); + assert_eq!(removed_changes, protocol_set_of("")); + } +} diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index ec5b7a109cc..fb02cdce392 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -1416,6 +1416,14 @@ impl Config { } } + #[doc(hidden)] + /// Used on connection benchmarks. + pub fn without_executor() -> Self { + Self { + pool_config: PoolConfig::new(None), + } + } + /// Sets executor to the `wasm` executor. /// Background tasks will be executed by the browser on the next micro-tick. ///