diff --git a/.github/workflows/ci-plugin-server.yml b/.github/workflows/ci-plugin-server.yml index a62bd4a66851a1..55b071a49b02a5 100644 --- a/.github/workflows/ci-plugin-server.yml +++ b/.github/workflows/ci-plugin-server.yml @@ -44,7 +44,7 @@ jobs: - 'plugin-server/**' - 'posthog/clickhouse/migrations/**' - 'ee/migrations/**' - - 'ee/management/commands/setup_test_environment.py' + - 'posthog/management/commands/setup_test_environment.py' - 'posthog/migrations/**' - 'posthog/plugins/**' - 'docker*.yml' diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a6ab35daa8d3b6..3624e6c028c0ff 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -34,6 +34,9 @@ jobs: - '.github/workflows/rust.yml' - '.github/workflows/rust-docker-build.yml' - '.github/workflows/rust-hook-migrator-docker.yml' + - 'posthog/management/commands/setup_test_environment.py' + - 'posthog/migrations/**' + - 'ee/migrations/**' build: name: Build rust services @@ -73,6 +76,11 @@ jobs: test: name: Test rust services + strategy: + matrix: + package: + - feature-flags + - others needs: changes runs-on: depot-ubuntu-22.04-4 timeout-minutes: 10 @@ -86,11 +94,15 @@ jobs: # Use sparse checkout to only select files in rust directory # Turning off cone mode ensures that files in the project root are not included during checkout - uses: actions/checkout@v3 - if: needs.changes.outputs.rust == 'true' + if: needs.changes.outputs.rust == 'true' && matrix.package == 'others' with: sparse-checkout: 'rust/' sparse-checkout-cone-mode: false + # For flags checkout entire repository + - uses: actions/checkout@v3 + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + - name: Login to DockerHub if: needs.changes.outputs.rust == 'true' uses: docker/login-action@v2 @@ -99,8 +111,15 @@ jobs: username: posthog password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Setup main repo dependencies for flags + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + run: | + docker compose -f ../docker-compose.dev.yml down + docker compose -f ../docker-compose.dev.yml up -d + echo "127.0.0.1 kafka" | sudo tee -a /etc/hosts + - name: Setup dependencies - if: needs.changes.outputs.rust == 'true' + if: needs.changes.outputs.rust == 'true' && matrix.package == 'others' run: | docker compose up kafka redis db echo_server -d --wait docker compose up setup_test_db @@ -119,9 +138,46 @@ jobs: rust/target key: ${ runner.os }-cargo-debug-${{ hashFiles('**/Cargo.lock') }} + - name: Set up Python + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + uses: actions/setup-python@v5 + with: + python-version: 3.11.9 + cache: 'pip' + cache-dependency-path: '**/requirements*.txt' + token: ${{ secrets.POSTHOG_BOT_GITHUB_TOKEN }} + + # uv is a fast pip alternative: https://github.com/astral-sh/uv/ + - run: pip install uv + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + + - name: Install SAML (python3-saml) dependencies + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + run: | + sudo apt-get update + sudo apt-get install libxml2-dev libxmlsec1-dev libxmlsec1-openssl + + - name: Install python dependencies + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + run: | + uv pip install --system -r ../requirements-dev.txt + uv pip install --system -r ../requirements.txt + + - name: Set up databases + if: needs.changes.outputs.rust == 'true' && matrix.package == 'feature-flags' + env: + DEBUG: 'true' + TEST: 'true' + SECRET_KEY: 'abcdef' # unsafe - for testing only + DATABASE_URL: 'postgres://posthog:posthog@localhost:5432/posthog' + run: cd ../ && python manage.py setup_test_environment --only-postgres + - name: Run cargo test if: needs.changes.outputs.rust == 'true' - run: cargo test --all-features + run: | + echo "Starting cargo test" + cargo test --all-features ${{ matrix.package == 'feature-flags' && '--package feature-flags' || '--workspace --exclude feature-flags' }} + echo "Cargo test completed" linting: name: Lint rust services diff --git a/posthog/management/commands/setup_test_environment.py b/posthog/management/commands/setup_test_environment.py index 39549ec864e6d3..07c39f6ce6414f 100644 --- a/posthog/management/commands/setup_test_environment.py +++ b/posthog/management/commands/setup_test_environment.py @@ -26,6 +26,12 @@ class Command(BaseCommand): help = "Set up databases for non-Python tests that depend on the Django server" + # has optional arg to only run postgres setup + def add_arguments(self, parser): + parser.add_argument( + "--only-postgres", action="store_true", help="Only set up the Postgres database", default=False + ) + def handle(self, *args, **options): if not TEST: raise ValueError("TEST environment variable needs to be set for this command to function") @@ -36,6 +42,10 @@ def handle(self, *args, **options): test_runner.setup_databases() test_runner.setup_test_environment() + if options["only_postgres"]: + print("Only setting up Postgres database") # noqa: T201 + return + print("\nCreating test ClickHouse database...") # noqa: T201 database = Database( CLICKHOUSE_DATABASE, diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 3617bee588a416..15c2210f61fb5f 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1049,10 +1049,12 @@ dependencies = [ "serde-pickle", "serde_json", "sha1", + "sqlx", "thiserror", "tokio", "tracing", "tracing-subscriber", + "uuid", ] [[package]] diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index 08ff21eaed0d85..e4d51dc308d349 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -15,6 +15,7 @@ tokio = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } bytes = { workspace = true } +once_cell = "1.18.0" rand = { workspace = true } redis = { version = "0.23.3", features = [ "tokio-comp", @@ -27,12 +28,13 @@ thiserror = { workspace = true } serde-pickle = { version = "1.1.1"} sha1 = "0.10.6" regex = "1.10.4" +sqlx = { workspace = true } +uuid = { workspace = true } [lints] workspace = true [dev-dependencies] assert-json-diff = { workspace = true } -once_cell = "1.18.0" reqwest = { workspace = true } diff --git a/rust/feature-flags/README.md b/rust/feature-flags/README.md index 1c9500900aade1..efce0361245246 100644 --- a/rust/feature-flags/README.md +++ b/rust/feature-flags/README.md @@ -1,6 +1,23 @@ # Testing +First, make sure docker compose is running (from main posthog repo), and test database exists: + +``` +docker compose -f ../docker-compose.dev.yml up -d +``` + +``` +TEST=1 python manage.py setup_test_environment --only-postgres +``` + +We only need to run the above once, when the test database is created. + +TODO: Would be nice to make the above automatic. + + +Then, run the tests: + ``` cargo test --package feature-flags ``` diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index ccf4735e5b04ae..2caae80bf9af6c 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -5,6 +5,9 @@ use axum::response::{IntoResponse, Response}; use serde::{Deserialize, Serialize}; use thiserror::Error; +use crate::database::CustomDatabaseError; +use crate::redis::CustomRedisError; + #[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] pub enum FlagsResponseCode { Ok = 1, @@ -42,6 +45,14 @@ pub enum FlagError { DataParsingError, #[error("redis unavailable")] RedisUnavailable, + #[error("database unavailable")] + DatabaseUnavailable, + #[error("Timed out while fetching data")] + TimeoutError, + // TODO: Consider splitting top-level errors (that are returned to the client) + // and FlagMatchingError, like timeouterror which we can gracefully handle. + // This will make the `into_response` a lot clearer as well, since it wouldn't + // have arbitrary errors that actually never make it to the client. } impl IntoResponse for FlagError { @@ -58,10 +69,53 @@ impl IntoResponse for FlagError { FlagError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, self.to_string()), - FlagError::DataParsingError | FlagError::RedisUnavailable => { - (StatusCode::SERVICE_UNAVAILABLE, self.to_string()) - } + FlagError::DataParsingError + | FlagError::RedisUnavailable + | FlagError::DatabaseUnavailable + | FlagError::TimeoutError => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()), } .into_response() } } + +impl From for FlagError { + fn from(e: CustomRedisError) -> Self { + match e { + CustomRedisError::NotFound => FlagError::TokenValidationError, + CustomRedisError::PickleError(e) => { + tracing::error!("failed to fetch data: {}", e); + FlagError::DataParsingError + } + CustomRedisError::Timeout(_) => FlagError::TimeoutError, + CustomRedisError::Other(e) => { + tracing::error!("Unknown redis error: {}", e); + FlagError::RedisUnavailable + } + } + } +} + +impl From for FlagError { + fn from(e: CustomDatabaseError) -> Self { + match e { + CustomDatabaseError::NotFound => FlagError::TokenValidationError, + CustomDatabaseError::Other(_) => { + tracing::error!("failed to get connection: {}", e); + FlagError::DatabaseUnavailable + } + CustomDatabaseError::Timeout(_) => FlagError::TimeoutError, + } + } +} + +impl From for FlagError { + fn from(e: sqlx::Error) -> Self { + // TODO: Be more precise with error handling here + tracing::error!("sqlx error: {}", e); + println!("sqlx error: {}", e); + match e { + sqlx::Error::RowNotFound => FlagError::TokenValidationError, + _ => FlagError::DatabaseUnavailable, + } + } +} diff --git a/rust/feature-flags/src/config.rs b/rust/feature-flags/src/config.rs index cc7ad37bf72c1b..d9e1bf06b1ee37 100644 --- a/rust/feature-flags/src/config.rs +++ b/rust/feature-flags/src/config.rs @@ -1,16 +1,17 @@ -use std::net::SocketAddr; - use envconfig::Envconfig; +use once_cell::sync::Lazy; +use std::net::SocketAddr; +use std::str::FromStr; -#[derive(Envconfig, Clone)] +#[derive(Envconfig, Clone, Debug)] pub struct Config { #[envconfig(default = "127.0.0.1:3001")] pub address: SocketAddr, - #[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")] + #[envconfig(default = "postgres://posthog:posthog@localhost:5432/test_posthog")] pub write_database_url: String, - #[envconfig(default = "postgres://posthog:posthog@localhost:15432/test_database")] + #[envconfig(default = "postgres://posthog:posthog@localhost:5432/test_posthog")] pub read_database_url: String, #[envconfig(default = "1024")] @@ -21,4 +22,83 @@ pub struct Config { #[envconfig(default = "redis://localhost:6379/")] pub redis_url: String, + + #[envconfig(default = "1")] + pub acquire_timeout_secs: u64, +} + +impl Config { + pub fn default_test_config() -> Self { + Self { + address: SocketAddr::from_str("127.0.0.1:0").unwrap(), + redis_url: "redis://localhost:6379/".to_string(), + write_database_url: "postgres://posthog:posthog@localhost:5432/test_posthog" + .to_string(), + read_database_url: "postgres://posthog:posthog@localhost:5432/test_posthog".to_string(), + max_concurrent_jobs: 1024, + max_pg_connections: 100, + acquire_timeout_secs: 1, + } + } +} + +pub static DEFAULT_TEST_CONFIG: Lazy = Lazy::new(Config::default_test_config); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = Config::init_from_env().unwrap(); + assert_eq!( + config.address, + SocketAddr::from_str("127.0.0.1:3001").unwrap() + ); + assert_eq!( + config.write_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!( + config.read_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!(config.max_concurrent_jobs, 1024); + assert_eq!(config.max_pg_connections, 100); + assert_eq!(config.redis_url, "redis://localhost:6379/"); + } + + #[test] + fn test_default_test_config() { + let config = Config::default_test_config(); + assert_eq!(config.address, SocketAddr::from_str("127.0.0.1:0").unwrap()); + assert_eq!( + config.write_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!( + config.read_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!(config.max_concurrent_jobs, 1024); + assert_eq!(config.max_pg_connections, 100); + assert_eq!(config.redis_url, "redis://localhost:6379/"); + } + + #[test] + fn test_default_test_config_static() { + let config = &*DEFAULT_TEST_CONFIG; + assert_eq!(config.address, SocketAddr::from_str("127.0.0.1:0").unwrap()); + assert_eq!( + config.write_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!( + config.read_database_url, + "postgres://posthog:posthog@localhost:5432/test_posthog" + ); + assert_eq!(config.max_concurrent_jobs, 1024); + assert_eq!(config.max_pg_connections, 100); + assert_eq!(config.redis_url, "redis://localhost:6379/"); + } } diff --git a/rust/feature-flags/src/database.rs b/rust/feature-flags/src/database.rs new file mode 100644 index 00000000000000..29360d22b9444f --- /dev/null +++ b/rust/feature-flags/src/database.rs @@ -0,0 +1,98 @@ +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use sqlx::{ + pool::PoolConnection, + postgres::{PgPoolOptions, PgRow}, + Postgres, +}; +use thiserror::Error; +use tokio::time::timeout; + +use crate::config::Config; + +const DATABASE_TIMEOUT_MILLISECS: u64 = 1000; + +#[derive(Error, Debug)] +pub enum CustomDatabaseError { + #[error("Not found in database")] + NotFound, + + #[error("Pg error: {0}")] + Other(#[from] sqlx::Error), + + #[error("Timeout error")] + Timeout(#[from] tokio::time::error::Elapsed), +} + +/// A simple db wrapper +/// Supports running any arbitrary query with a timeout. +/// TODO: Make sqlx prepared statements work with pgbouncer, potentially by setting pooling mode to session. +#[async_trait] +pub trait Client { + async fn get_connection(&self) -> Result, CustomDatabaseError>; + async fn run_query( + &self, + query: String, + parameters: Vec, + timeout_ms: Option, + ) -> Result, CustomDatabaseError>; +} + +pub struct PgClient { + pool: sqlx::PgPool, +} + +impl PgClient { + pub async fn new_read_client(config: &Config) -> Result { + let pool = PgPoolOptions::new() + .max_connections(config.max_pg_connections) + .acquire_timeout(Duration::from_secs(1)) + .test_before_acquire(true) + .connect(&config.read_database_url) + .await?; + + Ok(PgClient { pool }) + } + + pub async fn new_write_client(config: &Config) -> Result { + let pool = PgPoolOptions::new() + .max_connections(config.max_pg_connections) + .acquire_timeout(Duration::from_secs(1)) + .test_before_acquire(true) + .connect(&config.write_database_url) + .await?; + + Ok(PgClient { pool }) + } +} + +#[async_trait] +impl Client for PgClient { + async fn run_query( + &self, + query: String, + parameters: Vec, + timeout_ms: Option, + ) -> Result, CustomDatabaseError> { + let built_query = sqlx::query(&query); + let built_query = parameters + .iter() + .fold(built_query, |acc, param| acc.bind(param)); + let query_results = built_query.fetch_all(&self.pool); + + let timeout_ms = match timeout_ms { + Some(ms) => ms, + None => DATABASE_TIMEOUT_MILLISECS, + }; + + let fut = timeout(Duration::from_secs(timeout_ms), query_results).await?; + + Ok(fut?) + } + + async fn get_connection(&self) -> Result, CustomDatabaseError> { + Ok(self.pool.acquire().await?) + } +} diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index fbbd0445b59982..cc208ae8b073f2 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -1,11 +1,8 @@ -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; -use crate::{ - api::FlagError, - redis::{Client, CustomRedisError}, -}; +use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; // TRICKY: This cache data is coming from django-redis. If it ever goes out of sync, we'll bork. // TODO: Add integration tests across repos to ensure this doesn't happen. @@ -46,7 +43,7 @@ pub struct PropertyFilter { pub operator: Option, #[serde(rename = "type")] pub prop_type: String, - pub group_type_index: Option, + pub group_type_index: Option, } #[derive(Debug, Clone, Deserialize)] @@ -74,15 +71,15 @@ pub struct MultivariateFlagOptions { pub struct FlagFilters { pub groups: Vec, pub multivariate: Option, - pub aggregation_group_type_index: Option, + pub aggregation_group_type_index: Option, pub payloads: Option, pub super_groups: Option>, } #[derive(Debug, Clone, Deserialize)] pub struct FeatureFlag { - pub id: i64, - pub team_id: i64, + pub id: i32, + pub team_id: i32, pub name: Option, pub key: String, pub filters: FlagFilters, @@ -94,8 +91,20 @@ pub struct FeatureFlag { pub ensure_experience_continuity: bool, } +#[derive(Debug, Serialize, sqlx::FromRow)] +pub struct FeatureFlagRow { + pub id: i32, + pub team_id: i32, + pub name: Option, + pub key: String, + pub filters: serde_json::Value, + pub deleted: bool, + pub active: bool, + pub ensure_experience_continuity: bool, +} + impl FeatureFlag { - pub fn get_group_type_index(&self) -> Option { + pub fn get_group_type_index(&self) -> Option { self.filters.aggregation_group_type_index } @@ -121,27 +130,13 @@ impl FeatureFlagList { /// Returns feature flags from redis given a team_id #[instrument(skip_all)] pub async fn from_redis( - client: Arc, - team_id: i64, + client: Arc, + team_id: i32, ) -> Result { // TODO: Instead of failing here, i.e. if not in redis, fallback to pg let serialized_flags = client .get(format!("{TEAM_FLAGS_CACHE_PREFIX}{}", team_id)) - .await - .map_err(|e| match e { - CustomRedisError::NotFound => FlagError::TokenValidationError, - CustomRedisError::PickleError(_) => { - // TODO: Implement From trait for FlagError so we don't need to map - // CustomRedisError ourselves - tracing::error!("failed to fetch data: {}", e); - println!("failed to fetch data: {}", e); - FlagError::DataParsingError - } - _ => { - tracing::error!("Unknown redis error: {}", e); - FlagError::RedisUnavailable - } - })?; + .await?; let flags_list: Vec = serde_json::from_str(&serialized_flags).map_err(|e| { @@ -153,13 +148,45 @@ impl FeatureFlagList { Ok(FeatureFlagList { flags: flags_list }) } + + /// Returns feature flags from postgres given a team_id + #[instrument(skip_all)] + pub async fn from_pg( + client: Arc, + team_id: i32, + ) -> Result { + let mut conn = client.get_connection().await?; + // TODO: Clean up error handling here + + let query = "SELECT id, team_id, name, key, filters, deleted, active, ensure_experience_continuity FROM posthog_featureflag WHERE team_id = $1"; + let flags_row = sqlx::query_as::<_, FeatureFlagRow>(query) + .bind(team_id) + .fetch_all(&mut *conn) + .await?; + + let serialized_flags = serde_json::to_string(&flags_row).map_err(|e| { + tracing::error!("failed to serialize flags: {}", e); + println!("failed to serialize flags: {}", e); + FlagError::DataParsingError + })?; + + let flags_list: Vec = + serde_json::from_str(&serialized_flags).map_err(|e| { + tracing::error!("failed to parse data to flags list: {}", e); + println!("failed to parse data: {}", e); + + FlagError::DataParsingError + })?; + Ok(FeatureFlagList { flags: flags_list }) + } } #[cfg(test)] mod tests { use super::*; use crate::test_utils::{ - insert_flags_for_team_in_redis, insert_new_team_in_redis, setup_redis_client, + insert_flags_for_team_in_pg, insert_flags_for_team_in_redis, insert_new_team_in_pg, + insert_new_team_in_redis, setup_pg_client, setup_redis_client, }; #[tokio::test] @@ -211,4 +238,64 @@ mod tests { _ => panic!("Expected RedisUnavailable"), }; } + + #[tokio::test] + async fn test_fetch_flags_from_pg() { + let client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(client.clone()) + .await + .expect("Failed to insert team in pg"); + + insert_flags_for_team_in_pg(client.clone(), team.id, None) + .await + .expect("Failed to insert flags"); + + let flags_from_pg = FeatureFlagList::from_pg(client.clone(), team.id) + .await + .expect("Failed to fetch flags from pg"); + + assert_eq!(flags_from_pg.flags.len(), 1); + let flag = flags_from_pg.flags.get(0).expect("Flags should be in pg"); + + assert_eq!(flag.key, "flag1"); + assert_eq!(flag.team_id, team.id); + assert_eq!(flag.filters.groups.len(), 1); + assert_eq!( + flag.filters.groups[0] + .properties + .as_ref() + .expect("Properties don't exist on flag") + .len(), + 1 + ); + let property_filter = &flag.filters.groups[0] + .properties + .as_ref() + .expect("Properties don't exist on flag")[0]; + + assert_eq!(property_filter.key, "email"); + assert_eq!(property_filter.value, "a@b.com"); + assert_eq!(property_filter.operator, None); + assert_eq!(property_filter.prop_type, "person"); + assert_eq!(property_filter.group_type_index, None); + assert_eq!(flag.filters.groups[0].rollout_percentage, Some(50.0)); + } + + // TODO: Add more tests to validate deserialization of flags. + // TODO: Also make sure old flag data is handled, or everything is migrated to new style in production + + #[tokio::test] + async fn test_fetch_empty_team_from_pg() { + let client = setup_pg_client(None).await; + + match FeatureFlagList::from_pg(client.clone(), 1234) + .await + .expect("Failed to fetch flags from pg") + { + FeatureFlagList { flags } => { + assert_eq!(flags.len(), 0); + } + } + } } diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 510fc153dc87a5..485d8a646e8237 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,6 +1,12 @@ -use crate::flag_definitions::{FeatureFlag, FlagGroupType}; +use crate::{ + api::FlagError, + database::Client as DatabaseClient, + flag_definitions::{FeatureFlag, FlagGroupType}, + property_matching::match_property, +}; +use serde_json::Value; use sha1::{Digest, Sha1}; -use std::fmt::Write; +use std::{collections::HashMap, fmt::Write, sync::Arc}; #[derive(Debug, PartialEq, Eq)] pub struct FeatureFlagMatch { @@ -11,6 +17,11 @@ pub struct FeatureFlagMatch { //payload } +#[derive(Debug, sqlx::FromRow)] +pub struct Person { + pub properties: sqlx::types::Json>, +} + // TODO: Rework FeatureFlagMatcher - python has a pretty awkward interface, where we pass in all flags, and then again // the flag to match. I don't think there's any reason anymore to store the flags in the matcher, since we can just // pass the flag to match directly to the get_match method. This will also make the matcher more stateless. @@ -21,23 +32,30 @@ pub struct FeatureFlagMatch { // for all teams. If not, we can have a LRU cache, or a cache that stores only the most recent N keys. // But, this can be a future refactor, for now just focusing on getting the basic matcher working, write lots and lots of tests // and then we can easily refactor stuff around. -#[derive(Debug)] +// #[derive(Debug)] pub struct FeatureFlagMatcher { // pub flags: Vec, pub distinct_id: String, + pub database_client: Option>, + cached_properties: Option>, } const LONG_SCALE: u64 = 0xfffffffffffffff; impl FeatureFlagMatcher { - pub fn new(distinct_id: String) -> Self { + pub fn new( + distinct_id: String, + database_client: Option>, + ) -> Self { FeatureFlagMatcher { // flags, distinct_id, + database_client, + cached_properties: None, } } - pub fn get_match(&self, feature_flag: &FeatureFlag) -> FeatureFlagMatch { + pub async fn get_match(&mut self, feature_flag: &FeatureFlag) -> FeatureFlagMatch { if self.hashed_identifier(feature_flag).is_none() { return FeatureFlagMatch { matches: false, @@ -49,8 +67,9 @@ impl FeatureFlagMatcher { // TODO: Variant overrides condition sort for (index, condition) in feature_flag.get_conditions().iter().enumerate() { - let (is_match, _evaluation_reason) = - self.is_condition_match(feature_flag, condition, index); + let (is_match, _evaluation_reason) = self + .is_condition_match(feature_flag, condition, index) + .await; if is_match { // TODO: This is a bit awkward, we should handle overrides only when variants exist. @@ -82,20 +101,33 @@ impl FeatureFlagMatcher { } } - pub fn is_condition_match( - &self, + // TODO: Making all this mutable just to store a cached value is annoying. Can I refactor this to be non-mutable? + // Leaning a bit more towards a separate cache store for this. + pub async fn is_condition_match( + &mut self, feature_flag: &FeatureFlag, condition: &FlagGroupType, _index: usize, ) -> (bool, String) { let rollout_percentage = condition.rollout_percentage.unwrap_or(100.0); let mut condition_match = true; - if condition.properties.is_some() { - // TODO: Handle matching conditions - if !condition.properties.as_ref().unwrap().is_empty() { - condition_match = false; + + if let Some(ref properties) = condition.properties { + if properties.is_empty() { + condition_match = true; + } else { + // TODO: First handle given override properties before going to db + let target_properties = self + .get_person_properties(feature_flag.team_id, self.distinct_id.clone()) + .await + .unwrap_or_default(); + // TODO: Handle db issues / person not found + + condition_match = properties.iter().all(|property| { + match_property(property, &target_properties, false).unwrap_or(false) + }); } - } + }; if !condition_match { return (false, "NO_CONDITION_MATCH".to_string()); @@ -157,4 +189,133 @@ impl FeatureFlagMatcher { } None } + + pub async fn get_person_properties( + &mut self, + team_id: i32, + distinct_id: String, + ) -> Result, FlagError> { + // TODO: Do we even need to cache here anymore? + // Depends on how often we're calling this function + // to match all flags for a single person + + if let Some(cached_props) = self.cached_properties.clone() { + // TODO: Maybe we don't want to copy around all user properties, this will by far be the largest chunk + // of data we're copying around. Can we work with references here? + // Worst case, just use a Rc. + return Ok(cached_props); + } + + if self.database_client.is_none() { + return Err(FlagError::DatabaseUnavailable); + } + + let mut conn = self + .database_client + .as_ref() + .expect("client should exist here") + .get_connection() + .await?; + + let query = r#" + SELECT "posthog_person"."properties" + FROM "posthog_person" + INNER JOIN "posthog_persondistinctid" ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") + WHERE ("posthog_persondistinctid"."distinct_id" = $1 + AND "posthog_persondistinctid"."team_id" = $2 + AND "posthog_person"."team_id" = $3) + LIMIT 1; + "#; + + let row = sqlx::query_as::<_, Person>(query) + .bind(&distinct_id) + .bind(team_id) + .bind(team_id) + .fetch_optional(&mut *conn) + .await?; + + let props = match row { + Some(row) => row.properties.0, + None => HashMap::new(), + }; + + self.cached_properties = Some(props.clone()); + + Ok(props) + } +} + +#[cfg(test)] +mod tests { + + use serde_json::json; + + use super::*; + use crate::test_utils::{insert_new_team_in_pg, insert_person_for_team_in_pg, setup_pg_client}; + + #[tokio::test] + async fn test_fetch_properties_from_pg_to_match() { + let client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(client.clone()) + .await + .expect("Failed to insert team in pg"); + + let distinct_id = "user_distinct_id".to_string(); + insert_person_for_team_in_pg(client.clone(), team.id, distinct_id.clone(), None) + .await + .expect("Failed to insert person"); + + let not_matching_distinct_id = "not_matching_distinct_id".to_string(); + insert_person_for_team_in_pg( + client.clone(), + team.id, + not_matching_distinct_id.clone(), + Some(json!({ "email": "a@x.com"})), + ) + .await + .expect("Failed to insert person"); + + let flag = serde_json::from_value(json!( + { + "id": 1, + "team_id": team.id, + "name": "flag1", + "key": "flag1", + "filters": { + "groups": [ + { + "properties": [ + { + "key": "email", + "value": "a@b.com", + "type": "person" + } + ], + "rollout_percentage": 100 + } + ] + } + } + )) + .unwrap(); + + let mut matcher = FeatureFlagMatcher::new(distinct_id, Some(client.clone())); + let match_result = matcher.get_match(&flag).await; + assert_eq!(match_result.matches, true); + assert_eq!(match_result.variant, None); + + // property value is different + let mut matcher = FeatureFlagMatcher::new(not_matching_distinct_id, Some(client.clone())); + let match_result = matcher.get_match(&flag).await; + assert_eq!(match_result.matches, false); + assert_eq!(match_result.variant, None); + + // person does not exist + let mut matcher = + FeatureFlagMatcher::new("other_distinct_id".to_string(), Some(client.clone())); + let match_result = matcher.get_match(&flag).await; + assert_eq!(match_result.matches, false); + assert_eq!(match_result.variant, None); + } } diff --git a/rust/feature-flags/src/lib.rs b/rust/feature-flags/src/lib.rs index 7f03747b9ee6d8..7784bd7bf1b8dc 100644 --- a/rust/feature-flags/src/lib.rs +++ b/rust/feature-flags/src/lib.rs @@ -1,5 +1,6 @@ pub mod api; pub mod config; +pub mod database; pub mod flag_definitions; pub mod flag_matching; pub mod property_matching; diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index 8824d44efdbde9..2fbc87c8709304 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -2,18 +2,59 @@ use std::sync::Arc; use axum::{routing::post, Router}; -use crate::{redis::Client, v0_endpoint}; +use crate::{database::Client as DatabaseClient, redis::Client as RedisClient, v0_endpoint}; #[derive(Clone)] pub struct State { - pub redis: Arc, + pub redis: Arc, // TODO: Add pgClient when ready + pub postgres: Arc, } -pub fn router(redis: Arc) -> Router { - let state = State { redis }; +pub fn router(redis: Arc, postgres: Arc) -> Router +where + R: RedisClient + Send + Sync + 'static, + D: DatabaseClient + Send + Sync + 'static, +{ + let state = State { redis, postgres }; Router::new() .route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags)) .with_state(state) } + +// TODO, eventually we can differentiate read and write postgres clients, if needed +// I _think_ everything is read-only, but I'm not 100% sure yet +// here's how that client would look +// use std::sync::Arc; + +// use axum::{routing::post, Router}; + +// use crate::{database::Client as DatabaseClient, redis::Client as RedisClient, v0_endpoint}; + +// #[derive(Clone)] +// pub struct State { +// pub redis: Arc, +// pub postgres_read: Arc, +// pub postgres_write: Arc, +// } + +// pub fn router( +// redis: Arc, +// postgres_read: Arc, +// postgres_write: Arc, +// ) -> Router +// where +// R: RedisClient + Send + Sync + 'static, +// D: DatabaseClient + Send + Sync + 'static, +// { +// let state = State { +// redis, +// postgres_read, +// postgres_write, +// }; + +// Router::new() +// .route("/flags", post(v0_endpoint::flags).get(v0_endpoint::flags)) +// .with_state(state) +// } diff --git a/rust/feature-flags/src/server.rs b/rust/feature-flags/src/server.rs index ffe6b0efb70681..37bd721a9a51f6 100644 --- a/rust/feature-flags/src/server.rs +++ b/rust/feature-flags/src/server.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use tokio::net::TcpListener; use crate::config::Config; - +use crate::database::PgClient; use crate::redis::RedisClient; use crate::router; @@ -13,13 +13,25 @@ pub async fn serve(config: Config, listener: TcpListener, shutdown: F) where F: Future + Send + 'static, { - let redis_client = - Arc::new(RedisClient::new(config.redis_url).expect("failed to create redis client")); + let redis_client = match RedisClient::new(config.redis_url.clone()) { + Ok(client) => Arc::new(client), + Err(e) => { + tracing::error!("Failed to create Redis client: {}", e); + return; + } + }; + + let read_postgres_client = match PgClient::new_read_client(&config).await { + Ok(client) => Arc::new(client), + Err(e) => { + tracing::error!("Failed to create read Postgres client: {}", e); + return; + } + }; - let app = router::router(redis_client); + // You can decide which client to pass to the router, or pass both if needed + let app = router::router(redis_client, read_postgres_client); - // run our app with hyper - // `axum::Server` is a re-export of `hyper::Server` tracing::info!("listening on {:?}", listener.local_addr().unwrap()); axum::serve( listener, diff --git a/rust/feature-flags/src/team.rs b/rust/feature-flags/src/team.rs index e872aa477968f3..7c7cfd9547bbf9 100644 --- a/rust/feature-flags/src/team.rs +++ b/rust/feature-flags/src/team.rs @@ -2,18 +2,15 @@ use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; -use crate::{ - api::FlagError, - redis::{Client, CustomRedisError}, -}; +use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; // TRICKY: This cache data is coming from django-redis. If it ever goes out of sync, we'll bork. // TODO: Add integration tests across repos to ensure this doesn't happen. pub const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:"; -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize, sqlx::FromRow)] pub struct Team { - pub id: i64, + pub id: i32, pub name: String, pub api_token: String, } @@ -23,24 +20,13 @@ impl Team { #[instrument(skip_all)] pub async fn from_redis( - client: Arc, + client: Arc, token: String, ) -> Result { // TODO: Instead of failing here, i.e. if not in redis, fallback to pg let serialized_team = client .get(format!("{TEAM_TOKEN_CACHE_PREFIX}{}", token)) - .await - .map_err(|e| match e { - CustomRedisError::NotFound => FlagError::TokenValidationError, - CustomRedisError::PickleError(_) => { - tracing::error!("failed to fetch data: {}", e); - FlagError::DataParsingError - } - _ => { - tracing::error!("Unknown redis error: {}", e); - FlagError::RedisUnavailable - } - })?; + .await?; // TODO: Consider an LRU cache for teams as well, with small TTL to skip redis/pg lookups let team: Team = serde_json::from_str(&serialized_team).map_err(|e| { @@ -50,6 +36,21 @@ impl Team { Ok(team) } + + pub async fn from_pg( + client: Arc, + token: String, + ) -> Result { + let mut conn = client.get_connection().await?; + + let query = "SELECT id, name, api_token FROM posthog_team WHERE api_token = $1"; + let row = sqlx::query_as::<_, Team>(query) + .bind(&token) + .fetch_one(&mut *conn) + .await?; + + Ok(row) + } } #[cfg(test)] @@ -60,14 +61,19 @@ mod tests { use super::*; use crate::{ team, - test_utils::{insert_new_team_in_redis, random_string, setup_redis_client}, + test_utils::{ + insert_new_team_in_pg, insert_new_team_in_redis, random_string, setup_pg_client, + setup_redis_client, + }, }; #[tokio::test] async fn test_fetch_team_from_redis() { let client = setup_redis_client(None); - let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + let team = insert_new_team_in_redis(client.clone()) + .await + .expect("Failed to insert team in redis"); let target_token = team.api_token; @@ -137,4 +143,39 @@ mod tests { Ok(_) => panic!("Expected DataParsingError"), }; } + + #[tokio::test] + async fn test_fetch_team_from_pg() { + let client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(client.clone()) + .await + .expect("Failed to insert team in pg"); + + let target_token = team.api_token; + + let team_from_pg = Team::from_pg(client.clone(), target_token.clone()) + .await + .expect("Failed to fetch team from pg"); + + assert_eq!(team_from_pg.api_token, target_token); + assert_eq!(team_from_pg.id, team.id); + assert_eq!(team_from_pg.name, team.name); + } + + #[tokio::test] + async fn test_fetch_team_from_pg_with_invalid_token() { + // TODO: Figure out a way such that `run_database_migrations` is called only once, and already called + // before running these tests. + + let client = setup_pg_client(None).await; + let target_token = "xxxx".to_string(); + + match Team::from_pg(client.clone(), target_token.clone()).await { + Err(FlagError::TokenValidationError) => (), + _ => panic!("Expected TokenValidationError"), + }; + } + + // TODO: Handle cases where db connection fails. } diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index 92bc8a4ff44941..9d1f5970d46b67 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -1,10 +1,13 @@ use anyhow::Error; -use serde_json::json; +use serde_json::{json, Value}; use std::sync::Arc; +use uuid::Uuid; use crate::{ - flag_definitions::{self, FeatureFlag}, - redis::{Client, RedisClient}, + config::{Config, DEFAULT_TEST_CONFIG}, + database::{Client as DatabaseClientTrait, PgClient}, + flag_definitions::{self, FeatureFlag, FeatureFlagRow}, + redis::{Client as RedisClientTrait, RedisClient}, team::{self, Team}, }; use rand::{distributions::Alphanumeric, Rng}; @@ -44,7 +47,7 @@ pub async fn insert_new_team_in_redis(client: Arc) -> Result, - team_id: i64, + team_id: i32, json_value: Option, ) -> Result<(), Error> { let payload = match json_value { @@ -124,3 +127,149 @@ pub fn create_flag_from_json(json_value: Option) -> Vec { serde_json::from_str(&payload).expect("Failed to parse data to flags list"); flags } + +pub async fn setup_pg_client(config: Option<&Config>) -> Arc { + let config = config.unwrap_or(&DEFAULT_TEST_CONFIG); + Arc::new( + PgClient::new_read_client(config) + .await + .expect("Failed to create pg read client"), + ) +} + +pub async fn insert_new_team_in_pg(client: Arc) -> Result { + const ORG_ID: &str = "019026a4be8000005bf3171d00629163"; + + client.run_query( + r#"INSERT INTO posthog_organization + (id, name, slug, created_at, updated_at, plugins_access_level, for_internal_metrics, is_member_join_email_enabled, enforce_2fa, is_hipaa, customer_id, available_product_features, personalization, setup_section_2_completed, domain_whitelist) + VALUES + ($1::uuid, 'Test Organization', 'test-organization', '2024-06-17 14:40:49.298579+00:00', '2024-06-17 14:40:49.298593+00:00', 9, false, true, NULL, false, NULL, '{}', '{}', true, '{}') + ON CONFLICT DO NOTHING"#.to_string(), + vec![ORG_ID.to_string()], + Some(2000), + ).await?; + + client + .run_query( + r#"INSERT INTO posthog_project + (id, organization_id, name, created_at) + VALUES + (1, $1::uuid, 'Test Team', '2024-06-17 14:40:51.329772+00:00') + ON CONFLICT DO NOTHING"# + .to_string(), + vec![ORG_ID.to_string()], + Some(2000), + ) + .await?; + + let id = rand::thread_rng().gen_range(0..10_000_000); + let token = random_string("phc_", 12); + let team = Team { + id, + name: "team".to_string(), + api_token: token, + }; + let uuid = Uuid::now_v7(); + + let mut conn = client.get_connection().await?; + let res = sqlx::query( + r#"INSERT INTO posthog_team + (id, uuid, organization_id, project_id, api_token, name, created_at, updated_at, app_urls, anonymize_ips, completed_snippet_onboarding, ingested_event, session_recording_opt_in, is_demo, access_control, test_account_filters, timezone, data_attributes, plugins_opt_in, opt_out_capture, event_names, event_names_with_usage, event_properties, event_properties_with_usage, event_properties_numerical) VALUES + ($1, $5, $2::uuid, 1, $3, $4, '2024-06-17 14:40:51.332036+00:00', '2024-06-17', '{}', false, false, false, false, false, false, '{}', 'UTC', '["data-attr"]', false, false, '[]', '[]', '[]', '[]', '[]')"# + ).bind(team.id).bind(ORG_ID).bind(&team.api_token).bind(&team.name).bind(uuid).execute(&mut *conn).await?; + + assert_eq!(res.rows_affected(), 1); + + Ok(team) +} + +pub async fn insert_flags_for_team_in_pg( + client: Arc, + team_id: i32, + flag: Option, +) -> Result { + let id = rand::thread_rng().gen_range(0..10_000_000); + + let payload_flag = match flag { + Some(value) => value, + None => FeatureFlagRow { + id, + key: "flag1".to_string(), + name: Some("flag1 description".to_string()), + active: true, + deleted: false, + ensure_experience_continuity: false, + team_id, + filters: json!({ + "groups": [ + { + "properties": [ + { + "key": "email", + "value": "a@b.com", + "type": "person", + }, + ], + "rollout_percentage": 50, + }, + ], + }), + }, + }; + + let mut conn = client.get_connection().await?; + let res = sqlx::query( + r#"INSERT INTO posthog_featureflag + (id, team_id, name, key, filters, deleted, active, ensure_experience_continuity, created_at) VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, '2024-06-17')"# + ).bind(payload_flag.id).bind(team_id).bind(&payload_flag.name).bind(&payload_flag.key).bind(&payload_flag.filters).bind(payload_flag.deleted).bind(payload_flag.active).bind(payload_flag.ensure_experience_continuity).execute(&mut *conn).await?; + + assert_eq!(res.rows_affected(), 1); + + Ok(payload_flag) +} + +pub async fn insert_person_for_team_in_pg( + client: Arc, + team_id: i32, + distinct_id: String, + properties: Option, +) -> Result<(), Error> { + let payload = match properties { + Some(value) => value, + None => json!({ + "email": "a@b.com", + "name": "Alice", + }), + }; + + let uuid = Uuid::now_v7(); + + let mut conn = client.get_connection().await?; + let res = sqlx::query( + r#" + WITH inserted_person AS ( + INSERT INTO posthog_person ( + created_at, properties, properties_last_updated_at, + properties_last_operation, team_id, is_user_id, is_identified, uuid, version + ) + VALUES ('2023-04-05', $1, '{}', '{}', $2, NULL, true, $3, 0) + RETURNING * + ) + INSERT INTO posthog_persondistinctid (distinct_id, person_id, team_id, version) + VALUES ($4, (SELECT id FROM inserted_person), $5, 0) + "#, + ) + .bind(&payload) + .bind(team_id) + .bind(uuid) + .bind(&distinct_id) + .bind(team_id) + .execute(&mut *conn) + .await?; + + assert_eq!(res.rows_affected(), 1); + + Ok(()) +} diff --git a/rust/feature-flags/tests/common/mod.rs b/rust/feature-flags/tests/common/mod.rs index c8644fe1f45428..2b14292e0fda39 100644 --- a/rust/feature-flags/tests/common/mod.rs +++ b/rust/feature-flags/tests/common/mod.rs @@ -1,9 +1,6 @@ use std::net::SocketAddr; -use std::str::FromStr; -use std::string::ToString; use std::sync::Arc; -use once_cell::sync::Lazy; use reqwest::header::CONTENT_TYPE; use tokio::net::TcpListener; use tokio::sync::Notify; @@ -11,15 +8,6 @@ use tokio::sync::Notify; use feature_flags::config::Config; use feature_flags::server::serve; -pub static DEFAULT_CONFIG: Lazy = Lazy::new(|| Config { - address: SocketAddr::from_str("127.0.0.1:0").unwrap(), - redis_url: "redis://localhost:6379/".to_string(), - write_database_url: "postgres://posthog:posthog@localhost:15432/test_database".to_string(), - read_database_url: "postgres://posthog:posthog@localhost:15432/test_database".to_string(), - max_concurrent_jobs: 1024, - max_pg_connections: 100, -}); - pub struct ServerHandle { pub addr: SocketAddr, shutdown: Arc, diff --git a/rust/feature-flags/tests/test_flag_matching_consistency.rs b/rust/feature-flags/tests/test_flag_matching_consistency.rs index 4a24b0e16d50e1..d4b55ed4e90016 100644 --- a/rust/feature-flags/tests/test_flag_matching_consistency.rs +++ b/rust/feature-flags/tests/test_flag_matching_consistency.rs @@ -5,8 +5,8 @@ use feature_flags::flag_matching::{FeatureFlagMatch, FeatureFlagMatcher}; use feature_flags::test_utils::create_flag_from_json; use serde_json::json; -#[test] -fn it_is_consistent_with_rollout_calculation_for_simple_flags() { +#[tokio::test] +async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { let flags = create_flag_from_json(Some( json!([{ "id": 1, @@ -107,7 +107,9 @@ fn it_is_consistent_with_rollout_calculation_for_simple_flags() { for i in 0..1000 { let distinct_id = format!("distinct_id_{}", i); - let feature_flag_match = FeatureFlagMatcher::new(distinct_id).get_match(&flags[0]); + let feature_flag_match = FeatureFlagMatcher::new(distinct_id, None) + .get_match(&flags[0]) + .await; if results[i] { assert_eq!( @@ -129,8 +131,8 @@ fn it_is_consistent_with_rollout_calculation_for_simple_flags() { } } -#[test] -fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { +#[tokio::test] +async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { let flags = create_flag_from_json(Some( json!([{ "id": 1, @@ -1186,7 +1188,9 @@ fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { for i in 0..1000 { let distinct_id = format!("distinct_id_{}", i); - let feature_flag_match = FeatureFlagMatcher::new(distinct_id).get_match(&flags[0]); + let feature_flag_match = FeatureFlagMatcher::new(distinct_id, None) + .get_match(&flags[0]) + .await; if results[i].is_some() { assert_eq!( diff --git a/rust/feature-flags/tests/test_flags.rs b/rust/feature-flags/tests/test_flags.rs index 2ceba24efd7120..f9a46e1c543aff 100644 --- a/rust/feature-flags/tests/test_flags.rs +++ b/rust/feature-flags/tests/test_flags.rs @@ -6,13 +6,14 @@ use serde_json::{json, Value}; use crate::common::*; +use feature_flags::config::DEFAULT_TEST_CONFIG; use feature_flags::test_utils::{insert_new_team_in_redis, setup_redis_client}; pub mod common; #[tokio::test] async fn it_sends_flag_request() -> Result<()> { - let config = DEFAULT_CONFIG.clone(); + let config = DEFAULT_TEST_CONFIG.clone(); let distinct_id = "user_distinct_id".to_string(); @@ -50,7 +51,7 @@ async fn it_sends_flag_request() -> Result<()> { #[tokio::test] async fn it_rejects_invalid_headers_flag_request() -> Result<()> { - let config = DEFAULT_CONFIG.clone(); + let config = DEFAULT_TEST_CONFIG.clone(); let distinct_id = "user_distinct_id".to_string();