Skip to content

Commit

Permalink
Added Batcher type back
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsnaps committed Apr 25, 2024
1 parent 5399c32 commit d4eed7b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 45 deletions.
54 changes: 41 additions & 13 deletions limitador/src/storage/redis/counters_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::storage::redis::{
use moka::sync::Cache;
use std::collections::HashMap;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, MutexGuard};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};

pub struct CachedCounterValue {
Expand All @@ -16,10 +16,42 @@ pub struct CachedCounterValue {
expiry: AtomicExpiryTime,
}

pub struct Batcher {
updates: Mutex<HashMap<Counter, Arc<CachedCounterValue>>>,
}

impl Batcher {
fn new() -> Self {
Self {
updates: Mutex::new(Default::default()),
}
}

pub fn is_empty(&self) -> bool {
self.updates.lock().unwrap().is_empty()
}

pub fn consume_all(&self) -> HashMap<Counter, Arc<CachedCounterValue>> {
let mut batch = self.updates.lock().unwrap();
std::mem::take(&mut *batch)
}

pub fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) {
self.updates.lock().unwrap().entry(counter).or_insert(value);
}
}

impl Default for Batcher {
fn default() -> Self {
Self::new()
}
}

pub struct CountersCache {
max_ttl_cached_counters: Duration,
pub ttl_ratio_cached_counters: u64,
cache: Cache<Counter, Arc<CachedCounterValue>>,
batcher: Batcher,
}

impl CachedCounterValue {
Expand Down Expand Up @@ -137,6 +169,7 @@ impl CountersCacheBuilder {
max_ttl_cached_counters: self.max_ttl_cached_counters,
ttl_ratio_cached_counters: self.ttl_ratio_cached_counters,
cache: Cache::new(self.max_cached_counters as u64),
batcher: Default::default(),
}
}
}
Expand All @@ -146,6 +179,10 @@ impl CountersCache {
self.cache.get(counter)
}

pub fn batcher(&self) -> &Batcher {
&self.batcher
}

pub fn insert(
&self,
counter: Counter,
Expand Down Expand Up @@ -179,12 +216,7 @@ impl CountersCache {
))
}

