diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 9eb17af7..9f10e0ca 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -63,8 +63,13 @@ impl CountersCache { redis_ttl_ms: i64, ttl_margin: Duration, ) { - let counter_val = Self::value_from_redis_val(redis_val, counter.max_value()); - let counter_ttl = self.ttl_from_redis_ttl(redis_ttl_ms, counter.seconds(), counter_val); + let counter_val = redis_val.unwrap_or(0); + let counter_ttl = self.ttl_from_redis_ttl( + redis_ttl_ms, + counter.seconds(), + counter_val, + counter.max_value(), + ); if let Some(ttl) = counter_ttl.checked_sub(ttl_margin) { if ttl > Duration::from_secs(0) { self.cache.insert(counter, counter_val, ttl); @@ -72,24 +77,18 @@ impl CountersCache { } } - pub fn decrease_by(&mut self, counter: &Counter, delta: i64) { + pub fn increase_by(&mut self, counter: &Counter, delta: i64) { if let Some(val) = self.cache.get_mut(counter) { - *val -= delta + *val += delta }; } - fn value_from_redis_val(redis_val: Option, counter_max: i64) -> i64 { - match redis_val { - Some(val) => val, - None => counter_max, - } - } - fn ttl_from_redis_ttl( &self, redis_ttl_ms: i64, counter_seconds: u64, counter_val: i64, + counter_max: i64, ) -> Duration { // Redis returns -2 when the key does not exist. Ref: // https://redis.io/commands/ttl @@ -102,11 +101,11 @@ impl CountersCache { Duration::from_secs(counter_seconds) }; - // If a counter is already at 0, we can cache it for as long as its TTL + // If a counter is already at counter_max, we can cache it for as long as its TTL // is in Redis. This does not depend on the requests received by other // instances of Limitador. No matter what they do, we know that the // counter is not going to recover its quota until it expires in Redis. - if counter_val <= 0 { + if counter_val >= counter_max { return counter_ttl; } @@ -205,7 +204,7 @@ mod tests { } #[test] - fn insert_saves_max_value_when_redis_val_is_none() { + fn insert_saves_0_when_redis_val_is_none() { let max_val = 10; let mut values = HashMap::new(); values.insert("app_id".to_string(), "1".to_string()); @@ -223,13 +222,13 @@ mod tests { let mut cache = CountersCacheBuilder::new().build(); cache.insert(counter.clone(), None, 10, Duration::from_secs(0)); - assert_eq!(cache.get(&counter).unwrap(), max_val); + assert_eq!(cache.get(&counter).unwrap(), 0); } #[test] - fn decrease_by() { + fn increase_by() { let current_val = 10; - let decrease_by = 8; + let increase_by = 8; let mut values = HashMap::new(); values.insert("app_id".to_string(), "1".to_string()); let counter = Counter::new( @@ -250,8 +249,8 @@ mod tests { 10, Duration::from_secs(0), ); - cache.decrease_by(&counter, decrease_by); + cache.increase_by(&counter, increase_by); - assert_eq!(cache.get(&counter).unwrap(), current_val - decrease_by); + assert_eq!(cache.get(&counter).unwrap(), current_val + increase_by); } } diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index d10e92e8..6741811e 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -83,7 +83,7 @@ impl AsyncCounterStorage for CachedRedisStorage { for counter in counters.iter_mut() { match cached_counters.get(counter) { Some(val) => { - if first_limited.is_none() && val - delta < 0 { + if first_limited.is_none() && val + delta > counter.max_value() { let a = Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), ); @@ -93,7 +93,7 @@ impl AsyncCounterStorage for CachedRedisStorage { first_limited = Some(a); } if load_counters { - counter.set_remaining(val); + counter.set_remaining(counter.max_value() - val - delta); // todo: how do we get the ttl for this entry? // counter.set_expires_in(Duration::from_secs(counter.seconds())); } @@ -138,7 +138,7 @@ impl AsyncCounterStorage for CachedRedisStorage { counter_ttls_msecs[i], ttl_margin, ); - let remaining = counter_vals[i].unwrap_or(counter.max_value()) - delta; + let remaining = counter.max_value() - counter_vals[i].unwrap_or(0) - delta; if first_limited.is_none() && remaining < 0 { first_limited = Some(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), @@ -146,7 +146,13 @@ impl AsyncCounterStorage for CachedRedisStorage { } if load_counters { counter.set_remaining(remaining); - counter.set_expires_in(Duration::from_millis(counter_ttls_msecs[i] as u64)); + let counter_ttl = if counter_ttls_msecs[i] >= 0 { + Duration::from_millis(counter_ttls_msecs[i] as u64) + } else { + Duration::from_secs(counter.max_value() as u64) + }; + + counter.set_expires_in(counter_ttl); } } } @@ -160,7 +166,7 @@ impl AsyncCounterStorage for CachedRedisStorage { { let mut cached_counters = self.cached_counters.lock().unwrap(); for counter in counters.iter() { - cached_counters.decrease_by(counter, delta); + cached_counters.increase_by(counter, delta); } } diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index d300c46f..4807b832 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -24,7 +24,7 @@ macro_rules! test_with_all_storage_impls { #[cfg(feature = "redis_storage")] #[tokio::test] #[serial] - async fn [<$function _with_redis>]() { + async fn [<$function _with_sync_redis>]() { let storage = RedisStorage::default(); storage.clear().unwrap(); let rate_limiter = RateLimiter::new_with_storage( @@ -51,7 +51,23 @@ macro_rules! test_with_all_storage_impls { let rate_limiter = AsyncRateLimiter::new_with_storage( Box::new(storage) ); - AsyncRedisStorage::new("redis://127.0.0.1:6379").await.expect("We need a Redis running locally").clear().await.unwrap(); + $function(&mut TestsLimiter::new_from_async_impl(rate_limiter)).await; + } + + #[cfg(feature = "redis_storage")] + #[tokio::test] + #[serial] + async fn [<$function _with_async_redis_and_local_cache>]() { + let storage_builder = CachedRedisStorageBuilder::new("redis://127.0.0.1:6379"). + flushing_period(None). + max_ttl_cached_counters(Duration::from_secs(3600)). + ttl_ratio_cached_counters(1). + max_cached_counters(10000); + let storage = storage_builder.build().await.expect("We need a Redis running locally"); + storage.clear().await.unwrap(); + let rate_limiter = AsyncRateLimiter::new_with_storage( + Box::new(storage) + ); $function(&mut TestsLimiter::new_from_async_impl(rate_limiter)).await; } @@ -82,6 +98,7 @@ mod test { cfg_if::cfg_if! { if #[cfg(feature = "redis_storage")] { use limitador::storage::redis::AsyncRedisStorage; + use limitador::storage::redis::CachedRedisStorageBuilder; use limitador::storage::redis::RedisStorage; use limitador::AsyncRateLimiter;