Skip to content

Commit

Permalink
[DISCO-3065] relevancy: caching for multi armed bandit API
Browse files Browse the repository at this point in the history
  • Loading branch information
misaniwere committed Dec 16, 2024
1 parent 5aec372 commit 7896001
Showing 1 changed file with 64 additions and 5 deletions.
69 changes: 64 additions & 5 deletions components/relevancy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@ use rand_distr::{Beta, Distribution};
pub use db::RelevancyDb;
pub use error::{ApiResult, Error, RelevancyApiError, Result};
pub use interest::{Interest, InterestVector};
use parking_lot::Mutex;
pub use ranker::score;

use error_support::handle_error;

use std::collections::HashMap;

uniffi::setup_scaffolding!();

#[derive(uniffi::Object)]
pub struct RelevancyStore {
db: RelevancyDb,
cache: Mutex<BanditCache>,
}

/// Top-level API for the Relevancy component
Expand All @@ -45,6 +49,7 @@ impl RelevancyStore {
pub fn new(db_path: String) -> Self {
Self {
db: RelevancyDb::new(db_path),
cache: Mutex::new(BanditCache::new()),
}
}

Expand Down Expand Up @@ -125,15 +130,12 @@ impl RelevancyStore {
/// of success. The arm with the highest sampled probability is selected and returned.
#[handle_error(Error)]
pub fn bandit_select(&self, bandit: String, arms: &[String]) -> ApiResult<String> {
// we should cache the distribution so we don't retrieve each time

let mut cache = self.cache.lock();
let mut best_sample = f64::MIN;
let mut selected_arm = String::new();

for arm in arms {
let (alpha, beta) = self
.db
.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, arm))?;
let (alpha, beta) = cache.get_beta_distribution(&bandit, arm, &self.db)?;
// this creates a Beta distribution for an alpha & beta pair
let beta_dist = Beta::new(alpha as f64, beta as f64)
.expect("computing betas dist unexpectedly failed");
Expand All @@ -159,12 +161,69 @@ impl RelevancyStore {
/// its likelihood of a negative outcome.
#[handle_error(Error)]
pub fn bandit_update(&self, bandit: String, arm: String, selected: bool) -> ApiResult<()> {
let mut cache = self.cache.lock();

cache.clear(&bandit, &arm);

self.db
.read_write(|dao| dao.update_bandit_arm_data(&bandit, &arm, selected))?;

Ok(())
}
}

#[derive(Default)]
pub struct BanditCache {
cache: HashMap<(String, String), (usize, usize)>,
}

impl BanditCache {
/// Creates a new, empty `BanditCache`.
///
/// The cache is initialized as an empty `HashMap` and is used to store
/// precomputed Beta distribution parameters for faster access during
/// Thompson Sampling operations.
pub fn new() -> Self {
Self::default()
}

/// Retrieves the Beta distribution parameters for a given bandit and arm.
///
/// If the parameters for the specified `bandit` and `arm` are already cached,
/// they are returned directly. Otherwise, the parameters are fetched from
/// the database, added to the cache, and then returned.
pub fn get_beta_distribution(
&mut self,
bandit: &str,
arm: &str,
db: &RelevancyDb,
) -> Result<(usize, usize)> {
let key = (bandit.to_string(), arm.to_string());

// Check if the distribution is already cached
if let Some(&params) = self.cache.get(&key) {
return Ok(params);
}

let params = db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(bandit, arm))?;

// Cache the retrieved parameters for future use
self.cache.insert(key, params);

Ok(params)
}

/// Clears the cached Beta distribution parameters for a given bandit and arm.
///
/// This removes the cached values for the specified `bandit` and `arm` from the cache.
/// Use this method if the cached parameters are no longer valid or need to be refreshed.
pub fn clear(&mut self, bandit: &str, arm: &str) {
let key = (bandit.to_string(), arm.to_string());

self.cache.remove(&key);
}
}

impl RelevancyStore {
/// Download the interest data from remote settings if needed
#[handle_error(Error)]
Expand Down

0 comments on commit 7896001

Please sign in to comment.