pub fn increase_by(
&self,
counter: &Counter,
delta: i64,
batcher: Option<&mut MutexGuard<HashMap<Counter, Arc<CachedCounterValue>>>>,
) {
pub fn increase_by(&self, counter: &Counter, delta: i64) {
let val = self.cache.get_with_by_ref(counter, || {
Arc::new(
// this TTL is wrong, it needs to be the cache's TTL, not the time window of our limit
Expand All @@ -193,11 +225,7 @@ impl CountersCache {
)
});
val.delta(counter, delta);
if let Some(batcher) = batcher {
if batcher.get_mut(counter).is_none() {
batcher.insert(counter.clone(), val.clone());
}
}
self.batcher.add(counter.clone(), val.clone());
}

fn ttl_from_redis_ttl(
Expand Down Expand Up @@ -383,7 +411,7 @@ mod tests {
Duration::from_secs(0),
SystemTime::now(),
);
cache.increase_by(&counter, increase_by, None);
cache.increase_by(&counter, increase_by);

assert_eq!(
cache.get(&counter).map(|e| e.hits(&counter)).unwrap(),
Expand Down
41 changes: 9 additions & 32 deletions limitador/src/storage/redis/redis_cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use redis::{ConnectionInfo, RedisError};
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use tracing::{debug_span, error, warn, Instrument};

Expand All @@ -41,7 +41,6 @@ use tracing::{debug_span, error, warn, Instrument};

pub struct CachedRedisStorage {
cached_counters: Arc<CountersCache>,
batcher_counter_updates: Arc<Mutex<HashMap<Counter, Arc<CachedCounterValue>>>>,
async_redis_storage: AsyncRedisStorage,
redis_conn_manager: ConnectionManager,
partitioned: Arc<AtomicBool>,
Expand Down Expand Up @@ -151,10 +150,8 @@ impl AsyncCounterStorage for CachedRedisStorage {
}

// Update cached values
let mut batcher = self.batcher_counter_updates.lock().unwrap();
for counter in counters.iter() {
self.cached_counters
.increase_by(counter, delta, Some(&mut batcher));
self.cached_counters.increase_by(counter, delta);
}

Ok(Authorization::Ok)
Expand Down Expand Up @@ -220,21 +217,17 @@ impl CachedRedisStorage {
let partitioned = Arc::new(AtomicBool::new(false));
let async_redis_storage =
AsyncRedisStorage::new_with_conn_manager(redis_conn_manager.clone());
let batcher: Arc<Mutex<HashMap<Counter, Arc<CachedCounterValue>>>> =
Arc::new(Mutex::new(Default::default()));

{
let storage = async_redis_storage.clone();
let counters_cache_clone = counters_cache.clone();
let conn = redis_conn_manager.clone();
let p = Arc::clone(&partitioned);
let batcher_flusher = batcher.clone();
let mut interval = tokio::time::interval(flushing_period);
tokio::spawn(async move {
loop {
flush_batcher_and_update_counters(
conn.clone(),
batcher_flusher.clone(),
storage.is_alive().await,
counters_cache_clone.clone(),
p.clone(),
Expand All @@ -247,7 +240,6 @@ impl CachedRedisStorage {

Ok(Self {
cached_counters: counters_cache,
batcher_counter_updates: batcher,
redis_conn_manager,
async_redis_storage,
partitioned,
Expand Down Expand Up @@ -421,21 +413,16 @@ async fn update_counters<C: ConnectionLike>(

async fn flush_batcher_and_update_counters<C: ConnectionLike>(
mut redis_conn: C,
batcher: Arc<Mutex<HashMap<Counter, Arc<CachedCounterValue>>>>,
storage_is_alive: bool,
cached_counters: Arc<CountersCache>,
partitioned: Arc<AtomicBool>,
) {
if partitioned.load(Ordering::Acquire) || !storage_is_alive {
let batch = batcher.lock().unwrap();
if !batch.is_empty() {
if !cached_counters.batcher().is_empty() {
flip_partitioned(&partitioned, false);
}
} else {
let counters = {
let mut batch = batcher.lock().unwrap();
std::mem::take(&mut *batch)
};
let counters = cached_counters.batcher().consume_all();

let time_start_update_counters = Instant::now();

Expand Down Expand Up @@ -479,7 +466,7 @@ mod tests {
use redis_test::{MockCmd, MockRedisConnection};
use std::collections::HashMap;
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use std::time::{Duration, SystemTime};

#[tokio::test]
Expand Down Expand Up @@ -573,19 +560,15 @@ mod tests {
Ok(mock_response.clone()),
)]);

let mut batched_counters = HashMap::new();
batched_counters.insert(
let cache = CountersCacheBuilder::new().build();
cache.batcher().add(
counter.clone(),
Arc::new(CachedCounterValue::from(
&counter,
2,
Duration::from_secs(60),
)),
);

let batcher: Arc<Mutex<HashMap<Counter, Arc<CachedCounterValue>>>> =
Arc::new(Mutex::new(batched_counters));
let cache = CountersCacheBuilder::new().build();
cache.insert(
counter.clone(),
Some(1),
Expand All @@ -600,14 +583,8 @@ mod tests {
assert_eq!(c.hits(&counter), 1);
}

flush_batcher_and_update_counters(
mock_client,
batcher,
true,
cached_counters.clone(),
partitioned,
)
.await;
flush_batcher_and_update_counters(mock_client, true, cached_counters.clone(), partitioned)
.await;

if let Some(c) = cached_counters.get(&counter) {
assert_eq!(c.hits(&counter), 8);
Expand Down

0 comments on commit d4eed7b

Please sign in to comment.