diff --git a/Cargo.lock b/Cargo.lock index 301842c1..7a553227 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1477,6 +1477,7 @@ dependencies = [ "base64 0.22.0", "cfg-if", "criterion", + "dashmap", "futures", "getrandom", "infinispan", diff --git a/limitador/Cargo.toml b/limitador/Cargo.toml index 2d6646c6..7f791fe8 100644 --- a/limitador/Cargo.toml +++ b/limitador/Cargo.toml @@ -22,6 +22,7 @@ lenient_conditions = [] [dependencies] moka = { version = "0.12", features = ["sync"] } +dashmap = "5.5.3" getrandom = { version = "0.2", features = ["js"] } serde = { version = "1", features = ["derive"] } postcard = { version = "1.0.4", features = ["use-std"] } diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 2f2b22ca..c9af9091 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -4,27 +4,147 @@ use crate::storage::redis::{ DEFAULT_MAX_CACHED_COUNTERS, DEFAULT_MAX_TTL_CACHED_COUNTERS_SEC, DEFAULT_TTL_RATIO_CACHED_COUNTERS, }; +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; use moka::sync::Cache; +use std::collections::HashMap; +use std::future::Future; +use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime}; +use tokio::select; +use tokio::sync::Notify; +use tokio::time::interval; pub struct CachedCounterValue { value: AtomicExpiringValue, + initial_value: AtomicI64, expiry: AtomicExpiryTime, + from_authority: AtomicBool, +} + +pub struct Batcher { + updates: DashMap>, + notifier: Notify, + interval: Duration, + priority_flush: AtomicBool, +} + +impl Batcher { + fn new(period: Duration) -> Self { + Self { + updates: Default::default(), + notifier: Default::default(), + interval: period, + priority_flush: AtomicBool::new(false), + } + } + + pub fn is_empty(&self) -> bool { + self.updates.is_empty() + } + + pub async fn consume(&self, min: usize, consumer: F) -> O + where + F: FnOnce(HashMap>) -> Fut, + Fut: Future, + { + let mut interval = interval(self.interval); + let mut ready = self.updates.len() >= min; + loop { + if ready { + let mut batch = Vec::with_capacity(min); + for entry in &self.updates { + if entry.value().requires_fast_flush(&self.interval) { + batch.push(entry.key().clone()); + if batch.len() == min { + break; + } + } + } + if let Some(remaining) = min.checked_sub(batch.len()) { + let take = self.updates.iter().take(remaining); + batch.append(&mut take.map(|e| e.key().clone()).collect()); + } + let mut result = HashMap::new(); + for counter in &batch { + let value = self.updates.get(counter).unwrap().clone(); + result.insert(counter.clone(), value); + } + let result = consumer(result).await; + for counter in &batch { + self.updates + .remove_if(counter, |_, v| v.no_pending_writes()); + } + return result; + } else { + ready = select! { + _ = self.notifier.notified() => { + self.updates.len() >= min || + self.priority_flush + .compare_exchange(true, false, Ordering::Release, Ordering::Acquire) + .is_ok() + }, + _ = interval.tick() => true, + } + } + } + } + + pub fn add(&self, counter: Counter, value: Arc) { + let priority = value.requires_fast_flush(&self.interval); + match self.updates.entry(counter.clone()) { + Entry::Occupied(needs_merge) => { + let arc = needs_merge.get(); + if !Arc::ptr_eq(arc, &value) { + arc.delta(&counter, value.pending_writes().unwrap()); + } + } + Entry::Vacant(miss) => { + miss.insert_entry(value); + } + }; + if priority { + self.priority_flush.store(true, Ordering::Release); + } + self.notifier.notify_one(); + } +} + +impl Default for Batcher { + fn default() -> Self { + Self::new(Duration::from_millis(100)) + } } pub struct CountersCache { max_ttl_cached_counters: Duration, pub ttl_ratio_cached_counters: u64, cache: Cache>, + batcher: Batcher, } impl CachedCounterValue { - pub fn from(counter: &Counter, value: i64, ttl: Duration) -> Self { + pub fn from_authority(counter: &Counter, value: i64, ttl: Duration) -> Self { let now = SystemTime::now(); Self { value: AtomicExpiringValue::new(value, now + Duration::from_secs(counter.seconds())), + initial_value: AtomicI64::new(value), expiry: AtomicExpiryTime::from_now(ttl), + from_authority: AtomicBool::new(true), + } + } + + pub fn load_from_authority_asap(counter: &Counter, temp_value: i64) -> Self { + let now = SystemTime::now(); + Self { + value: AtomicExpiringValue::new( + temp_value, + now + Duration::from_secs(counter.seconds()), + ), + initial_value: AtomicI64::new(temp_value), + expiry: AtomicExpiryTime::from_now(Duration::from_secs(counter.seconds())), + from_authority: AtomicBool::new(false), } } @@ -34,13 +154,58 @@ impl CachedCounterValue { pub fn set_from_authority(&self, counter: &Counter, value: i64, expiry: Duration) { let time_window = Duration::from_secs(counter.seconds()); + self.initial_value.store(value, Ordering::SeqCst); self.value.set(value, time_window); self.expiry.update(expiry); + self.from_authority.store(true, Ordering::Release); } pub fn delta(&self, counter: &Counter, delta: i64) -> i64 { - self.value - .update(delta, counter.seconds(), SystemTime::now()) + let value = self + .value + .update(delta, counter.seconds(), SystemTime::now()); + if value == delta { + // new window, invalidate initial value + self.initial_value.store(0, Ordering::SeqCst); + } + value + } + + pub fn pending_writes(&self) -> Result { + let start = self.initial_value.load(Ordering::SeqCst); + let value = self.value.value_at(SystemTime::now()); + let offset = if start == 0 { + value + } else { + let writes = value - start; + if writes > 0 { + writes + } else { + value + } + }; + match self + .initial_value + .compare_exchange(start, value, Ordering::SeqCst, Ordering::SeqCst) + { + Ok(_) => Ok(offset), + Err(newer) => { + if newer == 0 { + // We got expired in the meantime, this fresh value can wait the next iteration + Ok(0) + } else { + // Concurrent call to this method? + // We could support that with a CAS loop in the future if needed + Err(()) + } + } + } + } + + fn no_pending_writes(&self) -> bool { + let start = self.initial_value.load(Ordering::SeqCst); + let value = self.value.value_at(SystemTime::now()); + value - start == 0 } pub fn hits(&self, _: &Counter) -> i64 { @@ -58,6 +223,10 @@ impl CachedCounterValue { pub fn to_next_window(&self) -> Duration { self.value.ttl() } + + pub fn requires_fast_flush(&self, within: &Duration) -> bool { + self.from_authority.load(Ordering::Acquire) || &self.value.ttl() <= within + } } pub struct CountersCacheBuilder { @@ -90,18 +259,31 @@ impl CountersCacheBuilder { self } - pub fn build(&self) -> CountersCache { + pub fn build(&self, period: Duration) -> CountersCache { CountersCache { max_ttl_cached_counters: self.max_ttl_cached_counters, ttl_ratio_cached_counters: self.ttl_ratio_cached_counters, cache: Cache::new(self.max_cached_counters as u64), + batcher: Batcher::new(period), } } } impl CountersCache { pub fn get(&self, counter: &Counter) -> Option> { - self.cache.get(counter) + let option = self.cache.get(counter); + if option.is_none() { + let from_queue = self.batcher.updates.get(counter); + if let Some(entry) = from_queue { + self.cache.insert(counter.clone(), entry.value().clone()); + return Some(entry.value().clone()); + } + } + option + } + + pub fn batcher(&self) -> &Batcher { + &self.batcher } pub fn insert( @@ -122,25 +304,38 @@ impl CountersCache { if let Some(ttl) = cache_ttl.checked_sub(ttl_margin) { if ttl > Duration::ZERO { let previous = self.cache.get_with(counter.clone(), || { - Arc::new(CachedCounterValue::from(&counter, counter_val, cache_ttl)) + if let Some(entry) = self.batcher.updates.get(&counter) { + entry.value().clone() + } else { + Arc::new(CachedCounterValue::from_authority( + &counter, + counter_val, + ttl, + )) + } }); if previous.expired_at(now) || previous.value.value() < counter_val { - previous.set_from_authority(&counter, counter_val, cache_ttl); + previous.set_from_authority(&counter, counter_val, ttl); } return previous; } } - Arc::new(CachedCounterValue::from( + Arc::new(CachedCounterValue::load_from_authority_asap( &counter, counter_val, - Duration::ZERO, )) } pub fn increase_by(&self, counter: &Counter, delta: i64) { - if let Some(val) = self.cache.get(counter) { - val.delta(counter, delta); - }; + let val = self.cache.get_with_by_ref(counter, || { + if let Some(entry) = self.batcher.updates.get(counter) { + entry.value().clone() + } else { + Arc::new(CachedCounterValue::load_from_authority_asap(counter, 0)) + } + }); + val.delta(counter, delta); + self.batcher.add(counter.clone(), val.clone()); } fn ttl_from_redis_ttl( @@ -209,7 +404,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), Some(10), @@ -236,7 +431,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); assert!(cache.get(&counter).is_none()); } @@ -258,7 +453,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), Some(current_value), @@ -289,7 +484,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), None, @@ -318,7 +513,7 @@ mod tests { values, ); - let cache = CountersCacheBuilder::new().build(); + let cache = CountersCacheBuilder::new().build(Duration::default()); cache.insert( counter.clone(), Some(current_val), diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 3f9e3da6..20dedcc0 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -1,10 +1,11 @@ use crate::counter::Counter; use crate::limit::Limit; -use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::keys::*; -use crate::storage::redis::counters_cache::{CountersCache, CountersCacheBuilder}; +use crate::storage::redis::counters_cache::{ + CachedCounterValue, CountersCache, CountersCacheBuilder, +}; use crate::storage::redis::redis_async::AsyncRedisStorage; -use crate::storage::redis::scripts::{BATCH_UPDATE_COUNTERS, VALUES_AND_TTLS}; +use crate::storage::redis::scripts::BATCH_UPDATE_COUNTERS; use crate::storage::redis::{ DEFAULT_FLUSHING_PERIOD_SEC, DEFAULT_MAX_CACHED_COUNTERS, DEFAULT_MAX_TTL_CACHED_COUNTERS_SEC, DEFAULT_RESPONSE_TIMEOUT_MS, DEFAULT_TTL_RATIO_CACHED_COUNTERS, @@ -16,7 +17,7 @@ use redis::{ConnectionInfo, RedisError}; use std::collections::{HashMap, HashSet}; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; use tracing::{debug_span, error, warn, Instrument}; @@ -40,10 +41,7 @@ use tracing::{debug_span, error, warn, Instrument}; pub struct CachedRedisStorage { cached_counters: Arc, - batcher_counter_updates: Arc>>, async_redis_storage: AsyncRedisStorage, - redis_conn_manager: ConnectionManager, - partitioned: Arc, } #[async_trait] @@ -102,37 +100,9 @@ impl AsyncCounterStorage for CachedRedisStorage { // Fetch non-cached counters, cache them, and check them if !not_cached.is_empty() { - let time_start_get_ttl = Instant::now(); - - let (counter_vals, counter_ttls_msecs) = if self.is_partitioned() { - self.fallback_vals_ttls(¬_cached) - } else { - self.values_with_ttls(¬_cached).await.or_else(|err| { - if err.is_transient() { - self.partitioned(true); - Ok(self.fallback_vals_ttls(¬_cached)) - } else { - Err(err) - } - })? - }; - - // Some time could have passed from the moment we got the TTL from Redis. - // This margin is not exact, because we don't know exactly the - // moment that Redis returned a particular TTL, but this - // approximation should be good enough. - let ttl_margin = - Duration::from_millis((Instant::now() - time_start_get_ttl).as_millis() as u64); - - for (i, counter) in not_cached.iter_mut().enumerate() { - let cached_value = self.cached_counters.insert( - counter.clone(), - counter_vals[i], - counter_ttls_msecs[i], - ttl_margin, - now, - ); - let remaining = cached_value.remaining(counter); + for counter in not_cached.iter_mut() { + let fake = CachedCounterValue::load_from_authority_asap(counter, 0); + let remaining = fake.remaining(counter); if first_limited.is_none() && remaining <= 0 { first_limited = Some(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), @@ -140,7 +110,7 @@ impl AsyncCounterStorage for CachedRedisStorage { } if load_counters { counter.set_remaining(remaining - delta); - counter.set_expires_in(cached_value.to_next_window()); + counter.set_expires_in(fake.to_next_window()); // todo: this is a plain lie! } } } @@ -154,26 +124,6 @@ impl AsyncCounterStorage for CachedRedisStorage { self.cached_counters.increase_by(counter, delta); } - // Batch or update depending on configuration - let mut batcher = self.batcher_counter_updates.lock().unwrap(); - let now = SystemTime::now(); - for counter in counters.iter() { - match batcher.get_mut(counter) { - Some(val) => { - val.update(delta, counter.seconds(), now); - } - None => { - batcher.insert( - counter.clone(), - AtomicExpiringValue::new( - delta, - now + Duration::from_secs(counter.seconds()), - ), - ); - } - } - } - Ok(Authorization::Ok) } @@ -231,96 +181,36 @@ impl CachedRedisStorage { .max_cached_counters(max_cached_counters) .max_ttl_cached_counter(ttl_cached_counters) .ttl_ratio_cached_counter(ttl_ratio_cached_counters) - .build(); + .build(flushing_period); let counters_cache = Arc::new(cached_counters); let partitioned = Arc::new(AtomicBool::new(false)); let async_redis_storage = AsyncRedisStorage::new_with_conn_manager(redis_conn_manager.clone()); - let batcher: Arc>> = - Arc::new(Mutex::new(Default::default())); { let storage = async_redis_storage.clone(); let counters_cache_clone = counters_cache.clone(); let conn = redis_conn_manager.clone(); let p = Arc::clone(&partitioned); - let batcher_flusher = batcher.clone(); - let mut interval = tokio::time::interval(flushing_period); tokio::spawn(async move { loop { flush_batcher_and_update_counters( conn.clone(), - batcher_flusher.clone(), storage.is_alive().await, counters_cache_clone.clone(), p.clone(), ) .await; - interval.tick().await; } }); } Ok(Self { cached_counters: counters_cache, - batcher_counter_updates: batcher, - redis_conn_manager, async_redis_storage, - partitioned, }) } - - fn is_partitioned(&self) -> bool { - self.partitioned.load(Ordering::Acquire) - } - - fn partitioned(&self, partition: bool) -> bool { - flip_partitioned(&self.partitioned, partition) - } - - fn fallback_vals_ttls(&self, counters: &Vec<&mut Counter>) -> (Vec>, Vec) { - let mut vals = Vec::with_capacity(counters.len()); - let mut ttls = Vec::with_capacity(counters.len()); - for counter in counters { - vals.push(Some(0i64)); - ttls.push(counter.limit().seconds() as i64 * 1000); - } - (vals, ttls) - } - - async fn values_with_ttls( - &self, - counters: &[&mut Counter], - ) -> Result<(Vec>, Vec), StorageErr> { - let mut redis_con = self.redis_conn_manager.clone(); - - let counter_keys: Vec = counters - .iter() - .map(|counter| key_for_counter(counter)) - .collect(); - - let script = redis::Script::new(VALUES_AND_TTLS); - let mut script_invocation = script.prepare_invoke(); - - for counter_key in counter_keys { - script_invocation.key(counter_key); - } - - let script_res: Vec> = script_invocation - .invoke_async::<_, _>(&mut redis_con) - .await?; - - let mut counter_vals: Vec> = vec![]; - let mut counter_ttls_msecs: Vec = vec![]; - - for val_ttl_pair in script_res.chunks(2) { - counter_vals.push(val_ttl_pair[0]); - counter_ttls_msecs.push(val_ttl_pair[1].unwrap()); - } - - Ok((counter_vals, counter_ttls_msecs)) - } } fn flip_partitioned(storage: &AtomicBool, partition: bool) -> bool { @@ -398,15 +288,14 @@ impl CachedRedisStorageBuilder { async fn update_counters( redis_conn: &mut C, - counters_and_deltas: HashMap, + counters_and_deltas: HashMap>, ) -> Result, StorageErr> { let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS); let mut script_invocation = redis_script.prepare_invoke(); let mut res: Vec<(Counter, i64, i64)> = Vec::new(); - let now = SystemTime::now(); for (counter, delta) in counters_and_deltas { - let delta = delta.value_at(now); + let delta = delta.pending_writes().expect("State machine is wrong!"); if delta > 0 { script_invocation.key(key_for_counter(&counter)); script_invocation.key(key_for_counters_of_limit(counter.limit())); @@ -439,25 +328,18 @@ async fn update_counters( async fn flush_batcher_and_update_counters( mut redis_conn: C, - batcher: Arc>>, storage_is_alive: bool, cached_counters: Arc, partitioned: Arc, ) { if partitioned.load(Ordering::Acquire) || !storage_is_alive { - let batch = batcher.lock().unwrap(); - if !batch.is_empty() { + if !cached_counters.batcher().is_empty() { flip_partitioned(&partitioned, false); } } else { - let counters = { - let mut batch = batcher.lock().unwrap(); - std::mem::take(&mut *batch) - }; - - let time_start_update_counters = Instant::now(); - - let updated_counters = update_counters(&mut redis_conn, counters) + let updated_counters = cached_counters + .batcher() + .consume(1, |counters| update_counters(&mut redis_conn, counters)) .await .or_else(|err| { if err.is_transient() { @@ -469,6 +351,8 @@ async fn flush_batcher_and_update_counters( }) .expect("Unrecoverable Redis error!"); + let time_start_update_counters = Instant::now(); + for (counter, value, ttl) in updated_counters { cached_counters.insert( counter, @@ -487,16 +371,17 @@ async fn flush_batcher_and_update_counters( mod tests { use crate::counter::Counter; use crate::limit::Limit; - use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::keys::{key_for_counter, key_for_counters_of_limit}; - use crate::storage::redis::counters_cache::{CountersCache, CountersCacheBuilder}; + use crate::storage::redis::counters_cache::{ + CachedCounterValue, CountersCache, CountersCacheBuilder, + }; use crate::storage::redis::redis_cached::{flush_batcher_and_update_counters, update_counters}; use crate::storage::redis::CachedRedisStorage; use redis::{ErrorKind, Value}; use redis_test::{MockCmd, MockRedisConnection}; use std::collections::HashMap; use std::sync::atomic::AtomicBool; - use std::sync::{Arc, Mutex}; + use std::sync::Arc; use std::time::{Duration, SystemTime}; #[tokio::test] @@ -529,10 +414,14 @@ mod tests { Default::default(), ); - let expiring_value = - AtomicExpiringValue::new(1, SystemTime::now() + Duration::from_secs(60)); - - counters_and_deltas.insert(counter.clone(), expiring_value); + counters_and_deltas.insert( + counter.clone(), + Arc::new(CachedCounterValue::from_authority( + &counter, + 1, + Duration::from_secs(60), + )), + ); let mock_response = Value::Bulk(vec![Value::Int(10), Value::Int(60)]); @@ -586,15 +475,15 @@ mod tests { Ok(mock_response.clone()), )]); - let mut batched_counters = HashMap::new(); - batched_counters.insert( + let cache = CountersCacheBuilder::new().build(Duration::from_millis(1)); + cache.batcher().add( counter.clone(), - AtomicExpiringValue::new(2, SystemTime::now() + Duration::from_secs(60)), + Arc::new(CachedCounterValue::from_authority( + &counter, + 2, + Duration::from_secs(60), + )), ); - - let batcher: Arc>> = - Arc::new(Mutex::new(batched_counters)); - let cache = CountersCacheBuilder::new().build(); cache.insert( counter.clone(), Some(1), @@ -606,17 +495,11 @@ mod tests { let partitioned = Arc::new(AtomicBool::new(false)); if let Some(c) = cached_counters.get(&counter) { - assert_eq!(c.hits(&counter), 1); + assert_eq!(c.hits(&counter), 2); } - flush_batcher_and_update_counters( - mock_client, - batcher, - true, - cached_counters.clone(), - partitioned, - ) - .await; + flush_batcher_and_update_counters(mock_client, true, cached_counters.clone(), partitioned) + .await; if let Some(c) = cached_counters.get(&counter) { assert_eq!(c.hits(&counter), 8);