From 765207bdea7e7fed2a33a091ffb9a31cb3de8ad3 Mon Sep 17 00:00:00 2001 From: Joseph Birr-Pixton Date: Fri, 26 Apr 2024 16:15:55 +0100 Subject: [PATCH] cache.rs: implement auto-clearing behaviour Every 255 cache operations, clear out any expired sessions. --- rustls-libssl/src/cache.rs | 69 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/rustls-libssl/src/cache.rs b/rustls-libssl/src/cache.rs index 419f4af..e40aaac 100644 --- a/rustls-libssl/src/cache.rs +++ b/rustls-libssl/src/cache.rs @@ -1,5 +1,6 @@ use core::ptr; use std::collections::BTreeSet; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; use std::time::SystemTime; @@ -114,6 +115,7 @@ impl Default for SessionCaches { pub struct ServerSessionStorage { items: Mutex>>, parameters: Mutex, + op_count: AtomicUsize, } impl ServerSessionStorage { @@ -121,6 +123,7 @@ impl ServerSessionStorage { Self { items: Mutex::new(BTreeSet::new()), parameters: Mutex::new(CacheParameters::new(max_size)), + op_count: AtomicUsize::new(0), } } @@ -230,6 +233,8 @@ impl ServerSessionStorage { } fn insert(&self, new: Arc) -> bool { + self.tick(); + if let Ok(mut items) = self.items.lock() { items.insert(new) } else { @@ -238,6 +243,8 @@ impl ServerSessionStorage { } fn take(&self, id: &[u8]) -> Option> { + self.tick(); + if let Ok(mut items) = self.items.lock() { items.take(&SslSessionLookup::for_id(id)) } else { @@ -246,6 +253,8 @@ impl ServerSessionStorage { } fn find_by_id(&self, id: &[u8]) -> Option> { + self.tick(); + if let Ok(items) = self.items.lock() { items.get(&SslSessionLookup::for_id(id)).cloned() } else { @@ -271,6 +280,41 @@ impl ServerSessionStorage { } } } + + fn flush_expired(&self, at_time: TimeBase) { + if let Ok(mut items) = self.items.lock() { + let callbacks = self.callbacks(); + if let Some(callback) = callbacks.remove_callback { + // if we have a callback to invoke, do it the slow way + let mut removal_list = BTreeSet::new(); + for item in items.iter() { + if item.expired(at_time) { + removal_list.insert(item.clone()); + } + } + + while let Some(sess) = removal_list.pop_first() { + items.remove(&sess); + callbacks::invoke_session_remove_callback( + Some(callback), + callbacks.ssl_ctx, + sess, + ); + } + } else { + items.retain(|item| !item.expired(at_time)); + } + } + } + + fn tick(&self) { + // Called every cache operation. Every 255 operations, expire + // sessions (unless application opts out with CACHE_MODE_NO_AUTO_CLEAR). + let op_count = self.op_count.fetch_add(1, Ordering::SeqCst); + if self.mode() & CACHE_MODE_NO_AUTO_CLEAR == 0 && op_count & 0xff == 0xff { + self.flush_expired(TimeBase::now()); + } + } } #[derive(Debug)] @@ -465,3 +509,28 @@ impl TimeBase { ) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn flush_expired() { + let cache = ServerSessionStorage::new(10); + + for i in 1..=5 { + assert!(cache.insert( + SslSession::new(vec![i], vec![], vec![], ExpiryTime(10 + i as u64)).into() + )); + } + + // expires items 1, 2 + cache.flush_expired(TimeBase(10 + 3)); + + assert!(cache.find_by_id(&[1]).is_none()); + assert!(cache.find_by_id(&[2]).is_none()); + assert!(cache.find_by_id(&[3]).is_some()); + assert!(cache.find_by_id(&[4]).is_some()); + assert!(cache.find_by_id(&[5]).is_some()); + } +}