From ca431b50344b2b72b65f0b2622e9bd0b87a56361 Mon Sep 17 00:00:00 2001 From: dylan Date: Tue, 15 Oct 2024 15:12:51 -0700 Subject: [PATCH 01/30] unifying some types --- rust/feature-flags/src/flag_request.rs | 19 +++++++++++-------- rust/feature-flags/src/request_handler.rs | 9 +++------ rust/feature-flags/src/test_utils.rs | 20 ++++++++++++-------- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/rust/feature-flags/src/flag_request.rs b/rust/feature-flags/src/flag_request.rs index 771c216834c96..feb8577b24dea 100644 --- a/rust/feature-flags/src/flag_request.rs +++ b/rust/feature-flags/src/flag_request.rs @@ -158,8 +158,8 @@ impl FlagRequest { pub async fn get_flags_from_cache_or_pg( &self, team_id: i32, - redis_client: Arc, - pg_client: Arc, + redis_client: &Arc, + pg_client: &Arc, ) -> Result { let mut cache_hit = false; let flags = match FeatureFlagList::from_redis(redis_client.clone(), team_id).await { @@ -167,10 +167,14 @@ impl FlagRequest { cache_hit = true; Ok(flags) } - Err(_) => match FeatureFlagList::from_pg(pg_client, team_id).await { + Err(_) => match FeatureFlagList::from_pg(pg_client.clone(), team_id).await { Ok(flags) => { - if let Err(e) = - FeatureFlagList::update_flags_in_redis(redis_client, team_id, &flags).await + if let Err(e) = FeatureFlagList::update_flags_in_redis( + redis_client.clone(), + team_id, + &flags, + ) + .await { tracing::warn!("Failed to update Redis cache: {}", e); // TODO add new metric category for this @@ -206,7 +210,6 @@ mod tests { TEAM_FLAGS_CACHE_PREFIX, }; use crate::flag_request::FlagRequest; - use crate::redis::Client as RedisClient; use crate::team::Team; use crate::test_utils::{insert_new_team_in_redis, setup_pg_reader_client, setup_redis_client}; use bytes::Bytes; @@ -426,7 +429,7 @@ mod tests { // Test fetching from Redis let result = flag_request - .get_flags_from_cache_or_pg(team.id, redis_client.clone(), pg_client.clone()) + .get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client) .await; assert!(result.is_ok()); let fetched_flags = result.unwrap(); @@ -483,7 +486,7 @@ mod tests { .expect("Failed to remove flags from Redis"); let result = flag_request - .get_flags_from_cache_or_pg(team.id, redis_client.clone(), pg_client.clone()) + .get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client) .await; assert!(result.is_ok()); // Verify that the flags were re-added to Redis diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs index 6c62e7c5ec091..066a9c8f8a889 100644 --- a/rust/feature-flags/src/request_handler.rs +++ b/rust/feature-flags/src/request_handler.rs @@ -108,18 +108,15 @@ pub async fn process_request(context: RequestContext) -> Result = state.postgres_reader.clone(); - let postgres_writer_dyn: Arc = state.postgres_writer.clone(); - let evaluation_context = FeatureFlagEvaluationContextBuilder::default() .team_id(team_id) .distinct_id(distinct_id) .feature_flags(feature_flags_from_cache_or_pg) - .postgres_reader(postgres_reader_dyn) - .postgres_writer(postgres_writer_dyn) + .postgres_reader(state.postgres_reader.clone()) + .postgres_writer(state.postgres_writer.clone()) .person_property_overrides(person_property_overrides) .group_property_overrides(group_property_overrides) .groups(groups) diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index 769f95039990d..8dd1f76a6596f 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -1,7 +1,7 @@ use anyhow::Error; use axum::async_trait; use serde_json::{json, Value}; -use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, PgPool, Postgres}; +use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, Postgres}; use std::sync::Arc; use uuid::Uuid; @@ -23,7 +23,9 @@ pub fn random_string(prefix: &str, length: usize) -> String { format!("{}{}", prefix, suffix) } -pub async fn insert_new_team_in_redis(client: Arc) -> Result { +pub async fn insert_new_team_in_redis( + client: Arc, +) -> Result { let id = rand::thread_rng().gen_range(0..10_000_000); let token = random_string("phc_", 12); let team = Team { @@ -48,7 +50,7 @@ pub async fn insert_new_team_in_redis(client: Arc) -> Result, + client: Arc, team_id: i32, json_value: Option, ) -> Result<(), Error> { @@ -88,7 +90,9 @@ pub async fn insert_flags_for_team_in_redis( Ok(()) } -pub fn setup_redis_client(url: Option) -> Arc { +// type RedisClientTrait = dyn RedisClient + Send + Sync; + +pub fn setup_redis_client(url: Option) -> Arc { let redis_url = match url { Some(value) => value, None => "redis://localhost:6379/".to_string(), @@ -130,7 +134,7 @@ pub fn create_flag_from_json(json_value: Option) -> Vec { flags } -pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc { +pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc { let config = config.unwrap_or(&DEFAULT_TEST_CONFIG); Arc::new( get_pool(&config.read_database_url, config.max_pg_connections) @@ -139,7 +143,7 @@ pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc { ) } -pub async fn setup_pg_writer_client(config: Option<&Config>) -> Arc { +pub async fn setup_pg_writer_client(config: Option<&Config>) -> Arc { let config = config.unwrap_or(&DEFAULT_TEST_CONFIG); Arc::new( get_pool(&config.write_database_url, config.max_pg_connections) @@ -250,7 +254,7 @@ pub async fn insert_new_team_in_pg( } pub async fn insert_flag_for_team_in_pg( - client: Arc, + client: Arc, team_id: i32, flag: Option, ) -> Result { @@ -299,7 +303,7 @@ pub async fn insert_flag_for_team_in_pg( } pub async fn insert_person_for_team_in_pg( - client: Arc, + client: Arc, team_id: i32, distinct_id: String, properties: Option, From e89f16902967119e7b29a157db52993b471dceba Mon Sep 17 00:00:00 2001 From: dylan Date: Tue, 15 Oct 2024 16:12:30 -0700 Subject: [PATCH 02/30] in progress but not done yet --- rust/feature-flags/src/cohort_definitions.rs | 273 +++++++++++++++++++ rust/feature-flags/src/flag_definitions.rs | 2 + rust/feature-flags/src/flag_matching.rs | 30 ++ rust/feature-flags/src/flag_request.rs | 2 + rust/feature-flags/src/lib.rs | 1 + rust/feature-flags/src/property_matching.rs | 35 +++ rust/feature-flags/src/request_handler.rs | 2 + 7 files changed, 345 insertions(+) create mode 100644 rust/feature-flags/src/cohort_definitions.rs diff --git a/rust/feature-flags/src/cohort_definitions.rs b/rust/feature-flags/src/cohort_definitions.rs new file mode 100644 index 0000000000000..2d7e7c20b60d1 --- /dev/null +++ b/rust/feature-flags/src/cohort_definitions.rs @@ -0,0 +1,273 @@ +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use std::sync::Arc; +use tracing::instrument; + +use crate::{ + api::FlagError, + database::Client as DatabaseClient, + flag_definitions::{OperatorType, PropertyFilter}, +}; + +#[derive(Debug, FromRow)] +struct CohortRow { + id: i32, + name: String, + description: Option, + team_id: i32, + deleted: bool, + filters: serde_json::Value, + query: Option, + version: Option, + pending_version: Option, + count: Option, + is_calculating: bool, + is_static: bool, +} + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +pub struct Cohort { + pub id: i32, + pub name: String, + pub description: Option, + pub team_id: i32, + pub deleted: bool, + pub filters: serde_json::Value, + pub query: Option, + pub version: Option, + pub pending_version: Option, + pub count: Option, + pub is_calculating: bool, + pub is_static: bool, +} + +impl Cohort { + /// Returns a cohort from postgres given a cohort_id and team_id + #[instrument(skip_all)] + pub async fn from_pg( + client: Arc, + cohort_id: i32, + team_id: i32, + ) -> Result { + let mut conn = client.get_connection().await.map_err(|e| { + tracing::error!("Failed to get database connection: {}", e); + // TODO should I model my errors more generally? Like, yes, everything behind this API is technically a FlagError, + // but I'm not sure if accessing Cohort definitions should be a FlagError (vs idk, a CohortError? A more general API error?) + FlagError::DatabaseUnavailable + })?; + + let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static FROM posthog_cohort WHERE id = $1 AND team_id = $2"; + let cohort_row = sqlx::query_as::<_, CohortRow>(query) + .bind(cohort_id) + .bind(team_id) + .fetch_optional(&mut *conn) + .await + .map_err(|e| { + tracing::error!("Failed to fetch cohort from database: {}", e); + FlagError::Internal(format!("Database query error: {}", e)) + })?; + + match cohort_row { + Some(row) => Ok(Cohort { + id: row.id, + name: row.name, + description: row.description, + team_id: row.team_id, + deleted: row.deleted, + filters: row.filters, + query: row.query, + version: row.version, + pending_version: row.pending_version, + count: row.count, + is_calculating: row.is_calculating, + is_static: row.is_static, + }), + None => Err(FlagError::DatabaseError(format!( + "Cohort with id {} not found for team {}", + cohort_id, team_id + ))), + } + } + + /// Parses the filters JSON into a CohortProperty structure + fn parse_filters(&self) -> Result { + serde_json::from_value(self.filters.clone()).map_err(|e| { + tracing::error!("Failed to parse filters: {}", e); + FlagError::Internal(format!("Invalid filters format: {}", e)) + }) + } +} + +use std::collections::{HashMap, HashSet}; + +type CohortId = i32; + +// Assuming CohortOrEmpty is an enum or struct representing a Cohort or an empty value +enum CohortOrEmpty { + Cohort(Cohort), + Empty, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "UPPERCASE")] +enum CohortPropertyType { + AND, + OR, +} + +// TODO this should serialize to "properties" in the DB +#[derive(Debug, Clone, Deserialize, Serialize)] +struct CohortProperty { + #[serde(rename = "type")] + prop_type: CohortPropertyType, // TODO make this an AND|OR string enum + values: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +struct CohortValues { + #[serde(rename = "type")] + prop_type: String, + values: Vec, +} + +fn sort_cohorts_topologically( + cohort_ids: HashSet, + seen_cohorts_cache: &HashMap, +) -> Vec { + if cohort_ids.is_empty() { + return Vec::new(); + } + + let mut dependency_graph: HashMap> = HashMap::new(); + let mut seen = HashSet::new(); + + // Build graph (adjacency list) + fn traverse( + cohort: &Cohort, + dependency_graph: &mut HashMap>, + seen_cohorts: &mut HashSet, + seen_cohorts_cache: &HashMap, + ) { + if seen_cohorts.contains(&cohort.id) { + return; + } + seen_cohorts.insert(cohort.id); + + // Parse the filters into CohortProperty + let cohort_property = match cohort.parse_filters() { + Ok(property) => property, + Err(e) => { + tracing::error!("Error parsing filters for cohort {}: {}", cohort.id, e); + return; + } + }; + + // Iterate through the properties to find dependencies + for value in &cohort_property.values { + if value.prop_type == "cohort" { + if let Some(id) = value.value.as_i64() { + let child_id = id as CohortId; + dependency_graph + .entry(cohort.id) + .or_insert_with(Vec::new) + .push(child_id); + + if let Some(CohortOrEmpty::Cohort(child_cohort)) = + seen_cohorts_cache.get(&child_id) + { + traverse( + child_cohort, + dependency_graph, + seen_cohorts, + seen_cohorts_cache, + ); + } + } else if let Some(id_str) = value.value.as_str() { + if let Ok(child_id) = id_str.parse::() { + dependency_graph + .entry(cohort.id) + .or_insert_with(Vec::new) + .push(child_id); + + if let Some(CohortOrEmpty::Cohort(child_cohort)) = + seen_cohorts_cache.get(&child_id) + { + traverse( + child_cohort, + dependency_graph, + seen_cohorts, + seen_cohorts_cache, + ); + } + } + } + } + + // Handle nested properties recursively if needed + if let Some(nested_values) = &value.values { + for nested in nested_values { + if nested.prop_type == "cohort" { + if let Some(id) = nested.value.as_i64() { + let child_id = id as CohortId; + dependency_graph + .entry(cohort.id) + .or_insert_with(Vec::new) + .push(child_id); + + if let Some(CohortOrEmpty::Cohort(child_cohort)) = + seen_cohorts_cache.get(&child_id) + { + traverse( + child_cohort, + dependency_graph, + seen_cohorts, + seen_cohorts_cache, + ); + } + } + } + } + } + } + } + + for &cohort_id in &cohort_ids { + if let Some(CohortOrEmpty::Cohort(cohort)) = seen_cohorts_cache.get(&cohort_id) { + traverse(cohort, &mut dependency_graph, &mut seen, seen_cohorts_cache); + } + } + + // Post-order DFS (children first, then the parent) + fn dfs( + node: CohortId, + seen: &mut HashSet, + sorted_arr: &mut Vec, + dependency_graph: &HashMap>, + ) { + if let Some(neighbors) = dependency_graph.get(&node) { + for &neighbor in neighbors { + if !seen.contains(&neighbor) { + dfs(neighbor, seen, sorted_arr, dependency_graph); + } + } + } + sorted_arr.push(node); + seen.insert(node); + } + + let mut sorted_cohort_ids = Vec::new(); + let mut seen = HashSet::new(); + for &cohort_id in &cohort_ids { + if !seen.contains(&cohort_id) { + seen.insert(cohort_id); + dfs( + cohort_id, + &mut seen, + &mut sorted_cohort_ids, + &dependency_graph, + ); + } + } + + sorted_cohort_ids +} diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index baebaa04da30e..627f4c3400569 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -37,6 +37,8 @@ pub struct PropertyFilter { pub operator: Option, #[serde(rename = "type")] pub prop_type: String, + // TODO add negation here? + pub negation: Option, pub group_type_index: Option, } diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index d6449f993d15d..861acb30c6f93 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1589,6 +1589,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -1649,6 +1650,7 @@ mod tests { operator: None, prop_type: "group".to_string(), group_type_index: Some(1), + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -1870,6 +1872,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -1940,6 +1943,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "age".to_string(), @@ -1947,6 +1951,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]), rollout_percentage: Some(100.0), @@ -2149,6 +2154,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "age".to_string(), @@ -2156,6 +2162,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]; @@ -2169,6 +2176,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "cohort".to_string(), @@ -2176,6 +2184,7 @@ mod tests { operator: None, prop_type: "cohort".to_string(), group_type_index: None, + negation: None, }, ]; @@ -2263,6 +2272,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "email".to_string(), @@ -2270,6 +2280,7 @@ mod tests { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]), rollout_percentage: Some(100.0), @@ -2490,6 +2501,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2550,6 +2562,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2602,6 +2615,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "age".to_string(), @@ -2609,6 +2623,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]), rollout_percentage: Some(100.0), @@ -2722,6 +2737,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2733,6 +2749,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2794,6 +2811,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(0.0), variant: None, @@ -2805,6 +2823,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2825,6 +2844,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2923,6 +2943,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(0.0), variant: None, @@ -2934,6 +2955,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2954,6 +2976,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3012,6 +3035,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(0.0), variant: None, @@ -3023,6 +3047,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3043,6 +3068,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3303,6 +3329,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3384,6 +3411,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3455,6 +3483,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3483,6 +3512,7 @@ mod tests { operator: Some(OperatorType::Gt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, diff --git a/rust/feature-flags/src/flag_request.rs b/rust/feature-flags/src/flag_request.rs index feb8577b24dea..1cf64eb879ac4 100644 --- a/rust/feature-flags/src/flag_request.rs +++ b/rust/feature-flags/src/flag_request.rs @@ -363,6 +363,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(50.0), variant: None, @@ -405,6 +406,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, diff --git a/rust/feature-flags/src/lib.rs b/rust/feature-flags/src/lib.rs index 051b3e27697f3..2a001c1953914 100644 --- a/rust/feature-flags/src/lib.rs +++ b/rust/feature-flags/src/lib.rs @@ -1,4 +1,5 @@ pub mod api; +pub mod cohort_definitions; pub mod config; pub mod database; pub mod feature_flag_match_reason; diff --git a/rust/feature-flags/src/property_matching.rs b/rust/feature-flags/src/property_matching.rs index 8d12fe6ab5e9d..228d793d52478 100644 --- a/rust/feature-flags/src/property_matching.rs +++ b/rust/feature-flags/src/property_matching.rs @@ -260,6 +260,7 @@ mod test_match_properties { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -313,6 +314,7 @@ mod test_match_properties { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -335,6 +337,7 @@ mod test_match_properties { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -379,6 +382,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNot), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -416,6 +420,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNot), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -490,6 +495,7 @@ mod test_match_properties { operator: Some(OperatorType::IsSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -538,6 +544,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -595,6 +602,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -634,6 +642,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -674,6 +683,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( &property_b, @@ -708,6 +718,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -730,6 +741,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( &property_d, @@ -760,6 +772,7 @@ mod test_match_properties { operator: Some(OperatorType::Gt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -802,6 +815,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -848,6 +862,7 @@ mod test_match_properties { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -889,6 +904,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -935,6 +951,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1013,6 +1030,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNot), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1034,6 +1052,7 @@ mod test_match_properties { operator: Some(OperatorType::IsSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1049,6 +1068,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1070,6 +1090,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1085,6 +1106,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1118,6 +1140,7 @@ mod test_match_properties { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1137,6 +1160,7 @@ mod test_match_properties { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1152,6 +1176,7 @@ mod test_match_properties { operator: Some(OperatorType::IsSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1167,6 +1192,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNotSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1203,6 +1229,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1218,6 +1245,7 @@ mod test_match_properties { operator: Some(OperatorType::NotIcontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1233,6 +1261,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1248,6 +1277,7 @@ mod test_match_properties { operator: Some(OperatorType::NotRegex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1263,6 +1293,7 @@ mod test_match_properties { operator: Some(OperatorType::Gt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1278,6 +1309,7 @@ mod test_match_properties { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1293,6 +1325,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1308,6 +1341,7 @@ mod test_match_properties { operator: Some(OperatorType::Lte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1324,6 +1358,7 @@ mod test_match_properties { operator: Some(OperatorType::IsDateBefore), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs index 066a9c8f8a889..4c327cd32b589 100644 --- a/rust/feature-flags/src/request_handler.rs +++ b/rust/feature-flags/src/request_handler.rs @@ -371,6 +371,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), // Set to 100% to ensure it's always on variant: None, @@ -624,6 +625,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "group".to_string(), group_type_index: Some(0), + negation: None, }]), rollout_percentage: Some(100.0), variant: None, From ed00224ecf4bd6fbe0898c3f28f34f23d30824c9 Mon Sep 17 00:00:00 2001 From: dylan Date: Wed, 23 Oct 2024 15:33:20 -0700 Subject: [PATCH 03/30] oh lol right let's actually ship --- rust/feature-flags/src/api.rs | 6 + rust/feature-flags/src/cohort_definitions.rs | 192 ++++++++++++------- rust/feature-flags/src/flag_definitions.rs | 33 +++- rust/feature-flags/src/flag_matching.rs | 15 +- 4 files changed, 174 insertions(+), 72 deletions(-) diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index 4430476d28a52..b9f12a4f77d41 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -102,6 +102,8 @@ pub enum FlagError { TimeoutError, #[error("No group type mappings")] NoGroupTypeMappings, + #[error("Invalid cohort id")] + InvalidCohortId, } impl IntoResponse for FlagError { @@ -194,6 +196,10 @@ impl IntoResponse for FlagError { "The requested row was not found in the database. Please try again later or contact support if the problem persists.".to_string(), ) } + FlagError::InvalidCohortId => { + tracing::error!("Invalid cohort id: {:?}", self); + (StatusCode::BAD_REQUEST, "Invalid cohort id".to_string()) + } } .into_response() } diff --git a/rust/feature-flags/src/cohort_definitions.rs b/rust/feature-flags/src/cohort_definitions.rs index 2d7e7c20b60d1..a362a39447374 100644 --- a/rust/feature-flags/src/cohort_definitions.rs +++ b/rust/feature-flags/src/cohort_definitions.rs @@ -1,13 +1,10 @@ use serde::{Deserialize, Serialize}; use sqlx::FromRow; +use std::collections::{HashMap, HashSet, VecDeque}; use std::sync::Arc; use tracing::instrument; -use crate::{ - api::FlagError, - database::Client as DatabaseClient, - flag_definitions::{OperatorType, PropertyFilter}, -}; +use crate::{api::FlagError, database::Client as DatabaseClient, flag_definitions::PropertyFilter}; #[derive(Debug, FromRow)] struct CohortRow { @@ -25,7 +22,7 @@ struct CohortRow { is_static: bool, } -#[derive(Debug, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Cohort { pub id: i32, pub name: String, @@ -90,47 +87,56 @@ impl Cohort { } /// Parses the filters JSON into a CohortProperty structure - fn parse_filters(&self) -> Result { - serde_json::from_value(self.filters.clone()).map_err(|e| { - tracing::error!("Failed to parse filters: {}", e); - FlagError::Internal(format!("Invalid filters format: {}", e)) - }) + pub fn parse_filters(&self) -> Result, FlagError> { + let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone())?; + Ok(cohort_property.to_property_filters()) } } -use std::collections::{HashMap, HashSet}; - type CohortId = i32; // Assuming CohortOrEmpty is an enum or struct representing a Cohort or an empty value -enum CohortOrEmpty { +pub enum CohortOrEmpty { Cohort(Cohort), Empty, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] #[serde(rename_all = "UPPERCASE")] -enum CohortPropertyType { +pub enum CohortPropertyType { AND, OR, } // TODO this should serialize to "properties" in the DB #[derive(Debug, Clone, Deserialize, Serialize)] -struct CohortProperty { +pub struct CohortProperty { #[serde(rename = "type")] prop_type: CohortPropertyType, // TODO make this an AND|OR string enum values: Vec, } +impl CohortProperty { + pub fn to_property_filters(&self) -> Vec { + self.values + .iter() + .flat_map(|value| &value.values) + .cloned() + .collect() + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] -struct CohortValues { +pub struct CohortValues { #[serde(rename = "type")] prop_type: String, values: Vec, } -fn sort_cohorts_topologically( +/// Sorts the given cohorts in an order where cohorts with no dependencies are placed first, +/// followed by cohorts that depend on the preceding ones. It ensures that each cohort in the sorted list +/// only depends on cohorts that appear earlier in the list. +pub fn sort_cohorts_topologically( cohort_ids: HashSet, seen_cohorts_cache: &HashMap, ) -> Vec { @@ -153,20 +159,25 @@ fn sort_cohorts_topologically( } seen_cohorts.insert(cohort.id); - // Parse the filters into CohortProperty - let cohort_property = match cohort.parse_filters() { - Ok(property) => property, + // Parse the filters into PropertyFilters + let property_filters = match cohort.parse_filters() { + Ok(filters) => filters, Err(e) => { tracing::error!("Error parsing filters for cohort {}: {}", cohort.id, e); return; } }; - // Iterate through the properties to find dependencies - for value in &cohort_property.values { - if value.prop_type == "cohort" { - if let Some(id) = value.value.as_i64() { - let child_id = id as CohortId; + // Iterate through the property filters to find dependencies + for filter in property_filters { + if filter.prop_type == "cohort" { + let child_id = match filter.value { + serde_json::Value::Number(num) => num.as_i64().map(|n| n as CohortId), + serde_json::Value::String(ref s) => s.parse::().ok(), + _ => None, + }; + + if let Some(child_id) = child_id { dependency_graph .entry(cohort.id) .or_insert_with(Vec::new) @@ -182,50 +193,6 @@ fn sort_cohorts_topologically( seen_cohorts_cache, ); } - } else if let Some(id_str) = value.value.as_str() { - if let Ok(child_id) = id_str.parse::() { - dependency_graph - .entry(cohort.id) - .or_insert_with(Vec::new) - .push(child_id); - - if let Some(CohortOrEmpty::Cohort(child_cohort)) = - seen_cohorts_cache.get(&child_id) - { - traverse( - child_cohort, - dependency_graph, - seen_cohorts, - seen_cohorts_cache, - ); - } - } - } - } - - // Handle nested properties recursively if needed - if let Some(nested_values) = &value.values { - for nested in nested_values { - if nested.prop_type == "cohort" { - if let Some(id) = nested.value.as_i64() { - let child_id = id as CohortId; - dependency_graph - .entry(cohort.id) - .or_insert_with(Vec::new) - .push(child_id); - - if let Some(CohortOrEmpty::Cohort(child_cohort)) = - seen_cohorts_cache.get(&child_id) - { - traverse( - child_cohort, - dependency_graph, - seen_cohorts, - seen_cohorts_cache, - ); - } - } - } } } } @@ -271,3 +238,88 @@ fn sort_cohorts_topologically( sorted_cohort_ids } + +pub async fn get_dependent_cohorts( + cohort: &Cohort, + seen_cohorts_cache: &mut HashMap, + team_id: i32, + db_client: Arc, +) -> Result, FlagError> { + let mut dependent_cohorts = Vec::new(); + let mut seen_cohort_ids = HashSet::new(); + seen_cohort_ids.insert(cohort.id); + + let mut queue = VecDeque::new(); + + let property_filters = match cohort.parse_filters() { + Ok(filters) => filters, + Err(e) => { + tracing::error!("Failed to parse filters for cohort {}: {}", cohort.id, e); + return Err(FlagError::Internal(format!( + "Failed to parse cohort filters: {}", + e + ))); + } + }; + + // Initial queue population + for filter in &property_filters { + if filter.prop_type == "cohort" { + if let Some(id) = filter.value.as_i64().map(|n| n as CohortId).or_else(|| { + filter + .value + .as_str() + .and_then(|s| s.parse::().ok()) + }) { + queue.push_back(id); + } + } + } + + while let Some(cohort_id) = queue.pop_front() { + let current_cohort = match seen_cohorts_cache.get(&cohort_id) { + Some(CohortOrEmpty::Cohort(c)) => c.clone(), + Some(CohortOrEmpty::Empty) => continue, + None => { + // Fetch the cohort from the database + match Cohort::from_pg(db_client.clone(), cohort_id, team_id).await { + Ok(c) => { + seen_cohorts_cache.insert(cohort_id, CohortOrEmpty::Cohort(c.clone())); + c + } + Err(e) => { + tracing::warn!("Failed to fetch cohort {}: {}", cohort_id, e); + seen_cohorts_cache.insert(cohort_id, CohortOrEmpty::Empty); + continue; + } + } + } + }; + + if !seen_cohort_ids.contains(¤t_cohort.id) { + dependent_cohorts.push(current_cohort.clone()); + seen_cohort_ids.insert(current_cohort.id); + + // Parse filters for the current cohort + if let Ok(current_filters) = current_cohort.parse_filters() { + // Add new cohort dependencies to the queue + for filter in current_filters { + if filter.prop_type == "cohort" { + if let Some(id) = + filter.value.as_i64().map(|n| n as CohortId).or_else(|| { + filter + .value + .as_str() + .and_then(|s| s.parse::().ok()) + }) + { + queue.push_back(id); + } + } + } + } + } + } + + Ok(dependent_cohorts) +} diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index 627f4c3400569..e746cda5b00f5 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -1,4 +1,7 @@ -use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; +use crate::{ + api::FlagError, cohort_definitions::Cohort, database::Client as DatabaseClient, + redis::Client as RedisClient, +}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; @@ -120,6 +123,34 @@ impl FeatureFlag { .and_then(|obj| obj.get(match_val).cloned()) }) } + + pub async fn transform_cohort_filters( + &self, + postgres_reader: Arc, + team_id: i32, + ) -> Result { + let mut transformed_flag = self.clone(); + for group in &mut transformed_flag.filters.groups { + if let Some(properties) = &mut group.properties { + let mut new_properties = Vec::new(); + for prop in properties.iter() { + if prop.prop_type == "cohort" { + // TODO is there a cleaner way to handle the who cohort values being numbers thing? + let cohort_id = prop.value.as_i64().ok_or(FlagError::InvalidCohortId)?; + let cohort = + Cohort::from_pg(postgres_reader.clone(), cohort_id as i32, team_id) + .await?; + let cohort_properties = cohort.parse_filters()?; + new_properties.extend(cohort_properties); + } else { + new_properties.push(prop.clone()); + } + } + group.properties = Some(new_properties); + } + } + Ok(transformed_flag) + } } #[derive(Clone, Debug, Default, Deserialize, Serialize)] diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 861acb30c6f93..4b18ca9c104c1 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -633,6 +633,19 @@ impl FeatureFlagMatcher { } } + let has_cohort_filter = flag.filters.groups.iter().any(|group| { + group.properties.as_ref().map_or(false, |props| { + props.iter().any(|prop| prop.prop_type == "cohort") + }) + }); + + let flag_to_evaluate = if has_cohort_filter { + flag.transform_cohort_filters(self.postgres_reader.clone(), self.team_id) + .await? + } else { + flag.clone() + }; + // Sort conditions with variant overrides to the top so that we can evaluate them first let mut sorted_conditions: Vec<(usize, &FlagGroupType)> = flag.get_conditions().iter().enumerate().collect(); @@ -643,7 +656,7 @@ impl FeatureFlagMatcher { for (index, condition) in sorted_conditions { let (is_match, reason) = self .is_condition_match( - flag, + &flag_to_evaluate, condition, property_overrides.clone(), hash_key_overrides.clone(), From fb8aab877cd056815f70a04a1009c17b97380d46 Mon Sep 17 00:00:00 2001 From: dylan Date: Wed, 23 Oct 2024 16:20:19 -0700 Subject: [PATCH 04/30] or default --- rust/feature-flags/src/cohort_definitions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/feature-flags/src/cohort_definitions.rs b/rust/feature-flags/src/cohort_definitions.rs index a362a39447374..dc24bfcc07d46 100644 --- a/rust/feature-flags/src/cohort_definitions.rs +++ b/rust/feature-flags/src/cohort_definitions.rs @@ -180,7 +180,7 @@ pub fn sort_cohorts_topologically( if let Some(child_id) = child_id { dependency_graph .entry(cohort.id) - .or_insert_with(Vec::new) + .or_default() .push(child_id); if let Some(CohortOrEmpty::Cohort(child_cohort)) = From eeea8cc88e173a3b1c9b4f53cbd2ce14d69bb3ae Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 24 Oct 2024 15:31:43 -0700 Subject: [PATCH 05/30] let's goooo --- rust/feature-flags/src/cohort_definitions.rs | 295 ++++++--- rust/feature-flags/src/flag_definitions.rs | 35 +- rust/feature-flags/src/flag_matching.rs | 600 ++++++++++++++++++- rust/feature-flags/src/property_matching.rs | 9 + rust/feature-flags/src/test_utils.rs | 55 ++ 5 files changed, 840 insertions(+), 154 deletions(-) diff --git a/rust/feature-flags/src/cohort_definitions.rs b/rust/feature-flags/src/cohort_definitions.rs index dc24bfcc07d46..3fc4ded9c4dcf 100644 --- a/rust/feature-flags/src/cohort_definitions.rs +++ b/rust/feature-flags/src/cohort_definitions.rs @@ -1,25 +1,28 @@ use serde::{Deserialize, Serialize}; use sqlx::FromRow; -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tracing::instrument; use crate::{api::FlagError, database::Client as DatabaseClient, flag_definitions::PropertyFilter}; #[derive(Debug, FromRow)] -struct CohortRow { - id: i32, - name: String, - description: Option, - team_id: i32, - deleted: bool, - filters: serde_json::Value, - query: Option, - version: Option, - pending_version: Option, - count: Option, - is_calculating: bool, - is_static: bool, +pub struct CohortRow { + pub id: i32, + pub name: String, + pub description: Option, + pub team_id: i32, + pub deleted: bool, + pub filters: serde_json::Value, + pub query: Option, + pub version: Option, + pub pending_version: Option, + pub count: Option, + pub is_calculating: bool, + pub is_static: bool, + pub errors_calculating: Option, // I think this has a null constraint, so maybe it shouldn't be optional + pub groups: serde_json::Value, + pub created_by_id: Option, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -36,6 +39,9 @@ pub struct Cohort { pub count: Option, pub is_calculating: bool, pub is_static: bool, + pub errors_calculating: Option, + pub groups: serde_json::Value, + pub created_by_id: Option, } impl Cohort { @@ -53,7 +59,7 @@ impl Cohort { FlagError::DatabaseUnavailable })?; - let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static FROM posthog_cohort WHERE id = $1 AND team_id = $2"; + let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id FROM posthog_cohort WHERE id = $1 AND team_id = $2"; let cohort_row = sqlx::query_as::<_, CohortRow>(query) .bind(cohort_id) .bind(team_id) @@ -78,6 +84,9 @@ impl Cohort { count: row.count, is_calculating: row.is_calculating, is_static: row.is_static, + errors_calculating: row.errors_calculating, + groups: row.groups, + created_by_id: row.created_by_id, }), None => Err(FlagError::DatabaseError(format!( "Cohort with id {} not found for team {}", @@ -93,7 +102,7 @@ impl Cohort { } } -type CohortId = i32; +pub type CohortId = i32; // Assuming CohortOrEmpty is an enum or struct representing a Cohort or an empty value pub enum CohortOrEmpty { @@ -239,87 +248,193 @@ pub fn sort_cohorts_topologically( sorted_cohort_ids } -pub async fn get_dependent_cohorts( - cohort: &Cohort, - seen_cohorts_cache: &mut HashMap, - team_id: i32, - db_client: Arc, -) -> Result, FlagError> { - let mut dependent_cohorts = Vec::new(); - let mut seen_cohort_ids = HashSet::new(); - seen_cohort_ids.insert(cohort.id); - - let mut queue = VecDeque::new(); - - let property_filters = match cohort.parse_filters() { - Ok(filters) => filters, - Err(e) => { - tracing::error!("Failed to parse filters for cohort {}: {}", cohort.id, e); - return Err(FlagError::Internal(format!( - "Failed to parse cohort filters: {}", - e - ))); - } +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::{ + insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client, + setup_pg_writer_client, }; + use serde_json::json; - // Initial queue population - for filter in &property_filters { - if filter.prop_type == "cohort" { - if let Some(id) = filter.value.as_i64().map(|n| n as CohortId).or_else(|| { - filter - .value - .as_str() - .and_then(|s| s.parse::().ok()) - }) { - queue.push_back(id); - } - } - } + #[tokio::test] + async fn test_cohort_from_pg() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; - while let Some(cohort_id) = queue.pop_front() { - let current_cohort = match seen_cohorts_cache.get(&cohort_id) { - Some(CohortOrEmpty::Cohort(c)) => c.clone(), - Some(CohortOrEmpty::Empty) => continue, - None => { - // Fetch the cohort from the database - match Cohort::from_pg(db_client.clone(), cohort_id, team_id).await { - Ok(c) => { - seen_cohorts_cache.insert(cohort_id, CohortOrEmpty::Cohort(c.clone())); - c + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + let cohort = insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + None, + json!({ + "type": "AND", + "values": [ + { + "type": "property", + "values": [ + { + "key": "email", + "value": "test@example.com", + "type": "person" + } + ] } - Err(e) => { - tracing::warn!("Failed to fetch cohort {}: {}", cohort_id, e); - seen_cohorts_cache.insert(cohort_id, CohortOrEmpty::Empty); - continue; + ] + }), + false, + ) + .await + .expect("Failed to insert cohort"); + + let fetched_cohort = Cohort::from_pg(postgres_reader, cohort.id, team.id) + .await + .expect("Failed to fetch cohort"); + + assert_eq!(fetched_cohort.id, cohort.id); + assert_eq!(fetched_cohort.name, "Test Cohort"); + assert_eq!(fetched_cohort.team_id, team.id); + } + + #[test] + fn test_cohort_parse_filters() { + let cohort = Cohort { + id: 1, + name: "Test Cohort".to_string(), + description: None, + team_id: 1, + deleted: false, + filters: json!({ + "type": "AND", + "values": [ + { + "type": "property", + "values": [ + { + "key": "email", + "value": "test@example.com", + "type": "person" + } + ] } - } - } + ] + }), + query: None, + version: None, + pending_version: None, + count: None, + is_calculating: false, + is_static: false, + errors_calculating: None, + groups: json!({}), + created_by_id: None, }; - if !seen_cohort_ids.contains(¤t_cohort.id) { - dependent_cohorts.push(current_cohort.clone()); - seen_cohort_ids.insert(current_cohort.id); - - // Parse filters for the current cohort - if let Ok(current_filters) = current_cohort.parse_filters() { - // Add new cohort dependencies to the queue - for filter in current_filters { - if filter.prop_type == "cohort" { - if let Some(id) = - filter.value.as_i64().map(|n| n as CohortId).or_else(|| { - filter - .value - .as_str() - .and_then(|s| s.parse::().ok()) - }) - { - queue.push_back(id); - } - } - } - } - } + let result = cohort.parse_filters().unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].key, "email"); + assert_eq!(result[0].value, json!("test@example.com")); + assert_eq!(result[0].prop_type, "person"); } - Ok(dependent_cohorts) + #[test] + fn test_sort_cohorts_topologically() { + let mut cohorts = HashMap::new(); + cohorts.insert( + 1, + CohortOrEmpty::Cohort(Cohort { + id: 1, + name: "Cohort 1".to_string(), + description: None, + team_id: 1, + deleted: false, + filters: json!({"type": "AND", "values": []}), + query: None, + version: None, + pending_version: None, + count: None, + is_calculating: false, + is_static: false, + errors_calculating: None, + groups: json!({}), + created_by_id: None, + }), + ); + cohorts.insert(2, CohortOrEmpty::Cohort(Cohort { + id: 2, + name: "Cohort 2".to_string(), + description: None, + team_id: 1, + deleted: false, + filters: json!({"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 1, "type": "cohort"}]}]}), + query: None, + version: None, + pending_version: None, + count: None, + is_calculating: false, + is_static: false, + errors_calculating: None, + groups: json!({}), + created_by_id: None, + })); + cohorts.insert(3, CohortOrEmpty::Cohort(Cohort { + id: 3, + name: "Cohort 3".to_string(), + description: None, + team_id: 1, + deleted: false, + filters: json!({"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 2, "type": "cohort"}]}]}), + query: None, + version: None, + pending_version: None, + count: None, + is_calculating: false, + is_static: false, + errors_calculating: None, + groups: json!({}), + created_by_id: None, + })); + + let cohort_ids: HashSet = vec![1, 2, 3].into_iter().collect(); + let result = sort_cohorts_topologically(cohort_ids, &cohorts); + assert_eq!(result, vec![1, 2, 3]); + } + + #[test] + fn test_cohort_property_to_property_filters() { + let cohort_property = CohortProperty { + prop_type: CohortPropertyType::AND, + values: vec![CohortValues { + prop_type: "property".to_string(), + values: vec![ + PropertyFilter { + key: "email".to_string(), + value: json!("test@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }, + PropertyFilter { + key: "age".to_string(), + value: json!(25), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }, + ], + }], + }; + + let result = cohort_property.to_property_filters(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].key, "email"); + assert_eq!(result[0].value, json!("test@example.com")); + assert_eq!(result[1].key, "age"); + assert_eq!(result[1].value, json!(25)); + } } diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index e746cda5b00f5..f081fbd98f21b 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -1,7 +1,4 @@ -use crate::{ - api::FlagError, cohort_definitions::Cohort, database::Client as DatabaseClient, - redis::Client as RedisClient, -}; +use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; @@ -28,6 +25,8 @@ pub enum OperatorType { IsDateExact, IsDateAfter, IsDateBefore, + In, + NotIn, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -123,34 +122,6 @@ impl FeatureFlag { .and_then(|obj| obj.get(match_val).cloned()) }) } - - pub async fn transform_cohort_filters( - &self, - postgres_reader: Arc, - team_id: i32, - ) -> Result { - let mut transformed_flag = self.clone(); - for group in &mut transformed_flag.filters.groups { - if let Some(properties) = &mut group.properties { - let mut new_properties = Vec::new(); - for prop in properties.iter() { - if prop.prop_type == "cohort" { - // TODO is there a cleaner way to handle the who cohort values being numbers thing? - let cohort_id = prop.value.as_i64().ok_or(FlagError::InvalidCohortId)?; - let cohort = - Cohort::from_pg(postgres_reader.clone(), cohort_id as i32, team_id) - .await?; - let cohort_properties = cohort.parse_filters()?; - new_properties.extend(cohort_properties); - } else { - new_properties.push(prop.clone()); - } - } - group.properties = Some(new_properties); - } - } - Ok(transformed_flag) - } } #[derive(Clone, Debug, Default, Deserialize, Serialize)] diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index c6ce6eb29dfd4..f7097e696da79 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,8 +1,9 @@ use crate::{ api::{FlagError, FlagValue, FlagsResponse}, + cohort_definitions::{sort_cohorts_topologically, Cohort, CohortId, CohortOrEmpty}, database::Client as DatabaseClient, feature_flag_match_reason::FeatureFlagMatchReason, - flag_definitions::{FeatureFlag, FeatureFlagList, FlagGroupType, PropertyFilter}, + flag_definitions::{FeatureFlag, FeatureFlagList, FlagGroupType, OperatorType, PropertyFilter}, metrics_consts::{FLAG_EVALUATION_ERROR_COUNTER, FLAG_HASH_KEY_WRITES_COUNTER}, property_matching::match_property, utils::parse_exception_for_prometheus_label, @@ -633,19 +634,6 @@ impl FeatureFlagMatcher { } } - let has_cohort_filter = flag.filters.groups.iter().any(|group| { - group.properties.as_ref().map_or(false, |props| { - props.iter().any(|prop| prop.prop_type == "cohort") - }) - }); - - let flag_to_evaluate = if has_cohort_filter { - flag.transform_cohort_filters(self.postgres_reader.clone(), self.team_id) - .await? - } else { - flag.clone() - }; - // Sort conditions with variant overrides to the top so that we can evaluate them first let mut sorted_conditions: Vec<(usize, &FlagGroupType)> = flag.get_conditions().iter().enumerate().collect(); @@ -656,7 +644,7 @@ impl FeatureFlagMatcher { for (index, condition) in sorted_conditions { let (is_match, reason) = self .is_condition_match( - &flag_to_evaluate, + flag, condition, property_overrides.clone(), hash_key_overrides.clone(), @@ -745,12 +733,30 @@ impl FeatureFlagMatcher { .await; } - // NB: we can only evaluate group or person properties, not both + // Separate cohort and non-cohort filters + let (cohort_filters, non_cohort_filters): (Vec, Vec) = + flag_property_filters + .iter() + .cloned() + .partition(|prop| prop.prop_type == "cohort"); + + // Evaluate non-cohort properties first to get properties_to_check let properties_to_check = self - .get_properties_to_check(feature_flag, property_overrides, flag_property_filters) + .get_properties_to_check(feature_flag, property_overrides, &non_cohort_filters) .await?; - if !all_properties_match(flag_property_filters, &properties_to_check) { + // Evaluate cohort conditions + if !cohort_filters.is_empty() { + if !self + .evaluate_cohort_filters(&cohort_filters, &properties_to_check) + .await? + { + return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); + } + } + + // Evaluate non-cohort properties + if !all_properties_match(&non_cohort_filters, &properties_to_check) { return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); } } @@ -818,6 +824,85 @@ impl FeatureFlagMatcher { } } + pub async fn evaluate_cohort_filters( + &self, + filters: &[PropertyFilter], + target_properties: &HashMap, + ) -> Result { + Box::pin(self.evaluate_potentially_nested_cohort_filters(filters, target_properties)).await + } + + async fn evaluate_potentially_nested_cohort_filters( + &self, + filters: &[PropertyFilter], + target_properties: &HashMap, + ) -> Result { + let mut cohort_filters = Vec::new(); + let mut non_cohort_filters = Vec::new(); + + // Separate cohort filters from non-cohort filters + for filter in filters { + if filter.prop_type == "cohort" { + cohort_filters.push(filter); + } else { + non_cohort_filters.push(filter); + } + } + + // Evaluate non-cohort filters + for filter in &non_cohort_filters { + if !match_property(filter, target_properties, false).unwrap_or(false) { + return Ok(false); + } + } + + // Evaluate cohort filters + if !cohort_filters.is_empty() { + let cohort_ids: HashSet = cohort_filters + .iter() + .filter_map(|f| f.value.as_i64().map(|id| id as CohortId)) + .collect(); + + let cohorts = self.fetch_cohorts(cohort_ids.clone()).await?; + let seen_cohorts_cache: HashMap = cohorts + .into_iter() + .map(|cohort| (cohort.id, CohortOrEmpty::Cohort(cohort))) + .collect(); + + let sorted_cohort_ids = sort_cohorts_topologically(cohort_ids, &seen_cohorts_cache); + + for cohort_id in sorted_cohort_ids { + if let Some(CohortOrEmpty::Cohort(cohort)) = seen_cohorts_cache.get(&cohort_id) { + let cohort_property_filters = cohort.parse_filters()?; // TODO error handle + let cohort_match = self + .evaluate_cohort_filters(&cohort_property_filters, target_properties) + .await?; + + let filter = cohort_filters + .iter() + .find(|f| f.value.as_i64() == Some(cohort_id as i64)) + .unwrap(); + match filter.operator { + Some(OperatorType::In) if !cohort_match => return Ok(false), + Some(OperatorType::NotIn) if cohort_match => return Ok(false), + _ => {} + } + } + } + } + + Ok(true) + } + + async fn fetch_cohorts(&self, cohort_ids: HashSet) -> Result, FlagError> { + let mut cohorts = Vec::new(); + for &id in &cohort_ids { + let cohort = Cohort::from_pg(self.postgres_reader.clone(), id, self.team_id).await?; + cohorts.push(cohort); + } + Ok(cohorts) + } + /// Check if a super condition matches for a feature flag. /// /// This function evaluates the super conditions of a feature flag to determine if any of them should be enabled. @@ -1456,8 +1541,8 @@ mod tests { OperatorType, }, test_utils::{ - insert_flag_for_team_in_pg, insert_new_team_in_pg, insert_person_for_team_in_pg, - setup_pg_reader_client, setup_pg_writer_client, + insert_cohort_for_team_in_pg, insert_flag_for_team_in_pg, insert_new_team_in_pg, + insert_person_for_team_in_pg, setup_pg_reader_client, setup_pg_writer_client, }, }; @@ -3155,6 +3240,462 @@ mod tests { assert_eq!(result_another_id.condition_index, Some(2)); } + #[tokio::test] + async fn test_basic_cohort_matching() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with the condition that matches the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] + }] + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + None, + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + println!("{:?}", result); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_not_in_cohort_matching() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with a condition that does not match the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "130", + "negation": false, + "operator": "gt" + }] + }] + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that do not match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a NotIn cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::NotIn), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + None, + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + println!("{:?}", result); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_not_in_cohort_matching_user_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with a condition that matches the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] + }] + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a NotIn cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::NotIn), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + None, + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + println!("{:?}", result); + + // The user matches the cohort, but the flag is set to NotIn, so it should evaluate to false + assert!(!result.matches); + } + + #[tokio::test] + async fn test_cohort_dependent_on_another_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a base cohort + let base_cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] + }] + }), + false, + ) + .await + .unwrap(); + + // Insert a dependent cohort that includes the base cohort + let dependent_cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "id", + "type": "cohort", + "value": base_cohort_row.id, + "negation": false, + "operator": "in" + }] + }] + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that match the base cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a cohort filter that depends on another cohort + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(dependent_cohort_row.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + None, + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + println!("{:?}", result); + + // This test might fail if the system doesn't support cohort dependencies + assert!(result.matches); + } + + #[tokio::test] + async fn test_in_cohort_matching_user_not_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with a condition that does not match the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "130", + "negation": false, + "operator": "gt" + }] + }] + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that do not match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 125})), + ) + .await + .unwrap(); + + // Define a flag with an In cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + None, + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + println!("{:?}", result); + + // The user does not match the cohort, and the flag is set to In, so it should evaluate to false + assert!(!result.matches); + } + #[tokio::test] async fn test_set_feature_flag_hash_key_overrides_success() { let postgres_reader = setup_pg_reader_client(None).await; @@ -3162,7 +3703,7 @@ mod tests { let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); - let distinct_id = "user1".to_string(); + let distinct_id = "user2".to_string(); // Insert person insert_person_for_team_in_pg(postgres_reader.clone(), team.id, distinct_id.clone(), None) @@ -3187,7 +3728,7 @@ mod tests { Some(true), // ensure_experience_continuity ); - // need to convert flag to FeatureFlagRow + // Convert flag to FeatureFlagRow let flag_row = FeatureFlagRow { id: flag.id, team_id: flag.team_id, @@ -3204,8 +3745,8 @@ mod tests { .await .unwrap(); - // Attempt to set hash key override - let result = set_feature_flag_hash_key_overrides( + // Set hash key override + set_feature_flag_hash_key_overrides( postgres_writer.clone(), team.id, vec![distinct_id.clone()], @@ -3214,9 +3755,7 @@ mod tests { .await .unwrap(); - assert!(result, "Hash key override should be set successfully"); - - // Retrieve the hash key overrides + // Retrieve hash key overrides let overrides = get_feature_flag_hash_key_overrides( postgres_reader.clone(), team.id, @@ -3225,14 +3764,10 @@ mod tests { .await .unwrap(); - assert!( - !overrides.is_empty(), - "At least one hash key override should be set" - ); assert_eq!( overrides.get("test_flag"), Some(&"hash_key_2".to_string()), - "Hash key override for 'test_flag' should match the set value" + "Hash key override should match the set value" ); } @@ -3310,6 +3845,7 @@ mod tests { "Hash key override should match the set value" ); } + #[tokio::test] async fn test_evaluate_feature_flags_with_experience_continuity() { let postgres_reader = setup_pg_reader_client(None).await; diff --git a/rust/feature-flags/src/property_matching.rs b/rust/feature-flags/src/property_matching.rs index 228d793d52478..2d13276135223 100644 --- a/rust/feature-flags/src/property_matching.rs +++ b/rust/feature-flags/src/property_matching.rs @@ -43,6 +43,9 @@ pub fn match_property( ))); } + println!("property: {:?}", property); + println!("matching_property_values: {:?}", matching_property_values); + let key = &property.key; let operator = property.operator.clone().unwrap_or(OperatorType::Exact); let value = &property.value; @@ -193,6 +196,12 @@ pub fn match_property( // Ok(false) // } } + OperatorType::In | OperatorType::NotIn => { + // TODO: we handle these in cohort matching, so we can just return false here + // because by the time we match properties, we've already decomposed the cohort + // filter into multiple property filters + Ok(false) + } } } diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index a36cae6016eca..c6084d89a20e9 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use uuid::Uuid; use crate::{ + cohort_definitions::CohortRow, config::{Config, DEFAULT_TEST_CONFIG}, database::{get_pool, Client, CustomDatabaseError}, flag_definitions::{self, FeatureFlag, FeatureFlagRow}, @@ -356,3 +357,57 @@ pub async fn insert_person_for_team_in_pg( Ok(()) } + +pub async fn insert_cohort_for_team_in_pg( + client: Arc, + team_id: i32, + name: Option, + filters: serde_json::Value, + is_static: bool, +) -> Result { + let cohort_row = CohortRow { + id: 0, // Placeholder, will be updated after insertion + name: name.unwrap_or("Test Cohort".to_string()), + description: Some("Description for cohort".to_string()), + team_id, + deleted: false, + filters, + query: None, + version: Some(1), + pending_version: None, + count: None, + is_calculating: false, + is_static, + errors_calculating: Some(0), + groups: serde_json::json!([]), + created_by_id: None, + }; + + let mut conn = client.get_connection().await?; + let row: (i32,) = sqlx::query_as( + r#"INSERT INTO posthog_cohort + (name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id) VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + RETURNING id"#, + ) + .bind(&cohort_row.name) + .bind(&cohort_row.description) + .bind(cohort_row.team_id) + .bind(cohort_row.deleted) + .bind(&cohort_row.filters) + .bind(&cohort_row.query) + .bind(cohort_row.version) + .bind(cohort_row.pending_version) + .bind(cohort_row.count) + .bind(cohort_row.is_calculating) + .bind(cohort_row.is_static) + .bind(cohort_row.errors_calculating.unwrap_or(0)) + .bind(&cohort_row.groups) + .bind(cohort_row.created_by_id) + .fetch_one(&mut *conn) + .await?; + + let id = row.0; + + Ok(CohortRow { id, ..cohort_row }) +} From db8cd8d8cc933d1128402bec59b0ab44d6d1dc24 Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 24 Oct 2024 15:59:56 -0700 Subject: [PATCH 06/30] modeled the data correctly this time :sweat: --- rust/feature-flags/src/cohort_definitions.rs | 58 ++++------ rust/feature-flags/src/flag_matching.rs | 108 ++++++++++--------- 2 files changed, 78 insertions(+), 88 deletions(-) diff --git a/rust/feature-flags/src/cohort_definitions.rs b/rust/feature-flags/src/cohort_definitions.rs index 3fc4ded9c4dcf..a655ff5f07259 100644 --- a/rust/feature-flags/src/cohort_definitions.rs +++ b/rust/feature-flags/src/cohort_definitions.rs @@ -97,7 +97,9 @@ impl Cohort { /// Parses the filters JSON into a CohortProperty structure pub fn parse_filters(&self) -> Result, FlagError> { - let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone())?; + let wrapper: serde_json::Value = serde_json::from_value(self.filters.clone())?; + let cohort_property: InnerCohortProperty = + serde_json::from_value(wrapper["properties"].clone())?; Ok(cohort_property.to_property_filters()) } } @@ -117,15 +119,19 @@ pub enum CohortPropertyType { OR, } -// TODO this should serialize to "properties" in the DB #[derive(Debug, Clone, Deserialize, Serialize)] pub struct CohortProperty { + properties: InnerCohortProperty, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct InnerCohortProperty { #[serde(rename = "type")] - prop_type: CohortPropertyType, // TODO make this an AND|OR string enum + prop_type: CohortPropertyType, values: Vec, } -impl CohortProperty { +impl InnerCohortProperty { pub fn to_property_filters(&self) -> Vec { self.values .iter() @@ -270,21 +276,7 @@ mod tests { postgres_writer.clone(), team.id, None, - json!({ - "type": "AND", - "values": [ - { - "type": "property", - "values": [ - { - "key": "email", - "value": "test@example.com", - "type": "person" - } - ] - } - ] - }), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$initial_browser_version", "type": "person", "value": ["125"], "negation": false, "operator": "exact"}]}]}}), false, ) .await @@ -307,21 +299,7 @@ mod tests { description: None, team_id: 1, deleted: false, - filters: json!({ - "type": "AND", - "values": [ - { - "type": "property", - "values": [ - { - "key": "email", - "value": "test@example.com", - "type": "person" - } - ] - } - ] - }), + filters: json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$initial_browser_version", "type": "person", "value": ["125"], "negation": false, "operator": "exact"}]}]}}), query: None, version: None, pending_version: None, @@ -335,8 +313,8 @@ mod tests { let result = cohort.parse_filters().unwrap(); assert_eq!(result.len(), 1); - assert_eq!(result[0].key, "email"); - assert_eq!(result[0].value, json!("test@example.com")); + assert_eq!(result[0].key, "$initial_browser_version"); + assert_eq!(result[0].value, json!(["125"])); assert_eq!(result[0].prop_type, "person"); } @@ -351,7 +329,7 @@ mod tests { description: None, team_id: 1, deleted: false, - filters: json!({"type": "AND", "values": []}), + filters: json!({"properties": {"type": "AND", "values": []}}), query: None, version: None, pending_version: None, @@ -369,7 +347,7 @@ mod tests { description: None, team_id: 1, deleted: false, - filters: json!({"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 1, "type": "cohort"}]}]}), + filters: json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 1, "type": "cohort"}]}]}}), query: None, version: None, pending_version: None, @@ -386,7 +364,7 @@ mod tests { description: None, team_id: 1, deleted: false, - filters: json!({"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 2, "type": "cohort"}]}]}), + filters: json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 2, "type": "cohort"}]}]}}), query: None, version: None, pending_version: None, @@ -405,7 +383,7 @@ mod tests { #[test] fn test_cohort_property_to_property_filters() { - let cohort_property = CohortProperty { + let cohort_property = InnerCohortProperty { prop_type: CohortPropertyType::AND, values: vec![CohortValues { prop_type: "property".to_string(), diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index f7097e696da79..a78d3e683f2f1 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -3254,17 +3254,19 @@ mod tests { team.id, None, json!({ - "type": "OR", - "values": [{ + "properties": { "type": "OR", "values": [{ - "key": "$browser_version", - "type": "person", - "value": "125", - "negation": false, - "operator": "gt" + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] }] - }] + } }), false, ) @@ -3340,17 +3342,19 @@ mod tests { team.id, None, json!({ - "type": "OR", - "values": [{ + "properties": { "type": "OR", "values": [{ - "key": "$browser_version", - "type": "person", - "value": "130", - "negation": false, - "operator": "gt" + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "130", + "negation": false, + "operator": "gt" + }] }] - }] + } }), false, ) @@ -3426,17 +3430,19 @@ mod tests { team.id, None, json!({ - "type": "OR", - "values": [{ + "properties": { "type": "OR", "values": [{ - "key": "$browser_version", - "type": "person", - "value": "125", - "negation": false, - "operator": "gt" + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] }] - }] + } }), false, ) @@ -3513,17 +3519,19 @@ mod tests { team.id, None, json!({ - "type": "OR", - "values": [{ + "properties": { "type": "OR", "values": [{ - "key": "$browser_version", - "type": "person", - "value": "125", - "negation": false, - "operator": "gt" + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] }] - }] + } }), false, ) @@ -3536,17 +3544,19 @@ mod tests { team.id, None, json!({ - "type": "OR", - "values": [{ + "properties": { "type": "OR", "values": [{ - "key": "id", - "type": "cohort", - "value": base_cohort_row.id, - "negation": false, - "operator": "in" + "type": "OR", + "values": [{ + "key": "id", + "type": "cohort", + "value": base_cohort_row.id, + "negation": false, + "operator": "in" + }] }] - }] + } }), false, ) @@ -3623,17 +3633,19 @@ mod tests { team.id, None, json!({ - "type": "OR", - "values": [{ + "properties": { "type": "OR", "values": [{ - "key": "$browser_version", - "type": "person", - "value": "130", - "negation": false, - "operator": "gt" + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "130", + "negation": false, + "operator": "gt" + }] }] - }] + } }), false, ) From 43cda7655d1cbec4e8662418ab19adc7153c3bbc Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 24 Oct 2024 16:04:01 -0700 Subject: [PATCH 07/30] clippy my frickin GUY --- rust/feature-flags/src/flag_matching.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index a78d3e683f2f1..5f9869ac5fdf3 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -746,13 +746,12 @@ impl FeatureFlagMatcher { .await?; // Evaluate cohort conditions - if !cohort_filters.is_empty() { - if !self + if !cohort_filters.is_empty() + && !self .evaluate_cohort_filters(&cohort_filters, &properties_to_check) .await? - { - return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); - } + { + return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); } // Evaluate non-cohort properties From 8d2ab857ecfc7d2860d62e3570c269cffa48befa Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 24 Oct 2024 21:36:43 -0700 Subject: [PATCH 08/30] some light renaming --- rust/feature-flags/src/api.rs | 6 -- rust/feature-flags/src/cohort_models.rs | 74 +++++++++++++++ ...rt_definitions.rs => cohort_operations.rs} | 92 +++---------------- rust/feature-flags/src/flag_matching.rs | 5 +- rust/feature-flags/src/lib.rs | 5 +- .../src/{utils.rs => metrics_utils.rs} | 0 rust/feature-flags/src/router.rs | 2 +- rust/feature-flags/src/test_utils.rs | 6 +- 8 files changed, 95 insertions(+), 95 deletions(-) create mode 100644 rust/feature-flags/src/cohort_models.rs rename rust/feature-flags/src/{cohort_definitions.rs => cohort_operations.rs} (84%) rename rust/feature-flags/src/{utils.rs => metrics_utils.rs} (100%) diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index b9f12a4f77d41..4430476d28a52 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -102,8 +102,6 @@ pub enum FlagError { TimeoutError, #[error("No group type mappings")] NoGroupTypeMappings, - #[error("Invalid cohort id")] - InvalidCohortId, } impl IntoResponse for FlagError { @@ -196,10 +194,6 @@ impl IntoResponse for FlagError { "The requested row was not found in the database. Please try again later or contact support if the problem persists.".to_string(), ) } - FlagError::InvalidCohortId => { - tracing::error!("Invalid cohort id: {:?}", self); - (StatusCode::BAD_REQUEST, "Invalid cohort id".to_string()) - } } .into_response() } diff --git a/rust/feature-flags/src/cohort_models.rs b/rust/feature-flags/src/cohort_models.rs new file mode 100644 index 0000000000000..500f3a3fbb37a --- /dev/null +++ b/rust/feature-flags/src/cohort_models.rs @@ -0,0 +1,74 @@ +use crate::flag_definitions::PropertyFilter; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; + +#[derive(Debug, FromRow)] +pub struct CohortRow { + pub id: i32, + pub name: String, + pub description: Option, + pub team_id: i32, + pub deleted: bool, + pub filters: serde_json::Value, + pub query: Option, + pub version: Option, + pub pending_version: Option, + pub count: Option, + pub is_calculating: bool, + pub is_static: bool, + pub errors_calculating: i32, + pub groups: serde_json::Value, + pub created_by_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Cohort { + pub id: i32, + pub name: String, + pub description: Option, + pub team_id: i32, + pub deleted: bool, + pub filters: serde_json::Value, + pub query: Option, + pub version: Option, + pub pending_version: Option, + pub count: Option, + pub is_calculating: bool, + pub is_static: bool, + pub errors_calculating: i32, + pub groups: serde_json::Value, + pub created_by_id: Option, +} + +pub type CohortId = i32; + +pub enum CohortOrEmpty { + Cohort(Cohort), + Empty, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "UPPERCASE")] +pub enum CohortPropertyType { + AND, + OR, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CohortProperty { + pub properties: InnerCohortProperty, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct InnerCohortProperty { + #[serde(rename = "type")] + pub prop_type: CohortPropertyType, + pub values: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CohortValues { + #[serde(rename = "type")] + pub prop_type: String, + pub values: Vec, +} diff --git a/rust/feature-flags/src/cohort_definitions.rs b/rust/feature-flags/src/cohort_operations.rs similarity index 84% rename from rust/feature-flags/src/cohort_definitions.rs rename to rust/feature-flags/src/cohort_operations.rs index a655ff5f07259..9cf046785d85a 100644 --- a/rust/feature-flags/src/cohort_definitions.rs +++ b/rust/feature-flags/src/cohort_operations.rs @@ -1,49 +1,10 @@ -use serde::{Deserialize, Serialize}; -use sqlx::FromRow; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tracing::instrument; +use crate::cohort_models::{Cohort, CohortId, CohortOrEmpty, CohortRow, InnerCohortProperty}; use crate::{api::FlagError, database::Client as DatabaseClient, flag_definitions::PropertyFilter}; -#[derive(Debug, FromRow)] -pub struct CohortRow { - pub id: i32, - pub name: String, - pub description: Option, - pub team_id: i32, - pub deleted: bool, - pub filters: serde_json::Value, - pub query: Option, - pub version: Option, - pub pending_version: Option, - pub count: Option, - pub is_calculating: bool, - pub is_static: bool, - pub errors_calculating: Option, // I think this has a null constraint, so maybe it shouldn't be optional - pub groups: serde_json::Value, - pub created_by_id: Option, -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct Cohort { - pub id: i32, - pub name: String, - pub description: Option, - pub team_id: i32, - pub deleted: bool, - pub filters: serde_json::Value, - pub query: Option, - pub version: Option, - pub pending_version: Option, - pub count: Option, - pub is_calculating: bool, - pub is_static: bool, - pub errors_calculating: Option, - pub groups: serde_json::Value, - pub created_by_id: Option, -} - impl Cohort { /// Returns a cohort from postgres given a cohort_id and team_id #[instrument(skip_all)] @@ -104,33 +65,6 @@ impl Cohort { } } -pub type CohortId = i32; - -// Assuming CohortOrEmpty is an enum or struct representing a Cohort or an empty value -pub enum CohortOrEmpty { - Cohort(Cohort), - Empty, -} - -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] -#[serde(rename_all = "UPPERCASE")] -pub enum CohortPropertyType { - AND, - OR, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CohortProperty { - properties: InnerCohortProperty, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct InnerCohortProperty { - #[serde(rename = "type")] - prop_type: CohortPropertyType, - values: Vec, -} - impl InnerCohortProperty { pub fn to_property_filters(&self) -> Vec { self.values @@ -141,13 +75,6 @@ impl InnerCohortProperty { } } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct CohortValues { - #[serde(rename = "type")] - prop_type: String, - values: Vec, -} - /// Sorts the given cohorts in an order where cohorts with no dependencies are placed first, /// followed by cohorts that depend on the preceding ones. It ensures that each cohort in the sorted list /// only depends on cohorts that appear earlier in the list. @@ -257,9 +184,12 @@ pub fn sort_cohorts_topologically( #[cfg(test)] mod tests { use super::*; - use crate::test_utils::{ - insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client, - setup_pg_writer_client, + use crate::{ + cohort_models::{CohortPropertyType, CohortValues}, + test_utils::{ + insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client, + setup_pg_writer_client, + }, }; use serde_json::json; @@ -306,7 +236,7 @@ mod tests { count: None, is_calculating: false, is_static: false, - errors_calculating: None, + errors_calculating: 0, groups: json!({}), created_by_id: None, }; @@ -336,7 +266,7 @@ mod tests { count: None, is_calculating: false, is_static: false, - errors_calculating: None, + errors_calculating: 0, groups: json!({}), created_by_id: None, }), @@ -354,7 +284,7 @@ mod tests { count: None, is_calculating: false, is_static: false, - errors_calculating: None, + errors_calculating: 0, groups: json!({}), created_by_id: None, })); @@ -371,7 +301,7 @@ mod tests { count: None, is_calculating: false, is_static: false, - errors_calculating: None, + errors_calculating: 0, groups: json!({}), created_by_id: None, })); diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 5f9869ac5fdf3..e820fdbdefb24 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,12 +1,13 @@ use crate::{ api::{FlagError, FlagValue, FlagsResponse}, - cohort_definitions::{sort_cohorts_topologically, Cohort, CohortId, CohortOrEmpty}, + cohort_models::{Cohort, CohortId, CohortOrEmpty}, + cohort_operations::sort_cohorts_topologically, database::Client as DatabaseClient, feature_flag_match_reason::FeatureFlagMatchReason, flag_definitions::{FeatureFlag, FeatureFlagList, FlagGroupType, OperatorType, PropertyFilter}, metrics_consts::{FLAG_EVALUATION_ERROR_COUNTER, FLAG_HASH_KEY_WRITES_COUNTER}, + metrics_utils::parse_exception_for_prometheus_label, property_matching::match_property, - utils::parse_exception_for_prometheus_label, }; use anyhow::Result; use common_metrics::inc; diff --git a/rust/feature-flags/src/lib.rs b/rust/feature-flags/src/lib.rs index 2a001c1953914..8899566edb274 100644 --- a/rust/feature-flags/src/lib.rs +++ b/rust/feature-flags/src/lib.rs @@ -1,5 +1,6 @@ pub mod api; -pub mod cohort_definitions; +pub mod cohort_models; +pub mod cohort_operations; pub mod config; pub mod database; pub mod feature_flag_match_reason; @@ -9,13 +10,13 @@ pub mod flag_matching; pub mod flag_request; pub mod geoip; pub mod metrics_consts; +pub mod metrics_utils; pub mod property_matching; pub mod redis; pub mod request_handler; pub mod router; pub mod server; pub mod team; -pub mod utils; pub mod v0_endpoint; // Test modules don't need to be compiled with main binary diff --git a/rust/feature-flags/src/utils.rs b/rust/feature-flags/src/metrics_utils.rs similarity index 100% rename from rust/feature-flags/src/utils.rs rename to rust/feature-flags/src/metrics_utils.rs diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index 505f18adfb008..9cb6a8415cfd8 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -12,8 +12,8 @@ use crate::{ config::{Config, TeamIdsToTrack}, database::Client as DatabaseClient, geoip::GeoIpClient, + metrics_utils::team_id_label_filter, redis::Client as RedisClient, - utils::team_id_label_filter, v0_endpoint, }; diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index c6084d89a20e9..1a19ec4cd4c73 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use uuid::Uuid; use crate::{ - cohort_definitions::CohortRow, + cohort_models::CohortRow, config::{Config, DEFAULT_TEST_CONFIG}, database::{get_pool, Client, CustomDatabaseError}, flag_definitions::{self, FeatureFlag, FeatureFlagRow}, @@ -378,7 +378,7 @@ pub async fn insert_cohort_for_team_in_pg( count: None, is_calculating: false, is_static, - errors_calculating: Some(0), + errors_calculating: 0, groups: serde_json::json!([]), created_by_id: None, }; @@ -401,7 +401,7 @@ pub async fn insert_cohort_for_team_in_pg( .bind(cohort_row.count) .bind(cohort_row.is_calculating) .bind(cohort_row.is_static) - .bind(cohort_row.errors_calculating.unwrap_or(0)) + .bind(cohort_row.errors_calculating) .bind(&cohort_row.groups) .bind(cohort_row.created_by_id) .fetch_one(&mut *conn) From 9ccf47966bae9cb198986d2e6431fc108ab62317 Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 24 Oct 2024 21:41:51 -0700 Subject: [PATCH 09/30] yeah --- rust/feature-flags/src/flag_definitions.rs | 1 - rust/feature-flags/src/property_matching.rs | 3 --- rust/feature-flags/src/test_utils.rs | 3 +-- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index f081fbd98f21b..9d8d2e7074b9a 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -39,7 +39,6 @@ pub struct PropertyFilter { pub operator: Option, #[serde(rename = "type")] pub prop_type: String, - // TODO add negation here? pub negation: Option, pub group_type_index: Option, } diff --git a/rust/feature-flags/src/property_matching.rs b/rust/feature-flags/src/property_matching.rs index 2d13276135223..2f174117befe9 100644 --- a/rust/feature-flags/src/property_matching.rs +++ b/rust/feature-flags/src/property_matching.rs @@ -43,9 +43,6 @@ pub fn match_property( ))); } - println!("property: {:?}", property); - println!("matching_property_values: {:?}", matching_property_values); - let key = &property.key; let operator = property.operator.clone().unwrap_or(OperatorType::Exact); let value = &property.value; diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index 1a19ec4cd4c73..7b0c0fa4b2d4c 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -91,8 +91,6 @@ pub async fn insert_flags_for_team_in_redis( Ok(()) } -// type RedisClientTrait = dyn RedisClient + Send + Sync; - pub fn setup_redis_client(url: Option) -> Arc { let redis_url = match url { Some(value) => value, @@ -407,6 +405,7 @@ pub async fn insert_cohort_for_team_in_pg( .fetch_one(&mut *conn) .await?; + // Update the cohort_row with the actual id generated by sqlx let id = row.0; Ok(CohortRow { id, ..cohort_row }) From 797adbe4f9d46f3b05b200f5d1f63babdf3cb79c Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 24 Oct 2024 21:52:14 -0700 Subject: [PATCH 10/30] remove printlns --- rust/feature-flags/src/flag_matching.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index e820fdbdefb24..5d40fe951be1d 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -3323,7 +3323,6 @@ mod tests { ); let result = matcher.get_match(&flag, None, None).await.unwrap(); - println!("{:?}", result); assert!(result.matches); } @@ -3411,7 +3410,6 @@ mod tests { ); let result = matcher.get_match(&flag, None, None).await.unwrap(); - println!("{:?}", result); assert!(result.matches); } @@ -3499,7 +3497,6 @@ mod tests { ); let result = matcher.get_match(&flag, None, None).await.unwrap(); - println!("{:?}", result); // The user matches the cohort, but the flag is set to NotIn, so it should evaluate to false assert!(!result.matches); @@ -3613,7 +3610,6 @@ mod tests { ); let result = matcher.get_match(&flag, None, None).await.unwrap(); - println!("{:?}", result); // This test might fail if the system doesn't support cohort dependencies assert!(result.matches); @@ -3702,7 +3698,6 @@ mod tests { ); let result = matcher.get_match(&flag, None, None).await.unwrap(); - println!("{:?}", result); // The user does not match the cohort, and the flag is set to In, so it should evaluate to false assert!(!result.matches); From 71def6759039727f654746e9d4f5b28775878b7e Mon Sep 17 00:00:00 2001 From: dylan Date: Sun, 27 Oct 2024 21:29:23 -0700 Subject: [PATCH 11/30] add note about not handling groups --- rust/feature-flags/src/cohort_operations.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rust/feature-flags/src/cohort_operations.rs b/rust/feature-flags/src/cohort_operations.rs index 9cf046785d85a..2839432e04742 100644 --- a/rust/feature-flags/src/cohort_operations.rs +++ b/rust/feature-flags/src/cohort_operations.rs @@ -57,6 +57,9 @@ impl Cohort { } /// Parses the filters JSON into a CohortProperty structure + // TODO: this doesn't handle the deprecated "groups" field, see + // https://github.com/PostHog/posthog/blob/feat/dynamic-cohorts-rust/posthog/models/cohort/cohort.py#L114-L169 + // I'll handle that in a separate PR. pub fn parse_filters(&self) -> Result, FlagError> { let wrapper: serde_json::Value = serde_json::from_value(self.filters.clone())?; let cohort_property: InnerCohortProperty = From 27af814a8b3379a3d8be5b1169da8a784ed6996c Mon Sep 17 00:00:00 2001 From: dylan Date: Tue, 29 Oct 2024 14:05:21 -0700 Subject: [PATCH 12/30] saving a working version that supports caching, since this is the right idea. Next up I will implement a version that stores the dependency graph as well so that we can only cache the relevant cohorts instead of caching and iterating through cohort --- rust/feature-flags/src/api.rs | 30 +- rust/feature-flags/src/cohort_cache.rs | 184 +++++++++++++ rust/feature-flags/src/cohort_models.rs | 26 +- rust/feature-flags/src/cohort_operations.rs | 290 +++++++++----------- rust/feature-flags/src/flag_definitions.rs | 13 +- rust/feature-flags/src/flag_matching.rs | 206 +++++++++----- rust/feature-flags/src/lib.rs | 1 + rust/feature-flags/src/team.rs | 6 +- rust/feature-flags/src/test_utils.rs | 36 +-- 9 files changed, 503 insertions(+), 289 deletions(-) create mode 100644 rust/feature-flags/src/cohort_cache.rs diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index 4430476d28a52..f1fe8fe485999 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -89,7 +89,7 @@ pub enum FlagError { #[error("Row not found in postgres")] RowNotFound, #[error("failed to parse redis cache data")] - DataParsingError, + RedisDataParsingError, #[error("failed to update redis cache")] CacheUpdateError, #[error("redis unavailable")] @@ -102,6 +102,14 @@ pub enum FlagError { TimeoutError, #[error("No group type mappings")] NoGroupTypeMappings, + #[error("Cohort not found")] + CohortNotFound(String), + #[error("Failed to parse cohort filters")] + CohortFiltersParsingError, + #[error("Cohort dependency cycle")] + CohortDependencyCycle(String), + #[error("Cohort dependency error")] + CohortDependencyError(String), } impl IntoResponse for FlagError { @@ -138,7 +146,7 @@ impl IntoResponse for FlagError { FlagError::TokenValidationError => { (StatusCode::UNAUTHORIZED, "The provided API key is invalid or has expired. Please check your API key and try again.".to_string()) } - FlagError::DataParsingError => { + FlagError::RedisDataParsingError => { tracing::error!("Data parsing error: {:?}", self); ( StatusCode::SERVICE_UNAVAILABLE, @@ -194,6 +202,22 @@ impl IntoResponse for FlagError { "The requested row was not found in the database. Please try again later or contact support if the problem persists.".to_string(), ) } + FlagError::CohortNotFound(msg) => { + tracing::error!("Cohort not found: {}", msg); + (StatusCode::NOT_FOUND, msg) + } + FlagError::CohortFiltersParsingError => { + tracing::error!("Failed to parse cohort filters: {:?}", self); + (StatusCode::BAD_REQUEST, "Failed to parse cohort filters. Please try again later or contact support if the problem persists.".to_string()) + } + FlagError::CohortDependencyCycle(msg) => { + tracing::error!("Cohort dependency cycle: {}", msg); + (StatusCode::BAD_REQUEST, msg) + } + FlagError::CohortDependencyError(msg) => { + tracing::error!("Cohort dependency error: {}", msg); + (StatusCode::BAD_REQUEST, msg) + } } .into_response() } @@ -205,7 +229,7 @@ impl From for FlagError { CustomRedisError::NotFound => FlagError::TokenValidationError, CustomRedisError::PickleError(e) => { tracing::error!("failed to fetch data: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError } CustomRedisError::Timeout(_) => FlagError::TimeoutError, CustomRedisError::Other(e) => { diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs new file mode 100644 index 0000000000000..a0e3de8782533 --- /dev/null +++ b/rust/feature-flags/src/cohort_cache.rs @@ -0,0 +1,184 @@ +use crate::api::FlagError; +use crate::cohort_models::{Cohort, CohortId}; +use crate::cohort_operations::sort_cohorts_topologically; +use crate::flag_definitions::{OperatorType, PropertyFilter}; +use crate::flag_matching::{PostgresReader, TeamId}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use tokio::sync::RwLock; + +pub type TeamCohortMap = HashMap>; +pub type TeamSortedCohorts = HashMap>; +pub type TeamCacheMap = HashMap; + +#[derive(Debug, Clone)] +pub struct CachedCohort { + // TODO name this something different + pub filters: Vec, // Non-cohort property filters + pub dependencies: Vec, // Dependencies with operators +} + +// Add this struct to facilitate handling operators on cohort filters +#[derive(Debug, Clone)] +pub struct CohortDependencyFilter { + pub cohort_id: CohortId, + pub operator: OperatorType, +} + +/// Threadsafety is ensured using Arc and RwLock +#[derive(Clone, Debug)] +pub struct CohortCache { + /// Mapping from TeamId to their respective CohortId and associated CachedCohort + per_team: Arc>>>, + /// Mapping from TeamId to sorted CohortIds based on dependencies + sorted_cohorts: Arc>>>, +} + +impl CohortCache { + /// Creates a new CohortCache instance + pub fn new() -> Self { + Self { + per_team: Arc::new(RwLock::new(HashMap::new())), + sorted_cohorts: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Fetches, flattens, sorts, and caches all cohorts for a given team if not already cached + pub async fn fetch_and_cache_cohorts( + &self, + team_id: TeamId, + postgres_reader: PostgresReader, + ) -> Result<(), FlagError> { + // Acquire write locks to modify the cache + let mut cache = self.per_team.write().await; + let mut sorted = self.sorted_cohorts.write().await; + + // Check if the team's cohorts are already cached + if cache.contains_key(&team_id) && sorted.contains_key(&team_id) { + return Ok(()); + } + + // Fetch all cohorts for the team from the database + let all_cohorts = fetch_all_cohorts(team_id, postgres_reader).await?; + + // Flatten the property filters, resolving dependencies + let flattened = flatten_cohorts(&all_cohorts).await?; + + // Extract all cohort IDs + let cohort_ids: HashSet = flattened.keys().cloned().collect(); + + // Sort the cohorts topologically based on dependencies + let sorted_ids = sort_cohorts_topologically(cohort_ids, &flattened)?; + + // Insert the flattened cohorts and their sorted order into the cache + cache.insert(team_id, flattened); + sorted.insert(team_id, sorted_ids); + + Ok(()) + } + + /// Retrieves sorted cohort IDs for a team from the cache + pub async fn get_sorted_cohort_ids( + &self, + team_id: TeamId, + postgres_reader: PostgresReader, + ) -> Result, FlagError> { + { + // Acquire read locks to check the cache + let cache = self.per_team.read().await; + let sorted = self.sorted_cohorts.read().await; + if let (Some(_cohort_map), Some(sorted_ids)) = + (cache.get(&team_id), sorted.get(&team_id)) + { + if !sorted_ids.is_empty() { + return Ok(sorted_ids.clone()); + } + } + } + + // If not cached, fetch and cache the cohorts + self.fetch_and_cache_cohorts(team_id, postgres_reader) + .await?; + + // Acquire read locks to retrieve the sorted list after caching + let sorted = self.sorted_cohorts.read().await; + if let Some(sorted_ids) = sorted.get(&team_id) { + Ok(sorted_ids.clone()) + } else { + Ok(Vec::new()) + } + } + + /// Retrieves cached cohorts for a team + pub async fn get_cached_cohorts( + &self, + team_id: TeamId, + ) -> Result, FlagError> { + let cache = self.per_team.read().await; + if let Some(cohort_map) = cache.get(&team_id) { + Ok(cohort_map.clone()) + } else { + Ok(HashMap::new()) + } + } +} + +async fn fetch_all_cohorts( + team_id: TeamId, + postgres_reader: PostgresReader, +) -> Result, FlagError> { + let mut conn = postgres_reader.get_connection().await?; + + let query = r#" + SELECT * + FROM posthog_cohort + WHERE team_id = $1 AND deleted = FALSE + "#; + + let cohorts: Vec = sqlx::query_as::<_, Cohort>(query) + .bind(team_id) + .fetch_all(&mut *conn) + .await + .map_err(|e| FlagError::DatabaseError(e.to_string()))?; + + Ok(cohorts) +} + +async fn flatten_cohorts( + all_cohorts: &Vec, +) -> Result, FlagError> { + let mut flattened = HashMap::new(); + + for cohort in all_cohorts { + let filters = cohort.parse_filters()?; + + // Extract dependencies from cohort filters + let dependencies = filters + .iter() + .filter_map(|f| { + if f.prop_type == "cohort" { + Some(CohortDependencyFilter { + cohort_id: f.value.as_i64().unwrap() as CohortId, + operator: f.operator.clone().unwrap_or(OperatorType::In), + }) + } else { + None + } + }) + .collect(); + + // Filter out cohort filters as they are now represented as dependencies + let non_cohort_filters: Vec = filters + .into_iter() + .filter(|f| f.prop_type != "cohort") + .collect(); + let cached_cohort = CachedCohort { + filters: non_cohort_filters, + dependencies, + }; + + flattened.insert(cohort.id, cached_cohort); + } + + Ok(flattened) +} diff --git a/rust/feature-flags/src/cohort_models.rs b/rust/feature-flags/src/cohort_models.rs index 500f3a3fbb37a..d109983901772 100644 --- a/rust/feature-flags/src/cohort_models.rs +++ b/rust/feature-flags/src/cohort_models.rs @@ -2,26 +2,7 @@ use crate::flag_definitions::PropertyFilter; use serde::{Deserialize, Serialize}; use sqlx::FromRow; -#[derive(Debug, FromRow)] -pub struct CohortRow { - pub id: i32, - pub name: String, - pub description: Option, - pub team_id: i32, - pub deleted: bool, - pub filters: serde_json::Value, - pub query: Option, - pub version: Option, - pub pending_version: Option, - pub count: Option, - pub is_calculating: bool, - pub is_static: bool, - pub errors_calculating: i32, - pub groups: serde_json::Value, - pub created_by_id: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] pub struct Cohort { pub id: i32, pub name: String, @@ -42,11 +23,6 @@ pub struct Cohort { pub type CohortId = i32; -pub enum CohortOrEmpty { - Cohort(Cohort), - Empty, -} - #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] #[serde(rename_all = "UPPERCASE")] pub enum CohortPropertyType { diff --git a/rust/feature-flags/src/cohort_operations.rs b/rust/feature-flags/src/cohort_operations.rs index 2839432e04742..b11fd1ca78c83 100644 --- a/rust/feature-flags/src/cohort_operations.rs +++ b/rust/feature-flags/src/cohort_operations.rs @@ -2,7 +2,8 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tracing::instrument; -use crate::cohort_models::{Cohort, CohortId, CohortOrEmpty, CohortRow, InnerCohortProperty}; +use crate::cohort_cache::CachedCohort; +use crate::cohort_models::{Cohort, CohortId, CohortProperty, InnerCohortProperty}; use crate::{api::FlagError, database::Client as DatabaseClient, flag_definitions::PropertyFilter}; impl Cohort { @@ -21,7 +22,7 @@ impl Cohort { })?; let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id FROM posthog_cohort WHERE id = $1 AND team_id = $2"; - let cohort_row = sqlx::query_as::<_, CohortRow>(query) + let cohort = sqlx::query_as::<_, Cohort>(query) .bind(cohort_id) .bind(team_id) .fetch_optional(&mut *conn) @@ -31,29 +32,12 @@ impl Cohort { FlagError::Internal(format!("Database query error: {}", e)) })?; - match cohort_row { - Some(row) => Ok(Cohort { - id: row.id, - name: row.name, - description: row.description, - team_id: row.team_id, - deleted: row.deleted, - filters: row.filters, - query: row.query, - version: row.version, - pending_version: row.pending_version, - count: row.count, - is_calculating: row.is_calculating, - is_static: row.is_static, - errors_calculating: row.errors_calculating, - groups: row.groups, - created_by_id: row.created_by_id, - }), - None => Err(FlagError::DatabaseError(format!( + cohort.ok_or_else(|| { + FlagError::CohortNotFound(format!( "Cohort with id {} not found for team {}", cohort_id, team_id - ))), - } + )) + }) } /// Parses the filters JSON into a CohortProperty structure @@ -61,10 +45,12 @@ impl Cohort { // https://github.com/PostHog/posthog/blob/feat/dynamic-cohorts-rust/posthog/models/cohort/cohort.py#L114-L169 // I'll handle that in a separate PR. pub fn parse_filters(&self) -> Result, FlagError> { - let wrapper: serde_json::Value = serde_json::from_value(self.filters.clone())?; - let cohort_property: InnerCohortProperty = - serde_json::from_value(wrapper["properties"].clone())?; - Ok(cohort_property.to_property_filters()) + let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone()) + .map_err(|e| { + tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e); + FlagError::CohortFiltersParsingError + })?; + Ok(cohort_property.properties.to_property_filters()) } } @@ -83,105 +69,76 @@ impl InnerCohortProperty { /// only depends on cohorts that appear earlier in the list. pub fn sort_cohorts_topologically( cohort_ids: HashSet, - seen_cohorts_cache: &HashMap, -) -> Vec { + cached_cohorts: &HashMap, +) -> Result, FlagError> { if cohort_ids.is_empty() { - return Vec::new(); + return Ok(Vec::new()); } let mut dependency_graph: HashMap> = HashMap::new(); - let mut seen = HashSet::new(); - - // Build graph (adjacency list) - fn traverse( - cohort: &Cohort, - dependency_graph: &mut HashMap>, - seen_cohorts: &mut HashSet, - seen_cohorts_cache: &HashMap, - ) { - if seen_cohorts.contains(&cohort.id) { - return; - } - seen_cohorts.insert(cohort.id); - - // Parse the filters into PropertyFilters - let property_filters = match cohort.parse_filters() { - Ok(filters) => filters, - Err(e) => { - tracing::error!("Error parsing filters for cohort {}: {}", cohort.id, e); - return; - } - }; - - // Iterate through the property filters to find dependencies - for filter in property_filters { - if filter.prop_type == "cohort" { - let child_id = match filter.value { - serde_json::Value::Number(num) => num.as_i64().map(|n| n as CohortId), - serde_json::Value::String(ref s) => s.parse::().ok(), - _ => None, - }; - - if let Some(child_id) = child_id { - dependency_graph - .entry(cohort.id) - .or_default() - .push(child_id); - - if let Some(CohortOrEmpty::Cohort(child_cohort)) = - seen_cohorts_cache.get(&child_id) - { - traverse( - child_cohort, - dependency_graph, - seen_cohorts, - seen_cohorts_cache, - ); - } - } + for &cohort_id in &cohort_ids { + if let Some(cohort) = cached_cohorts.get(&cohort_id) { + for dependency in &cohort.dependencies { + dependency_graph + .entry(cohort_id) + .or_default() + .push(dependency.cohort_id); } } } - for &cohort_id in &cohort_ids { - if let Some(CohortOrEmpty::Cohort(cohort)) = seen_cohorts_cache.get(&cohort_id) { - traverse(cohort, &mut dependency_graph, &mut seen, seen_cohorts_cache); - } - } + let mut sorted = Vec::new(); + let mut temporary_marks = HashSet::new(); + let mut permanent_marks = HashSet::new(); - // Post-order DFS (children first, then the parent) - fn dfs( + fn visit( node: CohortId, - seen: &mut HashSet, - sorted_arr: &mut Vec, dependency_graph: &HashMap>, - ) { + temporary_marks: &mut HashSet, + permanent_marks: &mut HashSet, + sorted: &mut Vec, + ) -> Result<(), FlagError> { + if permanent_marks.contains(&node) { + return Ok(()); + } + if temporary_marks.contains(&node) { + return Err(FlagError::CohortDependencyCycle(format!( + "Cycle detected at cohort {}", + node + ))); + } + + temporary_marks.insert(node); if let Some(neighbors) = dependency_graph.get(&node) { for &neighbor in neighbors { - if !seen.contains(&neighbor) { - dfs(neighbor, seen, sorted_arr, dependency_graph); - } + visit( + neighbor, + dependency_graph, + temporary_marks, + permanent_marks, + sorted, + )?; } } - sorted_arr.push(node); - seen.insert(node); + temporary_marks.remove(&node); + permanent_marks.insert(node); + sorted.push(node); + Ok(()) } - let mut sorted_cohort_ids = Vec::new(); - let mut seen = HashSet::new(); - for &cohort_id in &cohort_ids { - if !seen.contains(&cohort_id) { - seen.insert(cohort_id); - dfs( - cohort_id, - &mut seen, - &mut sorted_cohort_ids, + for &node in &cohort_ids { + if !permanent_marks.contains(&node) { + visit( + node, &dependency_graph, - ); + &mut temporary_marks, + &mut permanent_marks, + &mut sorted, + )?; } } - sorted_cohort_ids + Ok(sorted) } #[cfg(test)] @@ -251,68 +208,69 @@ mod tests { assert_eq!(result[0].prop_type, "person"); } - #[test] - fn test_sort_cohorts_topologically() { - let mut cohorts = HashMap::new(); - cohorts.insert( - 1, - CohortOrEmpty::Cohort(Cohort { - id: 1, - name: "Cohort 1".to_string(), - description: None, - team_id: 1, - deleted: false, - filters: json!({"properties": {"type": "AND", "values": []}}), - query: None, - version: None, - pending_version: None, - count: None, - is_calculating: false, - is_static: false, - errors_calculating: 0, - groups: json!({}), - created_by_id: None, - }), - ); - cohorts.insert(2, CohortOrEmpty::Cohort(Cohort { - id: 2, - name: "Cohort 2".to_string(), - description: None, - team_id: 1, - deleted: false, - filters: json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 1, "type": "cohort"}]}]}}), - query: None, - version: None, - pending_version: None, - count: None, - is_calculating: false, - is_static: false, - errors_calculating: 0, - groups: json!({}), - created_by_id: None, - })); - cohorts.insert(3, CohortOrEmpty::Cohort(Cohort { - id: 3, - name: "Cohort 3".to_string(), - description: None, - team_id: 1, - deleted: false, - filters: json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 2, "type": "cohort"}]}]}}), - query: None, - version: None, - pending_version: None, - count: None, - is_calculating: false, - is_static: false, - errors_calculating: 0, - groups: json!({}), - created_by_id: None, - })); + // #[test] + // fn test_sort_cohorts_topologically() { + // let mut cohorts = HashMap::new(); + // cohorts.insert( + // 1, + // Cohort { + // id: 1, + // name: "Cohort 1".to_string(), + // description: None, + // team_id: 1, + // deleted: false, + // filters: json!({"properties": {"type": "AND", "values": []}}), + // query: None, + // version: None, + // pending_version: None, + // count: None, + // is_calculating: false, + // is_static: false, + // errors_calculating: 0, + // groups: json!({}), + // created_by_id: None, + // }, + // ); + // cohorts.insert(2, Cohort { + // id: 2, + // name: "Cohort 2".to_string(), + // description: None, + // team_id: 1, + // deleted: false, + // filters: json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 1, "type": "cohort"}]}]}}), + // query: None, + // version: None, + // pending_version: None, + // count: None, + // is_calculating: false, + // is_static: false, + // errors_calculating: 0, + // groups: json!({}), + // created_by_id: None, + // }); + // cohorts.insert( + // 3, Cohort { + // id: 3, + // name: "Cohort 3".to_string(), + // description: None, + // team_id: 1, + // deleted: false, + // filters: json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 2, "type": "cohort"}]}]}}), + // query: None, + // version: None, + // pending_version: None, + // count: None, + // is_calculating: false, + // is_static: false, + // errors_calculating: 0, + // groups: json!({}), + // created_by_id: None, + // }); - let cohort_ids: HashSet = vec![1, 2, 3].into_iter().collect(); - let result = sort_cohorts_topologically(cohort_ids, &cohorts); - assert_eq!(result, vec![1, 2, 3]); - } + // let cohort_ids: HashSet = vec![1, 2, 3].into_iter().collect(); + // let result = sort_cohorts_topologically(cohort_ids, &cohorts); + // assert_eq!(result, vec![1, 2, 3]); + // } #[test] fn test_cohort_property_to_property_filters() { diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index 9d8d2e7074b9a..4cf3e88e2c16c 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -7,7 +7,7 @@ use tracing::instrument; // TODO: Add integration tests across repos to ensure this doesn't happen. pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_"; -#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum OperatorType { Exact, @@ -71,6 +71,9 @@ pub struct FlagFilters { pub super_groups: Option>, } +// TODO: see if you can combine these two structs, like we do with cohort models +// this will require not deserializing on read and instead doing it lazily, on-demand +// (which, tbh, is probably a better idea) #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FeatureFlag { pub id: i32, @@ -145,7 +148,7 @@ impl FeatureFlagList { tracing::error!("failed to parse data to flags list: {}", e); println!("failed to parse data: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; Ok(FeatureFlagList { flags: flags_list }) @@ -177,7 +180,7 @@ impl FeatureFlagList { .map(|row| { let filters = serde_json::from_value(row.filters).map_err(|e| { tracing::error!("Failed to deserialize filters for flag {}: {}", row.key, e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; Ok(FeatureFlag { @@ -203,7 +206,7 @@ impl FeatureFlagList { ) -> Result<(), FlagError> { let payload = serde_json::to_string(&flags.flags).map_err(|e| { tracing::error!("Failed to serialize flags: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; client @@ -1098,7 +1101,7 @@ mod tests { .expect("Failed to set malformed JSON in Redis"); let result = FeatureFlagList::from_redis(redis_client, team.id).await; - assert!(matches!(result, Err(FlagError::DataParsingError))); + assert!(matches!(result, Err(FlagError::RedisDataParsingError))); // Test database query error (using a non-existent table) let result = sqlx::query("SELECT * FROM non_existent_table") diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 5d40fe951be1d..2384097b26880 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,7 +1,7 @@ use crate::{ api::{FlagError, FlagValue, FlagsResponse}, - cohort_models::{Cohort, CohortId, CohortOrEmpty}, - cohort_operations::sort_cohorts_topologically, + cohort_cache::{CachedCohort, CohortCache, CohortDependencyFilter}, + cohort_models::CohortId, database::Client as DatabaseClient, feature_flag_match_reason::FeatureFlagMatchReason, flag_definitions::{FeatureFlag, FeatureFlagList, FlagGroupType, OperatorType, PropertyFilter}, @@ -23,10 +23,10 @@ use std::{ use tokio::time::{sleep, timeout}; use tracing::{error, info}; -type TeamId = i32; -type GroupTypeIndex = i32; -type PostgresReader = Arc; -type PostgresWriter = Arc; +pub type TeamId = i32; +pub type GroupTypeIndex = i32; +pub type PostgresReader = Arc; +pub type PostgresWriter = Arc; #[derive(Debug)] struct SuperConditionEvaluation { @@ -187,6 +187,7 @@ pub struct FeatureFlagMatcher { group_type_mapping_cache: GroupTypeMappingCache, properties_cache: PropertiesCache, groups: HashMap, + cohort_cache: CohortCache, } const LONG_SCALE: u64 = 0xfffffffffffffff; @@ -206,10 +207,11 @@ impl FeatureFlagMatcher { team_id, postgres_reader: postgres_reader.clone(), postgres_writer: postgres_writer.clone(), + groups: groups.unwrap_or_default(), group_type_mapping_cache: group_type_mapping_cache .unwrap_or_else(|| GroupTypeMappingCache::new(team_id, postgres_reader.clone())), properties_cache: properties_cache.unwrap_or_default(), - groups: groups.unwrap_or_default(), + cohort_cache: CohortCache::new(), } } @@ -718,7 +720,7 @@ impl FeatureFlagMatcher { /// It first checks if the condition has any property filters. If not, it performs a rollout check. /// Otherwise, it fetches the relevant properties and checks if they match the condition's filters. /// The function returns a tuple indicating whether the condition matched and the reason for the match. - async fn is_condition_match( + pub async fn is_condition_match( &mut self, feature_flag: &FeatureFlag, condition: &FlagGroupType, @@ -746,13 +748,14 @@ impl FeatureFlagMatcher { .get_properties_to_check(feature_flag, property_overrides, &non_cohort_filters) .await?; - // Evaluate cohort conditions - if !cohort_filters.is_empty() - && !self + // Evaluate cohort filters + if !cohort_filters.is_empty() { + let cohorts_match = self .evaluate_cohort_filters(&cohort_filters, &properties_to_check) - .await? - { - return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); + .await?; + if !cohorts_match { + return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); + } } // Evaluate non-cohort properties @@ -824,83 +827,148 @@ impl FeatureFlagMatcher { } } + /// Evaluates cohort-based property filters pub async fn evaluate_cohort_filters( &self, - filters: &[PropertyFilter], + cohort_filters: &[PropertyFilter], target_properties: &HashMap, ) -> Result { - Box::pin(self.evaluate_potentially_nested_cohort_filters(filters, target_properties)).await + // Step 1: Extract the cohort filters into a list of CohortDependencyFilters + // We will need these because the are the filters associated with the flag itself, and required for the final + // step of evaluating the cohort membership logic + let filter_cohorts = self.extract_cohort_filters(cohort_filters)?; + + // Get all of the cohort IDs associated with the team and sort them topologically by dependency relationships + // We will need to evaluate the cohort membership logic for each of these cohorts in order to evaluate the cohort filters + // themselves. + // TODO can the sorted cohort IDs be aware of IDs in the filter_cohorts list? Or somehow let us know which ones + // we need to evaluate? + let sorted_cohort_ids = self + .cohort_cache + .get_sorted_cohort_ids(self.team_id, self.postgres_reader.clone()) + .await?; + + // Step 6: Retrieve cached and flattened cohorts associated with this team + let cached_cohorts = self.cohort_cache.get_cached_cohorts(self.team_id).await?; + + // Step 7: Evaluate cohort dependencies and property filters + let cohort_matches = self.evaluate_cohorts_and_associated_dependencies( + &sorted_cohort_ids, + &cached_cohorts, + target_properties, + )?; + + // Step 8: Apply any cohort operator logic to determine final matching + self.apply_cohort_membership_logic(&filter_cohorts, &cohort_matches) + } + + fn extract_cohort_filters( + &self, + cohort_filters: &[PropertyFilter], + ) -> Result, FlagError> { + let filter_cohorts = cohort_filters + .iter() + .filter_map(|f| { + if f.key == "id" && f.prop_type == "cohort" { + let cohort_id = f.value.as_i64().map(|id| id as CohortId)?; // TODO handle error? + let operator = f.operator.clone().unwrap_or(OperatorType::In); + Some(CohortDependencyFilter { + cohort_id, + operator, + }) + } else { + None + } + }) + .collect(); + + Ok(filter_cohorts) } - async fn evaluate_potentially_nested_cohort_filters( + fn evaluate_cohorts_and_associated_dependencies( &self, - filters: &[PropertyFilter], + sorted_cohort_ids: &[CohortId], + cached_cohorts: &HashMap, target_properties: &HashMap, - ) -> Result { - let mut cohort_filters = Vec::new(); - let mut non_cohort_filters = Vec::new(); + ) -> Result, FlagError> { + let mut cohort_matches = HashMap::new(); + + for &cohort_id in sorted_cohort_ids { + let cached_cohort: &CachedCohort = match cached_cohorts.get(&cohort_id) { + Some(cohort) => cohort, + None => { + return Err(FlagError::CohortNotFound(format!( + "Cohort ID {} not found in cache", + cohort_id + ))); + } + }; - // Separate cohort filters from non-cohort filters - for filter in filters { - if filter.prop_type == "cohort" { - cohort_filters.push(filter); - } else { - non_cohort_filters.push(filter); + // Evaluate dependent cohorts membership + if !self.dependencies_satisfied(&cached_cohort.dependencies, &cohort_matches)? { + cohort_matches.insert(cohort_id, false); + continue; } - } - // Evaluate non-cohort filters - for filter in &non_cohort_filters { - if !match_property(filter, target_properties, false).unwrap_or(false) { - return Ok(false); - } - } + // Evaluate property filters associated with the cohort + let is_match = cached_cohort.filters.iter().all(|prop_filter| { + match_property(prop_filter, target_properties, false).unwrap_or(false) + }); - // Evaluate cohort filters - if !cohort_filters.is_empty() { - let cohort_ids: HashSet = cohort_filters - .iter() - .filter_map(|f| f.value.as_i64().map(|id| id as CohortId)) - .collect(); + cohort_matches.insert(cohort_id, is_match); + } - let cohorts = self.fetch_cohorts(cohort_ids.clone()).await?; - let seen_cohorts_cache: HashMap = cohorts - .into_iter() - .map(|cohort| (cohort.id, CohortOrEmpty::Cohort(cohort))) - .collect(); + Ok(cohort_matches) + } - let sorted_cohort_ids = sort_cohorts_topologically(cohort_ids, &seen_cohorts_cache); + fn dependencies_satisfied( + &self, + dependencies: &[CohortDependencyFilter], + cohort_matches: &HashMap, + ) -> Result { + for dependency in dependencies { + match cohort_matches.get(&dependency.cohort_id) { + Some(true) => match dependency.operator { + OperatorType::In => continue, + OperatorType::NotIn => return Ok(false), + _ => {} + }, + Some(false) => match dependency.operator { + OperatorType::In => return Ok(false), + OperatorType::NotIn => continue, + _ => {} + }, + None => return Ok(false), + } + } + Ok(true) + } - for cohort_id in sorted_cohort_ids { - if let Some(CohortOrEmpty::Cohort(cohort)) = seen_cohorts_cache.get(&cohort_id) { - let cohort_property_filters = cohort.parse_filters()?; // TODO error handle - let cohort_match = self - .evaluate_cohort_filters(&cohort_property_filters, target_properties) - .await?; + fn apply_cohort_membership_logic( + &self, + filter_cohorts: &[CohortDependencyFilter], + cohort_matches: &HashMap, + ) -> Result { + for filter_cohort in filter_cohorts { + let cohort_match = cohort_matches + .get(&filter_cohort.cohort_id) + .copied() + .unwrap_or(false); - let filter = cohort_filters - .iter() - .find(|f| f.value.as_i64() == Some(cohort_id as i64)) - .unwrap(); - match filter.operator { - Some(OperatorType::In) if !cohort_match => return Ok(false), - Some(OperatorType::NotIn) if cohort_match => return Ok(false), - _ => {} - } - } + if !self.cohort_membership_operator(filter_cohort.operator, cohort_match) { + return Ok(false); } } - Ok(true) } - async fn fetch_cohorts(&self, cohort_ids: HashSet) -> Result, FlagError> { - let mut cohorts = Vec::new(); - for &id in &cohort_ids { - let cohort = Cohort::from_pg(self.postgres_reader.clone(), id, self.team_id).await?; - cohorts.push(cohort); + fn cohort_membership_operator(&self, operator: OperatorType, match_status: bool) -> bool { + match operator { + OperatorType::In => match_status, + OperatorType::NotIn => !match_status, + // TODO, there shouldn't be any other operators here since this is only called from evaluate_cohort_dependencies + _ => false, // Handle other operators as needed } - Ok(cohorts) } /// Check if a super condition matches for a feature flag. diff --git a/rust/feature-flags/src/lib.rs b/rust/feature-flags/src/lib.rs index 8899566edb274..67659bfcf9dcd 100644 --- a/rust/feature-flags/src/lib.rs +++ b/rust/feature-flags/src/lib.rs @@ -1,4 +1,5 @@ pub mod api; +pub mod cohort_cache; pub mod cohort_models; pub mod cohort_operations; pub mod config; diff --git a/rust/feature-flags/src/team.rs b/rust/feature-flags/src/team.rs index 0fa75f0bd3db7..f13cf29094b85 100644 --- a/rust/feature-flags/src/team.rs +++ b/rust/feature-flags/src/team.rs @@ -42,7 +42,7 @@ impl Team { // TODO: Consider an LRU cache for teams as well, with small TTL to skip redis/pg lookups let team: Team = serde_json::from_str(&serialized_team).map_err(|e| { tracing::error!("failed to parse data to team: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; Ok(team) @@ -55,7 +55,7 @@ impl Team { ) -> Result<(), FlagError> { let serialized_team = serde_json::to_string(&team).map_err(|e| { tracing::error!("Failed to serialize team: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; client @@ -173,7 +173,7 @@ mod tests { let client = setup_redis_client(None); match Team::from_redis(client.clone(), team.api_token.clone()).await { - Err(FlagError::DataParsingError) => (), + Err(FlagError::RedisDataParsingError) => (), Err(other) => panic!("Expected DataParsingError, got {:?}", other), Ok(_) => panic!("Expected DataParsingError"), }; diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index 7b0c0fa4b2d4c..d7f7d53fbd22b 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use uuid::Uuid; use crate::{ - cohort_models::CohortRow, + cohort_models::Cohort, config::{Config, DEFAULT_TEST_CONFIG}, database::{get_pool, Client, CustomDatabaseError}, flag_definitions::{self, FeatureFlag, FeatureFlagRow}, @@ -362,8 +362,8 @@ pub async fn insert_cohort_for_team_in_pg( name: Option, filters: serde_json::Value, is_static: bool, -) -> Result { - let cohort_row = CohortRow { +) -> Result { + let cohort = Cohort { id: 0, // Placeholder, will be updated after insertion name: name.unwrap_or("Test Cohort".to_string()), description: Some("Description for cohort".to_string()), @@ -388,25 +388,25 @@ pub async fn insert_cohort_for_team_in_pg( ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING id"#, ) - .bind(&cohort_row.name) - .bind(&cohort_row.description) - .bind(cohort_row.team_id) - .bind(cohort_row.deleted) - .bind(&cohort_row.filters) - .bind(&cohort_row.query) - .bind(cohort_row.version) - .bind(cohort_row.pending_version) - .bind(cohort_row.count) - .bind(cohort_row.is_calculating) - .bind(cohort_row.is_static) - .bind(cohort_row.errors_calculating) - .bind(&cohort_row.groups) - .bind(cohort_row.created_by_id) + .bind(&cohort.name) + .bind(&cohort.description) + .bind(cohort.team_id) + .bind(cohort.deleted) + .bind(&cohort.filters) + .bind(&cohort.query) + .bind(cohort.version) + .bind(cohort.pending_version) + .bind(cohort.count) + .bind(cohort.is_calculating) + .bind(cohort.is_static) + .bind(cohort.errors_calculating) + .bind(&cohort.groups) + .bind(cohort.created_by_id) .fetch_one(&mut *conn) .await?; // Update the cohort_row with the actual id generated by sqlx let id = row.0; - Ok(CohortRow { id, ..cohort_row }) + Ok(Cohort { id, ..cohort }) } From 4c49bc46c5faf6534b758b4f665b64fc1e633952 Mon Sep 17 00:00:00 2001 From: dylan Date: Tue, 29 Oct 2024 21:54:54 -0700 Subject: [PATCH 13/30] new life --- rust/Cargo.lock | 1 + rust/feature-flags/Cargo.toml | 1 + rust/feature-flags/src/cohort_cache.rs | 279 ++++++++++---------- rust/feature-flags/src/cohort_operations.rs | 142 +++++----- rust/feature-flags/src/flag_definitions.rs | 20 +- rust/feature-flags/src/flag_matching.rs | 159 ++++------- 6 files changed, 283 insertions(+), 319 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 3f3c36157e36b..5bb5fc25b318d 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1316,6 +1316,7 @@ dependencies = [ "health", "maxminddb", "once_cell", + "petgraph", "rand", "redis", "regex", diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index 4cf4016767be6..9847569394bb2 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -39,6 +39,7 @@ health = { path = "../common/health" } common-metrics = { path = "../common/metrics" } tower = { workspace = true } derive_builder = "0.20.1" +petgraph = "0.6.5" [lints] workspace = true diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index a0e3de8782533..2991a73988c26 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -1,184 +1,191 @@ use crate::api::FlagError; -use crate::cohort_models::{Cohort, CohortId}; -use crate::cohort_operations::sort_cohorts_topologically; -use crate::flag_definitions::{OperatorType, PropertyFilter}; -use crate::flag_matching::{PostgresReader, TeamId}; +use crate::cohort_models::{Cohort, CohortId, CohortProperty}; +use crate::flag_definitions::PropertyFilter; +use crate::flag_matching::PostgresReader; +use petgraph::algo::toposort; +use petgraph::graphmap::DiGraphMap; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tokio::sync::RwLock; -pub type TeamCohortMap = HashMap>; -pub type TeamSortedCohorts = HashMap>; -pub type TeamCacheMap = HashMap; +// Flattened Cohort Map: CohortId -> Combined PropertyFilters +pub type FlattenedCohortMap = HashMap>; -#[derive(Debug, Clone)] -pub struct CachedCohort { - // TODO name this something different - pub filters: Vec, // Non-cohort property filters - pub dependencies: Vec, // Dependencies with operators -} - -// Add this struct to facilitate handling operators on cohort filters -#[derive(Debug, Clone)] -pub struct CohortDependencyFilter { - pub cohort_id: CohortId, - pub operator: OperatorType, -} - -/// Threadsafety is ensured using Arc and RwLock -#[derive(Clone, Debug)] +/// CohortCache manages the in-memory cache of flattened cohorts +#[derive(Clone)] pub struct CohortCache { - /// Mapping from TeamId to their respective CohortId and associated CachedCohort - per_team: Arc>>>, - /// Mapping from TeamId to sorted CohortIds based on dependencies - sorted_cohorts: Arc>>>, + pub per_team_flattened: Arc>>, // team_id -> (cohort_id -> filters) } impl CohortCache { /// Creates a new CohortCache instance pub fn new() -> Self { Self { - per_team: Arc::new(RwLock::new(HashMap::new())), - sorted_cohorts: Arc::new(RwLock::new(HashMap::new())), + per_team_flattened: Arc::new(RwLock::new(HashMap::new())), } } - /// Fetches, flattens, sorts, and caches all cohorts for a given team if not already cached - pub async fn fetch_and_cache_cohorts( + /// Asynchronous constructor that initializes the CohortCache by fetching and caching cohorts for the given team_id + pub async fn new_with_team( + team_id: i32, + postgres_reader: PostgresReader, + ) -> Result { + let cache = Self { + per_team_flattened: Arc::new(RwLock::new(HashMap::new())), + }; + cache + .fetch_and_cache_all_cohorts(team_id, postgres_reader) + .await?; + Ok(cache) + } + + /// Fetches, parses, and caches all cohorts for a given team + async fn fetch_and_cache_all_cohorts( &self, - team_id: TeamId, + team_id: i32, postgres_reader: PostgresReader, ) -> Result<(), FlagError> { - // Acquire write locks to modify the cache - let mut cache = self.per_team.write().await; - let mut sorted = self.sorted_cohorts.write().await; + // Fetch all cohorts for the team + let cohorts = Cohort::list_from_pg(postgres_reader, team_id).await?; - // Check if the team's cohorts are already cached - if cache.contains_key(&team_id) && sorted.contains_key(&team_id) { - return Ok(()); + // Build a mapping from cohort_id to Cohort + let mut cohort_map: HashMap = HashMap::new(); + for cohort in cohorts { + cohort_map.insert(cohort.id, cohort); } - // Fetch all cohorts for the team from the database - let all_cohorts = fetch_all_cohorts(team_id, postgres_reader).await?; + // Build dependency graph + let dependency_graph = Self::build_dependency_graph(&cohort_map)?; - // Flatten the property filters, resolving dependencies - let flattened = flatten_cohorts(&all_cohorts).await?; + // Perform topological sort + let sorted_cohorts = toposort(&dependency_graph, None).map_err(|_| { + FlagError::CohortDependencyCycle("Cycle detected in cohort dependencies".to_string()) + })?; - // Extract all cohort IDs - let cohort_ids: HashSet = flattened.keys().cloned().collect(); + // Reverse to process dependencies first + let sorted_cohorts: Vec = sorted_cohorts.into_iter().rev().collect(); - // Sort the cohorts topologically based on dependencies - let sorted_ids = sort_cohorts_topologically(cohort_ids, &flattened)?; + // Flatten cohorts + let flattened = Self::flatten_cohorts(&sorted_cohorts, &cohort_map)?; - // Insert the flattened cohorts and their sorted order into the cache + // Cache the flattened cohort filters + let mut cache = self.per_team_flattened.write().await; cache.insert(team_id, flattened); - sorted.insert(team_id, sorted_ids); Ok(()) } - /// Retrieves sorted cohort IDs for a team from the cache - pub async fn get_sorted_cohort_ids( + /// Retrieves flattened filters for a given team and cohort + pub async fn get_flattened_filters( &self, - team_id: TeamId, - postgres_reader: PostgresReader, - ) -> Result, FlagError> { - { - // Acquire read locks to check the cache - let cache = self.per_team.read().await; - let sorted = self.sorted_cohorts.read().await; - if let (Some(_cohort_map), Some(sorted_ids)) = - (cache.get(&team_id), sorted.get(&team_id)) - { - if !sorted_ids.is_empty() { - return Ok(sorted_ids.clone()); - } + team_id: i32, + cohort_id: CohortId, + ) -> Result, FlagError> { + let cache = self.per_team_flattened.read().await; + if let Some(team_map) = cache.get(&team_id) { + if let Some(filters) = team_map.get(&cohort_id) { + Ok(filters.clone()) + } else { + Err(FlagError::CohortNotFound(cohort_id.to_string())) } + } else { + Err(FlagError::CohortNotFound(cohort_id.to_string())) } + } - // If not cached, fetch and cache the cohorts - self.fetch_and_cache_cohorts(team_id, postgres_reader) - .await?; + /// Builds a dependency graph where an edge from A to B means A depends on B + fn build_dependency_graph( + cohort_map: &HashMap, + ) -> Result, FlagError> { + let mut graph = DiGraphMap::new(); - // Acquire read locks to retrieve the sorted list after caching - let sorted = self.sorted_cohorts.read().await; - if let Some(sorted_ids) = sorted.get(&team_id) { - Ok(sorted_ids.clone()) - } else { - Ok(Vec::new()) + // Add all cohorts as nodes + for &cohort_id in cohort_map.keys() { + graph.add_node(cohort_id); } + + // Add edges based on dependencies + for (&cohort_id, cohort) in cohort_map.iter() { + let dependencies = Self::extract_dependencies(cohort.filters.clone())?; + for dep_id in dependencies { + if !cohort_map.contains_key(&dep_id) { + return Err(FlagError::CohortNotFound(dep_id.to_string())); + } + graph.add_edge(cohort_id, dep_id, ()); // A depends on B: A -> B + } + } + + Ok(graph) } - /// Retrieves cached cohorts for a team - pub async fn get_cached_cohorts( - &self, - team_id: TeamId, - ) -> Result, FlagError> { - let cache = self.per_team.read().await; - if let Some(cohort_map) = cache.get(&team_id) { - Ok(cohort_map.clone()) - } else { - Ok(HashMap::new()) + /// Extracts all dependent CohortIds from the filters + fn extract_dependencies( + filters_as_json: serde_json::Value, + ) -> Result, FlagError> { + let filters: CohortProperty = serde_json::from_value(filters_as_json)?; + let mut dependencies = HashSet::new(); + Self::traverse_filters(&filters.properties, &mut dependencies)?; + Ok(dependencies) + } + + /// Recursively traverses the filter tree to find cohort dependencies + fn traverse_filters( + inner: &crate::cohort_models::InnerCohortProperty, + dependencies: &mut HashSet, + ) -> Result<(), FlagError> { + for cohort_values in &inner.values { + for filter in &cohort_values.values { + if filter.prop_type == "cohort" && filter.key == "id" { + // Assuming the value is a single integer CohortId + if let Some(cohort_id) = filter.value.as_i64() { + dependencies.insert(cohort_id as CohortId); + } else { + return Err(FlagError::CohortFiltersParsingError); // TODO more data here? + } + } + // Handle nested properties if necessary + // If the filter can contain nested properties with more conditions, traverse them here + } } + Ok(()) } -} -async fn fetch_all_cohorts( - team_id: TeamId, - postgres_reader: PostgresReader, -) -> Result, FlagError> { - let mut conn = postgres_reader.get_connection().await?; - - let query = r#" - SELECT * - FROM posthog_cohort - WHERE team_id = $1 AND deleted = FALSE - "#; - - let cohorts: Vec = sqlx::query_as::<_, Cohort>(query) - .bind(team_id) - .fetch_all(&mut *conn) - .await - .map_err(|e| FlagError::DatabaseError(e.to_string()))?; - - Ok(cohorts) -} + /// Flattens the filters based on sorted cohorts, including only property filters + fn flatten_cohorts( + sorted_cohorts: &[CohortId], + cohort_map: &HashMap, + ) -> Result { + let mut flattened: FlattenedCohortMap = HashMap::new(); -async fn flatten_cohorts( - all_cohorts: &Vec, -) -> Result, FlagError> { - let mut flattened = HashMap::new(); - - for cohort in all_cohorts { - let filters = cohort.parse_filters()?; - - // Extract dependencies from cohort filters - let dependencies = filters - .iter() - .filter_map(|f| { - if f.prop_type == "cohort" { - Some(CohortDependencyFilter { - cohort_id: f.value.as_i64().unwrap() as CohortId, - operator: f.operator.clone().unwrap_or(OperatorType::In), - }) + for &cohort_id in sorted_cohorts { + let cohort = cohort_map + .get(&cohort_id) + .ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?; + + // Use the updated parse_property_filters method + let property_filters = cohort.parse_filters()?; + + // Extract dependencies using Cohort's method + let dependencies = cohort.extract_dependencies()?; + + let mut combined_filters = Vec::new(); + + // Include filters from dependencies + for dep_id in &dependencies { + if let Some(dep_filters) = flattened.get(dep_id) { + combined_filters.extend(dep_filters.clone()); } else { - None + return Err(FlagError::CohortNotFound(dep_id.to_string())); } - }) - .collect(); - - // Filter out cohort filters as they are now represented as dependencies - let non_cohort_filters: Vec = filters - .into_iter() - .filter(|f| f.prop_type != "cohort") - .collect(); - let cached_cohort = CachedCohort { - filters: non_cohort_filters, - dependencies, - }; + } - flattened.insert(cohort.id, cached_cohort); - } + // Include own filters + combined_filters.extend(property_filters); - Ok(flattened) + // Insert into flattened map + flattened.insert(cohort_id, combined_filters); + } + + Ok(flattened) + } } diff --git a/rust/feature-flags/src/cohort_operations.rs b/rust/feature-flags/src/cohort_operations.rs index b11fd1ca78c83..eb25c91dca3be 100644 --- a/rust/feature-flags/src/cohort_operations.rs +++ b/rust/feature-flags/src/cohort_operations.rs @@ -1,8 +1,7 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::sync::Arc; use tracing::instrument; -use crate::cohort_cache::CachedCohort; use crate::cohort_models::{Cohort, CohortId, CohortProperty, InnerCohortProperty}; use crate::{api::FlagError, database::Client as DatabaseClient, flag_definitions::PropertyFilter}; @@ -40,6 +39,29 @@ impl Cohort { }) } + #[instrument(skip_all)] + pub async fn list_from_pg( + client: Arc, + team_id: i32, + ) -> Result, FlagError> { + let mut conn = client.get_connection().await.map_err(|e| { + tracing::error!("Failed to get database connection: {}", e); + FlagError::DatabaseUnavailable + })?; + + let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id FROM posthog_cohort WHERE team_id = $1"; + let cohorts = sqlx::query_as::<_, Cohort>(query) + .bind(team_id) + .fetch_all(&mut *conn) + .await + .map_err(|e| { + tracing::error!("Failed to fetch cohorts from database: {}", e); + FlagError::Internal(format!("Database query error: {}", e)) + })?; + + Ok(cohorts) + } + /// Parses the filters JSON into a CohortProperty structure // TODO: this doesn't handle the deprecated "groups" field, see // https://github.com/PostHog/posthog/blob/feat/dynamic-cohorts-rust/posthog/models/cohort/cohort.py#L114-L169 @@ -50,95 +72,59 @@ impl Cohort { tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e); FlagError::CohortFiltersParsingError })?; - Ok(cohort_property.properties.to_property_filters()) - } -} -impl InnerCohortProperty { - pub fn to_property_filters(&self) -> Vec { - self.values - .iter() - .flat_map(|value| &value.values) - .cloned() - .collect() + // Filter out cohort filters + Ok(cohort_property + .properties + .to_property_filters() + .into_iter() + .filter(|f| !(f.key == "id" && f.prop_type == "cohort")) + .collect()) } -} -/// Sorts the given cohorts in an order where cohorts with no dependencies are placed first, -/// followed by cohorts that depend on the preceding ones. It ensures that each cohort in the sorted list -/// only depends on cohorts that appear earlier in the list. -pub fn sort_cohorts_topologically( - cohort_ids: HashSet, - cached_cohorts: &HashMap, -) -> Result, FlagError> { - if cohort_ids.is_empty() { - return Ok(Vec::new()); - } + /// Extracts dependent CohortIds from the cohort's filters + pub fn extract_dependencies(&self) -> Result, FlagError> { + let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone()) + .map_err(|e| { + tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e); + FlagError::CohortFiltersParsingError + })?; - let mut dependency_graph: HashMap> = HashMap::new(); - for &cohort_id in &cohort_ids { - if let Some(cohort) = cached_cohorts.get(&cohort_id) { - for dependency in &cohort.dependencies { - dependency_graph - .entry(cohort_id) - .or_default() - .push(dependency.cohort_id); - } - } + let mut dependencies = HashSet::new(); + Self::traverse_filters(&cohort_property.properties, &mut dependencies)?; + Ok(dependencies) } - let mut sorted = Vec::new(); - let mut temporary_marks = HashSet::new(); - let mut permanent_marks = HashSet::new(); - - fn visit( - node: CohortId, - dependency_graph: &HashMap>, - temporary_marks: &mut HashSet, - permanent_marks: &mut HashSet, - sorted: &mut Vec, + /// Recursively traverses the filter tree to find cohort dependencies + fn traverse_filters( + inner: &InnerCohortProperty, + dependencies: &mut HashSet, ) -> Result<(), FlagError> { - if permanent_marks.contains(&node) { - return Ok(()); - } - if temporary_marks.contains(&node) { - return Err(FlagError::CohortDependencyCycle(format!( - "Cycle detected at cohort {}", - node - ))); - } - - temporary_marks.insert(node); - if let Some(neighbors) = dependency_graph.get(&node) { - for &neighbor in neighbors { - visit( - neighbor, - dependency_graph, - temporary_marks, - permanent_marks, - sorted, - )?; + for cohort_values in &inner.values { + for filter in &cohort_values.values { + if filter.prop_type == "cohort" && filter.key == "id" { + // Assuming the value is a single integer CohortId + if let Some(cohort_id) = filter.value.as_i64() { + dependencies.insert(cohort_id as CohortId); + } else { + return Err(FlagError::CohortFiltersParsingError); + } + } + // Handle nested properties if necessary } } - temporary_marks.remove(&node); - permanent_marks.insert(node); - sorted.push(node); Ok(()) } +} - for &node in &cohort_ids { - if !permanent_marks.contains(&node) { - visit( - node, - &dependency_graph, - &mut temporary_marks, - &mut permanent_marks, - &mut sorted, - )?; - } +impl InnerCohortProperty { + pub fn to_property_filters(&self) -> Vec { + self.values + .iter() + .flat_map(|value| &value.values) + .cloned() + .collect() } - - Ok(sorted) } #[cfg(test)] diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index 4cf3e88e2c16c..329d0321e77be 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -1,4 +1,7 @@ -use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; +use crate::{ + api::FlagError, cohort_models::CohortId, database::Client as DatabaseClient, + redis::Client as RedisClient, +}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; @@ -43,6 +46,21 @@ pub struct PropertyFilter { pub group_type_index: Option, } +impl PropertyFilter { + /// Checks if the filter is a cohort filter + pub fn is_cohort(&self) -> bool { + self.key == "id" && self.prop_type == "cohort" + } + + /// Returns the cohort id if the filter is a cohort filter + pub fn get_cohort_id(&self) -> Result { + self.value + .as_i64() + .map(|id| id as CohortId) + .ok_or(FlagError::CohortFiltersParsingError) + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FlagGroupType { pub properties: Option>, diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 2384097b26880..7abddb24080c3 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,6 +1,6 @@ use crate::{ api::{FlagError, FlagValue, FlagsResponse}, - cohort_cache::{CachedCohort, CohortCache, CohortDependencyFilter}, + cohort_cache::CohortCache, cohort_models::CohortId, database::Client as DatabaseClient, feature_flag_match_reason::FeatureFlagMatchReason, @@ -187,7 +187,6 @@ pub struct FeatureFlagMatcher { group_type_mapping_cache: GroupTypeMappingCache, properties_cache: PropertiesCache, groups: HashMap, - cohort_cache: CohortCache, } const LONG_SCALE: u64 = 0xfffffffffffffff; @@ -211,7 +210,6 @@ impl FeatureFlagMatcher { group_type_mapping_cache: group_type_mapping_cache .unwrap_or_else(|| GroupTypeMappingCache::new(team_id, postgres_reader.clone())), properties_cache: properties_cache.unwrap_or_default(), - cohort_cache: CohortCache::new(), } } @@ -830,144 +828,97 @@ impl FeatureFlagMatcher { /// Evaluates cohort-based property filters pub async fn evaluate_cohort_filters( &self, - cohort_filters: &[PropertyFilter], + cohort_and_property_filters: &[PropertyFilter], target_properties: &HashMap, ) -> Result { - // Step 1: Extract the cohort filters into a list of CohortDependencyFilters - // We will need these because the are the filters associated with the flag itself, and required for the final - // step of evaluating the cohort membership logic - let filter_cohorts = self.extract_cohort_filters(cohort_filters)?; - - // Get all of the cohort IDs associated with the team and sort them topologically by dependency relationships - // We will need to evaluate the cohort membership logic for each of these cohorts in order to evaluate the cohort filters - // themselves. - // TODO can the sorted cohort IDs be aware of IDs in the filter_cohorts list? Or somehow let us know which ones - // we need to evaluate? - let sorted_cohort_ids = self - .cohort_cache - .get_sorted_cohort_ids(self.team_id, self.postgres_reader.clone()) - .await?; + let cohort_cache = + CohortCache::new_with_team(self.team_id, self.postgres_reader.clone()).await?; - // Step 6: Retrieve cached and flattened cohorts associated with this team - let cached_cohorts = self.cohort_cache.get_cached_cohorts(self.team_id).await?; + let (cohort_filters, property_filters) = cohort_and_property_filters + .iter() + .partition::, _>(|f| f.is_cohort()); - // Step 7: Evaluate cohort dependencies and property filters - let cohort_matches = self.evaluate_cohorts_and_associated_dependencies( - &sorted_cohort_ids, - &cached_cohorts, - target_properties, - )?; + // Early exit if any of property filters fail to match + if !self + .evaluate_property_filters(&property_filters, target_properties) + .await? + { + return Ok(false); + } - // Step 8: Apply any cohort operator logic to determine final matching - self.apply_cohort_membership_logic(&filter_cohorts, &cohort_matches) + // Evaluate cohort filters + let cohort_matches = self + .evaluate_cohort_dependencies(&cohort_filters, &cohort_cache, target_properties) + .await?; + + // Apply cohort membership logic + self.apply_cohort_membership_logic(&cohort_filters, &cohort_matches) } - fn extract_cohort_filters( + /// Evaluates property filters against target properties + async fn evaluate_property_filters( &self, - cohort_filters: &[PropertyFilter], - ) -> Result, FlagError> { - let filter_cohorts = cohort_filters - .iter() - .filter_map(|f| { - if f.key == "id" && f.prop_type == "cohort" { - let cohort_id = f.value.as_i64().map(|id| id as CohortId)?; // TODO handle error? - let operator = f.operator.clone().unwrap_or(OperatorType::In); - Some(CohortDependencyFilter { - cohort_id, - operator, - }) - } else { - None - } - }) - .collect(); - - Ok(filter_cohorts) + property_filters: &[&PropertyFilter], + target_properties: &HashMap, + ) -> Result { + for filter in property_filters { + if !match_property(filter, target_properties, false).unwrap_or(false) { + return Ok(false); + } + } + Ok(true) } - fn evaluate_cohorts_and_associated_dependencies( + /// Evaluates cohort dependencies using the cache + async fn evaluate_cohort_dependencies( &self, - sorted_cohort_ids: &[CohortId], - cached_cohorts: &HashMap, + cohort_filters: &[&PropertyFilter], + cohort_cache: &CohortCache, target_properties: &HashMap, ) -> Result, FlagError> { let mut cohort_matches = HashMap::new(); - for &cohort_id in sorted_cohort_ids { - let cached_cohort: &CachedCohort = match cached_cohorts.get(&cohort_id) { - Some(cohort) => cohort, - None => { - return Err(FlagError::CohortNotFound(format!( - "Cohort ID {} not found in cache", - cohort_id - ))); - } - }; - - // Evaluate dependent cohorts membership - if !self.dependencies_satisfied(&cached_cohort.dependencies, &cohort_matches)? { - cohort_matches.insert(cohort_id, false); - continue; - } + for filter in cohort_filters { + let cohort_id = filter.get_cohort_id()?; + let filters_to_evaluate = cohort_cache + .get_flattened_filters(self.team_id, cohort_id) + .await?; - // Evaluate property filters associated with the cohort - let is_match = cached_cohort.filters.iter().all(|prop_filter| { - match_property(prop_filter, target_properties, false).unwrap_or(false) - }); + let all_filters_match = filters_to_evaluate + .iter() + .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); - cohort_matches.insert(cohort_id, is_match); + cohort_matches.insert(cohort_id, all_filters_match); } Ok(cohort_matches) } - fn dependencies_satisfied( - &self, - dependencies: &[CohortDependencyFilter], - cohort_matches: &HashMap, - ) -> Result { - for dependency in dependencies { - match cohort_matches.get(&dependency.cohort_id) { - Some(true) => match dependency.operator { - OperatorType::In => continue, - OperatorType::NotIn => return Ok(false), - _ => {} - }, - Some(false) => match dependency.operator { - OperatorType::In => return Ok(false), - OperatorType::NotIn => continue, - _ => {} - }, - None => return Ok(false), - } - } - Ok(true) - } - + /// Applies cohort membership logic based on operators fn apply_cohort_membership_logic( &self, - filter_cohorts: &[CohortDependencyFilter], + cohort_filters: &[&PropertyFilter], cohort_matches: &HashMap, ) -> Result { - for filter_cohort in filter_cohorts { - let cohort_match = cohort_matches - .get(&filter_cohort.cohort_id) - .copied() - .unwrap_or(false); + for filter in cohort_filters { + let cohort_id = filter.get_cohort_id()?; + let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); + let operator = filter.operator.clone().unwrap_or(OperatorType::In); - if !self.cohort_membership_operator(filter_cohort.operator, cohort_match) { + if !self.cohort_membership_operator(operator, matches) { return Ok(false); } } Ok(true) } + /// Determines the final match based on the operator and match status fn cohort_membership_operator(&self, operator: OperatorType, match_status: bool) -> bool { match operator { OperatorType::In => match_status, OperatorType::NotIn => !match_status, - // TODO, there shouldn't be any other operators here since this is only called from evaluate_cohort_dependencies - _ => false, // Handle other operators as needed + // Extend with other operators as needed + _ => false, } } From d4af2f0f8bd22e2184b64bfcadced743f8f5ff41 Mon Sep 17 00:00:00 2001 From: dylan Date: Tue, 29 Oct 2024 22:27:27 -0700 Subject: [PATCH 14/30] clippy u dawg --- rust/feature-flags/src/cohort_cache.rs | 6 ++++++ rust/feature-flags/src/flag_matching.rs | 2 +- rust/feature-flags/src/property_matching.rs | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index 2991a73988c26..bcfe67d398d8d 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -17,6 +17,12 @@ pub struct CohortCache { pub per_team_flattened: Arc>>, // team_id -> (cohort_id -> filters) } +impl Default for CohortCache { + fn default() -> Self { + Self::new() + } +} + impl CohortCache { /// Creates a new CohortCache instance pub fn new() -> Self { diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 7abddb24080c3..90525aa6db8ac 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -903,7 +903,7 @@ impl FeatureFlagMatcher { for filter in cohort_filters { let cohort_id = filter.get_cohort_id()?; let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); - let operator = filter.operator.clone().unwrap_or(OperatorType::In); + let operator = filter.operator.unwrap_or(OperatorType::In); if !self.cohort_membership_operator(operator, matches) { return Ok(false); diff --git a/rust/feature-flags/src/property_matching.rs b/rust/feature-flags/src/property_matching.rs index 2f174117befe9..84479f131611f 100644 --- a/rust/feature-flags/src/property_matching.rs +++ b/rust/feature-flags/src/property_matching.rs @@ -44,7 +44,7 @@ pub fn match_property( } let key = &property.key; - let operator = property.operator.clone().unwrap_or(OperatorType::Exact); + let operator = property.operator.unwrap_or(OperatorType::Exact); let value = &property.value; let match_value = matching_property_values.get(key); From 870f7190dab5bb3f5f94aff3ad1f09189e23901a Mon Sep 17 00:00:00 2001 From: dylan Date: Wed, 30 Oct 2024 23:03:56 -0700 Subject: [PATCH 15/30] traverse the dependency graph post-cache access --- rust/feature-flags/src/cohort_cache.rs | 170 +++------------ rust/feature-flags/src/flag_matching.rs | 263 +++++++++++++++++++++--- 2 files changed, 256 insertions(+), 177 deletions(-) diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index bcfe67d398d8d..2c7714e6ed207 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -1,20 +1,14 @@ use crate::api::FlagError; -use crate::cohort_models::{Cohort, CohortId, CohortProperty}; -use crate::flag_definitions::PropertyFilter; +use crate::cohort_models::{Cohort, CohortId}; use crate::flag_matching::PostgresReader; -use petgraph::algo::toposort; -use petgraph::graphmap::DiGraphMap; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -// Flattened Cohort Map: CohortId -> Combined PropertyFilters -pub type FlattenedCohortMap = HashMap>; - -/// CohortCache manages the in-memory cache of flattened cohorts +/// CohortCache manages the in-memory cache of cohorts #[derive(Clone)] pub struct CohortCache { - pub per_team_flattened: Arc>>, // team_id -> (cohort_id -> filters) + pub per_team_cohorts: Arc>>>, // team_id -> list of Cohorts } impl Default for CohortCache { @@ -27,7 +21,7 @@ impl CohortCache { /// Creates a new CohortCache instance pub fn new() -> Self { Self { - per_team_flattened: Arc::new(RwLock::new(HashMap::new())), + per_team_cohorts: Arc::new(RwLock::new(HashMap::new())), } } @@ -37,7 +31,7 @@ impl CohortCache { postgres_reader: PostgresReader, ) -> Result { let cache = Self { - per_team_flattened: Arc::new(RwLock::new(HashMap::new())), + per_team_cohorts: Arc::new(RwLock::new(HashMap::new())), }; cache .fetch_and_cache_all_cohorts(team_id, postgres_reader) @@ -45,7 +39,7 @@ impl CohortCache { Ok(cache) } - /// Fetches, parses, and caches all cohorts for a given team + /// Fetches and caches all cohorts for a given team async fn fetch_and_cache_all_cohorts( &self, team_id: i32, @@ -54,144 +48,34 @@ impl CohortCache { // Fetch all cohorts for the team let cohorts = Cohort::list_from_pg(postgres_reader, team_id).await?; - // Build a mapping from cohort_id to Cohort - let mut cohort_map: HashMap = HashMap::new(); - for cohort in cohorts { - cohort_map.insert(cohort.id, cohort); - } - - // Build dependency graph - let dependency_graph = Self::build_dependency_graph(&cohort_map)?; - - // Perform topological sort - let sorted_cohorts = toposort(&dependency_graph, None).map_err(|_| { - FlagError::CohortDependencyCycle("Cycle detected in cohort dependencies".to_string()) - })?; - - // Reverse to process dependencies first - let sorted_cohorts: Vec = sorted_cohorts.into_iter().rev().collect(); - - // Flatten cohorts - let flattened = Self::flatten_cohorts(&sorted_cohorts, &cohort_map)?; - - // Cache the flattened cohort filters - let mut cache = self.per_team_flattened.write().await; - cache.insert(team_id, flattened); + // Cache the cohorts without flattening + let mut cache = self.per_team_cohorts.write().await; + cache.insert(team_id, cohorts); Ok(()) } - /// Retrieves flattened filters for a given team and cohort - pub async fn get_flattened_filters( + /// Retrieves all cohorts for a given team + pub async fn get_all_cohorts(&self, team_id: i32) -> Result, FlagError> { + let cache = self.per_team_cohorts.read().await; + cache + .get(&team_id) + .cloned() + .ok_or_else(|| FlagError::CohortNotFound("No cohorts found for the team".to_string())) + } + + /// Retrieves a specific cohort by ID for a given team + pub async fn get_cohort_by_id( &self, team_id: i32, cohort_id: CohortId, - ) -> Result, FlagError> { - let cache = self.per_team_flattened.read().await; - if let Some(team_map) = cache.get(&team_id) { - if let Some(filters) = team_map.get(&cohort_id) { - Ok(filters.clone()) - } else { - Err(FlagError::CohortNotFound(cohort_id.to_string())) - } - } else { - Err(FlagError::CohortNotFound(cohort_id.to_string())) - } - } - - /// Builds a dependency graph where an edge from A to B means A depends on B - fn build_dependency_graph( - cohort_map: &HashMap, - ) -> Result, FlagError> { - let mut graph = DiGraphMap::new(); - - // Add all cohorts as nodes - for &cohort_id in cohort_map.keys() { - graph.add_node(cohort_id); - } - - // Add edges based on dependencies - for (&cohort_id, cohort) in cohort_map.iter() { - let dependencies = Self::extract_dependencies(cohort.filters.clone())?; - for dep_id in dependencies { - if !cohort_map.contains_key(&dep_id) { - return Err(FlagError::CohortNotFound(dep_id.to_string())); - } - graph.add_edge(cohort_id, dep_id, ()); // A depends on B: A -> B - } - } - - Ok(graph) - } - - /// Extracts all dependent CohortIds from the filters - fn extract_dependencies( - filters_as_json: serde_json::Value, - ) -> Result, FlagError> { - let filters: CohortProperty = serde_json::from_value(filters_as_json)?; - let mut dependencies = HashSet::new(); - Self::traverse_filters(&filters.properties, &mut dependencies)?; - Ok(dependencies) - } - - /// Recursively traverses the filter tree to find cohort dependencies - fn traverse_filters( - inner: &crate::cohort_models::InnerCohortProperty, - dependencies: &mut HashSet, - ) -> Result<(), FlagError> { - for cohort_values in &inner.values { - for filter in &cohort_values.values { - if filter.prop_type == "cohort" && filter.key == "id" { - // Assuming the value is a single integer CohortId - if let Some(cohort_id) = filter.value.as_i64() { - dependencies.insert(cohort_id as CohortId); - } else { - return Err(FlagError::CohortFiltersParsingError); // TODO more data here? - } - } - // Handle nested properties if necessary - // If the filter can contain nested properties with more conditions, traverse them here + ) -> Result { + let cache = self.per_team_cohorts.read().await; + if let Some(cohorts) = cache.get(&team_id) { + if let Some(cohort) = cohorts.iter().find(|c| c.id == cohort_id) { + return Ok(cohort.clone()); } } - Ok(()) - } - - /// Flattens the filters based on sorted cohorts, including only property filters - fn flatten_cohorts( - sorted_cohorts: &[CohortId], - cohort_map: &HashMap, - ) -> Result { - let mut flattened: FlattenedCohortMap = HashMap::new(); - - for &cohort_id in sorted_cohorts { - let cohort = cohort_map - .get(&cohort_id) - .ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?; - - // Use the updated parse_property_filters method - let property_filters = cohort.parse_filters()?; - - // Extract dependencies using Cohort's method - let dependencies = cohort.extract_dependencies()?; - - let mut combined_filters = Vec::new(); - - // Include filters from dependencies - for dep_id in &dependencies { - if let Some(dep_filters) = flattened.get(dep_id) { - combined_filters.extend(dep_filters.clone()); - } else { - return Err(FlagError::CohortNotFound(dep_id.to_string())); - } - } - - // Include own filters - combined_filters.extend(property_filters); - - // Insert into flattened map - flattened.insert(cohort_id, combined_filters); - } - - Ok(flattened) + Err(FlagError::CohortNotFound(cohort_id.to_string())) } } diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 90525aa6db8ac..351a32708369d 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -11,13 +11,15 @@ use crate::{ }; use anyhow::Result; use common_metrics::inc; +use petgraph::algo::{is_cyclic_directed, toposort}; +use petgraph::graph::DiGraph; use serde_json::Value; use sha1::{Digest, Sha1}; use sqlx::{postgres::PgQueryResult, Acquire, FromRow}; use std::fmt::Write; use std::sync::Arc; use std::{ - collections::{HashMap, HashSet}, + collections::{HashMap, HashSet, VecDeque}, time::Duration, }; use tokio::time::{sleep, timeout}; @@ -825,7 +827,7 @@ impl FeatureFlagMatcher { } } - /// Evaluates cohort-based property filters + /// Evaluates cohort-based property filters dynamically pub async fn evaluate_cohort_filters( &self, cohort_and_property_filters: &[PropertyFilter], @@ -833,12 +835,12 @@ impl FeatureFlagMatcher { ) -> Result { let cohort_cache = CohortCache::new_with_team(self.team_id, self.postgres_reader.clone()).await?; - + // Partition filters into cohort and non-cohort let (cohort_filters, property_filters) = cohort_and_property_filters .iter() .partition::, _>(|f| f.is_cohort()); - // Early exit if any of property filters fail to match + // Early exit if any property filters fail to match if !self .evaluate_property_filters(&property_filters, target_properties) .await? @@ -846,15 +848,90 @@ impl FeatureFlagMatcher { return Ok(false); } - // Evaluate cohort filters - let cohort_matches = self - .evaluate_cohort_dependencies(&cohort_filters, &cohort_cache, target_properties) - .await?; + // Evaluate cohort filters dynamically + let mut cohort_matches = HashMap::new(); + + for filter in &cohort_filters { + let cohort_id = filter.get_cohort_id()?; + let match_result = self + .evaluate_single_cohort(self.team_id, cohort_id, target_properties, &cohort_cache) + .await?; + cohort_matches.insert(cohort_id, match_result); + } // Apply cohort membership logic self.apply_cohort_membership_logic(&cohort_filters, &cohort_matches) } + /// Evaluates a single cohort and its dependencies using a dependency graph walk + async fn evaluate_single_cohort( + &self, + team_id: i32, + initial_cohort_id: CohortId, + target_properties: &HashMap, + cohort_cache: &CohortCache, + ) -> Result { + // Build the dependency graph + let graph = build_dependency_graph(team_id, initial_cohort_id, cohort_cache).await?; + + // Perform topological sort + let sorted_nodes = toposort(&graph, None).map_err(|e| { + FlagError::CohortDependencyCycle(format!("Cyclic dependency detected: {:?}", e)) + })?; + + // Map to store evaluation results + let mut evaluation_results: HashMap = HashMap::new(); + + // Iterate in reverse topological order (dependencies first) + for node in sorted_nodes.into_iter().rev() { + let cohort_id = graph[node]; + let cohort = cohort_cache.get_cohort_by_id(team_id, cohort_id).await?; + let property_filters = cohort.parse_filters()?; // Assuming parse_filters returns Vec + + // Evaluate dependencies + let dependencies = cohort.extract_dependencies()?; + let mut deps_match = true; + for dep_id in dependencies { + if let Some(&dep_result) = evaluation_results.get(&dep_id) { + if !dep_result { + deps_match = false; + break; + } + } else { + // This should not happen due to topological sorting + return Err(FlagError::CohortDependencyCycle(format!( + "Missing dependency result for cohort {}", + dep_id + ))); + } + } + + if !deps_match { + evaluation_results.insert(cohort_id, false); + continue; + } + + // Evaluate own property filters + let all_filters_match = property_filters + .iter() + .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); + + // Store the result in the cache + evaluation_results.insert(cohort_id, all_filters_match); + + // Optional: Early exit if desired + if !all_filters_match { + // break; + } + } + + // Return the result for the initial cohort + evaluation_results + .get(&initial_cohort_id) + .copied() + .ok_or_else(|| FlagError::NoGroupTypeMappings) + } + /// Evaluates property filters against target properties async fn evaluate_property_filters( &self, @@ -869,32 +946,7 @@ impl FeatureFlagMatcher { Ok(true) } - /// Evaluates cohort dependencies using the cache - async fn evaluate_cohort_dependencies( - &self, - cohort_filters: &[&PropertyFilter], - cohort_cache: &CohortCache, - target_properties: &HashMap, - ) -> Result, FlagError> { - let mut cohort_matches = HashMap::new(); - - for filter in cohort_filters { - let cohort_id = filter.get_cohort_id()?; - let filters_to_evaluate = cohort_cache - .get_flattened_filters(self.team_id, cohort_id) - .await?; - - let all_filters_match = filters_to_evaluate - .iter() - .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); - - cohort_matches.insert(cohort_id, all_filters_match); - } - - Ok(cohort_matches) - } - - /// Applies cohort membership logic based on operators + /// Apply cohort membership logic based on operators fn apply_cohort_membership_logic( &self, cohort_filters: &[&PropertyFilter], @@ -922,6 +974,103 @@ impl FeatureFlagMatcher { } } + // /// Evaluates cohort-based property filters + // pub async fn evaluate_cohort_filters( + // &self, + // cohort_and_property_filters: &[PropertyFilter], + // target_properties: &HashMap, + // ) -> Result { + // let cohort_cache = + // CohortCache::new_with_team(self.team_id, self.postgres_reader.clone()).await?; + + // let (cohort_filters, property_filters) = cohort_and_property_filters + // .iter() + // .partition::, _>(|f| f.is_cohort()); + + // // Early exit if any of property filters fail to match + // if !self + // .evaluate_property_filters(&property_filters, target_properties) + // .await? + // { + // return Ok(false); + // } + + // // Evaluate cohort filters + // let cohort_matches = self + // .evaluate_cohort_dependencies(&cohort_filters, &cohort_cache, target_properties) + // .await?; + + // // Apply cohort membership logic + // self.apply_cohort_membership_logic(&cohort_filters, &cohort_matches) + // } + + // /// Evaluates property filters against target properties + // async fn evaluate_property_filters( + // &self, + // property_filters: &[&PropertyFilter], + // target_properties: &HashMap, + // ) -> Result { + // for filter in property_filters { + // if !match_property(filter, target_properties, false).unwrap_or(false) { + // return Ok(false); + // } + // } + // Ok(true) + // } + + // /// Evaluates cohort dependencies using the cache + // async fn evaluate_cohort_dependencies( + // &self, + // cohort_filters: &[&PropertyFilter], + // cohort_cache: &CohortCache, + // target_properties: &HashMap, + // ) -> Result, FlagError> { + // let mut cohort_matches = HashMap::new(); + + // for filter in cohort_filters { + // let cohort_id = filter.get_cohort_id()?; + // let filters_to_evaluate = cohort_cache + // .get_flattened_filters(self.team_id, cohort_id) + // .await?; + + // let all_filters_match = filters_to_evaluate + // .iter() + // .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); + + // cohort_matches.insert(cohort_id, all_filters_match); + // } + + // Ok(cohort_matches) + // } + + // /// Applies cohort membership logic based on operators + // fn apply_cohort_membership_logic( + // &self, + // cohort_filters: &[&PropertyFilter], + // cohort_matches: &HashMap, + // ) -> Result { + // for filter in cohort_filters { + // let cohort_id = filter.get_cohort_id()?; + // let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); + // let operator = filter.operator.clone().unwrap_or(OperatorType::In); + + // if !self.cohort_membership_operator(operator, matches) { + // return Ok(false); + // } + // } + // Ok(true) + // } + + // /// Determines the final match based on the operator and match status + // fn cohort_membership_operator(&self, operator: OperatorType, match_status: bool) -> bool { + // match operator { + // OperatorType::In => match_status, + // OperatorType::NotIn => !match_status, + // // Extend with other operators as needed + // _ => false, + // } + // } + /// Check if a super condition matches for a feature flag. /// /// This function evaluates the super conditions of a feature flag to determine if any of them should be enabled. @@ -1165,6 +1314,52 @@ impl FeatureFlagMatcher { } } +/// Constructs a dependency graph for cohorts. +async fn build_dependency_graph( + team_id: i32, + initial_cohort_id: CohortId, + cohort_cache: &CohortCache, +) -> Result, FlagError> { + let mut graph: DiGraph = DiGraph::new(); + let mut node_map = HashMap::new(); + + // Queue for BFS traversal + let mut queue = VecDeque::new(); + queue.push_back(initial_cohort_id); + node_map.insert(initial_cohort_id, graph.add_node(initial_cohort_id)); + + while let Some(cohort_id) = queue.pop_front() { + let cohort = cohort_cache.get_cohort_by_id(team_id, cohort_id).await?; + let dependencies = cohort.extract_dependencies()?; + + for dep_id in dependencies { + // Retrieve the current node **before** mutable borrowing + let current_node = node_map[&cohort_id]; + + // Add dependency node if not present + let dep_node = node_map + .entry(dep_id) + .or_insert_with(|| graph.add_node(dep_id)); + + graph.add_edge(current_node, *dep_node, ()); + + if !node_map.contains_key(&dep_id) { + queue.push_back(dep_id); + } + } + } + + // Check for cycles + if is_cyclic_directed(&graph) { + return Err(FlagError::CohortDependencyCycle(format!( + "Cyclic dependency detected starting at cohort {}", + initial_cohort_id + ))); + } + + Ok(graph) +} + /// Fetch and locally cache all properties for a given distinct ID and team ID. /// /// This function fetches both person and group properties for a specified distinct ID and team ID. From 57d98850af1df16b863611d3f1605d3c19fd3ce3 Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 31 Oct 2024 13:18:33 -0700 Subject: [PATCH 16/30] cleaning up --- rust/feature-flags/src/cohort_cache.rs | 42 +++- rust/feature-flags/src/cohort_operations.rs | 64 ----- rust/feature-flags/src/flag_matching.rs | 262 ++++++-------------- 3 files changed, 103 insertions(+), 265 deletions(-) diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index 2c7714e6ed207..f3582bf15dc37 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -6,6 +6,20 @@ use std::sync::Arc; use tokio::sync::RwLock; /// CohortCache manages the in-memory cache of cohorts +/// +/// Example cache structure: +/// ```text +/// per_team_cohorts: { +/// 1: [ +/// Cohort { id: 101, name: "Active Users", filters: [...] }, +/// Cohort { id: 102, name: "Power Users", filters: [...] } +/// ], +/// 2: [ +/// Cohort { id: 201, name: "New Users", filters: [...] }, +/// Cohort { id: 202, name: "Churned Users", filters: [...] } +/// ] +/// } +/// ``` #[derive(Clone)] pub struct CohortCache { pub per_team_cohorts: Arc>>>, // team_id -> list of Cohorts @@ -18,7 +32,6 @@ impl Default for CohortCache { } impl CohortCache { - /// Creates a new CohortCache instance pub fn new() -> Self { Self { per_team_cohorts: Arc::new(RwLock::new(HashMap::new())), @@ -40,30 +53,33 @@ impl CohortCache { } /// Fetches and caches all cohorts for a given team + /// + /// Cache structure: + /// ```text + /// per_team_cohorts: { + /// team_id_1: [ + /// Cohort { id: 1, filters: [...], ... }, + /// Cohort { id: 2, filters: [...], ... }, + /// ... + /// ], + /// team_id_2: [ + /// Cohort { id: 3, filters: [...], ... }, + /// ... + /// ] + /// } + /// ``` async fn fetch_and_cache_all_cohorts( &self, team_id: i32, postgres_reader: PostgresReader, ) -> Result<(), FlagError> { - // Fetch all cohorts for the team let cohorts = Cohort::list_from_pg(postgres_reader, team_id).await?; - - // Cache the cohorts without flattening let mut cache = self.per_team_cohorts.write().await; cache.insert(team_id, cohorts); Ok(()) } - /// Retrieves all cohorts for a given team - pub async fn get_all_cohorts(&self, team_id: i32) -> Result, FlagError> { - let cache = self.per_team_cohorts.read().await; - cache - .get(&team_id) - .cloned() - .ok_or_else(|| FlagError::CohortNotFound("No cohorts found for the team".to_string())) - } - /// Retrieves a specific cohort by ID for a given team pub async fn get_cohort_by_id( &self, diff --git a/rust/feature-flags/src/cohort_operations.rs b/rust/feature-flags/src/cohort_operations.rs index eb25c91dca3be..30cf900218e01 100644 --- a/rust/feature-flags/src/cohort_operations.rs +++ b/rust/feature-flags/src/cohort_operations.rs @@ -194,70 +194,6 @@ mod tests { assert_eq!(result[0].prop_type, "person"); } - // #[test] - // fn test_sort_cohorts_topologically() { - // let mut cohorts = HashMap::new(); - // cohorts.insert( - // 1, - // Cohort { - // id: 1, - // name: "Cohort 1".to_string(), - // description: None, - // team_id: 1, - // deleted: false, - // filters: json!({"properties": {"type": "AND", "values": []}}), - // query: None, - // version: None, - // pending_version: None, - // count: None, - // is_calculating: false, - // is_static: false, - // errors_calculating: 0, - // groups: json!({}), - // created_by_id: None, - // }, - // ); - // cohorts.insert(2, Cohort { - // id: 2, - // name: "Cohort 2".to_string(), - // description: None, - // team_id: 1, - // deleted: false, - // filters: json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 1, "type": "cohort"}]}]}}), - // query: None, - // version: None, - // pending_version: None, - // count: None, - // is_calculating: false, - // is_static: false, - // errors_calculating: 0, - // groups: json!({}), - // created_by_id: None, - // }); - // cohorts.insert( - // 3, Cohort { - // id: 3, - // name: "Cohort 3".to_string(), - // description: None, - // team_id: 1, - // deleted: false, - // filters: json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "cohort", "value": 2, "type": "cohort"}]}]}}), - // query: None, - // version: None, - // pending_version: None, - // count: None, - // is_calculating: false, - // is_static: false, - // errors_calculating: 0, - // groups: json!({}), - // created_by_id: None, - // }); - - // let cohort_ids: HashSet = vec![1, 2, 3].into_iter().collect(); - // let result = sort_cohorts_topologically(cohort_ids, &cohorts); - // assert_eq!(result, vec![1, 2, 3]); - // } - #[test] fn test_cohort_property_to_property_filters() { let cohort_property = InnerCohortProperty { diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 351a32708369d..b7d92329d9e37 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -741,7 +741,7 @@ impl FeatureFlagMatcher { flag_property_filters .iter() .cloned() - .partition(|prop| prop.prop_type == "cohort"); + .partition(|prop| prop.is_cohort()); // Evaluate non-cohort properties first to get properties_to_check let properties_to_check = self @@ -827,126 +827,108 @@ impl FeatureFlagMatcher { } } - /// Evaluates cohort-based property filters dynamically + /// Evaluates dynamic cohort property filters + /// + /// NB: This method first caches all of the cohorts associated with the team, which allows us to avoid + /// hitting the database for each cohort filter. pub async fn evaluate_cohort_filters( &self, - cohort_and_property_filters: &[PropertyFilter], + property_filters: &[PropertyFilter], target_properties: &HashMap, ) -> Result { + // Caching all of the cohorts like this will make it so that we don't have to hit the database for each cohort filter let cohort_cache = CohortCache::new_with_team(self.team_id, self.postgres_reader.clone()).await?; - // Partition filters into cohort and non-cohort - let (cohort_filters, property_filters) = cohort_and_property_filters - .iter() - .partition::, _>(|f| f.is_cohort()); - // Early exit if any property filters fail to match - if !self - .evaluate_property_filters(&property_filters, target_properties) - .await? - { - return Ok(false); - } + // At this point, we shouldn't have any non-cohort property filters, but we'll filter them out anyway + let cohort_property_filters: Vec<_> = + property_filters.iter().filter(|f| f.is_cohort()).collect(); - // Evaluate cohort filters dynamically + // Store cohort match results in a HashMap to avoid re-evaluating the same cohort multiple times, + // since the same cohort could appear in multiple property filters. This is especially important + // because evaluating a cohort requires evaluating all of its dependencies, which can be expensive. let mut cohort_matches = HashMap::new(); - - for filter in &cohort_filters { + for filter in &cohort_property_filters { let cohort_id = filter.get_cohort_id()?; let match_result = self - .evaluate_single_cohort(self.team_id, cohort_id, target_properties, &cohort_cache) + .evaluate_cohort_dependencies( + self.team_id, + cohort_id, + target_properties, + &cohort_cache, + ) .await?; cohort_matches.insert(cohort_id, match_result); } - // Apply cohort membership logic - self.apply_cohort_membership_logic(&cohort_filters, &cohort_matches) + // Apply cohort membership logic (IN|NOT_IN) + self.apply_cohort_membership_logic(&cohort_property_filters, &cohort_matches) } - /// Evaluates a single cohort and its dependencies using a dependency graph walk - async fn evaluate_single_cohort( + /// Evaluates a single cohort and its dependencies. + /// This uses a topological sort to evaluate dependencies first, which is necessary + /// because a cohort can depend on another cohort, and we need to respect the dependency order. + async fn evaluate_cohort_dependencies( &self, team_id: i32, initial_cohort_id: CohortId, target_properties: &HashMap, cohort_cache: &CohortCache, ) -> Result { - // Build the dependency graph - let graph = build_dependency_graph(team_id, initial_cohort_id, cohort_cache).await?; - - // Perform topological sort - let sorted_nodes = toposort(&graph, None).map_err(|e| { - FlagError::CohortDependencyCycle(format!("Cyclic dependency detected: {:?}", e)) - })?; - - // Map to store evaluation results - let mut evaluation_results: HashMap = HashMap::new(); + let cohort_dependency_graph = + build_cohort_dependency_graph(team_id, initial_cohort_id, cohort_cache).await?; + + // We need to sort cohorts topologically to ensure we evaluate dependencies before the cohorts that depend on them. + // For example, if cohort A depends on cohort B, we need to evaluate B first to know if A matches. + // This also helps detect cycles - if cohort A depends on B which depends on A, toposort will fail. + let sorted_cohort_ids_as_graph_nodes = + toposort(&cohort_dependency_graph, None).map_err(|e| { + FlagError::CohortDependencyCycle(format!("Cyclic dependency detected: {:?}", e)) + })?; - // Iterate in reverse topological order (dependencies first) - for node in sorted_nodes.into_iter().rev() { - let cohort_id = graph[node]; - let cohort = cohort_cache.get_cohort_by_id(team_id, cohort_id).await?; - let property_filters = cohort.parse_filters()?; // Assuming parse_filters returns Vec + // Store evaluation results for each cohort in a map, so we can look up whether a cohort matched + // when evaluating cohorts that depend on it, and also return the final result for the initial cohort + let mut evaluation_results = HashMap::new(); - // Evaluate dependencies + // Iterate through the sorted nodes in reverse order (so that we can evaluate dependencies first) + for node in sorted_cohort_ids_as_graph_nodes.into_iter().rev() { + let cohort_id = cohort_dependency_graph[node]; + let cohort = cohort_cache + .get_cohort_by_id(team_id, cohort_dependency_graph[node]) + .await?; + let property_filters = cohort.parse_filters()?; let dependencies = cohort.extract_dependencies()?; - let mut deps_match = true; - for dep_id in dependencies { - if let Some(&dep_result) = evaluation_results.get(&dep_id) { - if !dep_result { - deps_match = false; - break; - } - } else { - // This should not happen due to topological sorting - return Err(FlagError::CohortDependencyCycle(format!( - "Missing dependency result for cohort {}", - dep_id - ))); - } - } - if !deps_match { + // Check if all dependencies have been met (i.e., previous cohorts matched) + let dependencies_met = dependencies + .iter() + .all(|dep_id| evaluation_results.get(dep_id).copied().unwrap_or(false)); + + // If dependencies are not met, mark the current cohort as not matched and continue + // NB: We don't want to _exit_ here, since the non-matching cohort could be wrapped in a `not_in` operator + // and we want to evaluate all cohorts to determine if the initial cohort matches. + if !dependencies_met { evaluation_results.insert(cohort_id, false); continue; } - // Evaluate own property filters + // Evaluate all property filters for the current cohort let all_filters_match = property_filters .iter() .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); - // Store the result in the cache + // Store the evaluation result for the current cohort evaluation_results.insert(cohort_id, all_filters_match); - - // Optional: Early exit if desired - if !all_filters_match { - // break; - } } - // Return the result for the initial cohort + // Retrieve and return the evaluation result for the initial cohort evaluation_results .get(&initial_cohort_id) .copied() - .ok_or_else(|| FlagError::NoGroupTypeMappings) - } - - /// Evaluates property filters against target properties - async fn evaluate_property_filters( - &self, - property_filters: &[&PropertyFilter], - target_properties: &HashMap, - ) -> Result { - for filter in property_filters { - if !match_property(filter, target_properties, false).unwrap_or(false) { - return Ok(false); - } - } - Ok(true) + .ok_or_else(|| FlagError::CohortNotFound(initial_cohort_id.to_string())) } - /// Apply cohort membership logic based on operators + /// Apply cohort membership logic (i.e., IN|NOT_IN) fn apply_cohort_membership_logic( &self, cohort_filters: &[&PropertyFilter], @@ -957,120 +939,24 @@ impl FeatureFlagMatcher { let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); let operator = filter.operator.unwrap_or(OperatorType::In); - if !self.cohort_membership_operator(operator, matches) { + // Combine the operator logic directly within this method + let membership_match = match operator { + OperatorType::In => matches, + OperatorType::NotIn => !matches, + // Currently supported operators are IN and NOT IN + // Any other operator defaults to false + _ => false, + }; + + // If any filter does not match, return false early + if !membership_match { return Ok(false); } } + // All filters matched Ok(true) } - /// Determines the final match based on the operator and match status - fn cohort_membership_operator(&self, operator: OperatorType, match_status: bool) -> bool { - match operator { - OperatorType::In => match_status, - OperatorType::NotIn => !match_status, - // Extend with other operators as needed - _ => false, - } - } - - // /// Evaluates cohort-based property filters - // pub async fn evaluate_cohort_filters( - // &self, - // cohort_and_property_filters: &[PropertyFilter], - // target_properties: &HashMap, - // ) -> Result { - // let cohort_cache = - // CohortCache::new_with_team(self.team_id, self.postgres_reader.clone()).await?; - - // let (cohort_filters, property_filters) = cohort_and_property_filters - // .iter() - // .partition::, _>(|f| f.is_cohort()); - - // // Early exit if any of property filters fail to match - // if !self - // .evaluate_property_filters(&property_filters, target_properties) - // .await? - // { - // return Ok(false); - // } - - // // Evaluate cohort filters - // let cohort_matches = self - // .evaluate_cohort_dependencies(&cohort_filters, &cohort_cache, target_properties) - // .await?; - - // // Apply cohort membership logic - // self.apply_cohort_membership_logic(&cohort_filters, &cohort_matches) - // } - - // /// Evaluates property filters against target properties - // async fn evaluate_property_filters( - // &self, - // property_filters: &[&PropertyFilter], - // target_properties: &HashMap, - // ) -> Result { - // for filter in property_filters { - // if !match_property(filter, target_properties, false).unwrap_or(false) { - // return Ok(false); - // } - // } - // Ok(true) - // } - - // /// Evaluates cohort dependencies using the cache - // async fn evaluate_cohort_dependencies( - // &self, - // cohort_filters: &[&PropertyFilter], - // cohort_cache: &CohortCache, - // target_properties: &HashMap, - // ) -> Result, FlagError> { - // let mut cohort_matches = HashMap::new(); - - // for filter in cohort_filters { - // let cohort_id = filter.get_cohort_id()?; - // let filters_to_evaluate = cohort_cache - // .get_flattened_filters(self.team_id, cohort_id) - // .await?; - - // let all_filters_match = filters_to_evaluate - // .iter() - // .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); - - // cohort_matches.insert(cohort_id, all_filters_match); - // } - - // Ok(cohort_matches) - // } - - // /// Applies cohort membership logic based on operators - // fn apply_cohort_membership_logic( - // &self, - // cohort_filters: &[&PropertyFilter], - // cohort_matches: &HashMap, - // ) -> Result { - // for filter in cohort_filters { - // let cohort_id = filter.get_cohort_id()?; - // let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); - // let operator = filter.operator.clone().unwrap_or(OperatorType::In); - - // if !self.cohort_membership_operator(operator, matches) { - // return Ok(false); - // } - // } - // Ok(true) - // } - - // /// Determines the final match based on the operator and match status - // fn cohort_membership_operator(&self, operator: OperatorType, match_status: bool) -> bool { - // match operator { - // OperatorType::In => match_status, - // OperatorType::NotIn => !match_status, - // // Extend with other operators as needed - // _ => false, - // } - // } - /// Check if a super condition matches for a feature flag. /// /// This function evaluates the super conditions of a feature flag to determine if any of them should be enabled. @@ -1315,12 +1201,12 @@ impl FeatureFlagMatcher { } /// Constructs a dependency graph for cohorts. -async fn build_dependency_graph( +async fn build_cohort_dependency_graph( team_id: i32, initial_cohort_id: CohortId, cohort_cache: &CohortCache, ) -> Result, FlagError> { - let mut graph: DiGraph = DiGraph::new(); + let mut graph = DiGraph::new(); let mut node_map = HashMap::new(); // Queue for BFS traversal From 9eb0f182ee42fa5bda9ebd93143fb092bf002bd4 Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 31 Oct 2024 14:02:29 -0700 Subject: [PATCH 17/30] adding more tests --- rust/feature-flags/src/api.rs | 6 - rust/feature-flags/src/cohort_operations.rs | 119 +++++++++- rust/feature-flags/src/flag_definitions.rs | 1 + rust/feature-flags/src/flag_matching.rs | 244 +++++++++++--------- 4 files changed, 247 insertions(+), 123 deletions(-) diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index f1fe8fe485999..9d6b649719bd2 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -108,8 +108,6 @@ pub enum FlagError { CohortFiltersParsingError, #[error("Cohort dependency cycle")] CohortDependencyCycle(String), - #[error("Cohort dependency error")] - CohortDependencyError(String), } impl IntoResponse for FlagError { @@ -214,10 +212,6 @@ impl IntoResponse for FlagError { tracing::error!("Cohort dependency cycle: {}", msg); (StatusCode::BAD_REQUEST, msg) } - FlagError::CohortDependencyError(msg) => { - tracing::error!("Cohort dependency error: {}", msg); - (StatusCode::BAD_REQUEST, msg) - } } .into_response() } diff --git a/rust/feature-flags/src/cohort_operations.rs b/rust/feature-flags/src/cohort_operations.rs index 30cf900218e01..42a2c861f4ea0 100644 --- a/rust/feature-flags/src/cohort_operations.rs +++ b/rust/feature-flags/src/cohort_operations.rs @@ -39,6 +39,7 @@ impl Cohort { }) } + /// Returns all cohorts for a given team #[instrument(skip_all)] pub async fn list_from_pg( client: Arc, @@ -72,8 +73,6 @@ impl Cohort { tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e); FlagError::CohortFiltersParsingError })?; - - // Filter out cohort filters Ok(cohort_property .properties .to_property_filters() @@ -96,13 +95,41 @@ impl Cohort { } /// Recursively traverses the filter tree to find cohort dependencies + /// + /// Example filter tree structure: + /// ```json + /// { + /// "properties": { + /// "type": "OR", + /// "values": [ + /// { + /// "type": "OR", + /// "values": [ + /// { + /// "key": "id", + /// "value": 123, + /// "type": "cohort", + /// "operator": "exact" + /// }, + /// { + /// "key": "email", + /// "value": "@posthog.com", + /// "type": "person", + /// "operator": "icontains" + /// } + /// ] + /// } + /// ] + /// } + /// } + /// ``` fn traverse_filters( inner: &InnerCohortProperty, dependencies: &mut HashSet, ) -> Result<(), FlagError> { for cohort_values in &inner.values { for filter in &cohort_values.values { - if filter.prop_type == "cohort" && filter.key == "id" { + if filter.is_cohort() { // Assuming the value is a single integer CohortId if let Some(cohort_id) = filter.value.as_i64() { dependencies.insert(cohort_id as CohortId); @@ -110,7 +137,7 @@ impl Cohort { return Err(FlagError::CohortFiltersParsingError); } } - // Handle nested properties if necessary + // NB: we don't support nested cohort properties, so we don't need to traverse further } } Ok(()) @@ -167,6 +194,46 @@ mod tests { assert_eq!(fetched_cohort.team_id, team.id); } + #[tokio::test] + async fn test_list_from_pg() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + // Insert multiple cohorts for the team + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Cohort 1".to_string()), + json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "age", "type": "person", "value": [30], "negation": false, "operator": "gt"}]}]}}), + false, + ) + .await + .expect("Failed to insert cohort1"); + + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Cohort 2".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "property", "values": [{"key": "country", "type": "person", "value": ["USA"], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert cohort2"); + + let cohorts = Cohort::list_from_pg(postgres_reader, team.id) + .await + .expect("Failed to list cohorts"); + + assert_eq!(cohorts.len(), 2); + let names: HashSet = cohorts.into_iter().map(|c| c.name).collect(); + assert!(names.contains("Cohort 1")); + assert!(names.contains("Cohort 2")); + } + #[test] fn test_cohort_parse_filters() { let cohort = Cohort { @@ -228,4 +295,48 @@ mod tests { assert_eq!(result[1].key, "age"); assert_eq!(result[1].value, json!(25)); } + + #[tokio::test] + async fn test_extract_dependencies() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + // Insert a single cohort that is dependent on another cohort + let dependent_cohort = insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Dependent Cohort".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$browser", "type": "person", "value": ["Safari"], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert dependent_cohort"); + + // Insert main cohort with a single dependency + let main_cohort = insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Main Cohort".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "id", "type": "cohort", "value": dependent_cohort.id, "negation": false}]}]}}), + false, + ) + .await + .expect("Failed to insert main_cohort"); + + let fetched_main_cohort = Cohort::from_pg(postgres_reader.clone(), main_cohort.id, team.id) + .await + .expect("Failed to fetch main cohort"); + + println!("fetched_main_cohort: {:?}", fetched_main_cohort); + + let dependencies = fetched_main_cohort.extract_dependencies().unwrap(); + let expected_dependencies: HashSet = + [dependent_cohort.id].iter().cloned().collect(); + + assert_eq!(dependencies, expected_dependencies); + } } diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index 329d0321e77be..8b37072ec494b 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -41,6 +41,7 @@ pub struct PropertyFilter { pub value: serde_json::Value, pub operator: Option, #[serde(rename = "type")] + // TODO: worth making a enum here to differentiate between cohort and person filters? pub prop_type: String, pub negation: Option, pub group_type_index: Option, diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index b7d92329d9e37..59fc6c5a34b52 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -833,128 +833,31 @@ impl FeatureFlagMatcher { /// hitting the database for each cohort filter. pub async fn evaluate_cohort_filters( &self, - property_filters: &[PropertyFilter], + cohort_property_filters: &[PropertyFilter], target_properties: &HashMap, ) -> Result { // Caching all of the cohorts like this will make it so that we don't have to hit the database for each cohort filter let cohort_cache = CohortCache::new_with_team(self.team_id, self.postgres_reader.clone()).await?; - // At this point, we shouldn't have any non-cohort property filters, but we'll filter them out anyway - let cohort_property_filters: Vec<_> = - property_filters.iter().filter(|f| f.is_cohort()).collect(); - // Store cohort match results in a HashMap to avoid re-evaluating the same cohort multiple times, // since the same cohort could appear in multiple property filters. This is especially important // because evaluating a cohort requires evaluating all of its dependencies, which can be expensive. let mut cohort_matches = HashMap::new(); - for filter in &cohort_property_filters { + for filter in cohort_property_filters { let cohort_id = filter.get_cohort_id()?; - let match_result = self - .evaluate_cohort_dependencies( - self.team_id, - cohort_id, - target_properties, - &cohort_cache, - ) - .await?; + let match_result = evaluate_cohort_dependencies( + self.team_id, + cohort_id, + target_properties, + &cohort_cache, + ) + .await?; cohort_matches.insert(cohort_id, match_result); } // Apply cohort membership logic (IN|NOT_IN) - self.apply_cohort_membership_logic(&cohort_property_filters, &cohort_matches) - } - - /// Evaluates a single cohort and its dependencies. - /// This uses a topological sort to evaluate dependencies first, which is necessary - /// because a cohort can depend on another cohort, and we need to respect the dependency order. - async fn evaluate_cohort_dependencies( - &self, - team_id: i32, - initial_cohort_id: CohortId, - target_properties: &HashMap, - cohort_cache: &CohortCache, - ) -> Result { - let cohort_dependency_graph = - build_cohort_dependency_graph(team_id, initial_cohort_id, cohort_cache).await?; - - // We need to sort cohorts topologically to ensure we evaluate dependencies before the cohorts that depend on them. - // For example, if cohort A depends on cohort B, we need to evaluate B first to know if A matches. - // This also helps detect cycles - if cohort A depends on B which depends on A, toposort will fail. - let sorted_cohort_ids_as_graph_nodes = - toposort(&cohort_dependency_graph, None).map_err(|e| { - FlagError::CohortDependencyCycle(format!("Cyclic dependency detected: {:?}", e)) - })?; - - // Store evaluation results for each cohort in a map, so we can look up whether a cohort matched - // when evaluating cohorts that depend on it, and also return the final result for the initial cohort - let mut evaluation_results = HashMap::new(); - - // Iterate through the sorted nodes in reverse order (so that we can evaluate dependencies first) - for node in sorted_cohort_ids_as_graph_nodes.into_iter().rev() { - let cohort_id = cohort_dependency_graph[node]; - let cohort = cohort_cache - .get_cohort_by_id(team_id, cohort_dependency_graph[node]) - .await?; - let property_filters = cohort.parse_filters()?; - let dependencies = cohort.extract_dependencies()?; - - // Check if all dependencies have been met (i.e., previous cohorts matched) - let dependencies_met = dependencies - .iter() - .all(|dep_id| evaluation_results.get(dep_id).copied().unwrap_or(false)); - - // If dependencies are not met, mark the current cohort as not matched and continue - // NB: We don't want to _exit_ here, since the non-matching cohort could be wrapped in a `not_in` operator - // and we want to evaluate all cohorts to determine if the initial cohort matches. - if !dependencies_met { - evaluation_results.insert(cohort_id, false); - continue; - } - - // Evaluate all property filters for the current cohort - let all_filters_match = property_filters - .iter() - .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); - - // Store the evaluation result for the current cohort - evaluation_results.insert(cohort_id, all_filters_match); - } - - // Retrieve and return the evaluation result for the initial cohort - evaluation_results - .get(&initial_cohort_id) - .copied() - .ok_or_else(|| FlagError::CohortNotFound(initial_cohort_id.to_string())) - } - - /// Apply cohort membership logic (i.e., IN|NOT_IN) - fn apply_cohort_membership_logic( - &self, - cohort_filters: &[&PropertyFilter], - cohort_matches: &HashMap, - ) -> Result { - for filter in cohort_filters { - let cohort_id = filter.get_cohort_id()?; - let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); - let operator = filter.operator.unwrap_or(OperatorType::In); - - // Combine the operator logic directly within this method - let membership_match = match operator { - OperatorType::In => matches, - OperatorType::NotIn => !matches, - // Currently supported operators are IN and NOT IN - // Any other operator defaults to false - _ => false, - }; - - // If any filter does not match, return false early - if !membership_match { - return Ok(false); - } - } - // All filters matched - Ok(true) + apply_cohort_membership_logic(cohort_property_filters, &cohort_matches) } /// Check if a super condition matches for a feature flag. @@ -1200,7 +1103,115 @@ impl FeatureFlagMatcher { } } +/// Evaluates a single cohort and its dependencies. +/// This uses a topological sort to evaluate dependencies first, which is necessary +/// because a cohort can depend on another cohort, and we need to respect the dependency order. +async fn evaluate_cohort_dependencies( + team_id: i32, + initial_cohort_id: CohortId, + target_properties: &HashMap, + cohort_cache: &CohortCache, +) -> Result { + let cohort_dependency_graph = + build_cohort_dependency_graph(team_id, initial_cohort_id, cohort_cache).await?; + + // We need to sort cohorts topologically to ensure we evaluate dependencies before the cohorts that depend on them. + // For example, if cohort A depends on cohort B, we need to evaluate B first to know if A matches. + // This also helps detect cycles - if cohort A depends on B which depends on A, toposort will fail. + let sorted_cohort_ids_as_graph_nodes = + toposort(&cohort_dependency_graph, None).map_err(|e| { + FlagError::CohortDependencyCycle(format!("Cyclic dependency detected: {:?}", e)) + })?; + + // Store evaluation results for each cohort in a map, so we can look up whether a cohort matched + // when evaluating cohorts that depend on it, and also return the final result for the initial cohort + let mut evaluation_results = HashMap::new(); + + // Iterate through the sorted nodes in reverse order (so that we can evaluate dependencies first) + for node in sorted_cohort_ids_as_graph_nodes.into_iter().rev() { + let cohort_id = cohort_dependency_graph[node]; + let cohort = cohort_cache + .get_cohort_by_id(team_id, cohort_dependency_graph[node]) + .await?; + let property_filters = cohort.parse_filters()?; + let dependencies = cohort.extract_dependencies()?; + + // Check if all dependencies have been met (i.e., previous cohorts matched) + let dependencies_met = dependencies + .iter() + .all(|dep_id| evaluation_results.get(dep_id).copied().unwrap_or(false)); + + // If dependencies are not met, mark the current cohort as not matched and continue + // NB: We don't want to _exit_ here, since the non-matching cohort could be wrapped in a `not_in` operator + // and we want to evaluate all cohorts to determine if the initial cohort matches. + if !dependencies_met { + evaluation_results.insert(cohort_id, false); + continue; + } + + // Evaluate all property filters for the current cohort + let all_filters_match = property_filters + .iter() + .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); + + // Store the evaluation result for the current cohort + evaluation_results.insert(cohort_id, all_filters_match); + } + + // Retrieve and return the evaluation result for the initial cohort + evaluation_results + .get(&initial_cohort_id) + .copied() + .ok_or_else(|| FlagError::CohortNotFound(initial_cohort_id.to_string())) +} + +/// Apply cohort membership logic (i.e., IN|NOT_IN) +fn apply_cohort_membership_logic( + cohort_filters: &[PropertyFilter], + cohort_matches: &HashMap, +) -> Result { + for filter in cohort_filters { + let cohort_id = filter.get_cohort_id()?; + let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); + let operator = filter.operator.unwrap_or(OperatorType::In); + + // Combine the operator logic directly within this method + let membership_match = match operator { + OperatorType::In => matches, + OperatorType::NotIn => !matches, + // Currently supported operators are IN and NOT IN + // Any other operator defaults to false + _ => false, + }; + + // If any filter does not match, return false early + if !membership_match { + return Ok(false); + } + } + // All filters matched + Ok(true) +} + /// Constructs a dependency graph for cohorts. +/// +/// Example dependency graph: +/// ```text +/// A B +/// | /| +/// | / | +/// | / | +/// C D +/// \ / +/// \ / +/// E +/// ``` +/// In this example: +/// - Cohorts A and B are root nodes (no dependencies) +/// - C depends on A and B +/// - D depends on B +/// - E depends on C and D +/// The graph is acyclic, which is required for valid cohort dependencies. async fn build_cohort_dependency_graph( team_id: i32, initial_cohort_id: CohortId, @@ -1208,21 +1219,28 @@ async fn build_cohort_dependency_graph( ) -> Result, FlagError> { let mut graph = DiGraph::new(); let mut node_map = HashMap::new(); - - // Queue for BFS traversal let mut queue = VecDeque::new(); + // This implements a breadth-first search (BFS) traversal to build a directed graph of cohort dependencies. + // Starting from the initial cohort, we: + // 1. Add each cohort as a node in the graph + // 2. Track visited nodes in a map to avoid duplicates + // 3. For each cohort, get its dependencies and add directed edges from the cohort to its dependencies + // 4. Queue up any unvisited dependencies to process their dependencies later + // This builds up the full dependency graph level by level, which we can later check for cycles queue.push_back(initial_cohort_id); node_map.insert(initial_cohort_id, graph.add_node(initial_cohort_id)); while let Some(cohort_id) = queue.pop_front() { let cohort = cohort_cache.get_cohort_by_id(team_id, cohort_id).await?; let dependencies = cohort.extract_dependencies()?; - for dep_id in dependencies { // Retrieve the current node **before** mutable borrowing + // This is safe because we're not mutating the node map, + // and it keeps the borrow checker happy let current_node = node_map[&cohort_id]; - - // Add dependency node if not present + // Add dependency node if we haven't seen this cohort ID before in our traversal. + // This happens when we discover a new dependency that wasn't previously + // encountered while processing other cohorts in the graph. let dep_node = node_map .entry(dep_id) .or_insert_with(|| graph.add_node(dep_id)); @@ -1235,7 +1253,7 @@ async fn build_cohort_dependency_graph( } } - // Check for cycles + // Check for cycles, this is an directed acyclic graph so we use is_cyclic_directed if is_cyclic_directed(&graph) { return Err(FlagError::CohortDependencyCycle(format!( "Cyclic dependency detected starting at cohort {}", From 3cfc590e2f8d20c93200f89705338869328482a2 Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 31 Oct 2024 14:13:31 -0700 Subject: [PATCH 18/30] test for the cohort cache --- rust/feature-flags/src/cohort_cache.rs | 386 +++++++++++++++++++++++++ 1 file changed, 386 insertions(+) diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index f3582bf15dc37..978be17642f38 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -95,3 +95,389 @@ impl CohortCache { Err(FlagError::CohortNotFound(cohort_id.to_string())) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::{ + insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client, + setup_pg_writer_client, + }; + use serde_json::json; + use std::collections::HashSet; + + #[tokio::test] + async fn test_default_cache_is_empty() { + let cache = CohortCache::default(); + let cache_guard = cache.per_team_cohorts.read().await; + assert!(cache_guard.is_empty(), "Default cache should be empty"); + } + + #[tokio::test] + async fn test_new_with_team_initializes_cache() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + // Insert cohorts for the team + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Active Users".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert Active Users cohort"); + + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Power Users".to_string()), + json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "usage", "type": "person", "value": [100], "negation": false, "operator": "gt"}]}]}}), + false, + ) + .await + .expect("Failed to insert Power Users cohort"); + + // Initialize the cache with the team + let cache = CohortCache::new_with_team(team.id, postgres_reader.clone()) + .await + .expect("Failed to initialize CohortCache with team"); + + let cache_guard = cache.per_team_cohorts.read().await; + assert!( + cache_guard.contains_key(&team.id), + "Cache should contain the team_id" + ); + + let cohorts = cache_guard.get(&team.id).unwrap(); + assert_eq!(cohorts.len(), 2, "There should be 2 cohorts for the team"); + let cohort_names: HashSet = cohorts.iter().map(|c| c.name.clone()).collect(); + assert!( + cohort_names.contains("Active Users"), + "Cache should contain 'Active Users' cohort" + ); + assert!( + cohort_names.contains("Power Users"), + "Cache should contain 'Power Users' cohort" + ); + } + + #[tokio::test] + async fn test_get_cohort_by_id_success() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + let cohort = insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Active Users".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert Active Users cohort"); + + let cache = CohortCache::new_with_team(team.id, postgres_reader.clone()) + .await + .expect("Failed to initialize CohortCache with team"); + + let fetched_cohort = cache + .get_cohort_by_id(team.id, cohort.id) + .await + .expect("Failed to retrieve cohort by ID"); + + assert_eq!( + fetched_cohort.id, cohort.id, + "Fetched cohort ID should match" + ); + assert_eq!( + fetched_cohort.name, "Active Users", + "Fetched cohort name should match" + ); + } + + #[tokio::test] + async fn test_get_cohort_by_id_not_found() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + // Insert a cohort to ensure the team has at least one cohort + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Active Users".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert Active Users cohort"); + + let cache = CohortCache::new_with_team(team.id, postgres_reader.clone()) + .await + .expect("Failed to initialize CohortCache with team"); + + let non_existent_cohort_id = 9999; + let result = cache + .get_cohort_by_id(team.id, non_existent_cohort_id) + .await; + + assert!( + matches!(result, Err(FlagError::CohortNotFound(_))), + "Should return CohortNotFound error for non-existent cohort ID" + ); + } + + #[tokio::test] + async fn test_fetch_and_cache_all_cohorts() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + // Insert multiple cohorts for the team + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Active Users".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert Active Users cohort"); + + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Power Users".to_string()), + json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "usage", "type": "person", "value": [100], "negation": false, "operator": "gt"}]}]}}), + false, + ) + .await + .expect("Failed to insert Power Users cohort"); + + let cache = CohortCache::new(); + + // Fetch and cache all cohorts for the team + cache + .fetch_and_cache_all_cohorts(team.id, postgres_reader.clone()) + .await + .expect("Failed to fetch and cache all cohorts"); + + let cache_guard = cache.per_team_cohorts.read().await; + assert!( + cache_guard.contains_key(&team.id), + "Cache should contain the team_id" + ); + + let cohorts = cache_guard.get(&team.id).unwrap(); + assert_eq!( + cohorts.len(), + 2, + "There should be 2 cohorts cached for the team" + ); + + let cohort_names: HashSet = cohorts.iter().map(|c| c.name.clone()).collect(); + assert!( + cohort_names.contains("Active Users"), + "Cache should contain 'Active Users' cohort" + ); + assert!( + cohort_names.contains("Power Users"), + "Cache should contain 'Power Users' cohort" + ); + } + + #[tokio::test] + async fn test_cache_updates_on_new_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + // Initialize the cache + let cache = CohortCache::new(); + + // Fetch and cache all cohorts for the team (initially, there should be none) + cache + .fetch_and_cache_all_cohorts(team.id, postgres_reader.clone()) + .await + .expect("Failed to fetch and cache cohorts"); + + // Assert that the cache now contains the team_id with an empty vector + { + let cache_guard = cache.per_team_cohorts.read().await; + assert!( + cache_guard.contains_key(&team.id), + "Cache should contain the team_id after initial fetch" + ); + let cohorts = cache_guard.get(&team.id).unwrap(); + assert!( + cohorts.is_empty(), + "Cache for team_id should be empty initially" + ); + } + + // Insert a new cohort for the team + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("New Cohort".to_string()), + json!({ + "properties": { + "type": "AND", + "values": [{ + "type": "property", + "values": [{ + "key": "subscription", + "type": "person", + "value": ["premium"], + "negation": false, + "operator": "exact" + }] + }] + } + }), + false, + ) + .await + .expect("Failed to insert New Cohort"); + + // Update the cache by fetching again after inserting the new cohort + cache + .fetch_and_cache_all_cohorts(team.id, postgres_reader.clone()) + .await + .expect("Failed to update cache with new cohort"); + + // Verify the cache has been updated with the new cohort + { + let cache_guard = cache.per_team_cohorts.read().await; + assert!( + cache_guard.contains_key(&team.id), + "Cache should contain the team_id after update" + ); + + let cohorts = cache_guard.get(&team.id).unwrap(); + assert_eq!( + cohorts.len(), + 1, + "There should be 1 cohort cached for the team after update" + ); + assert_eq!( + cohorts[0].name, "New Cohort", + "Cached cohort should be 'New Cohort'" + ); + } + } + + #[tokio::test] + async fn test_cache_handles_multiple_teams() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + // Insert two teams + let team1 = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team1"); + let team2 = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team2"); + + // Insert cohorts for team1 + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team1.id, + Some("Team1 Cohort1".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "feature_x", "type": "feature", "value": [true], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert Team1 Cohort1"); + + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team1.id, + Some("Team1 Cohort2".to_string()), + json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "usage", "type": "feature", "value": [50], "negation": false, "operator": "gt"}]}]}}), + false, + ) + .await + .expect("Failed to insert Team1 Cohort2"); + + // Insert cohorts for team2 + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team2.id, + Some("Team2 Cohort1".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "region", "type": "geo", "value": ["NA"], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert Team2 Cohort1"); + + // Initialize and cache cohorts for both teams + let cache = CohortCache::new(); + + cache + .fetch_and_cache_all_cohorts(team1.id, postgres_reader.clone()) + .await + .expect("Failed to cache team1 cohorts"); + + cache + .fetch_and_cache_all_cohorts(team2.id, postgres_reader.clone()) + .await + .expect("Failed to cache team2 cohorts"); + + // Verify team1's cache + { + let cache_guard = cache.per_team_cohorts.read().await; + assert!( + cache_guard.contains_key(&team1.id), + "Cache should contain team1" + ); + let team1_cohorts = cache_guard.get(&team1.id).unwrap(); + let team1_names: HashSet = + team1_cohorts.iter().map(|c| c.name.clone()).collect(); + assert!( + team1_names.contains("Team1 Cohort1"), + "Cache should contain 'Team1 Cohort1'" + ); + assert!( + team1_names.contains("Team1 Cohort2"), + "Cache should contain 'Team1 Cohort2'" + ); + } + + // Verify team2's cache + { + let cache_guard = cache.per_team_cohorts.read().await; + assert!( + cache_guard.contains_key(&team2.id), + "Cache should contain team2" + ); + let team2_cohorts = cache_guard.get(&team2.id).unwrap(); + let team2_names: HashSet = + team2_cohorts.iter().map(|c| c.name.clone()).collect(); + assert!( + team2_names.contains("Team2 Cohort1"), + "Cache should contain 'Team2 Cohort1'" + ); + } + } +} From 3e8e5d2d960f6ba0a21d088d6468dfe354cd3c65 Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 31 Oct 2024 14:37:14 -0700 Subject: [PATCH 19/30] a few things --- rust/feature-flags/src/cohort_operations.rs | 27 ++++++++++++++++++++ rust/feature-flags/src/flag_definitions.rs | 13 +++++----- rust/feature-flags/src/flag_matching.rs | 28 ++++++++++++--------- 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/rust/feature-flags/src/cohort_operations.rs b/rust/feature-flags/src/cohort_operations.rs index 42a2c861f4ea0..ea4214ccdc08b 100644 --- a/rust/feature-flags/src/cohort_operations.rs +++ b/rust/feature-flags/src/cohort_operations.rs @@ -145,6 +145,33 @@ impl Cohort { } impl InnerCohortProperty { + /// Flattens the nested cohort property structure into a list of property filters. + /// + /// The cohort property structure in Postgres looks like: + /// ```json + /// { + /// "type": "OR", + /// "values": [ + /// { + /// "type": "OR", + /// "values": [ + /// { + /// "key": "email", + /// "value": "@posthog.com", + /// "type": "person", + /// "operator": "icontains" + /// }, + /// { + /// "key": "age", + /// "value": 25, + /// "type": "person", + /// "operator": "gt" + /// } + /// ] + /// } + /// ] + /// } + /// ``` pub fn to_property_filters(&self) -> Vec { self.values .iter() diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index 8b37072ec494b..d62ecc9e0e0c1 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -53,12 +53,13 @@ impl PropertyFilter { self.key == "id" && self.prop_type == "cohort" } - /// Returns the cohort id if the filter is a cohort filter - pub fn get_cohort_id(&self) -> Result { - self.value - .as_i64() - .map(|id| id as CohortId) - .ok_or(FlagError::CohortFiltersParsingError) + /// Returns the cohort id if the filter is a cohort filter, or None if it's not a cohort filter + /// or if the value cannot be parsed as a cohort id + pub fn get_cohort_id(&self) -> Option { + if !self.is_cohort() { + return None; + } + self.value.as_i64().map(|id| id as CohortId) } } diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 59fc6c5a34b52..28acda0252d93 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -720,7 +720,7 @@ impl FeatureFlagMatcher { /// It first checks if the condition has any property filters. If not, it performs a rollout check. /// Otherwise, it fetches the relevant properties and checks if they match the condition's filters. /// The function returns a tuple indicating whether the condition matched and the reason for the match. - pub async fn is_condition_match( + async fn is_condition_match( &mut self, feature_flag: &FeatureFlag, condition: &FlagGroupType, @@ -743,25 +743,25 @@ impl FeatureFlagMatcher { .cloned() .partition(|prop| prop.is_cohort()); - // Evaluate non-cohort properties first to get properties_to_check - let properties_to_check = self + // Get the relevant properties to check for the condition + let target_properties = self .get_properties_to_check(feature_flag, property_overrides, &non_cohort_filters) .await?; - // Evaluate cohort filters + // Evaluate non-cohort properties first, since they're cheaper to evaluate + if !all_properties_match(&non_cohort_filters, &target_properties) { + return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); + } + + // Evaluate cohort filters, if any if !cohort_filters.is_empty() { let cohorts_match = self - .evaluate_cohort_filters(&cohort_filters, &properties_to_check) + .evaluate_cohort_filters(&cohort_filters, &target_properties) .await?; if !cohorts_match { return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); } } - - // Evaluate non-cohort properties - if !all_properties_match(&non_cohort_filters, &properties_to_check) { - return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); - } } self.check_rollout(feature_flag, rollout_percentage, hash_key_overrides) @@ -845,7 +845,9 @@ impl FeatureFlagMatcher { // because evaluating a cohort requires evaluating all of its dependencies, which can be expensive. let mut cohort_matches = HashMap::new(); for filter in cohort_property_filters { - let cohort_id = filter.get_cohort_id()?; + let cohort_id = filter + .get_cohort_id() + .ok_or(FlagError::CohortFiltersParsingError)?; let match_result = evaluate_cohort_dependencies( self.team_id, cohort_id, @@ -1171,7 +1173,9 @@ fn apply_cohort_membership_logic( cohort_matches: &HashMap, ) -> Result { for filter in cohort_filters { - let cohort_id = filter.get_cohort_id()?; + let cohort_id = filter + .get_cohort_id() + .ok_or(FlagError::CohortFiltersParsingError)?; let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); let operator = filter.operator.unwrap_or(OperatorType::In); From 09317c4039229f5709f3553600735c303bc77a77 Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 1 Nov 2024 16:10:37 -0700 Subject: [PATCH 20/30] use global cohort cache --- rust/Cargo.lock | 48 +- rust/feature-flags/Cargo.toml | 1 + rust/feature-flags/src/cohort_cache.rs | 600 +++++------------- rust/feature-flags/src/flag_matching.rs | 136 +++- rust/feature-flags/src/request_handler.rs | 12 + rust/feature-flags/src/router.rs | 4 + rust/feature-flags/src/server.rs | 4 + .../tests/test_flag_matching_consistency.rs | 7 + 8 files changed, 354 insertions(+), 458 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 5bb5fc25b318d..efaebf561991b 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1315,6 +1315,7 @@ dependencies = [ "futures", "health", "maxminddb", + "moka", "once_cell", "petgraph", "rand", @@ -2467,6 +2468,30 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "moka" +version = "0.12.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cf62eb4dd975d2dde76432fb1075c49e3ee2331cf36f1f8fd4b66550d32b6f" +dependencies = [ + "async-lock 3.4.0", + "async-trait", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "event-listener 5.3.1", + "futures-util", + "once_cell", + "parking_lot", + "quanta 0.12.2", + "rustc_version 0.4.1", + "smallvec", + "tagptr", + "thiserror", + "triomphe", + "uuid", +] + [[package]] name = "native-tls" version = "0.2.11" @@ -3455,6 +3480,15 @@ dependencies = [ "semver 0.9.0", ] +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver 1.0.23", +] + [[package]] name = "rustix" version = "0.37.27" @@ -3841,7 +3875,7 @@ dependencies = [ "debugid", "if_chain", "rustc-hash", - "rustc_version", + "rustc_version 0.2.3", "serde", "serde_json", "unicode-id-start", @@ -4208,6 +4242,12 @@ dependencies = [ "libc", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tap" version = "1.0.1" @@ -4600,6 +4640,12 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "triomphe" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "859eb650cfee7434994602c3a68b25d77ad9e68c8a6cd491616ef86661382eb3" + [[package]] name = "try-lock" version = "0.2.5" diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index 9847569394bb2..4099fd8ab06fd 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -40,6 +40,7 @@ common-metrics = { path = "../common/metrics" } tower = { workspace = true } derive_builder = "0.20.1" petgraph = "0.6.5" +moka = { version = "0.12.8", features = ["future"] } [lints] workspace = true diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index 978be17642f38..a948289d899d8 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -1,483 +1,235 @@ use crate::api::FlagError; -use crate::cohort_models::{Cohort, CohortId}; -use crate::flag_matching::PostgresReader; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::RwLock; +use crate::cohort_models::Cohort; +use crate::flag_matching::{PostgresReader, TeamId}; +use moka::future::Cache; +use std::time::Duration; -/// CohortCache manages the in-memory cache of cohorts +/// CohortCache manages the in-memory cache of cohorts using `moka` for caching. /// -/// Example cache structure: /// ```text /// per_team_cohorts: { -/// 1: [ -/// Cohort { id: 101, name: "Active Users", filters: [...] }, -/// Cohort { id: 102, name: "Power Users", filters: [...] } -/// ], -/// 2: [ -/// Cohort { id: 201, name: "New Users", filters: [...] }, -/// Cohort { id: 202, name: "Churned Users", filters: [...] } -/// ] +/// 1: [Cohort { id: 101, name: "Active Users", filters: [...] }, ...], +/// 2: [Cohort { id: 201, name: "New Users", filters: [...] }, ...] /// } /// ``` +/// +/// Features: +/// - **TTL**: Each cache entry expires after 5 minutes. +/// - **Size-based eviction**: The cache evicts least recently used entries when the maximum capacity is reached. +/// +/// Caches only successful cohort lists to maintain cache integrity. #[derive(Clone)] pub struct CohortCache { - pub per_team_cohorts: Arc>>>, // team_id -> list of Cohorts -} - -impl Default for CohortCache { - fn default() -> Self { - Self::new() - } + postgres_reader: PostgresReader, + per_team_cohorts: Cache>, // team_id -> list of Cohorts } impl CohortCache { - pub fn new() -> Self { - Self { - per_team_cohorts: Arc::new(RwLock::new(HashMap::new())), - } - } - - /// Asynchronous constructor that initializes the CohortCache by fetching and caching cohorts for the given team_id - pub async fn new_with_team( - team_id: i32, + /// Creates a new `CohortCache` with configurable TTL and maximum capacity. + pub fn new( postgres_reader: PostgresReader, - ) -> Result { - let cache = Self { - per_team_cohorts: Arc::new(RwLock::new(HashMap::new())), + max_capacity: Option, + ttl_seconds: Option, + ) -> Self { + // Define the weigher closure. Here, we consider the number of cohorts as the weight. + let weigher = |_: &TeamId, value: &Vec| -> u32 { + return value.len().try_into().unwrap_or(u32::MAX); }; - cache - .fetch_and_cache_all_cohorts(team_id, postgres_reader) - .await?; - Ok(cache) - } - /// Fetches and caches all cohorts for a given team - /// - /// Cache structure: - /// ```text - /// per_team_cohorts: { - /// team_id_1: [ - /// Cohort { id: 1, filters: [...], ... }, - /// Cohort { id: 2, filters: [...], ... }, - /// ... - /// ], - /// team_id_2: [ - /// Cohort { id: 3, filters: [...], ... }, - /// ... - /// ] - /// } - /// ``` - async fn fetch_and_cache_all_cohorts( - &self, - team_id: i32, - postgres_reader: PostgresReader, - ) -> Result<(), FlagError> { - let cohorts = Cohort::list_from_pg(postgres_reader, team_id).await?; - let mut cache = self.per_team_cohorts.write().await; - cache.insert(team_id, cohorts); + // Initialize the Moka cache with TTL and size-based eviction. + let cache = Cache::builder() + .time_to_live(Duration::from_secs(ttl_seconds.unwrap_or(300))) // Default to 5 minutes + .weigher(weigher) + .max_capacity(max_capacity.unwrap_or(10_000)) // Default to 10,000 cohorts + .build(); - Ok(()) + Self { + postgres_reader, + per_team_cohorts: cache, + } } - /// Retrieves a specific cohort by ID for a given team - pub async fn get_cohort_by_id( - &self, - team_id: i32, - cohort_id: CohortId, - ) -> Result { - let cache = self.per_team_cohorts.read().await; - if let Some(cohorts) = cache.get(&team_id) { - if let Some(cohort) = cohorts.iter().find(|c| c.id == cohort_id) { - return Ok(cohort.clone()); - } + /// Retrieves cohorts for a given team. + /// + /// If the cohorts are not present in the cache or have expired, it fetches them from the database, + /// caches the result upon successful retrieval, and then returns it. + pub async fn get_cohorts_for_team(&self, team_id: TeamId) -> Result, FlagError> { + if let Some(cached_cohorts) = self.per_team_cohorts.get(&team_id).await { + return Ok(cached_cohorts.clone()); } - Err(FlagError::CohortNotFound(cohort_id.to_string())) + // Fetch from database + let fetched_cohorts = Cohort::list_from_pg(self.postgres_reader.clone(), team_id).await?; + // Insert into cache + self.per_team_cohorts + .insert(team_id.clone(), fetched_cohorts.clone()) + .await; + + Ok(fetched_cohorts) } } #[cfg(test)] mod tests { use super::*; + use crate::cohort_models::Cohort; use crate::test_utils::{ insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client, setup_pg_writer_client, }; - use serde_json::json; - use std::collections::HashSet; + use std::sync::Arc; + use tokio::time::{sleep, Duration}; + + /// Helper function to setup a new team for testing. + async fn setup_test_team( + writer_client: Arc, + ) -> Result { + let team = crate::test_utils::insert_new_team_in_pg(writer_client, None).await?; + Ok(team.id) + } - #[tokio::test] - async fn test_default_cache_is_empty() { - let cache = CohortCache::default(); - let cache_guard = cache.per_team_cohorts.read().await; - assert!(cache_guard.is_empty(), "Default cache should be empty"); + /// Helper function to insert a cohort for a team. + async fn setup_test_cohort( + writer_client: Arc, + team_id: TeamId, // Adjusted to accept TeamId + name: Option, + ) -> Result { + let filters = serde_json::json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}); + insert_cohort_for_team_in_pg(writer_client, team_id, name, filters, false).await } + /// Tests that cache entries expire after the specified TTL. #[tokio::test] - async fn test_new_with_team_initializes_cache() { - let postgres_reader = setup_pg_reader_client(None).await; - let postgres_writer = setup_pg_writer_client(None).await; - - let team = insert_new_team_in_pg(postgres_reader.clone(), None) - .await - .expect("Failed to insert team"); - - // Insert cohorts for the team - insert_cohort_for_team_in_pg( - postgres_writer.clone(), - team.id, - Some("Active Users".to_string()), - json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}), - false, - ) - .await - .expect("Failed to insert Active Users cohort"); - - insert_cohort_for_team_in_pg( - postgres_writer.clone(), - team.id, - Some("Power Users".to_string()), - json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "usage", "type": "person", "value": [100], "negation": false, "operator": "gt"}]}]}}), - false, - ) - .await - .expect("Failed to insert Power Users cohort"); - - // Initialize the cache with the team - let cache = CohortCache::new_with_team(team.id, postgres_reader.clone()) - .await - .expect("Failed to initialize CohortCache with team"); - - let cache_guard = cache.per_team_cohorts.read().await; - assert!( - cache_guard.contains_key(&team.id), - "Cache should contain the team_id" + async fn test_cache_expiry() -> Result<(), anyhow::Error> { + // Setup PostgreSQL clients + let writer_client = setup_pg_writer_client(None).await; + let reader_client = setup_pg_reader_client(None).await; + + // Setup test team and cohort + let team_id = setup_test_team(writer_client.clone()).await?; + let _cohort = setup_test_cohort(writer_client.clone(), team_id.clone(), None).await?; + + // Initialize CohortCache with a short TTL for testing + let cohort_cache = CohortCache::new( + reader_client.clone(), + Some(100), + Some(1), // 1-second TTL ); - let cohorts = cache_guard.get(&team.id).unwrap(); - assert_eq!(cohorts.len(), 2, "There should be 2 cohorts for the team"); - let cohort_names: HashSet = cohorts.iter().map(|c| c.name.clone()).collect(); - assert!( - cohort_names.contains("Active Users"), - "Cache should contain 'Active Users' cohort" - ); - assert!( - cohort_names.contains("Power Users"), - "Cache should contain 'Power Users' cohort" - ); - } + // Fetch cohorts, which should populate the cache + let cohorts = cohort_cache.get_cohorts_for_team(team_id.clone()).await?; + assert_eq!(cohorts.len(), 1); + assert_eq!(cohorts[0].team_id, team_id.clone()); - #[tokio::test] - async fn test_get_cohort_by_id_success() { - let postgres_reader = setup_pg_reader_client(None).await; - let postgres_writer = setup_pg_writer_client(None).await; - - let team = insert_new_team_in_pg(postgres_reader.clone(), None) - .await - .expect("Failed to insert team"); - - let cohort = insert_cohort_for_team_in_pg( - postgres_writer.clone(), - team.id, - Some("Active Users".to_string()), - json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}), - false, - ) - .await - .expect("Failed to insert Active Users cohort"); - - let cache = CohortCache::new_with_team(team.id, postgres_reader.clone()) - .await - .expect("Failed to initialize CohortCache with team"); - - let fetched_cohort = cache - .get_cohort_by_id(team.id, cohort.id) - .await - .expect("Failed to retrieve cohort by ID"); + // Ensure the cohort is cached + let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await; + assert!(cached_cohorts.is_some()); - assert_eq!( - fetched_cohort.id, cohort.id, - "Fetched cohort ID should match" - ); - assert_eq!( - fetched_cohort.name, "Active Users", - "Fetched cohort name should match" - ); - } + // Wait for TTL to expire + sleep(Duration::from_secs(2)).await; - #[tokio::test] - async fn test_get_cohort_by_id_not_found() { - let postgres_reader = setup_pg_reader_client(None).await; - let postgres_writer = setup_pg_writer_client(None).await; - - let team = insert_new_team_in_pg(postgres_reader.clone(), None) - .await - .expect("Failed to insert team"); - - // Insert a cohort to ensure the team has at least one cohort - insert_cohort_for_team_in_pg( - postgres_writer.clone(), - team.id, - Some("Active Users".to_string()), - json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}), - false, - ) - .await - .expect("Failed to insert Active Users cohort"); - - let cache = CohortCache::new_with_team(team.id, postgres_reader.clone()) - .await - .expect("Failed to initialize CohortCache with team"); - - let non_existent_cohort_id = 9999; - let result = cache - .get_cohort_by_id(team.id, non_existent_cohort_id) - .await; + // Attempt to retrieve from cache again + let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await; + assert!(cached_cohorts.is_none(), "Cache entry should have expired"); - assert!( - matches!(result, Err(FlagError::CohortNotFound(_))), - "Should return CohortNotFound error for non-existent cohort ID" - ); + Ok(()) } + /// Tests that the cache correctly evicts least recently used entries based on the weigher. #[tokio::test] - async fn test_fetch_and_cache_all_cohorts() { - let postgres_reader = setup_pg_reader_client(None).await; - let postgres_writer = setup_pg_writer_client(None).await; - - let team = insert_new_team_in_pg(postgres_reader.clone(), None) - .await - .expect("Failed to insert team"); - - // Insert multiple cohorts for the team - insert_cohort_for_team_in_pg( - postgres_writer.clone(), - team.id, - Some("Active Users".to_string()), - json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}), - false, - ) - .await - .expect("Failed to insert Active Users cohort"); - - insert_cohort_for_team_in_pg( - postgres_writer.clone(), - team.id, - Some("Power Users".to_string()), - json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "usage", "type": "person", "value": [100], "negation": false, "operator": "gt"}]}]}}), - false, - ) - .await - .expect("Failed to insert Power Users cohort"); - - let cache = CohortCache::new(); - - // Fetch and cache all cohorts for the team - cache - .fetch_and_cache_all_cohorts(team.id, postgres_reader.clone()) - .await - .expect("Failed to fetch and cache all cohorts"); - - let cache_guard = cache.per_team_cohorts.read().await; - assert!( - cache_guard.contains_key(&team.id), - "Cache should contain the team_id" + async fn test_cache_weigher() -> Result<(), anyhow::Error> { + // Setup PostgreSQL clients + let writer_client = setup_pg_writer_client(None).await; + let reader_client = setup_pg_reader_client(None).await; + + // Define a smaller max_capacity and TTL for testing + let max_capacity: u64 = 3; + let ttl_seconds: u64 = 300; // 5 minutes + + // Initialize CohortCache + let cohort_cache = + CohortCache::new(reader_client.clone(), Some(max_capacity), Some(ttl_seconds)); + + let mut inserted_team_ids = Vec::new(); + + // Insert multiple teams and their cohorts + for _ in 0..max_capacity { + let team = insert_new_team_in_pg(writer_client.clone(), None).await?; + let team_id = team.id; + inserted_team_ids.push(team_id.clone()); + setup_test_cohort(writer_client.clone(), team_id.clone(), None).await?; + cohort_cache.get_cohorts_for_team(team_id.clone()).await?; + } + + // The cache should be at max_capacity + cohort_cache.per_team_cohorts.run_pending_tasks().await; + let cache_size = cohort_cache.per_team_cohorts.entry_count(); + assert_eq!( + cache_size, max_capacity, + "Cache size should be equal to max_capacity" ); - let cohorts = cache_guard.get(&team.id).unwrap(); + // Insert one more team to trigger eviction + let new_team = insert_new_team_in_pg(writer_client.clone(), None).await?; + let new_team_id = new_team.id; + setup_test_cohort(writer_client.clone(), new_team_id.clone(), None).await?; + cohort_cache + .get_cohorts_for_team(new_team_id.clone()) + .await?; + + // Now, the cache should still have max_capacity entries + cohort_cache.per_team_cohorts.run_pending_tasks().await; + let cache_size_after = cohort_cache.per_team_cohorts.entry_count(); assert_eq!( - cohorts.len(), - 2, - "There should be 2 cohorts cached for the team" + cache_size_after, max_capacity, + "Cache size should remain equal to max_capacity after eviction" ); - let cohort_names: HashSet = cohorts.iter().map(|c| c.name.clone()).collect(); + // The least recently used team should have been evicted + let evicted_team_id = &inserted_team_ids[0]; + let cached_cohorts = cohort_cache.per_team_cohorts.get(evicted_team_id).await; assert!( - cohort_names.contains("Active Users"), - "Cache should contain 'Active Users' cohort" + cached_cohorts.is_none(), + "Least recently used cache entry should have been evicted" ); + + // The new team should be present + let cached_new_team = cohort_cache.per_team_cohorts.get(&new_team_id).await; assert!( - cohort_names.contains("Power Users"), - "Cache should contain 'Power Users' cohort" + cached_new_team.is_some(), + "Newly added cache entry should be present" ); + + Ok(()) } + /// Functional test to verify that fetching cohorts populates the cache correctly. #[tokio::test] - async fn test_cache_updates_on_new_cohort() { - let postgres_reader = setup_pg_reader_client(None).await; - let postgres_writer = setup_pg_writer_client(None).await; - - let team = insert_new_team_in_pg(postgres_reader.clone(), None) - .await - .expect("Failed to insert team"); - - // Initialize the cache - let cache = CohortCache::new(); - - // Fetch and cache all cohorts for the team (initially, there should be none) - cache - .fetch_and_cache_all_cohorts(team.id, postgres_reader.clone()) - .await - .expect("Failed to fetch and cache cohorts"); - - // Assert that the cache now contains the team_id with an empty vector - { - let cache_guard = cache.per_team_cohorts.read().await; - assert!( - cache_guard.contains_key(&team.id), - "Cache should contain the team_id after initial fetch" - ); - let cohorts = cache_guard.get(&team.id).unwrap(); - assert!( - cohorts.is_empty(), - "Cache for team_id should be empty initially" - ); - } + async fn test_get_cohorts_for_team() -> Result<(), anyhow::Error> { + // Setup PostgreSQL clients + let writer_client = setup_pg_writer_client(None).await; + let reader_client = setup_pg_reader_client(None).await; - // Insert a new cohort for the team - insert_cohort_for_team_in_pg( - postgres_writer.clone(), - team.id, - Some("New Cohort".to_string()), - json!({ - "properties": { - "type": "AND", - "values": [{ - "type": "property", - "values": [{ - "key": "subscription", - "type": "person", - "value": ["premium"], - "negation": false, - "operator": "exact" - }] - }] - } - }), - false, - ) - .await - .expect("Failed to insert New Cohort"); - - // Update the cache by fetching again after inserting the new cohort - cache - .fetch_and_cache_all_cohorts(team.id, postgres_reader.clone()) - .await - .expect("Failed to update cache with new cohort"); - - // Verify the cache has been updated with the new cohort - { - let cache_guard = cache.per_team_cohorts.read().await; - assert!( - cache_guard.contains_key(&team.id), - "Cache should contain the team_id after update" - ); - - let cohorts = cache_guard.get(&team.id).unwrap(); - assert_eq!( - cohorts.len(), - 1, - "There should be 1 cohort cached for the team after update" - ); - assert_eq!( - cohorts[0].name, "New Cohort", - "Cached cohort should be 'New Cohort'" - ); - } - } + // Setup test team and cohort + let team_id = setup_test_team(writer_client.clone()).await?; + let _cohort = setup_test_cohort(writer_client.clone(), team_id.clone(), None).await?; - #[tokio::test] - async fn test_cache_handles_multiple_teams() { - let postgres_reader = setup_pg_reader_client(None).await; - let postgres_writer = setup_pg_writer_client(None).await; - - // Insert two teams - let team1 = insert_new_team_in_pg(postgres_reader.clone(), None) - .await - .expect("Failed to insert team1"); - let team2 = insert_new_team_in_pg(postgres_reader.clone(), None) - .await - .expect("Failed to insert team2"); - - // Insert cohorts for team1 - insert_cohort_for_team_in_pg( - postgres_writer.clone(), - team1.id, - Some("Team1 Cohort1".to_string()), - json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "feature_x", "type": "feature", "value": [true], "negation": false, "operator": "exact"}]}]}}), - false, - ) - .await - .expect("Failed to insert Team1 Cohort1"); - - insert_cohort_for_team_in_pg( - postgres_writer.clone(), - team1.id, - Some("Team1 Cohort2".to_string()), - json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "usage", "type": "feature", "value": [50], "negation": false, "operator": "gt"}]}]}}), - false, - ) - .await - .expect("Failed to insert Team1 Cohort2"); - - // Insert cohorts for team2 - insert_cohort_for_team_in_pg( - postgres_writer.clone(), - team2.id, - Some("Team2 Cohort1".to_string()), - json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "region", "type": "geo", "value": ["NA"], "negation": false, "operator": "exact"}]}]}}), - false, - ) - .await - .expect("Failed to insert Team2 Cohort1"); - - // Initialize and cache cohorts for both teams - let cache = CohortCache::new(); - - cache - .fetch_and_cache_all_cohorts(team1.id, postgres_reader.clone()) - .await - .expect("Failed to cache team1 cohorts"); - - cache - .fetch_and_cache_all_cohorts(team2.id, postgres_reader.clone()) - .await - .expect("Failed to cache team2 cohorts"); - - // Verify team1's cache - { - let cache_guard = cache.per_team_cohorts.read().await; - assert!( - cache_guard.contains_key(&team1.id), - "Cache should contain team1" - ); - let team1_cohorts = cache_guard.get(&team1.id).unwrap(); - let team1_names: HashSet = - team1_cohorts.iter().map(|c| c.name.clone()).collect(); - assert!( - team1_names.contains("Team1 Cohort1"), - "Cache should contain 'Team1 Cohort1'" - ); - assert!( - team1_names.contains("Team1 Cohort2"), - "Cache should contain 'Team1 Cohort2'" - ); - } + // Initialize CohortCache + let cohort_cache = CohortCache::new(reader_client.clone(), None, None); - // Verify team2's cache - { - let cache_guard = cache.per_team_cohorts.read().await; - assert!( - cache_guard.contains_key(&team2.id), - "Cache should contain team2" - ); - let team2_cohorts = cache_guard.get(&team2.id).unwrap(); - let team2_names: HashSet = - team2_cohorts.iter().map(|c| c.name.clone()).collect(); - assert!( - team2_names.contains("Team2 Cohort1"), - "Cache should contain 'Team2 Cohort1'" - ); - } + // Initially, cache should be empty + let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await; + assert!(cached_cohorts.is_none(), "Cache should initially be empty"); + + // Fetch cohorts, which should populate the cache + let cohorts = cohort_cache.get_cohorts_for_team(team_id.clone()).await?; + assert_eq!(cohorts.len(), 1); + assert_eq!(cohorts[0].team_id, team_id.clone()); + + // Now, cache should have the cohorts + let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await.unwrap(); + assert_eq!(cached_cohorts.len(), 1); + assert_eq!(cached_cohorts[0].team_id, team_id.clone()); + + Ok(()) } } diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 28acda0252d93..9255175a3816e 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,7 +1,7 @@ use crate::{ api::{FlagError, FlagValue, FlagsResponse}, cohort_cache::CohortCache, - cohort_models::CohortId, + cohort_models::{Cohort, CohortId}, database::Client as DatabaseClient, feature_flag_match_reason::FeatureFlagMatchReason, flag_definitions::{FeatureFlag, FeatureFlagList, FlagGroupType, OperatorType, PropertyFilter}, @@ -186,6 +186,7 @@ pub struct FeatureFlagMatcher { pub team_id: TeamId, pub postgres_reader: PostgresReader, pub postgres_writer: PostgresWriter, + pub cohort_cache: Arc, group_type_mapping_cache: GroupTypeMappingCache, properties_cache: PropertiesCache, groups: HashMap, @@ -199,6 +200,7 @@ impl FeatureFlagMatcher { team_id: TeamId, postgres_reader: PostgresReader, postgres_writer: PostgresWriter, + cohort_cache: Arc, group_type_mapping_cache: Option, properties_cache: Option, groups: Option>, @@ -208,6 +210,7 @@ impl FeatureFlagMatcher { team_id, postgres_reader: postgres_reader.clone(), postgres_writer: postgres_writer.clone(), + cohort_cache, groups: groups.unwrap_or_default(), group_type_mapping_cache: group_type_mapping_cache .unwrap_or_else(|| GroupTypeMappingCache::new(team_id, postgres_reader.clone())), @@ -836,9 +839,10 @@ impl FeatureFlagMatcher { cohort_property_filters: &[PropertyFilter], target_properties: &HashMap, ) -> Result { - // Caching all of the cohorts like this will make it so that we don't have to hit the database for each cohort filter - let cohort_cache = - CohortCache::new_with_team(self.team_id, self.postgres_reader.clone()).await?; + // At the start of the request, fetch all of the cohorts for the team from the cache + // This method also caches the cohorts in memory for the duration of the application, so we don't need to fetch from + // the database again until we restart the application. + let cohorts = self.cohort_cache.get_cohorts_for_team(self.team_id).await?; // Store cohort match results in a HashMap to avoid re-evaluating the same cohort multiple times, // since the same cohort could appear in multiple property filters. This is especially important @@ -848,13 +852,8 @@ impl FeatureFlagMatcher { let cohort_id = filter .get_cohort_id() .ok_or(FlagError::CohortFiltersParsingError)?; - let match_result = evaluate_cohort_dependencies( - self.team_id, - cohort_id, - target_properties, - &cohort_cache, - ) - .await?; + let match_result = + evaluate_cohort_dependencies(cohort_id, target_properties, cohorts.clone())?; cohort_matches.insert(cohort_id, match_result); } @@ -1108,14 +1107,13 @@ impl FeatureFlagMatcher { /// Evaluates a single cohort and its dependencies. /// This uses a topological sort to evaluate dependencies first, which is necessary /// because a cohort can depend on another cohort, and we need to respect the dependency order. -async fn evaluate_cohort_dependencies( - team_id: i32, +fn evaluate_cohort_dependencies( initial_cohort_id: CohortId, target_properties: &HashMap, - cohort_cache: &CohortCache, + cohorts: Vec, ) -> Result { let cohort_dependency_graph = - build_cohort_dependency_graph(team_id, initial_cohort_id, cohort_cache).await?; + build_cohort_dependency_graph(initial_cohort_id, cohorts.clone())?; // We need to sort cohorts topologically to ensure we evaluate dependencies before the cohorts that depend on them. // For example, if cohort A depends on cohort B, we need to evaluate B first to know if A matches. @@ -1132,9 +1130,10 @@ async fn evaluate_cohort_dependencies( // Iterate through the sorted nodes in reverse order (so that we can evaluate dependencies first) for node in sorted_cohort_ids_as_graph_nodes.into_iter().rev() { let cohort_id = cohort_dependency_graph[node]; - let cohort = cohort_cache - .get_cohort_by_id(team_id, cohort_dependency_graph[node]) - .await?; + let cohort = cohorts + .iter() + .find(|c| c.id == cohort_id) + .ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?; let property_filters = cohort.parse_filters()?; let dependencies = cohort.extract_dependencies()?; @@ -1216,10 +1215,9 @@ fn apply_cohort_membership_logic( /// - D depends on B /// - E depends on C and D /// The graph is acyclic, which is required for valid cohort dependencies. -async fn build_cohort_dependency_graph( - team_id: i32, +fn build_cohort_dependency_graph( initial_cohort_id: CohortId, - cohort_cache: &CohortCache, + cohorts: Vec, ) -> Result, FlagError> { let mut graph = DiGraph::new(); let mut node_map = HashMap::new(); @@ -1235,7 +1233,10 @@ async fn build_cohort_dependency_graph( node_map.insert(initial_cohort_id, graph.add_node(initial_cohort_id)); while let Some(cohort_id) = queue.pop_front() { - let cohort = cohort_cache.get_cohort_by_id(team_id, cohort_id).await?; + let cohort = cohorts + .iter() + .find(|c| c.id == cohort_id) + .ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?; let dependencies = cohort.extract_dependencies()?; for dep_id in dependencies { // Retrieve the current node **before** mutable borrowing @@ -1705,6 +1706,7 @@ mod tests { async fn test_fetch_properties_from_pg_to_match() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await @@ -1754,6 +1756,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -1767,6 +1770,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -1780,6 +1784,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -1793,6 +1798,7 @@ mod tests { async fn test_person_property_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1832,6 +1838,7 @@ mod tests { team.id, postgres_reader, postgres_writer, + cohort_cache, None, None, None, @@ -1854,6 +1861,7 @@ mod tests { async fn test_group_property_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1886,10 +1894,12 @@ mod tests { None, ); - let mut cache = GroupTypeMappingCache::new(team.id, postgres_reader.clone()); + let mut group_type_mapping_cache = + GroupTypeMappingCache::new(team.id, postgres_reader.clone()); let group_types_to_indexes = [("organization".to_string(), 1)].into_iter().collect(); - cache.group_types_to_indexes = group_types_to_indexes; - cache.group_indexes_to_types = [(1, "organization".to_string())].into_iter().collect(); + group_type_mapping_cache.group_types_to_indexes = group_types_to_indexes; + group_type_mapping_cache.group_indexes_to_types = + [(1, "organization".to_string())].into_iter().collect(); let groups = HashMap::from([("organization".to_string(), json!("org_123"))]); @@ -1906,7 +1916,8 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - Some(cache), + cohort_cache.clone(), + Some(group_type_mapping_cache), None, Some(groups), ); @@ -1930,14 +1941,14 @@ mod tests { let flag = create_test_flag_with_variants(1); let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - - let mut cache = GroupTypeMappingCache::new(1, postgres_reader.clone()); + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let mut group_type_mapping_cache = GroupTypeMappingCache::new(1, postgres_reader.clone()); let group_types_to_indexes = [("group_type_1".to_string(), 1)].into_iter().collect(); let group_type_index_to_name = [(1, "group_type_1".to_string())].into_iter().collect(); - cache.group_types_to_indexes = group_types_to_indexes; - cache.group_indexes_to_types = group_type_index_to_name; + group_type_mapping_cache.group_types_to_indexes = group_types_to_indexes; + group_type_mapping_cache.group_indexes_to_types = group_type_index_to_name; let groups = HashMap::from([("group_type_1".to_string(), json!("group_key_1"))]); @@ -1946,7 +1957,8 @@ mod tests { 1, postgres_reader.clone(), postgres_writer.clone(), - Some(cache), + cohort_cache.clone(), + Some(group_type_mapping_cache), None, Some(groups), ); @@ -1962,6 +1974,7 @@ mod tests { async fn test_get_matching_variant_with_db() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1973,6 +1986,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -1987,6 +2001,7 @@ mod tests { async fn test_is_condition_match_empty_properties() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -2019,6 +2034,7 @@ mod tests { 1, postgres_reader, postgres_writer, + cohort_cache, None, None, None, @@ -2076,6 +2092,7 @@ mod tests { async fn test_overrides_avoid_db_lookups() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2116,6 +2133,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -2146,6 +2164,7 @@ mod tests { async fn test_fallback_to_db_when_overrides_insufficient() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2207,6 +2226,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -2231,6 +2251,7 @@ mod tests { async fn test_property_fetching_and_caching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2250,6 +2271,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -2275,6 +2297,7 @@ mod tests { async fn test_property_caching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2294,6 +2317,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -2327,6 +2351,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -2418,6 +2443,7 @@ mod tests { async fn test_concurrent_flag_evaluation() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2447,12 +2473,14 @@ mod tests { let flag_clone = flag.clone(); let postgres_reader_clone = postgres_reader.clone(); let postgres_writer_clone = postgres_writer.clone(); + let cohort_cache_clone = cohort_cache.clone(); handles.push(tokio::spawn(async move { let mut matcher = FeatureFlagMatcher::new( format!("test_user_{}", i), team.id, postgres_reader_clone, postgres_writer_clone, + cohort_cache_clone, None, None, None, @@ -2475,6 +2503,7 @@ mod tests { async fn test_property_operators() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2531,6 +2560,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -2545,7 +2575,7 @@ mod tests { async fn test_empty_hashed_identifier() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -2572,6 +2602,7 @@ mod tests { 1, postgres_reader, postgres_writer, + cohort_cache, None, None, None, @@ -2586,6 +2617,7 @@ mod tests { async fn test_rollout_percentage() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let mut flag = create_test_flag( Some(1), None, @@ -2612,6 +2644,7 @@ mod tests { 1, postgres_reader, postgres_writer, + cohort_cache, None, None, None, @@ -2633,7 +2666,7 @@ mod tests { async fn test_uneven_variant_distribution() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let mut flag = create_test_flag_with_variants(1); // Adjust variant rollout percentages to be uneven @@ -2663,6 +2696,7 @@ mod tests { 1, postgres_reader, postgres_writer, + cohort_cache, None, None, None, @@ -2695,6 +2729,7 @@ mod tests { async fn test_missing_properties_in_db() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2742,6 +2777,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache, None, None, None, @@ -2756,6 +2792,7 @@ mod tests { async fn test_malformed_property_data() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2803,6 +2840,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache, None, None, None, @@ -2818,6 +2856,7 @@ mod tests { async fn test_get_match_with_insufficient_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2879,6 +2918,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache, None, None, None, @@ -2896,6 +2936,7 @@ mod tests { async fn test_evaluation_reasons() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -2922,6 +2963,7 @@ mod tests { 1, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache, None, None, None, @@ -2940,6 +2982,7 @@ mod tests { async fn test_complex_conditions() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3000,6 +3043,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache, None, None, None, @@ -3014,6 +3058,7 @@ mod tests { async fn test_super_condition_matches_boolean() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3090,6 +3135,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3100,6 +3146,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3110,6 +3157,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3137,6 +3185,7 @@ mod tests { async fn test_super_condition_matches_string() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3213,6 +3262,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3229,6 +3279,7 @@ mod tests { async fn test_super_condition_matches_and_false() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3305,6 +3356,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3315,6 +3367,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3325,6 +3378,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3366,6 +3420,7 @@ mod tests { async fn test_basic_cohort_matching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3439,6 +3494,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3453,6 +3509,7 @@ mod tests { async fn test_not_in_cohort_matching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3526,6 +3583,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3540,6 +3598,7 @@ mod tests { async fn test_not_in_cohort_matching_user_in_cohort() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3613,6 +3672,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3628,6 +3688,7 @@ mod tests { async fn test_cohort_dependent_on_another_cohort() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3726,6 +3787,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3741,6 +3803,7 @@ mod tests { async fn test_in_cohort_matching_user_not_in_cohort() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3814,6 +3877,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -3979,6 +4043,7 @@ mod tests { async fn test_evaluate_feature_flags_with_experience_continuity() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -4042,6 +4107,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -4061,6 +4127,7 @@ mod tests { async fn test_evaluate_feature_flags_with_continuity_missing_override() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -4114,6 +4181,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, @@ -4133,6 +4201,7 @@ mod tests { async fn test_evaluate_all_feature_flags_mixed_continuity() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -4225,6 +4294,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), + cohort_cache.clone(), None, None, None, diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs index bd770be819d4d..5277a2565ef2b 100644 --- a/rust/feature-flags/src/request_handler.rs +++ b/rust/feature-flags/src/request_handler.rs @@ -1,5 +1,6 @@ use crate::{ api::{FlagError, FlagsResponse}, + cohort_cache::CohortCache, database::Client, flag_definitions::FeatureFlagList, flag_matching::{FeatureFlagMatcher, GroupTypeMappingCache}, @@ -69,6 +70,7 @@ pub struct FeatureFlagEvaluationContext { feature_flags: FeatureFlagList, postgres_reader: Arc, postgres_writer: Arc, + cohort_cache: Arc, #[builder(default)] person_property_overrides: Option>, #[builder(default)] @@ -117,6 +119,7 @@ pub async fn process_request(context: RequestContext) -> Result Fl context.team_id, context.postgres_reader, context.postgres_writer, + context.cohort_cache, Some(group_type_mapping_cache), None, // TODO maybe remove this from the matcher struct, since it's used internally but not passed around context.groups, @@ -356,6 +360,7 @@ mod tests { async fn test_evaluate_feature_flags() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let flag = FeatureFlag { name: Some("Test Flag".to_string()), id: 1, @@ -395,6 +400,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .person_property_overrides(Some(person_properties)) .build() .expect("Failed to build FeatureFlagEvaluationContext"); @@ -503,6 +509,7 @@ mod tests { async fn test_evaluate_feature_flags_multiple_flags() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let flags = vec![ FeatureFlag { name: Some("Flag 1".to_string()), @@ -554,6 +561,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .build() .expect("Failed to build FeatureFlagEvaluationContext"); @@ -606,6 +614,7 @@ mod tests { async fn test_evaluate_feature_flags_with_overrides() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -654,6 +663,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .group_property_overrides(Some(group_property_overrides)) .groups(Some(groups)) .build() @@ -687,6 +697,7 @@ mod tests { let long_id = "a".repeat(1000); let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let flag = FeatureFlag { name: Some("Test Flag".to_string()), id: 1, @@ -716,6 +727,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .build() .expect("Failed to build FeatureFlagEvaluationContext"); diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index 9cb6a8415cfd8..ceb908a9a869b 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -9,6 +9,7 @@ use health::HealthRegistry; use tower::limit::ConcurrencyLimitLayer; use crate::{ + cohort_cache::CohortCache, config::{Config, TeamIdsToTrack}, database::Client as DatabaseClient, geoip::GeoIpClient, @@ -22,6 +23,7 @@ pub struct State { pub redis: Arc, pub postgres_reader: Arc, pub postgres_writer: Arc, + pub cohort_cache: Arc, // TODO does this need a better name than just `cohort_cache`? pub geoip: Arc, pub team_ids_to_track: TeamIdsToTrack, } @@ -30,6 +32,7 @@ pub fn router( redis: Arc, postgres_reader: Arc, postgres_writer: Arc, + cohort_cache: Arc, geoip: Arc, liveness: HealthRegistry, config: Config, @@ -42,6 +45,7 @@ where redis, postgres_reader, postgres_writer, + cohort_cache, geoip, team_ids_to_track: config.team_ids_to_track.clone(), }; diff --git a/rust/feature-flags/src/server.rs b/rust/feature-flags/src/server.rs index c9e238fa8fd4e..93492a05f4400 100644 --- a/rust/feature-flags/src/server.rs +++ b/rust/feature-flags/src/server.rs @@ -6,6 +6,7 @@ use std::time::Duration; use health::{HealthHandle, HealthRegistry}; use tokio::net::TcpListener; +use crate::cohort_cache::CohortCache; use crate::config::Config; use crate::database::get_pool; use crate::geoip::GeoIpClient; @@ -54,6 +55,8 @@ where } }; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let health = HealthRegistry::new("liveness"); // TODO - we don't have a more complex health check yet, but we should add e.g. some around DB operations @@ -67,6 +70,7 @@ where redis_client, postgres_reader, postgres_writer, + cohort_cache, geoip_service, health, config, diff --git a/rust/feature-flags/tests/test_flag_matching_consistency.rs b/rust/feature-flags/tests/test_flag_matching_consistency.rs index 94f4f67dcdc56..0080fd470261f 100644 --- a/rust/feature-flags/tests/test_flag_matching_consistency.rs +++ b/rust/feature-flags/tests/test_flag_matching_consistency.rs @@ -1,3 +1,6 @@ +use std::sync::Arc; + +use feature_flags::cohort_cache::CohortCache; use feature_flags::feature_flag_match_reason::FeatureFlagMatchReason; /// These tests are common between all libraries doing local evaluation of feature flags. /// This ensures there are no mismatches between implementations. @@ -110,6 +113,7 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { for (i, result) in results.iter().enumerate().take(1000) { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let distinct_id = format!("distinct_id_{}", i); @@ -118,6 +122,7 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { 1, postgres_reader, postgres_writer, + cohort_cache, None, None, None, @@ -1209,6 +1214,7 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { for (i, result) in results.iter().enumerate().take(1000) { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); let distinct_id = format!("distinct_id_{}", i); let feature_flag_match = FeatureFlagMatcher::new( @@ -1216,6 +1222,7 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { 1, postgres_reader, postgres_writer, + cohort_cache, None, None, None, From 43e8692e344452ca68cdf44447826e2973c7b67b Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 1 Nov 2024 16:14:24 -0700 Subject: [PATCH 21/30] less yapping --- rust/feature-flags/src/cohort_cache.rs | 57 ++++++++++---------------- 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index a948289d899d8..447d1c6971691 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -6,22 +6,31 @@ use std::time::Duration; /// CohortCache manages the in-memory cache of cohorts using `moka` for caching. /// -/// ```text -/// per_team_cohorts: { -/// 1: [Cohort { id: 101, name: "Active Users", filters: [...] }, ...], -/// 2: [Cohort { id: 201, name: "New Users", filters: [...] }, ...] -/// } -/// ``` -/// /// Features: /// - **TTL**: Each cache entry expires after 5 minutes. /// - **Size-based eviction**: The cache evicts least recently used entries when the maximum capacity is reached. /// +/// ```text +/// CohortCache { +/// postgres_reader: PostgresReader, +/// per_team_cohorts: Cache> { +/// // Example: +/// 2: [ +/// Cohort { id: 1, name: "Power Users", filters: {...} }, +/// Cohort { id: 2, name: "Churned", filters: {...} } +/// ], +/// 5: [ +/// Cohort { id: 3, name: "Beta Users", filters: {...} } +/// ] +/// } +/// } +/// ``` +/// /// Caches only successful cohort lists to maintain cache integrity. #[derive(Clone)] pub struct CohortCache { postgres_reader: PostgresReader, - per_team_cohorts: Cache>, // team_id -> list of Cohorts + per_team_cohorts: Cache>, } impl CohortCache { @@ -31,7 +40,7 @@ impl CohortCache { max_capacity: Option, ttl_seconds: Option, ) -> Self { - // Define the weigher closure. Here, we consider the number of cohorts as the weight. + // We use the size of the cohort list as the weight of the entry let weigher = |_: &TeamId, value: &Vec| -> u32 { return value.len().try_into().unwrap_or(u32::MAX); }; @@ -57,9 +66,7 @@ impl CohortCache { if let Some(cached_cohorts) = self.per_team_cohorts.get(&team_id).await { return Ok(cached_cohorts.clone()); } - // Fetch from database let fetched_cohorts = Cohort::list_from_pg(self.postgres_reader.clone(), team_id).await?; - // Insert into cache self.per_team_cohorts .insert(team_id.clone(), fetched_cohorts.clone()) .await; @@ -90,7 +97,7 @@ mod tests { /// Helper function to insert a cohort for a team. async fn setup_test_cohort( writer_client: Arc, - team_id: TeamId, // Adjusted to accept TeamId + team_id: TeamId, name: Option, ) -> Result { let filters = serde_json::json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}); @@ -100,11 +107,9 @@ mod tests { /// Tests that cache entries expire after the specified TTL. #[tokio::test] async fn test_cache_expiry() -> Result<(), anyhow::Error> { - // Setup PostgreSQL clients let writer_client = setup_pg_writer_client(None).await; let reader_client = setup_pg_reader_client(None).await; - // Setup test team and cohort let team_id = setup_test_team(writer_client.clone()).await?; let _cohort = setup_test_cohort(writer_client.clone(), team_id.clone(), None).await?; @@ -115,12 +120,10 @@ mod tests { Some(1), // 1-second TTL ); - // Fetch cohorts, which should populate the cache let cohorts = cohort_cache.get_cohorts_for_team(team_id.clone()).await?; assert_eq!(cohorts.len(), 1); assert_eq!(cohorts[0].team_id, team_id.clone()); - // Ensure the cohort is cached let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await; assert!(cached_cohorts.is_some()); @@ -137,17 +140,13 @@ mod tests { /// Tests that the cache correctly evicts least recently used entries based on the weigher. #[tokio::test] async fn test_cache_weigher() -> Result<(), anyhow::Error> { - // Setup PostgreSQL clients let writer_client = setup_pg_writer_client(None).await; let reader_client = setup_pg_reader_client(None).await; - // Define a smaller max_capacity and TTL for testing + // Define a smaller max_capacity for testing let max_capacity: u64 = 3; - let ttl_seconds: u64 = 300; // 5 minutes - // Initialize CohortCache - let cohort_cache = - CohortCache::new(reader_client.clone(), Some(max_capacity), Some(ttl_seconds)); + let cohort_cache = CohortCache::new(reader_client.clone(), Some(max_capacity), None); let mut inserted_team_ids = Vec::new(); @@ -160,7 +159,6 @@ mod tests { cohort_cache.get_cohorts_for_team(team_id.clone()).await?; } - // The cache should be at max_capacity cohort_cache.per_team_cohorts.run_pending_tasks().await; let cache_size = cohort_cache.per_team_cohorts.entry_count(); assert_eq!( @@ -168,7 +166,6 @@ mod tests { "Cache size should be equal to max_capacity" ); - // Insert one more team to trigger eviction let new_team = insert_new_team_in_pg(writer_client.clone(), None).await?; let new_team_id = new_team.id; setup_test_cohort(writer_client.clone(), new_team_id.clone(), None).await?; @@ -176,7 +173,6 @@ mod tests { .get_cohorts_for_team(new_team_id.clone()) .await?; - // Now, the cache should still have max_capacity entries cohort_cache.per_team_cohorts.run_pending_tasks().await; let cache_size_after = cohort_cache.per_team_cohorts.entry_count(); assert_eq!( @@ -184,7 +180,6 @@ mod tests { "Cache size should remain equal to max_capacity after eviction" ); - // The least recently used team should have been evicted let evicted_team_id = &inserted_team_ids[0]; let cached_cohorts = cohort_cache.per_team_cohorts.get(evicted_team_id).await; assert!( @@ -192,7 +187,6 @@ mod tests { "Least recently used cache entry should have been evicted" ); - // The new team should be present let cached_new_team = cohort_cache.per_team_cohorts.get(&new_team_id).await; assert!( cached_new_team.is_some(), @@ -202,30 +196,21 @@ mod tests { Ok(()) } - /// Functional test to verify that fetching cohorts populates the cache correctly. #[tokio::test] async fn test_get_cohorts_for_team() -> Result<(), anyhow::Error> { - // Setup PostgreSQL clients let writer_client = setup_pg_writer_client(None).await; let reader_client = setup_pg_reader_client(None).await; - - // Setup test team and cohort let team_id = setup_test_team(writer_client.clone()).await?; let _cohort = setup_test_cohort(writer_client.clone(), team_id.clone(), None).await?; - - // Initialize CohortCache let cohort_cache = CohortCache::new(reader_client.clone(), None, None); - // Initially, cache should be empty let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await; assert!(cached_cohorts.is_none(), "Cache should initially be empty"); - // Fetch cohorts, which should populate the cache let cohorts = cohort_cache.get_cohorts_for_team(team_id.clone()).await?; assert_eq!(cohorts.len(), 1); assert_eq!(cohorts[0].team_id, team_id.clone()); - // Now, cache should have the cohorts let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await.unwrap(); assert_eq!(cached_cohorts.len(), 1); assert_eq!(cached_cohorts[0].team_id, team_id.clone()); From 3a656837fe5726ee3d278a44a4ab264846143f6d Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 1 Nov 2024 17:16:06 -0700 Subject: [PATCH 22/30] appeasing the linter --- rust/feature-flags/src/cohort_cache.rs | 7 ++- rust/feature-flags/src/flag_matching.rs | 43 +------------------ rust/feature-flags/src/request_handler.rs | 1 - .../tests/test_flag_matching_consistency.rs | 2 - 4 files changed, 5 insertions(+), 48 deletions(-) diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index 447d1c6971691..3adf9a606fb8c 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -41,9 +41,8 @@ impl CohortCache { ttl_seconds: Option, ) -> Self { // We use the size of the cohort list as the weight of the entry - let weigher = |_: &TeamId, value: &Vec| -> u32 { - return value.len().try_into().unwrap_or(u32::MAX); - }; + let weigher = + |_: &TeamId, value: &Vec| -> u32 { value.len().try_into().unwrap_or(u32::MAX) }; // Initialize the Moka cache with TTL and size-based eviction. let cache = Cache::builder() @@ -68,7 +67,7 @@ impl CohortCache { } let fetched_cohorts = Cohort::list_from_pg(self.postgres_reader.clone(), team_id).await?; self.per_team_cohorts - .insert(team_id.clone(), fetched_cohorts.clone()) + .insert(team_id, fetched_cohorts.clone()) .await; Ok(fetched_cohorts) diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 9255175a3816e..89729b0ab54d7 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -202,7 +202,6 @@ impl FeatureFlagMatcher { postgres_writer: PostgresWriter, cohort_cache: Arc, group_type_mapping_cache: Option, - properties_cache: Option, groups: Option>, ) -> Self { FeatureFlagMatcher { @@ -211,10 +210,10 @@ impl FeatureFlagMatcher { postgres_reader: postgres_reader.clone(), postgres_writer: postgres_writer.clone(), cohort_cache, - groups: groups.unwrap_or_default(), group_type_mapping_cache: group_type_mapping_cache .unwrap_or_else(|| GroupTypeMappingCache::new(team_id, postgres_reader.clone())), - properties_cache: properties_cache.unwrap_or_default(), + groups: groups.unwrap_or_default(), + properties_cache: PropertiesCache::default(), } } @@ -1759,7 +1758,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let match_result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(match_result.matches); @@ -1773,7 +1771,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let match_result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(!match_result.matches); @@ -1787,7 +1784,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let match_result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(!match_result.matches); @@ -1841,7 +1837,6 @@ mod tests { cohort_cache, None, None, - None, ); let flags = FeatureFlagList { @@ -1918,7 +1913,6 @@ mod tests { postgres_writer.clone(), cohort_cache.clone(), Some(group_type_mapping_cache), - None, Some(groups), ); @@ -1959,7 +1953,6 @@ mod tests { postgres_writer.clone(), cohort_cache.clone(), Some(group_type_mapping_cache), - None, Some(groups), ); let variant = matcher.get_matching_variant(&flag, None).await.unwrap(); @@ -1989,7 +1982,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let variant = matcher.get_matching_variant(&flag, None).await.unwrap(); @@ -2037,7 +2029,6 @@ mod tests { cohort_cache, None, None, - None, ); let (is_match, reason) = matcher .is_condition_match(&flag, &condition, None, None) @@ -2136,7 +2127,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result = matcher @@ -2229,7 +2219,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result = matcher @@ -2274,7 +2263,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let properties = matcher @@ -2320,7 +2308,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); // First access should fetch from the database @@ -2354,7 +2341,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); // First access with new matcher should fetch from the database again @@ -2483,7 +2469,6 @@ mod tests { cohort_cache_clone, None, None, - None, ); matcher.get_match(&flag_clone, None, None).await.unwrap() })); @@ -2563,7 +2548,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -2605,7 +2589,6 @@ mod tests { cohort_cache, None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -2647,7 +2630,6 @@ mod tests { cohort_cache, None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -2699,7 +2681,6 @@ mod tests { cohort_cache, None, None, - None, ); let mut control_count = 0; @@ -2780,7 +2761,6 @@ mod tests { cohort_cache, None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -2843,7 +2823,6 @@ mod tests { cohort_cache, None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -2921,7 +2900,6 @@ mod tests { cohort_cache, None, None, - None, ); let result = matcher @@ -2966,7 +2944,6 @@ mod tests { cohort_cache, None, None, - None, ); let (is_match, reason) = matcher @@ -3046,7 +3023,6 @@ mod tests { cohort_cache, None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -3138,7 +3114,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let mut matcher_example_id = FeatureFlagMatcher::new( @@ -3149,7 +3124,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let mut matcher_another_id = FeatureFlagMatcher::new( @@ -3160,7 +3134,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result_test_id = matcher_test_id.get_match(&flag, None, None).await.unwrap(); @@ -3265,7 +3238,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -3359,7 +3331,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let mut matcher_example_id = FeatureFlagMatcher::new( @@ -3370,7 +3341,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let mut matcher_another_id = FeatureFlagMatcher::new( @@ -3381,7 +3351,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result_test_id = matcher_test_id.get_match(&flag, None, None).await.unwrap(); @@ -3497,7 +3466,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -3586,7 +3554,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -3675,7 +3642,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -3790,7 +3756,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -3880,7 +3845,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ); let result = matcher.get_match(&flag, None, None).await.unwrap(); @@ -4110,7 +4074,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ) .evaluate_all_feature_flags(flags, None, None, Some("hash_key_continuity".to_string())) .await; @@ -4184,7 +4147,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ) .evaluate_all_feature_flags(flags, None, None, None) .await; @@ -4297,7 +4259,6 @@ mod tests { cohort_cache.clone(), None, None, - None, ) .evaluate_all_feature_flags( flags, diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs index 5277a2565ef2b..73bc6129f2c10 100644 --- a/rust/feature-flags/src/request_handler.rs +++ b/rust/feature-flags/src/request_handler.rs @@ -226,7 +226,6 @@ pub async fn evaluate_feature_flags(context: FeatureFlagEvaluationContext) -> Fl context.postgres_writer, context.cohort_cache, Some(group_type_mapping_cache), - None, // TODO maybe remove this from the matcher struct, since it's used internally but not passed around context.groups, ); feature_flag_matcher diff --git a/rust/feature-flags/tests/test_flag_matching_consistency.rs b/rust/feature-flags/tests/test_flag_matching_consistency.rs index 0080fd470261f..0ffb0f687e4aa 100644 --- a/rust/feature-flags/tests/test_flag_matching_consistency.rs +++ b/rust/feature-flags/tests/test_flag_matching_consistency.rs @@ -125,7 +125,6 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { cohort_cache, None, None, - None, ) .get_match(&flags[0], None, None) .await @@ -1225,7 +1224,6 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { cohort_cache, None, None, - None, ) .get_match(&flags[0], None, None) .await From a5812e67b66bdde48110353d5acdf9b2a4b1abd4 Mon Sep 17 00:00:00 2001 From: dylan Date: Fri, 1 Nov 2024 17:27:39 -0700 Subject: [PATCH 23/30] that should do it --- rust/feature-flags/src/cohort_cache.rs | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index 3adf9a606fb8c..4cb2ac475e1a1 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -110,7 +110,7 @@ mod tests { let reader_client = setup_pg_reader_client(None).await; let team_id = setup_test_team(writer_client.clone()).await?; - let _cohort = setup_test_cohort(writer_client.clone(), team_id.clone(), None).await?; + let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?; // Initialize CohortCache with a short TTL for testing let cohort_cache = CohortCache::new( @@ -119,9 +119,9 @@ mod tests { Some(1), // 1-second TTL ); - let cohorts = cohort_cache.get_cohorts_for_team(team_id.clone()).await?; + let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?; assert_eq!(cohorts.len(), 1); - assert_eq!(cohorts[0].team_id, team_id.clone()); + assert_eq!(cohorts[0].team_id, team_id); let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await; assert!(cached_cohorts.is_some()); @@ -153,9 +153,9 @@ mod tests { for _ in 0..max_capacity { let team = insert_new_team_in_pg(writer_client.clone(), None).await?; let team_id = team.id; - inserted_team_ids.push(team_id.clone()); - setup_test_cohort(writer_client.clone(), team_id.clone(), None).await?; - cohort_cache.get_cohorts_for_team(team_id.clone()).await?; + inserted_team_ids.push(team_id); + setup_test_cohort(writer_client.clone(), team_id, None).await?; + cohort_cache.get_cohorts_for_team(team_id).await?; } cohort_cache.per_team_cohorts.run_pending_tasks().await; @@ -167,10 +167,8 @@ mod tests { let new_team = insert_new_team_in_pg(writer_client.clone(), None).await?; let new_team_id = new_team.id; - setup_test_cohort(writer_client.clone(), new_team_id.clone(), None).await?; - cohort_cache - .get_cohorts_for_team(new_team_id.clone()) - .await?; + setup_test_cohort(writer_client.clone(), new_team_id, None).await?; + cohort_cache.get_cohorts_for_team(new_team_id).await?; cohort_cache.per_team_cohorts.run_pending_tasks().await; let cache_size_after = cohort_cache.per_team_cohorts.entry_count(); @@ -200,19 +198,19 @@ mod tests { let writer_client = setup_pg_writer_client(None).await; let reader_client = setup_pg_reader_client(None).await; let team_id = setup_test_team(writer_client.clone()).await?; - let _cohort = setup_test_cohort(writer_client.clone(), team_id.clone(), None).await?; + let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?; let cohort_cache = CohortCache::new(reader_client.clone(), None, None); let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await; assert!(cached_cohorts.is_none(), "Cache should initially be empty"); - let cohorts = cohort_cache.get_cohorts_for_team(team_id.clone()).await?; + let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?; assert_eq!(cohorts.len(), 1); - assert_eq!(cohorts[0].team_id, team_id.clone()); + assert_eq!(cohorts[0].team_id, team_id); let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await.unwrap(); assert_eq!(cached_cohorts.len(), 1); - assert_eq!(cached_cohorts[0].team_id, team_id.clone()); + assert_eq!(cached_cohorts[0].team_id, team_id); Ok(()) } From fd52b24e019ea4c6de419ae5c9afa1b534e2c2cc Mon Sep 17 00:00:00 2001 From: dylan Date: Mon, 4 Nov 2024 11:09:31 -0800 Subject: [PATCH 24/30] clean up --- rust/feature-flags/src/cohort_cache.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index 4cb2ac475e1a1..b4550117dcada 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -26,7 +26,6 @@ use std::time::Duration; /// } /// ``` /// -/// Caches only successful cohort lists to maintain cache integrity. #[derive(Clone)] pub struct CohortCache { postgres_reader: PostgresReader, @@ -34,17 +33,15 @@ pub struct CohortCache { } impl CohortCache { - /// Creates a new `CohortCache` with configurable TTL and maximum capacity. pub fn new( postgres_reader: PostgresReader, max_capacity: Option, ttl_seconds: Option, ) -> Self { - // We use the size of the cohort list as the weight of the entry + // We use the size of the cohort list (i.e., the number of cohorts for a given team)as the weight of the entry let weigher = |_: &TeamId, value: &Vec| -> u32 { value.len().try_into().unwrap_or(u32::MAX) }; - // Initialize the Moka cache with TTL and size-based eviction. let cache = Cache::builder() .time_to_live(Duration::from_secs(ttl_seconds.unwrap_or(300))) // Default to 5 minutes .weigher(weigher) From 59f7c10df19c1feba7b4873d901f2fdf41b0cff4 Mon Sep 17 00:00:00 2001 From: dylan Date: Mon, 4 Nov 2024 11:13:49 -0800 Subject: [PATCH 25/30] rename --- rust/feature-flags/src/cohort_cache.rs | 51 ++++++++------ rust/feature-flags/src/flag_matching.rs | 68 +++++++++---------- rust/feature-flags/src/request_handler.rs | 12 ++-- rust/feature-flags/src/router.rs | 6 +- rust/feature-flags/src/server.rs | 4 +- .../tests/test_flag_matching_consistency.rs | 6 +- 6 files changed, 77 insertions(+), 70 deletions(-) diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index b4550117dcada..68894c19f88e2 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -4,14 +4,14 @@ use crate::flag_matching::{PostgresReader, TeamId}; use moka::future::Cache; use std::time::Duration; -/// CohortCache manages the in-memory cache of cohorts using `moka` for caching. +/// CohortCacheManager manages the in-memory cache of cohorts using `moka` for caching. /// /// Features: /// - **TTL**: Each cache entry expires after 5 minutes. /// - **Size-based eviction**: The cache evicts least recently used entries when the maximum capacity is reached. /// /// ```text -/// CohortCache { +/// CohortCacheManager { /// postgres_reader: PostgresReader, /// per_team_cohorts: Cache> { /// // Example: @@ -27,12 +27,12 @@ use std::time::Duration; /// ``` /// #[derive(Clone)] -pub struct CohortCache { +pub struct CohortCacheManager { postgres_reader: PostgresReader, - per_team_cohorts: Cache>, + per_team_cohort_cache: Cache>, } -impl CohortCache { +impl CohortCacheManager { pub fn new( postgres_reader: PostgresReader, max_capacity: Option, @@ -50,7 +50,7 @@ impl CohortCache { Self { postgres_reader, - per_team_cohorts: cache, + per_team_cohort_cache: cache, } } @@ -59,11 +59,11 @@ impl CohortCache { /// If the cohorts are not present in the cache or have expired, it fetches them from the database, /// caches the result upon successful retrieval, and then returns it. pub async fn get_cohorts_for_team(&self, team_id: TeamId) -> Result, FlagError> { - if let Some(cached_cohorts) = self.per_team_cohorts.get(&team_id).await { + if let Some(cached_cohorts) = self.per_team_cohort_cache.get(&team_id).await { return Ok(cached_cohorts.clone()); } let fetched_cohorts = Cohort::list_from_pg(self.postgres_reader.clone(), team_id).await?; - self.per_team_cohorts + self.per_team_cohort_cache .insert(team_id, fetched_cohorts.clone()) .await; @@ -109,8 +109,8 @@ mod tests { let team_id = setup_test_team(writer_client.clone()).await?; let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?; - // Initialize CohortCache with a short TTL for testing - let cohort_cache = CohortCache::new( + // Initialize CohortCacheManager with a short TTL for testing + let cohort_cache = CohortCacheManager::new( reader_client.clone(), Some(100), Some(1), // 1-second TTL @@ -120,14 +120,14 @@ mod tests { assert_eq!(cohorts.len(), 1); assert_eq!(cohorts[0].team_id, team_id); - let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await; + let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await; assert!(cached_cohorts.is_some()); // Wait for TTL to expire sleep(Duration::from_secs(2)).await; // Attempt to retrieve from cache again - let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await; + let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await; assert!(cached_cohorts.is_none(), "Cache entry should have expired"); Ok(()) @@ -142,7 +142,7 @@ mod tests { // Define a smaller max_capacity for testing let max_capacity: u64 = 3; - let cohort_cache = CohortCache::new(reader_client.clone(), Some(max_capacity), None); + let cohort_cache = CohortCacheManager::new(reader_client.clone(), Some(max_capacity), None); let mut inserted_team_ids = Vec::new(); @@ -155,8 +155,8 @@ mod tests { cohort_cache.get_cohorts_for_team(team_id).await?; } - cohort_cache.per_team_cohorts.run_pending_tasks().await; - let cache_size = cohort_cache.per_team_cohorts.entry_count(); + cohort_cache.per_team_cohort_cache.run_pending_tasks().await; + let cache_size = cohort_cache.per_team_cohort_cache.entry_count(); assert_eq!( cache_size, max_capacity, "Cache size should be equal to max_capacity" @@ -167,21 +167,24 @@ mod tests { setup_test_cohort(writer_client.clone(), new_team_id, None).await?; cohort_cache.get_cohorts_for_team(new_team_id).await?; - cohort_cache.per_team_cohorts.run_pending_tasks().await; - let cache_size_after = cohort_cache.per_team_cohorts.entry_count(); + cohort_cache.per_team_cohort_cache.run_pending_tasks().await; + let cache_size_after = cohort_cache.per_team_cohort_cache.entry_count(); assert_eq!( cache_size_after, max_capacity, "Cache size should remain equal to max_capacity after eviction" ); let evicted_team_id = &inserted_team_ids[0]; - let cached_cohorts = cohort_cache.per_team_cohorts.get(evicted_team_id).await; + let cached_cohorts = cohort_cache + .per_team_cohort_cache + .get(evicted_team_id) + .await; assert!( cached_cohorts.is_none(), "Least recently used cache entry should have been evicted" ); - let cached_new_team = cohort_cache.per_team_cohorts.get(&new_team_id).await; + let cached_new_team = cohort_cache.per_team_cohort_cache.get(&new_team_id).await; assert!( cached_new_team.is_some(), "Newly added cache entry should be present" @@ -196,16 +199,20 @@ mod tests { let reader_client = setup_pg_reader_client(None).await; let team_id = setup_test_team(writer_client.clone()).await?; let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?; - let cohort_cache = CohortCache::new(reader_client.clone(), None, None); + let cohort_cache = CohortCacheManager::new(reader_client.clone(), None, None); - let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await; + let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await; assert!(cached_cohorts.is_none(), "Cache should initially be empty"); let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?; assert_eq!(cohorts.len(), 1); assert_eq!(cohorts[0].team_id, team_id); - let cached_cohorts = cohort_cache.per_team_cohorts.get(&team_id).await.unwrap(); + let cached_cohorts = cohort_cache + .per_team_cohort_cache + .get(&team_id) + .await + .unwrap(); assert_eq!(cached_cohorts.len(), 1); assert_eq!(cached_cohorts[0].team_id, team_id); diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 89729b0ab54d7..dd94b1456f3d5 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,6 +1,6 @@ use crate::{ api::{FlagError, FlagValue, FlagsResponse}, - cohort_cache::CohortCache, + cohort_cache::CohortCacheManager, cohort_models::{Cohort, CohortId}, database::Client as DatabaseClient, feature_flag_match_reason::FeatureFlagMatchReason, @@ -186,7 +186,7 @@ pub struct FeatureFlagMatcher { pub team_id: TeamId, pub postgres_reader: PostgresReader, pub postgres_writer: PostgresWriter, - pub cohort_cache: Arc, + pub cohort_cache: Arc, group_type_mapping_cache: GroupTypeMappingCache, properties_cache: PropertiesCache, groups: HashMap, @@ -200,7 +200,7 @@ impl FeatureFlagMatcher { team_id: TeamId, postgres_reader: PostgresReader, postgres_writer: PostgresWriter, - cohort_cache: Arc, + cohort_cache: Arc, group_type_mapping_cache: Option, groups: Option>, ) -> Self { @@ -1705,7 +1705,7 @@ mod tests { async fn test_fetch_properties_from_pg_to_match() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await @@ -1794,7 +1794,7 @@ mod tests { async fn test_person_property_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1856,7 +1856,7 @@ mod tests { async fn test_group_property_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1935,7 +1935,7 @@ mod tests { let flag = create_test_flag_with_variants(1); let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let mut group_type_mapping_cache = GroupTypeMappingCache::new(1, postgres_reader.clone()); let group_types_to_indexes = [("group_type_1".to_string(), 1)].into_iter().collect(); @@ -1967,7 +1967,7 @@ mod tests { async fn test_get_matching_variant_with_db() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1993,7 +1993,7 @@ mod tests { async fn test_is_condition_match_empty_properties() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -2083,7 +2083,7 @@ mod tests { async fn test_overrides_avoid_db_lookups() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2154,7 +2154,7 @@ mod tests { async fn test_fallback_to_db_when_overrides_insufficient() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2240,7 +2240,7 @@ mod tests { async fn test_property_fetching_and_caching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2285,7 +2285,7 @@ mod tests { async fn test_property_caching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2429,7 +2429,7 @@ mod tests { async fn test_concurrent_flag_evaluation() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2488,7 +2488,7 @@ mod tests { async fn test_property_operators() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2559,7 +2559,7 @@ mod tests { async fn test_empty_hashed_identifier() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -2600,7 +2600,7 @@ mod tests { async fn test_rollout_percentage() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let mut flag = create_test_flag( Some(1), None, @@ -2648,7 +2648,7 @@ mod tests { async fn test_uneven_variant_distribution() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let mut flag = create_test_flag_with_variants(1); // Adjust variant rollout percentages to be uneven @@ -2710,7 +2710,7 @@ mod tests { async fn test_missing_properties_in_db() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2772,7 +2772,7 @@ mod tests { async fn test_malformed_property_data() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2835,7 +2835,7 @@ mod tests { async fn test_get_match_with_insufficient_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2914,7 +2914,7 @@ mod tests { async fn test_evaluation_reasons() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -2959,7 +2959,7 @@ mod tests { async fn test_complex_conditions() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3034,7 +3034,7 @@ mod tests { async fn test_super_condition_matches_boolean() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3158,7 +3158,7 @@ mod tests { async fn test_super_condition_matches_string() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3251,7 +3251,7 @@ mod tests { async fn test_super_condition_matches_and_false() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3389,7 +3389,7 @@ mod tests { async fn test_basic_cohort_matching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3477,7 +3477,7 @@ mod tests { async fn test_not_in_cohort_matching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3565,7 +3565,7 @@ mod tests { async fn test_not_in_cohort_matching_user_in_cohort() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3654,7 +3654,7 @@ mod tests { async fn test_cohort_dependent_on_another_cohort() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3768,7 +3768,7 @@ mod tests { async fn test_in_cohort_matching_user_not_in_cohort() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -4007,7 +4007,7 @@ mod tests { async fn test_evaluate_feature_flags_with_experience_continuity() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -4090,7 +4090,7 @@ mod tests { async fn test_evaluate_feature_flags_with_continuity_missing_override() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -4163,7 +4163,7 @@ mod tests { async fn test_evaluate_all_feature_flags_mixed_continuity() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs index 73bc6129f2c10..538c6845d2a02 100644 --- a/rust/feature-flags/src/request_handler.rs +++ b/rust/feature-flags/src/request_handler.rs @@ -1,6 +1,6 @@ use crate::{ api::{FlagError, FlagsResponse}, - cohort_cache::CohortCache, + cohort_cache::CohortCacheManager, database::Client, flag_definitions::FeatureFlagList, flag_matching::{FeatureFlagMatcher, GroupTypeMappingCache}, @@ -70,7 +70,7 @@ pub struct FeatureFlagEvaluationContext { feature_flags: FeatureFlagList, postgres_reader: Arc, postgres_writer: Arc, - cohort_cache: Arc, + cohort_cache: Arc, #[builder(default)] person_property_overrides: Option>, #[builder(default)] @@ -359,7 +359,7 @@ mod tests { async fn test_evaluate_feature_flags() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = FeatureFlag { name: Some("Test Flag".to_string()), id: 1, @@ -508,7 +508,7 @@ mod tests { async fn test_evaluate_feature_flags_multiple_flags() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flags = vec![ FeatureFlag { name: Some("Flag 1".to_string()), @@ -613,7 +613,7 @@ mod tests { async fn test_evaluate_feature_flags_with_overrides() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -696,7 +696,7 @@ mod tests { let long_id = "a".repeat(1000); let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = FeatureFlag { name: Some("Test Flag".to_string()), id: 1, diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index ceb908a9a869b..e34ea31a3c65a 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -9,7 +9,7 @@ use health::HealthRegistry; use tower::limit::ConcurrencyLimitLayer; use crate::{ - cohort_cache::CohortCache, + cohort_cache::CohortCacheManager, config::{Config, TeamIdsToTrack}, database::Client as DatabaseClient, geoip::GeoIpClient, @@ -23,7 +23,7 @@ pub struct State { pub redis: Arc, pub postgres_reader: Arc, pub postgres_writer: Arc, - pub cohort_cache: Arc, // TODO does this need a better name than just `cohort_cache`? + pub cohort_cache: Arc, // TODO does this need a better name than just `cohort_cache`? pub geoip: Arc, pub team_ids_to_track: TeamIdsToTrack, } @@ -32,7 +32,7 @@ pub fn router( redis: Arc, postgres_reader: Arc, postgres_writer: Arc, - cohort_cache: Arc, + cohort_cache: Arc, geoip: Arc, liveness: HealthRegistry, config: Config, diff --git a/rust/feature-flags/src/server.rs b/rust/feature-flags/src/server.rs index 93492a05f4400..69ff759ddfcdf 100644 --- a/rust/feature-flags/src/server.rs +++ b/rust/feature-flags/src/server.rs @@ -6,7 +6,7 @@ use std::time::Duration; use health::{HealthHandle, HealthRegistry}; use tokio::net::TcpListener; -use crate::cohort_cache::CohortCache; +use crate::cohort_cache::CohortCacheManager; use crate::config::Config; use crate::database::get_pool; use crate::geoip::GeoIpClient; @@ -55,7 +55,7 @@ where } }; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let health = HealthRegistry::new("liveness"); diff --git a/rust/feature-flags/tests/test_flag_matching_consistency.rs b/rust/feature-flags/tests/test_flag_matching_consistency.rs index 0ffb0f687e4aa..c632d28bc151d 100644 --- a/rust/feature-flags/tests/test_flag_matching_consistency.rs +++ b/rust/feature-flags/tests/test_flag_matching_consistency.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use feature_flags::cohort_cache::CohortCache; +use feature_flags::cohort_cache::CohortCacheManager; use feature_flags::feature_flag_match_reason::FeatureFlagMatchReason; /// These tests are common between all libraries doing local evaluation of feature flags. /// This ensures there are no mismatches between implementations. @@ -113,7 +113,7 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { for (i, result) in results.iter().enumerate().take(1000) { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let distinct_id = format!("distinct_id_{}", i); @@ -1213,7 +1213,7 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { for (i, result) in results.iter().enumerate().take(1000) { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - let cohort_cache = Arc::new(CohortCache::new(postgres_reader.clone(), None, None)); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let distinct_id = format!("distinct_id_{}", i); let feature_flag_match = FeatureFlagMatcher::new( From 8066aff2c7c5b12f1cf1108fb465701525bd5766 Mon Sep 17 00:00:00 2001 From: dylan Date: Mon, 4 Nov 2024 11:21:11 -0800 Subject: [PATCH 26/30] bit more --- rust/feature-flags/src/flag_matching.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index dd94b1456f3d5..255bd016f6fa4 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -745,22 +745,22 @@ impl FeatureFlagMatcher { .cloned() .partition(|prop| prop.is_cohort()); - // Get the relevant properties to check for the condition + // Get the properties we need to check for in this condition match from the flag + any overrides let target_properties = self .get_properties_to_check(feature_flag, property_overrides, &non_cohort_filters) .await?; - // Evaluate non-cohort properties first, since they're cheaper to evaluate + // Evaluate non-cohort filters first, since they're cheaper to evaluate and we can return early if they don't match if !all_properties_match(&non_cohort_filters, &target_properties) { return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); } - // Evaluate cohort filters, if any + // Evaluate cohort filters, if any. if !cohort_filters.is_empty() { - let cohorts_match = self + if !self .evaluate_cohort_filters(&cohort_filters, &target_properties) - .await?; - if !cohorts_match { + .await? + { return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); } } @@ -3760,7 +3760,6 @@ mod tests { let result = matcher.get_match(&flag, None, None).await.unwrap(); - // This test might fail if the system doesn't support cohort dependencies assert!(result.matches); } From 4d5ecd9c4eb13579be3f4f300a5fdad526f96721 Mon Sep 17 00:00:00 2001 From: dylan Date: Mon, 4 Nov 2024 11:23:59 -0800 Subject: [PATCH 27/30] collapse condition --- rust/feature-flags/src/flag_matching.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 255bd016f6fa4..38f2eb83ef979 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -756,13 +756,12 @@ impl FeatureFlagMatcher { } // Evaluate cohort filters, if any. - if !cohort_filters.is_empty() { - if !self + if !cohort_filters.is_empty() + && !self .evaluate_cohort_filters(&cohort_filters, &target_properties) .await? - { - return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); - } + { + return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); } } From fe37b0462152a48b636e5ac1fba43be7f060b014 Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 7 Nov 2024 08:36:36 -0800 Subject: [PATCH 28/30] working on it --- posthog/api/test/test_decide.py | 131 ++++++++++++++++++++++++++++++++ rust/Cargo.lock | 6 -- 2 files changed, 131 insertions(+), 6 deletions(-) diff --git a/posthog/api/test/test_decide.py b/posthog/api/test/test_decide.py index 7532c3405d27a..a753d5778adc1 100644 --- a/posthog/api/test/test_decide.py +++ b/posthog/api/test/test_decide.py @@ -925,6 +925,75 @@ def test_feature_flags_v2(self, *args): "third-variant", response.json()["featureFlags"]["multivariate-flag"] ) # different hash, different variant assigned + + # @patch( + # "posthog.models.feature_flag.flag_matching.postgres_healthcheck.is_connected", + # return_value=True, + # ) + def test_feature_flags_v3_with_groups(self, mock_is_connected): + self.team.app_urls = ["https://example.com"] + self.team.save() + self.client.logout() + + FeatureFlag.objects.create( + team=self.team, + key="groups-flag", + name="Groups flag", + created_by=self.user, + filters={ + "filters": { + "aggregation_group_type_index": 0, + "groups": [ + { + "properties": [ + { + "key": "$group_key", + "type": "group", + "value": "13", + "operator": "exact", + "group_type_index": 0 + } + ], + "rollout_percentage": 100, + } + ], + }, + "name": "This is a group-based flag", + "key": "groups-flag", + }, + ) + + GroupTypeMapping.objects.db_manager().create( + team=self.team, project_id=self.team.project_id, group_type="tenant", group_type_index=0 + ) + + Group.objects.db_manager().create( + team_id=self.team.pk, + group_type_index=0, + group_key="13", + group_properties={"tenant_name": "Tacit", "tenant_type": "Organization"}, + version=2, + ) + + with self.assertNumQueries(9), self.assertNumQueries(0, using="default"): + # E 1. SET LOCAL statement_timeout = 300 + # E 2. SELECT "posthog_grouptypemapping"."id", "posthog_grouptypemapping"."team_id", -- a.k.a get group type mappings + + # E 3. SET LOCAL statement_timeout = 600 + # E 4. SELECT (UPPER(("posthog_group"."group_properties" ->> 'email')::text) AS "flag_182_condition_0" FROM "posthog_group" -- a.k.a get group0 conditions + # E 5. SELECT (true) AS "flag_181_condition_0" FROM "posthog_group" WHERE ("posthog_group"."team_id" = 91 -- a.k.a get group1 conditions + response = self._post_decide( + distinct_id="example_id", + groups={"tenant": "13"}, + api_version=3, + ) + print(response.json()) + # self.assertFalse( + # response.json()["featureFlags"]["groups-flag"], + # ) + self.assertFalse(response.json()["errorsWhileComputingFlags"]) + + def test_feature_flags_v2_with_property_overrides(self, *args): self.team.app_urls = ["https://example.com"] self.team.save() @@ -4553,6 +4622,68 @@ def test_feature_flags_v3_consistent_flags_with_write_on_hash_key_overrides(self "first-variant", response.json()["featureFlags"]["multivariate-flag"] ) # assigned by distinct_id hash + @patch( + "posthog.models.feature_flag.flag_matching.postgres_healthcheck.is_connected", + return_value=True, + ) + def test_feature_flags_v3_with_groups(self, mock_is_connected): + org, team, user = self.setup_user_and_team_in_db("replica") + self.organization, self.team, self.user = org, team, user + + flags = [ + { + "filters": { + "aggregation_group_type_index": 0, + "groups": [ + { + "properties": [ + { + "key": "$group_key", + "type": "group", + "value": "13", + "operator": "exact", + "group_type_index": 0 + } + ], + "rollout_percentage": 100, + } + ], + }, + "name": "This is a group-based flag", + "key": "groups-flag", + }, + ] + self.setup_flags_in_db("replica", team, user, flags, persons) + + GroupTypeMapping.objects.db_manager("replica").create( + team=self.team, project_id=self.team.project_id, group_type="tenant", group_type_index=0 + ) + + Group.objects.db_manager("replica").create( + team_id=self.team.pk, + group_type_index=0, + group_key="13", + group_properties={"tenant_name": "Tacit", "tenant_type": "Organization"}, + version=2, + ) + + with self.assertNumQueries(9, using="replica"), self.assertNumQueries(0, using="default"): + # E 1. SET LOCAL statement_timeout = 300 + # E 2. SELECT "posthog_grouptypemapping"."id", "posthog_grouptypemapping"."team_id", -- a.k.a get group type mappings + + # E 3. SET LOCAL statement_timeout = 600 + # E 4. SELECT (UPPER(("posthog_group"."group_properties" ->> 'email')::text) AS "flag_182_condition_0" FROM "posthog_group" -- a.k.a get group0 conditions + # E 5. SELECT (true) AS "flag_181_condition_0" FROM "posthog_group" WHERE ("posthog_group"."team_id" = 91 -- a.k.a get group1 conditions + response = self._post_decide( + distinct_id="example_id", + groups={"tenant": "13"}, + ) + self.assertEqual( + response.json()["featureFlags"], + {"groups-flag": False}, + ) + self.assertFalse(response.json()["errorsWhileComputingFlags"]) + @patch( "posthog.models.feature_flag.flag_matching.postgres_healthcheck.is_connected", return_value=True, diff --git a/rust/Cargo.lock b/rust/Cargo.lock index d6e1257080c73..88ce33468dea4 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -3047,7 +3047,6 @@ version = "0.12.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32cf62eb4dd975d2dde76432fb1075c49e3ee2331cf36f1f8fd4b66550d32b6f" dependencies = [ -<<<<<<< HEAD "async-lock 3.4.0", "async-trait", "crossbeam-channel", @@ -3055,11 +3054,6 @@ dependencies = [ "crossbeam-utils", "event-listener 5.3.1", "futures-util", -======= - "crossbeam-channel", - "crossbeam-epoch", - "crossbeam-utils", ->>>>>>> master "once_cell", "parking_lot", "quanta 0.12.2", From 41d3db30382663ac71e3dcc11a4ef789d68b0b56 Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 7 Nov 2024 12:43:36 -0800 Subject: [PATCH 29/30] not this either --- posthog/api/test/test_decide.py | 131 -------------------------------- 1 file changed, 131 deletions(-) diff --git a/posthog/api/test/test_decide.py b/posthog/api/test/test_decide.py index a753d5778adc1..7532c3405d27a 100644 --- a/posthog/api/test/test_decide.py +++ b/posthog/api/test/test_decide.py @@ -925,75 +925,6 @@ def test_feature_flags_v2(self, *args): "third-variant", response.json()["featureFlags"]["multivariate-flag"] ) # different hash, different variant assigned - - # @patch( - # "posthog.models.feature_flag.flag_matching.postgres_healthcheck.is_connected", - # return_value=True, - # ) - def test_feature_flags_v3_with_groups(self, mock_is_connected): - self.team.app_urls = ["https://example.com"] - self.team.save() - self.client.logout() - - FeatureFlag.objects.create( - team=self.team, - key="groups-flag", - name="Groups flag", - created_by=self.user, - filters={ - "filters": { - "aggregation_group_type_index": 0, - "groups": [ - { - "properties": [ - { - "key": "$group_key", - "type": "group", - "value": "13", - "operator": "exact", - "group_type_index": 0 - } - ], - "rollout_percentage": 100, - } - ], - }, - "name": "This is a group-based flag", - "key": "groups-flag", - }, - ) - - GroupTypeMapping.objects.db_manager().create( - team=self.team, project_id=self.team.project_id, group_type="tenant", group_type_index=0 - ) - - Group.objects.db_manager().create( - team_id=self.team.pk, - group_type_index=0, - group_key="13", - group_properties={"tenant_name": "Tacit", "tenant_type": "Organization"}, - version=2, - ) - - with self.assertNumQueries(9), self.assertNumQueries(0, using="default"): - # E 1. SET LOCAL statement_timeout = 300 - # E 2. SELECT "posthog_grouptypemapping"."id", "posthog_grouptypemapping"."team_id", -- a.k.a get group type mappings - - # E 3. SET LOCAL statement_timeout = 600 - # E 4. SELECT (UPPER(("posthog_group"."group_properties" ->> 'email')::text) AS "flag_182_condition_0" FROM "posthog_group" -- a.k.a get group0 conditions - # E 5. SELECT (true) AS "flag_181_condition_0" FROM "posthog_group" WHERE ("posthog_group"."team_id" = 91 -- a.k.a get group1 conditions - response = self._post_decide( - distinct_id="example_id", - groups={"tenant": "13"}, - api_version=3, - ) - print(response.json()) - # self.assertFalse( - # response.json()["featureFlags"]["groups-flag"], - # ) - self.assertFalse(response.json()["errorsWhileComputingFlags"]) - - def test_feature_flags_v2_with_property_overrides(self, *args): self.team.app_urls = ["https://example.com"] self.team.save() @@ -4622,68 +4553,6 @@ def test_feature_flags_v3_consistent_flags_with_write_on_hash_key_overrides(self "first-variant", response.json()["featureFlags"]["multivariate-flag"] ) # assigned by distinct_id hash - @patch( - "posthog.models.feature_flag.flag_matching.postgres_healthcheck.is_connected", - return_value=True, - ) - def test_feature_flags_v3_with_groups(self, mock_is_connected): - org, team, user = self.setup_user_and_team_in_db("replica") - self.organization, self.team, self.user = org, team, user - - flags = [ - { - "filters": { - "aggregation_group_type_index": 0, - "groups": [ - { - "properties": [ - { - "key": "$group_key", - "type": "group", - "value": "13", - "operator": "exact", - "group_type_index": 0 - } - ], - "rollout_percentage": 100, - } - ], - }, - "name": "This is a group-based flag", - "key": "groups-flag", - }, - ] - self.setup_flags_in_db("replica", team, user, flags, persons) - - GroupTypeMapping.objects.db_manager("replica").create( - team=self.team, project_id=self.team.project_id, group_type="tenant", group_type_index=0 - ) - - Group.objects.db_manager("replica").create( - team_id=self.team.pk, - group_type_index=0, - group_key="13", - group_properties={"tenant_name": "Tacit", "tenant_type": "Organization"}, - version=2, - ) - - with self.assertNumQueries(9, using="replica"), self.assertNumQueries(0, using="default"): - # E 1. SET LOCAL statement_timeout = 300 - # E 2. SELECT "posthog_grouptypemapping"."id", "posthog_grouptypemapping"."team_id", -- a.k.a get group type mappings - - # E 3. SET LOCAL statement_timeout = 600 - # E 4. SELECT (UPPER(("posthog_group"."group_properties" ->> 'email')::text) AS "flag_182_condition_0" FROM "posthog_group" -- a.k.a get group0 conditions - # E 5. SELECT (true) AS "flag_181_condition_0" FROM "posthog_group" WHERE ("posthog_group"."team_id" = 91 -- a.k.a get group1 conditions - response = self._post_decide( - distinct_id="example_id", - groups={"tenant": "13"}, - ) - self.assertEqual( - response.json()["featureFlags"], - {"groups-flag": False}, - ) - self.assertFalse(response.json()["errorsWhileComputingFlags"]) - @patch( "posthog.models.feature_flag.flag_matching.postgres_healthcheck.is_connected", return_value=True, From 0a409f41c5495f300020dec1be82a9176f23b8fc Mon Sep 17 00:00:00 2001 From: dylan Date: Thu, 7 Nov 2024 12:56:01 -0800 Subject: [PATCH 30/30] docs --- rust/feature-flags/src/flag_matching.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 38f2eb83ef979..571fe9c84b40a 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1212,6 +1212,7 @@ fn apply_cohort_membership_logic( /// - C depends on A and B /// - D depends on B /// - E depends on C and D +/// /// The graph is acyclic, which is required for valid cohort dependencies. fn build_cohort_dependency_graph( initial_cohort_id: CohortId,