Skip to content

Commit

Permalink
fix(flags): Defend against invalid data on creation
Browse files Browse the repository at this point in the history
  • Loading branch information
neilkakkar committed Apr 26, 2024
1 parent 06fe778 commit df9fd36
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 3 deletions.
20 changes: 20 additions & 0 deletions posthog/api/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
get_user_blast_radius,
)
from posthog.models.feature_flag.flag_analytics import increment_request_count
from posthog.models.feature_flag.flag_matching import check_flag_evaluation_query_is_ok
from posthog.models.feedback.survey import Survey
from posthog.models.group_type_mapping import GroupTypeMapping
from posthog.models.property import Property
Expand Down Expand Up @@ -263,6 +264,22 @@ def properties_all_match(predicate):

return filters

def check_flag_evaluation(self, data):
# this is a very rough simulation of the actual query that will be run.
# Only reason we do it this way is to catch any DB level errors that will bork at runtime
# but aren't caught by above validation, like a regex valid according to re2 but not postgresql.
# We also randomly query for 20 people sans distinct id to make sure the query is valid.

# TODO: Once we move to no DB level evaluation, can get rid of this.

temporary_flag = FeatureFlag(**data)
team_id = self.context["team_id"]

try:
check_flag_evaluation_query_is_ok(temporary_flag, team_id)
except Exception as e:
raise serializers.ValidationError("Can't evaluate flag: " + str(e))

def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag:
request = self.context["request"]
validated_data["created_by"] = request.user
Expand All @@ -289,6 +306,9 @@ def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag
raise exceptions.ValidationError(
"Feature flag with this key already exists and is used in an experiment. Please delete the experiment before deleting the flag."
)

self.check_flag_evaluation(validated_data)

instance: FeatureFlag = super().create(validated_data)

self._attempt_set_tags(tags, instance)
Expand Down
21 changes: 21 additions & 0 deletions posthog/api/test/__snapshots__/test_feature_flag.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,27 @@
LIMIT 21
'''
# ---
# name: TestFeatureFlag.test_cant_create_flag_with_data_that_fails_to_query
'''
SELECT (("posthog_person"."properties" ->> 'email')::text ~ '2.3.9{0-9}{1}'
AND "posthog_person"."properties" ? 'email'
AND NOT (("posthog_person"."properties" -> 'email') = 'null'::jsonb)) AS "flag_X_condition_0"
FROM "posthog_person"
WHERE "posthog_person"."team_id" = 2
LIMIT 10
'''
# ---
# name: TestFeatureFlag.test_cant_create_flag_with_group_data_that_fails_to_query
'''
SELECT (("posthog_group"."group_properties" ->> 'email')::text ~ '2.3.9{0-9}{1 ef}'
AND "posthog_group"."group_properties" ? 'email'
AND NOT (("posthog_group"."group_properties" -> 'email') = 'null'::jsonb)) AS "flag_X_condition_0"
FROM "posthog_group"
WHERE ("posthog_group"."group_type_index" = 2
AND "posthog_group"."team_id" = 2)
LIMIT 10
'''
# ---
# name: TestFeatureFlag.test_creating_static_cohort
'''
SELECT "posthog_user"."id",
Expand Down
97 changes: 97 additions & 0 deletions posthog/api/test/test_feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4051,6 +4051,103 @@ def test_cant_update_early_access_flag_with_group(self):
response.json(),
)

def test_cant_create_flag_with_data_that_fails_to_query(self):
Person.objects.create(
distinct_ids=["123"],
team=self.team,
properties={"email": "x y z"},
)
Person.objects.create(
distinct_ids=["456"],
team=self.team,
properties={"email": "2.3.999"},
)

