diff --git a/chitchat-test/src/main.rs b/chitchat-test/src/main.rs index e8f1837..ef468af 100644 --- a/chitchat-test/src/main.rs +++ b/chitchat-test/src/main.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, SystemTime}; use chitchat::transport::UdpTransport; use chitchat::{spawn_chitchat, Chitchat, ChitchatConfig, ChitchatId, FailureDetectorConfig}; @@ -28,7 +28,11 @@ impl Api { cluster_id: chitchat_guard.cluster_id().to_string(), cluster_state: chitchat_guard.state_snapshot(), live_nodes: chitchat_guard.live_nodes().cloned().collect::>(), - dead_nodes: chitchat_guard.dead_nodes().cloned().collect::>(), + dead_nodes: chitchat_guard + .dead_nodes() + .cloned() + .map(|node| node.0) + .collect::>(), }; Json(serde_json::to_value(&response).unwrap()) } @@ -84,7 +88,11 @@ async fn main() -> anyhow::Result<()> { let node_id = opt .node_id .unwrap_or_else(|| generate_server_id(public_addr)); - let chitchat_id = ChitchatId::new(node_id, 0, public_addr); + let generation = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(); + let chitchat_id = ChitchatId::new(node_id, generation, public_addr); let config = ChitchatConfig { cluster_id: "testing".to_string(), chitchat_id, diff --git a/chitchat/src/delta.rs b/chitchat/src/delta.rs index ced1fe6..b03898f 100644 --- a/chitchat/src/delta.rs +++ b/chitchat/src/delta.rs @@ -2,40 +2,40 @@ use std::collections::{BTreeMap, HashSet}; use std::mem; use crate::serialize::*; -use crate::{ChitchatId, Heartbeat, MaxVersion, VersionedValue}; +use crate::{ChitchatId, ChitchatIdGenerationEq, Heartbeat, MaxVersion, VersionedValue}; #[derive(Debug, Default, Eq, PartialEq)] pub struct Delta { - pub(crate) node_deltas: BTreeMap, - pub(crate) nodes_to_reset: HashSet, + pub(crate) node_deltas: BTreeMap, + pub(crate) nodes_to_reset: HashSet, } impl Serializable for Delta { fn serialize(&self, buf: &mut Vec) { (self.node_deltas.len() as u16).serialize(buf); for (chitchat_id, node_delta) in &self.node_deltas { - chitchat_id.serialize(buf); + chitchat_id.0.serialize(buf); node_delta.serialize(buf); } (self.nodes_to_reset.len() as u16).serialize(buf); for chitchat_id in &self.nodes_to_reset { - chitchat_id.serialize(buf); + chitchat_id.0.serialize(buf); } } fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let mut node_deltas: BTreeMap = Default::default(); + let mut node_deltas: BTreeMap = Default::default(); let num_nodes = u16::deserialize(buf)?; for _ in 0..num_nodes { let chitchat_id = ChitchatId::deserialize(buf)?; let node_delta = NodeDelta::deserialize(buf)?; - node_deltas.insert(chitchat_id, node_delta); + node_deltas.insert(ChitchatIdGenerationEq(chitchat_id), node_delta); } let num_nodes_to_reset = u16::deserialize(buf)?; let mut nodes_to_reset = HashSet::with_capacity(num_nodes_to_reset as usize); for _ in 0..num_nodes_to_reset { let chitchat_id = ChitchatId::deserialize(buf)?; - nodes_to_reset.insert(chitchat_id); + nodes_to_reset.insert(ChitchatIdGenerationEq(chitchat_id)); } Ok(Delta { node_deltas, @@ -46,12 +46,12 @@ impl Serializable for Delta { fn serialized_len(&self) -> usize { let mut len = 2; for (chitchat_id, node_delta) in &self.node_deltas { - len += chitchat_id.serialized_len(); + len += chitchat_id.0.serialized_len(); len += node_delta.serialized_len(); } len += 2; for chitchat_id in &self.nodes_to_reset { - len += chitchat_id.serialized_len(); + len += chitchat_id.0.serialized_len(); } len } @@ -68,7 +68,7 @@ impl Delta { pub fn add_node(&mut self, chitchat_id: ChitchatId, heartbeat: Heartbeat) { self.node_deltas - .entry(chitchat_id) + .entry(ChitchatIdGenerationEq(chitchat_id)) .or_insert_with(|| NodeDelta { heartbeat, ..Default::default() @@ -83,7 +83,10 @@ impl Delta { version: crate::Version, tombstone: Option, ) { - let node_delta = self.node_deltas.get_mut(chitchat_id).unwrap(); + let node_delta = self + .node_deltas + .get_mut(&ChitchatIdGenerationEq(chitchat_id.clone())) + .unwrap(); node_delta.max_version = node_delta.max_version.max(version); node_delta.key_values.insert( @@ -97,7 +100,8 @@ impl Delta { } pub fn add_node_to_reset(&mut self, chitchat_id: ChitchatId) { - self.nodes_to_reset.insert(chitchat_id); + self.nodes_to_reset + .insert(ChitchatIdGenerationEq(chitchat_id)); } } @@ -141,13 +145,16 @@ impl DeltaWriter { let chitchat_id_opt = mem::take(&mut self.current_chitchat_id); let node_delta = mem::take(&mut self.current_node_delta); if let Some(chitchat_id) = chitchat_id_opt { - self.delta.node_deltas.insert(chitchat_id, node_delta); + self.delta + .node_deltas + .insert(ChitchatIdGenerationEq(chitchat_id), node_delta); } } pub fn add_node_to_reset(&mut self, chitchat_id: ChitchatId) -> bool { + let chitchat_id = ChitchatIdGenerationEq(chitchat_id); assert!(!self.delta.nodes_to_reset.contains(&chitchat_id)); - if !self.attempt_add_bytes(chitchat_id.serialized_len()) { + if !self.attempt_add_bytes(chitchat_id.0.serialized_len()) { return false; } self.delta.nodes_to_reset.insert(chitchat_id); @@ -155,15 +162,21 @@ impl DeltaWriter { } pub fn add_node(&mut self, chitchat_id: ChitchatId, heartbeat: Heartbeat) -> bool { - assert!(self.current_chitchat_id.as_ref() != Some(&chitchat_id)); + assert!(self + .current_chitchat_id + .as_ref() + .map(|current_node| !current_node.eq_generation(&chitchat_id)) + .unwrap_or(true)); + let chitchat_id = ChitchatIdGenerationEq(chitchat_id); assert!(!self.delta.node_deltas.contains_key(&chitchat_id)); self.flush(); // Reserve bytes for [`ChitchatId`], [`Hearbeat`], and for an empty [`NodeDelta`] which has // a size of 2 bytes. - if !self.attempt_add_bytes(chitchat_id.serialized_len() + heartbeat.serialized_len() + 2) { + if !self.attempt_add_bytes(chitchat_id.0.serialized_len() + heartbeat.serialized_len() + 2) + { return false; } - self.current_chitchat_id = Some(chitchat_id); + self.current_chitchat_id = Some(chitchat_id.0); self.current_node_delta.heartbeat = heartbeat; true } diff --git a/chitchat/src/digest.rs b/chitchat/src/digest.rs index 2371dd7..dd69d79 100644 --- a/chitchat/src/digest.rs +++ b/chitchat/src/digest.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use crate::serialize::*; -use crate::{ChitchatId, Heartbeat, MaxVersion}; +use crate::{ChitchatId, ChitchatIdGenerationEq, Heartbeat, MaxVersion}; #[derive(Debug, Clone, Copy, Default, Eq, PartialEq)] pub(crate) struct NodeDigest { @@ -25,14 +25,15 @@ impl NodeDigest { /// peer -> (heartbeat, max version). #[derive(Debug, Default, Eq, PartialEq)] pub struct Digest { - pub(crate) node_digests: BTreeMap, + pub(crate) node_digests: BTreeMap, } #[cfg(test)] impl Digest { pub fn add_node(&mut self, node: ChitchatId, heartbeat: Heartbeat, max_version: MaxVersion) { let node_digest = NodeDigest::new(heartbeat, max_version); - self.node_digests.insert(node, node_digest); + self.node_digests + .insert(ChitchatIdGenerationEq(node), node_digest); } } @@ -40,7 +41,7 @@ impl Serializable for Digest { fn serialize(&self, buf: &mut Vec) { (self.node_digests.len() as u16).serialize(buf); for (chitchat_id, node_digest) in &self.node_digests { - chitchat_id.serialize(buf); + chitchat_id.0.serialize(buf); node_digest.heartbeat.serialize(buf); node_digest.max_version.serialize(buf); } @@ -48,14 +49,14 @@ impl Serializable for Digest { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let num_nodes = u16::deserialize(buf)?; - let mut node_digests: BTreeMap = Default::default(); + let mut node_digests: BTreeMap = Default::default(); for _ in 0..num_nodes { let chitchat_id = ChitchatId::deserialize(buf)?; let heartbeat = Heartbeat::deserialize(buf)?; let max_version = u64::deserialize(buf)?; let node_digest = NodeDigest::new(heartbeat, max_version); - node_digests.insert(chitchat_id, node_digest); + node_digests.insert(ChitchatIdGenerationEq(chitchat_id), node_digest); } Ok(Digest { node_digests }) } @@ -63,7 +64,7 @@ impl Serializable for Digest { fn serialized_len(&self) -> usize { let mut len = (self.node_digests.len() as u16).serialized_len(); for (chitchat_id, node_digest) in &self.node_digests { - len += chitchat_id.serialized_len(); + len += chitchat_id.0.serialized_len(); len += node_digest.heartbeat.serialized_len(); len += node_digest.max_version.serialized_len(); } diff --git a/chitchat/src/failure_detector.rs b/chitchat/src/failure_detector.rs index a41373b..1efc135 100644 --- a/chitchat/src/failure_detector.rs +++ b/chitchat/src/failure_detector.rs @@ -8,18 +8,18 @@ use mock_instant::Instant; use serde::{Deserialize, Serialize}; use tracing::debug; -use crate::ChitchatId; +use crate::{ChitchatId, ChitchatIdNodeEq}; /// A phi accrual failure detector implementation. pub struct FailureDetector { /// Heartbeat samples for each node. - node_samples: HashMap, + node_samples: HashMap, /// Failure detector configuration. config: FailureDetectorConfig, /// Denotes live nodes. - live_nodes: HashSet, + live_nodes: HashSet, /// Denotes dead nodes. - dead_nodes: HashMap, + dead_nodes: HashMap, } impl FailureDetector { @@ -37,7 +37,7 @@ impl FailureDetector { debug!(node_id=%chitchat_id.node_id, "reporting node heartbeat."); let heartbeat_window = self .node_samples - .entry(chitchat_id.clone()) + .entry(ChitchatIdNodeEq(chitchat_id.clone())) .or_insert_with(|| { SamplingWindow::new( self.config.sampling_window_size, @@ -52,15 +52,22 @@ impl FailureDetector { pub fn update_node_liveness(&mut self, chitchat_id: &ChitchatId) { if let Some(phi) = self.phi(chitchat_id) { debug!(node_id=%chitchat_id.node_id, phi=phi, "updating node liveness"); + let chitchat_id = ChitchatIdNodeEq(chitchat_id.clone()); if phi > self.config.phi_threshold { - self.live_nodes.remove(chitchat_id); - self.dead_nodes.insert(chitchat_id.clone(), Instant::now()); + self.live_nodes.remove(&chitchat_id); + // Remove current sampling window so that when the node // comes back online, we start with a fresh sampling window. - self.node_samples.remove(chitchat_id); + self.node_samples.remove(&chitchat_id); + + // remove and re-add to make sure we have the latest generation id + self.dead_nodes.remove(&chitchat_id); + self.dead_nodes.insert(chitchat_id, Instant::now()); } else { - self.live_nodes.insert(chitchat_id.clone()); - self.dead_nodes.remove(chitchat_id); + self.dead_nodes.remove(&chitchat_id); + // remove and re-add to make sure we have the latest generation id + self.live_nodes.remove(&chitchat_id); + self.live_nodes.insert(chitchat_id); } } } @@ -70,29 +77,30 @@ impl FailureDetector { let mut garbage_collected_nodes = Vec::new(); for (chitchat_id, instant) in self.dead_nodes.iter() { if instant.elapsed() >= self.config.dead_node_grace_period { - garbage_collected_nodes.push(chitchat_id.clone()) + garbage_collected_nodes.push(chitchat_id.0.clone()) } } for chitchat_id in garbage_collected_nodes.iter() { - self.dead_nodes.remove(chitchat_id); + self.dead_nodes + .remove(&ChitchatIdNodeEq(chitchat_id.clone())); } garbage_collected_nodes } /// Returns the list of nodes considered live by the failure detector. - pub fn live_nodes(&self) -> impl Iterator { + pub fn live_nodes(&self) -> impl Iterator { self.live_nodes.iter() } /// Returns the list of nodes considered dead by the failure detector. - pub fn dead_nodes(&self) -> impl Iterator { + pub fn dead_nodes(&self) -> impl Iterator { self.dead_nodes.keys() } /// Returns the current phi value of a node. fn phi(&mut self, chitchat_id: &ChitchatId) -> Option { self.node_samples - .get(chitchat_id) + .get(&ChitchatIdNodeEq(chitchat_id.clone())) .map(|sampling_window| sampling_window.phi()) } } @@ -283,11 +291,11 @@ mod tests { let mut live_nodes = failure_detector .live_nodes() - .map(|chitchat_id| chitchat_id.node_id.as_str()) + .map(|chitchat_id| chitchat_id.0.node_id.as_str()) .collect::>(); live_nodes.sort_unstable(); assert_eq!(live_nodes, vec!["node-10001", "node-10002", "node-10003"]); - assert_eq!(failure_detector.garbage_collect(), Vec::new()); + assert!(failure_detector.garbage_collect().is_empty()); // stop reporting heartbeat for few seconds MockClock::advance(Duration::from_secs(50)); @@ -296,11 +304,11 @@ mod tests { } let mut dead_nodes = failure_detector .dead_nodes() - .map(|chitchat_id| chitchat_id.node_id.as_str()) + .map(|chitchat_id| chitchat_id.0.node_id.as_str()) .collect::>(); dead_nodes.sort_unstable(); assert_eq!(dead_nodes, vec!["node-10001", "node-10002", "node-10003"]); - assert_eq!(failure_detector.garbage_collect(), Vec::new()); + assert!(failure_detector.garbage_collect().is_empty()); // Wait for dead_node_grace_period & garbage collect. MockClock::advance(Duration::from_secs(25 * 60 * 60)); @@ -308,14 +316,14 @@ mod tests { assert_eq!( failure_detector .live_nodes() - .map(|chitchat_id| chitchat_id.node_id.as_str()) + .map(|chitchat_id| chitchat_id.0.node_id.as_str()) .collect::>(), Vec::<&str>::new() ); assert_eq!( failure_detector .dead_nodes() - .map(|chitchat_id| chitchat_id.node_id.as_str()) + .map(|chitchat_id| chitchat_id.0.node_id.as_str()) .collect::>(), Vec::<&str>::new() ); @@ -347,7 +355,7 @@ mod tests { assert_eq!( failure_detector .live_nodes() - .map(|chitchat_id| chitchat_id.node_id.as_str()) + .map(|chitchat_id| chitchat_id.0.node_id.as_str()) .collect::>(), vec!["node-10001"] ); @@ -358,7 +366,7 @@ mod tests { assert_eq!( failure_detector .live_nodes() - .map(|chitchat_id| chitchat_id.node_id.as_str()) + .map(|chitchat_id| chitchat_id.0.node_id.as_str()) .collect::>(), Vec::<&str>::new() ); @@ -373,7 +381,7 @@ mod tests { assert_eq!( failure_detector .live_nodes() - .map(|chitchat_id| chitchat_id.node_id.as_str()) + .map(|chitchat_id| chitchat_id.0.node_id.as_str()) .collect::>(), vec!["node-10001"] ); @@ -391,7 +399,7 @@ mod tests { let live_nodes = failure_detector .live_nodes() - .map(|chitchat_id| chitchat_id.node_id.as_str()) + .map(|chitchat_id| chitchat_id.0.node_id.as_str()) .collect::>(); assert_eq!(live_nodes, vec!["node-10001"]); MockClock::advance(Duration::from_secs(40)); @@ -399,7 +407,7 @@ mod tests { let live_nodes = failure_detector .live_nodes() - .map(|chitchat_id| chitchat_id.node_id.as_str()) + .map(|chitchat_id| chitchat_id.0.node_id.as_str()) .collect::>(); assert_eq!(live_nodes, Vec::<&str>::new()); } diff --git a/chitchat/src/lib.rs b/chitchat/src/lib.rs index d23f3c9..70cb99e 100644 --- a/chitchat/src/lib.rs +++ b/chitchat/src/lib.rs @@ -32,7 +32,10 @@ use crate::message::syn_ack_serialized_len; pub use crate::message::ChitchatMessage; pub use crate::server::{spawn_chitchat, ChitchatHandle}; use crate::state::ClusterState; -pub use crate::types::{ChitchatId, Heartbeat, MaxVersion, Version, VersionedValue}; +pub use crate::types::{ + ChitchatId, ChitchatIdGenerationEq, ChitchatIdNodeEq, Heartbeat, MaxVersion, Version, + VersionedValue, +}; /// Maximum UDP datagram payload size (in bytes). /// @@ -50,9 +53,9 @@ pub struct Chitchat { cluster_state: ClusterState, failure_detector: FailureDetector, /// Notifies listeners when a change has occurred in the set of live nodes. - previous_live_nodes: HashMap, - live_nodes_watcher_tx: watch::Sender>, - live_nodes_watcher_rx: watch::Receiver>, + previous_live_nodes: HashMap, + live_nodes_watcher_tx: watch::Sender>, + live_nodes_watcher_rx: watch::Receiver>, } impl Chitchat { @@ -160,11 +163,16 @@ impl Chitchat { /// update. fn report_heartbeats(&mut self, delta: &Delta) { for (chitchat_id, node_delta) in &delta.node_deltas { - if let Some(node_state) = self.cluster_state.node_states.get(chitchat_id) { + if let Some((node_id, node_state)) = self + .cluster_state + .node_states + .get_key_value(&ChitchatIdNodeEq(chitchat_id.0.clone())) + { if node_state.heartbeat() < node_delta.heartbeat || node_state.max_version() < node_delta.max_version + || node_id.0.generation_id < chitchat_id.0.generation_id { - self.failure_detector.report_heartbeat(chitchat_id); + self.failure_detector.report_heartbeat(&chitchat_id.0); } } } @@ -175,7 +183,7 @@ impl Chitchat { pub(crate) fn update_nodes_liveness(&mut self) { self.cluster_state .nodes() - .filter(|&chitchat_id| *chitchat_id != self.config.chitchat_id) + .filter(|&chitchat_id| !chitchat_id.eq_node_id(&self.config.chitchat_id)) .for_each(|chitchat_id| { self.failure_detector.update_node_liveness(chitchat_id); }); @@ -186,7 +194,10 @@ impl Chitchat { let node_state = self .node_state(chitchat_id) .expect("Node state should exist."); - (chitchat_id.clone(), node_state.max_version()) + ( + ChitchatIdGenerationEq(chitchat_id.clone()), + node_state.max_version(), + ) }) .collect::>(); @@ -196,7 +207,7 @@ impl Chitchat { .cloned() .map(|chitchat_id| { let node_state = self - .node_state(&chitchat_id) + .node_state(&chitchat_id.0) .expect("Node state should exist.") .clone(); (chitchat_id, node_state) @@ -215,7 +226,7 @@ impl Chitchat { } } - pub fn node_states(&self) -> &BTreeMap { + pub fn node_states(&self) -> &BTreeMap { &self.cluster_state.node_states } @@ -230,7 +241,7 @@ impl Chitchat { /// Returns the set of nodes considered alive by the failure detector. It includes the /// current node (also called "self node"), which is always considered alive. pub fn live_nodes(&self) -> impl Iterator { - once(self.self_chitchat_id()).chain(self.failure_detector.live_nodes()) + once(self.self_chitchat_id()).chain(self.failure_detector.live_nodes().map(|node| &node.0)) } /// Returns a watch stream for monitoring changes in the cluster. @@ -241,12 +252,12 @@ impl Chitchat { /// - updates its max version /// /// Heartbeats are not notified. - pub fn live_nodes_watcher(&self) -> WatchStream> { + pub fn live_nodes_watcher(&self) -> WatchStream> { WatchStream::new(self.live_nodes_watcher_rx.clone()) } /// Returns the set of nodes considered dead by the failure detector. - pub fn dead_nodes(&self) -> impl Iterator { + pub fn dead_nodes(&self) -> impl Iterator { self.failure_detector.dead_nodes() } @@ -278,7 +289,7 @@ impl Chitchat { } /// Computes the node's digest. - fn compute_digest(&self, dead_nodes: &HashSet<&ChitchatId>) -> Digest { + fn compute_digest(&self, dead_nodes: &HashSet<&ChitchatIdNodeEq>) -> Digest { self.cluster_state.compute_digest(dead_nodes) } @@ -308,7 +319,7 @@ impl Chitchat { } } -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Debug, Copy, Clone)] pub struct KeyChangeEvent<'a> { /// The matching key without the prefix used to subscribe to the event. pub key: &'a str, @@ -318,6 +329,13 @@ pub struct KeyChangeEvent<'a> { pub node: &'a ChitchatId, } +impl<'a> Eq for KeyChangeEvent<'a> {} +impl<'a> PartialEq for KeyChangeEvent<'a> { + fn eq(&self, other: &Self) -> bool { + self.key == other.key && self.value == other.value && self.node.eq_generation(&other.node) + } +} + impl<'a> KeyChangeEvent<'a> { fn strip_key_prefix(&self, prefix: &str) -> Option { let key_without_prefix = self.key.strip_prefix(prefix)?; @@ -408,7 +426,7 @@ mod tests { for chitchat_id in &chitchat_ids[1..] { let seeds = chitchat_ids .iter() - .filter(|&peer_id| peer_id != chitchat_id) + .filter(|&peer_id| !peer_id.eq_node_id(chitchat_id)) .map(|peer_id| peer_id.gossip_advertise_addr.to_string()) .collect::>(); chitchat_handlers.push(start_node(chitchat_id.clone(), &seeds, transport).await); @@ -434,6 +452,10 @@ mod tests { chitchat: Arc>, expected_nodes: &[ChitchatId], ) { + let expected_nodes: Vec<_> = expected_nodes + .iter() + .map(|node| ChitchatIdGenerationEq(node.clone())) + .collect(); let expected_nodes = expected_nodes.iter().collect::>(); let mut live_nodes_watcher = chitchat diff --git a/chitchat/src/serialize.rs b/chitchat/src/serialize.rs index 5536383..b1027ff 100644 --- a/chitchat/src/serialize.rs +++ b/chitchat/src/serialize.rs @@ -269,9 +269,18 @@ mod tests { #[test] fn test_serialize_chitchat_id() { - test_serdeser_aux( - &ChitchatId::new("node-id".to_string(), 1, "127.0.0.1:7280".parse().unwrap()), - 24, + // we cant use test_serdeser_aux because ChitchatId isn't Eq + let obj = ChitchatId::new("node-id".to_string(), 1, "127.0.0.1:7280".parse().unwrap()); + let num_bytes = 24; + + let mut buf = Vec::new(); + obj.serialize(&mut buf); + assert_eq!(buf.len(), obj.serialized_len()); + assert_eq!(buf.len(), num_bytes); + let obj_serdeser = ChitchatId::deserialize(&mut &buf[..]).unwrap(); + assert!( + obj.eq_generation(&obj_serdeser) + && obj.gossip_advertise_addr == obj_serdeser.gossip_advertise_addr ); } diff --git a/chitchat/src/server.rs b/chitchat/src/server.rs index 3602c0f..55e3e64 100644 --- a/chitchat/src/server.rs +++ b/chitchat/src/server.rs @@ -260,17 +260,17 @@ impl Server { let peer_nodes = cluster_state .nodes() - .filter(|chitchat_id| *chitchat_id != chitchat_guard.self_chitchat_id()) + .filter(|chitchat_id| !chitchat_id.eq_node_id(chitchat_guard.self_chitchat_id())) .map(|chitchat_id| chitchat_id.gossip_advertise_addr) .collect::>(); let live_nodes = chitchat_guard .live_nodes() - .filter(|chitchat_id| *chitchat_id != chitchat_guard.self_chitchat_id()) + .filter(|chitchat_id| !chitchat_id.eq_node_id(chitchat_guard.self_chitchat_id())) .map(|chitchat_id| chitchat_id.gossip_advertise_addr) .collect::>(); let dead_nodes = chitchat_guard .dead_nodes() - .map(|chitchat_id| chitchat_id.gossip_advertise_addr) + .map(|chitchat_id| chitchat_id.0.gossip_advertise_addr) .collect::>(); let seed_nodes: HashSet = chitchat_guard.seed_nodes(); let (selected_nodes, random_dead_node_opt, random_seed_node_opt) = select_nodes_for_gossip( @@ -416,7 +416,7 @@ mod tests { use super::*; use crate::message::ChitchatMessage; use crate::transport::{ChannelTransport, Transport}; - use crate::{Heartbeat, NodeState, MAX_UDP_DATAGRAM_PAYLOAD_SIZE}; + use crate::{ChitchatIdGenerationEq, Heartbeat, NodeState, MAX_UDP_DATAGRAM_PAYLOAD_SIZE}; #[derive(Debug, Default)] struct RngForTest { @@ -601,7 +601,10 @@ mod tests { panic!("Expected ack"); }; - let node_delta = &delta.node_deltas.get(&server_id).unwrap(); + let node_delta = &delta + .node_deltas + .get(&ChitchatIdGenerationEq(server_id)) + .unwrap(); let heartbeat = node_delta.heartbeat; assert_eq!(heartbeat, Heartbeat(3)); @@ -628,7 +631,7 @@ mod tests { { let live_nodes = next_live_nodes(&mut live_nodes_watcher).await; assert_eq!(live_nodes.len(), 1); - assert!(live_nodes.contains_key(&node1_id)); + assert!(live_nodes.contains_key(&ChitchatIdGenerationEq(node1_id))); } let mut node2_config = ChitchatConfig::for_test(6664); node2_config.seed_nodes = vec![node1_addr.to_string()]; @@ -639,16 +642,18 @@ mod tests { { let live_nodes = next_live_nodes(&mut live_nodes_watcher).await; assert_eq!(live_nodes.len(), 2); - assert!(live_nodes.contains_key(&node2_id)); + assert!(live_nodes.contains_key(&ChitchatIdGenerationEq(node2_id))); } node1.shutdown().await.unwrap(); node2.shutdown().await.unwrap(); } - async fn next_live_nodes>>( + async fn next_live_nodes< + S: Unpin + Stream>, + >( watcher: &mut S, - ) -> BTreeMap { + ) -> BTreeMap { tokio::time::timeout(Duration::from_secs(3), watcher.next()) .await .expect("No Change within 3s") diff --git a/chitchat/src/state.rs b/chitchat/src/state.rs index e30a0e3..93e19d2 100644 --- a/chitchat/src/state.rs +++ b/chitchat/src/state.rs @@ -15,7 +15,10 @@ use tracing::warn; use crate::delta::{Delta, DeltaWriter}; use crate::digest::{Digest, NodeDigest}; use crate::listener::Listeners; -use crate::{ChitchatId, Heartbeat, KeyChangeEvent, MaxVersion, Version, VersionedValue}; +use crate::{ + ChitchatId, ChitchatIdGenerationEq, ChitchatIdNodeEq, Heartbeat, KeyChangeEvent, MaxVersion, + Version, VersionedValue, +}; #[derive(Clone, Serialize, Deserialize)] pub struct NodeState { @@ -230,7 +233,9 @@ impl NodeState { } pub(crate) struct ClusterState { - pub(crate) node_states: BTreeMap, + // when inserting in this map, it's up to you to make sure you store the newest generation, + // which possibly means removing and reinserting a key. + pub(crate) node_states: BTreeMap, seed_addrs: watch::Receiver>, pub(crate) listeners: Listeners, } @@ -268,16 +273,16 @@ impl ClusterState { pub(crate) fn node_state_mut(&mut self, chitchat_id: &ChitchatId) -> &mut NodeState { // TODO use the `hash_raw_entry` feature once it gets stabilized. self.node_states - .entry(chitchat_id.clone()) + .entry(ChitchatIdNodeEq(chitchat_id.clone())) .or_insert_with(|| NodeState::new(chitchat_id.clone(), self.listeners.clone())) } pub fn node_state(&self, chitchat_id: &ChitchatId) -> Option<&NodeState> { - self.node_states.get(chitchat_id) + self.node_states.get(&ChitchatIdNodeEq(chitchat_id.clone())) } pub fn nodes(&self) -> impl Iterator { - self.node_states.keys() + self.node_states.keys().map(|node| &node.0) } pub fn seed_addrs(&self) -> HashSet { @@ -285,20 +290,38 @@ impl ClusterState { } pub(crate) fn remove_node(&mut self, chitchat_id: &ChitchatId) { - self.node_states.remove(chitchat_id); + self.node_states + .remove(&ChitchatIdNodeEq(chitchat_id.clone())); } + /// Apply a delta, ignoring any entry from a previous relay of `me`. pub(crate) fn apply_delta(&mut self, delta: Delta) { // Remove nodes to reset. - self.node_states - .retain(|chitchat_id, _| !delta.nodes_to_reset.contains(chitchat_id)); + self.node_states.retain(|chitchat_id, _| { + !delta + .nodes_to_reset + .iter() + .any(|to_reset| to_reset.0.eq_generation(&chitchat_id.0)) + }); // Apply delta. for (chitchat_id, node_delta) in delta.node_deltas { - let node_state = self + // we remove and re-insert to update the key in case generation changed. + let mut node_state = if let Some((old_chitchat_id, old_state)) = self .node_states - .entry(chitchat_id.clone()) - .or_insert_with(|| NodeState::new(chitchat_id, self.listeners.clone())); + .remove_entry(&ChitchatIdNodeEq(chitchat_id.0.clone())) + { + let old_chichat_id = ChitchatIdGenerationEq(old_chitchat_id.0); + if old_chichat_id > chitchat_id { + // we know a newer generation, restore the write and ignore that bit of delta. + self.node_states + .insert(ChitchatIdNodeEq(old_chichat_id.0), old_state); + continue; + } + old_state + } else { + NodeState::new(chitchat_id.0.clone(), self.listeners.clone()) + }; if node_state.heartbeat < node_delta.heartbeat { node_state.heartbeat = node_delta.heartbeat; node_state.last_heartbeat = Instant::now(); @@ -308,16 +331,23 @@ impl ClusterState { node_state.max_version = node_state.max_version.max(versioned_value.version); node_state.set_versioned_value(key, versioned_value); } + self.node_states + .insert(ChitchatIdNodeEq(chitchat_id.0), node_state); } } - pub fn compute_digest(&self, dead_nodes: &HashSet<&ChitchatId>) -> Digest { + pub fn compute_digest(&self, dead_nodes: &HashSet<&ChitchatIdNodeEq>) -> Digest { Digest { node_digests: self .node_states .iter() .filter(|(chitchat_id, _)| !dead_nodes.contains(chitchat_id)) - .map(|(chitchat_id, node_state)| (chitchat_id.clone(), node_state.digest())) + .map(|(chitchat_id, node_state)| { + ( + ChitchatIdGenerationEq(chitchat_id.0.clone()), + node_state.digest(), + ) + }) .collect(), } } @@ -325,7 +355,7 @@ impl ClusterState { pub fn gc_keys_marked_for_deletion( &mut self, marked_for_deletion_grace_period: u64, - dead_nodes: &HashSet, + dead_nodes: &HashSet, ) { for (chitchat_id, node_state) in &mut self.node_states { if dead_nodes.contains(chitchat_id) { @@ -340,7 +370,7 @@ impl ClusterState { &self, digest: &Digest, mtu: usize, - dead_nodes: &HashSet<&ChitchatId>, + dead_nodes: &HashSet<&ChitchatIdNodeEq>, marked_for_deletion_grace_period: u64, ) -> Delta { let mut stale_nodes = SortedStaleNodes::default(); @@ -350,17 +380,20 @@ impl ClusterState { if dead_nodes.contains(chitchat_id) { continue; } - let Some(node_digest) = digest.node_digests.get(chitchat_id) else { - stale_nodes.insert(chitchat_id, node_state); + let Some(node_digest) = digest + .node_digests + .get(&ChitchatIdGenerationEq(chitchat_id.0.clone())) + else { + stale_nodes.insert(&chitchat_id.0, node_state); continue; }; if node_digest.heartbeat.0 + marked_for_deletion_grace_period < node_state.heartbeat.0 { warn!("Node to reset {chitchat_id:?}"); - nodes_to_reset.push(chitchat_id); - stale_nodes.insert(chitchat_id, node_state); + nodes_to_reset.push(&chitchat_id.0); + stale_nodes.insert(&chitchat_id.0, node_state); continue; } - stale_nodes.offer(chitchat_id, node_state, node_digest); + stale_nodes.offer(&chitchat_id.0, node_state, node_digest); } let mut delta_writer = DeltaWriter::with_mtu(mtu); @@ -491,7 +524,7 @@ impl From<&ClusterState> for ClusterStateSnapshot { .node_states .iter() .map(|(chitchat_id, node_state)| NodeStateSnapshot { - chitchat_id: chitchat_id.clone(), + chitchat_id: chitchat_id.0.clone(), node_state: node_state.clone(), }) .collect(); @@ -825,7 +858,8 @@ mod tests { assert_eq!(&digest, &expected_node_digests); // Consider node 1 dead: - let dead_nodes = HashSet::from_iter([&node1]); + let dead_node = ChitchatIdNodeEq(node1); + let dead_nodes = HashSet::from_iter([&dead_node]); let digest = cluster_state.compute_digest(&dead_nodes); let mut expected_node_digests = Digest::default(); @@ -934,7 +968,7 @@ mod tests { fn test_with_varying_max_transmitted_kv_helper( cluster_state: &ClusterState, digest: &Digest, - dead_nodes: &HashSet<&ChitchatId>, + dead_nodes: &HashSet<&ChitchatIdNodeEq>, expected_delta_atoms: &[(&ChitchatId, &str, &str, Version, Option)], ) { let max_delta = cluster_state.compute_delta(digest, usize::MAX, dead_nodes, 10_000); @@ -1060,7 +1094,8 @@ mod tests { let node1 = ChitchatId::for_local_test(10_001); let node2 = ChitchatId::for_local_test(10_002); - let dead_nodes = HashSet::from_iter([&node2]); + let dead_node = ChitchatIdNodeEq(node2); + let dead_nodes = HashSet::from_iter([&dead_node]); test_with_varying_max_transmitted_kv_helper( &cluster_state, diff --git a/chitchat/src/types.rs b/chitchat/src/types.rs index 8a919e8..de62cda 100644 --- a/chitchat/src/types.rs +++ b/chitchat/src/types.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; +use std::hash::{Hash, Hasher}; use std::net::SocketAddr; use serde::{Deserialize, Serialize}; @@ -13,7 +15,13 @@ use serde::{Deserialize, Serialize}; /// leaves and rejoins the cluster. Backends such as Cassandra or Quickwit typically use the node's /// startup time as the `generation_id`. Applications with stable state across restarts can use a /// constant `generation_id`, for instance, `0`. -#[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize, Deserialize)] +// This type doesn't implement Eq & co because there are multiple notions of equality depending +// on what you want to do with it. Nodes with the same node_id are the same by definition, so +// sometime checking node_id is enough, but sometime we want to compare ChitchatId for generations, +// in which case node_id+generation_id needs to be compared. Mixing both is easy and can lead to +// bugs. Instead you have to use dedicated methods and/or wrappers depending on what equality means +// for you in this context. +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChitchatId { /// An identifier unique across the cluster. pub node_id: String, @@ -31,6 +39,14 @@ impl ChitchatId { gossip_advertise_addr, } } + + pub fn eq_node_id(&self, other: &ChitchatId) -> bool { + self.node_id == other.node_id + } + + pub fn eq_generation(&self, other: &ChitchatId) -> bool { + self.eq_node_id(other) && self.generation_id == other.generation_id + } } #[cfg(any(test, feature = "testsuite"))] @@ -46,6 +62,60 @@ impl ChitchatId { } } +#[derive(Debug, Clone)] +pub struct ChitchatIdNodeEq(pub ChitchatId); + +impl Eq for ChitchatIdNodeEq {} +impl PartialEq for ChitchatIdNodeEq { + fn eq(&self, other: &ChitchatIdNodeEq) -> bool { + self.0.eq_node_id(&other.0) + } +} +impl Hash for ChitchatIdNodeEq { + fn hash(&self, state: &mut H) { + self.0.node_id.hash(state); + } +} +impl Ord for ChitchatIdNodeEq { + fn cmp(&self, other: &Self) -> Ordering { + self.0.node_id.cmp(&other.0.node_id) + } +} +impl PartialOrd for ChitchatIdNodeEq { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +#[derive(Debug, Clone)] +pub struct ChitchatIdGenerationEq(pub ChitchatId); + +impl Eq for ChitchatIdGenerationEq {} +impl PartialEq for ChitchatIdGenerationEq { + fn eq(&self, other: &ChitchatIdGenerationEq) -> bool { + self.0.eq_generation(&other.0) + } +} +impl Hash for ChitchatIdGenerationEq { + fn hash(&self, state: &mut H) { + self.0.node_id.hash(state); + self.0.generation_id.hash(state); + } +} +impl Ord for ChitchatIdGenerationEq { + fn cmp(&self, other: &Self) -> Ordering { + self.0 + .node_id + .cmp(&other.0.node_id) + .then(self.0.generation_id.cmp(&other.0.generation_id)) + } +} +impl PartialOrd for ChitchatIdGenerationEq { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + /// A versioned key-value pair. #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] pub struct VersionedValue { diff --git a/chitchat/tests/cluster_test.rs b/chitchat/tests/cluster_test.rs index a41906d..d1f9529 100644 --- a/chitchat/tests/cluster_test.rs +++ b/chitchat/tests/cluster_test.rs @@ -5,7 +5,8 @@ use std::time::Duration; use anyhow::anyhow; use chitchat::transport::ChannelTransport; use chitchat::{ - spawn_chitchat, ChitchatConfig, ChitchatHandle, ChitchatId, FailureDetectorConfig, NodeState, + spawn_chitchat, ChitchatConfig, ChitchatHandle, ChitchatId, ChitchatIdGenerationEq, + FailureDetectorConfig, NodeState, }; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; @@ -66,7 +67,7 @@ impl NodeStatePredicate { struct Simulator { transport: ChannelTransport, - node_handles: HashMap, + node_handles: HashMap, gossip_interval: Duration, marked_for_deletion_key_grace_period: usize, } @@ -124,7 +125,7 @@ impl Simulator { debug!(server_node_id=%server_chitchat_id.node_id, node_id=%chitchat_id.node_id, "node-state-assert"); let chitchat = self .node_handles - .get(&server_chitchat_id) + .get(&ChitchatIdGenerationEq(server_chitchat_id)) .unwrap() .chitchat(); // Wait for node_state & predicate. @@ -173,7 +174,11 @@ impl Simulator { keys_values: Vec<(String, String)>, ) { debug!(node_id=%chitchat_id.node_id, num_keys_values=?keys_values.len(), "insert-keys-values"); - let chitchat = self.node_handles.get(&chitchat_id).unwrap().chitchat(); + let chitchat = self + .node_handles + .get(&ChitchatIdGenerationEq(chitchat_id)) + .unwrap() + .chitchat(); let mut chitchat_guard = chitchat.lock().await; for (key, value) in keys_values.into_iter() { chitchat_guard.self_node_state().set(key.clone(), value); @@ -181,7 +186,11 @@ impl Simulator { } pub async fn mark_for_deletion(&mut self, chitchat_id: ChitchatId, key: String) { - let chitchat = self.node_handles.get(&chitchat_id).unwrap().chitchat(); + let chitchat = self + .node_handles + .get(&ChitchatIdGenerationEq(chitchat_id.clone())) + .unwrap() + .chitchat(); let mut chitchat_guard = chitchat.lock().await; chitchat_guard.self_node_state().mark_for_deletion(&key); let hearbeat = chitchat_guard.self_node_state().heartbeat(); @@ -198,7 +207,7 @@ impl Simulator { .unwrap_or_else(|| { self.node_handles .keys() - .cloned() + .map(|id| id.0.clone()) .collect::>() }) .iter() @@ -219,7 +228,8 @@ impl Simulator { let handle = spawn_chitchat(config, Vec::new(), &self.transport) .await .unwrap(); - self.node_handles.insert(chitchat_id, handle); + self.node_handles + .insert(ChitchatIdGenerationEq(chitchat_id), handle); } } @@ -471,7 +481,7 @@ async fn test_simple_simulation_heavy_insert_delete() { simulator.execute(add_node_operations).await; let key_names: Vec<_> = (0..200).map(|idx| format!("key_{}", idx)).collect(); - let mut keys_values_inserted_per_chitchat_id: HashMap> = + let mut keys_values_inserted_per_chitchat_id: HashMap> = HashMap::new(); for chitchat_id in chitchat_ids.iter() { let mut keys_values = Vec::new(); @@ -479,7 +489,7 @@ async fn test_simple_simulation_heavy_insert_delete() { let value: u64 = rng.gen(); keys_values.push((key.to_string(), value.to_string())); let keys_entry = keys_values_inserted_per_chitchat_id - .entry(chitchat_id.clone()) + .entry(ChitchatIdGenerationEq(chitchat_id.clone())) .or_default(); keys_entry.insert(key.to_string()); } @@ -494,12 +504,12 @@ async fn test_simple_simulation_heavy_insert_delete() { tokio::time::sleep(Duration::from_secs(10)).await; info!("Checking keys are present..."); for (chitchat_id, keys) in keys_values_inserted_per_chitchat_id.clone().into_iter() { - debug!(node_id=%chitchat_id.node_id, keys=?keys, "check"); + debug!(node_id=%chitchat_id.0.node_id, keys=?keys, "check"); for key in keys { let server_chitchat_id = chitchat_ids.choose(&mut rng).unwrap().clone(); let check_operation = Operation::NodeStateAssert { server_chitchat_id, - chitchat_id: chitchat_id.clone(), + chitchat_id: chitchat_id.clone().0, predicate: NodeStatePredicate::KeyPresent(key.to_string(), true), timeout_opt: None, }; @@ -511,7 +521,7 @@ async fn test_simple_simulation_heavy_insert_delete() { for (chitchat_id, keys) in keys_values_inserted_per_chitchat_id.clone().into_iter() { for key in keys { let check_operation = Operation::MarkKeyForDeletion { - chitchat_id: chitchat_id.clone(), + chitchat_id: chitchat_id.clone().0, key, }; simulator.execute(vec![check_operation]).await; @@ -527,7 +537,7 @@ async fn test_simple_simulation_heavy_insert_delete() { let server_chitchat_id = chitchat_ids.choose(&mut rng).unwrap().clone(); let check_operation = Operation::NodeStateAssert { server_chitchat_id, - chitchat_id: chitchat_id.clone(), + chitchat_id: chitchat_id.clone().0, predicate: NodeStatePredicate::KeyPresent(key.to_string(), false), timeout_opt: None, }; diff --git a/chitchat/tests/perf_test.rs b/chitchat/tests/perf_test.rs index 918bccd..6e899df 100644 --- a/chitchat/tests/perf_test.rs +++ b/chitchat/tests/perf_test.rs @@ -4,7 +4,8 @@ use std::time::Duration; use chitchat::transport::{ChannelTransport, Transport, TransportExt}; use chitchat::{ - spawn_chitchat, ChitchatConfig, ChitchatHandle, ChitchatId, FailureDetectorConfig, NodeState, + spawn_chitchat, ChitchatConfig, ChitchatHandle, ChitchatId, ChitchatIdGenerationEq, + FailureDetectorConfig, NodeState, }; use tokio::time::Instant; use tokio_stream::StreamExt; @@ -42,7 +43,7 @@ async fn spawn_nodes(num_nodes: u16, transport: &dyn Transport) -> Vec) -> bool>( +async fn wait_until) -> bool>( handle: &ChitchatHandle, predicate: P, ) -> Duration {