diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 9b38709da35c3..f4d14e9ed49ce 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -165,6 +165,12 @@ jobs: DATABASE_URL: 'postgres://posthog:posthog@localhost:5432/posthog' run: cd ../ && python manage.py setup_test_environment --only-postgres + - name: Download MaxMind Database + if: needs.changes.outputs.rust == 'true' + run: | + mkdir -p ../share + curl -L "https://mmdbcdn.posthog.net/" --http1.1 | brotli --decompress --output=../share/GeoLite2-City.mmdb + - name: Run cargo test if: needs.changes.outputs.rust == 'true' run: | diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 59078d1b019cc..6c15916314724 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1132,6 +1132,7 @@ dependencies = [ "axum-client-ip", "bytes", "envconfig", + "maxminddb", "once_cell", "rand", "redis", @@ -2141,6 +2142,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "maxminddb" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d13fa57adcc4f3aca91e511b3cdaa58ed8cbcbf97f20e342a11218c76e127f51" +dependencies = [ + "log", + "serde", +] + [[package]] name = "md-5" version = "0.10.6" diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index e4d51dc308d34..b43d09cc93d2f 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -28,6 +28,7 @@ thiserror = { workspace = true } serde-pickle = { version = "1.1.1"} sha1 = "0.10.6" regex = "1.10.4" +maxminddb = "0.17" sqlx = { workspace = true } uuid = { workspace = true } diff --git a/rust/feature-flags/src/config.rs b/rust/feature-flags/src/config.rs index a9e3517398770..1f1c47a99249b 100644 --- a/rust/feature-flags/src/config.rs +++ b/rust/feature-flags/src/config.rs @@ -1,8 +1,10 @@ use envconfig::Envconfig; use once_cell::sync::Lazy; use std::net::SocketAddr; +use std::path::{Path, PathBuf}; use std::str::FromStr; +// TODO rewrite this to follow the AppConfig pattern in other files #[derive(Envconfig, Clone, Debug)] pub struct Config { #[envconfig(default = "127.0.0.1:3001")] @@ -25,6 +27,9 @@ pub struct Config { #[envconfig(default = "1")] pub acquire_timeout_secs: u64, + + #[envconfig(from = "MAXMIND_DB_PATH", default = "")] + pub maxmind_db_path: String, } impl Config { @@ -38,6 +43,21 @@ impl Config { max_concurrent_jobs: 1024, max_pg_connections: 100, acquire_timeout_secs: 1, + maxmind_db_path: "".to_string(), + } + } + + pub fn get_maxmind_db_path(&self) -> PathBuf { + if self.maxmind_db_path.is_empty() { + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .parent() + .unwrap() + .join("share") + .join("GeoLite2-City.mmdb") + } else { + PathBuf::from(&self.maxmind_db_path) } } } diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index 225a4ba4898a9..df0b1998cd1bf 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -195,7 +195,7 @@ impl FeatureFlagList { team_id: i32, flags: &FeatureFlagList, ) -> Result<(), FlagError> { - let payload = serde_json::to_string(flags).map_err(|e| { + let payload = serde_json::to_string(&flags.flags).map_err(|e| { tracing::error!("Failed to serialize flags: {}", e); FlagError::DataParsingError })?; @@ -1392,23 +1392,46 @@ mod tests { } // Fetch flags from both sources - let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) + let mut redis_flags = FeatureFlagList::from_redis(redis_client, team.id) .await .expect("Failed to fetch flags from Redis"); - let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + let mut pg_flags = FeatureFlagList::from_pg(pg_client, team.id) .await .expect("Failed to fetch flags from Postgres"); + // Sort flags by key to ensure consistent order + redis_flags.flags.sort_by(|a, b| a.key.cmp(&b.key)); + pg_flags.flags.sort_by(|a, b| a.key.cmp(&b.key)); + // Compare results - assert_eq!(redis_flags.flags.len(), pg_flags.flags.len()); + assert_eq!( + redis_flags.flags.len(), + pg_flags.flags.len(), + "Number of flags mismatch" + ); + for (redis_flag, pg_flag) in redis_flags.flags.iter().zip(pg_flags.flags.iter()) { - assert_eq!(redis_flag.key, pg_flag.key); - assert_eq!(redis_flag.name, pg_flag.name); - assert_eq!(redis_flag.active, pg_flag.active); - assert_eq!(redis_flag.deleted, pg_flag.deleted); + assert_eq!(redis_flag.key, pg_flag.key, "Flag key mismatch"); + assert_eq!( + redis_flag.name, pg_flag.name, + "Flag name mismatch for key: {}", + redis_flag.key + ); + assert_eq!( + redis_flag.active, pg_flag.active, + "Flag active status mismatch for key: {}", + redis_flag.key + ); + assert_eq!( + redis_flag.deleted, pg_flag.deleted, + "Flag deleted status mismatch for key: {}", + redis_flag.key + ); assert_eq!( redis_flag.filters.groups[0].rollout_percentage, - pg_flag.filters.groups[0].rollout_percentage + pg_flag.filters.groups[0].rollout_percentage, + "Flag rollout percentage mismatch for key: {}", + redis_flag.key ); } } diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 485d8a646e823..88911c90bb7be 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,7 +1,7 @@ use crate::{ api::FlagError, database::Client as DatabaseClient, - flag_definitions::{FeatureFlag, FlagGroupType}, + flag_definitions::{FeatureFlag, FlagGroupType, PropertyFilter}, property_matching::match_property, }; use serde_json::Value; @@ -37,7 +37,13 @@ pub struct FeatureFlagMatcher { // pub flags: Vec, pub distinct_id: String, pub database_client: Option>, + // TODO do I need cached_properties, or do I get them from the request? + // like, in python I get them from the request. Hmm. Let me try that. + // OH, or is this the FlagMatcherCache. Yeah, so this is the flag matcher cache cached_properties: Option>, + person_property_overrides: Option>, + // TODO handle group properties + // group_property_overrides: Option>>, } const LONG_SCALE: u64 = 0xfffffffffffffff; @@ -46,21 +52,28 @@ impl FeatureFlagMatcher { pub fn new( distinct_id: String, database_client: Option>, + person_property_overrides: Option>, + // group_property_overrides: Option>>, ) -> Self { FeatureFlagMatcher { // flags, distinct_id, database_client, cached_properties: None, + person_property_overrides, + // group_property_overrides, } } - pub async fn get_match(&mut self, feature_flag: &FeatureFlag) -> FeatureFlagMatch { + pub async fn get_match( + &mut self, + feature_flag: &FeatureFlag, + ) -> Result { if self.hashed_identifier(feature_flag).is_none() { - return FeatureFlagMatch { + return Ok(FeatureFlagMatch { matches: false, variant: None, - }; + }); } // TODO: super groups for early access @@ -69,10 +82,10 @@ impl FeatureFlagMatcher { for (index, condition) in feature_flag.get_conditions().iter().enumerate() { let (is_match, _evaluation_reason) = self .is_condition_match(feature_flag, condition, index) - .await; + .await?; if is_match { - // TODO: This is a bit awkward, we should handle overrides only when variants exist. + // TODO: this is a bit awkward, we should only handle variants when overrides exist let variant = match condition.variant.clone() { Some(variant_override) => { if feature_flag @@ -88,16 +101,25 @@ impl FeatureFlagMatcher { None => self.get_matching_variant(feature_flag), }; - // let payload = self.get_matching_payload(is_match, variant, feature_flag); - return FeatureFlagMatch { + return Ok(FeatureFlagMatch { matches: true, variant, - }; + }); } } - FeatureFlagMatch { + Ok(FeatureFlagMatch { matches: false, variant: None, + }) + } + + fn check_rollout(&self, feature_flag: &FeatureFlag, rollout_percentage: f64) -> (bool, String) { + if rollout_percentage == 100.0 + || self.get_hash(feature_flag, "") <= (rollout_percentage / 100.0) + { + (true, "CONDITION_MATCH".to_string()) + } else { + (false, "OUT_OF_ROLLOUT_BOUND".to_string()) } } @@ -108,39 +130,71 @@ impl FeatureFlagMatcher { feature_flag: &FeatureFlag, condition: &FlagGroupType, _index: usize, - ) -> (bool, String) { + ) -> Result<(bool, String), FlagError> { let rollout_percentage = condition.rollout_percentage.unwrap_or(100.0); - let mut condition_match = true; - - if let Some(ref properties) = condition.properties { + if let Some(properties) = &condition.properties { if properties.is_empty() { - condition_match = true; - } else { - // TODO: First handle given override properties before going to db - let target_properties = self - .get_person_properties(feature_flag.team_id, self.distinct_id.clone()) - .await - .unwrap_or_default(); - // TODO: Handle db issues / person not found - - condition_match = properties.iter().all(|property| { - match_property(property, &target_properties, false).unwrap_or(false) - }); + return Ok(self.check_rollout(feature_flag, rollout_percentage)); } - }; - if !condition_match { - return (false, "NO_CONDITION_MATCH".to_string()); - } else if rollout_percentage == 100.0 { - // TODO: Check floating point schenanigans if any - return (true, "CONDITION_MATCH".to_string()); + let target_properties = self.get_target_properties(feature_flag, properties).await?; + + if !self.all_properties_match(properties, &target_properties) { + return Ok((false, "NO_CONDITION_MATCH".to_string())); + } } - if self.get_hash(feature_flag, "") > (rollout_percentage / 100.0) { - return (false, "OUT_OF_ROLLOUT_BOUND".to_string()); + Ok(self.check_rollout(feature_flag, rollout_percentage)) + } + + async fn get_target_properties( + &mut self, + feature_flag: &FeatureFlag, + properties: &[PropertyFilter], + ) -> Result, FlagError> { + self.get_person_properties(feature_flag.team_id, properties) + .await + // TODO handle group properties, will go something like this + // if let Some(group_index) = feature_flag.get_group_type_index() { + // self.get_group_properties(feature_flag.team_id, group_index, properties) + // } else { + // self.get_person_properties(feature_flag.team_id, properties) + // .await + // } + } + + async fn get_person_properties( + &mut self, + team_id: i32, + properties: &[PropertyFilter], + ) -> Result, FlagError> { + if let Some(person_overrides) = &self.person_property_overrides { + // Check if all required properties are present in the overrides + // and none of them are of type "cohort" + let should_prefer_overrides = properties + .iter() + .all(|prop| person_overrides.contains_key(&prop.key) && prop.prop_type != "cohort"); + + if should_prefer_overrides { + // TODO let's count how often this happens + return Ok(person_overrides.clone()); + } } - (true, "CONDITION_MATCH".to_string()) + // If we don't prefer the overrides (they're either not present, don't contain enough properties to evaluate the condition, + // or contain a cohort property), fall back to getting properties from cache or DB + self.get_person_properties_from_cache_or_db(team_id, self.distinct_id.clone()) + .await + } + + fn all_properties_match( + &self, + condition_properties: &[PropertyFilter], + target_properties: &HashMap, + ) -> bool { + condition_properties + .iter() + .all(|property| match_property(property, target_properties, false).unwrap_or(false)) } pub fn hashed_identifier(&self, feature_flag: &FeatureFlag) -> Option { @@ -177,6 +231,7 @@ impl FeatureFlagMatcher { hash_val as f64 / LONG_SCALE as f64 } + /// This function takes a feature flag and returns the key of the variant that should be shown to the user. pub fn get_matching_variant(&self, feature_flag: &FeatureFlag) -> Option { let hash = self.get_hash(feature_flag, "variant"); let mut total_percentage = 0.0; @@ -190,7 +245,8 @@ impl FeatureFlagMatcher { None } - pub async fn get_person_properties( + /// This function takes a feature flag and returns the key of the variant that should be shown to the user. + pub async fn get_person_properties_from_cache_or_db( &mut self, team_id: i32, distinct_id: String, @@ -199,6 +255,7 @@ impl FeatureFlagMatcher { // Depends on how often we're calling this function // to match all flags for a single person + // TODO which of these properties do we need to cache? if let Some(cached_props) = self.cached_properties.clone() { // TODO: Maybe we don't want to copy around all user properties, this will by far be the largest chunk // of data we're copying around. Can we work with references here? @@ -243,6 +300,15 @@ impl FeatureFlagMatcher { Ok(props) } + + // async fn get_group_properties_from_cache_or_db( + // &self, + // team_id: i32, + // group_index: usize, + // properties: &Vec, + // ) -> HashMap { + // todo!() + // } } #[cfg(test)] @@ -251,7 +317,33 @@ mod tests { use serde_json::json; use super::*; - use crate::test_utils::{insert_new_team_in_pg, insert_person_for_team_in_pg, setup_pg_client}; + use crate::{ + flag_definitions::{FlagFilters, MultivariateFlagOptions, MultivariateFlagVariant}, + test_utils::{insert_new_team_in_pg, insert_person_for_team_in_pg, setup_pg_client}, + }; + + fn create_test_flag(team_id: i32, properties: Vec) -> FeatureFlag { + FeatureFlag { + id: 1, + team_id, + name: Some("Test Flag".to_string()), + key: "test_flag".to_string(), + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: Some(properties), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + deleted: false, + active: true, + ensure_experience_continuity: false, + } + } #[tokio::test] async fn test_fetch_properties_from_pg_to_match() { @@ -300,22 +392,129 @@ mod tests { )) .unwrap(); - let mut matcher = FeatureFlagMatcher::new(distinct_id, Some(client.clone())); - let match_result = matcher.get_match(&flag).await; + let mut matcher = FeatureFlagMatcher::new(distinct_id, Some(client.clone()), None); + let match_result = matcher.get_match(&flag).await.unwrap(); assert_eq!(match_result.matches, true); assert_eq!(match_result.variant, None); // property value is different - let mut matcher = FeatureFlagMatcher::new(not_matching_distinct_id, Some(client.clone())); - let match_result = matcher.get_match(&flag).await; + let mut matcher = + FeatureFlagMatcher::new(not_matching_distinct_id, Some(client.clone()), None); + let match_result = matcher.get_match(&flag).await.unwrap(); assert_eq!(match_result.matches, false); assert_eq!(match_result.variant, None); // person does not exist let mut matcher = - FeatureFlagMatcher::new("other_distinct_id".to_string(), Some(client.clone())); - let match_result = matcher.get_match(&flag).await; + FeatureFlagMatcher::new("other_distinct_id".to_string(), Some(client.clone()), None); + let match_result = matcher.get_match(&flag).await.unwrap(); assert_eq!(match_result.matches, false); assert_eq!(match_result.variant, None); } + + #[tokio::test] + async fn test_person_property_overrides() { + let client = setup_pg_client(None).await; + let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + + let flag = create_test_flag( + team.id, + vec![PropertyFilter { + key: "email".to_string(), + value: json!("override@example.com"), + operator: None, + prop_type: "email".to_string(), + group_type_index: None, + }], + ); + + let overrides = HashMap::from([("email".to_string(), json!("override@example.com"))]); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + Some(client.clone()), + Some(overrides), + ); + + let match_result = matcher.get_match(&flag).await.unwrap(); + assert_eq!(match_result.matches, true); + } + + #[test] + fn test_hashed_identifier() { + let flag = create_test_flag(1, vec![]); + + let matcher = FeatureFlagMatcher::new("test_user".to_string(), None, None); + assert_eq!( + matcher.hashed_identifier(&flag), + Some("test_user".to_string()) + ); + + // Test with a group type index (this part of the functionality is not implemented yet) + // let mut group_flag = flag.clone(); + // group_flag.filters.aggregation_group_type_index = Some(1); + // assert_eq!(matcher.hashed_identifier(&group_flag), Some("".to_string())); + } + + #[test] + fn test_get_matching_variant() { + let flag = FeatureFlag { + id: 1, + team_id: 1, + name: Some("Test Flag".to_string()), + key: "test_flag".to_string(), + filters: FlagFilters { + groups: vec![], + multivariate: Some(MultivariateFlagOptions { + variants: vec![ + MultivariateFlagVariant { + name: Some("Control".to_string()), + key: "control".to_string(), + rollout_percentage: 33.0, + }, + MultivariateFlagVariant { + name: Some("Test".to_string()), + key: "test".to_string(), + rollout_percentage: 33.0, + }, + MultivariateFlagVariant { + name: Some("Test2".to_string()), + key: "test2".to_string(), + rollout_percentage: 34.0, + }, + ], + }), + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + deleted: false, + active: true, + ensure_experience_continuity: false, + }; + + let matcher = FeatureFlagMatcher::new("test_user".to_string(), None, None); + let variant = matcher.get_matching_variant(&flag); + assert!(variant.is_some()); + assert!(["control", "test", "test2"].contains(&variant.unwrap().as_str())); + } + + #[tokio::test] + async fn test_is_condition_match_empty_properties() { + let flag = create_test_flag(1, vec![]); + + let condition = FlagGroupType { + variant: None, + properties: Some(vec![]), + rollout_percentage: Some(100.0), + }; + + let mut matcher = FeatureFlagMatcher::new("test_user".to_string(), None, None); + let (is_match, reason) = matcher + .is_condition_match(&flag, &condition, 0) + .await + .unwrap(); + assert_eq!(is_match, true); + assert_eq!(reason, "CONDITION_MATCH"); + } } diff --git a/rust/feature-flags/src/flag_request.rs b/rust/feature-flags/src/flag_request.rs new file mode 100644 index 0000000000000..d15876a37481b --- /dev/null +++ b/rust/feature-flags/src/flag_request.rs @@ -0,0 +1,488 @@ +use std::{collections::HashMap, sync::Arc}; + +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tracing::instrument; + +use crate::{ + api::FlagError, database::Client as DatabaseClient, flag_definitions::FeatureFlagList, + redis::Client as RedisClient, team::Team, +}; + +#[derive(Default, Debug, Deserialize, Serialize)] +pub struct FlagRequest { + #[serde( + alias = "$token", + alias = "api_key", + skip_serializing_if = "Option::is_none" + )] + pub token: Option, + #[serde(alias = "$distinct_id", skip_serializing_if = "Option::is_none")] + pub distinct_id: Option, + pub geoip_disable: Option, + #[serde(default)] + pub person_properties: Option>, + #[serde(default)] + pub groups: Option>, + // TODO: better type this since we know its going to be a nested json + #[serde(default)] + pub group_properties: Option>>, + #[serde(alias = "$anon_distinct_id", skip_serializing_if = "Option::is_none")] + pub anon_distinct_id: Option, + pub ip_address: Option, +} + +impl FlagRequest { + /// Takes a request payload and tries to read it. + /// Only supports base64 encoded payloads or uncompressed utf-8 as json. + #[instrument(skip_all)] + pub fn from_bytes(bytes: Bytes) -> Result { + tracing::debug!(len = bytes.len(), "decoding new request"); + // TODO: Add base64 decoding + let payload = String::from_utf8(bytes.into()).map_err(|e| { + tracing::error!("failed to decode body: {}", e); + FlagError::RequestDecodingError(String::from("invalid body encoding")) + })?; + + tracing::debug!(json = payload, "decoded event data"); + Ok(serde_json::from_str::(&payload)?) + } + + /// Extracts the token from the request and verifies it against the cache. + /// If the token is not found in the cache, it will be verified against the database. + pub async fn extract_and_verify_token( + &self, + redis_client: Arc, + pg_client: Arc, + ) -> Result { + let token = match self { + FlagRequest { + token: Some(token), .. + } => token.to_string(), + _ => return Err(FlagError::NoTokenError), + }; + + match Team::from_redis(redis_client.clone(), token.clone()).await { + Ok(_) => Ok(token), + Err(_) => { + // Fallback: Check PostgreSQL if not found in Redis + match Team::from_pg(pg_client, token.clone()).await { + Ok(team) => { + // Token found in PostgreSQL, update Redis cache so that we can verify it from Redis next time + if let Err(e) = Team::update_redis_cache(redis_client, &team).await { + tracing::warn!("Failed to update Redis cache: {}", e); + } + Ok(token) + } + // TODO do we need a custom error here to track the fallback + Err(_) => Err(FlagError::TokenValidationError), + } + } + } + } + + /// Fetches the team from the cache or the database. + /// If the team is not found in the cache, it will be fetched from the database and stored in the cache. + /// Returns the team if found, otherwise an error. + pub async fn get_team_from_cache_or_pg( + &self, + token: &str, + redis_client: Arc, + pg_client: Arc, + ) -> Result { + match Team::from_redis(redis_client.clone(), token.to_owned()).await { + Ok(team) => Ok(team), + Err(_) => match Team::from_pg(pg_client, token.to_owned()).await { + Ok(team) => { + // If we have the team in postgres, but not redis, update redis so we're faster next time + // TODO: we have some counters in django for tracking these cache misses + // we should probably do the same here + if let Err(e) = Team::update_redis_cache(redis_client, &team).await { + tracing::warn!("Failed to update Redis cache: {}", e); + } + Ok(team) + } + // TODO what kind of error should we return here? + Err(e) => Err(e), + }, + } + } + + /// Extracts the distinct_id from the request. + /// If the distinct_id is missing or empty, an error is returned. + pub fn extract_distinct_id(&self) -> Result { + let distinct_id = match &self.distinct_id { + None => return Err(FlagError::MissingDistinctId), + Some(id) => id, + }; + + match distinct_id.len() { + 0 => Err(FlagError::EmptyDistinctId), + 1..=200 => Ok(distinct_id.to_owned()), + _ => Ok(distinct_id.chars().take(200).collect()), + } + } + + /// Extracts the properties from the request. + /// If the request contains person_properties, they are returned. + // TODO do I even need this one? + pub fn extract_properties(&self) -> HashMap { + let mut properties = HashMap::new(); + if let Some(person_properties) = &self.person_properties { + properties.extend(person_properties.clone()); + } + properties + } + + /// Fetches the flags from the cache or the database. + /// If the flags are not found in the cache, they will be fetched from the database and stored in the cache. + /// Returns the flags if found, otherwise an error. + pub async fn get_flags_from_cache_or_pg( + &self, + team_id: i32, + redis_client: Arc, + pg_client: Arc, + ) -> Result { + match FeatureFlagList::from_redis(redis_client.clone(), team_id).await { + Ok(flags) => Ok(flags), + Err(_) => match FeatureFlagList::from_pg(pg_client, team_id).await { + Ok(flags) => { + // If we have the flags in postgres, but not redis, update redis so we're faster next time + // TODO: we have some counters in django for tracking these cache misses + // we should probably do the same here + if let Err(e) = + FeatureFlagList::update_flags_in_redis(redis_client, team_id, &flags).await + { + tracing::warn!("Failed to update Redis cache: {}", e); + } + Ok(flags) + } + // TODO what kind of error should we return here? + Err(e) => Err(e), + }, + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::api::FlagError; + use crate::flag_definitions::{ + FeatureFlag, FeatureFlagList, FlagFilters, FlagGroupType, OperatorType, PropertyFilter, + TEAM_FLAGS_CACHE_PREFIX, + }; + use crate::flag_request::FlagRequest; + use crate::redis::Client as RedisClient; + use crate::team::Team; + use crate::test_utils::{insert_new_team_in_redis, setup_pg_client, setup_redis_client}; + use bytes::Bytes; + use serde_json::json; + + #[test] + fn empty_distinct_id_not_accepted() { + let json = json!({ + "distinct_id": "", + "token": "my_token1", + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + match flag_payload.extract_distinct_id() { + Err(FlagError::EmptyDistinctId) => (), + _ => panic!("expected empty distinct id error"), + }; + } + + #[test] + fn too_large_distinct_id_is_truncated() { + let json = json!({ + "distinct_id": std::iter::repeat("a").take(210).collect::(), + "token": "my_token1", + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + assert_eq!(flag_payload.extract_distinct_id().unwrap().len(), 200); + } + + #[test] + fn distinct_id_is_returned_correctly() { + let json = json!({ + "$distinct_id": "alakazam", + "token": "my_token1", + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + match flag_payload.extract_distinct_id() { + Ok(id) => assert_eq!(id, "alakazam"), + _ => panic!("expected distinct id"), + }; + } + + #[tokio::test] + async fn token_is_returned_correctly() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + let team = insert_new_team_in_redis(redis_client.clone()) + .await + .expect("Failed to insert new team in Redis"); + + let json = json!({ + "$distinct_id": "alakazam", + "token": team.api_token, + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + match flag_payload + .extract_and_verify_token(redis_client, pg_client) + .await + { + Ok(extracted_token) => assert_eq!(extracted_token, team.api_token), + Err(e) => panic!("Failed to extract and verify token: {:?}", e), + }; + } + + #[tokio::test] + async fn test_get_team_from_cache_or_pg() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + let team = insert_new_team_in_redis(redis_client.clone()) + .await + .expect("Failed to insert new team in Redis"); + + let flag_request = FlagRequest { + token: Some(team.api_token.clone()), + ..Default::default() + }; + + // Test fetching from Redis + let result = flag_request + .get_team_from_cache_or_pg(&team.api_token, redis_client.clone(), pg_client.clone()) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().id, team.id); + + // Test fetching from PostgreSQL (simulate Redis miss) + // First, remove the team from Redis + redis_client + .del(format!("team:{}", team.api_token)) + .await + .expect("Failed to remove team from Redis"); + + let result = flag_request + .get_team_from_cache_or_pg(&team.api_token, redis_client.clone(), pg_client.clone()) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().id, team.id); + + // Verify that the team was re-added to Redis + let redis_team = Team::from_redis(redis_client.clone(), team.api_token.clone()).await; + assert!(redis_team.is_ok()); + } + + #[test] + fn test_extract_properties() { + let flag_request = FlagRequest { + person_properties: Some(HashMap::from([ + ("key1".to_string(), json!("value1")), + ("key2".to_string(), json!(42)), + ])), + ..Default::default() + }; + + let properties = flag_request.extract_properties(); + assert_eq!(properties.len(), 2); + assert_eq!(properties.get("key1").unwrap(), &json!("value1")); + assert_eq!(properties.get("key2").unwrap(), &json!(42)); + } + + #[tokio::test] + async fn test_get_flags_from_cache_or_pg() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + let team = insert_new_team_in_redis(redis_client.clone()) + .await + .expect("Failed to insert new team in Redis"); + + // Insert some mock flags into Redis + let mock_flags = FeatureFlagList { + flags: vec![ + FeatureFlag { + id: 1, + team_id: team.id, + name: Some("Beta Feature".to_string()), + key: "beta_feature".to_string(), + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "country".to_string(), + value: json!("US"), + operator: Some(OperatorType::Exact), + prop_type: "person".to_string(), + group_type_index: None, + }]), + rollout_percentage: Some(50.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + deleted: false, + active: true, + ensure_experience_continuity: false, + }, + FeatureFlag { + id: 2, + team_id: team.id, + name: Some("New User Interface".to_string()), + key: "new_ui".to_string(), + filters: FlagFilters { + groups: vec![], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + deleted: false, + active: false, + ensure_experience_continuity: false, + }, + FeatureFlag { + id: 3, + team_id: team.id, + name: Some("Premium Feature".to_string()), + key: "premium_feature".to_string(), + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "is_premium".to_string(), + value: json!(true), + operator: Some(OperatorType::Exact), + prop_type: "person".to_string(), + group_type_index: None, + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + deleted: false, + active: true, + ensure_experience_continuity: false, + }, + ], + }; + + FeatureFlagList::update_flags_in_redis(redis_client.clone(), team.id, &mock_flags) + .await + .expect("Failed to insert mock flags in Redis"); + + let flag_request = FlagRequest::default(); + + // Test fetching from Redis + let result = flag_request + .get_flags_from_cache_or_pg(team.id, redis_client.clone(), pg_client.clone()) + .await; + assert!(result.is_ok()); + let fetched_flags = result.unwrap(); + assert_eq!(fetched_flags.flags.len(), mock_flags.flags.len()); + + // Verify the contents of the fetched flags + let beta_feature = fetched_flags + .flags + .iter() + .find(|f| f.key == "beta_feature") + .unwrap(); + assert!(beta_feature.active); + assert_eq!( + beta_feature.filters.groups[0].rollout_percentage, + Some(50.0) + ); + assert_eq!( + beta_feature.filters.groups[0].properties.as_ref().unwrap()[0].key, + "country" + ); + + let new_ui = fetched_flags + .flags + .iter() + .find(|f| f.key == "new_ui") + .unwrap(); + assert!(!new_ui.active); + assert!(new_ui.filters.groups.is_empty()); + + let premium_feature = fetched_flags + .flags + .iter() + .find(|f| f.key == "premium_feature") + .unwrap(); + assert!(premium_feature.active); + assert_eq!( + premium_feature.filters.groups[0].rollout_percentage, + Some(100.0) + ); + assert_eq!( + premium_feature.filters.groups[0] + .properties + .as_ref() + .unwrap()[0] + .key, + "is_premium" + ); + + // Test fetching from PostgreSQL (simulate Redis miss) + // First, remove the flags from Redis + redis_client + .del(format!("{}:{}", TEAM_FLAGS_CACHE_PREFIX, team.id)) + .await + .expect("Failed to remove flags from Redis"); + + let result = flag_request + .get_flags_from_cache_or_pg(team.id, redis_client.clone(), pg_client.clone()) + .await; + assert!(result.is_ok()); + // Verify that the flags were re-added to Redis + let redis_flags = FeatureFlagList::from_redis(redis_client.clone(), team.id).await; + assert!(redis_flags.is_ok()); + assert_eq!(redis_flags.unwrap().flags.len(), mock_flags.flags.len()); + } + + #[tokio::test] + async fn test_error_cases() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + // Test invalid token + let flag_request = FlagRequest { + token: Some("invalid_token".to_string()), + ..Default::default() + }; + let result = flag_request + .extract_and_verify_token(redis_client.clone(), pg_client.clone()) + .await; + assert!(matches!(result, Err(FlagError::TokenValidationError))); + + // Test missing distinct_id + let flag_request = FlagRequest { + token: Some("valid_token".to_string()), + distinct_id: None, + ..Default::default() + }; + let result = flag_request.extract_distinct_id(); + assert!(matches!(result, Err(FlagError::MissingDistinctId))); + } +} diff --git a/rust/feature-flags/src/geoip.rs b/rust/feature-flags/src/geoip.rs new file mode 100644 index 0000000000000..8ffbf0bfa34e4 --- /dev/null +++ b/rust/feature-flags/src/geoip.rs @@ -0,0 +1,304 @@ +use crate::config::Config; +use maxminddb::Reader; +use serde_json::Value; +use std::collections::HashMap; +use std::net::IpAddr; +use std::str::FromStr; +use thiserror::Error; +use tracing::log::{error, info}; + +#[derive(Error, Debug)] +pub enum GeoIpError { + #[error("Failed to open GeoIP database: {0}")] + DatabaseOpenError(#[from] maxminddb::MaxMindDBError), +} + +pub struct GeoIpClient { + reader: Reader>, +} + +impl GeoIpClient { + /// Creates a new GeoIpClient instance. + /// Returns an error if the database can't be loaded. + pub fn new(config: &Config) -> Result { + let geoip_path = config.get_maxmind_db_path(); + + info!("Attempting to open GeoIP database at: {:?}", geoip_path); + + let reader = Reader::open_readfile(&geoip_path)?; + info!("Successfully opened GeoIP database"); + + Ok(GeoIpClient { reader }) + } + + /// Checks if the given IP address is valid. + fn is_valid_ip(&self, ip: &str) -> bool { + ip != "127.0.0.1" || ip != "::1" + } + + /// Looks up the city data for the given IP address. + /// Returns None if the lookup fails. + fn lookup_city(&self, ip: &str, addr: IpAddr) -> Option { + match self.reader.lookup::(addr) { + Ok(city) => { + info!( + "GeoIP lookup succeeded for IP {}: Full city data: {:?}", + ip, city + ); + Some(city) + } + Err(e) => { + error!("GeoIP lookup error for IP {}: {}", ip, e); + None + } + } + } + + /// Returns a dictionary of geoip properties for the given ip address. + pub fn get_geoip_properties(&self, ip_address: Option<&str>) -> HashMap { + match ip_address { + None => { + info!("No IP address provided; returning empty properties"); + HashMap::new() + } + Some(ip) if !self.is_valid_ip(ip) => { + info!("Returning empty properties for IP: {}", ip); + HashMap::new() + } + Some(ip) => match IpAddr::from_str(ip) { + Ok(addr) => self + .lookup_city(ip, addr) + .map(|city| extract_properties(&city)) + .unwrap_or_default(), + Err(_) => { + error!("Invalid IP address: {}", ip); + HashMap::new() + } + }, + } + } +} + +const GEOIP_FIELDS: [(&str, &[&str]); 7] = [ + ("$geoip_country_name", &["country", "names", "en"]), + ("$geoip_city_name", &["city", "names", "en"]), + ("$geoip_country_code", &["country", "iso_code"]), + ("$geoip_continent_name", &["continent", "names", "en"]), + ("$geoip_continent_code", &["continent", "code"]), + ("$geoip_postal_code", &["postal", "code"]), + ("$geoip_time_zone", &["location", "time_zone"]), +]; + +fn get_nested_value<'a>(data: &'a Value, path: &[&str]) -> Option<&'a str> { + let mut current = data; + for &key in path { + current = current.get(key)?; + } + current.as_str() +} + +fn extract_properties(city: &Value) -> HashMap { + GEOIP_FIELDS + .iter() + .filter_map(|&(field, path)| { + get_nested_value(city, path).map(|value| (field.to_string(), value.to_string())) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + use crate::config::Config; + use std::sync::Once; + + static INIT: Once = Once::new(); + + fn initialize() { + INIT.call_once(|| { + tracing_subscriber::fmt::init(); + }); + } + + fn create_test_service() -> GeoIpClient { + let config = Config::default_test_config(); + GeoIpClient::new(&config).expect("Failed to create GeoIpService") + } + + #[test] + fn test_geoip_service_creation() { + initialize(); + let config = Config::default_test_config(); + let service_result = GeoIpClient::new(&config); + assert!(service_result.is_ok()); + } + + #[test] + fn test_geoip_service_creation_failure() { + initialize(); + let mut config = Config::default_test_config(); + config.maxmind_db_path = "/path/to/nonexistent/file".to_string(); + let service_result = GeoIpClient::new(&config); + assert!(service_result.is_err()); + } + + #[test] + fn test_get_geoip_properties_none() { + initialize(); + let service = create_test_service(); + let result = service.get_geoip_properties(None); + assert!(result.is_empty()); + } + + #[test] + fn test_get_geoip_properties_localhost() { + initialize(); + let service = create_test_service(); + let result = service.get_geoip_properties(Some("127.0.0.1")); + assert!(result.is_empty()); + } + + #[test] + fn test_get_geoip_properties_invalid_ip() { + initialize(); + let service = create_test_service(); + let result = service.get_geoip_properties(Some("not_an_ip")); + assert!(result.is_empty()); + } + + #[test] + fn test_geoip_results() { + initialize(); + let service = create_test_service(); + let test_cases = vec![ + ("13.106.122.3", "Australia"), + ("31.28.64.3", "United Kingdom"), + ("2600:6c52:7a00:11c:1b6:b7b0:ea19:6365", "United States"), + ]; + + for (ip, expected_country) in test_cases { + let result = service.get_geoip_properties(Some(ip)); + info!("GeoIP lookup result for IP {}: {:?}", ip, result); + info!( + "Expected country: {}, Actual country: {:?}", + expected_country, + result.get("$geoip_country_name") + ); + assert_eq!( + result.get("$geoip_country_name"), + Some(&expected_country.to_string()) + ); + assert_eq!(result.len(), 7); + } + } + + #[test] + fn test_geoip_on_local_ip() { + initialize(); + let service = create_test_service(); + let result = service.get_geoip_properties(Some("127.0.0.1")); + assert!(result.is_empty()); + } + + #[test] + fn test_geoip_on_invalid_ip() { + initialize(); + let service = create_test_service(); + let result = service.get_geoip_properties(Some("999.999.999.999")); + assert!(result.is_empty()); + } + + #[test] + fn test_get_nested_value() { + let data = json!({ + "country": { + "names": { + "en": "United States" + } + }, + "city": { + "names": { + "en": "New York" + } + }, + "postal": { + "code": "10001" + } + }); + + assert_eq!( + get_nested_value(&data, &["country", "names", "en"]), + Some("United States") + ); + assert_eq!( + get_nested_value(&data, &["city", "names", "en"]), + Some("New York") + ); + assert_eq!(get_nested_value(&data, &["postal", "code"]), Some("10001")); + assert_eq!(get_nested_value(&data, &["country", "code"]), None); + assert_eq!(get_nested_value(&data, &["nonexistent", "path"]), None); + } + + #[test] + fn test_extract_properties() { + let city_data = json!({ + "country": { + "names": { + "en": "United States" + }, + "iso_code": "US" + }, + "city": { + "names": { + "en": "New York" + } + }, + "continent": { + "names": { + "en": "North America" + }, + "code": "NA" + }, + "postal": { + "code": "10001" + }, + "location": { + "time_zone": "America/New_York" + } + }); + + let properties = extract_properties(&city_data); + + assert_eq!( + properties.get("$geoip_country_name"), + Some(&"United States".to_string()) + ); + assert_eq!( + properties.get("$geoip_city_name"), + Some(&"New York".to_string()) + ); + assert_eq!( + properties.get("$geoip_country_code"), + Some(&"US".to_string()) + ); + assert_eq!( + properties.get("$geoip_continent_name"), + Some(&"North America".to_string()) + ); + assert_eq!( + properties.get("$geoip_continent_code"), + Some(&"NA".to_string()) + ); + assert_eq!( + properties.get("$geoip_postal_code"), + Some(&"10001".to_string()) + ); + assert_eq!( + properties.get("$geoip_time_zone"), + Some(&"America/New_York".to_string()) + ); + assert_eq!(properties.len(), 7); + } +} diff --git a/rust/feature-flags/src/lib.rs b/rust/feature-flags/src/lib.rs index 7784bd7bf1b8d..de5065723e45a 100644 --- a/rust/feature-flags/src/lib.rs +++ b/rust/feature-flags/src/lib.rs @@ -3,13 +3,15 @@ pub mod config; pub mod database; pub mod flag_definitions; pub mod flag_matching; +pub mod flag_request; +pub mod geoip; pub mod property_matching; pub mod redis; +pub mod request_handler; pub mod router; pub mod server; pub mod team; pub mod v0_endpoint; -pub mod v0_request; // Test modules don't need to be compiled with main binary // #[cfg(test)] diff --git a/rust/feature-flags/src/redis.rs b/rust/feature-flags/src/redis.rs index 89dde421d0abc..954ffe1a09f04 100644 --- a/rust/feature-flags/src/redis.rs +++ b/rust/feature-flags/src/redis.rs @@ -34,6 +34,7 @@ pub trait Client { async fn get(&self, k: String) -> Result; async fn set(&self, k: String, v: String) -> Result<()>; + async fn del(&self, k: String) -> Result<(), CustomRedisError>; } pub struct RedisClient { @@ -93,4 +94,14 @@ impl Client for RedisClient { Ok(fut?) } + + async fn del(&self, k: String) -> Result<(), CustomRedisError> { + let mut conn = self.client.get_async_connection().await?; + + let results = conn.del(k); + let fut: Result<(), RedisError> = + timeout(Duration::from_secs(REDIS_TIMEOUT_MILLISECS), results).await?; + + fut.map_err(CustomRedisError::from) + } } diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs new file mode 100644 index 0000000000000..35606727f3259 --- /dev/null +++ b/rust/feature-flags/src/request_handler.rs @@ -0,0 +1,375 @@ +use crate::{ + api::{FlagError, FlagValue, FlagsResponse}, + database::Client, + flag_definitions::FeatureFlagList, + flag_matching::FeatureFlagMatcher, + flag_request::FlagRequest, + geoip::GeoIpClient, + router, +}; +use axum::{extract::State, http::HeaderMap}; +use bytes::Bytes; +use serde::Deserialize; +use serde_json::Value; +use std::sync::Arc; +use std::{collections::HashMap, net::IpAddr}; +use tracing::error; + +#[derive(Deserialize, Default)] +pub enum Compression { + #[default] + Unsupported, + #[serde(rename = "gzip", alias = "gzip-js")] + Gzip, +} + +impl Compression { + pub fn as_str(&self) -> &'static str { + match self { + Compression::Gzip => "gzip", + Compression::Unsupported => "unsupported", + } + } +} + +#[derive(Deserialize, Default)] +pub struct FlagsQueryParams { + #[serde(alias = "v")] + pub version: Option, + + pub compression: Option, + + #[serde(alias = "ver")] + pub lib_version: Option, + + #[serde(alias = "_")] + pub sent_at: Option, +} + +pub struct RequestContext { + pub state: State, + pub ip: IpAddr, + pub meta: FlagsQueryParams, + pub headers: HeaderMap, + pub body: Bytes, +} + +pub async fn process_request(context: RequestContext) -> Result { + let RequestContext { + state, + ip, + meta: _, // TODO use this + headers, + body, + } = context; + + let request = decode_request(&headers, body)?; + let token = request + .extract_and_verify_token(state.redis.clone(), state.postgres.clone()) + .await?; + let team = request + .get_team_from_cache_or_pg(&token, state.redis.clone(), state.postgres.clone()) + .await?; + let distinct_id = request.extract_distinct_id()?; + let person_property_overrides = get_person_property_overrides( + !request.geoip_disable.unwrap_or(false), + request.person_properties.clone(), + &ip, + &state.geoip.clone(), + ); + // TODO group_property_overrides + + let feature_flags_from_cache_or_pg = request + .get_flags_from_cache_or_pg(team.id, state.redis.clone(), state.postgres.clone()) + .await?; + + let flags_response = evaluate_feature_flags( + distinct_id, + feature_flags_from_cache_or_pg, + Some(state.postgres.clone()), + person_property_overrides, + // group_property_overrides, + ) + .await; + + Ok(flags_response) +} + +/// Get person property overrides based on the request +/// - If geoip is enabled, fetch geoip properties and merge them with any person properties +/// - If geoip is disabled, return the person properties as is +/// - If no person properties are provided, return None +pub fn get_person_property_overrides( + geoip_enabled: bool, + person_properties: Option>, + ip: &IpAddr, + geoip_service: &GeoIpClient, +) -> Option> { + match (geoip_enabled, person_properties) { + (true, Some(mut props)) => { + let geoip_props = geoip_service.get_geoip_properties(Some(&ip.to_string())); + if !geoip_props.is_empty() { + props.extend(geoip_props.into_iter().map(|(k, v)| (k, Value::String(v)))); + } + Some(props) + } + (true, None) => { + let geoip_props = geoip_service.get_geoip_properties(Some(&ip.to_string())); + if !geoip_props.is_empty() { + Some( + geoip_props + .into_iter() + .map(|(k, v)| (k, Value::String(v))) + .collect(), + ) + } else { + None + } + } + (false, Some(props)) => Some(props), + (false, None) => None, + } +} + +/// Decode a request into a `FlagRequest` +/// - Currently only supports JSON requests +// TODO support all supported content types +fn decode_request(headers: &HeaderMap, body: Bytes) -> Result { + match headers + .get("content-type") + .map_or("", |v| v.to_str().unwrap_or("")) + { + "application/json" => FlagRequest::from_bytes(body), + ct => Err(FlagError::RequestDecodingError(format!( + "unsupported content type: {}", + ct + ))), + } +} + +/// Evaluate feature flags for a given distinct_id +/// Returns a map of feature flag keys to their values +/// If an error occurs while evaluating a flag, it will be logged and the flag will be omitted from the result +pub async fn evaluate_feature_flags( + distinct_id: String, + feature_flags_from_cache_or_pg: FeatureFlagList, + database_client: Option>, + person_property_overrides: Option>, + // group_property_overrides: Option>>, +) -> FlagsResponse { + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + database_client, + person_property_overrides, + // group_property_overrides, + ); + let mut feature_flags = HashMap::new(); + let mut error_while_computing_flags = false; + let feature_flag_list = feature_flags_from_cache_or_pg.flags; + + for flag in feature_flag_list { + if !flag.active || flag.deleted { + continue; + } + + match matcher.get_match(&flag).await { + Ok(flag_match) => { + let flag_value = if flag_match.matches { + match flag_match.variant { + Some(variant) => FlagValue::String(variant), + None => FlagValue::Boolean(true), + } + } else { + FlagValue::Boolean(false) + }; + feature_flags.insert(flag.key.clone(), flag_value); + } + Err(e) => { + error_while_computing_flags = true; + error!( + "Error evaluating feature flag '{}' for distinct_id '{}': {:?}", + flag.key, distinct_id, e + ); + } + } + } + + FlagsResponse { + error_while_computing_flags, + feature_flags, + } +} + +#[cfg(test)] +mod tests { + use crate::{ + config::Config, + flag_definitions::{FeatureFlag, FlagFilters, FlagGroupType, OperatorType, PropertyFilter}, + test_utils::setup_pg_client, + }; + + use super::*; + use axum::http::HeaderMap; + use serde_json::json; + use std::net::Ipv4Addr; + + fn create_test_geoip_service() -> GeoIpClient { + let config = Config::default_test_config(); + GeoIpClient::new(&config).expect("Failed to create GeoIpService for testing") + } + + #[test] + fn test_geoip_enabled_with_person_properties() { + let geoip_service = create_test_geoip_service(); + + let mut person_props = HashMap::new(); + person_props.insert("name".to_string(), Value::String("John".to_string())); + + let result = get_person_property_overrides( + true, + Some(person_props), + &IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), // Google's public DNS, should be in the US + &geoip_service, + ); + + assert!(result.is_some()); + let result = result.unwrap(); + assert!(result.len() > 1); + assert_eq!(result.get("name"), Some(&Value::String("John".to_string()))); + assert!(result.contains_key("$geoip_country_name")); + } + + #[test] + fn test_geoip_enabled_without_person_properties() { + let geoip_service = create_test_geoip_service(); + + let result = get_person_property_overrides( + true, + None, + &IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), // Google's public DNS, should be in the US + &geoip_service, + ); + + assert!(result.is_some()); + let result = result.unwrap(); + assert!(!result.is_empty()); + assert!(result.contains_key("$geoip_country_name")); + } + + #[test] + fn test_geoip_disabled_with_person_properties() { + let geoip_service = create_test_geoip_service(); + + let mut person_props = HashMap::new(); + person_props.insert("name".to_string(), Value::String("John".to_string())); + + let result = get_person_property_overrides( + false, + Some(person_props), + &IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), + &geoip_service, + ); + + assert!(result.is_some()); + let result = result.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result.get("name"), Some(&Value::String("John".to_string()))); + } + + #[test] + fn test_geoip_disabled_without_person_properties() { + let geoip_service = create_test_geoip_service(); + + let result = get_person_property_overrides( + false, + None, + &IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), + &geoip_service, + ); + + assert!(result.is_none()); + } + + #[test] + fn test_geoip_enabled_local_ip() { + let geoip_service = create_test_geoip_service(); + + let result = get_person_property_overrides( + true, + None, + &IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + &geoip_service, + ); + + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_evaluate_feature_flags() { + let pg_client = setup_pg_client(None).await; + let flag = FeatureFlag { + name: Some("Test Flag".to_string()), + id: 1, + key: "test_flag".to_string(), + active: true, + deleted: false, + team_id: 1, + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "country".to_string(), + value: json!("US"), + operator: Some(OperatorType::Exact), + prop_type: "person".to_string(), + group_type_index: None, + }]), + rollout_percentage: Some(100.0), // Set to 100% to ensure it's always on + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + ensure_experience_continuity: false, + }; + + let feature_flag_list = FeatureFlagList { flags: vec![flag] }; + + let mut person_properties = HashMap::new(); + person_properties.insert("country".to_string(), json!("US")); + + let result = evaluate_feature_flags( + "user123".to_string(), + feature_flag_list, + Some(pg_client), + Some(person_properties), + ) + .await; + + assert!(!result.error_while_computing_flags); + assert!(result.feature_flags.contains_key("test_flag")); + assert_eq!(result.feature_flags["test_flag"], FlagValue::Boolean(true)); + } + + #[test] + fn test_decode_request() { + let mut headers = HeaderMap::new(); + headers.insert("content-type", "application/json".parse().unwrap()); + + let body = Bytes::from(r#"{"token": "test_token", "distinct_id": "user123"}"#); + + let result = decode_request(&headers, body); + + assert!(result.is_ok()); + let request = result.unwrap(); + assert_eq!(request.token, Some("test_token".to_string())); + assert_eq!(request.distinct_id, Some("user123".to_string())); + } + + #[test] + fn test_compression_as_str() { + assert_eq!(Compression::Gzip.as_str(), "gzip"); + assert_eq!(Compression::Unsupported.as_str(), "unsupported"); + } +} diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index 2fbc87c870930..1a32e0837cede 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -2,21 +2,29 @@ use std::sync::Arc; use axum::{routing::post, Router}; -use crate::{database::Client as DatabaseClient, redis::Client as RedisClient, v0_endpoint}; +use crate::{ + database::Client as DatabaseClient, geoip::GeoIpClient, redis::Client as RedisClient, + v0_endpoint, +}; #[derive(Clone)] pub struct State { pub redis: Arc, // TODO: Add pgClient when ready pub postgres: Arc, + pub geoip: Arc, } -pub fn router(redis: Arc, postgres: Arc) -> Router +pub fn router(redis: Arc, postgres: Arc, geoip: Arc) -> Router where R: RedisClient + Send + Sync + 'static, D: DatabaseClient + Send + Sync + 'static, { - let state = State { redis, postgres }; + let state = State { + redis, + postgres, + geoip, + }; Router::new() .route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags)) diff --git a/rust/feature-flags/src/server.rs b/rust/feature-flags/src/server.rs index 37bd721a9a51f..c718657e3af66 100644 --- a/rust/feature-flags/src/server.rs +++ b/rust/feature-flags/src/server.rs @@ -6,6 +6,7 @@ use tokio::net::TcpListener; use crate::config::Config; use crate::database::PgClient; +use crate::geoip::GeoIpClient; use crate::redis::RedisClient; use crate::router; @@ -29,8 +30,16 @@ where } }; + let geoip_service = match GeoIpClient::new(&config) { + Ok(service) => Arc::new(service), + Err(e) => { + tracing::error!("Failed to create GeoIP service: {}", e); + return; + } + }; + // You can decide which client to pass to the router, or pass both if needed - let app = router::router(redis_client, read_postgres_client); + let app = router::router(redis_client, read_postgres_client, geoip_service); tracing::info!("listening on {:?}", listener.local_addr().unwrap()); axum::serve( diff --git a/rust/feature-flags/src/v0_endpoint.rs b/rust/feature-flags/src/v0_endpoint.rs index 95d4c3a813685..56734eae32d45 100644 --- a/rust/feature-flags/src/v0_endpoint.rs +++ b/rust/feature-flags/src/v0_endpoint.rs @@ -1,24 +1,17 @@ -use std::collections::HashMap; -use std::sync::Arc; +use std::net::IpAddr; -use axum::{debug_handler, Json}; -use bytes::Bytes; -// TODO: stream this instead -use axum::extract::{MatchedPath, Query, State}; -use axum::http::{HeaderMap, Method}; -use axum_client_ip::InsecureClientIp; -use tracing::{error, instrument, warn}; - -use crate::api::FlagValue; -use crate::database::Client; -use crate::flag_definitions::FeatureFlagList; -use crate::flag_matching::FeatureFlagMatcher; -use crate::v0_request::Compression; use crate::{ api::{FlagError, FlagsResponse}, + request_handler::{process_request, FlagsQueryParams, RequestContext}, router, - v0_request::{FlagRequest, FlagsQueryParams}, }; +// TODO: stream this instead +use axum::extract::{MatchedPath, Query, State}; +use axum::http::{HeaderMap, Method}; +use axum::{debug_handler, Json}; +use axum_client_ip::InsecureClientIp; +use bytes::Bytes; +use tracing::instrument; /// Feature flag evaluation endpoint. /// Only supports a specific shape of data, and rejects any malformed data. @@ -47,153 +40,50 @@ pub async fn flags( path: MatchedPath, body: Bytes, ) -> Result, FlagError> { - // TODO this could be extracted into some load_data_for_request type thing + record_request_metadata(&headers, &method, &path, &ip, &meta); + + let context = RequestContext { + state, + ip, + meta: meta.0, + headers, + body, + }; + + Ok(Json(process_request(context).await?)) +} + +fn record_request_metadata( + headers: &HeaderMap, + method: &Method, + path: &MatchedPath, + ip: &IpAddr, + meta: &Query, +) { let user_agent = headers .get("user-agent") .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); let content_encoding = headers .get("content-encoding") .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); - let comp = match meta.compression { - None => String::from("unknown"), - Some(Compression::Gzip) => String::from("gzip"), - // Some(Compression::Base64) => String::from("base64"), - Some(Compression::Unsupported) => String::from("unsupported"), - }; - // TODO what do we use this for? - let sent_at = meta.sent_at.unwrap_or(0); + let content_type = headers + .get("content-type") + .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); tracing::Span::current().record("user_agent", user_agent); tracing::Span::current().record("content_encoding", content_encoding); - tracing::Span::current().record("version", meta.version.clone()); - tracing::Span::current().record("lib_version", meta.lib_version.clone()); - tracing::Span::current().record("compression", comp.as_str()); + tracing::Span::current().record("content_type", content_type); + tracing::Span::current().record("version", meta.version.as_deref().unwrap_or("unknown")); + tracing::Span::current().record( + "lib_version", + meta.lib_version.as_deref().unwrap_or("unknown"), + ); + tracing::Span::current().record( + "compression", + meta.compression.as_ref().map_or("none", |c| c.as_str()), + ); tracing::Span::current().record("method", method.as_str()); tracing::Span::current().record("path", path.as_str().trim_end_matches('/')); tracing::Span::current().record("ip", ip.to_string()); - tracing::Span::current().record("sent_at", sent_at.to_string()); - - tracing::debug!("request headers: {:?}", headers); - - // TODO handle different content types and encodings - let request = match headers - .get("content-type") - .map_or("", |v| v.to_str().unwrap_or("")) - { - "application/json" => { - tracing::Span::current().record("content_type", "application/json"); - FlagRequest::from_bytes(body) - } - // TODO support other content types - ct => { - return Err(FlagError::RequestDecodingError(format!( - "unsupported content type: {}", - ct - ))); - } - }?; - - // this errors up top-level if there's no token - // return the team here, too? - let token = request - .extract_and_verify_token(state.redis.clone(), state.postgres.clone()) - .await?; - - // at this point, we should get the team since I need the team values for options on the payload - // Note that the team here is different than the redis team. - // TODO: consider making team an option, since we could fetch a project instead of a team - // if the token is valid by the team doesn't exist for some reason. Like, there might be a case - // where the token exists in the database but the team has been deleted. - // That said, though, I don't think this is necessary because we're already validating that the token exists - // in the database, so if it doesn't exist, we should be returning an error there. - // TODO make that one request; we already extract the token by accessing the team table, so we can just extract the team here - let team = request - .get_team_from_cache_or_pg(&token, state.redis.clone(), state.postgres.clone()) - .await?; - - // this errors up top-level if there's no distinct_id or missing one - let distinct_id = request.extract_distinct_id()?; - - // TODO handle disabled flags, should probably do that right at the beginning - - tracing::Span::current().record("token", &token); - tracing::Span::current().record("distinct_id", &distinct_id); - - tracing::debug!("request: {:?}", request); - - // now that I have a team ID and a distinct ID, I can evaluate the feature flags - - // first, get the flags - let all_feature_flags = request - .get_flags_from_cache_or_pg(team.id, state.redis.clone(), state.postgres.clone()) - .await?; - - tracing::Span::current().record("flags", &format!("{:?}", all_feature_flags)); - - // debug log, I'm keeping it around bc it's useful - // tracing::debug!( - // "flags: {}", - // serde_json::to_string_pretty(&all_feature_flags) - // .unwrap_or_else(|_| format!("{:?}", all_feature_flags)) - // ); - - let flags_response = - evaluate_feature_flags(distinct_id, all_feature_flags, Some(state.postgres.clone())).await; - - Ok(Json(flags_response)) - - // TODO need to handle experience continuity here -} - -pub async fn evaluate_feature_flags( - distinct_id: String, - feature_flag_list: FeatureFlagList, - database_client: Option>, -) -> FlagsResponse { - let mut matcher = FeatureFlagMatcher::new(distinct_id.clone(), database_client); - let mut feature_flags = HashMap::new(); - let mut error_while_computing_flags = false; - let all_feature_flags = feature_flag_list.flags; - - for flag in all_feature_flags { - if !flag.active || flag.deleted { - continue; - } - - let flag_match = matcher.get_match(&flag).await; - - let flag_value = if flag_match.matches { - match flag_match.variant { - Some(variant) => FlagValue::String(variant), - None => FlagValue::Boolean(true), - } - } else { - FlagValue::Boolean(false) - }; - - feature_flags.insert(flag.key.clone(), flag_value); - - if let Err(e) = matcher - .get_person_properties(flag.team_id, distinct_id.clone()) - .await - { - error_while_computing_flags = true; - error!( - "Error fetching properties for feature flag '{}' and distinct_id '{}': {:?}", - flag.key, distinct_id, e - ); - } - } - - if error_while_computing_flags { - warn!( - "Errors occurred while computing feature flags for distinct_id '{}'", - distinct_id - ); - } - - FlagsResponse { - error_while_computing_flags, - feature_flags, - } + tracing::Span::current().record("sent_at", &meta.sent_at.unwrap_or(0).to_string()); } diff --git a/rust/feature-flags/src/v0_request.rs b/rust/feature-flags/src/v0_request.rs deleted file mode 100644 index 3f28d6da85f04..0000000000000 --- a/rust/feature-flags/src/v0_request.rs +++ /dev/null @@ -1,244 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use bytes::Bytes; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tracing::instrument; - -use crate::{ - api::FlagError, database::Client as DatabaseClient, flag_definitions::FeatureFlagList, - redis::Client as RedisClient, team::Team, -}; - -#[derive(Deserialize, Default)] -pub enum Compression { - #[default] - Unsupported, - - #[serde(rename = "gzip", alias = "gzip-js")] - Gzip, - // TODO do we want to support this at all? It's not in the spec - // #[serde(rename = "base64")] - // Base64, -} - -#[derive(Deserialize, Default)] -pub struct FlagsQueryParams { - #[serde(alias = "v")] - pub version: Option, - - pub compression: Option, - - #[serde(alias = "ver")] - pub lib_version: Option, - - #[serde(alias = "_")] - pub sent_at: Option, -} - -#[derive(Default, Debug, Deserialize, Serialize)] -pub struct FlagRequest { - #[serde( - alias = "$token", - alias = "api_key", - skip_serializing_if = "Option::is_none" - )] - pub token: Option, - #[serde(alias = "$distinct_id", skip_serializing_if = "Option::is_none")] - pub distinct_id: Option, - pub geoip_disable: Option, - #[serde(default)] - pub person_properties: Option>, - #[serde(default)] - pub groups: Option>, - // TODO: better type this since we know its going to be a nested json - #[serde(default)] - pub group_properties: Option>, - #[serde(alias = "$anon_distinct_id", skip_serializing_if = "Option::is_none")] - pub anon_distinct_id: Option, -} - -impl FlagRequest { - /// Takes a request payload and tries to read it. - /// Only supports base64 encoded payloads or uncompressed utf-8 as json. - #[instrument(skip_all)] - pub fn from_bytes(bytes: Bytes) -> Result { - tracing::debug!(len = bytes.len(), "decoding new request"); - // TODO: Add base64 decoding - let payload = String::from_utf8(bytes.into()).map_err(|e| { - tracing::error!("failed to decode body: {}", e); - FlagError::RequestDecodingError(String::from("invalid body encoding")) - })?; - - tracing::debug!(json = payload, "decoded event data"); - Ok(serde_json::from_str::(&payload)?) - } - - pub async fn extract_and_verify_token( - &self, - redis_client: Arc, - pg_client: Arc, - ) -> Result { - let token = match self { - FlagRequest { - token: Some(token), .. - } => token.to_string(), - _ => return Err(FlagError::NoTokenError), - }; - - match Team::from_redis(redis_client.clone(), token.clone()).await { - Ok(_) => Ok(token), - Err(_) => { - // Fallback: Check PostgreSQL if not found in Redis - match Team::from_pg(pg_client, token.clone()).await { - Ok(team) => { - // Token found in PostgreSQL, update Redis cache so that we can verify it from Redis next time - if let Err(e) = Team::update_redis_cache(redis_client, &team).await { - tracing::warn!("Failed to update Redis cache: {}", e); - } - Ok(token) - } - Err(_) => Err(FlagError::TokenValidationError), - } - } - } - } - - pub async fn get_team_from_cache_or_pg( - &self, - token: &str, - redis_client: Arc, - pg_client: Arc, - ) -> Result { - match Team::from_redis(redis_client.clone(), token.to_owned()).await { - Ok(team) => Ok(team), - Err(_) => match Team::from_pg(pg_client, token.to_owned()).await { - Ok(team) => { - // If we have the team in postgres, but not redis, update redis so we're faster next time - if let Err(e) = Team::update_redis_cache(redis_client, &team).await { - tracing::warn!("Failed to update Redis cache: {}", e); - } - Ok(team) - } - Err(_) => Err(FlagError::TokenValidationError), - }, - } - } - - pub fn extract_distinct_id(&self) -> Result { - let distinct_id = match &self.distinct_id { - None => return Err(FlagError::MissingDistinctId), - Some(id) => id, - }; - - match distinct_id.len() { - 0 => Err(FlagError::EmptyDistinctId), - 1..=200 => Ok(distinct_id.to_owned()), - _ => Ok(distinct_id.chars().take(200).collect()), - } - } - - pub async fn get_flags_from_cache_or_pg( - &self, - team_id: i32, - redis_client: Arc, - pg_client: Arc, - ) -> Result { - match FeatureFlagList::from_redis(redis_client.clone(), team_id).await { - Ok(flags) => Ok(flags), - Err(_) => match FeatureFlagList::from_pg(pg_client, team_id).await { - Ok(flags) => { - // If we have the flags in postgres, but not redis, update redis so we're faster next time - // TODO: we have some counters in django for tracking these cache misses - // we should probably do the same here - if let Err(e) = - FeatureFlagList::update_flags_in_redis(redis_client, team_id, &flags).await - { - tracing::warn!("Failed to update Redis cache: {}", e); - } - Ok(flags) - } - Err(_) => Err(FlagError::TokenValidationError), - }, - } - } -} - -#[cfg(test)] -mod tests { - use crate::api::FlagError; - use crate::test_utils::{insert_new_team_in_redis, setup_pg_client, setup_redis_client}; - use crate::v0_request::FlagRequest; - use bytes::Bytes; - use serde_json::json; - - #[test] - fn empty_distinct_id_not_accepted() { - let json = json!({ - "distinct_id": "", - "token": "my_token1", - }); - let bytes = Bytes::from(json.to_string()); - - let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); - - match flag_payload.extract_distinct_id() { - Err(FlagError::EmptyDistinctId) => (), - _ => panic!("expected empty distinct id error"), - }; - } - - #[test] - fn too_large_distinct_id_is_truncated() { - let json = json!({ - "distinct_id": std::iter::repeat("a").take(210).collect::(), - "token": "my_token1", - }); - let bytes = Bytes::from(json.to_string()); - - let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); - - assert_eq!(flag_payload.extract_distinct_id().unwrap().len(), 200); - } - - #[test] - fn distinct_id_is_returned_correctly() { - let json = json!({ - "$distinct_id": "alakazam", - "token": "my_token1", - }); - let bytes = Bytes::from(json.to_string()); - - let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); - - match flag_payload.extract_distinct_id() { - Ok(id) => assert_eq!(id, "alakazam"), - _ => panic!("expected distinct id"), - }; - } - - #[tokio::test] - async fn token_is_returned_correctly() { - let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; - let team = insert_new_team_in_redis(redis_client.clone()) - .await - .expect("Failed to insert new team in Redis"); - - let json = json!({ - "$distinct_id": "alakazam", - "token": team.api_token, - }); - let bytes = Bytes::from(json.to_string()); - - let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); - - match flag_payload - .extract_and_verify_token(redis_client, pg_client) - .await - { - Ok(extracted_token) => assert_eq!(extracted_token, team.api_token), - Err(e) => panic!("Failed to extract and verify token: {:?}", e), - }; - } -} diff --git a/rust/feature-flags/tests/test_flag_matching_consistency.rs b/rust/feature-flags/tests/test_flag_matching_consistency.rs index d4b55ed4e9001..2a4972962019c 100644 --- a/rust/feature-flags/tests/test_flag_matching_consistency.rs +++ b/rust/feature-flags/tests/test_flag_matching_consistency.rs @@ -107,9 +107,10 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { for i in 0..1000 { let distinct_id = format!("distinct_id_{}", i); - let feature_flag_match = FeatureFlagMatcher::new(distinct_id, None) + let feature_flag_match = FeatureFlagMatcher::new(distinct_id, None, None) .get_match(&flags[0]) - .await; + .await + .unwrap(); if results[i] { assert_eq!( @@ -1188,9 +1189,10 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { for i in 0..1000 { let distinct_id = format!("distinct_id_{}", i); - let feature_flag_match = FeatureFlagMatcher::new(distinct_id, None) + let feature_flag_match = FeatureFlagMatcher::new(distinct_id, None, None) .get_match(&flags[0]) - .await; + .await + .unwrap(); if results[i].is_some() { assert_eq!(