diff --git a/crates/core/src/router.rs b/crates/core/src/router.rs index cf495998e..41573bd24 100644 --- a/crates/core/src/router.rs +++ b/crates/core/src/router.rs @@ -16,6 +16,7 @@ pub(crate) struct Router { transfer_rate_estimator: IsotonicEstimator, failure_estimator: IsotonicEstimator, mean_transfer_size: Mean, + consider_n_closest_peers: usize, } impl Router { @@ -106,9 +107,16 @@ impl Router { EstimatorType::Negative, ), mean_transfer_size, + consider_n_closest_peers: 20, } } + #[allow(dead_code)] + pub fn considering_n_closest_peers(mut self, n: u32) -> Self { + self.consider_n_closest_peers = n as usize; + self + } + pub fn add_event(&mut self, event: RouteEvent) { match event.outcome { RouteOutcome::Success { @@ -145,6 +153,33 @@ impl Router { } } + fn select_closest_peers<'a>( + &self, + peers: impl IntoIterator, + target_location: &Location, + ) -> Vec<&'a PeerKeyLocation> { + let mut heap = + std::collections::BinaryHeap::with_capacity(self.consider_n_closest_peers + 1); + + for peer_location in peers { + if let Some(location) = peer_location.location.as_ref() { + let distance = target_location.distance(location); + heap.push((distance, peer_location)); + + // Ensure we keep the heap size to specified capacity + if heap.len() > self.consider_n_closest_peers { + heap.pop(); + } + } + } + + // Convert the heap to a sorted vector + heap.into_sorted_vec() + .into_iter() + .map(|(_, peer_location)| peer_location) + .collect() + } + pub fn select_peer<'a>( &self, peers: impl IntoIterator, @@ -163,7 +198,7 @@ impl Router { .map(|(peer, _)| peer) } else { // Find the peer with the minimum predicted routing outcome time - peers + self.select_closest_peers(peers, &target_location) .into_iter() .map(|peer: &PeerKeyLocation| { let t = self.predict_routing_outcome(peer, target_location).expect( @@ -279,6 +314,8 @@ pub enum RouteOutcome { mod tests { use rand::Rng; + use crate::ring::Distance; + use super::*; #[test] @@ -393,6 +430,46 @@ mod tests { } } + #[test] + fn test_select_closest_peers_size() { + const NUM_PEERS: u32 = 45; + const CAP: u32 = 30; + + assert_eq!( + CAP as usize, + Router::new(&[]) + .considering_n_closest_peers(CAP) + .select_closest_peers(&create_peers(NUM_PEERS), &Location::random()) + .len() + ); + } + + #[test] + fn test_select_closest_peers_equality() { + const NUM_PEERS: u32 = 100; + const CLOSEST_CAP: u32 = 10; + let peers: Vec = create_peers(NUM_PEERS); + let contract_location = Location::random(); + + let expected_closest = select_closest_peers_vec(CLOSEST_CAP, &peers, &contract_location); + + // Create a router with no historical data + let router = Router::new(&[]).considering_n_closest_peers(CLOSEST_CAP); + let asserted_closest: Vec<&PeerKeyLocation> = + router.select_closest_peers(&peers, &contract_location); + + let mut expected_iter = expected_closest.iter(); + let mut asserted_iter = asserted_closest.iter(); + + while let (Some(expected_location), Some(asserted_location)) = + (expected_iter.next(), asserted_iter.next()) + { + assert_eq!(**expected_location, **asserted_location); + } + + assert_eq!(expected_iter.next(), asserted_iter.next()); + } + fn simulate_prediction( random: &mut rand::rngs::ThreadRng, peer: PeerKeyLocation, @@ -413,4 +490,35 @@ mod tests { expected_total_time: time_to_response_start + transfer_time, } } + + fn select_closest_peers_vec<'a>( + closest_peers_capacity: u32, + peers: impl IntoIterator, + target_location: &Location, + ) -> Vec<&'a PeerKeyLocation> + where + PeerKeyLocation: Clone, + { + let mut closest: Vec<&'a PeerKeyLocation> = peers.into_iter().collect(); + closest.sort_by_key(|&peer| { + if let Some(location) = peer.location { + target_location.distance(location) + } else { + Distance::new(f64::MAX) + } + }); + + closest[..closest_peers_capacity as usize].to_vec() + } + + fn create_peers(num_peers: u32) -> Vec { + let mut peers: Vec = vec![]; + + for _ in 0..num_peers { + let peer = PeerKeyLocation::random(); + peers.push(peer); + } + + peers + } }