diff --git a/rustls-libssl/src/cache.rs b/rustls-libssl/src/cache.rs index e40aaac..cf6f461 100644 --- a/rustls-libssl/src/cache.rs +++ b/rustls-libssl/src/cache.rs @@ -235,8 +235,20 @@ impl ServerSessionStorage { fn insert(&self, new: Arc) -> bool { self.tick(); + let max_size = self + .parameters + .lock() + .map(|inner| inner.max_size) + .unwrap_or_default(); + if let Ok(mut items) = self.items.lock() { - items.insert(new) + let inserted = items.insert(new); + + while items.len() > max_size { + Self::flush_oldest(&mut items); + } + + inserted } else { false } @@ -315,6 +327,27 @@ impl ServerSessionStorage { self.flush_expired(TimeBase::now()); } } + + fn flush_oldest(items: &mut BTreeSet>) { + let mut oldest = None; + + for item in items.iter() { + oldest = match oldest { + None => Some(item.clone()), + Some(oldest) => { + if item.older_than(&oldest) { + Some(item.clone()) + } else { + Some(oldest) + } + } + }; + } + + if let Some(oldest) = oldest { + items.take(&oldest); + } + } } #[derive(Debug)] @@ -533,4 +566,48 @@ mod tests { assert!(cache.find_by_id(&[4]).is_some()); assert!(cache.find_by_id(&[5]).is_some()); } + + #[test] + fn respects_max_size() { + let cache = ServerSessionStorage::new(4); + + for i in 1..=5 { + assert!(cache.insert( + SslSession::new(vec![i], vec![], vec![], ExpiryTime(10 + i as u64)).into() + )); + } + + assert!(cache.find_by_id(&[1]).is_none()); + assert!(cache.find_by_id(&[2]).is_some()); + assert!(cache.find_by_id(&[3]).is_some()); + assert!(cache.find_by_id(&[4]).is_some()); + assert!(cache.find_by_id(&[5]).is_some()); + } + + #[test] + fn respects_change_in_max_size() { + let cache = ServerSessionStorage::new(5); + + for i in 1..=5 { + assert!(cache.insert( + SslSession::new(vec![i], vec![], vec![], ExpiryTime(10 + i as u64)).into() + )); + } + + assert!(cache.find_by_id(&[1]).is_some()); + assert!(cache.find_by_id(&[2]).is_some()); + assert!(cache.find_by_id(&[3]).is_some()); + assert!(cache.find_by_id(&[4]).is_some()); + assert!(cache.find_by_id(&[5]).is_some()); + + cache.set_size(4); + assert!(cache.insert(SslSession::new(vec![6], vec![], vec![], ExpiryTime(16)).into())); + + assert!(cache.find_by_id(&[1]).is_none()); + assert!(cache.find_by_id(&[2]).is_none()); + assert!(cache.find_by_id(&[3]).is_some()); + assert!(cache.find_by_id(&[4]).is_some()); + assert!(cache.find_by_id(&[5]).is_some()); + assert!(cache.find_by_id(&[6]).is_some()); + } }