Skip to content

Commit

Permalink
oh lol right let's actually ship
Browse files Browse the repository at this point in the history
dmarticus committed Oct 23, 2024
1 parent 4f20e07 commit ed00224
Showing 4 changed files with 174 additions and 72 deletions.
6 changes: 6 additions & 0 deletions rust/feature-flags/src/api.rs
Original file line number Diff line number Diff line change
@@ -102,6 +102,8 @@ pub enum FlagError {
TimeoutError,
#[error("No group type mappings")]
NoGroupTypeMappings,
#[error("Invalid cohort id")]
InvalidCohortId,
}

impl IntoResponse for FlagError {
@@ -194,6 +196,10 @@ impl IntoResponse for FlagError {
"The requested row was not found in the database. Please try again later or contact support if the problem persists.".to_string(),
)
}
FlagError::InvalidCohortId => {
tracing::error!("Invalid cohort id: {:?}", self);
(StatusCode::BAD_REQUEST, "Invalid cohort id".to_string())
}
}
.into_response()
}
192 changes: 122 additions & 70 deletions rust/feature-flags/src/cohort_definitions.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use tracing::instrument;

use crate::{
api::FlagError,
database::Client as DatabaseClient,
flag_definitions::{OperatorType, PropertyFilter},
};
use crate::{api::FlagError, database::Client as DatabaseClient, flag_definitions::PropertyFilter};

