Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(flags): return real feature flag evaluation for basic matching with new endpoint #24358

Merged
merged 15 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion rust/feature-flags/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,40 @@ pub enum FlagValue {
String(String),
}

// TODO the following two types are kinda general, maybe we should move them to a shared module
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum BooleanOrStringObject {
Boolean(bool),
Object(HashMap<String, String>),
}

#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum BooleanOrBooleanObject {
Boolean(bool),
Object(HashMap<String, bool>),
}

#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct FlagsResponse {
pub error_while_computing_flags: bool,
pub feature_flags: HashMap<String, FlagValue>,
// TODO support the other fields in the payload
// pub config: HashMap<String, bool>,
// pub toolbar_params: HashMap<String, String>,
// pub is_authenticated: bool,
// pub supported_compression: Vec<String>,
// pub session_recording: bool,
// pub feature_flag_payloads: HashMap<String, String>,
// pub capture_performance: BooleanOrBooleanObject,
// #[serde(rename = "autocapture_opt_out")]
// pub autocapture_opt_out: bool,
// pub autocapture_exceptions: BooleanOrStringObject,
// pub surveys: bool,
// pub heatmaps: bool,
// pub site_apps: Vec<String>,
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -98,7 +127,7 @@ impl IntoResponse for FlagError {
(StatusCode::BAD_REQUEST, "The distinct_id field is missing from the request. Please include a valid identifier.".to_string())
}
FlagError::NoTokenError => {
(StatusCode::UNAUTHORIZED, "No API key provided. Please include a valid API key in your request.".to_string())
(StatusCode::UNAUTHORIZED, "No API token provided. Please include a valid API token in your request.".to_string())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we should call it a key or a token, lol. We'll pick one on which to standardize.

}
FlagError::TokenValidationError => {
(StatusCode::UNAUTHORIZED, "The provided API key is invalid or has expired. Please check your API key and try again.".to_string())
Expand Down
8 changes: 4 additions & 4 deletions rust/feature-flags/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ pub struct Config {
#[envconfig(default = "127.0.0.1:3001")]
pub address: SocketAddr,

#[envconfig(default = "postgres://posthog:posthog@localhost:5432/test_posthog")]
#[envconfig(default = "postgres://posthog:posthog@localhost:5432/posthog")]
Copy link
Contributor Author

@dmarticus dmarticus Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

real app uses real DB!

pub write_database_url: String,

#[envconfig(default = "postgres://posthog:posthog@localhost:5432/test_posthog")]
#[envconfig(default = "postgres://posthog:posthog@localhost:5432/posthog")]
pub read_database_url: String,

#[envconfig(default = "1024")]
Expand Down Expand Up @@ -57,11 +57,11 @@ mod tests {
);
assert_eq!(
config.write_database_url,
"postgres://posthog:posthog@localhost:5432/test_posthog"
"postgres://posthog:posthog@localhost:5432/posthog"
);
assert_eq!(
config.read_database_url,
"postgres://posthog:posthog@localhost:5432/test_posthog"
"postgres://posthog:posthog@localhost:5432/posthog"
);
assert_eq!(config.max_concurrent_jobs, 1024);
assert_eq!(config.max_pg_connections, 100);
Expand Down
37 changes: 29 additions & 8 deletions rust/feature-flags/src/flag_definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_";
#[derive(Debug, Deserialize)]
pub enum GroupTypeIndex {}

#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum OperatorType {
Exact,
Expand All @@ -32,7 +32,7 @@ pub enum OperatorType {
IsDateBefore,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PropertyFilter {
pub key: String,
// TODO: Probably need a default for value?
Expand All @@ -45,26 +45,26 @@ pub struct PropertyFilter {
pub group_type_index: Option<i8>,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FlagGroupType {
pub properties: Option<Vec<PropertyFilter>>,
pub rollout_percentage: Option<f64>,
pub variant: Option<String>,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MultivariateFlagVariant {
pub key: String,
pub name: Option<String>,
pub rollout_percentage: f64,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MultivariateFlagOptions {
pub variants: Vec<MultivariateFlagVariant>,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FlagFilters {
pub groups: Vec<FlagGroupType>,
pub multivariate: Option<MultivariateFlagOptions>,
Expand All @@ -73,7 +73,7 @@ pub struct FlagFilters {
pub super_groups: Option<Vec<FlagGroupType>>,
}

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FeatureFlag {
pub id: i32,
pub team_id: i32,
Expand Down Expand Up @@ -117,7 +117,7 @@ impl FeatureFlag {
}
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Serialize)]
pub struct FeatureFlagList {
pub flags: Vec<FeatureFlag>,
}
Expand Down Expand Up @@ -189,6 +189,27 @@ impl FeatureFlagList {

Ok(FeatureFlagList { flags: flags_list })
}

pub async fn update_flags_in_redis(
client: Arc<dyn RedisClient + Send + Sync>,
team_id: i32,
flags: &FeatureFlagList,
) -> Result<(), FlagError> {
let payload = serde_json::to_string(flags).map_err(|e| {
tracing::error!("Failed to serialize flags: {}", e);
FlagError::DataParsingError
})?;

client
.set(format!("{TEAM_FLAGS_CACHE_PREFIX}{}", team_id), payload)
.await
.map_err(|e| {
tracing::error!("Failed to update Redis cache: {}", e);
FlagError::CacheUpdateError
})?;

Ok(())
}
}

#[cfg(test)]
Expand Down
15 changes: 13 additions & 2 deletions rust/feature-flags/src/team.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,22 @@ use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as
// TODO: Add integration tests across repos to ensure this doesn't happen.
pub const TEAM_TOKEN_CACHE_PREFIX: &str = "posthog:1:team_token:";

#[derive(Debug, Deserialize, Serialize, sqlx::FromRow)]
#[derive(Clone, Debug, Deserialize, Serialize, sqlx::FromRow)]
pub struct Team {
pub id: i32,
pub name: String,
pub api_token: String,
// TODO: the following fields are used for the `/decide` response,
// but they're not used for flags and they don't live in redis.
// At some point I'll need to differentiate between teams in Redis and teams
// with additional fields in Postgres, since the Postgres team is a superset of the fields
// we use for flags, anyway.
// pub surveys_opt_in: bool,
// pub heatmaps_opt_in: bool,
// pub capture_performance_opt_in: bool,
// pub autocapture_web_vitals_opt_in: bool,
// pub autocapture_opt_out: bool,
// pub autocapture_exceptions_opt_in: bool,
}

impl Team {
Expand Down Expand Up @@ -40,7 +51,7 @@ impl Team {
#[instrument(skip_all)]
pub async fn update_redis_cache(
client: Arc<dyn RedisClient + Send + Sync>,
team: Team,
team: &Team,
) -> Result<(), FlagError> {
let serialized_team = serde_json::to_string(&team).map_err(|e| {
tracing::error!("Failed to serialize team: {}", e);
Expand Down
127 changes: 114 additions & 13 deletions rust/feature-flags/src/v0_endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use std::collections::HashMap;
use std::sync::Arc;

use axum::{debug_handler, Json};
use bytes::Bytes;
// TODO: stream this instead
use axum::extract::{MatchedPath, Query, State};
use axum::http::{HeaderMap, Method};
use axum_client_ip::InsecureClientIp;
use tracing::instrument;
use tracing::{error, instrument, warn};

use crate::api::FlagValue;
use crate::database::Client;
use crate::flag_definitions::FeatureFlagList;
use crate::flag_matching::FeatureFlagMatcher;
use crate::v0_request::Compression;
use crate::{
api::{FlagError, FlagsResponse},
router,
Expand Down Expand Up @@ -42,20 +47,35 @@ pub async fn flags(
path: MatchedPath,
body: Bytes,
) -> Result<Json<FlagsResponse>, FlagError> {
// TODO this could be extracted into some load_data_for_request type thing
let user_agent = headers
.get("user-agent")
.map_or("unknown", |v| v.to_str().unwrap_or("unknown"));
let content_encoding = headers
.get("content-encoding")
.map_or("unknown", |v| v.to_str().unwrap_or("unknown"));
let comp = match meta.compression {
None => String::from("unknown"),
Some(Compression::Gzip) => String::from("gzip"),
// Some(Compression::Base64) => String::from("base64"),
Some(Compression::Unsupported) => String::from("unsupported"),
};
// TODO what do we use this for?
let sent_at = meta.sent_at.unwrap_or(0);

tracing::Span::current().record("user_agent", user_agent);
tracing::Span::current().record("content_encoding", content_encoding);
tracing::Span::current().record("version", meta.version.clone());
tracing::Span::current().record("lib_version", meta.lib_version.clone());
tracing::Span::current().record("compression", comp.as_str());
tracing::Span::current().record("method", method.as_str());
tracing::Span::current().record("path", path.as_str().trim_end_matches('/'));
tracing::Span::current().record("ip", ip.to_string());
tracing::Span::current().record("sent_at", &sent_at.to_string());

tracing::debug!("request headers: {:?}", headers);

// TODO handle different content types and encodings
let request = match headers
.get("content-type")
.map_or("", |v| v.to_str().unwrap_or(""))
Expand All @@ -64,6 +84,7 @@ pub async fn flags(
tracing::Span::current().record("content_type", "application/json");
FlagRequest::from_bytes(body)
}
// TODO support other content types
ct => {
return Err(FlagError::RequestDecodingError(format!(
"unsupported content type: {}",
Expand All @@ -72,27 +93,107 @@ pub async fn flags(
}
}?;

// this errors up top-level if there's no token
// return the team here, too?
let token = request
.extract_and_verify_token(state.redis.clone(), state.postgres.clone())
.await?;

// at this point, we should get the team since I need the team values for options on the payload
// Note that the team here is different than the redis team.
// TODO: consider making team an option, since we could fetch a project instead of a team
// if the token is valid by the team doesn't exist for some reason. Like, there might be a case
// where the token exists in the database but the team has been deleted.
// That said, though, I don't think this is necessary because we're already validating that the token exists
// in the database, so if it doesn't exist, we should be returning an error there.
// TODO make that one request; we already extract the token by accessing the team table, so we can just extract the team here
let team = request
.get_team_from_cache_or_pg(&token, state.redis.clone(), state.postgres.clone())
.await?;

// this errors up top-level if there's no distinct_id or missing one
let distinct_id = request.extract_distinct_id()?;

// TODO handle disabled flags, should probably do that right at the beginning

tracing::Span::current().record("token", &token);
tracing::Span::current().record("distinct_id", &distinct_id);

tracing::debug!("request: {:?}", request);

// TODO: Some actual processing for evaluating the feature flag

Ok(Json(FlagsResponse {
error_while_computing_flags: false,
feature_flags: HashMap::from([
(
"beta-feature".to_string(),
FlagValue::String("variant-1".to_string()),
),
("rollout-flag".to_string(), FlagValue::Boolean(true)),
]),
}))
// now that I have a team ID and a distinct ID, I can evaluate the feature flags

// first, get the flags
let all_feature_flags = request
.get_flags_from_cache_or_pg(team.id, state.redis.clone(), state.postgres.clone())
.await?;

tracing::Span::current().record("flags", &format!("{:?}", all_feature_flags));

// debug log, I'm keeping it around bc it's useful
// tracing::debug!(
// "flags: {}",
// serde_json::to_string_pretty(&all_feature_flags)
// .unwrap_or_else(|_| format!("{:?}", all_feature_flags))
// );

let flags_response =
evaluate_feature_flags(distinct_id, all_feature_flags, Some(state.postgres.clone())).await;

Ok(Json(flags_response))

// TODO need to handle experience continuity here
}

pub async fn evaluate_feature_flags(
distinct_id: String,
feature_flag_list: FeatureFlagList,
database_client: Option<Arc<dyn Client + Send + Sync>>,
) -> FlagsResponse {
let mut matcher = FeatureFlagMatcher::new(distinct_id.clone(), database_client);
let mut feature_flags = HashMap::new();
let mut error_while_computing_flags = false;
let all_feature_flags = feature_flag_list.flags;

for flag in all_feature_flags {
if !flag.active || flag.deleted {
continue;
}

let flag_match = matcher.get_match(&flag).await;

let flag_value = if flag_match.matches {
match flag_match.variant {
Some(variant) => FlagValue::String(variant),
None => FlagValue::Boolean(true),
}
} else {
FlagValue::Boolean(false)
};

feature_flags.insert(flag.key.clone(), flag_value);

if let Err(e) = matcher
.get_person_properties(flag.team_id, distinct_id.clone())
.await
{
error_while_computing_flags = true;
error!(
"Error fetching properties for feature flag '{}' and distinct_id '{}': {:?}",
flag.key, distinct_id, e
);
}
}

if error_while_computing_flags {
warn!(
"Errors occurred while computing feature flags for distinct_id '{}'",
distinct_id
);
}

FlagsResponse {
error_while_computing_flags,
feature_flags,
}
}
Loading
Loading