Skip to content

Commit

Permalink
new life
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarticus committed Oct 30, 2024
1 parent 27af814 commit 4c49bc4
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 319 deletions.
1 change: 1 addition & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/feature-flags/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ health = { path = "../common/health" }
common-metrics = { path = "../common/metrics" }
tower = { workspace = true }
derive_builder = "0.20.1"
petgraph = "0.6.5"

[lints]
workspace = true
Expand Down
279 changes: 143 additions & 136 deletions rust/feature-flags/src/cohort_cache.rs
Original file line number Diff line number Diff line change
@@ -1,184 +1,191 @@
use crate::api::FlagError;
use crate::cohort_models::{Cohort, CohortId};
use crate::cohort_operations::sort_cohorts_topologically;
use crate::flag_definitions::{OperatorType, PropertyFilter};
use crate::flag_matching::{PostgresReader, TeamId};
use crate::cohort_models::{Cohort, CohortId, CohortProperty};
use crate::flag_definitions::PropertyFilter;
use crate::flag_matching::PostgresReader;
use petgraph::algo::toposort;
use petgraph::graphmap::DiGraphMap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;

pub type TeamCohortMap = HashMap<CohortId, Vec<PropertyFilter>>;
pub type TeamSortedCohorts = HashMap<TeamId, Vec<CohortId>>;
pub type TeamCacheMap = HashMap<TeamId, TeamCohortMap>;
// Flattened Cohort Map: CohortId -> Combined PropertyFilters
pub type FlattenedCohortMap = HashMap<CohortId, Vec<PropertyFilter>>;

#[derive(Debug, Clone)]
pub struct CachedCohort {
// TODO name this something different
pub filters: Vec<PropertyFilter>, // Non-cohort property filters
pub dependencies: Vec<CohortDependencyFilter>, // Dependencies with operators
}

// Add this struct to facilitate handling operators on cohort filters
#[derive(Debug, Clone)]
pub struct CohortDependencyFilter {
pub cohort_id: CohortId,
pub operator: OperatorType,
}

/// Threadsafety is ensured using Arc and RwLock
#[derive(Clone, Debug)]
/// CohortCache manages the in-memory cache of flattened cohorts
#[derive(Clone)]
pub struct CohortCache {
/// Mapping from TeamId to their respective CohortId and associated CachedCohort
per_team: Arc<RwLock<HashMap<TeamId, HashMap<CohortId, CachedCohort>>>>,
/// Mapping from TeamId to sorted CohortIds based on dependencies
sorted_cohorts: Arc<RwLock<HashMap<TeamId, Vec<CohortId>>>>,
pub per_team_flattened: Arc<RwLock<HashMap<i32, FlattenedCohortMap>>>, // team_id -> (cohort_id -> filters)
}

