diff --git a/src/job_manager.rs b/src/job_manager.rs index 321c1443..886b2445 100644 --- a/src/job_manager.rs +++ b/src/job_manager.rs @@ -1,6 +1,7 @@ use parking_lot::RwLock; use std::fmt::{Debug, Display}; use std::future::Future; +use std::ops::{Add, Sub}; use std::{ collections::{HashMap, HashSet, VecDeque}, hash::{Hash, Hasher}, @@ -54,7 +55,18 @@ pub type EnqueuedMessage = HashMap>>; pub trait WorkManagerInterface: Send + Sync + 'static + Sized { type SSID: Copy + Hash + Eq + PartialEq + Send + Sync + 'static; - type Clock: Copy + Debug + Eq + PartialEq + Send + Sync + 'static; + type Clock: Copy + + Debug + + Default + + Eq + + Ord + + PartialOrd + + PartialEq + + Send + + Sync + + Sub + + Add + + 'static; type ProtocolMessage: ProtocolMessageMetadata + Send + Sync + 'static; type Error: Debug + Send + Sync + 'static; type SessionID: Copy + Hash + Eq + PartialEq + Display + Debug + Send + Sync + 'static; @@ -62,7 +74,25 @@ pub trait WorkManagerInterface: Send + Sync + 'static + Sized { fn error(&self, input: String); fn warn(&self, input: String); fn clock(&self) -> Self::Clock; - fn associated_block_id_acceptable(now: Self::Clock, compare: Self::Clock) -> bool; + fn acceptable_block_tolerance() -> Self::Clock; + fn associated_block_id_acceptable(expected: Self::Clock, received: Self::Clock) -> bool { + // Favor explicit logic for readability + let tolerance = Self::acceptable_block_tolerance(); + let is_acceptable_above = received >= expected && received <= expected + tolerance; + let is_acceptable_below = + received < expected && received >= saturating_sub(expected, tolerance); + let is_equal = expected == received; + + is_acceptable_above || is_acceptable_below || is_equal + } +} + +fn saturating_sub + Ord + Default>(a: T, b: T) -> T { + if a < b { + T::default() + } else { + a - b + } } pub trait ProtocolMessageMetadata { @@ -568,8 +598,9 @@ mod tests { fn clock(&self) -> Self::Clock { 0 } - fn associated_block_id_acceptable(now: Self::Clock, compare: Self::Clock) -> bool { - now == compare + + fn acceptable_block_tolerance() -> Self::Clock { + 0 } } @@ -1126,4 +1157,93 @@ mod tests { let delivery_type = work_manager.deliver_message(msg, [2; 32]).unwrap(); // incorrect task hash assert_eq!(delivery_type, DeliveryType::EnqueuedMessage); // message should be enqueued because the task hash is incorrect } + + struct DummyRangeChecker; + struct DummyProtocolMessage; + + impl ProtocolMessageMetadata> for DummyProtocolMessage { + fn associated_block_id(&self) -> u64 { + 0u64 + } + + fn associated_session_id(&self) -> u64 { + 0u64 + } + + fn associated_ssid(&self) -> u64 { + 0u64 + } + } + + impl WorkManagerInterface for DummyRangeChecker { + type SSID = u64; + type Clock = u64; + type ProtocolMessage = DummyProtocolMessage; + type Error = (); + type SessionID = u64; + + fn debug(&self, _input: String) { + todo!() + } + + fn error(&self, _input: String) { + todo!() + } + + fn warn(&self, _input: String) { + todo!() + } + + fn clock(&self) -> Self::Clock { + todo!() + } + + fn acceptable_block_tolerance() -> Self::Clock { + N + } + } + + #[test] + fn test_range_above() { + let current_block: u64 = 10; + const TOL: u64 = 5; + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block + )); + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block + 1 + )); + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block + TOL + )); + assert!(!DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block + TOL + 1 + )); + } + + #[test] + fn test_range_below() { + let current_block: u64 = 10; + const TOL: u64 = 5; + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block + )); + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block - 1 + )); + assert!(DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block - TOL + )); + assert!(!DummyRangeChecker::::associated_block_id_acceptable( + current_block, + current_block - TOL - 1 + )); + } }