From d6ad9fa096abbdfc684dca7e093afbbd33434503 Mon Sep 17 00:00:00 2001 From: Dylan Martin Date: Thu, 22 Aug 2024 12:10:47 -0400 Subject: [PATCH] feat(flags): return real feature flag evaluation for basic matching with new endpoint (#24358) --- rust/feature-flags/src/api.rs | 31 ++++- rust/feature-flags/src/config.rs | 8 +- rust/feature-flags/src/flag_definitions.rs | 37 ++++-- rust/feature-flags/src/team.rs | 15 ++- rust/feature-flags/src/v0_endpoint.rs | 127 ++++++++++++++++++--- rust/feature-flags/src/v0_request.rs | 99 +++++++++++++++- rust/feature-flags/tests/test_flags.rs | 32 ++++-- 7 files changed, 311 insertions(+), 38 deletions(-) diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index da2b00fbfdef5..285e09edc5c77 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -20,11 +20,40 @@ pub enum FlagValue { String(String), } +// TODO the following two types are kinda general, maybe we should move them to a shared module +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] +#[serde(untagged)] +pub enum BooleanOrStringObject { + Boolean(bool), + Object(HashMap), +} + +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] +#[serde(untagged)] +pub enum BooleanOrBooleanObject { + Boolean(bool), + Object(HashMap), +} + #[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct FlagsResponse { pub error_while_computing_flags: bool, pub feature_flags: HashMap, + // TODO support the other fields in the payload + // pub config: HashMap, + // pub toolbar_params: HashMap, + // pub is_authenticated: bool, + // pub supported_compression: Vec, + // pub session_recording: bool, + // pub feature_flag_payloads: HashMap, + // pub capture_performance: BooleanOrBooleanObject, + // #[serde(rename = "autocapture_opt_out")] + // pub autocapture_opt_out: bool, + // pub autocapture_exceptions: BooleanOrStringObject, + // pub surveys: bool, + // pub heatmaps: bool, + // pub site_apps: Vec, } #[derive(Error, Debug)] @@ -98,7 +127,7 @@ impl IntoResponse for FlagError { (StatusCode::BAD_REQUEST, "The distinct_id field is missing from the request. Please include a valid identifier.".to_string()) } FlagError::NoTokenError => { - (StatusCode::UNAUTHORIZED, "No API key provided. Please include a valid API key in your request.".to_string()) + (StatusCode::UNAUTHORIZED, "No API token provided. Please include a valid API token in your request.".to_string()) } FlagError::TokenValidationError => { (StatusCode::UNAUTHORIZED, "The provided API key is invalid or has expired. Please check your API key and try again.".to_string()) diff --git a/rust/feature-flags/src/config.rs b/rust/feature-flags/src/config.rs index d9e1bf06b1ee3..a9e3517398770 100644 --- a/rust/feature-flags/src/config.rs +++ b/rust/feature-flags/src/config.rs @@ -8,10 +8,10 @@ pub struct Config { #[envconfig(default = "127.0.0.1:3001")] pub address: SocketAddr, - #[envconfig(default = "postgres://posthog:posthog@localhost:5432/test_posthog")] + #[envconfig(default = "postgres://posthog:posthog@localhost:5432/posthog")] pub write_database_url: String, - #[envconfig(default = "postgres://posthog:posthog@localhost:5432/test_posthog")] + #[envconfig(default = "postgres://posthog:posthog@localhost:5432/posthog")] pub read_database_url: String, #[envconfig(default = "1024")] @@ -57,11 +57,11 @@ mod tests { ); assert_eq!( config.write_database_url, - "postgres://posthog:posthog@localhost:5432/test_posthog" + "postgres://posthog:posthog@localhost:5432/posthog" ); assert_eq!( config.read_database_url, - "postgres://posthog:posthog@localhost:5432/test_posthog" + "postgres://posthog:posthog@localhost:5432/posthog" ); assert_eq!(config.max_concurrent_jobs, 1024); assert_eq!(config.max_pg_connections, 100); diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index ef1db6762a5ce..225a4ba4898a9 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -12,7 +12,7 @@ pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_"; #[derive(Debug, Deserialize)] pub enum GroupTypeIndex {} -#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum OperatorType { Exact, @@ -32,7 +32,7 @@ pub enum OperatorType { IsDateBefore, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct PropertyFilter { pub key: String, // TODO: Probably need a default for value? @@ -45,26 +45,26 @@ pub struct PropertyFilter { pub group_type_index: Option, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct FlagGroupType { pub properties: Option>, pub rollout_percentage: Option, pub variant: Option, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct MultivariateFlagVariant { pub key: String, pub name: Option, pub rollout_percentage: f64, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct MultivariateFlagOptions { pub variants: Vec, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct FlagFilters { pub groups: Vec, pub multivariate: Option, @@ -73,7 +73,7 @@ pub struct FlagFilters { pub super_groups: Option>, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct FeatureFlag { pub id: i32, pub team_id: i32, @@ -117,7 +117,7 @@ impl FeatureFlag { } } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct FeatureFlagList { pub flags: Vec, } @@ -189,6 +189,27 @@ impl FeatureFlagList { Ok(FeatureFlagList { flags: flags_list }) } + + pub async fn update_flags_in_redis( + client: Arc, + team_id: i32, + flags: &FeatureFlagList, + ) -> Result<(), FlagError> { + let payload = serde_json::to_string(flags).map_err(|e| { + tracing::error!("Failed to serialize flags: {}", e); + FlagError::DataParsingError + })?; + + client + .set(format!("{TEAM_FLAGS_CACHE_PREFIX}{}", team_id), payload) + .await + .map_err(|e| { + tracing::error!("Failed to update Redis cache: {}", e); + FlagError::CacheUpdateError + })?; + + Ok(()) + } } #[cfg(test)] diff --git a/rust/feature-flags/src/team.rs b/rust/feature-flags/src/team.rs index 678668490485d..bd975385eb216 100644 --- a/rust/feature-flags/src/team.rs +++ b/rust/feature-flags/src/team.rs @@ -8,11 +8,22 @@ use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as // TODO: Add integration tests across repos to ensure this doesn't happen. pub const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:"; -#[derive(Debug, Deserialize, Serialize, sqlx::FromRow)] +#[derive(Clone, Debug, Deserialize, Serialize, sqlx::FromRow)] pub struct Team { pub id: i32, pub name: String, pub api_token: String, + // TODO: the following fields are used for the `/decide` response, + // but they're not used for flags and they don't live in redis. + // At some point I'll need to differentiate between teams in Redis and teams + // with additional fields in Postgres, since the Postgres team is a superset of the fields + // we use for flags, anyway. + // pub surveys_opt_in: bool, + // pub heatmaps_opt_in: bool, + // pub capture_performance_opt_in: bool, + // pub autocapture_web_vitals_opt_in: bool, + // pub autocapture_opt_out: bool, + // pub autocapture_exceptions_opt_in: bool, } impl Team { @@ -40,7 +51,7 @@ impl Team { #[instrument(skip_all)] pub async fn update_redis_cache( client: Arc, - team: Team, + team: &Team, ) -> Result<(), FlagError> { let serialized_team = serde_json::to_string(&team).map_err(|e| { tracing::error!("Failed to serialize team: {}", e); diff --git a/rust/feature-flags/src/v0_endpoint.rs b/rust/feature-flags/src/v0_endpoint.rs index d32f976d94447..65136e5c42772 100644 --- a/rust/feature-flags/src/v0_endpoint.rs +++ b/rust/feature-flags/src/v0_endpoint.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use axum::{debug_handler, Json}; use bytes::Bytes; @@ -6,9 +7,13 @@ use bytes::Bytes; use axum::extract::{MatchedPath, Query, State}; use axum::http::{HeaderMap, Method}; use axum_client_ip::InsecureClientIp; -use tracing::instrument; +use tracing::{error, instrument, warn}; use crate::api::FlagValue; +use crate::database::Client; +use crate::flag_definitions::FeatureFlagList; +use crate::flag_matching::FeatureFlagMatcher; +use crate::v0_request::Compression; use crate::{ api::{FlagError, FlagsResponse}, router, @@ -42,20 +47,35 @@ pub async fn flags( path: MatchedPath, body: Bytes, ) -> Result, FlagError> { + // TODO this could be extracted into some load_data_for_request type thing let user_agent = headers .get("user-agent") .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); let content_encoding = headers .get("content-encoding") .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); + let comp = match meta.compression { + None => String::from("unknown"), + Some(Compression::Gzip) => String::from("gzip"), + // Some(Compression::Base64) => String::from("base64"), + Some(Compression::Unsupported) => String::from("unsupported"), + }; + // TODO what do we use this for? + let sent_at = meta.sent_at.unwrap_or(0); tracing::Span::current().record("user_agent", user_agent); tracing::Span::current().record("content_encoding", content_encoding); tracing::Span::current().record("version", meta.version.clone()); + tracing::Span::current().record("lib_version", meta.lib_version.clone()); + tracing::Span::current().record("compression", comp.as_str()); tracing::Span::current().record("method", method.as_str()); tracing::Span::current().record("path", path.as_str().trim_end_matches('/')); tracing::Span::current().record("ip", ip.to_string()); + tracing::Span::current().record("sent_at", &sent_at.to_string()); + tracing::debug!("request headers: {:?}", headers); + + // TODO handle different content types and encodings let request = match headers .get("content-type") .map_or("", |v| v.to_str().unwrap_or("")) @@ -64,6 +84,7 @@ pub async fn flags( tracing::Span::current().record("content_type", "application/json"); FlagRequest::from_bytes(body) } + // TODO support other content types ct => { return Err(FlagError::RequestDecodingError(format!( "unsupported content type: {}", @@ -72,27 +93,107 @@ pub async fn flags( } }?; + // this errors up top-level if there's no token + // return the team here, too? let token = request .extract_and_verify_token(state.redis.clone(), state.postgres.clone()) .await?; + // at this point, we should get the team since I need the team values for options on the payload + // Note that the team here is different than the redis team. + // TODO: consider making team an option, since we could fetch a project instead of a team + // if the token is valid by the team doesn't exist for some reason. Like, there might be a case + // where the token exists in the database but the team has been deleted. + // That said, though, I don't think this is necessary because we're already validating that the token exists + // in the database, so if it doesn't exist, we should be returning an error there. + // TODO make that one request; we already extract the token by accessing the team table, so we can just extract the team here + let team = request + .get_team_from_cache_or_pg(&token, state.redis.clone(), state.postgres.clone()) + .await?; + + // this errors up top-level if there's no distinct_id or missing one let distinct_id = request.extract_distinct_id()?; + // TODO handle disabled flags, should probably do that right at the beginning + tracing::Span::current().record("token", &token); tracing::Span::current().record("distinct_id", &distinct_id); tracing::debug!("request: {:?}", request); - // TODO: Some actual processing for evaluating the feature flag - - Ok(Json(FlagsResponse { - error_while_computing_flags: false, - feature_flags: HashMap::from([ - ( - "beta-feature".to_string(), - FlagValue::String("variant-1".to_string()), - ), - ("rollout-flag".to_string(), FlagValue::Boolean(true)), - ]), - })) + // now that I have a team ID and a distinct ID, I can evaluate the feature flags + + // first, get the flags + let all_feature_flags = request + .get_flags_from_cache_or_pg(team.id, state.redis.clone(), state.postgres.clone()) + .await?; + + tracing::Span::current().record("flags", &format!("{:?}", all_feature_flags)); + + // debug log, I'm keeping it around bc it's useful + // tracing::debug!( + // "flags: {}", + // serde_json::to_string_pretty(&all_feature_flags) + // .unwrap_or_else(|_| format!("{:?}", all_feature_flags)) + // ); + + let flags_response = + evaluate_feature_flags(distinct_id, all_feature_flags, Some(state.postgres.clone())).await; + + Ok(Json(flags_response)) + + // TODO need to handle experience continuity here +} + +pub async fn evaluate_feature_flags( + distinct_id: String, + feature_flag_list: FeatureFlagList, + database_client: Option>, +) -> FlagsResponse { + let mut matcher = FeatureFlagMatcher::new(distinct_id.clone(), database_client); + let mut feature_flags = HashMap::new(); + let mut error_while_computing_flags = false; + let all_feature_flags = feature_flag_list.flags; + + for flag in all_feature_flags { + if !flag.active || flag.deleted { + continue; + } + + let flag_match = matcher.get_match(&flag).await; + + let flag_value = if flag_match.matches { + match flag_match.variant { + Some(variant) => FlagValue::String(variant), + None => FlagValue::Boolean(true), + } + } else { + FlagValue::Boolean(false) + }; + + feature_flags.insert(flag.key.clone(), flag_value); + + if let Err(e) = matcher + .get_person_properties(flag.team_id, distinct_id.clone()) + .await + { + error_while_computing_flags = true; + error!( + "Error fetching properties for feature flag '{}' and distinct_id '{}': {:?}", + flag.key, distinct_id, e + ); + } + } + + if error_while_computing_flags { + warn!( + "Errors occurred while computing feature flags for distinct_id '{}'", + distinct_id + ); + } + + FlagsResponse { + error_while_computing_flags, + feature_flags, + } } diff --git a/rust/feature-flags/src/v0_request.rs b/rust/feature-flags/src/v0_request.rs index 4447cb64d1d68..3f28d6da85f04 100644 --- a/rust/feature-flags/src/v0_request.rs +++ b/rust/feature-flags/src/v0_request.rs @@ -6,13 +6,34 @@ use serde_json::Value; use tracing::instrument; use crate::{ - api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient, team::Team, + api::FlagError, database::Client as DatabaseClient, flag_definitions::FeatureFlagList, + redis::Client as RedisClient, team::Team, }; +#[derive(Deserialize, Default)] +pub enum Compression { + #[default] + Unsupported, + + #[serde(rename = "gzip", alias = "gzip-js")] + Gzip, + // TODO do we want to support this at all? It's not in the spec + // #[serde(rename = "base64")] + // Base64, +} + #[derive(Deserialize, Default)] pub struct FlagsQueryParams { #[serde(alias = "v")] pub version: Option, + + pub compression: Option, + + #[serde(alias = "ver")] + pub lib_version: Option, + + #[serde(alias = "_")] + pub sent_at: Option, } #[derive(Default, Debug, Deserialize, Serialize)] @@ -71,8 +92,8 @@ impl FlagRequest { // Fallback: Check PostgreSQL if not found in Redis match Team::from_pg(pg_client, token.clone()).await { Ok(team) => { - // Token found in PostgreSQL, update Redis cache - if let Err(e) = Team::update_redis_cache(redis_client, team).await { + // Token found in PostgreSQL, update Redis cache so that we can verify it from Redis next time + if let Err(e) = Team::update_redis_cache(redis_client, &team).await { tracing::warn!("Failed to update Redis cache: {}", e); } Ok(token) @@ -83,6 +104,27 @@ impl FlagRequest { } } + pub async fn get_team_from_cache_or_pg( + &self, + token: &str, + redis_client: Arc, + pg_client: Arc, + ) -> Result { + match Team::from_redis(redis_client.clone(), token.to_owned()).await { + Ok(team) => Ok(team), + Err(_) => match Team::from_pg(pg_client, token.to_owned()).await { + Ok(team) => { + // If we have the team in postgres, but not redis, update redis so we're faster next time + if let Err(e) = Team::update_redis_cache(redis_client, &team).await { + tracing::warn!("Failed to update Redis cache: {}", e); + } + Ok(team) + } + Err(_) => Err(FlagError::TokenValidationError), + }, + } + } + pub fn extract_distinct_id(&self) -> Result { let distinct_id = match &self.distinct_id { None => return Err(FlagError::MissingDistinctId), @@ -95,11 +137,37 @@ impl FlagRequest { _ => Ok(distinct_id.chars().take(200).collect()), } } + + pub async fn get_flags_from_cache_or_pg( + &self, + team_id: i32, + redis_client: Arc, + pg_client: Arc, + ) -> Result { + match FeatureFlagList::from_redis(redis_client.clone(), team_id).await { + Ok(flags) => Ok(flags), + Err(_) => match FeatureFlagList::from_pg(pg_client, team_id).await { + Ok(flags) => { + // If we have the flags in postgres, but not redis, update redis so we're faster next time + // TODO: we have some counters in django for tracking these cache misses + // we should probably do the same here + if let Err(e) = + FeatureFlagList::update_flags_in_redis(redis_client, team_id, &flags).await + { + tracing::warn!("Failed to update Redis cache: {}", e); + } + Ok(flags) + } + Err(_) => Err(FlagError::TokenValidationError), + }, + } + } } #[cfg(test)] mod tests { use crate::api::FlagError; + use crate::test_utils::{insert_new_team_in_redis, setup_pg_client, setup_redis_client}; use crate::v0_request::FlagRequest; use bytes::Bytes; use serde_json::json; @@ -148,4 +216,29 @@ mod tests { _ => panic!("expected distinct id"), }; } + + #[tokio::test] + async fn token_is_returned_correctly() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + let team = insert_new_team_in_redis(redis_client.clone()) + .await + .expect("Failed to insert new team in Redis"); + + let json = json!({ + "$distinct_id": "alakazam", + "token": team.api_token, + }); + let bytes = Bytes::from(json.to_string()); + + let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); + + match flag_payload + .extract_and_verify_token(redis_client, pg_client) + .await + { + Ok(extracted_token) => assert_eq!(extracted_token, team.api_token), + Err(e) => panic!("Failed to extract and verify token: {:?}", e), + }; + } } diff --git a/rust/feature-flags/tests/test_flags.rs b/rust/feature-flags/tests/test_flags.rs index 7f50064daddb6..706d8fdfed0da 100644 --- a/rust/feature-flags/tests/test_flags.rs +++ b/rust/feature-flags/tests/test_flags.rs @@ -7,7 +7,9 @@ use serde_json::{json, Value}; use crate::common::*; use feature_flags::config::DEFAULT_TEST_CONFIG; -use feature_flags::test_utils::{insert_new_team_in_redis, setup_redis_client}; +use feature_flags::test_utils::{ + insert_flags_for_team_in_redis, insert_new_team_in_redis, setup_redis_client, +}; pub mod common; @@ -21,6 +23,25 @@ async fn it_sends_flag_request() -> Result<()> { let team = insert_new_team_in_redis(client.clone()).await.unwrap(); let token = team.api_token; + // Insert a specific flag for the team + let flag_json = json!([{ + "id": 1, + "key": "test-flag", + "name": "Test Flag", + "active": true, + "deleted": false, + "team_id": team.id, + "filters": { + "groups": [ + { + "properties": [], + "rollout_percentage": 100 + } + ], + }, + }]); + insert_flags_for_team_in_redis(client, team.id, Some(flag_json.to_string())).await?; + let server = ServerHandle::for_config(config).await; let payload = json!({ @@ -28,20 +49,17 @@ async fn it_sends_flag_request() -> Result<()> { "distinct_id": distinct_id, "groups": {"group1": "group1"} }); + let res = server.send_flags_request(payload.to_string()).await; assert_eq!(StatusCode::OK, res.status()); - // We don't want to deserialize the data into a flagResponse struct here, - // because we want to assert the shape of the raw json data. let json_data = res.json::().await?; - assert_json_include!( actual: json_data, expected: json!({ "errorWhileComputingFlags": false, "featureFlags": { - "beta-feature": "variant-1", - "rollout-flag": true, + "test-flag": true } }) ); @@ -139,7 +157,7 @@ async fn it_rejects_missing_token() -> Result<()> { assert_eq!(StatusCode::UNAUTHORIZED, res.status()); assert_eq!( res.text().await?, - "No API key provided. Please include a valid API key in your request." + "No API token provided. Please include a valid API token in your request." ); Ok(()) }