diff --git a/.idea/limitador.iml b/.idea/limitador.iml index e6ebde49..455a1806 100644 --- a/.idea/limitador.iml +++ b/.idea/limitador.iml @@ -1,12 +1,14 @@ - + + + diff --git a/limitador/src/counter.rs b/limitador/src/counter.rs index a83485ac..c1bcf889 100644 --- a/limitador/src/counter.rs +++ b/limitador/src/counter.rs @@ -102,14 +102,10 @@ impl Hash for Counter { fn hash(&self, state: &mut H) { self.limit.hash(state); - let mut encoded_vars = self - .set_variables - .iter() - .map(|(k, v)| k.to_owned() + ":" + v) - .collect::>(); - - encoded_vars.sort(); - encoded_vars.hash(state); + self.set_variables.iter().for_each(|(k, v)| { + k.hash(state); + v.hash(state); + }); } } diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index db6e944a..d541007d 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -406,24 +406,8 @@ impl Hash for Limit { fn hash(&self, state: &mut H) { self.namespace.hash(state); self.seconds.hash(state); - - let mut encoded_conditions = self - .conditions - .iter() - .map(|c| c.clone().into()) - .collect::>(); - - encoded_conditions.sort(); - encoded_conditions.hash(state); - - let mut encoded_vars = self - .variables - .iter() - .map(|c| c.to_string()) - .collect::>(); - - encoded_vars.sort(); - encoded_vars.hash(state); + self.conditions.iter().for_each(|e| e.hash(state)); + self.variables.iter().for_each(|e| e.hash(state)); } } diff --git a/limitador/src/storage/disk/mod.rs b/limitador/src/storage/disk/mod.rs index 011f6e2a..beba2377 100644 --- a/limitador/src/storage/disk/mod.rs +++ b/limitador/src/storage/disk/mod.rs @@ -1,6 +1,5 @@ use crate::storage::StorageErr; -mod expiring_value; mod rocksdb_storage; pub use rocksdb_storage::RocksDbStorage as DiskStorage; diff --git a/limitador/src/storage/disk/rocksdb_storage.rs b/limitador/src/storage/disk/rocksdb_storage.rs index d7275529..9b03e141 100644 --- a/limitador/src/storage/disk/rocksdb_storage.rs +++ b/limitador/src/storage/disk/rocksdb_storage.rs @@ -1,7 +1,7 @@ use crate::counter::Counter; use crate::limit::Limit; -use crate::storage::disk::expiring_value::ExpiringValue; use crate::storage::disk::OptimizeFor; +use crate::storage::expiring_value::ExpiringValue; use crate::storage::keys::bin::{ key_for_counter, partial_counter_from_counter_key, prefix_for_namespace, }; diff --git a/limitador/src/storage/disk/expiring_value.rs b/limitador/src/storage/expiring_value.rs similarity index 92% rename from limitador/src/storage/disk/expiring_value.rs rename to limitador/src/storage/expiring_value.rs index cdfd5d5e..922a833e 100644 --- a/limitador/src/storage/disk/expiring_value.rs +++ b/limitador/src/storage/expiring_value.rs @@ -24,6 +24,7 @@ impl ExpiringValue { self.value_at(SystemTime::now()) } + #[must_use] pub fn update(self, delta: i64, ttl: u64, now: SystemTime) -> Self { let expiry = if self.expiry <= now { now + Duration::from_secs(ttl) @@ -35,6 +36,18 @@ impl ExpiringValue { Self { value, expiry } } + pub fn update_mut(&mut self, delta: i64, ttl: u64, now: SystemTime) { + let expiry = if self.expiry <= now { + now + Duration::from_secs(ttl) + } else { + self.expiry + }; + + self.value = self.value_at(now) + delta; + self.expiry = expiry; + } + + #[must_use] pub fn merge(self, other: ExpiringValue, now: SystemTime) -> Self { if self.expiry > now { ExpiringValue { diff --git a/limitador/src/storage/in_memory.rs b/limitador/src/storage/in_memory.rs index 86176a06..4f2230f4 100644 --- a/limitador/src/storage/in_memory.rs +++ b/limitador/src/storage/in_memory.rs @@ -1,31 +1,97 @@ use crate::counter::Counter; use crate::limit::{Limit, Namespace}; +use crate::storage::expiring_value::ExpiringValue; use crate::storage::{Authorization, CounterStorage, StorageErr}; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; use std::sync::RwLock; -use std::time::Duration; -use ttl_cache::TtlCache; +use std::time::{Duration, SystemTime}; + +#[derive(Eq, Clone)] +struct CounterKey { + set_variables: HashMap, +} + +impl CounterKey { + fn to_counter(&self, limit: &Limit) -> Counter { + Counter::new(limit.clone(), self.set_variables.clone()) + } +} + +impl From<&Counter> for CounterKey { + fn from(counter: &Counter) -> Self { + CounterKey { + set_variables: counter.set_variables().clone(), + } + } +} + +impl From<&mut Counter> for CounterKey { + fn from(counter: &mut Counter) -> Self { + CounterKey { + set_variables: counter.set_variables().clone(), + } + } +} + +impl Hash for CounterKey { + fn hash(&self, state: &mut H) { + self.set_variables.iter().for_each(|(k, v)| { + k.hash(state); + v.hash(state); + }); + } +} + +impl PartialEq for CounterKey { + fn eq(&self, other: &Self) -> bool { + self.set_variables == other.set_variables + } +} + +type NamespacedLimitCounters = HashMap>>; pub struct InMemoryStorage { - limits_for_namespace: RwLock>>>, - counters: RwLock>, + limits_for_namespace: RwLock>, } impl CounterStorage for InMemoryStorage { fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result { - let stored_counters = self.counters.read().unwrap(); + let limits_by_namespace = self.limits_for_namespace.read().unwrap(); - Ok(Self::counter_is_within_limits( - counter, - stored_counters.get(counter), - delta, - )) + let mut value = 0; + if let Some(limits) = limits_by_namespace.get(counter.limit().namespace()) { + if let Some(counters) = limits.get(counter.limit()) { + if let Some(expiring_value) = counters.get(&counter.into()) { + value = expiring_value.value(); + } + } + } + Ok(counter.max_value() >= value + delta) } fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr> { - let mut counters = self.counters.write().unwrap(); - self.insert_or_update_counter(&mut counters, counter, delta); + let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); + match limits_by_namespace.entry(counter.limit().namespace().clone()) { + Entry::Vacant(v) => { + let mut limits = HashMap::new(); + let mut counters = HashMap::new(); + self.insert_or_update_counter(&mut counters, counter, delta); + limits.insert(counter.limit().clone(), counters); + v.insert(limits); + } + Entry::Occupied(mut o) => match o.get_mut().entry(counter.limit().clone()) { + Entry::Vacant(v) => { + let mut counters = HashMap::new(); + self.insert_or_update_counter(&mut counters, counter, delta); + v.insert(counters); + } + Entry::Occupied(mut o) => { + self.insert_or_update_counter(o.get_mut(), counter, delta); + } + }, + } Ok(()) } @@ -35,36 +101,80 @@ impl CounterStorage for InMemoryStorage { delta: i64, load_counters: bool, ) -> Result { - // This makes the operator of check + update atomic - let mut stored_counters = self.counters.write().unwrap(); - - if load_counters { - let mut first_limited = None; - for counter in counters.iter_mut() { - let remaining = - *stored_counters.get(counter).unwrap_or(&counter.max_value()) - delta; - counter.set_remaining(remaining); - if first_limited.is_none() && remaining < 0 { - first_limited = Some(Authorization::Limited( - counter.limit().name().map(|n| n.to_owned()), - )) + let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); + let mut first_limited = None; + + let mut process_counter = + |counter: &mut Counter, value: i64, delta: i64| -> Option { + if load_counters { + let remaining = counter.max_value() - (value + delta); + counter.set_remaining(remaining); + if first_limited.is_none() && remaining < 0 { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } } - } - if let Some(l) = first_limited { - return Ok(l); - } - } else { - for counter in counters.iter() { - if !Self::counter_is_within_limits(counter, stored_counters.get(counter), delta) { - return Ok(Authorization::Limited( + if !Self::counter_is_within_limits(counter, Some(&value), delta) { + return Some(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), )); } + None + }; + + for counter in counters.iter_mut() { + if counter.max_value() < delta { + if let Some(limited) = process_counter(counter, 0, delta) { + if !load_counters { + return Ok(limited); + } + } + continue; + } + if let Some(limits) = limits_by_namespace.get(counter.limit().namespace()) { + if let Some(counters) = limits.get(counter.limit()) { + if let Some(expiring_value) = counters.get(&counter.into()) { + let value = expiring_value.value(); + if let Some(Authorization::Limited(counter_limited)) = + process_counter(counter, value, delta) + { + if !load_counters { + return Ok(Authorization::Limited(counter_limited)); + } + } + } + } + } else if let Some(limited) = process_counter(counter, 0, delta) { + if !load_counters { + return Ok(limited); + } } } - for counter in counters { - self.insert_or_update_counter(&mut stored_counters, counter, delta) + if let Some(limited) = first_limited { + return Ok(limited); + } + + for counter in counters.iter_mut() { + let now = SystemTime::now(); + match limits_by_namespace + .entry(counter.limit().namespace().clone()) + .or_insert_with(HashMap::new) + .entry(counter.limit().clone()) + .or_insert_with(HashMap::new) + .entry(counter.into()) + { + Entry::Vacant(v) => { + v.insert(ExpiringValue::new( + delta, + now + Duration::from_secs(counter.seconds()), + )); + } + Entry::Occupied(mut o) => { + o.get_mut().update_mut(delta, counter.seconds(), now); + } + } } Ok(Authorization::Ok) @@ -74,18 +184,26 @@ impl CounterStorage for InMemoryStorage { let mut res = HashSet::new(); let namespaces: HashSet<&Namespace> = limits.iter().map(Limit::namespace).collect(); + let limits_by_namespace = self.limits_for_namespace.read().unwrap(); for namespace in namespaces { - for counter in self.counters_in_namespace(namespace) { - if let Some(counter_val) = self.counters.read().unwrap().get(&counter) { - // TODO: return correct TTL - let mut counter_with_val = counter.clone(); - counter_with_val.set_remaining(*counter_val); - res.insert(counter_with_val); + if let Some(limits) = limits_by_namespace.get(namespace) { + for limit in limits.keys() { + if limits.contains_key(limit) { + for (counter, expiring_value) in self.counters_in_namespace(namespace) { + let mut counter_with_val = counter.clone(); + counter_with_val.set_remaining( + counter_with_val.max_value() - expiring_value.value(), + ); + counter_with_val.set_expires_in(expiring_value.ttl()); + if counter_with_val.expires_in().unwrap() > Duration::ZERO { + res.insert(counter_with_val); + } + } + } } } } - Ok(res) } @@ -97,26 +215,26 @@ impl CounterStorage for InMemoryStorage { } fn clear(&self) -> Result<(), StorageErr> { - self.counters.write().unwrap().clear(); self.limits_for_namespace.write().unwrap().clear(); Ok(()) } } impl InMemoryStorage { - pub fn new(capacity: usize) -> Self { + pub fn new() -> Self { Self { limits_for_namespace: RwLock::new(HashMap::new()), - counters: RwLock::new(TtlCache::new(capacity)), } } - fn counters_in_namespace(&self, namespace: &Namespace) -> HashSet { - let mut res: HashSet = HashSet::new(); + fn counters_in_namespace(&self, namespace: &Namespace) -> HashMap { + let mut res: HashMap = HashMap::new(); if let Some(counters_by_limit) = self.limits_for_namespace.read().unwrap().get(namespace) { - for counter in counters_by_limit.values().flatten() { - res.insert(counter.clone()); + for (limit, values) in counters_by_limit { + for (counter_key, expiring_value) in values { + res.insert(counter_key.to_counter(limit), expiring_value.clone()); + } } } @@ -125,78 +243,46 @@ impl InMemoryStorage { fn delete_counters_of_limit(&self, limit: &Limit) { if let Some(counters_by_limit) = self - .limits_for_namespace - .read() - .unwrap() - .get(limit.namespace()) - { - if let Some(counters_of_limit) = counters_by_limit.get(limit) { - let mut counters = self.counters.write().unwrap(); - for counter in counters_of_limit { - counters.remove(counter); - } - } - } - } - - fn add_counter_limit_association(&self, counter: &Counter) { - let namespace = counter.limit().namespace(); - - match self .limits_for_namespace .write() .unwrap() - .entry(namespace.clone()) + .get_mut(limit.namespace()) { - Entry::Occupied(mut e) => { - e.get_mut() - .entry(counter.limit().clone()) - .or_default() - .insert(counter.clone()); - } - Entry::Vacant(e) => { - let mut counters = HashSet::new(); - counters.insert(counter.clone()); - let mut map = HashMap::new(); - map.insert(counter.limit().clone(), counters); - e.insert(map); - } + counters_by_limit.remove(limit); } } fn insert_or_update_counter( &self, - counters: &mut TtlCache, + counters: &mut HashMap, counter: &Counter, delta: i64, ) { - match counters.get_mut(counter) { - Some(value) => { - *value -= delta; + let now = SystemTime::now(); + match counters.entry(counter.into()) { + Entry::Vacant(v) => { + v.insert(ExpiringValue::new( + delta, + now + Duration::from_secs(counter.seconds()), + )); } - None => { - counters.insert( - counter.clone(), - counter.max_value() - delta, - Duration::from_secs(counter.seconds()), - ); - - self.add_counter_limit_association(counter); + Entry::Occupied(mut o) => { + o.get_mut().update_mut(delta, counter.seconds(), now); } } } fn counter_is_within_limits(counter: &Counter, current_val: Option<&i64>, delta: i64) -> bool { match current_val { - Some(current_val) => current_val - delta >= 0, - None => counter.max_value() - delta >= 0, + Some(current_val) => current_val + delta <= counter.max_value(), + None => counter.max_value() >= delta, } } } impl Default for InMemoryStorage { fn default() -> Self { - Self::new(1000) + Self::new() } } @@ -206,7 +292,7 @@ mod tests { #[test] fn counters_for_multiple_limit_per_ns() { - let storage = InMemoryStorage::new(100); + let storage = InMemoryStorage::new(); let namespace = "test_namespace"; let limit_1 = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]); let limit_2 = Limit::new( diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index 11a7cc33..54f2bf3d 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -17,6 +17,7 @@ pub mod redis; #[cfg(feature = "infinispan_storage")] pub mod infinispan; +mod expiring_value; mod keys; pub enum Authorization { diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index bfe53e1e..e6218279 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -540,10 +540,13 @@ mod test { // iteration. It should not affect. values.insert("does_not_apply".to_string(), i.to_string()); - assert!(!rate_limiter - .is_rate_limited(namespace, &values, 1) - .await - .unwrap()); + assert!( + !rate_limiter + .is_rate_limited(namespace, &values, 1) + .await + .unwrap(), + "Must not be limited after {i}" + ); rate_limiter .update_counters(namespace, &values, 1) .await