diff --git a/components/relevancy/src/lib.rs b/components/relevancy/src/lib.rs index 43ea661d3c..f72c0e613f 100644 --- a/components/relevancy/src/lib.rs +++ b/components/relevancy/src/lib.rs @@ -23,15 +23,19 @@ use rand_distr::{Beta, Distribution}; pub use db::RelevancyDb; pub use error::{ApiResult, Error, RelevancyApiError, Result}; pub use interest::{Interest, InterestVector}; +use parking_lot::Mutex; pub use ranker::score; use error_support::handle_error; +use std::collections::HashMap; + uniffi::setup_scaffolding!(); #[derive(uniffi::Object)] pub struct RelevancyStore { db: RelevancyDb, + cache: Mutex, } /// Top-level API for the Relevancy component @@ -45,6 +49,7 @@ impl RelevancyStore { pub fn new(db_path: String) -> Self { Self { db: RelevancyDb::new(db_path), + cache: Mutex::new(BanditCache::new()), } } @@ -125,15 +130,12 @@ impl RelevancyStore { /// of success. The arm with the highest sampled probability is selected and returned. #[handle_error(Error)] pub fn bandit_select(&self, bandit: String, arms: &[String]) -> ApiResult { - // we should cache the distribution so we don't retrieve each time - + let mut cache = self.cache.lock(); let mut best_sample = f64::MIN; let mut selected_arm = String::new(); for arm in arms { - let (alpha, beta) = self - .db - .read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, arm))?; + let (alpha, beta) = cache.get_beta_distribution(&bandit, arm, &self.db)?; // this creates a Beta distribution for an alpha & beta pair let beta_dist = Beta::new(alpha as f64, beta as f64) .expect("computing betas dist unexpectedly failed"); @@ -159,12 +161,69 @@ impl RelevancyStore { /// its likelihood of a negative outcome. #[handle_error(Error)] pub fn bandit_update(&self, bandit: String, arm: String, selected: bool) -> ApiResult<()> { + let mut cache = self.cache.lock(); + + cache.clear(&bandit, &arm); + self.db .read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, selected))?; + Ok(()) } } +#[derive(Default)] +pub struct BanditCache { + cache: HashMap<(String, String), (usize, usize)>, +} + +impl BanditCache { + /// Creates a new, empty `BanditCache`. + /// + /// The cache is initialized as an empty `HashMap` and is used to store + /// precomputed Beta distribution parameters for faster access during + /// Thompson Sampling operations. + pub fn new() -> Self { + Self::default() + } + + /// Retrieves the Beta distribution parameters for a given bandit and arm. + /// + /// If the parameters for the specified `bandit` and `arm` are already cached, + /// they are returned directly. Otherwise, the parameters are fetched from + /// the database, added to the cache, and then returned. + pub fn get_beta_distribution( + &mut self, + bandit: &str, + arm: &str, + db: &RelevancyDb, + ) -> Result<(usize, usize)> { + let key = (bandit.to_string(), arm.to_string()); + + // Check if the distribution is already cached + if let Some(¶ms) = self.cache.get(&key) { + return Ok(params); + } + + let params = db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(bandit, arm))?; + + // Cache the retrieved parameters for future use + self.cache.insert(key, params); + + Ok(params) + } + + /// Clears the cached Beta distribution parameters for a given bandit and arm. + /// + /// This removes the cached values for the specified `bandit` and `arm` from the cache. + /// Use this method if the cached parameters are no longer valid or need to be refreshed. + pub fn clear(&mut self, bandit: &str, arm: &str) { + let key = (bandit.to_string(), arm.to_string()); + + self.cache.remove(&key); + } +} + impl RelevancyStore { /// Download the interest data from remote settings if needed #[handle_error(Error)]