Skip to content

Commit

Permalink
Got rid of the namespace indirection
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Snaps <[email protected]>
  • Loading branch information
alexsnaps committed Sep 12, 2024
1 parent 57abc95 commit 15f51ba
Showing 1 changed file with 39 additions and 75 deletions.
114 changes: 39 additions & 75 deletions limitador/src/storage/in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,49 +9,42 @@ use std::ops::Deref;
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};

type NamespacedLimitCounters<T> = HashMap<Namespace, HashMap<Limit, T>>;

pub struct InMemoryStorage {
limits_for_namespace: RwLock<NamespacedLimitCounters<AtomicExpiringValue>>,
simple_limits: RwLock<HashMap<Limit, AtomicExpiringValue>>,
qualified_counters: Cache<Counter, Arc<AtomicExpiringValue>>,
}

impl CounterStorage for InMemoryStorage {
#[tracing::instrument(skip_all)]
fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result<bool, StorageErr> {
let limits_by_namespace = self.limits_for_namespace.read().unwrap();

let mut value = 0;

if counter.is_qualified() {
if let Some(counter) = self.qualified_counters.get(counter) {
value = counter.value();
}
} else if let Some(limits) = limits_by_namespace.get(counter.limit().namespace()) {
if let Some(counter) = limits.get(counter.limit()) {
value = counter.value();
}
}
let value = if counter.is_qualified() {
self.qualified_counters
.get(counter)
.map(|c| c.value())
.unwrap_or_default()
} else {
let limits_by_namespace = self.simple_limits.read().unwrap();
limits_by_namespace
.get(counter.limit())
.map(|c| c.value())
.unwrap_or_default()
};

Ok(counter.max_value() >= value + delta)
}

#[tracing::instrument(skip_all)]
fn add_counter(&self, limit: &Limit) -> Result<(), StorageErr> {
if limit.variables().is_empty() {
let mut limits_by_namespace = self.limits_for_namespace.write().unwrap();
limits_by_namespace
.entry(limit.namespace().clone())
.or_default()
.entry(limit.clone())
.or_default();
let mut limits_by_namespace = self.simple_limits.write().unwrap();
limits_by_namespace.entry(limit.clone()).or_default();
}
Ok(())
}

#[tracing::instrument(skip_all)]
fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> {
let mut limits_by_namespace = self.limits_for_namespace.write().unwrap();
let mut counters = self.simple_limits.write().unwrap();
let now = SystemTime::now();
if counter.is_qualified() {
let value = match self.qualified_counters.get(counter) {
Expand All @@ -62,23 +55,13 @@ impl CounterStorage for InMemoryStorage {
};
value.update(delta, counter.window(), now);
} else {
match limits_by_namespace.entry(counter.limit().namespace().clone()) {
match counters.entry(counter.limit().clone()) {
Entry::Vacant(v) => {
let mut limits = HashMap::new();
limits.insert(
counter.limit().clone(),
AtomicExpiringValue::new(delta, now + counter.window()),
);
v.insert(limits);
v.insert(AtomicExpiringValue::new(delta, now + counter.window()));
}
Entry::Occupied(o) => {
o.get().update(delta, counter.window(), now);
}
Entry::Occupied(mut o) => match o.get_mut().entry(counter.limit().clone()) {
Entry::Vacant(v) => {
v.insert(AtomicExpiringValue::new(delta, now + counter.window()));
}
Entry::Occupied(o) => {
o.get().update(delta, counter.window(), now);
}
},
}
}
Ok(())
Expand All @@ -91,7 +74,7 @@ impl CounterStorage for InMemoryStorage {
delta: u64,
load_counters: bool,
) -> Result<Authorization, StorageErr> {
let limits_by_namespace = self.limits_for_namespace.read().unwrap();
let limits_by_namespace = self.simple_limits.read().unwrap();
let mut first_limited = None;
let mut counter_values_to_update: Vec<(&AtomicExpiringValue, Duration)> = Vec::new();
let mut qualified_counter_values_to_updated: Vec<(Arc<AtomicExpiringValue>, Duration)> =
Expand Down Expand Up @@ -119,10 +102,8 @@ impl CounterStorage for InMemoryStorage {

// Process simple counters
for counter in counters.iter_mut().filter(|c| !c.is_qualified()) {
let atomic_expiring_value: &AtomicExpiringValue = limits_by_namespace
.get(counter.limit().namespace())
.and_then(|limits| limits.get(counter.limit()))
.unwrap();
let atomic_expiring_value: &AtomicExpiringValue =
limits_by_namespace.get(counter.limit()).unwrap();

if let Some(limited) = process_counter(counter, atomic_expiring_value.value(), delta) {
if !load_counters {
Expand All @@ -135,7 +116,7 @@ impl CounterStorage for InMemoryStorage {
// Process qualified counters
for counter in counters.iter_mut().filter(|c| c.is_qualified()) {
let value = match self.qualified_counters.get(counter) {
None => self.qualified_counters.get_with(counter.clone(), || {
None => self.qualified_counters.get_with_by_ref(counter, || {
Arc::new(AtomicExpiringValue::new(0, now + counter.window()))
}),
Some(counter) => counter,
Expand Down Expand Up @@ -171,24 +152,14 @@ impl CounterStorage for InMemoryStorage {
fn get_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<HashSet<Counter>, StorageErr> {
let mut res = HashSet::new();

let namespaces: HashSet<&Namespace> = limits.iter().map(|l| l.namespace()).collect();
let limits_by_namespace = self.limits_for_namespace.read().unwrap();

for namespace in namespaces {
if let Some(limits) = limits_by_namespace.get(namespace) {
for limit in limits.keys() {
if limits.contains_key(limit) {
for (counter, expiring_value) in self.counters_in_namespace(namespace) {
let mut counter_with_val = counter.clone();
counter_with_val.set_remaining(
counter_with_val.max_value() - expiring_value.value(),
);
counter_with_val.set_expires_in(expiring_value.ttl());
if counter_with_val.expires_in().unwrap() > Duration::ZERO {
res.insert(counter_with_val);
}
}
}
for limit in limits {
for (counter, expiring_value) in self.counters_in_namespace(limit.namespace()) {
let mut counter_with_val = counter.clone();
counter_with_val
.set_remaining(counter_with_val.max_value() - expiring_value.value());
counter_with_val.set_expires_in(expiring_value.ttl());
if counter_with_val.expires_in().unwrap() > Duration::ZERO {
res.insert(counter_with_val);
}
}
}
Expand Down Expand Up @@ -218,15 +189,15 @@ impl CounterStorage for InMemoryStorage {

#[tracing::instrument(skip_all)]
fn clear(&self) -> Result<(), StorageErr> {
self.limits_for_namespace.write().unwrap().clear();
self.simple_limits.write().unwrap().clear();
Ok(())
}
}

impl InMemoryStorage {
pub fn new(cache_size: u64) -> Self {
Self {
limits_for_namespace: RwLock::new(HashMap::new()),
simple_limits: RwLock::new(HashMap::new()),
qualified_counters: Cache::new(cache_size),
}
}
Expand All @@ -237,11 +208,11 @@ impl InMemoryStorage {
) -> HashMap<Counter, AtomicExpiringValue> {
let mut res: HashMap<Counter, AtomicExpiringValue> = HashMap::new();

if let Some(counters_by_limit) = self.limits_for_namespace.read().unwrap().get(namespace) {
for (limit, value) in counters_by_limit {
for (limit, counter) in self.simple_limits.read().unwrap().iter() {
if limit.namespace() == namespace {
res.insert(
Counter::new(limit.clone(), HashMap::default()),
value.clone(),
counter.clone(),
);
}
}
Expand All @@ -256,14 +227,7 @@ impl InMemoryStorage {
}

fn delete_counters_of_limit(&self, limit: &Limit) {
if let Some(counters_by_limit) = self
.limits_for_namespace
.write()
.unwrap()
.get_mut(limit.namespace())
{
counters_by_limit.remove(limit);
}
self.simple_limits.write().unwrap().remove(limit);
}

fn counter_is_within_limits(counter: &Counter, current_val: Option<&u64>, delta: u64) -> bool {
Expand Down

0 comments on commit 15f51ba

Please sign in to comment.