Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(flags): cleaning up some flags stuff #26444

Merged
merged 9 commits into from
Nov 27, 2024
Merged
2 changes: 1 addition & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ ahash = "0.8.11"
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-sdk-s3 = "1.58.0"
mockall = "0.13.0"
moka = { version = "0.12.8", features = ["sync"] }
moka = { version = "0.12.8", features = ["sync", "future"] }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this feature to the workspace version of moka because I need it for flags

2 changes: 1 addition & 1 deletion rust/feature-flags/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ common-metrics = { path = "../common/metrics" }
tower = { workspace = true }
derive_builder = "0.20.1"
petgraph = "0.6.5"
moka = { version = "0.12.8", features = ["future"] }
moka = { workspace = true }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the workspace version


[lints]
workspace = true
Expand Down
66 changes: 32 additions & 34 deletions rust/feature-flags/src/api/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ pub struct FeatureFlagEvaluationContext {
team_id: i32,
distinct_id: String,
feature_flags: FeatureFlagList,
postgres_reader: Arc<dyn Client + Send + Sync>,
postgres_writer: Arc<dyn Client + Send + Sync>,
reader: Arc<dyn Client + Send + Sync>,
writer: Arc<dyn Client + Send + Sync>,
cohort_cache: Arc<CohortCacheManager>,
#[builder(default)]
person_property_overrides: Option<HashMap<String, Value>>,
Expand All @@ -93,10 +93,10 @@ pub async fn process_request(context: RequestContext) -> Result<FlagsResponse, F

let request = decode_request(&headers, body)?;
let token = request
.extract_and_verify_token(state.redis.clone(), state.postgres_reader.clone())
.extract_and_verify_token(state.redis.clone(), state.reader.clone())
.await?;
let team = request
.get_team_from_cache_or_pg(&token, state.redis.clone(), state.postgres_reader.clone())
.get_team_from_cache_or_pg(&token, state.redis.clone(), state.reader.clone())
.await?;

let distinct_id = request.extract_distinct_id()?;
Expand All @@ -112,16 +112,16 @@ pub async fn process_request(context: RequestContext) -> Result<FlagsResponse, F
let hash_key_override = request.anon_distinct_id.clone();

let feature_flags_from_cache_or_pg = request
.get_flags_from_cache_or_pg(team_id, &state.redis, &state.postgres_reader)
.get_flags_from_cache_or_pg(team_id, &state.redis, &state.reader)
.await?;

let evaluation_context = FeatureFlagEvaluationContextBuilder::default()
.team_id(team_id)
.distinct_id(distinct_id)
.feature_flags(feature_flags_from_cache_or_pg)
.postgres_reader(state.postgres_reader.clone())
.postgres_writer(state.postgres_writer.clone())
.cohort_cache(state.cohort_cache.clone())
.reader(state.reader.clone())
.writer(state.writer.clone())
.cohort_cache(state.cohort_cache_manager.clone())
.person_property_overrides(person_property_overrides)
.group_property_overrides(group_property_overrides)
.groups(groups)
Expand Down Expand Up @@ -220,12 +220,12 @@ fn decode_request(headers: &HeaderMap, body: Bytes) -> Result<FlagRequest, FlagE
// which flags failed to evaluate
pub async fn evaluate_feature_flags(context: FeatureFlagEvaluationContext) -> FlagsResponse {
let group_type_mapping_cache =
GroupTypeMappingCache::new(context.team_id, context.postgres_reader.clone());
GroupTypeMappingCache::new(context.team_id, context.reader.clone());
let mut feature_flag_matcher = FeatureFlagMatcher::new(
context.distinct_id,
context.team_id,
context.postgres_reader,
context.postgres_writer,
context.reader,
context.writer,
context.cohort_cache,
Some(group_type_mapping_cache),
context.groups,
Expand Down Expand Up @@ -362,9 +362,9 @@ mod tests {

#[tokio::test]
async fn test_evaluate_feature_flags() {
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(reader.clone(), None, None));
let flag = FeatureFlag {
name: Some("Test Flag".to_string()),
id: 1,
Expand Down Expand Up @@ -402,8 +402,8 @@ mod tests {
.team_id(1)
.distinct_id("user123".to_string())
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.reader(reader)
.writer(writer)
.cohort_cache(cohort_cache)
.person_property_overrides(Some(person_properties))
.build()
Expand Down Expand Up @@ -511,9 +511,9 @@ mod tests {

#[tokio::test]
async fn test_evaluate_feature_flags_multiple_flags() {
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(reader.clone(), None, None));
let flags = vec![
FeatureFlag {
name: Some("Flag 1".to_string()),
Expand Down Expand Up @@ -563,8 +563,8 @@ mod tests {
.team_id(1)
.distinct_id("user123".to_string())
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.reader(reader)
.writer(writer)
.cohort_cache(cohort_cache)
.build()
.expect("Failed to build FeatureFlagEvaluationContext");
Expand Down Expand Up @@ -616,12 +616,10 @@ mod tests {

#[tokio::test]
async fn test_evaluate_feature_flags_with_overrides() {
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let team = insert_new_team_in_pg(postgres_reader.clone(), None)
.await
.unwrap();
let reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(reader.clone(), None, None));
let team = insert_new_team_in_pg(reader.clone(), None).await.unwrap();

let flag = FeatureFlag {
name: Some("Test Flag".to_string()),
Expand Down Expand Up @@ -665,8 +663,8 @@ mod tests {
.team_id(team.id)
.distinct_id("user123".to_string())
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.reader(reader)
.writer(writer)
.cohort_cache(cohort_cache)
.group_property_overrides(Some(group_property_overrides))
.groups(Some(groups))
Expand Down Expand Up @@ -699,9 +697,9 @@ mod tests {
#[tokio::test]
async fn test_long_distinct_id() {
let long_id = "a".repeat(1000);
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(reader.clone(), None, None));
let flag = FeatureFlag {
name: Some("Test Flag".to_string()),
id: 1,
Expand Down Expand Up @@ -729,8 +727,8 @@ mod tests {
.team_id(1)
.distinct_id(long_id)
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.reader(reader)
.writer(writer)
.cohort_cache(cohort_cache)
.build()
.expect("Failed to build FeatureFlagEvaluationContext");
Expand Down
84 changes: 44 additions & 40 deletions rust/feature-flags/src/cohort/cohort_cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::api::errors::FlagError;
use crate::cohort::cohort_models::Cohort;
use crate::flags::flag_matching::{PostgresReader, TeamId};
use moka::future::Cache;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;

/// CohortCacheManager manages the in-memory cache of cohorts using `moka` for caching.
///
Expand All @@ -12,8 +14,8 @@ use std::time::Duration;
///
/// ```text
/// CohortCacheManager {
/// postgres_reader: PostgresReader,
/// per_team_cohorts: Cache<TeamId, Vec<Cohort>> {
/// reader: PostgresReader,
/// cache: Cache<TeamId, Vec<Cohort>> {
/// // Example:
/// 2: [
/// Cohort { id: 1, name: "Power Users", filters: {...} },
Expand All @@ -22,50 +24,59 @@ use std::time::Duration;
/// 5: [
/// Cohort { id: 3, name: "Beta Users", filters: {...} }
/// ]
/// }
/// },
/// fetch_lock: Mutex<()> // Manager-wide lock
/// }
/// ```
///
#[derive(Clone)]
pub struct CohortCacheManager {
postgres_reader: PostgresReader,
per_team_cohort_cache: Cache<TeamId, Vec<Cohort>>,
reader: PostgresReader,
cache: Cache<TeamId, Vec<Cohort>>,
fetch_lock: Arc<Mutex<()>>, // Added fetch_lock
}

impl CohortCacheManager {
pub fn new(
postgres_reader: PostgresReader,
reader: PostgresReader,
max_capacity: Option<u64>,
ttl_seconds: Option<u64>,
) -> Self {
// We use the size of the cohort list (i.e., the number of cohorts for a given team)as the weight of the entry
let weigher =
|_: &TeamId, value: &Vec<Cohort>| -> u32 { value.len().try_into().unwrap_or(u32::MAX) };
// We use the size of the cohort list (i.e., the number of cohorts for a given team) as the weight of the entry
let weigher = |_: &TeamId, value: &Vec<Cohort>| -> u32 { value.len() as u32 };

let cache = Cache::builder()
.time_to_live(Duration::from_secs(ttl_seconds.unwrap_or(300))) // Default to 5 minutes
.weigher(weigher)
.max_capacity(max_capacity.unwrap_or(10_000)) // Default to 10,000 cohorts
.max_capacity(max_capacity.unwrap_or(100_000)) // Default to 100,000 cohorts
.build();

Self {
postgres_reader,
per_team_cohort_cache: cache,
reader,
cache,
fetch_lock: Arc::new(Mutex::new(())), // Initialize the lock
}
}

/// Retrieves cohorts for a given team.
///
/// If the cohorts are not present in the cache or have expired, it fetches them from the database,
/// caches the result upon successful retrieval, and then returns it.
pub async fn get_cohorts_for_team(&self, team_id: TeamId) -> Result<Vec<Cohort>, FlagError> {
if let Some(cached_cohorts) = self.per_team_cohort_cache.get(&team_id).await {
pub async fn get_cohorts(&self, team_id: TeamId) -> Result<Vec<Cohort>, FlagError> {
if let Some(cached_cohorts) = self.cache.get(&team_id).await {
return Ok(cached_cohorts.clone());
}
let fetched_cohorts = Cohort::list_from_pg(self.postgres_reader.clone(), team_id).await?;
self.per_team_cohort_cache
.insert(team_id, fetched_cohorts.clone())
.await;

// Acquire the lock before fetching
let _lock = self.fetch_lock.lock().await;

// Double-check the cache after acquiring the lock
if let Some(cached_cohorts) = self.cache.get(&team_id).await {
return Ok(cached_cohorts.clone());
}

let fetched_cohorts = Cohort::list_from_pg(self.reader.clone(), team_id).await?;
self.cache.insert(team_id, fetched_cohorts.clone()).await;

Ok(fetched_cohorts)
}
Expand Down Expand Up @@ -116,18 +127,18 @@ mod tests {
Some(1), // 1-second TTL
);

let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?;
let cohorts = cohort_cache.get_cohorts(team_id).await?;
assert_eq!(cohorts.len(), 1);
assert_eq!(cohorts[0].team_id, team_id);

let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
let cached_cohorts = cohort_cache.cache.get(&team_id).await;
assert!(cached_cohorts.is_some());

// Wait for TTL to expire
sleep(Duration::from_secs(2)).await;

// Attempt to retrieve from cache again
let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
let cached_cohorts = cohort_cache.cache.get(&team_id).await;
assert!(cached_cohorts.is_none(), "Cache entry should have expired");

Ok(())
Expand All @@ -152,11 +163,11 @@ mod tests {
let team_id = team.id;
inserted_team_ids.push(team_id);
setup_test_cohort(writer_client.clone(), team_id, None).await?;
cohort_cache.get_cohorts_for_team(team_id).await?;
cohort_cache.get_cohorts(team_id).await?;
}

cohort_cache.per_team_cohort_cache.run_pending_tasks().await;
let cache_size = cohort_cache.per_team_cohort_cache.entry_count();
cohort_cache.cache.run_pending_tasks().await;
let cache_size = cohort_cache.cache.entry_count();
assert_eq!(
cache_size, max_capacity,
"Cache size should be equal to max_capacity"
Expand All @@ -165,26 +176,23 @@ mod tests {
let new_team = insert_new_team_in_pg(writer_client.clone(), None).await?;
let new_team_id = new_team.id;
setup_test_cohort(writer_client.clone(), new_team_id, None).await?;
cohort_cache.get_cohorts_for_team(new_team_id).await?;
cohort_cache.get_cohorts(new_team_id).await?;

cohort_cache.per_team_cohort_cache.run_pending_tasks().await;
let cache_size_after = cohort_cache.per_team_cohort_cache.entry_count();
cohort_cache.cache.run_pending_tasks().await;
let cache_size_after = cohort_cache.cache.entry_count();
assert_eq!(
cache_size_after, max_capacity,
"Cache size should remain equal to max_capacity after eviction"
);

let evicted_team_id = &inserted_team_ids[0];
let cached_cohorts = cohort_cache
.per_team_cohort_cache
.get(evicted_team_id)
.await;
let cached_cohorts = cohort_cache.cache.get(evicted_team_id).await;
assert!(
cached_cohorts.is_none(),
"Least recently used cache entry should have been evicted"
);

let cached_new_team = cohort_cache.per_team_cohort_cache.get(&new_team_id).await;
let cached_new_team = cohort_cache.cache.get(&new_team_id).await;
assert!(
cached_new_team.is_some(),
"Newly added cache entry should be present"
Expand All @@ -194,25 +202,21 @@ mod tests {
}

#[tokio::test]
async fn test_get_cohorts_for_team() -> Result<(), anyhow::Error> {
async fn test_get_cohorts() -> Result<(), anyhow::Error> {
let writer_client = setup_pg_writer_client(None).await;
let reader_client = setup_pg_reader_client(None).await;
let team_id = setup_test_team(writer_client.clone()).await?;
let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?;
let cohort_cache = CohortCacheManager::new(reader_client.clone(), None, None);

let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
let cached_cohorts = cohort_cache.cache.get(&team_id).await;
assert!(cached_cohorts.is_none(), "Cache should initially be empty");

let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?;
let cohorts = cohort_cache.get_cohorts(team_id).await?;
assert_eq!(cohorts.len(), 1);
assert_eq!(cohorts[0].team_id, team_id);

let cached_cohorts = cohort_cache
.per_team_cohort_cache
.get(&team_id)
.await
.unwrap();
let cached_cohorts = cohort_cache.cache.get(&team_id).await.unwrap();
assert_eq!(cached_cohorts.len(), 1);
assert_eq!(cached_cohorts[0].team_id, team_id);

Expand Down
Loading