From 84058e07328231de9e4c6a7f7c28b67cbc7cace2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sosth=C3=A8ne=20Gu=C3=A9don?= Date: Thu, 1 Feb 2024 15:21:45 +0100 Subject: [PATCH] Add tests for fetching keys with failure --- pkcs11/src/backend/session.rs | 85 +++++++++++++++++++++++++++++++---- 1 file changed, 77 insertions(+), 8 deletions(-) diff --git a/pkcs11/src/backend/session.rs b/pkcs11/src/backend/session.rs index 30695f8d..c03cc0f6 100644 --- a/pkcs11/src/backend/session.rs +++ b/pkcs11/src/backend/session.rs @@ -1,6 +1,6 @@ use std::{ collections::HashMap, - sync::{atomic::Ordering, Arc, Condvar, Mutex}, + sync::{atomic::Ordering, Arc, Condvar, Mutex, MutexGuard}, }; use cryptoki_sys::{ @@ -527,15 +527,27 @@ impl Session { } /// Drop the Condvar to notify on close - struct NotifyAllGuard<'a>(&'a (Mutex, Condvar)); + struct NotifyAllGuard<'a>(Option<&'a (Mutex, Condvar)>); impl<'a> Drop for NotifyAllGuard<'a> { fn drop(&mut self) { - self.0 .0.lock().unwrap().set_is_being_fetched(false); - self.0 .1.notify_all(); + if let Some(cv) = self.0 { + cv.0.lock().unwrap().set_is_being_fetched(false); + cv.1.notify_all(); + } + } + } + + impl<'a> NotifyAllGuard<'a> { + fn success(&mut self, mut lock: MutexGuard<'a, Db>) { + let cv = self.0.take().unwrap(); + lock.set_is_being_fetched(false); + lock.set_fetched_all_keys(true); + drop(lock); + cv.1.notify_all(); } } - NotifyAllGuard(&self.db); + let mut guard = NotifyAllGuard(Some(&self.db)); if !self .login_ctx @@ -573,7 +585,7 @@ impl Session { .flatten() .map(|o| db.add_object(o)) .collect(); - db.set_fetched_all_keys(true); + guard.success(db); Ok(handles) } @@ -678,7 +690,10 @@ impl Session { #[cfg(test)] mod test { use crate::{ - backend::slot::{get_slot, init_for_tests}, + backend::{ + slot::{get_slot, init_for_tests}, + ApiError, + }, config::config_file::RetryConfig, }; use std::thread; @@ -692,7 +707,7 @@ mod test { let db = Arc::new((Mutex::new(Db::new()), Condvar::new())); let mut sessions = Vec::new(); - for _ in 0..20 { + for _ in 0..10 { let session = Session { db: db.clone(), decrypt_ctx: None, @@ -724,4 +739,58 @@ mod test { } }) } + + #[test] + fn parrallel_fetch_all_keys_fail() { + THREADS_ALLOWED.store(false, Ordering::Relaxed); + init_for_tests(); + let slot = get_slot(0).unwrap(); + + let db = Arc::new((Mutex::new(Db::new()), Condvar::new())); + let mut sessions = Vec::new(); + for _ in 0..10 { + let mut bad_instance = slot.instances[0].clone(); + bad_instance.base_path.push_str("/corrupted_url"); + let session = Session { + db: db.clone(), + decrypt_ctx: None, + encrypt_ctx: None, + sign_ctx: None, + device_error: 0, + enum_ctx: None, + flags: 0, + login_ctx: LoginCtx::new( + Some(crate::config::config_file::UserConfig { + username: "operator".into(), + password: Some("opPassphrase".into()), + }), + None, + vec![bad_instance], + Some(RetryConfig { + count: 2, + delay_seconds: 0, + }), + ), + slot_id: 0, + }; + sessions.push(session); + } + + thread::scope(|s| { + for session in &mut sessions { + s.spawn(|| { + match session.fetch_all_keys() { + Err(Error::Api(ApiError::Ureq(r))) => { + assert!( + r.ends_with(": status code 404"), + "expected 404 error, got {}", + r + ); + } + res => panic!("{res:?}"), + }; + }); + } + }) + } }