Skip to content

Commit

Permalink
cache.rs: implement auto-clearing behaviour
Browse files Browse the repository at this point in the history
Every 255 cache operations, clear out any expired sessions.
  • Loading branch information
ctz committed Apr 26, 2024
1 parent d17c827 commit 765207b
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions rustls-libssl/src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::ptr;
use std::collections::BTreeSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::SystemTime;

Expand Down Expand Up @@ -114,13 +115,15 @@ impl Default for SessionCaches {
pub struct ServerSessionStorage {
items: Mutex<BTreeSet<Arc<SslSession>>>,
parameters: Mutex<CacheParameters>,
op_count: AtomicUsize,
}

impl ServerSessionStorage {
fn new(max_size: usize) -> Self {
Self {
items: Mutex::new(BTreeSet::new()),
parameters: Mutex::new(CacheParameters::new(max_size)),
op_count: AtomicUsize::new(0),
}
}

Expand Down Expand Up @@ -230,6 +233,8 @@ impl ServerSessionStorage {
}

fn insert(&self, new: Arc<SslSession>) -> bool {
self.tick();

if let Ok(mut items) = self.items.lock() {
items.insert(new)
} else {
Expand All @@ -238,6 +243,8 @@ impl ServerSessionStorage {
}

fn take(&self, id: &[u8]) -> Option<Arc<SslSession>> {
self.tick();

if let Ok(mut items) = self.items.lock() {
items.take(&SslSessionLookup::for_id(id))
} else {
Expand All @@ -246,6 +253,8 @@ impl ServerSessionStorage {
}

fn find_by_id(&self, id: &[u8]) -> Option<Arc<SslSession>> {
self.tick();

if let Ok(items) = self.items.lock() {
items.get(&SslSessionLookup::for_id(id)).cloned()
} else {
Expand All @@ -271,6 +280,41 @@ impl ServerSessionStorage {
}
}
}

fn flush_expired(&self, at_time: TimeBase) {
if let Ok(mut items) = self.items.lock() {
let callbacks = self.callbacks();
if let Some(callback) = callbacks.remove_callback {
// if we have a callback to invoke, do it the slow way
let mut removal_list = BTreeSet::new();
for item in items.iter() {
if item.expired(at_time) {
removal_list.insert(item.clone());
}
}

while let Some(sess) = removal_list.pop_first() {
items.remove(&sess);
callbacks::invoke_session_remove_callback(
Some(callback),
callbacks.ssl_ctx,
sess,
);
}
} else {
items.retain(|item| !item.expired(at_time));
}
}
}

fn tick(&self) {
// Called every cache operation. Every 255 operations, expire
// sessions (unless application opts out with CACHE_MODE_NO_AUTO_CLEAR).
let op_count = self.op_count.fetch_add(1, Ordering::SeqCst);
if self.mode() & CACHE_MODE_NO_AUTO_CLEAR == 0 && op_count & 0xff == 0xff {
self.flush_expired(TimeBase::now());
}
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -465,3 +509,28 @@ impl TimeBase {
)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn flush_expired() {
let cache = ServerSessionStorage::new(10);

for i in 1..=5 {
assert!(cache.insert(
SslSession::new(vec![i], vec![], vec![], ExpiryTime(10 + i as u64)).into()
));
}

// expires items 1, 2
cache.flush_expired(TimeBase(10 + 3));

assert!(cache.find_by_id(&[1]).is_none());
assert!(cache.find_by_id(&[2]).is_none());
assert!(cache.find_by_id(&[3]).is_some());
assert!(cache.find_by_id(&[4]).is_some());
assert!(cache.find_by_id(&[5]).is_some());
}
}

0 comments on commit 765207b

Please sign in to comment.