impl CohortCache {
/// Creates a new CohortCache instance
pub fn new() -> Self {
Self {
per_team: Arc::new(RwLock::new(HashMap::new())),
sorted_cohorts: Arc::new(RwLock::new(HashMap::new())),
per_team_flattened: Arc::new(RwLock::new(HashMap::new())),
}
}

/// Fetches, flattens, sorts, and caches all cohorts for a given team if not already cached
pub async fn fetch_and_cache_cohorts(
/// Asynchronous constructor that initializes the CohortCache by fetching and caching cohorts for the given team_id
pub async fn new_with_team(
team_id: i32,
postgres_reader: PostgresReader,
) -> Result<Self, FlagError> {
let cache = Self {
per_team_flattened: Arc::new(RwLock::new(HashMap::new())),
};
cache
.fetch_and_cache_all_cohorts(team_id, postgres_reader)
.await?;
Ok(cache)
}

/// Fetches, parses, and caches all cohorts for a given team
async fn fetch_and_cache_all_cohorts(
&self,
team_id: TeamId,
team_id: i32,
postgres_reader: PostgresReader,
) -> Result<(), FlagError> {
// Acquire write locks to modify the cache
let mut cache = self.per_team.write().await;
let mut sorted = self.sorted_cohorts.write().await;
// Fetch all cohorts for the team
let cohorts = Cohort::list_from_pg(postgres_reader, team_id).await?;

// Check if the team's cohorts are already cached
if cache.contains_key(&team_id) && sorted.contains_key(&team_id) {
return Ok(());
// Build a mapping from cohort_id to Cohort
let mut cohort_map: HashMap<CohortId, Cohort> = HashMap::new();
for cohort in cohorts {
cohort_map.insert(cohort.id, cohort);
}

// Fetch all cohorts for the team from the database
let all_cohorts = fetch_all_cohorts(team_id, postgres_reader).await?;
// Build dependency graph
let dependency_graph = Self::build_dependency_graph(&cohort_map)?;

// Flatten the property filters, resolving dependencies
let flattened = flatten_cohorts(&all_cohorts).await?;
// Perform topological sort
let sorted_cohorts = toposort(&dependency_graph, None).map_err(|_| {
FlagError::CohortDependencyCycle("Cycle detected in cohort dependencies".to_string())
})?;

// Extract all cohort IDs
let cohort_ids: HashSet<CohortId> = flattened.keys().cloned().collect();
// Reverse to process dependencies first
let sorted_cohorts: Vec<CohortId> = sorted_cohorts.into_iter().rev().collect();

// Sort the cohorts topologically based on dependencies
let sorted_ids = sort_cohorts_topologically(cohort_ids, &flattened)?;
// Flatten cohorts
let flattened = Self::flatten_cohorts(&sorted_cohorts, &cohort_map)?;

// Insert the flattened cohorts and their sorted order into the cache
// Cache the flattened cohort filters
let mut cache = self.per_team_flattened.write().await;
cache.insert(team_id, flattened);
sorted.insert(team_id, sorted_ids);

Ok(())
}

/// Retrieves sorted cohort IDs for a team from the cache
pub async fn get_sorted_cohort_ids(
/// Retrieves flattened filters for a given team and cohort
pub async fn get_flattened_filters(
&self,
team_id: TeamId,
postgres_reader: PostgresReader,
) -> Result<Vec<CohortId>, FlagError> {
{
// Acquire read locks to check the cache
let cache = self.per_team.read().await;
let sorted = self.sorted_cohorts.read().await;
if let (Some(_cohort_map), Some(sorted_ids)) =
(cache.get(&team_id), sorted.get(&team_id))
{
if !sorted_ids.is_empty() {
return Ok(sorted_ids.clone());
}
team_id: i32,
cohort_id: CohortId,
) -> Result<Vec<PropertyFilter>, FlagError> {
let cache = self.per_team_flattened.read().await;
if let Some(team_map) = cache.get(&team_id) {
if let Some(filters) = team_map.get(&cohort_id) {
Ok(filters.clone())
} else {
Err(FlagError::CohortNotFound(cohort_id.to_string()))
}
} else {
Err(FlagError::CohortNotFound(cohort_id.to_string()))
}
}

// If not cached, fetch and cache the cohorts
self.fetch_and_cache_cohorts(team_id, postgres_reader)
.await?;
/// Builds a dependency graph where an edge from A to B means A depends on B
fn build_dependency_graph(
cohort_map: &HashMap<CohortId, Cohort>,
) -> Result<DiGraphMap<CohortId, ()>, FlagError> {
let mut graph = DiGraphMap::new();

// Acquire read locks to retrieve the sorted list after caching
let sorted = self.sorted_cohorts.read().await;
if let Some(sorted_ids) = sorted.get(&team_id) {
Ok(sorted_ids.clone())
} else {
Ok(Vec::new())
// Add all cohorts as nodes
for &cohort_id in cohort_map.keys() {
graph.add_node(cohort_id);
}

// Add edges based on dependencies
for (&cohort_id, cohort) in cohort_map.iter() {
let dependencies = Self::extract_dependencies(cohort.filters.clone())?;
for dep_id in dependencies {
if !cohort_map.contains_key(&dep_id) {
return Err(FlagError::CohortNotFound(dep_id.to_string()));
}
graph.add_edge(cohort_id, dep_id, ()); // A depends on B: A -> B
}
}

Ok(graph)
}

/// Retrieves cached cohorts for a team
pub async fn get_cached_cohorts(
&self,
team_id: TeamId,
) -> Result<HashMap<CohortId, CachedCohort>, FlagError> {
let cache = self.per_team.read().await;
if let Some(cohort_map) = cache.get(&team_id) {
Ok(cohort_map.clone())
} else {
Ok(HashMap::new())
/// Extracts all dependent CohortIds from the filters
fn extract_dependencies(
filters_as_json: serde_json::Value,
) -> Result<HashSet<CohortId>, FlagError> {
let filters: CohortProperty = serde_json::from_value(filters_as_json)?;
let mut dependencies = HashSet::new();
Self::traverse_filters(&filters.properties, &mut dependencies)?;
Ok(dependencies)
}

/// Recursively traverses the filter tree to find cohort dependencies
fn traverse_filters(
inner: &crate::cohort_models::InnerCohortProperty,
dependencies: &mut HashSet<CohortId>,
) -> Result<(), FlagError> {
for cohort_values in &inner.values {
for filter in &cohort_values.values {
if filter.prop_type == "cohort" && filter.key == "id" {
// Assuming the value is a single integer CohortId
if let Some(cohort_id) = filter.value.as_i64() {
dependencies.insert(cohort_id as CohortId);
} else {
return Err(FlagError::CohortFiltersParsingError); // TODO more data here?
}
}
// Handle nested properties if necessary
// If the filter can contain nested properties with more conditions, traverse them here
}
}
Ok(())
}
}

async fn fetch_all_cohorts(
team_id: TeamId,
postgres_reader: PostgresReader,
) -> Result<Vec<Cohort>, FlagError> {
let mut conn = postgres_reader.get_connection().await?;

let query = r#"
SELECT *
FROM posthog_cohort
WHERE team_id = $1 AND deleted = FALSE
"#;

let cohorts: Vec<Cohort> = sqlx::query_as::<_, Cohort>(query)
.bind(team_id)
.fetch_all(&mut *conn)
.await
.map_err(|e| FlagError::DatabaseError(e.to_string()))?;

Ok(cohorts)
}
/// Flattens the filters based on sorted cohorts, including only property filters
fn flatten_cohorts(
sorted_cohorts: &[CohortId],
cohort_map: &HashMap<CohortId, Cohort>,
) -> Result<FlattenedCohortMap, FlagError> {
let mut flattened: FlattenedCohortMap = HashMap::new();

async fn flatten_cohorts(
all_cohorts: &Vec<Cohort>,
) -> Result<HashMap<CohortId, CachedCohort>, FlagError> {
let mut flattened = HashMap::new();

for cohort in all_cohorts {
let filters = cohort.parse_filters()?;

// Extract dependencies from cohort filters
let dependencies = filters
.iter()
.filter_map(|f| {
if f.prop_type == "cohort" {
Some(CohortDependencyFilter {
cohort_id: f.value.as_i64().unwrap() as CohortId,
operator: f.operator.clone().unwrap_or(OperatorType::In),
})
for &cohort_id in sorted_cohorts {
let cohort = cohort_map
.get(&cohort_id)
.ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?;

// Use the updated parse_property_filters method
let property_filters = cohort.parse_filters()?;

// Extract dependencies using Cohort's method
let dependencies = cohort.extract_dependencies()?;

let mut combined_filters = Vec::new();

// Include filters from dependencies
for dep_id in &dependencies {
if let Some(dep_filters) = flattened.get(dep_id) {
combined_filters.extend(dep_filters.clone());
} else {
None
return Err(FlagError::CohortNotFound(dep_id.to_string()));
}
})
.collect();

// Filter out cohort filters as they are now represented as dependencies
let non_cohort_filters: Vec<PropertyFilter> = filters
.into_iter()
.filter(|f| f.prop_type != "cohort")
.collect();
let cached_cohort = CachedCohort {
filters: non_cohort_filters,
dependencies,
};
}

flattened.insert(cohort.id, cached_cohort);
}
// Include own filters
combined_filters.extend(property_filters);

Ok(flattened)
// Insert into flattened map
flattened.insert(cohort_id, combined_filters);
}

Ok(flattened)
}
}
Loading

0 comments on commit 4c49bc4

Please sign in to comment.