From 211791658e6b2f5b34e80bfc7cb3b71d031a13e4 Mon Sep 17 00:00:00 2001 From: Dylan Martin Date: Wed, 2 Oct 2024 04:14:05 -0400 Subject: [PATCH] feat(flags): add experience continuity (#25245) Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com> --- rust/Cargo.lock | 79 + rust/feature-flags/Cargo.toml | 1 + rust/feature-flags/src/api.rs | 22 +- rust/feature-flags/src/database.rs | 4 +- rust/feature-flags/src/flag_definitions.rs | 139 +- rust/feature-flags/src/flag_matching.rs | 1292 ++++++++++++++--- rust/feature-flags/src/flag_request.rs | 11 +- rust/feature-flags/src/request_handler.rs | 180 ++- rust/feature-flags/src/router.rs | 46 +- rust/feature-flags/src/server.rs | 22 +- rust/feature-flags/src/team.rs | 14 +- rust/feature-flags/src/test_utils.rs | 21 +- .../tests/test_flag_matching_consistency.rs | 44 +- rust/feature-flags/tests/test_flags.rs | 268 +++- 14 files changed, 1745 insertions(+), 398 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 00996eb404053..a3126214c319a 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -957,6 +957,41 @@ dependencies = [ "uuid", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.48", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.48", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -991,6 +1026,37 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_builder" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd33f37ee6a119146a1781d3356a7c26028f83d779b2e04ecd45fdc75c76877b" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7431fa049613920234f22c47fdc33e6cf3ee83067091ea4277a3f8c4587aae38" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc" +dependencies = [ + "derive_builder_core", + "syn 2.0.48", +] + [[package]] name = "digest" version = "0.10.7" @@ -1201,6 +1267,7 @@ dependencies = [ "bytes", "common-alloc", "common-metrics", + "derive_builder", "envconfig", "flate2", "futures", @@ -1959,6 +2026,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.5.0" @@ -3939,6 +4012,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "strum" version = "0.26.3" diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index 0cf96dfc6756b..4cf4016767be6 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -38,6 +38,7 @@ strum = { version = "0.26", features = ["derive"] } health = { path = "../common/health" } common-metrics = { path = "../common/metrics" } tower = { workspace = true } +derive_builder = "0.20.1" [lints] workspace = true diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index 30ddf27809dec..4430476d28a52 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -86,6 +86,8 @@ pub enum FlagError { NoTokenError, #[error("API key is not valid")] TokenValidationError, + #[error("Row not found in postgres")] + RowNotFound, #[error("failed to parse redis cache data")] DataParsingError, #[error("failed to update redis cache")] @@ -94,6 +96,8 @@ pub enum FlagError { RedisUnavailable, #[error("database unavailable")] DatabaseUnavailable, + #[error("Database error: {0}")] + DatabaseError(String), #[error("Timed out while fetching data")] TimeoutError, #[error("No group type mappings")] @@ -162,6 +166,13 @@ impl IntoResponse for FlagError { "Our database service is currently unavailable. This is likely a temporary issue. Please try again later.".to_string(), ) } + FlagError::DatabaseError(msg) => { + tracing::error!("Database error: {}", msg); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "A database error occurred. Please try again later or contact support if the problem persists.".to_string(), + ) + } FlagError::TimeoutError => { tracing::error!("Timeout error: {:?}", self); ( @@ -176,6 +187,13 @@ impl IntoResponse for FlagError { "No group type mappings found. This is likely a configuration issue. Please contact support.".to_string(), ) } + FlagError::RowNotFound => { + tracing::error!("Row not found in postgres: {:?}", self); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "The requested row was not found in the database. Please try again later or contact support if the problem persists.".to_string(), + ) + } } .into_response() } @@ -216,8 +234,8 @@ impl From for FlagError { tracing::error!("sqlx error: {}", e); println!("sqlx error: {}", e); match e { - sqlx::Error::RowNotFound => FlagError::TokenValidationError, - _ => FlagError::DatabaseUnavailable, + sqlx::Error::RowNotFound => FlagError::RowNotFound, + _ => FlagError::DatabaseError(e.to_string()), } } } diff --git a/rust/feature-flags/src/database.rs b/rust/feature-flags/src/database.rs index c7a45aeffc2d1..c340b61774a7f 100644 --- a/rust/feature-flags/src/database.rs +++ b/rust/feature-flags/src/database.rs @@ -4,8 +4,8 @@ use anyhow::Result; use async_trait::async_trait; use sqlx::{ pool::PoolConnection, - postgres::{PgPoolOptions, PgRow}, - PgPool, Postgres, + postgres::{PgPool, PgPoolOptions, PgRow}, + Postgres, }; use thiserror::Error; use tokio::time::timeout; diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index 820d68a4d0250..baebaa04da30e 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -120,7 +120,7 @@ impl FeatureFlag { } } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct FeatureFlagList { pub flags: Vec, } @@ -226,22 +226,23 @@ mod tests { use super::*; use crate::test_utils::{ insert_flag_for_team_in_pg, insert_flags_for_team_in_redis, insert_new_team_in_pg, - insert_new_team_in_redis, setup_invalid_pg_client, setup_pg_client, setup_redis_client, + insert_new_team_in_redis, setup_invalid_pg_client, setup_pg_reader_client, + setup_redis_client, }; #[tokio::test] async fn test_fetch_flags_from_redis() { - let client = setup_redis_client(None); + let redis_client = setup_redis_client(None); - let team = insert_new_team_in_redis(client.clone()) + let team = insert_new_team_in_redis(redis_client.clone()) .await .expect("Failed to insert team"); - insert_flags_for_team_in_redis(client.clone(), team.id, None) + insert_flags_for_team_in_redis(redis_client.clone(), team.id, None) .await .expect("Failed to insert flags"); - let flags_from_redis = FeatureFlagList::from_redis(client.clone(), team.id) + let flags_from_redis = FeatureFlagList::from_redis(redis_client.clone(), team.id) .await .expect("Failed to fetch flags from redis"); assert_eq!(flags_from_redis.flags.len(), 1); @@ -264,9 +265,9 @@ mod tests { #[tokio::test] async fn test_fetch_invalid_team_from_redis() { - let client = setup_redis_client(None); + let redis_client = setup_redis_client(None); - match FeatureFlagList::from_redis(client.clone(), 1234).await { + match FeatureFlagList::from_redis(redis_client.clone(), 1234).await { Err(FlagError::TokenValidationError) => (), _ => panic!("Expected TokenValidationError"), }; @@ -284,17 +285,17 @@ mod tests { #[tokio::test] async fn test_fetch_flags_from_pg() { - let client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); - insert_flag_for_team_in_pg(client.clone(), team.id, None) + insert_flag_for_team_in_pg(postgres_reader.clone(), team.id, None) .await .expect("Failed to insert flags"); - let flags_from_pg = FeatureFlagList::from_pg(client.clone(), team.id) + let flags_from_pg = FeatureFlagList::from_pg(postgres_reader.clone(), team.id) .await .expect("Failed to fetch flags from pg"); @@ -423,9 +424,9 @@ mod tests { #[tokio::test] async fn test_fetch_empty_team_from_pg() { - let client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let FeatureFlagList { flags } = FeatureFlagList::from_pg(client.clone(), 1234) + let FeatureFlagList { flags } = FeatureFlagList::from_pg(postgres_reader.clone(), 1234) .await .expect("Failed to fetch flags from pg"); { @@ -435,9 +436,9 @@ mod tests { #[tokio::test] async fn test_fetch_nonexistent_team_from_pg() { - let client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - match FeatureFlagList::from_pg(client.clone(), -1).await { + match FeatureFlagList::from_pg(postgres_reader.clone(), -1).await { Ok(flags) => assert_eq!(flags.flags.len(), 0), Err(err) => panic!("Expected empty result, got error: {:?}", err), } @@ -446,9 +447,9 @@ mod tests { #[tokio::test] async fn test_fetch_flags_db_connection_failure() { // Simulate a database connection failure by using an invalid client setup - let client = setup_invalid_pg_client().await; + let invalid_client = setup_invalid_pg_client().await; - match FeatureFlagList::from_pg(client, 1).await { + match FeatureFlagList::from_pg(invalid_client, 1).await { Err(FlagError::DatabaseUnavailable) => (), other => panic!("Expected DatabaseUnavailable error, got: {:?}", other), } @@ -456,9 +457,9 @@ mod tests { #[tokio::test] async fn test_fetch_multiple_flags_from_pg() { - let client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -488,15 +489,15 @@ mod tests { }; // Insert multiple flags for the team - insert_flag_for_team_in_pg(client.clone(), team.id, Some(flag1)) + insert_flag_for_team_in_pg(postgres_reader.clone(), team.id, Some(flag1)) .await .expect("Failed to insert flags"); - insert_flag_for_team_in_pg(client.clone(), team.id, Some(flag2)) + insert_flag_for_team_in_pg(postgres_reader.clone(), team.id, Some(flag2)) .await .expect("Failed to insert flags"); - let flags_from_pg = FeatureFlagList::from_pg(client.clone(), team.id) + let flags_from_pg = FeatureFlagList::from_pg(postgres_reader.clone(), team.id) .await .expect("Failed to fetch flags from pg"); @@ -544,9 +545,9 @@ mod tests { #[tokio::test] async fn test_multivariate_flag_parsing() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -597,7 +598,7 @@ mod tests { // Insert into Postgres insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 1, @@ -624,7 +625,7 @@ mod tests { assert_eq!(redis_flag.get_variants().len(), 3); // Fetch and verify from Postgres - let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + let pg_flags = FeatureFlagList::from_pg(postgres_reader, team.id) .await .expect("Failed to fetch flags from Postgres"); @@ -637,9 +638,9 @@ mod tests { #[tokio::test] async fn test_multivariate_flag_with_payloads() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -695,7 +696,7 @@ mod tests { // Insert into Postgres insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 1, @@ -721,7 +722,7 @@ mod tests { assert_eq!(redis_flag.key, "multivariate_flag_with_payloads"); // Fetch and verify from Postgres - let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + let pg_flags = FeatureFlagList::from_pg(postgres_reader, team.id) .await .expect("Failed to fetch flags from Postgres"); @@ -779,9 +780,9 @@ mod tests { #[tokio::test] async fn test_flag_with_super_groups() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -826,7 +827,7 @@ mod tests { // Insert into Postgres insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 1, @@ -854,7 +855,7 @@ mod tests { assert_eq!(redis_flag.filters.super_groups.as_ref().unwrap().len(), 1); // Fetch and verify from Postgres - let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + let pg_flags = FeatureFlagList::from_pg(postgres_reader, team.id) .await .expect("Failed to fetch flags from Postgres"); @@ -868,9 +869,9 @@ mod tests { #[tokio::test] async fn test_flags_with_different_property_types() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -921,7 +922,7 @@ mod tests { // Insert into Postgres insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 1, @@ -952,7 +953,7 @@ mod tests { assert_eq!(redis_properties[2].prop_type, "event"); // Fetch and verify from Postgres - let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + let pg_flags = FeatureFlagList::from_pg(postgres_reader, team.id) .await .expect("Failed to fetch flags from Postgres"); @@ -969,9 +970,9 @@ mod tests { #[tokio::test] async fn test_deleted_and_inactive_flags() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -1006,7 +1007,7 @@ mod tests { // Insert into Postgres insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 0, @@ -1023,7 +1024,7 @@ mod tests { .expect("Failed to insert deleted flag in Postgres"); insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 0, @@ -1055,7 +1056,7 @@ mod tests { .any(|f| f.key == "inactive_flag" && !f.active)); // Fetch and verify from Postgres - let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + let pg_flags = FeatureFlagList::from_pg(postgres_reader, team.id) .await .expect("Failed to fetch flags from Postgres"); @@ -1073,7 +1074,7 @@ mod tests { #[tokio::test] async fn test_error_handling() { let redis_client = setup_redis_client(Some("redis://localhost:6379/".to_string())); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; // Test Redis connection error let bad_redis_client = setup_redis_client(Some("redis://localhost:1111/".to_string())); @@ -1081,7 +1082,7 @@ mod tests { assert!(matches!(result, Err(FlagError::RedisUnavailable))); // Test malformed JSON in Redis - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -1098,7 +1099,7 @@ mod tests { // Test database query error (using a non-existent table) let result = sqlx::query("SELECT * FROM non_existent_table") - .fetch_all(&mut *pg_client.get_connection().await.unwrap()) + .fetch_all(&mut *postgres_reader.get_connection().await.unwrap()) .await; assert!(result.is_err()); } @@ -1106,9 +1107,9 @@ mod tests { #[tokio::test] async fn test_concurrent_access() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -1131,7 +1132,7 @@ mod tests { .expect("Failed to insert flag in Redis"); insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 0, @@ -1150,14 +1151,16 @@ mod tests { let mut handles = vec![]; for _ in 0..10 { let redis_client = redis_client.clone(); - let pg_client = pg_client.clone(); + let postgres_reader = postgres_reader.clone(); let team_id = team.id; let handle = task::spawn(async move { let redis_flags = FeatureFlagList::from_redis(redis_client, team_id) .await .unwrap(); - let pg_flags = FeatureFlagList::from_pg(pg_client, team_id).await.unwrap(); + let pg_flags = FeatureFlagList::from_pg(postgres_reader, team_id) + .await + .unwrap(); (redis_flags, pg_flags) }); @@ -1177,9 +1180,9 @@ mod tests { #[ignore] async fn test_performance() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -1209,7 +1212,7 @@ mod tests { for flag in flags { insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 0, @@ -1233,7 +1236,7 @@ mod tests { let redis_duration = start.elapsed(); let start = Instant::now(); - let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + let pg_flags = FeatureFlagList::from_pg(postgres_reader, team.id) .await .expect("Failed to fetch flags from Postgres"); let pg_duration = start.elapsed(); @@ -1251,9 +1254,9 @@ mod tests { #[tokio::test] async fn test_edge_cases() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -1298,7 +1301,7 @@ mod tests { for flag in edge_case_flags.as_array().unwrap() { insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 0, @@ -1319,7 +1322,7 @@ mod tests { let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) .await .expect("Failed to fetch flags from Redis"); - let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + let pg_flags = FeatureFlagList::from_pg(postgres_reader, team.id) .await .expect("Failed to fetch flags from Postgres"); @@ -1344,9 +1347,9 @@ mod tests { #[tokio::test] async fn test_consistent_behavior_from_both_clients() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -1378,7 +1381,7 @@ mod tests { for flag in flags.as_array().unwrap() { insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 0, @@ -1399,7 +1402,7 @@ mod tests { let mut redis_flags = FeatureFlagList::from_redis(redis_client, team.id) .await .expect("Failed to fetch flags from Redis"); - let mut pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + let mut pg_flags = FeatureFlagList::from_pg(postgres_reader, team.id) .await .expect("Failed to fetch flags from Postgres"); @@ -1443,9 +1446,9 @@ mod tests { #[tokio::test] async fn test_rollout_percentage_edge_cases() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); @@ -1486,7 +1489,7 @@ mod tests { for flag in flags.as_array().unwrap() { insert_flag_for_team_in_pg( - pg_client.clone(), + postgres_reader.clone(), team.id, Some(FeatureFlagRow { id: 0, @@ -1507,7 +1510,7 @@ mod tests { let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) .await .expect("Failed to fetch flags from Redis"); - let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + let pg_flags = FeatureFlagList::from_pg(postgres_reader, team.id) .await .expect("Failed to fetch flags from Postgres"); diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index e75580a6cbc90..4dd72ed32aba3 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -8,15 +8,20 @@ use crate::{ use anyhow::Result; use serde_json::Value; use sha1::{Digest, Sha1}; -use sqlx::FromRow; -use std::collections::{HashMap, HashSet}; +use sqlx::{postgres::PgQueryResult, Acquire, FromRow}; use std::fmt::Write; use std::sync::Arc; -use tracing::error; +use std::{ + collections::{HashMap, HashSet}, + time::Duration, +}; +use tokio::time::{sleep, timeout}; +use tracing::{error, info}; type TeamId = i32; -type DatabaseClientArc = Arc; type GroupTypeIndex = i32; +type PostgresReader = Arc; +type PostgresWriter = Arc; #[derive(Debug)] struct SuperConditionEvaluation { @@ -61,17 +66,17 @@ pub struct GroupTypeMappingCache { failed_to_fetch_flags: bool, group_types_to_indexes: HashMap, group_indexes_to_types: HashMap, - database_client: DatabaseClientArc, + postgres_reader: PostgresReader, } impl GroupTypeMappingCache { - pub fn new(team_id: TeamId, database_client: DatabaseClientArc) -> Self { + pub fn new(team_id: TeamId, postgres_reader: PostgresReader) -> Self { GroupTypeMappingCache { team_id, failed_to_fetch_flags: false, group_types_to_indexes: HashMap::new(), group_indexes_to_types: HashMap::new(), - database_client, + postgres_reader, } } @@ -86,10 +91,9 @@ impl GroupTypeMappingCache { return Ok(self.group_types_to_indexes.clone()); } - let database_client = self.database_client.clone(); let team_id = self.team_id; let mapping = match self - .fetch_group_type_mapping(database_client, team_id) + .fetch_group_type_mapping(self.postgres_reader.clone(), team_id) .await { Ok(mapping) if !mapping.is_empty() => mapping, @@ -128,10 +132,10 @@ impl GroupTypeMappingCache { async fn fetch_group_type_mapping( &mut self, - database_client: DatabaseClientArc, + postgres_reader: PostgresReader, team_id: TeamId, ) -> Result, FlagError> { - let mut conn = database_client.as_ref().get_connection().await?; + let mut conn = postgres_reader.as_ref().get_connection().await?; let query = r#" SELECT group_type, group_type_index @@ -170,7 +174,8 @@ pub struct PropertiesCache { pub struct FeatureFlagMatcher { pub distinct_id: String, pub team_id: TeamId, - pub database_client: DatabaseClientArc, + pub postgres_reader: PostgresReader, + pub postgres_writer: PostgresWriter, group_type_mapping_cache: GroupTypeMappingCache, properties_cache: PropertiesCache, groups: HashMap, @@ -182,7 +187,8 @@ impl FeatureFlagMatcher { pub fn new( distinct_id: String, team_id: TeamId, - database_client: DatabaseClientArc, + postgres_reader: PostgresReader, + postgres_writer: PostgresWriter, group_type_mapping_cache: Option, properties_cache: Option, groups: Option>, @@ -190,9 +196,10 @@ impl FeatureFlagMatcher { FeatureFlagMatcher { distinct_id, team_id, - database_client: database_client.clone(), + postgres_reader: postgres_reader.clone(), + postgres_writer: postgres_writer.clone(), group_type_mapping_cache: group_type_mapping_cache - .unwrap_or_else(|| GroupTypeMappingCache::new(team_id, database_client.clone())), + .unwrap_or_else(|| GroupTypeMappingCache::new(team_id, postgres_reader.clone())), properties_cache: properties_cache.unwrap_or_default(), groups: groups.unwrap_or_default(), } @@ -201,29 +208,127 @@ impl FeatureFlagMatcher { /// Evaluate feature flags for a given distinct_id /// - Returns a map of feature flag keys to their values /// - If an error occurs while evaluating a flag, it will be logged and the flag will be omitted from the result - pub async fn evaluate_feature_flags( + pub async fn evaluate_all_feature_flags( &mut self, feature_flags: FeatureFlagList, person_property_overrides: Option>, group_property_overrides: Option>>, + hash_key_override: Option, + ) -> FlagsResponse { + let flags_have_experience_continuity_enabled = feature_flags + .flags + .iter() + .any(|flag| flag.ensure_experience_continuity); + + // Process any hash key overrides + let (hash_key_overrides, initial_error) = if flags_have_experience_continuity_enabled { + match hash_key_override { + Some(hash_key) => { + let target_distinct_ids = vec![self.distinct_id.clone(), hash_key.clone()]; + self.process_hash_key_override(hash_key, target_distinct_ids) + .await + } + // if a flag has experience continuity enabled but no hash key override is provided, + // we don't need to write an override, we can just use the distinct_id + None => (None, false), + } + } else { + // if experience continuity is not enabled, we don't need to worry about hash key overrides + (None, false) + }; + + let flags_response = self + .evaluate_flags_with_overrides( + feature_flags, + person_property_overrides, + group_property_overrides, + hash_key_overrides, + ) + .await; + + FlagsResponse { + error_while_computing_flags: initial_error + || flags_response.error_while_computing_flags, + feature_flags: flags_response.feature_flags, + } + } + + async fn process_hash_key_override( + &self, + hash_key: String, + target_distinct_ids: Vec, + ) -> (Option>, bool) { + let should_write = match should_write_hash_key_override( + self.postgres_reader.clone(), + self.team_id, + self.distinct_id.clone(), + hash_key.clone(), + ) + .await + { + Ok(should_write) => should_write, + Err(e) => { + error!( + "Failed to check if hash key override should be written: {:?}", + e + ); + return (None, true); + } + }; + + if should_write { + if let Err(e) = set_feature_flag_hash_key_overrides( + // NB: this is the only method that writes to the database, so it's the only one that should use the writer + self.postgres_writer.clone(), + self.team_id, + target_distinct_ids.clone(), + hash_key, + ) + .await + { + error!("Failed to set feature flag hash key overrides: {:?}", e); + return (None, true); + } + } + + match get_feature_flag_hash_key_overrides( + self.postgres_reader.clone(), + self.team_id, + target_distinct_ids, + ) + .await + { + Ok(overrides) => (Some(overrides), false), + Err(e) => { + error!("Failed to get feature flag hash key overrides: {:?}", e); + (None, true) + } + } + } + + async fn evaluate_flags_with_overrides( + &mut self, + feature_flags: FeatureFlagList, + person_property_overrides: Option>, + group_property_overrides: Option>>, + hash_key_overrides: Option>, ) -> FlagsResponse { let mut result = HashMap::new(); let mut error_while_computing_flags = false; let mut flags_needing_db_properties = Vec::new(); - // Step 1: Evaluate flags that can be resolved with overrides + // Step 1: Evaluate flags with locally computable property overrides first for flag in &feature_flags.flags { - // Skip inactive or deleted flags if !flag.active || flag.deleted { continue; } - // Get any flag matches with overrides, assuming that these overrides can be computed locally match self - .match_flag_with_overrides( + .match_flag_with_property_overrides( flag, &person_property_overrides, &group_property_overrides, + hash_key_overrides.clone(), ) .await { @@ -234,7 +339,6 @@ impl FeatureFlagMatcher { Ok(None) => { flags_needing_db_properties.push(flag.clone()); } - // We had overrides, but couldn't evaluate the flag Err(e) => { error_while_computing_flags = true; error!( @@ -245,46 +349,50 @@ impl FeatureFlagMatcher { } } - // At this point, we have a list of flags that we couldn't locally evaluate (with overrides, or without), so - // we need to hit the DB to fetch the properties for these flags to continue our evaluation. - - // Step 2: Fetch and cache properties for remaining flags + // Step 2: Fetch and cache properties for remaining flags (just one DB lookup for all of relevant properties) if !flags_needing_db_properties.is_empty() { let group_type_indexes: HashSet = flags_needing_db_properties .iter() .filter_map(|flag| flag.get_group_type_index()) .collect(); - let database_client = self.database_client.clone(); + let postgres_reader = self.postgres_reader.clone(); let distinct_id = self.distinct_id.clone(); let team_id = self.team_id; match fetch_and_locally_cache_all_properties( &mut self.properties_cache, - database_client, + postgres_reader, distinct_id, team_id, &group_type_indexes, ) .await { - Ok(_) => {} // `fetch_and_locally_cache_all_properties` returns void on success, - // so at this point we know we've cached the properties, and can continue. + Ok(_) => {} Err(e) => { error_while_computing_flags = true; + // TODO add sentry exception tracking error!("Error fetching properties: {:?}", e); } } - // Step 3: Evaluate remaining flags + // Step 3: Evaluate remaining flags with cached properties + // At this point we've already done a round of flag evaluations with locally computable property overrides + // This step is for flags that couldn't be evaluated locally due to missing property values, + // so we do a single query to fetch all of the remaining properties, and then proceed with flag evaluations for flag in flags_needing_db_properties { - match self.get_match(&flag, None).await { + match self + .get_match(&flag, None, hash_key_overrides.clone()) + .await + { Ok(flag_match) => { let flag_value = self.flag_match_to_value(&flag_match); result.insert(flag.key.clone(), flag_value); } Err(e) => { error_while_computing_flags = true; + // TODO add sentry exception tracking error!( "Error evaluating feature flag '{}' for distinct_id '{}': {:?}", flag.key, self.distinct_id, e @@ -306,11 +414,12 @@ impl FeatureFlagMatcher { /// depending on whether the flag is group-based or person-based. It first collects all property /// filters from the flag's conditions, then retrieves the appropriate overrides, and finally /// attempts to match the flag using these overrides. - async fn match_flag_with_overrides( + async fn match_flag_with_property_overrides( &mut self, flag: &FeatureFlag, person_property_overrides: &Option>, group_property_overrides: &Option>>, + hash_key_overrides: Option>, ) -> Result, FlagError> { let flag_property_filters: Vec = flag .get_conditions() @@ -331,7 +440,10 @@ impl FeatureFlagMatcher { }; match overrides { - Some(props) => self.get_match(flag, Some(props)).await.map(Some), + Some(props) => self + .get_match(flag, Some(props), hash_key_overrides) + .await + .map(Some), None => Ok(None), } } @@ -409,8 +521,13 @@ impl FeatureFlagMatcher { &mut self, flag: &FeatureFlag, property_overrides: Option>, + hash_key_overrides: Option>, ) -> Result { - if self.hashed_identifier(flag).await?.is_empty() { + if self + .hashed_identifier(flag, hash_key_overrides.clone()) + .await? + .is_empty() + { return Ok(FeatureFlagMatch { matches: false, variant: None, @@ -427,7 +544,11 @@ impl FeatureFlagMatcher { if let Some(super_groups) = &flag.filters.super_groups { if !super_groups.is_empty() { let super_condition_evaluation = self - .is_super_condition_match(flag, property_overrides.clone()) + .is_super_condition_match( + flag, + property_overrides.clone(), + hash_key_overrides.clone(), + ) .await?; if super_condition_evaluation.should_evaluate { @@ -452,7 +573,12 @@ impl FeatureFlagMatcher { for (index, condition) in sorted_conditions { let (is_match, reason) = self - .is_condition_match(flag, condition, property_overrides.clone()) + .is_condition_match( + flag, + condition, + property_overrides.clone(), + hash_key_overrides.clone(), + ) .await?; // Update highest_match and highest_index @@ -471,7 +597,9 @@ impl FeatureFlagMatcher { break; // Exit early if we've found a super condition match } - let variant = self.get_matching_variant(flag).await?; + let variant = self + .get_matching_variant(flag, hash_key_overrides.clone()) + .await?; let payload = self.get_matching_payload(variant.as_deref(), flag); return Ok(FeatureFlagMatch { @@ -524,12 +652,15 @@ impl FeatureFlagMatcher { feature_flag: &FeatureFlag, condition: &FlagGroupType, property_overrides: Option>, + hash_key_overrides: Option>, ) -> Result<(bool, FeatureFlagMatchReason), FlagError> { let rollout_percentage = condition.rollout_percentage.unwrap_or(100.0); if let Some(flag_property_filters) = &condition.properties { if flag_property_filters.is_empty() { - return self.check_rollout(feature_flag, rollout_percentage).await; + return self + .check_rollout(feature_flag, rollout_percentage, hash_key_overrides) + .await; } // NB: we can only evaluate group or person properties, not both @@ -542,7 +673,8 @@ impl FeatureFlagMatcher { } } - self.check_rollout(feature_flag, rollout_percentage).await + self.check_rollout(feature_flag, rollout_percentage, hash_key_overrides) + .await } /// Get properties to check for a feature flag. @@ -614,6 +746,7 @@ impl FeatureFlagMatcher { &mut self, feature_flag: &FeatureFlag, property_overrides: Option>, + hash_key_overrides: Option>, ) -> Result { if let Some(first_condition) = feature_flag .filters @@ -638,7 +771,12 @@ impl FeatureFlagMatcher { }); let (is_match, _) = self - .is_condition_match(feature_flag, first_condition, Some(person_properties)) + .is_condition_match( + feature_flag, + first_condition, + Some(person_properties), + hash_key_overrides, + ) .await?; if has_relevant_super_condition_properties { @@ -679,10 +817,10 @@ impl FeatureFlagMatcher { return Ok(result); } - let database_client = self.database_client.clone(); + let postgres_reader = self.postgres_reader.clone(); let team_id = self.team_id; let db_properties = - fetch_group_properties_from_db(database_client, team_id, group_type_index).await?; + fetch_group_properties_from_db(postgres_reader, team_id, group_type_index).await?; // once the properties are fetched, cache them so we don't need to fetch again in a given request self.properties_cache @@ -707,11 +845,11 @@ impl FeatureFlagMatcher { return Ok(result); } - let database_client = self.database_client.clone(); + let postgres_reader = self.postgres_reader.clone(); let distinct_id = self.distinct_id.clone(); let team_id = self.team_id; let db_properties = - fetch_person_properties_from_db(database_client, distinct_id, team_id).await?; + fetch_person_properties_from_db(postgres_reader, distinct_id, team_id).await?; // once the properties are fetched, cache them so we don't need to fetch again in a given request self.properties_cache.person_properties = Some(db_properties.clone()); @@ -723,9 +861,11 @@ impl FeatureFlagMatcher { /// /// This function generates a hashed identifier for a feature flag based on the feature flag's group type index. /// If the feature flag is group-based, it fetches the group key; otherwise, it uses the distinct ID. - async fn hashed_identifier(&mut self, feature_flag: &FeatureFlag) -> Result { - // TODO: Use hash key overrides for experience continuity - + async fn hashed_identifier( + &mut self, + feature_flag: &FeatureFlag, + hash_key_overrides: Option>, + ) -> Result { if let Some(group_type_index) = feature_flag.get_group_type_index() { // Group-based flag let group_key = self @@ -740,7 +880,15 @@ impl FeatureFlagMatcher { Ok(group_key.to_string()) } else { // Person-based flag - Ok(self.distinct_id.clone()) + // Use hash key overrides for experience continuity + if let Some(hash_key_override) = hash_key_overrides + .as_ref() + .and_then(|h| h.get(&feature_flag.key)) + { + Ok(hash_key_override.clone()) + } else { + Ok(self.distinct_id.clone()) + } } } @@ -748,8 +896,15 @@ impl FeatureFlagMatcher { /// Given the same identifier and key, it'll always return the same float. These floats are /// uniformly distributed between 0 and 1, so if we want to show this feature to 20% of traffic /// we can do _hash(key, identifier) < 0.2 - async fn get_hash(&mut self, feature_flag: &FeatureFlag, salt: &str) -> Result { - let hashed_identifier = self.hashed_identifier(feature_flag).await?; + async fn get_hash( + &mut self, + feature_flag: &FeatureFlag, + salt: &str, + hash_key_overrides: Option>, + ) -> Result { + let hashed_identifier = self + .hashed_identifier(feature_flag, hash_key_overrides) + .await?; if hashed_identifier.is_empty() { // Return a hash value that will make the flag evaluate to false // TODO make this cleaner – we should have a way to return a default value @@ -780,8 +935,9 @@ impl FeatureFlagMatcher { &mut self, feature_flag: &FeatureFlag, rollout_percentage: f64, + hash_key_overrides: Option>, ) -> Result<(bool, FeatureFlagMatchReason), FlagError> { - let hash = self.get_hash(feature_flag, "").await?; + let hash = self.get_hash(feature_flag, "", hash_key_overrides).await?; if rollout_percentage == 100.0 || hash <= (rollout_percentage / 100.0) { Ok((true, FeatureFlagMatchReason::ConditionMatch)) } else { @@ -793,8 +949,11 @@ impl FeatureFlagMatcher { async fn get_matching_variant( &mut self, feature_flag: &FeatureFlag, + hash_key_overrides: Option>, ) -> Result, FlagError> { - let hash = self.get_hash(feature_flag, "variant").await?; + let hash = self + .get_hash(feature_flag, "variant", hash_key_overrides) + .await?; let mut cumulative_percentage = 0.0; for variant in feature_flag.get_variants() { @@ -826,12 +985,12 @@ impl FeatureFlagMatcher { /// It updates the properties cache with the fetched properties and returns the result. async fn fetch_and_locally_cache_all_properties( properties_cache: &mut PropertiesCache, - database_client: DatabaseClientArc, + postgres_reader: PostgresReader, distinct_id: String, team_id: TeamId, group_type_indexes: &HashSet, ) -> Result<(), FlagError> { - let mut conn = database_client.as_ref().get_connection().await?; + let mut conn = postgres_reader.as_ref().get_connection().await?; let query = r#" SELECT @@ -899,11 +1058,11 @@ async fn fetch_and_locally_cache_all_properties( /// This function constructs and executes a SQL query to fetch the person properties for a specified distinct ID and team ID. /// It returns the fetched properties as a HashMap. async fn fetch_person_properties_from_db( - database_client: DatabaseClientArc, + postgres_reader: PostgresReader, distinct_id: String, team_id: TeamId, ) -> Result, FlagError> { - let mut conn = database_client.as_ref().get_connection().await?; + let mut conn = postgres_reader.as_ref().get_connection().await?; let query = r#" SELECT "posthog_person"."properties" as person_properties @@ -934,11 +1093,11 @@ async fn fetch_person_properties_from_db( /// This function constructs and executes a SQL query to fetch the group properties for a specified team ID and group type index. /// It returns the fetched properties as a HashMap. async fn fetch_group_properties_from_db( - database_client: DatabaseClientArc, + postgres_reader: PostgresReader, team_id: TeamId, group_type_index: GroupTypeIndex, ) -> Result, FlagError> { - let mut conn = database_client.as_ref().get_connection().await?; + let mut conn = postgres_reader.as_ref().get_connection().await?; let query = r#" SELECT "posthog_group"."group_properties" @@ -995,6 +1154,214 @@ fn all_properties_match( .all(|property| match_property(property, target_properties, false).unwrap_or(false)) } +async fn get_feature_flag_hash_key_overrides( + postgres_reader: PostgresReader, + team_id: TeamId, + distinct_id_and_hash_key_override: Vec, +) -> Result, FlagError> { + let mut feature_flag_hash_key_overrides = HashMap::new(); + let mut conn = postgres_reader.as_ref().get_connection().await?; + + let person_and_distinct_id_query = r#" + SELECT person_id, distinct_id + FROM posthog_persondistinctid + WHERE team_id = $1 AND distinct_id = ANY($2) + "#; + + let person_and_distinct_ids: Vec<(i32, String)> = sqlx::query_as(person_and_distinct_id_query) + .bind(team_id) + .bind(&distinct_id_and_hash_key_override) + .fetch_all(&mut *conn) + .await?; + + let person_id_to_distinct_id: HashMap = + person_and_distinct_ids.into_iter().collect(); + let person_ids: Vec = person_id_to_distinct_id.keys().cloned().collect(); + + // Get hash key overrides + let hash_key_override_query = r#" + SELECT feature_flag_key, hash_key, person_id + FROM posthog_featureflaghashkeyoverride + WHERE team_id = $1 AND person_id = ANY($2) + "#; + + let overrides: Vec<(String, String, i32)> = sqlx::query_as(hash_key_override_query) + .bind(team_id) + .bind(&person_ids) + .fetch_all(&mut *conn) + .await?; + + // Sort and process overrides, with the distinct_id at the start of the array having priority + // We want the highest priority to go last in sort order, so it's the latest update in the hashmap + let mut sorted_overrides = overrides; + sorted_overrides.sort_by_key(|(_, _, person_id)| { + if person_id_to_distinct_id.get(person_id) == Some(&distinct_id_and_hash_key_override[0]) { + std::cmp::Ordering::Greater + } else { + std::cmp::Ordering::Less + } + }); + + for (feature_flag_key, hash_key, _) in sorted_overrides { + feature_flag_hash_key_overrides.insert(feature_flag_key, hash_key); + } + + Ok(feature_flag_hash_key_overrides) +} + +async fn set_feature_flag_hash_key_overrides( + postgres_writer: PostgresReader, + team_id: TeamId, + distinct_ids: Vec, + hash_key_override: String, +) -> Result { + const MAX_RETRIES: u32 = 2; + const RETRY_DELAY: Duration = Duration::from_millis(100); + + for retry in 0..MAX_RETRIES { + let mut conn = postgres_writer.get_connection().await?; + let mut transaction = conn.begin().await?; + + let query = r#" + WITH target_person_ids AS ( + SELECT team_id, person_id FROM posthog_persondistinctid WHERE team_id = $1 AND + distinct_id = ANY($2) + ), + existing_overrides AS ( + SELECT team_id, person_id, feature_flag_key, hash_key FROM posthog_featureflaghashkeyoverride + WHERE team_id = $1 AND person_id IN (SELECT person_id FROM target_person_ids) + ), + flags_to_override AS ( + SELECT key FROM posthog_featureflag WHERE team_id = $1 AND ensure_experience_continuity = TRUE AND active = TRUE AND deleted = FALSE + AND key NOT IN (SELECT feature_flag_key FROM existing_overrides) + ) + INSERT INTO posthog_featureflaghashkeyoverride (team_id, person_id, feature_flag_key, hash_key) + SELECT team_id, person_id, key, $3 + FROM flags_to_override, target_person_ids + WHERE EXISTS (SELECT 1 FROM posthog_person WHERE id = person_id AND team_id = $1) + ON CONFLICT DO NOTHING + "#; + + let result: Result = sqlx::query(query) + .bind(team_id) + .bind(&distinct_ids) + .bind(&hash_key_override) + .execute(&mut *transaction) + .await; + + match result { + Ok(query_result) => { + // Commit the transaction if successful + transaction + .commit() + .await + .map_err(|e| FlagError::DatabaseError(e.to_string()))?; + return Ok(query_result.rows_affected() > 0); + } + Err(e) => { + // Rollback the transaction on error + transaction + .rollback() + .await + .map_err(|e| FlagError::DatabaseError(e.to_string()))?; + + if e.to_string().contains("violates foreign key constraint") + && retry < MAX_RETRIES - 1 + { + // Retry logic for specific error + tracing::info!( + "Retrying set_feature_flag_hash_key_overrides due to person deletion: {:?}", + e + ); + sleep(RETRY_DELAY).await; + } else { + return Err(FlagError::DatabaseError(e.to_string())); + } + } + } + } + + // If we get here, something went wrong + Ok(false) +} + +async fn should_write_hash_key_override( + postgres_reader: PostgresReader, + team_id: TeamId, + distinct_id: String, + hash_key_override: String, +) -> Result { + const QUERY_TIMEOUT: Duration = Duration::from_millis(1000); + const MAX_RETRIES: u32 = 2; + const RETRY_DELAY: Duration = Duration::from_millis(100); + + let distinct_ids = vec![distinct_id, hash_key_override.clone()]; + + let query = r#" + WITH target_person_ids AS ( + SELECT team_id, person_id + FROM posthog_persondistinctid + WHERE team_id = $1 AND distinct_id = ANY($2) + ), + existing_overrides AS ( + SELECT team_id, person_id, feature_flag_key, hash_key + FROM posthog_featureflaghashkeyoverride + WHERE team_id = $1 AND person_id IN (SELECT person_id FROM target_person_ids) + ) + SELECT key + FROM posthog_featureflag + WHERE team_id = $1 + AND ensure_experience_continuity = TRUE + AND active = TRUE + AND deleted = FALSE + AND key NOT IN (SELECT feature_flag_key FROM existing_overrides) + "#; + + for retry in 0..MAX_RETRIES { + let result = timeout(QUERY_TIMEOUT, async { + let mut conn = postgres_reader.get_connection().await.map_err(|e| { + FlagError::DatabaseError(format!("Failed to acquire connection: {}", e)) + })?; + + let rows = sqlx::query(query) + .bind(team_id) + .bind(&distinct_ids) + .fetch_all(&mut *conn) + .await + .map_err(|e| FlagError::DatabaseError(format!("Query execution failed: {}", e)))?; + + Ok::(!rows.is_empty()) + }) + .await; + + match result { + Ok(Ok(flags_present)) => return Ok(flags_present), + Ok(Err(e)) => { + if e.to_string().contains("violates foreign key constraint") + && retry < MAX_RETRIES - 1 + { + info!( + "Retrying set_feature_flag_hash_key_overrides due to person deletion: {:?}", + e + ); + tokio::time::sleep(RETRY_DELAY).await; + continue; + } else { + // For other errors or if max retries exceeded, return the error + return Err(e); + } + } + Err(_) => { + // Handle timeout + return Err(FlagError::TimeoutError); + } + } + } + + // If all retries failed without returning, return false + Ok(false) +} + #[cfg(test)] mod tests { use serde_json::json; @@ -1003,9 +1370,13 @@ mod tests { use super::*; use crate::{ flag_definitions::{ - FlagFilters, MultivariateFlagOptions, MultivariateFlagVariant, OperatorType, + FeatureFlagRow, FlagFilters, MultivariateFlagOptions, MultivariateFlagVariant, + OperatorType, + }, + test_utils::{ + insert_flag_for_team_in_pg, insert_new_team_in_pg, insert_person_for_team_in_pg, + setup_pg_reader_client, setup_pg_writer_client, }, - test_utils::{insert_new_team_in_pg, insert_person_for_team_in_pg, setup_pg_client}, }; fn create_test_flag( @@ -1042,20 +1413,21 @@ mod tests { #[tokio::test] async fn test_fetch_properties_from_pg_to_match() { - let database_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; - let team = insert_new_team_in_pg(database_client.clone()) + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .expect("Failed to insert team in pg"); let distinct_id = "user_distinct_id".to_string(); - insert_person_for_team_in_pg(database_client.clone(), team.id, distinct_id.clone(), None) + insert_person_for_team_in_pg(postgres_reader.clone(), team.id, distinct_id.clone(), None) .await .expect("Failed to insert person"); let not_matching_distinct_id = "not_matching_distinct_id".to_string(); insert_person_for_team_in_pg( - database_client.clone(), + postgres_reader.clone(), team.id, not_matching_distinct_id.clone(), Some(json!({ "email": "a@x.com"})), @@ -1090,44 +1462,48 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( distinct_id.clone(), team.id, - database_client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let match_result = matcher.get_match(&flag, None).await.unwrap(); + let match_result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(match_result.matches); assert_eq!(match_result.variant, None); let mut matcher = FeatureFlagMatcher::new( not_matching_distinct_id.clone(), team.id, - database_client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let match_result = matcher.get_match(&flag, None).await.unwrap(); + let match_result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(!match_result.matches); assert_eq!(match_result.variant, None); let mut matcher = FeatureFlagMatcher::new( "other_distinct_id".to_string(), team.id, - database_client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let match_result = matcher.get_match(&flag, None).await.unwrap(); + let match_result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(!match_result.matches); assert_eq!(match_result.variant, None); } #[tokio::test] async fn test_person_property_overrides() { - let database_client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(database_client.clone()) + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1163,7 +1539,8 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), team.id, - database_client, + postgres_reader, + postgres_writer, None, None, None, @@ -1173,7 +1550,7 @@ mod tests { flags: vec![flag.clone()], }; let result = matcher - .evaluate_feature_flags(flags, Some(overrides), None) + .evaluate_all_feature_flags(flags, Some(overrides), None, None) .await; assert!(!result.error_while_computing_flags); assert_eq!( @@ -1184,8 +1561,11 @@ mod tests { #[tokio::test] async fn test_group_property_overrides() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let flag = create_test_flag( None, @@ -1214,7 +1594,7 @@ mod tests { None, ); - let mut cache = GroupTypeMappingCache::new(team.id, client.clone()); + let mut cache = GroupTypeMappingCache::new(team.id, postgres_reader.clone()); let group_types_to_indexes = [("organization".to_string(), 1)].into_iter().collect(); cache.group_types_to_indexes = group_types_to_indexes; cache.group_indexes_to_types = [(1, "organization".to_string())].into_iter().collect(); @@ -1232,7 +1612,8 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), Some(cache), None, Some(groups), @@ -1242,7 +1623,7 @@ mod tests { flags: vec![flag.clone()], }; let result = matcher - .evaluate_feature_flags(flags, None, Some(group_overrides)) + .evaluate_all_feature_flags(flags, None, Some(group_overrides), None) .await; assert!(!result.error_while_computing_flags); @@ -1255,9 +1636,10 @@ mod tests { #[tokio::test] async fn test_get_matching_variant_with_cache() { let flag = create_test_flag_with_variants(1); - let database_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; - let mut cache = GroupTypeMappingCache::new(1, database_client.clone()); + let mut cache = GroupTypeMappingCache::new(1, postgres_reader.clone()); let group_types_to_indexes = [("group_type_1".to_string(), 1)].into_iter().collect(); let group_type_index_to_name = [(1, "group_type_1".to_string())].into_iter().collect(); @@ -1270,12 +1652,13 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), 1, - database_client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), Some(cache), None, Some(groups), ); - let variant = matcher.get_matching_variant(&flag).await.unwrap(); + let variant = matcher.get_matching_variant(&flag, None).await.unwrap(); assert!(variant.is_some(), "No variant was selected"); assert!( ["control", "test", "test2"].contains(&variant.unwrap().as_str()), @@ -1285,8 +1668,9 @@ mod tests { #[tokio::test] async fn test_get_matching_variant_with_db() { - let database_client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(database_client.clone()) + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1295,20 +1679,22 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), team.id, - database_client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let variant = matcher.get_matching_variant(&flag).await.unwrap(); + let variant = matcher.get_matching_variant(&flag, None).await.unwrap(); assert!(variant.is_some()); assert!(["control", "test", "test2"].contains(&variant.unwrap().as_str())); } #[tokio::test] async fn test_is_condition_match_empty_properties() { - let database_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; let flag = create_test_flag( Some(1), None, @@ -1339,13 +1725,14 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), 1, - database_client, + postgres_reader, + postgres_writer, None, None, None, ); let (is_match, reason) = matcher - .is_condition_match(&flag, &condition, None) + .is_condition_match(&flag, &condition, None, None) .await .unwrap(); assert!(is_match); @@ -1395,8 +1782,11 @@ mod tests { #[tokio::test] async fn test_overrides_avoid_db_lookups() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let flag = create_test_flag( None, @@ -1431,19 +1821,21 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); let result = matcher - .evaluate_feature_flags( + .evaluate_all_feature_flags( FeatureFlagList { flags: vec![flag.clone()], }, Some(person_property_overrides), None, + None, ) .await; @@ -1459,8 +1851,11 @@ mod tests { #[tokio::test] async fn test_fallback_to_db_when_overrides_insufficient() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let flag = create_test_flag( None, @@ -1504,7 +1899,7 @@ mod tests { )])); insert_person_for_team_in_pg( - client.clone(), + postgres_reader.clone(), team.id, "test_user".to_string(), Some(json!({"email": "test@example.com", "age": 30})), @@ -1515,14 +1910,15 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); let result = matcher - .get_match(&flag, person_property_overrides.clone()) + .get_match(&flag, person_property_overrides.clone(), None) .await .unwrap(); @@ -1538,12 +1934,15 @@ mod tests { #[tokio::test] async fn test_property_fetching_and_caching() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let distinct_id = "test_user".to_string(); insert_person_for_team_in_pg( - client.clone(), + postgres_reader.clone(), team.id, distinct_id.clone(), Some(json!({"email": "test@example.com", "age": 30})), @@ -1551,8 +1950,15 @@ mod tests { .await .unwrap(); - let mut matcher = - FeatureFlagMatcher::new(distinct_id, team.id, client.clone(), None, None, None); + let mut matcher = FeatureFlagMatcher::new( + distinct_id, + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + None, + None, + None, + ); let properties = matcher .get_person_properties_from_cache_or_db() @@ -1572,12 +1978,15 @@ mod tests { #[tokio::test] async fn test_property_caching() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let distinct_id = "test_user".to_string(); insert_person_for_team_in_pg( - client.clone(), + postgres_reader.clone(), team.id, distinct_id.clone(), Some(json!({"email": "test@example.com", "age": 30})), @@ -1588,7 +1997,8 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( distinct_id.clone(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, @@ -1620,7 +2030,8 @@ mod tests { let mut new_matcher = FeatureFlagMatcher::new( distinct_id.clone(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, @@ -1706,8 +2117,11 @@ mod tests { #[tokio::test] async fn test_concurrent_flag_evaluation() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let flag = Arc::new(create_test_flag( None, Some(team.id), @@ -1732,17 +2146,19 @@ mod tests { let mut handles = vec![]; for i in 0..100 { let flag_clone = flag.clone(); - let client_clone = client.clone(); + let postgres_reader_clone = postgres_reader.clone(); + let postgres_writer_clone = postgres_writer.clone(); handles.push(tokio::spawn(async move { let mut matcher = FeatureFlagMatcher::new( format!("test_user_{}", i), team.id, - client_clone, + postgres_reader_clone, + postgres_writer_clone, None, None, None, ); - matcher.get_match(&flag_clone, None).await.unwrap() + matcher.get_match(&flag_clone, None, None).await.unwrap() })); } @@ -1758,8 +2174,11 @@ mod tests { #[tokio::test] async fn test_property_operators() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let flag = create_test_flag( None, @@ -1798,7 +2217,7 @@ mod tests { ); insert_person_for_team_in_pg( - client.clone(), + postgres_reader.clone(), team.id, "test_user".to_string(), Some(json!({"email": "user@example@domain.com", "age": 30})), @@ -1809,20 +2228,22 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let result = matcher.get_match(&flag, None).await.unwrap(); + let result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(result.matches); } #[tokio::test] async fn test_empty_hashed_identifier() { - let client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; let flag = create_test_flag( Some(1), @@ -1845,17 +2266,25 @@ mod tests { None, ); - let mut matcher = FeatureFlagMatcher::new("".to_string(), 1, client, None, None, None); + let mut matcher = FeatureFlagMatcher::new( + "".to_string(), + 1, + postgres_reader, + postgres_writer, + None, + None, + None, + ); - let result = matcher.get_match(&flag, None).await.unwrap(); + let result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(!result.matches); } #[tokio::test] async fn test_rollout_percentage() { - let client = setup_pg_client(None).await; - + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; let mut flag = create_test_flag( Some(1), None, @@ -1877,24 +2306,32 @@ mod tests { None, ); - let mut matcher = - FeatureFlagMatcher::new("test_user".to_string(), 1, client, None, None, None); + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + 1, + postgres_reader, + postgres_writer, + None, + None, + None, + ); - let result = matcher.get_match(&flag, None).await.unwrap(); + let result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(!result.matches); // Now set the rollout percentage to 100% flag.filters.groups[0].rollout_percentage = Some(100.0); - let result = matcher.get_match(&flag, None).await.unwrap(); + let result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(result.matches); } #[tokio::test] async fn test_uneven_variant_distribution() { - let client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; let mut flag = create_test_flag_with_variants(1); @@ -1920,8 +2357,15 @@ mod tests { // Ensure the flag is person-based by setting aggregation_group_type_index to None flag.filters.aggregation_group_type_index = None; - let mut matcher = - FeatureFlagMatcher::new("test_user".to_string(), 1, client, None, None, None); + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + 1, + postgres_reader, + postgres_writer, + None, + None, + None, + ); let mut control_count = 0; let mut test_count = 0; @@ -1930,7 +2374,7 @@ mod tests { // Run the test multiple times to simulate distribution for i in 0..1000 { matcher.distinct_id = format!("user_{}", i); - let variant = matcher.get_matching_variant(&flag).await.unwrap(); + let variant = matcher.get_matching_variant(&flag, None).await.unwrap(); match variant.as_deref() { Some("control") => control_count += 1, Some("test") => test_count += 1, @@ -1948,14 +2392,22 @@ mod tests { #[tokio::test] async fn test_missing_properties_in_db() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); - - // Insert a person without properties - insert_person_for_team_in_pg(client.clone(), team.id, "test_user".to_string(), None) + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); + // Insert a person without properties + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + None, + ) + .await + .unwrap(); + let flag = create_test_flag( None, Some(team.id), @@ -1986,25 +2438,29 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let result = matcher.get_match(&flag, None).await.unwrap(); + let result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(!result.matches); } #[tokio::test] async fn test_malformed_property_data() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); // Insert a person with malformed properties insert_person_for_team_in_pg( - client.clone(), + postgres_reader.clone(), team.id, "test_user".to_string(), Some(json!({"age": "not_a_number"})), @@ -2042,13 +2498,14 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let result = matcher.get_match(&flag, None).await.unwrap(); + let result = matcher.get_match(&flag, None, None).await.unwrap(); // The match should fail due to invalid data type assert!(!result.matches); @@ -2056,8 +2513,11 @@ mod tests { #[tokio::test] async fn test_get_match_with_insufficient_overrides() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let flag = create_test_flag( None, @@ -2101,7 +2561,7 @@ mod tests { )])); insert_person_for_team_in_pg( - client.clone(), + postgres_reader.clone(), team.id, "test_user".to_string(), Some(json!({"email": "test@example.com", "age": 30})), @@ -2112,20 +2572,25 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let result = matcher.get_match(&flag, person_overrides).await.unwrap(); + let result = matcher + .get_match(&flag, person_overrides, None) + .await + .unwrap(); assert!(result.matches); } #[tokio::test] async fn test_evaluation_reasons() { - let client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; let flag = create_test_flag( Some(1), None, @@ -2147,11 +2612,18 @@ mod tests { None, ); - let mut matcher = - FeatureFlagMatcher::new("test_user".to_string(), 1, client, None, None, None); + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + 1, + postgres_reader.clone(), + postgres_writer.clone(), + None, + None, + None, + ); let (is_match, reason) = matcher - .is_condition_match(&flag, &flag.filters.groups[0], None) + .is_condition_match(&flag, &flag.filters.groups[0], None, None) .await .unwrap(); @@ -2161,8 +2633,11 @@ mod tests { #[tokio::test] async fn test_complex_conditions() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let flag = create_test_flag( Some(1), @@ -2205,7 +2680,7 @@ mod tests { ); insert_person_for_team_in_pg( - client.clone(), + postgres_reader.clone(), team.id, "test_user".to_string(), Some(json!({"email": "user2@example.com", "age": 35})), @@ -2216,21 +2691,25 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_user".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let result = matcher.get_match(&flag, None).await.unwrap(); + let result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(result.matches); } #[tokio::test] async fn test_super_condition_matches_boolean() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let flag = create_test_flag( Some(1), @@ -2288,7 +2767,7 @@ mod tests { ); insert_person_for_team_in_pg( - client.clone(), + postgres_reader.clone(), team.id, "test_id".to_string(), Some(json!({"email": "test@posthog.com", "is_enabled": true})), @@ -2299,7 +2778,8 @@ mod tests { let mut matcher_test_id = FeatureFlagMatcher::new( "test_id".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, @@ -2308,7 +2788,8 @@ mod tests { let mut matcher_example_id = FeatureFlagMatcher::new( "lil_id".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, @@ -2317,15 +2798,22 @@ mod tests { let mut matcher_another_id = FeatureFlagMatcher::new( "another_id".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let result_test_id = matcher_test_id.get_match(&flag, None).await.unwrap(); - let result_example_id = matcher_example_id.get_match(&flag, None).await.unwrap(); - let result_another_id = matcher_another_id.get_match(&flag, None).await.unwrap(); + let result_test_id = matcher_test_id.get_match(&flag, None, None).await.unwrap(); + let result_example_id = matcher_example_id + .get_match(&flag, None, None) + .await + .unwrap(); + let result_another_id = matcher_another_id + .get_match(&flag, None, None) + .await + .unwrap(); assert!(result_test_id.matches); assert!(result_test_id.reason == FeatureFlagMatchReason::SuperConditionValue); @@ -2337,11 +2825,14 @@ mod tests { #[tokio::test] async fn test_super_condition_matches_string() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); insert_person_for_team_in_pg( - client.clone(), + postgres_reader.clone(), team.id, "test_id".to_string(), Some(json!({"email": "test@posthog.com", "is_enabled": "true"})), @@ -2407,13 +2898,14 @@ mod tests { let mut matcher = FeatureFlagMatcher::new( "test_id".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let result = matcher.get_match(&flag, None).await.unwrap(); + let result = matcher.get_match(&flag, None, None).await.unwrap(); assert!(result.matches); assert_eq!(result.reason, FeatureFlagMatchReason::SuperConditionValue); @@ -2422,11 +2914,14 @@ mod tests { #[tokio::test] async fn test_super_condition_matches_and_false() { - let client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(client.clone()).await.unwrap(); + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); insert_person_for_team_in_pg( - client.clone(), + postgres_reader.clone(), team.id, "test_id".to_string(), Some(json!({"email": "test@posthog.com", "is_enabled": true})), @@ -2492,7 +2987,8 @@ mod tests { let mut matcher_test_id = FeatureFlagMatcher::new( "test_id".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, @@ -2501,7 +2997,8 @@ mod tests { let mut matcher_example_id = FeatureFlagMatcher::new( "lil_id".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, @@ -2510,15 +3007,22 @@ mod tests { let mut matcher_another_id = FeatureFlagMatcher::new( "another_id".to_string(), team.id, - client.clone(), + postgres_reader.clone(), + postgres_writer.clone(), None, None, None, ); - let result_test_id = matcher_test_id.get_match(&flag, None).await.unwrap(); - let result_example_id = matcher_example_id.get_match(&flag, None).await.unwrap(); - let result_another_id = matcher_another_id.get_match(&flag, None).await.unwrap(); + let result_test_id = matcher_test_id.get_match(&flag, None, None).await.unwrap(); + let result_example_id = matcher_example_id + .get_match(&flag, None, None) + .await + .unwrap(); + let result_another_id = matcher_another_id + .get_match(&flag, None, None) + .await + .unwrap(); assert!(!result_test_id.matches); assert_eq!( @@ -2541,4 +3045,430 @@ mod tests { ); assert_eq!(result_another_id.condition_index, Some(2)); } + + #[tokio::test] + async fn test_set_feature_flag_hash_key_overrides_success() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + let distinct_id = "user1".to_string(); + + // Insert person + insert_person_for_team_in_pg(postgres_reader.clone(), team.id, distinct_id.clone(), None) + .await + .unwrap(); + + // Create a feature flag with ensure_experience_continuity = true + let flag = create_test_flag( + None, + Some(team.id), + Some("Test Flag".to_string()), + Some("test_flag".to_string()), + Some(FlagFilters { + groups: vec![], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + Some(false), // not deleted + Some(true), // active + Some(true), // ensure_experience_continuity + ); + + // need to convert flag to FeatureFlagRow + let flag_row = FeatureFlagRow { + id: flag.id, + team_id: flag.team_id, + name: flag.name, + key: flag.key, + filters: json!(flag.filters), + deleted: flag.deleted, + active: flag.active, + ensure_experience_continuity: flag.ensure_experience_continuity, + }; + + // Insert the feature flag into the database + insert_flag_for_team_in_pg(postgres_writer.clone(), team.id, Some(flag_row)) + .await + .unwrap(); + + // Attempt to set hash key override + let result = set_feature_flag_hash_key_overrides( + postgres_writer.clone(), + team.id, + vec![distinct_id.clone()], + "hash_key_2".to_string(), + ) + .await + .unwrap(); + + assert!(result, "Hash key override should be set successfully"); + + // Retrieve the hash key overrides + let overrides = get_feature_flag_hash_key_overrides( + postgres_reader.clone(), + team.id, + vec![distinct_id.clone()], + ) + .await + .unwrap(); + + assert!( + !overrides.is_empty(), + "At least one hash key override should be set" + ); + assert_eq!( + overrides.get("test_flag"), + Some(&"hash_key_2".to_string()), + "Hash key override for 'test_flag' should match the set value" + ); + } + + #[tokio::test] + async fn test_get_feature_flag_hash_key_overrides_success() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + let distinct_id = "user2".to_string(); + + // Insert person + insert_person_for_team_in_pg(postgres_reader.clone(), team.id, distinct_id.clone(), None) + .await + .unwrap(); + + // Create a feature flag with ensure_experience_continuity = true + let flag = create_test_flag( + None, + Some(team.id), + Some("Test Flag".to_string()), + Some("test_flag".to_string()), + Some(FlagFilters { + groups: vec![], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + Some(false), // not deleted + Some(true), // active + Some(true), // ensure_experience_continuity + ); + + // Convert flag to FeatureFlagRow + let flag_row = FeatureFlagRow { + id: flag.id, + team_id: flag.team_id, + name: flag.name, + key: flag.key, + filters: json!(flag.filters), + deleted: flag.deleted, + active: flag.active, + ensure_experience_continuity: flag.ensure_experience_continuity, + }; + + // Insert the feature flag into the database + insert_flag_for_team_in_pg(postgres_writer.clone(), team.id, Some(flag_row)) + .await + .unwrap(); + + // Set hash key override + set_feature_flag_hash_key_overrides( + postgres_writer.clone(), + team.id, + vec![distinct_id.clone()], + "hash_key_2".to_string(), + ) + .await + .unwrap(); + + // Retrieve hash key overrides + let overrides = get_feature_flag_hash_key_overrides( + postgres_reader.clone(), + team.id, + vec![distinct_id.clone()], + ) + .await + .unwrap(); + + assert_eq!( + overrides.get("test_flag"), + Some(&"hash_key_2".to_string()), + "Hash key override should match the set value" + ); + } + #[tokio::test] + async fn test_evaluate_feature_flags_with_experience_continuity() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + let distinct_id = "user3".to_string(); + + // Insert person + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "user3@example.com"})), + ) + .await + .unwrap(); + + // Create flag with experience continuity + let flag = create_test_flag( + None, + Some(team.id), + None, + Some("flag_continuity".to_string()), + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "email".to_string(), + value: json!("user3@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + Some(true), + ); + + // Set hash key override + set_feature_flag_hash_key_overrides( + postgres_writer.clone(), + team.id, + vec![distinct_id.clone()], + "hash_key_continuity".to_string(), + ) + .await + .unwrap(); + + let flags = FeatureFlagList { + flags: vec![flag.clone()], + }; + + let result = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + None, + None, + None, + ) + .evaluate_all_feature_flags(flags, None, None, Some("hash_key_continuity".to_string())) + .await; + + assert!(!result.error_while_computing_flags, "No error should occur"); + assert_eq!( + result.feature_flags.get("flag_continuity"), + Some(&FlagValue::Boolean(true)), + "Flag should be evaluated as true with continuity" + ); + } + + #[tokio::test] + async fn test_evaluate_feature_flags_with_continuity_missing_override() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + let distinct_id = "user4".to_string(); + + // Insert person + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "user4@example.com"})), + ) + .await + .unwrap(); + + // Create flag with experience continuity + let flag = create_test_flag( + None, + Some(team.id), + None, + Some("flag_continuity_missing".to_string()), + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "email".to_string(), + value: json!("user4@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + Some(true), + ); + + let flags = FeatureFlagList { + flags: vec![flag.clone()], + }; + + let result = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + None, + None, + None, + ) + .evaluate_all_feature_flags(flags, None, None, None) + .await; + + assert!(!result.error_while_computing_flags, "No error should occur"); + assert_eq!( + result.feature_flags.get("flag_continuity_missing"), + Some(&FlagValue::Boolean(true)), + "Flag should be evaluated as true even without continuity override" + ); + } + + #[tokio::test] + async fn test_evaluate_all_feature_flags_mixed_continuity() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + let distinct_id = "user5".to_string(); + + // Insert person + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "user5@example.com"})), + ) + .await + .unwrap(); + + // Create flag with continuity + let flag_continuity = create_test_flag( + None, + Some(team.id), + None, + Some("flag_continuity_mix".to_string()), + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "email".to_string(), + value: json!("user5@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + Some(true), + ); + + // Create flag without continuity + let flag_no_continuity = create_test_flag( + None, + Some(team.id), + None, + Some("flag_no_continuity_mix".to_string()), + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "age".to_string(), + value: json!(30), + operator: Some(OperatorType::Gt), + prop_type: "person".to_string(), + group_type_index: None, + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + Some(false), + ); + + // Set hash key override for the continuity flag + set_feature_flag_hash_key_overrides( + postgres_writer.clone(), + team.id, + vec![distinct_id.clone()], + "hash_key_mixed".to_string(), + ) + .await + .unwrap(); + + let flags = FeatureFlagList { + flags: vec![flag_continuity.clone(), flag_no_continuity.clone()], + }; + + let result = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + None, + None, + None, + ) + .evaluate_all_feature_flags( + flags, + Some(HashMap::from([("age".to_string(), json!(35))])), + None, + Some("hash_key_mixed".to_string()), + ) + .await; + + assert!(!result.error_while_computing_flags, "No error should occur"); + assert_eq!( + result.feature_flags.get("flag_continuity_mix"), + Some(&FlagValue::Boolean(true)), + "Continuity flag should be evaluated as true" + ); + assert_eq!( + result.feature_flags.get("flag_no_continuity_mix"), + Some(&FlagValue::Boolean(true)), + "Non-continuity flag should be evaluated based on properties" + ); + } } diff --git a/rust/feature-flags/src/flag_request.rs b/rust/feature-flags/src/flag_request.rs index 05c4ceff047be..4d215867813e9 100644 --- a/rust/feature-flags/src/flag_request.rs +++ b/rust/feature-flags/src/flag_request.rs @@ -160,6 +160,7 @@ impl FlagRequest { redis_client: Arc, pg_client: Arc, ) -> Result { + // TODO add a cache hit/miss counter match FeatureFlagList::from_redis(redis_client.clone(), team_id).await { Ok(flags) => Ok(flags), Err(_) => match FeatureFlagList::from_pg(pg_client, team_id).await { @@ -193,7 +194,7 @@ mod tests { use crate::flag_request::FlagRequest; use crate::redis::Client as RedisClient; use crate::team::Team; - use crate::test_utils::{insert_new_team_in_redis, setup_pg_client, setup_redis_client}; + use crate::test_utils::{insert_new_team_in_redis, setup_pg_reader_client, setup_redis_client}; use bytes::Bytes; use serde_json::json; @@ -245,7 +246,7 @@ mod tests { #[tokio::test] async fn token_is_returned_correctly() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let pg_client = setup_pg_reader_client(None).await; let team = insert_new_team_in_redis(redis_client.clone()) .await .expect("Failed to insert new team in Redis"); @@ -270,7 +271,7 @@ mod tests { #[tokio::test] async fn test_get_team_from_cache_or_pg() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let pg_client = setup_pg_reader_client(None).await; let team = insert_new_team_in_redis(redis_client.clone()) .await .expect("Failed to insert new team in Redis"); @@ -324,7 +325,7 @@ mod tests { #[tokio::test] async fn test_get_flags_from_cache_or_pg() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let pg_client = setup_pg_reader_client(None).await; let team = insert_new_team_in_redis(redis_client.clone()) .await .expect("Failed to insert new team in Redis"); @@ -480,7 +481,7 @@ mod tests { #[tokio::test] async fn test_error_cases() { let redis_client = setup_redis_client(None); - let pg_client = setup_pg_client(None).await; + let pg_client = setup_pg_reader_client(None).await; // Test invalid token let flag_request = FlagRequest { diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs index 83d1c0f66f352..6c62e7c5ec091 100644 --- a/rust/feature-flags/src/request_handler.rs +++ b/rust/feature-flags/src/request_handler.rs @@ -10,6 +10,7 @@ use crate::{ use axum::{extract::State, http::HeaderMap}; use base64::{engine::general_purpose, Engine as _}; use bytes::Bytes; +use derive_builder::Builder; use flate2::read::GzDecoder; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -60,6 +61,24 @@ pub struct RequestContext { pub body: Bytes, } +#[derive(Builder, Clone)] +#[builder(setter(into))] +pub struct FeatureFlagEvaluationContext { + team_id: i32, + distinct_id: String, + feature_flags: FeatureFlagList, + postgres_reader: Arc, + postgres_writer: Arc, + #[builder(default)] + person_property_overrides: Option>, + #[builder(default)] + group_property_overrides: Option>>, + #[builder(default)] + groups: Option>, + #[builder(default)] + hash_key_override: Option, +} + pub async fn process_request(context: RequestContext) -> Result { let RequestContext { state, @@ -71,10 +90,10 @@ pub async fn process_request(context: RequestContext) -> Result Result = state.postgres_reader.clone(); + let postgres_writer_dyn: Arc = state.postgres_writer.clone(); + + let evaluation_context = FeatureFlagEvaluationContextBuilder::default() + .team_id(team_id) + .distinct_id(distinct_id) + .feature_flags(feature_flags_from_cache_or_pg) + .postgres_reader(postgres_reader_dyn) + .postgres_writer(postgres_writer_dyn) + .person_property_overrides(person_property_overrides) + .group_property_overrides(group_property_overrides) + .groups(groups) + .hash_key_override(hash_key_override) + .build() + .expect("Failed to build FeatureFlagEvaluationContext"); + + let flags_response = evaluate_feature_flags(evaluation_context).await; Ok(flags_response) } @@ -189,29 +216,24 @@ fn decode_request(headers: &HeaderMap, body: Bytes) -> Result, - person_property_overrides: Option>, - group_property_overrides: Option>>, - groups: Option>, -) -> FlagsResponse { - let group_type_mapping_cache = GroupTypeMappingCache::new(team_id, database_client.clone()); +pub async fn evaluate_feature_flags(context: FeatureFlagEvaluationContext) -> FlagsResponse { + let group_type_mapping_cache = + GroupTypeMappingCache::new(context.team_id, context.postgres_reader.clone()); let mut feature_flag_matcher = FeatureFlagMatcher::new( - distinct_id.clone(), - team_id, - database_client, + context.distinct_id, + context.team_id, + context.postgres_reader, + context.postgres_writer, Some(group_type_mapping_cache), - None, - groups, + None, // TODO maybe remove this from the matcher struct, since it's used internally but not passed around + context.groups, ); feature_flag_matcher - .evaluate_feature_flags( - feature_flags_from_cache_or_pg, - person_property_overrides, - group_property_overrides, + .evaluate_all_feature_flags( + context.feature_flags, + context.person_property_overrides, + context.group_property_overrides, + context.hash_key_override, ) .await } @@ -234,7 +256,7 @@ mod tests { api::FlagValue, config::Config, flag_definitions::{FeatureFlag, FlagFilters, FlagGroupType, OperatorType, PropertyFilter}, - test_utils::{insert_new_team_in_pg, setup_pg_client}, + test_utils::{insert_new_team_in_pg, setup_pg_reader_client, setup_pg_writer_client}, }; use super::*; @@ -335,7 +357,8 @@ mod tests { #[tokio::test] async fn test_evaluate_feature_flags() { - let pg_client = setup_pg_client(None).await; + let postgres_reader: Arc = setup_pg_reader_client(None).await; + let postgres_writer: Arc = setup_pg_writer_client(None).await; let flag = FeatureFlag { name: Some("Test Flag".to_string()), id: 1, @@ -368,16 +391,17 @@ mod tests { let mut person_properties = HashMap::new(); person_properties.insert("country".to_string(), json!("US")); - let result = evaluate_feature_flags( - 1, - "user123".to_string(), - feature_flag_list, - pg_client, - Some(person_properties), - None, - None, - ) - .await; + let evaluation_context = FeatureFlagEvaluationContextBuilder::default() + .team_id(1) + .distinct_id("user123".to_string()) + .feature_flags(feature_flag_list) + .postgres_reader(postgres_reader) + .postgres_writer(postgres_writer) + .person_property_overrides(Some(person_properties)) + .build() + .expect("Failed to build FeatureFlagEvaluationContext"); + + let result = evaluate_feature_flags(evaluation_context).await; assert!(!result.error_while_computing_flags); assert!(result.feature_flags.contains_key("test_flag")); @@ -479,7 +503,8 @@ mod tests { #[tokio::test] async fn test_evaluate_feature_flags_multiple_flags() { - let pg_client = setup_pg_client(None).await; + let postgres_reader: Arc = setup_pg_reader_client(None).await; + let postgres_writer: Arc = setup_pg_writer_client(None).await; let flags = vec![ FeatureFlag { name: Some("Flag 1".to_string()), @@ -525,16 +550,16 @@ mod tests { let feature_flag_list = FeatureFlagList { flags }; - let result = evaluate_feature_flags( - 1, - "user123".to_string(), - feature_flag_list, - pg_client, - None, - None, - None, - ) - .await; + let evaluation_context = FeatureFlagEvaluationContextBuilder::default() + .team_id(1) + .distinct_id("user123".to_string()) + .feature_flags(feature_flag_list) + .postgres_reader(postgres_reader) + .postgres_writer(postgres_writer) + .build() + .expect("Failed to build FeatureFlagEvaluationContext"); + + let result = evaluate_feature_flags(evaluation_context).await; assert!(!result.error_while_computing_flags); assert_eq!(result.feature_flags["flag_1"], FlagValue::Boolean(true)); @@ -581,8 +606,11 @@ mod tests { #[tokio::test] async fn test_evaluate_feature_flags_with_overrides() { - let pg_client = setup_pg_client(None).await; - let team = insert_new_team_in_pg(pg_client.clone()).await.unwrap(); + let postgres_reader: Arc = setup_pg_reader_client(None).await; + let postgres_writer: Arc = setup_pg_writer_client(None).await; + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); let flag = FeatureFlag { name: Some("Test Flag".to_string()), @@ -621,16 +649,18 @@ mod tests { ]), )]); - let result = evaluate_feature_flags( - team.id, - "user123".to_string(), - feature_flag_list, - pg_client, - None, - Some(group_property_overrides), - Some(groups), - ) - .await; + let evaluation_context = FeatureFlagEvaluationContextBuilder::default() + .team_id(team.id) + .distinct_id("user123".to_string()) + .feature_flags(feature_flag_list) + .postgres_reader(postgres_reader) + .postgres_writer(postgres_writer) + .group_property_overrides(Some(group_property_overrides)) + .groups(Some(groups)) + .build() + .expect("Failed to build FeatureFlagEvaluationContext"); + + let result = evaluate_feature_flags(evaluation_context).await; assert!( !result.error_while_computing_flags, @@ -656,7 +686,8 @@ mod tests { #[tokio::test] async fn test_long_distinct_id() { let long_id = "a".repeat(1000); - let pg_client = setup_pg_client(None).await; + let postgres_reader: Arc = setup_pg_reader_client(None).await; + let postgres_writer: Arc = setup_pg_writer_client(None).await; let flag = FeatureFlag { name: Some("Test Flag".to_string()), id: 1, @@ -680,9 +711,16 @@ mod tests { let feature_flag_list = FeatureFlagList { flags: vec![flag] }; - let result = - evaluate_feature_flags(1, long_id, feature_flag_list, pg_client, None, None, None) - .await; + let evaluation_context = FeatureFlagEvaluationContextBuilder::default() + .team_id(1) + .distinct_id(long_id) + .feature_flags(feature_flag_list) + .postgres_reader(postgres_reader) + .postgres_writer(postgres_writer) + .build() + .expect("Failed to build FeatureFlagEvaluationContext"); + + let result = evaluate_feature_flags(evaluation_context).await; assert!(!result.error_while_computing_flags); assert_eq!(result.feature_flags["test_flag"], FlagValue::Boolean(true)); diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index 51a4b95e3f9c2..e12b32b464795 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -15,15 +15,16 @@ use crate::{ #[derive(Clone)] pub struct State { - // TODO add writers when ready pub redis: Arc, - pub postgres: Arc, + pub postgres_reader: Arc, + pub postgres_writer: Arc, pub geoip: Arc, } pub fn router( redis: Arc, - postgres: Arc, + postgres_reader: Arc, + postgres_writer: Arc, geoip: Arc, liveness: HealthRegistry, metrics: bool, @@ -35,7 +36,8 @@ where { let state = State { redis, - postgres, + postgres_reader, + postgres_writer, geoip, }; @@ -64,39 +66,3 @@ where pub async fn index() -> &'static str { "feature flags service" } - -// TODO, eventually we can differentiate read and write postgres clients, if needed -// I _think_ everything is read-only, but I'm not 100% sure yet -// here's how that client would look -// use std::sync::Arc; - -// use axum::{routing::post, Router}; - -// use crate::{database::Client as DatabaseClient, redis::Client as RedisClient, v0_endpoint}; - -// #[derive(Clone)] -// pub struct State { -// pub redis: Arc, -// pub postgres_read: Arc, -// pub postgres_write: Arc, -// } - -// pub fn router( -// redis: Arc, -// postgres_read: Arc, -// postgres_write: Arc, -// ) -> Router -// where -// R: RedisClient + Send + Sync + 'static, -// D: DatabaseClient + Send + Sync + 'static, -// { -// let state = State { -// redis, -// postgres_read, -// postgres_write, -// }; - -// Router::new() -// .route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags)) -// .with_state(state) -// } diff --git a/rust/feature-flags/src/server.rs b/rust/feature-flags/src/server.rs index 4c476aaf46169..dbda44fe244e1 100644 --- a/rust/feature-flags/src/server.rs +++ b/rust/feature-flags/src/server.rs @@ -24,11 +24,24 @@ where } }; - let read_postgres_client = - match get_pool(&config.read_database_url, config.max_pg_connections).await { + // TODO - we should have a dedicated URL for both this and the writer – the reader will read + // from the replica, and the writer will write to the main database. + let postgres_reader = match get_pool(&config.read_database_url, config.max_pg_connections).await + { + Ok(client) => Arc::new(client), + Err(e) => { + tracing::error!("Failed to create read Postgres client: {}", e); + return; + } + }; + + let postgres_writer = + // TODO - we should have a dedicated URL for both this and the reader – the reader will read + // from the replica, and the writer will write to the main database. + match get_pool(&config.write_database_url, config.max_pg_connections).await { Ok(client) => Arc::new(client), Err(e) => { - tracing::error!("Failed to create read Postgres client: {}", e); + tracing::error!("Failed to create write Postgres client: {}", e); return; } }; @@ -52,7 +65,8 @@ where // You can decide which client to pass to the router, or pass both if needed let app = router::router( redis_client, - read_postgres_client, + postgres_reader, + postgres_writer, geoip_service, health, config.enable_metrics, diff --git a/rust/feature-flags/src/team.rs b/rust/feature-flags/src/team.rs index bd975385eb216..0fa75f0bd3db7 100644 --- a/rust/feature-flags/src/team.rs +++ b/rust/feature-flags/src/team.rs @@ -97,7 +97,7 @@ mod tests { use crate::{ team, test_utils::{ - insert_new_team_in_pg, insert_new_team_in_redis, random_string, setup_pg_client, + insert_new_team_in_pg, insert_new_team_in_redis, random_string, setup_pg_reader_client, setup_redis_client, }, }; @@ -181,9 +181,9 @@ mod tests { #[tokio::test] async fn test_fetch_team_from_pg() { - let client = setup_pg_client(None).await; + let client = setup_pg_reader_client(None).await; - let team = insert_new_team_in_pg(client.clone()) + let team = insert_new_team_in_pg(client.clone(), None) .await .expect("Failed to insert team in pg"); @@ -203,14 +203,12 @@ mod tests { // TODO: Figure out a way such that `run_database_migrations` is called only once, and already called // before running these tests. - let client = setup_pg_client(None).await; + let client = setup_pg_reader_client(None).await; let target_token = "xxxx".to_string(); match Team::from_pg(client.clone(), target_token.clone()).await { - Err(FlagError::TokenValidationError) => (), - _ => panic!("Expected TokenValidationError"), + Err(FlagError::RowNotFound) => (), + _ => panic!("Expected RowNotFound"), }; } - - // TODO: Handle cases where db connection fails. } diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index e4777485eb66c..769f95039990d 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -130,7 +130,7 @@ pub fn create_flag_from_json(json_value: Option) -> Vec { flags } -pub async fn setup_pg_client(config: Option<&Config>) -> Arc { +pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc { let config = config.unwrap_or(&DEFAULT_TEST_CONFIG); Arc::new( get_pool(&config.read_database_url, config.max_pg_connections) @@ -139,6 +139,15 @@ pub async fn setup_pg_client(config: Option<&Config>) -> Arc { ) } +pub async fn setup_pg_writer_client(config: Option<&Config>) -> Arc { + let config = config.unwrap_or(&DEFAULT_TEST_CONFIG); + Arc::new( + get_pool(&config.write_database_url, config.max_pg_connections) + .await + .expect("Failed to create Postgres client"), + ) +} + pub struct MockPgClient; #[async_trait] @@ -163,7 +172,10 @@ pub async fn setup_invalid_pg_client() -> Arc { Arc::new(MockPgClient) } -pub async fn insert_new_team_in_pg(client: Arc) -> Result { +pub async fn insert_new_team_in_pg( + client: Arc, + team_id: Option, +) -> Result { const ORG_ID: &str = "019026a4be8000005bf3171d00629163"; client.run_query( @@ -189,7 +201,10 @@ pub async fn insert_new_team_in_pg(client: Arc) -> Result { ) .await?; - let id = rand::thread_rng().gen_range(0..10_000_000); + let id = match team_id { + Some(value) => value, + None => rand::thread_rng().gen_range(0..10_000_000), + }; let token = random_string("phc_", 12); let team = Team { id, diff --git a/rust/feature-flags/tests/test_flag_matching_consistency.rs b/rust/feature-flags/tests/test_flag_matching_consistency.rs index 5ce20cd55e89f..94f4f67dcdc56 100644 --- a/rust/feature-flags/tests/test_flag_matching_consistency.rs +++ b/rust/feature-flags/tests/test_flag_matching_consistency.rs @@ -3,7 +3,9 @@ use feature_flags::feature_flag_match_reason::FeatureFlagMatchReason; /// This ensures there are no mismatches between implementations. use feature_flags::flag_matching::{FeatureFlagMatch, FeatureFlagMatcher}; -use feature_flags::test_utils::{create_flag_from_json, setup_pg_client}; +use feature_flags::test_utils::{ + create_flag_from_json, setup_pg_reader_client, setup_pg_writer_client, +}; use serde_json::json; #[tokio::test] @@ -106,15 +108,23 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { ]; for (i, result) in results.iter().enumerate().take(1000) { - let database_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; let distinct_id = format!("distinct_id_{}", i); - let feature_flag_match = - FeatureFlagMatcher::new(distinct_id, 1, database_client, None, None, None) - .get_match(&flags[0], None) - .await - .unwrap(); + let feature_flag_match = FeatureFlagMatcher::new( + distinct_id, + 1, + postgres_reader, + postgres_writer, + None, + None, + None, + ) + .get_match(&flags[0], None, None) + .await + .unwrap(); if *result { assert_eq!( @@ -1197,14 +1207,22 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { ]; for (i, result) in results.iter().enumerate().take(1000) { - let database_client = setup_pg_client(None).await; + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; let distinct_id = format!("distinct_id_{}", i); - let feature_flag_match = - FeatureFlagMatcher::new(distinct_id, 1, database_client, None, None, None) - .get_match(&flags[0], None) - .await - .unwrap(); + let feature_flag_match = FeatureFlagMatcher::new( + distinct_id, + 1, + postgres_reader, + postgres_writer, + None, + None, + None, + ) + .get_match(&flags[0], None, None) + .await + .unwrap(); if let Some(variant) = &result { assert_eq!( diff --git a/rust/feature-flags/tests/test_flags.rs b/rust/feature-flags/tests/test_flags.rs index f12f8434aface..6b6263b4a772c 100644 --- a/rust/feature-flags/tests/test_flags.rs +++ b/rust/feature-flags/tests/test_flags.rs @@ -8,7 +8,8 @@ use crate::common::*; use feature_flags::config::DEFAULT_TEST_CONFIG; use feature_flags::test_utils::{ - insert_flags_for_team_in_redis, insert_new_team_in_redis, setup_redis_client, + insert_flags_for_team_in_redis, insert_new_team_in_pg, insert_new_team_in_redis, + setup_pg_reader_client, setup_redis_client, }; pub mod common; @@ -40,6 +41,7 @@ async fn it_sends_flag_request() -> Result<()> { ], }, }]); + insert_flags_for_team_in_redis(client, team.id, Some(flag_json.to_string())).await?; let server = ServerHandle::for_config(config).await; @@ -234,3 +236,267 @@ async fn it_handles_malformed_json() -> Result<()> { // ); // Ok(()) // } + +#[tokio::test] +async fn it_handles_multivariate_flags() -> Result<()> { + let config = DEFAULT_TEST_CONFIG.clone(); + let distinct_id = "user_distinct_id".to_string(); + + let client = setup_redis_client(Some(config.redis_url.clone())); + let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + let token = team.api_token; + + let flag_json = json!([{ + "id": 1, + "key": "multivariate-flag", + "name": "Multivariate Flag", + "active": true, + "deleted": false, + "team_id": team.id, + "filters": { + "groups": [ + { + "properties": [], + "rollout_percentage": 100 + } + ], + "multivariate": { + "variants": [ + { + "key": "control", + "name": "Control", + "rollout_percentage": 0 + }, + { + "key": "test_a", + "name": "Test A", + "rollout_percentage": 0 + }, + { + "key": "test_b", + "name": "Test B", + "rollout_percentage": 100 + } + ] + } + }, + }]); + + insert_flags_for_team_in_redis(client, team.id, Some(flag_json.to_string())).await?; + + let server = ServerHandle::for_config(config).await; + + let payload = json!({ + "token": token, + "distinct_id": distinct_id, + }); + + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::OK, res.status()); + + let json_data = res.json::().await?; + assert_json_include!( + actual: json_data, + expected: json!({ + "errorWhileComputingFlags": false, + "featureFlags": { + "multivariate-flag": "test_b" + } + }) + ); + + let variant = json_data["featureFlags"]["multivariate-flag"] + .as_str() + .unwrap(); + assert!(["control", "test_a", "test_b"].contains(&variant)); + + Ok(()) +} + +#[tokio::test] +async fn it_handles_flag_with_property_filter() -> Result<()> { + let config = DEFAULT_TEST_CONFIG.clone(); + let distinct_id = "user_distinct_id".to_string(); + + let client = setup_redis_client(Some(config.redis_url.clone())); + let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + let token = team.api_token; + + let flag_json = json!([{ + "id": 1, + "key": "property-flag", + "name": "Property Flag", + "active": true, + "deleted": false, + "team_id": team.id, + "filters": { + "groups": [ + { + "properties": [ + { + "key": "email", + "value": "test@example.com", + "operator": "exact", + "type": "person" + } + ], + "rollout_percentage": 100 + } + ], + }, + }]); + + insert_flags_for_team_in_redis(client, team.id, Some(flag_json.to_string())).await?; + + let server = ServerHandle::for_config(config).await; + + // Test with matching property + let payload = json!({ + "token": token, + "distinct_id": distinct_id, + "person_properties": { + "email": "test@example.com" + } + }); + + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::OK, res.status()); + + let json_data = res.json::().await?; + assert_json_include!( + actual: json_data, + expected: json!({ + "errorWhileComputingFlags": false, + "featureFlags": { + "property-flag": true + } + }) + ); + + // Test with non-matching property + let payload = json!({ + "token": token, + "distinct_id": distinct_id, + "person_properties": { + "email": "other@example.com" + } + }); + + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::OK, res.status()); + + let json_data = res.json::().await?; + assert_json_include!( + actual: json_data, + expected: json!({ + "errorWhileComputingFlags": false, + "featureFlags": { + "property-flag": false + } + }) + ); + + Ok(()) +} + +#[tokio::test] +async fn it_handles_flag_with_group_properties() -> Result<()> { + let config = DEFAULT_TEST_CONFIG.clone(); + let distinct_id = "user_distinct_id".to_string(); + + let client = setup_redis_client(Some(config.redis_url.clone())); + let pg_client = setup_pg_reader_client(None).await; + let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + insert_new_team_in_pg(pg_client.clone(), Some(team.id)) + .await + .unwrap(); + let token = team.api_token; + + let flag_json = json!([{ + "id": 1, + "key": "group-flag", + "name": "Group Flag", + "active": true, + "deleted": false, + "team_id": team.id, + "filters": { + "groups": [ + { + "properties": [ + { + "key": "name", + "value": "Test Group", + "operator": "exact", + "type": "group", + "group_type_index": 0 + } + ], + "rollout_percentage": 100 + } + ], + "aggregation_group_type_index": 0 + }, + }]); + + insert_flags_for_team_in_redis(client, team.id, Some(flag_json.to_string())).await?; + + let server = ServerHandle::for_config(config).await; + + // Test with matching group property + let payload = json!({ + "token": token, + "distinct_id": distinct_id, + "groups": { + "project": "test_company_id" + }, + "group_properties": { + "project": { + "name": "Test Group" + } + } + }); + + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::OK, res.status()); + + let json_data = res.json::().await?; + assert_json_include!( + actual: json_data, + expected: json!({ + "errorWhileComputingFlags": false, + "featureFlags": { + "group-flag": true + } + }) + ); + + // Test with non-matching group property + let payload = json!({ + "token": token, + "distinct_id": distinct_id, + "groups": { + "project": "test_company_id" + }, + "group_properties": { + "project": { + "name": "Other Group" + } + } + }); + + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::OK, res.status()); + + let json_data = res.json::().await?; + assert_json_include!( + actual: json_data, + expected: json!({ + "errorWhileComputingFlags": false, + "featureFlags": { + "group-flag": false + } + }) + ); + + Ok(()) +}