Skip to content

Commit

Permalink
feat(flags): return real feature flag evaluation for basic matching w…
Browse files Browse the repository at this point in the history
…ith new endpoint (#24358)
  • Loading branch information
dmarticus authored Aug 22, 2024
1 parent 4d3d6d2 commit d6ad9fa
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 38 deletions.
31 changes: 30 additions & 1 deletion rust/feature-flags/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,40 @@ pub enum FlagValue {
String(String),
}

// TODO the following two types are kinda general, maybe we should move them to a shared module
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum BooleanOrStringObject {
Boolean(bool),
Object(HashMap<String, String>),
}

#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum BooleanOrBooleanObject {
Boolean(bool),
Object(HashMap<String, bool>),
}

#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct FlagsResponse {
pub error_while_computing_flags: bool,
pub feature_flags: HashMap<String, FlagValue>,
// TODO support the other fields in the payload
// pub config: HashMap<String, bool>,
// pub toolbar_params: HashMap<String, String>,
// pub is_authenticated: bool,
// pub supported_compression: Vec<String>,
// pub session_recording: bool,
// pub feature_flag_payloads: HashMap<String, String>,
// pub capture_performance: BooleanOrBooleanObject,
// #[serde(rename = "autocapture_opt_out")]
// pub autocapture_opt_out: bool,
// pub autocapture_exceptions: BooleanOrStringObject,
// pub surveys: bool,
// pub heatmaps: bool,
// pub site_apps: Vec<String>,
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -98,7 +127,7 @@ impl IntoResponse for FlagError {
(StatusCode::BAD_REQUEST, "The distinct_id field is missing from the request. Please include a valid identifier.".to_string())
}
FlagError::NoTokenError => {
(StatusCode::UNAUTHORIZED, "No API key provided. Please include a valid API key in your request.".to_string())
(StatusCode::UNAUTHORIZED, "No API token provided. Please include a valid API token in your request.".to_string())
}
FlagError::TokenValidationError => {
(StatusCode::UNAUTHORIZED, "The provided API key is invalid or has expired. Please check your API key and try again.".to_string())
Expand Down
8 changes: 4 additions & 4 deletions rust/feature-flags/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ pub struct Config {
#[envconfig(default = "127.0.0.1:3001")]
pub address: SocketAddr,

#[envconfig(default = "postgres://posthog:posthog@localhost:5432/test_posthog")]
#[envconfig(default = "postgres://posthog:posthog@localhost:5432/posthog")]
pub write_database_url: String,

#[envconfig(default = "postgres://posthog:posthog@localhost:5432/test_posthog")]
#[envconfig(default = "postgres://posthog:posthog@localhost:5432/posthog")]
pub read_database_url: String,

#[envconfig(default = "1024")]
Expand Down Expand Up @@ -57,11 +57,11 @@ mod tests {
);
assert_eq!(
config.write_database_url,
"postgres://posthog:posthog@localhost:5432/test_posthog"
"postgres://posthog:posthog@localhost:5432/posthog"
);
assert_eq!(
config.read_database_url,
"postgres://posthog:posthog@localhost:5432/test_posthog"
"postgres://posthog:posthog@localhost:5432/posthog"
);
assert_eq!(config.max_concurrent_jobs, 1024);
assert_eq!(config.max_pg_connections, 100);
Expand Down
37 changes: 29 additions & 8 deletions rust/feature-flags/src/flag_definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_";
#[derive(Debug, Deserialize)]
pub enum GroupTypeIndex {}

#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum OperatorType {
Exact,
Expand All @@ -32,7 +32,7 @@ pub enum OperatorType {
IsDateBefore,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PropertyFilter {
pub key: String,
// TODO: Probably need a default for value?
Expand All @@ -45,26 +45,26 @@ pub struct PropertyFilter {
pub group_type_index: Option<i8>,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FlagGroupType {
pub properties: Option<Vec<PropertyFilter>>,
pub rollout_percentage: Option<f64>,
pub variant: Option<String>,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MultivariateFlagVariant {
pub key: String,
pub name: Option<String>,
pub rollout_percentage: f64,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MultivariateFlagOptions {
pub variants: Vec<MultivariateFlagVariant>,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FlagFilters {
pub groups: Vec<FlagGroupType>,
pub multivariate: Option<MultivariateFlagOptions>,
Expand All @@ -73,7 +73,7 @@ pub struct FlagFilters {
pub super_groups: Option<Vec<FlagGroupType>>,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FeatureFlag {
pub id: i32,
pub team_id: i32,
Expand Down Expand Up @@ -117,7 +117,7 @@ impl FeatureFlag {
}
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Serialize)]
pub struct FeatureFlagList {
pub flags: Vec<FeatureFlag>,
}
Expand Down Expand Up @@ -189,6 +189,27 @@ impl FeatureFlagList {

Ok(FeatureFlagList { flags: flags_list })
}

pub async fn update_flags_in_redis(
client: Arc<dyn RedisClient + Send + Sync>,
team_id: i32,
flags: &FeatureFlagList,
) -> Result<(), FlagError> {
let payload = serde_json::to_string(flags).map_err(|e| {
tracing::error!("Failed to serialize flags: {}", e);
FlagError::DataParsingError
})?;

client
.set(format!("{TEAM_FLAGS_CACHE_PREFIX}{}", team_id), payload)
.await
.map_err(|e| {
tracing::error!("Failed to update Redis cache: {}", e);
FlagError::CacheUpdateError
})?;

Ok(())
}
}

#[cfg(test)]
Expand Down
15 changes: 13 additions & 2 deletions rust/feature-flags/src/team.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,22 @@ use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as
// TODO: Add integration tests across repos to ensure this doesn't happen.
pub const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:";

#[derive(Debug, Deserialize, Serialize, sqlx::FromRow)]
#[derive(Clone, Debug, Deserialize, Serialize, sqlx::FromRow)]
pub struct Team {
pub id: i32,
pub name: String,
pub api_token: String,
// TODO: the following fields are used for the `/decide` response,
// but they're not used for flags and they don't live in redis.
// At some point I'll need to differentiate between teams in Redis and teams
// with additional fields in Postgres, since the Postgres team is a superset of the fields
// we use for flags, anyway.
// pub surveys_opt_in: bool,
// pub heatmaps_opt_in: bool,
// pub capture_performance_opt_in: bool,
// pub autocapture_web_vitals_opt_in: bool,
// pub autocapture_opt_out: bool,
// pub autocapture_exceptions_opt_in: bool,
}

impl Team {
Expand Down Expand Up @@ -40,7 +51,7 @@ impl Team {
#[instrument(skip_all)]
pub async fn update_redis_cache(
client: Arc<dyn RedisClient + Send + Sync>,
team: Team,
team: &Team,
) -> Result<(), FlagError> {
let serialized_team = serde_json::to_string(&team).map_err(|e| {
tracing::error!("Failed to serialize team: {}", e);
Expand Down
127 changes: 114 additions & 13 deletions rust/feature-flags/src/v0_endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use std::collections::HashMap;
use std::sync::Arc;

use axum::{debug_handler, Json};
use bytes::Bytes;
// TODO: stream this instead
use axum::extract::{MatchedPath, Query, State};
use axum::http::{HeaderMap, Method};
use axum_client_ip::InsecureClientIp;
use tracing::instrument;
use tracing::{error, instrument, warn};

use crate::api::FlagValue;
use crate::database::Client;
use crate::flag_definitions::FeatureFlagList;
use crate::flag_matching::FeatureFlagMatcher;
use crate::v0_request::Compression;
use crate::{
api::{FlagError, FlagsResponse},
router,
Expand Down Expand Up @@ -42,20 +47,35 @@ pub async fn flags(
path: MatchedPath,
body: Bytes,
) -> Result<Json<FlagsResponse>, FlagError> {
// TODO this could be extracted into some load_data_for_request type thing
let user_agent = headers
.get("user-agent")
.map_or("unknown", |v| v.to_str().unwrap_or("unknown"));
let content_encoding = headers
.get("content-encoding")
.map_or("unknown", |v| v.to_str().unwrap_or("unknown"));
let comp = match meta.compression {
None => String::from("unknown"),
Some(Compression::Gzip) => String::from("gzip"),
// Some(Compression::Base64) => String::from("base64"),
Some(Compression::Unsupported) => String::from("unsupported"),
};
// TODO what do we use this for?
let sent_at = meta.sent_at.unwrap_or(0);

tracing::Span::current().record("user_agent", user_agent);
tracing::Span::current().record("content_encoding", content_encoding);
tracing::Span::current().record("version", meta.version.clone());
tracing::Span::current().record("lib_version", meta.lib_version.clone());
tracing::Span::current().record("compression", comp.as_str());
tracing::Span::current().record("method", method.as_str());
tracing::Span::current().record("path", path.as_str().trim_end_matches('/'));
tracing::Span::current().record("ip", ip.to_string());
tracing::Span::current().record("sent_at", &sent_at.to_string());

tracing::debug!("request headers: {:?}", headers);

// TODO handle different content types and encodings
let request = match headers
.get("content-type")
.map_or("", |v| v.to_str().unwrap_or(""))
Expand All @@ -64,6 +84,7 @@ pub async fn flags(
tracing::Span::current().record("content_type", "application/json");
FlagRequest::from_bytes(body)
}
// TODO support other content types
ct => {
return Err(FlagError::RequestDecodingError(format!(
"unsupported content type: {}",
Expand All @@ -72,27 +93,107 @@ pub async fn flags(
}
}?;

// this errors up top-level if there's no token
// return the team here, too?
let token = request
.extract_and_verify_token(state.redis.clone(), state.postgres.clone())
.await?;

// at this point, we should get the team since I need the team values for options on the payload
// Note that the team here is different than the redis team.
// TODO: consider making team an option, since we could fetch a project instead of a team
// if the token is valid by the team doesn't exist for some reason. Like, there might be a case
// where the token exists in the database but the team has been deleted.
// That said, though, I don't think this is necessary because we're already validating that the token exists
// in the database, so if it doesn't exist, we should be returning an error there.
// TODO make that one request; we already extract the token by accessing the team table, so we can just extract the team here
let team = request
.get_team_from_cache_or_pg(&token, state.redis.clone(), state.postgres.clone())
.await?;

// this errors up top-level if there's no distinct_id or missing one
let distinct_id = request.extract_distinct_id()?;

// TODO handle disabled flags, should probably do that right at the beginning

tracing::Span::current().record("token", &token);
tracing::Span::current().record("distinct_id", &distinct_id);

tracing::debug!("request: {:?}", request);

// TODO: Some actual processing for evaluating the feature flag

Ok(Json(FlagsResponse {
error_while_computing_flags: false,
feature_flags: HashMap::from([
(
"beta-feature".to_string(),
FlagValue::String("variant-1".to_string()),
),
("rollout-flag".to_string(), FlagValue::Boolean(true)),
]),
}))
// now that I have a team ID and a distinct ID, I can evaluate the feature flags

// first, get the flags
let all_feature_flags = request
.get_flags_from_cache_or_pg(team.id, state.redis.clone(), state.postgres.clone())
.await?;

tracing::Span::current().record("flags", &format!("{:?}", all_feature_flags));

// debug log, I'm keeping it around bc it's useful
// tracing::debug!(
// "flags: {}",
// serde_json::to_string_pretty(&all_feature_flags)
// .unwrap_or_else(|_| format!("{:?}", all_feature_flags))
// );

let flags_response =
evaluate_feature_flags(distinct_id, all_feature_flags, Some(state.postgres.clone())).await;

Ok(Json(flags_response))

// TODO need to handle experience continuity here
}

pub async fn evaluate_feature_flags(
distinct_id: String,
feature_flag_list: FeatureFlagList,
database_client: Option<Arc<dyn Client + Send + Sync>>,
) -> FlagsResponse {
let mut matcher = FeatureFlagMatcher::new(distinct_id.clone(), database_client);
let mut feature_flags = HashMap::new();
let mut error_while_computing_flags = false;
let all_feature_flags = feature_flag_list.flags;

for flag in all_feature_flags {
if !flag.active || flag.deleted {
continue;
}

let flag_match = matcher.get_match(&flag).await;

let flag_value = if flag_match.matches {
match flag_match.variant {
Some(variant) => FlagValue::String(variant),
None => FlagValue::Boolean(true),
}
} else {
FlagValue::Boolean(false)
};

feature_flags.insert(flag.key.clone(), flag_value);

if let Err(e) = matcher
.get_person_properties(flag.team_id, distinct_id.clone())
.await
{
error_while_computing_flags = true;
error!(
"Error fetching properties for feature flag '{}' and distinct_id '{}': {:?}",
flag.key, distinct_id, e
);
}
}

if error_while_computing_flags {
warn!(
"Errors occurred while computing feature flags for distinct_id '{}'",
distinct_id
);
}

FlagsResponse {
error_while_computing_flags,
feature_flags,
}
}
Loading

0 comments on commit d6ad9fa

Please sign in to comment.