# Only snapshot flag evaluation queries
with snapshot_postgres_queries_context(self, custom_query_matcher=lambda query: "posthog_person" in query):
response = self.client.post(
f"/api/projects/{self.team.id}/feature_flags",
{
"name": "Beta feature",
"key": "beta-x",
"filters": {
"groups": [
{
"rollout_percentage": 65,
"properties": [
{
"key": "email",
"type": "person",
"value": "2.3.9{0-9}{1}",
"operator": "regex",
}
],
}
]
},
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{
"type": "validation_error",
"code": "invalid_input",
"detail": "Can't evaluate flag: invalid regular expression: invalid repetition count(s)",
"attr": None,
},
)

def test_cant_create_flag_with_group_data_that_fails_to_query(self):
GroupTypeMapping.objects.create(team=self.team, group_type="organization", group_type_index=0)
GroupTypeMapping.objects.create(team=self.team, group_type="xyz", group_type_index=1)

for i in range(5):
create_group(
team_id=self.team.pk,
group_type_index=1,
group_key=f"xyz:{i}",
properties={"industry": f"{i}", "email": "2.3.4445"},
)

# Only snapshot flag evaluation queries
with snapshot_postgres_queries_context(self, custom_query_matcher=lambda query: "posthog_group" in query):
# Test group flag with invalid regex
response = self.client.post(
f"/api/projects/{self.team.id}/feature_flags",
{
"name": "Beta feature",
"key": "beta-x",
"filters": {
"aggregation_group_type_index": 1,
"groups": [
{
"rollout_percentage": 65,
"properties": [
{
"key": "email",
"type": "group",
"group_type_index": 1,
"value": "2.3.9{0-9}{1 ef}",
"operator": "regex",
}
],
}
],
},
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{
"type": "validation_error",
"code": "invalid_input",
"detail": "Can't evaluate flag: invalid regular expression: invalid repetition count(s)",
"attr": None,
},
)


class TestCohortGenerationForFeatureFlag(APIBaseTest, ClickhouseTestMixin):
def test_creating_static_cohort_with_deleted_flag(self):
Expand Down
42 changes: 40 additions & 2 deletions posthog/models/feature_flag/flag_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from prometheus_client import Counter
from django.conf import settings
from django.db import DatabaseError, IntegrityError
from django.db import DatabaseError, IntegrityError, DataError
from django.db.models.expressions import ExpressionWrapper, RawSQL
from django.db.models.fields import BooleanField
from django.db.models import Q, Func, F, CharField
Expand Down Expand Up @@ -981,7 +981,9 @@ def handle_feature_flag_exception(err: Exception, log_message: str = "", set_hea
if reason == "unknown":
capture_exception(err)

if isinstance(err, DatabaseError) and set_healthcheck:
# DataErrors are generally not because the db is down, but because of bad data.
# We don't want to set the healthcheck down for bad data.
if not isinstance(err, DataError) and isinstance(err, DatabaseError) and set_healthcheck:
postgres_healthcheck.set_connection(False)


Expand Down Expand Up @@ -1061,3 +1063,39 @@ def check_pure_is_not_operator_condition(condition: dict) -> bool:
if properties and all(prop.get("operator") in ("is_not_set", "is_not") for prop in properties):
return True
return False


def check_flag_evaluation_query_is_ok(feature_flag: FeatureFlag, team_id: int) -> bool:
# TRICKY: There are some cases where the regex is valid re2 syntax, but postgresql doesn't like it.
# This function tries to validate such cases.
# It however doesn't catch all cases, like when the property doesn't exist on any person, which shortcircuits regex evaluation
# so it's not a guarantee that the query will work.

group_type_index = feature_flag.aggregation_group_type_index

base_query: QuerySet = (
Person.objects.filter(team_id=team_id)
if group_type_index is None
else Group.objects.filter(team_id=team_id, group_type_index=group_type_index)
)
query_fields = []

for index, condition in enumerate(feature_flag.conditions):
key = f"flag_0_condition_{index}"
property_list = Filter(data=condition).property_groups.flat
expr = properties_to_Q(
team_id,
property_list,
)
base_query = base_query.annotate(
**{
key: ExpressionWrapper(
expr if expr else RawSQL("true", []),
output_field=BooleanField(),
),
}
)
query_fields.append(key)

values = base_query.values(*query_fields)[:10]
return len(values) > 0
7 changes: 6 additions & 1 deletion posthog/test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from contextlib import contextmanager
from functools import wraps
from typing import Any, Optional, Union
from collections.abc import Callable
from collections.abc import Generator
from unittest.mock import patch

Expand Down Expand Up @@ -627,6 +628,7 @@ def snapshot_postgres_queries_context(
replace_all_numbers: bool = True,
using: str = "default",
capture_all_queries: bool = False,
custom_query_matcher: Optional[Callable] = None,
):
"""
Captures and snapshots select queries from test using `syrupy` library.
Expand Down Expand Up @@ -659,7 +661,10 @@ def test_something(self):

for query_with_time in context.captured_queries:
query = query_with_time["sql"]
if capture_all_queries:
if custom_query_matcher:
if query and custom_query_matcher(query):
testcase.assertQueryMatchesSnapshot(query, replace_all_numbers=replace_all_numbers)
elif capture_all_queries:
testcase.assertQueryMatchesSnapshot(query, replace_all_numbers=replace_all_numbers)
elif query and "SELECT" in query and "django_session" not in query and not re.match(r"^\s*INSERT", query):
testcase.assertQueryMatchesSnapshot(query, replace_all_numbers=replace_all_numbers)
Expand Down
62 changes: 62 additions & 0 deletions posthog/test/test_feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from freezegun import freeze_time
import pytest

from posthog.api.test.test_feature_flag import QueryTimeoutWrapper
from posthog.models import Cohort, FeatureFlag, GroupTypeMapping, Person
from posthog.models.feature_flag import get_feature_flags_for_team_in_cache
from posthog.models.feature_flag.flag_matching import (
Expand Down Expand Up @@ -4406,6 +4407,67 @@ def test_invalid_group_filters_dont_set_db_down(self, mock_database_healthcheck)
self.assertEqual(matcher.failed_to_fetch_conditions, False)
mock_database_healthcheck.set_connection.assert_not_called()

@patch("posthog.models.feature_flag.flag_matching.postgres_healthcheck")
def test_data_errors_dont_set_db_down(self, mock_database_healthcheck):
flag: FeatureFlag = FeatureFlag.objects.create(
team=self.team,
created_by=self.user,
active=True,
key="active-flag",
filters={"groups": [{"properties": [], "rollout_percentage": 100}]},
)
flag2: FeatureFlag = FeatureFlag.objects.create(
team=self.team,
created_by=self.user,
active=True,
key="other-flag",
filters={
"groups": [
{"properties": [{"key": "tear", "value": "tear", "type": "person"}], "rollout_percentage": 100}
],
},
)

matcher = FeatureFlagMatcher([flag, flag2], "bxss.me/t/xss.html?\x00")

self.assertEqual(
matcher.get_matches(),
(
{"active-flag": True},
{
"active-flag": {
"condition_index": 0,
"reason": FeatureFlagMatchReason.CONDITION_MATCH,
}
},
{},
True,
),
)
self.assertEqual(matcher.failed_to_fetch_conditions, True)
mock_database_healthcheck.set_connection.assert_not_called()

# with operational error, should set db down
with connection.execute_wrapper(QueryTimeoutWrapper()):
matcher = FeatureFlagMatcher([flag, flag2], "bxss.me/t/xss.html")

self.assertEqual(
matcher.get_matches(),
(
{"active-flag": True},
{
"active-flag": {
"condition_index": 0,
"reason": FeatureFlagMatchReason.CONDITION_MATCH,
}
},
{},
True,
),
)
self.assertEqual(matcher.failed_to_fetch_conditions, True)
mock_database_healthcheck.set_connection.assert_called_once_with(False)

def test_legacy_rollout_percentage(self):
feature_flag = self.create_feature_flag(rollout_percentage=50)
self.assertEqual(
Expand Down

0 comments on commit df9fd36

Please sign in to comment.