Skip to content

Commit

Permalink
Merge pull request #304 from chirino/fix-295
Browse files Browse the repository at this point in the history
fixes #295: Use a semaphore to protect the Batcher from unbounded memory growth.
  • Loading branch information
chirino authored May 15, 2024
2 parents a5415b6 + ad6a8f5 commit 37bb70b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 31 deletions.
64 changes: 38 additions & 26 deletions limitador/src/storage/redis/counters_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::select;
use tokio::sync::Notify;
use tokio::sync::{Notify, Semaphore};

#[derive(Debug)]
pub struct CachedCounterValue {
Expand Down Expand Up @@ -129,19 +129,21 @@ pub struct Batcher {
notifier: Notify,
interval: Duration,
priority_flush: AtomicBool,
limiter: Semaphore,
}

impl Batcher {
fn new(period: Duration) -> Self {
fn new(period: Duration, max_cached_counters: usize) -> Self {
Self {
updates: Default::default(),
notifier: Default::default(),
interval: period,
priority_flush: AtomicBool::new(false),
limiter: Semaphore::new(max_cached_counters),
}
}

pub fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) {
pub async fn add(&self, counter: Counter, value: Arc<CachedCounterValue>) {
let priority = value.requires_fast_flush(&self.interval);
match self.updates.entry(counter.clone()) {
Entry::Occupied(needs_merge) => {
Expand All @@ -151,6 +153,7 @@ impl Batcher {
}
}
Entry::Vacant(miss) => {
self.limiter.acquire().await.unwrap().forget();
miss.insert_entry(value);
}
};
Expand Down Expand Up @@ -186,8 +189,12 @@ impl Batcher {
}
let result = consumer(result).await;
batch.iter().for_each(|counter| {
self.updates
let prev = self
.updates
.remove_if(counter, |_, v| v.no_pending_writes());
if prev.is_some() {
self.limiter.add_permits(1);
}
});
return result;
} else {
Expand All @@ -214,7 +221,7 @@ impl Batcher {

impl Default for Batcher {
fn default() -> Self {
Self::new(Duration::from_millis(100))
Self::new(Duration::from_millis(100), DEFAULT_MAX_CACHED_COUNTERS)
}
}

Expand Down Expand Up @@ -272,7 +279,7 @@ impl CountersCache {
))
}

pub fn increase_by(&self, counter: &Counter, delta: i64) {
pub async fn increase_by(&self, counter: &Counter, delta: i64) {
let val = self.cache.get_with_by_ref(counter, || {
if let Some(entry) = self.batcher.updates.get(counter) {
entry.value().clone()
Expand All @@ -281,7 +288,7 @@ impl CountersCache {
}
});
val.delta(counter, delta);
self.batcher.add(counter.clone(), val.clone());
self.batcher.add(counter.clone(), val.clone()).await;
}
}

Expand All @@ -304,25 +311,28 @@ impl CountersCacheBuilder {
pub fn build(&self, period: Duration) -> CountersCache {
CountersCache {
cache: Cache::new(self.max_cached_counters as u64),
batcher: Batcher::new(period),
batcher: Batcher::new(period, self.max_cached_counters),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::limit::Limit;
use std::collections::HashMap;
use std::ops::Add;
use std::time::UNIX_EPOCH;

use crate::limit::Limit;

use super::*;

mod cached_counter_value {
use crate::storage::redis::counters_cache::tests::test_counter;
use crate::storage::redis::counters_cache::CachedCounterValue;
use std::ops::{Add, Not};
use std::time::{Duration, SystemTime};

use crate::storage::redis::counters_cache::tests::test_counter;
use crate::storage::redis::counters_cache::CachedCounterValue;

#[test]
fn records_pending_writes() {
let counter = test_counter(10, None);
Expand Down Expand Up @@ -401,15 +411,17 @@ mod tests {
}

mod batcher {
use crate::storage::redis::counters_cache::tests::test_counter;
use crate::storage::redis::counters_cache::{Batcher, CachedCounterValue};
use std::sync::Arc;
use std::time::{Duration, SystemTime};

use crate::storage::redis::counters_cache::tests::test_counter;
use crate::storage::redis::counters_cache::{Batcher, CachedCounterValue};
use crate::storage::redis::DEFAULT_MAX_CACHED_COUNTERS;

#[tokio::test]
async fn consume_waits_when_empty() {
let duration = Duration::from_millis(100);
let batcher = Batcher::new(duration);
let batcher = Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS);
let start = SystemTime::now();
batcher
.consume(2, |items| {
Expand All @@ -423,15 +435,15 @@ mod tests {
#[tokio::test]
async fn consume_waits_when_batch_not_filled() {
let duration = Duration::from_millis(100);
let batcher = Arc::new(Batcher::new(duration));
let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS));
let start = SystemTime::now();
{
let batcher = Arc::clone(&batcher);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(40)).await;
let counter = test_counter(6, None);
let arc = Arc::new(CachedCounterValue::from_authority(&counter, 0));
batcher.add(counter, arc);
batcher.add(counter, arc).await;
});
}
batcher
Expand All @@ -449,15 +461,15 @@ mod tests {
#[tokio::test]
async fn consume_waits_until_batch_is_filled() {
let duration = Duration::from_millis(100);
let batcher = Arc::new(Batcher::new(duration));
let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS));
let start = SystemTime::now();
{
let batcher = Arc::clone(&batcher);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(40)).await;
let counter = test_counter(6, None);
let arc = Arc::new(CachedCounterValue::from_authority(&counter, 0));
batcher.add(counter, arc);
batcher.add(counter, arc).await;
});
}
batcher
Expand All @@ -474,12 +486,12 @@ mod tests {
#[tokio::test]
async fn consume_immediately_when_batch_is_filled() {
let duration = Duration::from_millis(100);
let batcher = Arc::new(Batcher::new(duration));
let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS));
let start = SystemTime::now();
{
let counter = test_counter(6, None);
let arc = Arc::new(CachedCounterValue::from_authority(&counter, 0));
batcher.add(counter, arc);
batcher.add(counter, arc).await;
}
batcher
.consume(1, |items| {
Expand All @@ -495,15 +507,15 @@ mod tests {
#[tokio::test]
async fn consume_triggers_on_fast_flush() {
let duration = Duration::from_millis(100);
let batcher = Arc::new(Batcher::new(duration));
let batcher = Arc::new(Batcher::new(duration, DEFAULT_MAX_CACHED_COUNTERS));
let start = SystemTime::now();
{
let batcher = Arc::clone(&batcher);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(40)).await;
let counter = test_counter(6, None);
let arc = Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 0));
batcher.add(counter, arc);
batcher.add(counter, arc).await;
});
}
batcher
Expand Down Expand Up @@ -570,8 +582,8 @@ mod tests {
);
}

#[test]
fn increase_by() {
#[tokio::test]
async fn increase_by() {
let current_val = 10;
let increase_by = 8;
let counter = test_counter(current_val, None);
Expand All @@ -587,7 +599,7 @@ mod tests {
.unwrap()
.as_micros() as i64,
);
cache.increase_by(&counter, increase_by);
cache.increase_by(&counter, increase_by).await;

assert_eq!(
cache.get(&counter).map(|e| e.hits(&counter)).unwrap(),
Expand Down
13 changes: 8 additions & 5 deletions limitador/src/storage/redis/redis_cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl AsyncCounterStorage for CachedRedisStorage {

// Update cached values
for counter in counters.iter() {
self.cached_counters.increase_by(counter, delta);
self.cached_counters.increase_by(counter, delta).await;
}

Ok(Authorization::Ok)
Expand Down Expand Up @@ -489,10 +489,13 @@ mod tests {
)]);

let cache = CountersCacheBuilder::new().build(Duration::from_millis(10));
cache.batcher().add(
counter.clone(),
Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 2)),
);
cache
.batcher()
.add(
counter.clone(),
Arc::new(CachedCounterValue::load_from_authority_asap(&counter, 2)),
)
.await;

let cached_counters: Arc<CountersCache> = Arc::new(cache);
let partitioned = Arc::new(AtomicBool::new(false));
Expand Down

0 comments on commit 37bb70b

Please sign in to comment.