-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
283 additions
and
319 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,184 +1,191 @@ | ||
use crate::api::FlagError; | ||
use crate::cohort_models::{Cohort, CohortId}; | ||
use crate::cohort_operations::sort_cohorts_topologically; | ||
use crate::flag_definitions::{OperatorType, PropertyFilter}; | ||
use crate::flag_matching::{PostgresReader, TeamId}; | ||
use crate::cohort_models::{Cohort, CohortId, CohortProperty}; | ||
use crate::flag_definitions::PropertyFilter; | ||
use crate::flag_matching::PostgresReader; | ||
use petgraph::algo::toposort; | ||
use petgraph::graphmap::DiGraphMap; | ||
use std::collections::{HashMap, HashSet}; | ||
use std::sync::Arc; | ||
use tokio::sync::RwLock; | ||
|
||
pub type TeamCohortMap = HashMap<CohortId, Vec<PropertyFilter>>; | ||
pub type TeamSortedCohorts = HashMap<TeamId, Vec<CohortId>>; | ||
pub type TeamCacheMap = HashMap<TeamId, TeamCohortMap>; | ||
// Flattened Cohort Map: CohortId -> Combined PropertyFilters | ||
pub type FlattenedCohortMap = HashMap<CohortId, Vec<PropertyFilter>>; | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct CachedCohort { | ||
// TODO name this something different | ||
pub filters: Vec<PropertyFilter>, // Non-cohort property filters | ||
pub dependencies: Vec<CohortDependencyFilter>, // Dependencies with operators | ||
} | ||
|
||
// Add this struct to facilitate handling operators on cohort filters | ||
#[derive(Debug, Clone)] | ||
pub struct CohortDependencyFilter { | ||
pub cohort_id: CohortId, | ||
pub operator: OperatorType, | ||
} | ||
|
||
/// Threadsafety is ensured using Arc and RwLock | ||
#[derive(Clone, Debug)] | ||
/// CohortCache manages the in-memory cache of flattened cohorts | ||
#[derive(Clone)] | ||
pub struct CohortCache { | ||
/// Mapping from TeamId to their respective CohortId and associated CachedCohort | ||
per_team: Arc<RwLock<HashMap<TeamId, HashMap<CohortId, CachedCohort>>>>, | ||
/// Mapping from TeamId to sorted CohortIds based on dependencies | ||
sorted_cohorts: Arc<RwLock<HashMap<TeamId, Vec<CohortId>>>>, | ||
pub per_team_flattened: Arc<RwLock<HashMap<i32, FlattenedCohortMap>>>, // team_id -> (cohort_id -> filters) | ||
} | ||
|
||
impl CohortCache { | ||
/// Creates a new CohortCache instance | ||
pub fn new() -> Self { | ||
Self { | ||
per_team: Arc::new(RwLock::new(HashMap::new())), | ||
sorted_cohorts: Arc::new(RwLock::new(HashMap::new())), | ||
per_team_flattened: Arc::new(RwLock::new(HashMap::new())), | ||
} | ||
} | ||
|
||
/// Fetches, flattens, sorts, and caches all cohorts for a given team if not already cached | ||
pub async fn fetch_and_cache_cohorts( | ||
/// Asynchronous constructor that initializes the CohortCache by fetching and caching cohorts for the given team_id | ||
pub async fn new_with_team( | ||
team_id: i32, | ||
postgres_reader: PostgresReader, | ||
) -> Result<Self, FlagError> { | ||
let cache = Self { | ||
per_team_flattened: Arc::new(RwLock::new(HashMap::new())), | ||
}; | ||
cache | ||
.fetch_and_cache_all_cohorts(team_id, postgres_reader) | ||
.await?; | ||
Ok(cache) | ||
} | ||
|
||
/// Fetches, parses, and caches all cohorts for a given team | ||
async fn fetch_and_cache_all_cohorts( | ||
&self, | ||
team_id: TeamId, | ||
team_id: i32, | ||
postgres_reader: PostgresReader, | ||
) -> Result<(), FlagError> { | ||
// Acquire write locks to modify the cache | ||
let mut cache = self.per_team.write().await; | ||
let mut sorted = self.sorted_cohorts.write().await; | ||
// Fetch all cohorts for the team | ||
let cohorts = Cohort::list_from_pg(postgres_reader, team_id).await?; | ||
|
||
// Check if the team's cohorts are already cached | ||
if cache.contains_key(&team_id) && sorted.contains_key(&team_id) { | ||
return Ok(()); | ||
// Build a mapping from cohort_id to Cohort | ||
let mut cohort_map: HashMap<CohortId, Cohort> = HashMap::new(); | ||
for cohort in cohorts { | ||
cohort_map.insert(cohort.id, cohort); | ||
} | ||
|
||
// Fetch all cohorts for the team from the database | ||
let all_cohorts = fetch_all_cohorts(team_id, postgres_reader).await?; | ||
// Build dependency graph | ||
let dependency_graph = Self::build_dependency_graph(&cohort_map)?; | ||
|
||
// Flatten the property filters, resolving dependencies | ||
let flattened = flatten_cohorts(&all_cohorts).await?; | ||
// Perform topological sort | ||
let sorted_cohorts = toposort(&dependency_graph, None).map_err(|_| { | ||
FlagError::CohortDependencyCycle("Cycle detected in cohort dependencies".to_string()) | ||
})?; | ||
|
||
// Extract all cohort IDs | ||
let cohort_ids: HashSet<CohortId> = flattened.keys().cloned().collect(); | ||
// Reverse to process dependencies first | ||
let sorted_cohorts: Vec<CohortId> = sorted_cohorts.into_iter().rev().collect(); | ||
|
||
// Sort the cohorts topologically based on dependencies | ||
let sorted_ids = sort_cohorts_topologically(cohort_ids, &flattened)?; | ||
// Flatten cohorts | ||
let flattened = Self::flatten_cohorts(&sorted_cohorts, &cohort_map)?; | ||
|
||
// Insert the flattened cohorts and their sorted order into the cache | ||
// Cache the flattened cohort filters | ||
let mut cache = self.per_team_flattened.write().await; | ||
cache.insert(team_id, flattened); | ||
sorted.insert(team_id, sorted_ids); | ||
|
||
Ok(()) | ||
} | ||
|
||
/// Retrieves sorted cohort IDs for a team from the cache | ||
pub async fn get_sorted_cohort_ids( | ||
/// Retrieves flattened filters for a given team and cohort | ||
pub async fn get_flattened_filters( | ||
&self, | ||
team_id: TeamId, | ||
postgres_reader: PostgresReader, | ||
) -> Result<Vec<CohortId>, FlagError> { | ||
{ | ||
// Acquire read locks to check the cache | ||
let cache = self.per_team.read().await; | ||
let sorted = self.sorted_cohorts.read().await; | ||
if let (Some(_cohort_map), Some(sorted_ids)) = | ||
(cache.get(&team_id), sorted.get(&team_id)) | ||
{ | ||
if !sorted_ids.is_empty() { | ||
return Ok(sorted_ids.clone()); | ||
} | ||
team_id: i32, | ||
cohort_id: CohortId, | ||
) -> Result<Vec<PropertyFilter>, FlagError> { | ||
let cache = self.per_team_flattened.read().await; | ||
if let Some(team_map) = cache.get(&team_id) { | ||
if let Some(filters) = team_map.get(&cohort_id) { | ||
Ok(filters.clone()) | ||
} else { | ||
Err(FlagError::CohortNotFound(cohort_id.to_string())) | ||
} | ||
} else { | ||
Err(FlagError::CohortNotFound(cohort_id.to_string())) | ||
} | ||
} | ||
|
||
// If not cached, fetch and cache the cohorts | ||
self.fetch_and_cache_cohorts(team_id, postgres_reader) | ||
.await?; | ||
/// Builds a dependency graph where an edge from A to B means A depends on B | ||
fn build_dependency_graph( | ||
cohort_map: &HashMap<CohortId, Cohort>, | ||
) -> Result<DiGraphMap<CohortId, ()>, FlagError> { | ||
let mut graph = DiGraphMap::new(); | ||
|
||
// Acquire read locks to retrieve the sorted list after caching | ||
let sorted = self.sorted_cohorts.read().await; | ||
if let Some(sorted_ids) = sorted.get(&team_id) { | ||
Ok(sorted_ids.clone()) | ||
} else { | ||
Ok(Vec::new()) | ||
// Add all cohorts as nodes | ||
for &cohort_id in cohort_map.keys() { | ||
graph.add_node(cohort_id); | ||
} | ||
|
||
// Add edges based on dependencies | ||
for (&cohort_id, cohort) in cohort_map.iter() { | ||
let dependencies = Self::extract_dependencies(cohort.filters.clone())?; | ||
for dep_id in dependencies { | ||
if !cohort_map.contains_key(&dep_id) { | ||
return Err(FlagError::CohortNotFound(dep_id.to_string())); | ||
} | ||
graph.add_edge(cohort_id, dep_id, ()); // A depends on B: A -> B | ||
} | ||
} | ||
|
||
Ok(graph) | ||
} | ||
|
||
/// Retrieves cached cohorts for a team | ||
pub async fn get_cached_cohorts( | ||
&self, | ||
team_id: TeamId, | ||
) -> Result<HashMap<CohortId, CachedCohort>, FlagError> { | ||
let cache = self.per_team.read().await; | ||
if let Some(cohort_map) = cache.get(&team_id) { | ||
Ok(cohort_map.clone()) | ||
} else { | ||
Ok(HashMap::new()) | ||
/// Extracts all dependent CohortIds from the filters | ||
fn extract_dependencies( | ||
filters_as_json: serde_json::Value, | ||
) -> Result<HashSet<CohortId>, FlagError> { | ||
let filters: CohortProperty = serde_json::from_value(filters_as_json)?; | ||
let mut dependencies = HashSet::new(); | ||
Self::traverse_filters(&filters.properties, &mut dependencies)?; | ||
Ok(dependencies) | ||
} | ||
|
||
/// Recursively traverses the filter tree to find cohort dependencies | ||
fn traverse_filters( | ||
inner: &crate::cohort_models::InnerCohortProperty, | ||
dependencies: &mut HashSet<CohortId>, | ||
) -> Result<(), FlagError> { | ||
for cohort_values in &inner.values { | ||
for filter in &cohort_values.values { | ||
if filter.prop_type == "cohort" && filter.key == "id" { | ||
// Assuming the value is a single integer CohortId | ||
if let Some(cohort_id) = filter.value.as_i64() { | ||
dependencies.insert(cohort_id as CohortId); | ||
} else { | ||
return Err(FlagError::CohortFiltersParsingError); // TODO more data here? | ||
} | ||
} | ||
// Handle nested properties if necessary | ||
// If the filter can contain nested properties with more conditions, traverse them here | ||
} | ||
} | ||
Ok(()) | ||
} | ||
} | ||
|
||
async fn fetch_all_cohorts( | ||
team_id: TeamId, | ||
postgres_reader: PostgresReader, | ||
) -> Result<Vec<Cohort>, FlagError> { | ||
let mut conn = postgres_reader.get_connection().await?; | ||
|
||
let query = r#" | ||
SELECT * | ||
FROM posthog_cohort | ||
WHERE team_id = $1 AND deleted = FALSE | ||
"#; | ||
|
||
let cohorts: Vec<Cohort> = sqlx::query_as::<_, Cohort>(query) | ||
.bind(team_id) | ||
.fetch_all(&mut *conn) | ||
.await | ||
.map_err(|e| FlagError::DatabaseError(e.to_string()))?; | ||
|
||
Ok(cohorts) | ||
} | ||
/// Flattens the filters based on sorted cohorts, including only property filters | ||
fn flatten_cohorts( | ||
sorted_cohorts: &[CohortId], | ||
cohort_map: &HashMap<CohortId, Cohort>, | ||
) -> Result<FlattenedCohortMap, FlagError> { | ||
let mut flattened: FlattenedCohortMap = HashMap::new(); | ||
|
||
async fn flatten_cohorts( | ||
all_cohorts: &Vec<Cohort>, | ||
) -> Result<HashMap<CohortId, CachedCohort>, FlagError> { | ||
let mut flattened = HashMap::new(); | ||
|
||
for cohort in all_cohorts { | ||
let filters = cohort.parse_filters()?; | ||
|
||
// Extract dependencies from cohort filters | ||
let dependencies = filters | ||
.iter() | ||
.filter_map(|f| { | ||
if f.prop_type == "cohort" { | ||
Some(CohortDependencyFilter { | ||
cohort_id: f.value.as_i64().unwrap() as CohortId, | ||
operator: f.operator.clone().unwrap_or(OperatorType::In), | ||
}) | ||
for &cohort_id in sorted_cohorts { | ||
let cohort = cohort_map | ||
.get(&cohort_id) | ||
.ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?; | ||
|
||
// Use the updated parse_property_filters method | ||
let property_filters = cohort.parse_filters()?; | ||
|
||
// Extract dependencies using Cohort's method | ||
let dependencies = cohort.extract_dependencies()?; | ||
|
||
let mut combined_filters = Vec::new(); | ||
|
||
// Include filters from dependencies | ||
for dep_id in &dependencies { | ||
if let Some(dep_filters) = flattened.get(dep_id) { | ||
combined_filters.extend(dep_filters.clone()); | ||
} else { | ||
None | ||
return Err(FlagError::CohortNotFound(dep_id.to_string())); | ||
} | ||
}) | ||
.collect(); | ||
|
||
// Filter out cohort filters as they are now represented as dependencies | ||
let non_cohort_filters: Vec<PropertyFilter> = filters | ||
.into_iter() | ||
.filter(|f| f.prop_type != "cohort") | ||
.collect(); | ||
let cached_cohort = CachedCohort { | ||
filters: non_cohort_filters, | ||
dependencies, | ||
}; | ||
} | ||
|
||
flattened.insert(cohort.id, cached_cohort); | ||
} | ||
// Include own filters | ||
combined_filters.extend(property_filters); | ||
|
||
Ok(flattened) | ||
// Insert into flattened map | ||
flattened.insert(cohort_id, combined_filters); | ||
} | ||
|
||
Ok(flattened) | ||
} | ||
} |
Oops, something went wrong.