#[derive(Debug, FromRow)]
struct CohortRow {
@@ -25,7 +22,7 @@ struct CohortRow {
is_static: bool,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Cohort {
pub id: i32,
pub name: String,
@@ -90,47 +87,56 @@ impl Cohort {
}

/// Parses the filters JSON into a CohortProperty structure
fn parse_filters(&self) -> Result<CohortProperty, FlagError> {
serde_json::from_value(self.filters.clone()).map_err(|e| {
tracing::error!("Failed to parse filters: {}", e);
FlagError::Internal(format!("Invalid filters format: {}", e))
})
pub fn parse_filters(&self) -> Result<Vec<PropertyFilter>, FlagError> {
let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone())?;
Ok(cohort_property.to_property_filters())
}
}

use std::collections::{HashMap, HashSet};

type CohortId = i32;

// Assuming CohortOrEmpty is an enum or struct representing a Cohort or an empty value
enum CohortOrEmpty {
pub enum CohortOrEmpty {
Cohort(Cohort),
Empty,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "UPPERCASE")]
enum CohortPropertyType {
pub enum CohortPropertyType {
AND,
OR,
}

// TODO this should serialize to "properties" in the DB
#[derive(Debug, Clone, Deserialize, Serialize)]
struct CohortProperty {
pub struct CohortProperty {
#[serde(rename = "type")]
prop_type: CohortPropertyType, // TODO make this an AND|OR string enum
values: Vec<CohortValues>,
}

impl CohortProperty {
pub fn to_property_filters(&self) -> Vec<PropertyFilter> {
self.values
.iter()
.flat_map(|value| &value.values)
.cloned()
.collect()
}
}

#[derive(Debug, Clone, Deserialize, Serialize)]
struct CohortValues {
pub struct CohortValues {
#[serde(rename = "type")]
prop_type: String,
values: Vec<PropertyFilter>,
}

fn sort_cohorts_topologically(
/// Sorts the given cohorts in an order where cohorts with no dependencies are placed first,
/// followed by cohorts that depend on the preceding ones. It ensures that each cohort in the sorted list
/// only depends on cohorts that appear earlier in the list.
pub fn sort_cohorts_topologically(
cohort_ids: HashSet<CohortId>,
seen_cohorts_cache: &HashMap<CohortId, CohortOrEmpty>,
) -> Vec<CohortId> {
@@ -153,20 +159,25 @@ fn sort_cohorts_topologically(
}
seen_cohorts.insert(cohort.id);

// Parse the filters into CohortProperty
let cohort_property = match cohort.parse_filters() {
Ok(property) => property,
// Parse the filters into PropertyFilters
let property_filters = match cohort.parse_filters() {
Ok(filters) => filters,
Err(e) => {
tracing::error!("Error parsing filters for cohort {}: {}", cohort.id, e);
return;
}
};

// Iterate through the properties to find dependencies
for value in &cohort_property.values {
if value.prop_type == "cohort" {
if let Some(id) = value.value.as_i64() {
let child_id = id as CohortId;
// Iterate through the property filters to find dependencies
for filter in property_filters {
if filter.prop_type == "cohort" {
let child_id = match filter.value {
serde_json::Value::Number(num) => num.as_i64().map(|n| n as CohortId),
serde_json::Value::String(ref s) => s.parse::<CohortId>().ok(),
_ => None,
};

if let Some(child_id) = child_id {
dependency_graph
.entry(cohort.id)
.or_insert_with(Vec::new)
@@ -182,50 +193,6 @@ fn sort_cohorts_topologically(
seen_cohorts_cache,
);
}
} else if let Some(id_str) = value.value.as_str() {
if let Ok(child_id) = id_str.parse::<CohortId>() {
dependency_graph
.entry(cohort.id)
.or_insert_with(Vec::new)
.push(child_id);

if let Some(CohortOrEmpty::Cohort(child_cohort)) =
seen_cohorts_cache.get(&child_id)
{
traverse(
child_cohort,
dependency_graph,
seen_cohorts,
seen_cohorts_cache,
);
}
}
}
}

// Handle nested properties recursively if needed
if let Some(nested_values) = &value.values {
for nested in nested_values {
if nested.prop_type == "cohort" {
if let Some(id) = nested.value.as_i64() {
let child_id = id as CohortId;
dependency_graph
.entry(cohort.id)
.or_insert_with(Vec::new)
.push(child_id);

if let Some(CohortOrEmpty::Cohort(child_cohort)) =
seen_cohorts_cache.get(&child_id)
{
traverse(
child_cohort,
dependency_graph,
seen_cohorts,
seen_cohorts_cache,
);
}
}
}
}
}
}
@@ -271,3 +238,88 @@ fn sort_cohorts_topologically(

sorted_cohort_ids
}

pub async fn get_dependent_cohorts(
cohort: &Cohort,
seen_cohorts_cache: &mut HashMap<CohortId, CohortOrEmpty>,
team_id: i32,
db_client: Arc<dyn DatabaseClient + Send + Sync>,
) -> Result<Vec<Cohort>, FlagError> {
let mut dependent_cohorts = Vec::new();
let mut seen_cohort_ids = HashSet::new();
seen_cohort_ids.insert(cohort.id);

let mut queue = VecDeque::new();

let property_filters = match cohort.parse_filters() {
Ok(filters) => filters,
Err(e) => {
tracing::error!("Failed to parse filters for cohort {}: {}", cohort.id, e);
return Err(FlagError::Internal(format!(
"Failed to parse cohort filters: {}",
e
)));
}
};

// Initial queue population
for filter in &property_filters {
if filter.prop_type == "cohort" {
if let Some(id) = filter.value.as_i64().map(|n| n as CohortId).or_else(|| {
filter
.value
.as_str()
.and_then(|s| s.parse::<CohortId>().ok())
}) {
queue.push_back(id);
}
}
}

while let Some(cohort_id) = queue.pop_front() {
let current_cohort = match seen_cohorts_cache.get(&cohort_id) {
Some(CohortOrEmpty::Cohort(c)) => c.clone(),
Some(CohortOrEmpty::Empty) => continue,
None => {
// Fetch the cohort from the database
match Cohort::from_pg(db_client.clone(), cohort_id, team_id).await {
Ok(c) => {
seen_cohorts_cache.insert(cohort_id, CohortOrEmpty::Cohort(c.clone()));
c
}
Err(e) => {
tracing::warn!("Failed to fetch cohort {}: {}", cohort_id, e);
seen_cohorts_cache.insert(cohort_id, CohortOrEmpty::Empty);
continue;
}
}
}
};

if !seen_cohort_ids.contains(&current_cohort.id) {
dependent_cohorts.push(current_cohort.clone());
seen_cohort_ids.insert(current_cohort.id);

// Parse filters for the current cohort
if let Ok(current_filters) = current_cohort.parse_filters() {
// Add new cohort dependencies to the queue
for filter in current_filters {
if filter.prop_type == "cohort" {
if let Some(id) =
filter.value.as_i64().map(|n| n as CohortId).or_else(|| {
filter
.value
.as_str()
.and_then(|s| s.parse::<CohortId>().ok())
})
{
queue.push_back(id);
}
}
}
}
}
}

Ok(dependent_cohorts)
}
33 changes: 32 additions & 1 deletion rust/feature-flags/src/flag_definitions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient};
use crate::{
api::FlagError, cohort_definitions::Cohort, database::Client as DatabaseClient,
redis::Client as RedisClient,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::instrument;
@@ -120,6 +123,34 @@ impl FeatureFlag {
.and_then(|obj| obj.get(match_val).cloned())
})
}

pub async fn transform_cohort_filters(
&self,
postgres_reader: Arc<dyn DatabaseClient + Send + Sync>,
team_id: i32,
) -> Result<FeatureFlag, FlagError> {
let mut transformed_flag = self.clone();
for group in &mut transformed_flag.filters.groups {
if let Some(properties) = &mut group.properties {
let mut new_properties = Vec::new();
for prop in properties.iter() {
if prop.prop_type == "cohort" {
// TODO is there a cleaner way to handle the who cohort values being numbers thing?
let cohort_id = prop.value.as_i64().ok_or(FlagError::InvalidCohortId)?;
let cohort =
Cohort::from_pg(postgres_reader.clone(), cohort_id as i32, team_id)
.await?;
let cohort_properties = cohort.parse_filters()?;
new_properties.extend(cohort_properties);
} else {
new_properties.push(prop.clone());
}
}
group.properties = Some(new_properties);
}
}
Ok(transformed_flag)
}
}

#[derive(Clone, Debug, Default, Deserialize, Serialize)]
15 changes: 14 additions & 1 deletion rust/feature-flags/src/flag_matching.rs
Original file line number Diff line number Diff line change
@@ -633,6 +633,19 @@ impl FeatureFlagMatcher {
}
}

let has_cohort_filter = flag.filters.groups.iter().any(|group| {
group.properties.as_ref().map_or(false, |props| {
props.iter().any(|prop| prop.prop_type == "cohort")
})
});

let flag_to_evaluate = if has_cohort_filter {
flag.transform_cohort_filters(self.postgres_reader.clone(), self.team_id)
.await?
} else {
flag.clone()
};

// Sort conditions with variant overrides to the top so that we can evaluate them first
let mut sorted_conditions: Vec<(usize, &FlagGroupType)> =
flag.get_conditions().iter().enumerate().collect();
@@ -643,7 +656,7 @@ impl FeatureFlagMatcher {
for (index, condition) in sorted_conditions {
let (is_match, reason) = self
.is_condition_match(
flag,
&flag_to_evaluate,
condition,
property_overrides.clone(),
hash_key_overrides.clone(),

0 comments on commit ed00224

Please sign in to comment.