From 4c49bc46c5faf6534b758b4f665b64fc1e633952 Mon Sep 17 00:00:00 2001 From: dylan Date: Tue, 29 Oct 2024 21:54:54 -0700 Subject: [PATCH] new life --- rust/Cargo.lock | 1 + rust/feature-flags/Cargo.toml | 1 + rust/feature-flags/src/cohort_cache.rs | 279 ++++++++++---------- rust/feature-flags/src/cohort_operations.rs | 142 +++++----- rust/feature-flags/src/flag_definitions.rs | 20 +- rust/feature-flags/src/flag_matching.rs | 159 ++++------- 6 files changed, 283 insertions(+), 319 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 3f3c36157e36b..5bb5fc25b318d 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1316,6 +1316,7 @@ dependencies = [ "health", "maxminddb", "once_cell", + "petgraph", "rand", "redis", "regex", diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index 4cf4016767be6..9847569394bb2 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -39,6 +39,7 @@ health = { path = "../common/health" } common-metrics = { path = "../common/metrics" } tower = { workspace = true } derive_builder = "0.20.1" +petgraph = "0.6.5" [lints] workspace = true diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs index a0e3de8782533..2991a73988c26 100644 --- a/rust/feature-flags/src/cohort_cache.rs +++ b/rust/feature-flags/src/cohort_cache.rs @@ -1,184 +1,191 @@ use crate::api::FlagError; -use crate::cohort_models::{Cohort, CohortId}; -use crate::cohort_operations::sort_cohorts_topologically; -use crate::flag_definitions::{OperatorType, PropertyFilter}; -use crate::flag_matching::{PostgresReader, TeamId}; +use crate::cohort_models::{Cohort, CohortId, CohortProperty}; +use crate::flag_definitions::PropertyFilter; +use crate::flag_matching::PostgresReader; +use petgraph::algo::toposort; +use petgraph::graphmap::DiGraphMap; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tokio::sync::RwLock; -pub type TeamCohortMap = HashMap>; -pub type TeamSortedCohorts = HashMap>; -pub type TeamCacheMap = HashMap; +// Flattened Cohort Map: CohortId -> Combined PropertyFilters +pub type FlattenedCohortMap = HashMap>; -#[derive(Debug, Clone)] -pub struct CachedCohort { - // TODO name this something different - pub filters: Vec, // Non-cohort property filters - pub dependencies: Vec, // Dependencies with operators -} - -// Add this struct to facilitate handling operators on cohort filters -#[derive(Debug, Clone)] -pub struct CohortDependencyFilter { - pub cohort_id: CohortId, - pub operator: OperatorType, -} - -/// Threadsafety is ensured using Arc and RwLock -#[derive(Clone, Debug)] +/// CohortCache manages the in-memory cache of flattened cohorts +#[derive(Clone)] pub struct CohortCache { - /// Mapping from TeamId to their respective CohortId and associated CachedCohort - per_team: Arc>>>, - /// Mapping from TeamId to sorted CohortIds based on dependencies - sorted_cohorts: Arc>>>, + pub per_team_flattened: Arc>>, // team_id -> (cohort_id -> filters) } impl CohortCache { /// Creates a new CohortCache instance pub fn new() -> Self { Self { - per_team: Arc::new(RwLock::new(HashMap::new())), - sorted_cohorts: Arc::new(RwLock::new(HashMap::new())), + per_team_flattened: Arc::new(RwLock::new(HashMap::new())), } } - /// Fetches, flattens, sorts, and caches all cohorts for a given team if not already cached - pub async fn fetch_and_cache_cohorts( + /// Asynchronous constructor that initializes the CohortCache by fetching and caching cohorts for the given team_id + pub async fn new_with_team( + team_id: i32, + postgres_reader: PostgresReader, + ) -> Result { + let cache = Self { + per_team_flattened: Arc::new(RwLock::new(HashMap::new())), + }; + cache + .fetch_and_cache_all_cohorts(team_id, postgres_reader) + .await?; + Ok(cache) + } + + /// Fetches, parses, and caches all cohorts for a given team + async fn fetch_and_cache_all_cohorts( &self, - team_id: TeamId, + team_id: i32, postgres_reader: PostgresReader, ) -> Result<(), FlagError> { - // Acquire write locks to modify the cache - let mut cache = self.per_team.write().await; - let mut sorted = self.sorted_cohorts.write().await; + // Fetch all cohorts for the team + let cohorts = Cohort::list_from_pg(postgres_reader, team_id).await?; - // Check if the team's cohorts are already cached - if cache.contains_key(&team_id) && sorted.contains_key(&team_id) { - return Ok(()); + // Build a mapping from cohort_id to Cohort + let mut cohort_map: HashMap = HashMap::new(); + for cohort in cohorts { + cohort_map.insert(cohort.id, cohort); } - // Fetch all cohorts for the team from the database - let all_cohorts = fetch_all_cohorts(team_id, postgres_reader).await?; + // Build dependency graph + let dependency_graph = Self::build_dependency_graph(&cohort_map)?; - // Flatten the property filters, resolving dependencies - let flattened = flatten_cohorts(&all_cohorts).await?; + // Perform topological sort + let sorted_cohorts = toposort(&dependency_graph, None).map_err(|_| { + FlagError::CohortDependencyCycle("Cycle detected in cohort dependencies".to_string()) + })?; - // Extract all cohort IDs - let cohort_ids: HashSet = flattened.keys().cloned().collect(); + // Reverse to process dependencies first + let sorted_cohorts: Vec = sorted_cohorts.into_iter().rev().collect(); - // Sort the cohorts topologically based on dependencies - let sorted_ids = sort_cohorts_topologically(cohort_ids, &flattened)?; + // Flatten cohorts + let flattened = Self::flatten_cohorts(&sorted_cohorts, &cohort_map)?; - // Insert the flattened cohorts and their sorted order into the cache + // Cache the flattened cohort filters + let mut cache = self.per_team_flattened.write().await; cache.insert(team_id, flattened); - sorted.insert(team_id, sorted_ids); Ok(()) } - /// Retrieves sorted cohort IDs for a team from the cache - pub async fn get_sorted_cohort_ids( + /// Retrieves flattened filters for a given team and cohort + pub async fn get_flattened_filters( &self, - team_id: TeamId, - postgres_reader: PostgresReader, - ) -> Result, FlagError> { - { - // Acquire read locks to check the cache - let cache = self.per_team.read().await; - let sorted = self.sorted_cohorts.read().await; - if let (Some(_cohort_map), Some(sorted_ids)) = - (cache.get(&team_id), sorted.get(&team_id)) - { - if !sorted_ids.is_empty() { - return Ok(sorted_ids.clone()); - } + team_id: i32, + cohort_id: CohortId, + ) -> Result, FlagError> { + let cache = self.per_team_flattened.read().await; + if let Some(team_map) = cache.get(&team_id) { + if let Some(filters) = team_map.get(&cohort_id) { + Ok(filters.clone()) + } else { + Err(FlagError::CohortNotFound(cohort_id.to_string())) } + } else { + Err(FlagError::CohortNotFound(cohort_id.to_string())) } + } - // If not cached, fetch and cache the cohorts - self.fetch_and_cache_cohorts(team_id, postgres_reader) - .await?; + /// Builds a dependency graph where an edge from A to B means A depends on B + fn build_dependency_graph( + cohort_map: &HashMap, + ) -> Result, FlagError> { + let mut graph = DiGraphMap::new(); - // Acquire read locks to retrieve the sorted list after caching - let sorted = self.sorted_cohorts.read().await; - if let Some(sorted_ids) = sorted.get(&team_id) { - Ok(sorted_ids.clone()) - } else { - Ok(Vec::new()) + // Add all cohorts as nodes + for &cohort_id in cohort_map.keys() { + graph.add_node(cohort_id); } + + // Add edges based on dependencies + for (&cohort_id, cohort) in cohort_map.iter() { + let dependencies = Self::extract_dependencies(cohort.filters.clone())?; + for dep_id in dependencies { + if !cohort_map.contains_key(&dep_id) { + return Err(FlagError::CohortNotFound(dep_id.to_string())); + } + graph.add_edge(cohort_id, dep_id, ()); // A depends on B: A -> B + } + } + + Ok(graph) } - /// Retrieves cached cohorts for a team - pub async fn get_cached_cohorts( - &self, - team_id: TeamId, - ) -> Result, FlagError> { - let cache = self.per_team.read().await; - if let Some(cohort_map) = cache.get(&team_id) { - Ok(cohort_map.clone()) - } else { - Ok(HashMap::new()) + /// Extracts all dependent CohortIds from the filters + fn extract_dependencies( + filters_as_json: serde_json::Value, + ) -> Result, FlagError> { + let filters: CohortProperty = serde_json::from_value(filters_as_json)?; + let mut dependencies = HashSet::new(); + Self::traverse_filters(&filters.properties, &mut dependencies)?; + Ok(dependencies) + } + + /// Recursively traverses the filter tree to find cohort dependencies + fn traverse_filters( + inner: &crate::cohort_models::InnerCohortProperty, + dependencies: &mut HashSet, + ) -> Result<(), FlagError> { + for cohort_values in &inner.values { + for filter in &cohort_values.values { + if filter.prop_type == "cohort" && filter.key == "id" { + // Assuming the value is a single integer CohortId + if let Some(cohort_id) = filter.value.as_i64() { + dependencies.insert(cohort_id as CohortId); + } else { + return Err(FlagError::CohortFiltersParsingError); // TODO more data here? + } + } + // Handle nested properties if necessary + // If the filter can contain nested properties with more conditions, traverse them here + } } + Ok(()) } -} -async fn fetch_all_cohorts( - team_id: TeamId, - postgres_reader: PostgresReader, -) -> Result, FlagError> { - let mut conn = postgres_reader.get_connection().await?; - - let query = r#" - SELECT * - FROM posthog_cohort - WHERE team_id = $1 AND deleted = FALSE - "#; - - let cohorts: Vec = sqlx::query_as::<_, Cohort>(query) - .bind(team_id) - .fetch_all(&mut *conn) - .await - .map_err(|e| FlagError::DatabaseError(e.to_string()))?; - - Ok(cohorts) -} + /// Flattens the filters based on sorted cohorts, including only property filters + fn flatten_cohorts( + sorted_cohorts: &[CohortId], + cohort_map: &HashMap, + ) -> Result { + let mut flattened: FlattenedCohortMap = HashMap::new(); -async fn flatten_cohorts( - all_cohorts: &Vec, -) -> Result, FlagError> { - let mut flattened = HashMap::new(); - - for cohort in all_cohorts { - let filters = cohort.parse_filters()?; - - // Extract dependencies from cohort filters - let dependencies = filters - .iter() - .filter_map(|f| { - if f.prop_type == "cohort" { - Some(CohortDependencyFilter { - cohort_id: f.value.as_i64().unwrap() as CohortId, - operator: f.operator.clone().unwrap_or(OperatorType::In), - }) + for &cohort_id in sorted_cohorts { + let cohort = cohort_map + .get(&cohort_id) + .ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?; + + // Use the updated parse_property_filters method + let property_filters = cohort.parse_filters()?; + + // Extract dependencies using Cohort's method + let dependencies = cohort.extract_dependencies()?; + + let mut combined_filters = Vec::new(); + + // Include filters from dependencies + for dep_id in &dependencies { + if let Some(dep_filters) = flattened.get(dep_id) { + combined_filters.extend(dep_filters.clone()); } else { - None + return Err(FlagError::CohortNotFound(dep_id.to_string())); } - }) - .collect(); - - // Filter out cohort filters as they are now represented as dependencies - let non_cohort_filters: Vec = filters - .into_iter() - .filter(|f| f.prop_type != "cohort") - .collect(); - let cached_cohort = CachedCohort { - filters: non_cohort_filters, - dependencies, - }; + } - flattened.insert(cohort.id, cached_cohort); - } + // Include own filters + combined_filters.extend(property_filters); - Ok(flattened) + // Insert into flattened map + flattened.insert(cohort_id, combined_filters); + } + + Ok(flattened) + } } diff --git a/rust/feature-flags/src/cohort_operations.rs b/rust/feature-flags/src/cohort_operations.rs index b11fd1ca78c83..eb25c91dca3be 100644 --- a/rust/feature-flags/src/cohort_operations.rs +++ b/rust/feature-flags/src/cohort_operations.rs @@ -1,8 +1,7 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::sync::Arc; use tracing::instrument; -use crate::cohort_cache::CachedCohort; use crate::cohort_models::{Cohort, CohortId, CohortProperty, InnerCohortProperty}; use crate::{api::FlagError, database::Client as DatabaseClient, flag_definitions::PropertyFilter}; @@ -40,6 +39,29 @@ impl Cohort { }) } + #[instrument(skip_all)] + pub async fn list_from_pg( + client: Arc, + team_id: i32, + ) -> Result, FlagError> { + let mut conn = client.get_connection().await.map_err(|e| { + tracing::error!("Failed to get database connection: {}", e); + FlagError::DatabaseUnavailable + })?; + + let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id FROM posthog_cohort WHERE team_id = $1"; + let cohorts = sqlx::query_as::<_, Cohort>(query) + .bind(team_id) + .fetch_all(&mut *conn) + .await + .map_err(|e| { + tracing::error!("Failed to fetch cohorts from database: {}", e); + FlagError::Internal(format!("Database query error: {}", e)) + })?; + + Ok(cohorts) + } + /// Parses the filters JSON into a CohortProperty structure // TODO: this doesn't handle the deprecated "groups" field, see // https://github.com/PostHog/posthog/blob/feat/dynamic-cohorts-rust/posthog/models/cohort/cohort.py#L114-L169 @@ -50,95 +72,59 @@ impl Cohort { tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e); FlagError::CohortFiltersParsingError })?; - Ok(cohort_property.properties.to_property_filters()) - } -} -impl InnerCohortProperty { - pub fn to_property_filters(&self) -> Vec { - self.values - .iter() - .flat_map(|value| &value.values) - .cloned() - .collect() + // Filter out cohort filters + Ok(cohort_property + .properties + .to_property_filters() + .into_iter() + .filter(|f| !(f.key == "id" && f.prop_type == "cohort")) + .collect()) } -} -/// Sorts the given cohorts in an order where cohorts with no dependencies are placed first, -/// followed by cohorts that depend on the preceding ones. It ensures that each cohort in the sorted list -/// only depends on cohorts that appear earlier in the list. -pub fn sort_cohorts_topologically( - cohort_ids: HashSet, - cached_cohorts: &HashMap, -) -> Result, FlagError> { - if cohort_ids.is_empty() { - return Ok(Vec::new()); - } + /// Extracts dependent CohortIds from the cohort's filters + pub fn extract_dependencies(&self) -> Result, FlagError> { + let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone()) + .map_err(|e| { + tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e); + FlagError::CohortFiltersParsingError + })?; - let mut dependency_graph: HashMap> = HashMap::new(); - for &cohort_id in &cohort_ids { - if let Some(cohort) = cached_cohorts.get(&cohort_id) { - for dependency in &cohort.dependencies { - dependency_graph - .entry(cohort_id) - .or_default() - .push(dependency.cohort_id); - } - } + let mut dependencies = HashSet::new(); + Self::traverse_filters(&cohort_property.properties, &mut dependencies)?; + Ok(dependencies) } - let mut sorted = Vec::new(); - let mut temporary_marks = HashSet::new(); - let mut permanent_marks = HashSet::new(); - - fn visit( - node: CohortId, - dependency_graph: &HashMap>, - temporary_marks: &mut HashSet, - permanent_marks: &mut HashSet, - sorted: &mut Vec, + /// Recursively traverses the filter tree to find cohort dependencies + fn traverse_filters( + inner: &InnerCohortProperty, + dependencies: &mut HashSet, ) -> Result<(), FlagError> { - if permanent_marks.contains(&node) { - return Ok(()); - } - if temporary_marks.contains(&node) { - return Err(FlagError::CohortDependencyCycle(format!( - "Cycle detected at cohort {}", - node - ))); - } - - temporary_marks.insert(node); - if let Some(neighbors) = dependency_graph.get(&node) { - for &neighbor in neighbors { - visit( - neighbor, - dependency_graph, - temporary_marks, - permanent_marks, - sorted, - )?; + for cohort_values in &inner.values { + for filter in &cohort_values.values { + if filter.prop_type == "cohort" && filter.key == "id" { + // Assuming the value is a single integer CohortId + if let Some(cohort_id) = filter.value.as_i64() { + dependencies.insert(cohort_id as CohortId); + } else { + return Err(FlagError::CohortFiltersParsingError); + } + } + // Handle nested properties if necessary } } - temporary_marks.remove(&node); - permanent_marks.insert(node); - sorted.push(node); Ok(()) } +} - for &node in &cohort_ids { - if !permanent_marks.contains(&node) { - visit( - node, - &dependency_graph, - &mut temporary_marks, - &mut permanent_marks, - &mut sorted, - )?; - } +impl InnerCohortProperty { + pub fn to_property_filters(&self) -> Vec { + self.values + .iter() + .flat_map(|value| &value.values) + .cloned() + .collect() } - - Ok(sorted) } #[cfg(test)] diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index 4cf3e88e2c16c..329d0321e77be 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -1,4 +1,7 @@ -use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; +use crate::{ + api::FlagError, cohort_models::CohortId, database::Client as DatabaseClient, + redis::Client as RedisClient, +}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; @@ -43,6 +46,21 @@ pub struct PropertyFilter { pub group_type_index: Option, } +impl PropertyFilter { + /// Checks if the filter is a cohort filter + pub fn is_cohort(&self) -> bool { + self.key == "id" && self.prop_type == "cohort" + } + + /// Returns the cohort id if the filter is a cohort filter + pub fn get_cohort_id(&self) -> Result { + self.value + .as_i64() + .map(|id| id as CohortId) + .ok_or(FlagError::CohortFiltersParsingError) + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FlagGroupType { pub properties: Option>, diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 2384097b26880..7abddb24080c3 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,6 +1,6 @@ use crate::{ api::{FlagError, FlagValue, FlagsResponse}, - cohort_cache::{CachedCohort, CohortCache, CohortDependencyFilter}, + cohort_cache::CohortCache, cohort_models::CohortId, database::Client as DatabaseClient, feature_flag_match_reason::FeatureFlagMatchReason, @@ -187,7 +187,6 @@ pub struct FeatureFlagMatcher { group_type_mapping_cache: GroupTypeMappingCache, properties_cache: PropertiesCache, groups: HashMap, - cohort_cache: CohortCache, } const LONG_SCALE: u64 = 0xfffffffffffffff; @@ -211,7 +210,6 @@ impl FeatureFlagMatcher { group_type_mapping_cache: group_type_mapping_cache .unwrap_or_else(|| GroupTypeMappingCache::new(team_id, postgres_reader.clone())), properties_cache: properties_cache.unwrap_or_default(), - cohort_cache: CohortCache::new(), } } @@ -830,144 +828,97 @@ impl FeatureFlagMatcher { /// Evaluates cohort-based property filters pub async fn evaluate_cohort_filters( &self, - cohort_filters: &[PropertyFilter], + cohort_and_property_filters: &[PropertyFilter], target_properties: &HashMap, ) -> Result { - // Step 1: Extract the cohort filters into a list of CohortDependencyFilters - // We will need these because the are the filters associated with the flag itself, and required for the final - // step of evaluating the cohort membership logic - let filter_cohorts = self.extract_cohort_filters(cohort_filters)?; - - // Get all of the cohort IDs associated with the team and sort them topologically by dependency relationships - // We will need to evaluate the cohort membership logic for each of these cohorts in order to evaluate the cohort filters - // themselves. - // TODO can the sorted cohort IDs be aware of IDs in the filter_cohorts list? Or somehow let us know which ones - // we need to evaluate? - let sorted_cohort_ids = self - .cohort_cache - .get_sorted_cohort_ids(self.team_id, self.postgres_reader.clone()) - .await?; + let cohort_cache = + CohortCache::new_with_team(self.team_id, self.postgres_reader.clone()).await?; - // Step 6: Retrieve cached and flattened cohorts associated with this team - let cached_cohorts = self.cohort_cache.get_cached_cohorts(self.team_id).await?; + let (cohort_filters, property_filters) = cohort_and_property_filters + .iter() + .partition::, _>(|f| f.is_cohort()); - // Step 7: Evaluate cohort dependencies and property filters - let cohort_matches = self.evaluate_cohorts_and_associated_dependencies( - &sorted_cohort_ids, - &cached_cohorts, - target_properties, - )?; + // Early exit if any of property filters fail to match + if !self + .evaluate_property_filters(&property_filters, target_properties) + .await? + { + return Ok(false); + } - // Step 8: Apply any cohort operator logic to determine final matching - self.apply_cohort_membership_logic(&filter_cohorts, &cohort_matches) + // Evaluate cohort filters + let cohort_matches = self + .evaluate_cohort_dependencies(&cohort_filters, &cohort_cache, target_properties) + .await?; + + // Apply cohort membership logic + self.apply_cohort_membership_logic(&cohort_filters, &cohort_matches) } - fn extract_cohort_filters( + /// Evaluates property filters against target properties + async fn evaluate_property_filters( &self, - cohort_filters: &[PropertyFilter], - ) -> Result, FlagError> { - let filter_cohorts = cohort_filters - .iter() - .filter_map(|f| { - if f.key == "id" && f.prop_type == "cohort" { - let cohort_id = f.value.as_i64().map(|id| id as CohortId)?; // TODO handle error? - let operator = f.operator.clone().unwrap_or(OperatorType::In); - Some(CohortDependencyFilter { - cohort_id, - operator, - }) - } else { - None - } - }) - .collect(); - - Ok(filter_cohorts) + property_filters: &[&PropertyFilter], + target_properties: &HashMap, + ) -> Result { + for filter in property_filters { + if !match_property(filter, target_properties, false).unwrap_or(false) { + return Ok(false); + } + } + Ok(true) } - fn evaluate_cohorts_and_associated_dependencies( + /// Evaluates cohort dependencies using the cache + async fn evaluate_cohort_dependencies( &self, - sorted_cohort_ids: &[CohortId], - cached_cohorts: &HashMap, + cohort_filters: &[&PropertyFilter], + cohort_cache: &CohortCache, target_properties: &HashMap, ) -> Result, FlagError> { let mut cohort_matches = HashMap::new(); - for &cohort_id in sorted_cohort_ids { - let cached_cohort: &CachedCohort = match cached_cohorts.get(&cohort_id) { - Some(cohort) => cohort, - None => { - return Err(FlagError::CohortNotFound(format!( - "Cohort ID {} not found in cache", - cohort_id - ))); - } - }; - - // Evaluate dependent cohorts membership - if !self.dependencies_satisfied(&cached_cohort.dependencies, &cohort_matches)? { - cohort_matches.insert(cohort_id, false); - continue; - } + for filter in cohort_filters { + let cohort_id = filter.get_cohort_id()?; + let filters_to_evaluate = cohort_cache + .get_flattened_filters(self.team_id, cohort_id) + .await?; - // Evaluate property filters associated with the cohort - let is_match = cached_cohort.filters.iter().all(|prop_filter| { - match_property(prop_filter, target_properties, false).unwrap_or(false) - }); + let all_filters_match = filters_to_evaluate + .iter() + .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); - cohort_matches.insert(cohort_id, is_match); + cohort_matches.insert(cohort_id, all_filters_match); } Ok(cohort_matches) } - fn dependencies_satisfied( - &self, - dependencies: &[CohortDependencyFilter], - cohort_matches: &HashMap, - ) -> Result { - for dependency in dependencies { - match cohort_matches.get(&dependency.cohort_id) { - Some(true) => match dependency.operator { - OperatorType::In => continue, - OperatorType::NotIn => return Ok(false), - _ => {} - }, - Some(false) => match dependency.operator { - OperatorType::In => return Ok(false), - OperatorType::NotIn => continue, - _ => {} - }, - None => return Ok(false), - } - } - Ok(true) - } - + /// Applies cohort membership logic based on operators fn apply_cohort_membership_logic( &self, - filter_cohorts: &[CohortDependencyFilter], + cohort_filters: &[&PropertyFilter], cohort_matches: &HashMap, ) -> Result { - for filter_cohort in filter_cohorts { - let cohort_match = cohort_matches - .get(&filter_cohort.cohort_id) - .copied() - .unwrap_or(false); + for filter in cohort_filters { + let cohort_id = filter.get_cohort_id()?; + let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); + let operator = filter.operator.clone().unwrap_or(OperatorType::In); - if !self.cohort_membership_operator(filter_cohort.operator, cohort_match) { + if !self.cohort_membership_operator(operator, matches) { return Ok(false); } } Ok(true) } + /// Determines the final match based on the operator and match status fn cohort_membership_operator(&self, operator: OperatorType, match_status: bool) -> bool { match operator { OperatorType::In => match_status, OperatorType::NotIn => !match_status, - // TODO, there shouldn't be any other operators here since this is only called from evaluate_cohort_dependencies - _ => false, // Handle other operators as needed + // Extend with other operators as needed + _ => false, } }