Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(flags): Defend against invalid data on creation #21893

Merged
merged 5 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions posthog/api/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
get_user_blast_radius,
)
from posthog.models.feature_flag.flag_analytics import increment_request_count
from posthog.models.feature_flag.flag_matching import check_flag_evaluation_query_is_ok
from posthog.models.feedback.survey import Survey
from posthog.models.group_type_mapping import GroupTypeMapping
from posthog.models.property import Property
Expand Down Expand Up @@ -263,6 +264,17 @@

return filters

def check_flag_evaluation(self, data):
# TODO: Once we move to no DB level evaluation, can get rid of this.

temporary_flag = FeatureFlag(**data)
team_id = self.context["team_id"]

try:
check_flag_evaluation_query_is_ok(temporary_flag, team_id)
except Exception as e:
raise serializers.ValidationError("Can't evaluate flag: " + str(e))
Fixed Show fixed Hide fixed

def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag:
request = self.context["request"]
validated_data["created_by"] = request.user
Expand All @@ -289,6 +301,9 @@
raise exceptions.ValidationError(
"Feature flag with this key already exists and is used in an experiment. Please delete the experiment before deleting the flag."
)

self.check_flag_evaluation(validated_data)

instance: FeatureFlag = super().create(validated_data)

self._attempt_set_tags(tags, instance)
Expand Down
21 changes: 21 additions & 0 deletions posthog/api/test/__snapshots__/test_feature_flag.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,27 @@
LIMIT 21
'''
# ---
# name: TestFeatureFlag.test_cant_create_flag_with_data_that_fails_to_query
'''
SELECT (("posthog_person"."properties" ->> 'email')::text ~ '2.3.9{0-9}{1}'
AND "posthog_person"."properties" ? 'email'
AND NOT (("posthog_person"."properties" -> 'email') = 'null'::jsonb)) AS "flag_X_condition_0"
FROM "posthog_person"
WHERE "posthog_person"."team_id" = 2
LIMIT 10
'''
# ---
# name: TestFeatureFlag.test_cant_create_flag_with_group_data_that_fails_to_query
'''
SELECT (("posthog_group"."group_properties" ->> 'email')::text ~ '2.3.9{0-9}{1 ef}'
AND "posthog_group"."group_properties" ? 'email'
AND NOT (("posthog_group"."group_properties" -> 'email') = 'null'::jsonb)) AS "flag_X_condition_0"
FROM "posthog_group"
WHERE ("posthog_group"."group_type_index" = 2
AND "posthog_group"."team_id" = 2)
LIMIT 10
'''
# ---
# name: TestFeatureFlag.test_creating_static_cohort
'''
SELECT "posthog_user"."id",
Expand Down
97 changes: 97 additions & 0 deletions posthog/api/test/test_feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4051,6 +4051,103 @@ def test_cant_update_early_access_flag_with_group(self):
response.json(),
)

def test_cant_create_flag_with_data_that_fails_to_query(self):
Person.objects.create(
distinct_ids=["123"],
team=self.team,
properties={"email": "x y z"},
)
Person.objects.create(
distinct_ids=["456"],
team=self.team,
properties={"email": "2.3.999"},
)

# Only snapshot flag evaluation queries
with snapshot_postgres_queries_context(self, custom_query_matcher=lambda query: "posthog_person" in query):
response = self.client.post(
f"/api/projects/{self.team.id}/feature_flags",
{
"name": "Beta feature",
"key": "beta-x",
"filters": {
"groups": [
{
"rollout_percentage": 65,
"properties": [
{
"key": "email",
"type": "person",
"value": "2.3.9{0-9}{1}",
"operator": "regex",
}
],
}
]
},
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{
"type": "validation_error",
"code": "invalid_input",
"detail": "Can't evaluate flag: invalid regular expression: invalid repetition count(s)",
"attr": None,
},
)

def test_cant_create_flag_with_group_data_that_fails_to_query(self):
GroupTypeMapping.objects.create(team=self.team, group_type="organization", group_type_index=0)
GroupTypeMapping.objects.create(team=self.team, group_type="xyz", group_type_index=1)

for i in range(5):
create_group(
team_id=self.team.pk,
group_type_index=1,
group_key=f"xyz:{i}",
properties={"industry": f"{i}", "email": "2.3.4445"},
)

# Only snapshot flag evaluation queries
with snapshot_postgres_queries_context(self, custom_query_matcher=lambda query: "posthog_group" in query):
# Test group flag with invalid regex
response = self.client.post(
f"/api/projects/{self.team.id}/feature_flags",
{
"name": "Beta feature",
"key": "beta-x",
"filters": {
"aggregation_group_type_index": 1,
"groups": [
{
"rollout_percentage": 65,
"properties": [
{
"key": "email",
"type": "group",
"group_type_index": 1,
"value": "2.3.9{0-9}{1 ef}",
"operator": "regex",
}
],
}
],
},
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.json(),
{
"type": "validation_error",
"code": "invalid_input",
"detail": "Can't evaluate flag: invalid regular expression: invalid repetition count(s)",
"attr": None,
},
)


class TestCohortGenerationForFeatureFlag(APIBaseTest, ClickhouseTestMixin):
def test_creating_static_cohort_with_deleted_flag(self):
Expand Down
49 changes: 47 additions & 2 deletions posthog/models/feature_flag/flag_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from prometheus_client import Counter
from django.conf import settings
from django.db import DatabaseError, IntegrityError
from django.db import DatabaseError, IntegrityError, DataError
from django.db.models.expressions import ExpressionWrapper, RawSQL
from django.db.models.fields import BooleanField
from django.db.models import Q, Func, F, CharField
Expand Down Expand Up @@ -981,7 +981,9 @@ def handle_feature_flag_exception(err: Exception, log_message: str = "", set_hea
if reason == "unknown":
capture_exception(err)

if isinstance(err, DatabaseError) and set_healthcheck:
# DataErrors are generally not because the db is down, but because of bad data.
# We don't want to set the healthcheck down for bad data.
if not isinstance(err, DataError) and isinstance(err, DatabaseError) and set_healthcheck:
postgres_healthcheck.set_connection(False)


Expand Down Expand Up @@ -1061,3 +1063,46 @@ def check_pure_is_not_operator_condition(condition: dict) -> bool:
if properties and all(prop.get("operator") in ("is_not_set", "is_not") for prop in properties):
return True
return False


def check_flag_evaluation_query_is_ok(feature_flag: FeatureFlag, team_id: int) -> bool:
# TRICKY: There are some cases where the regex is valid re2 syntax, but postgresql doesn't like it.
# This function tries to validate such cases. See `test_cant_create_flag_with_data_that_fails_to_query` for an example.
# It however doesn't catch all cases, like when the property doesn't exist on any person, which shortcircuits regex evaluation
# so it's not a guarantee that the query will work.

# This is a very rough simulation of the actual query that will be run.
# Only reason we do it this way is to catch any DB level errors that will bork at runtime
# but aren't caught by above validation, like a regex valid according to re2 but not postgresql.
# We also randomly query for 20 people sans distinct id to make sure the query is valid.

# TODO: Once we move to no DB level evaluation, can get rid of this.

group_type_index = feature_flag.aggregation_group_type_index

base_query: QuerySet = (
Person.objects.filter(team_id=team_id)
if group_type_index is None
else Group.objects.filter(team_id=team_id, group_type_index=group_type_index)
)
query_fields = []

for index, condition in enumerate(feature_flag.conditions):
key = f"flag_0_condition_{index}"
property_list = Filter(data=condition).property_groups.flat
expr = properties_to_Q(
team_id,
property_list,
)
base_query = base_query.annotate(
**{
key: ExpressionWrapper(
expr if expr else RawSQL("true", []),
output_field=BooleanField(),
),
}
)
query_fields.append(key)

values = base_query.values(*query_fields)[:10]
return len(values) > 0
7 changes: 6 additions & 1 deletion posthog/test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from contextlib import contextmanager
from functools import wraps
from typing import Any, Optional, Union
from collections.abc import Callable
from collections.abc import Generator
from unittest.mock import patch

Expand Down Expand Up @@ -627,6 +628,7 @@ def snapshot_postgres_queries_context(
replace_all_numbers: bool = True,
using: str = "default",
capture_all_queries: bool = False,
custom_query_matcher: Optional[Callable] = None,
):
"""
Captures and snapshots select queries from test using `syrupy` library.
Expand Down Expand Up @@ -659,7 +661,10 @@ def test_something(self):

for query_with_time in context.captured_queries:
query = query_with_time["sql"]
if capture_all_queries:
if custom_query_matcher:
if query and custom_query_matcher(query):
testcase.assertQueryMatchesSnapshot(query, replace_all_numbers=replace_all_numbers)
elif capture_all_queries:
testcase.assertQueryMatchesSnapshot(query, replace_all_numbers=replace_all_numbers)
elif query and "SELECT" in query and "django_session" not in query and not re.match(r"^\s*INSERT", query):
testcase.assertQueryMatchesSnapshot(query, replace_all_numbers=replace_all_numbers)
Expand Down
62 changes: 62 additions & 0 deletions posthog/test/test_feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from freezegun import freeze_time
import pytest

from posthog.api.test.test_feature_flag import QueryTimeoutWrapper
from posthog.models import Cohort, FeatureFlag, GroupTypeMapping, Person
from posthog.models.feature_flag import get_feature_flags_for_team_in_cache
from posthog.models.feature_flag.flag_matching import (
Expand Down Expand Up @@ -4406,6 +4407,67 @@ def test_invalid_group_filters_dont_set_db_down(self, mock_database_healthcheck)
self.assertEqual(matcher.failed_to_fetch_conditions, False)
mock_database_healthcheck.set_connection.assert_not_called()

@patch("posthog.models.feature_flag.flag_matching.postgres_healthcheck")
def test_data_errors_dont_set_db_down(self, mock_database_healthcheck):
flag: FeatureFlag = FeatureFlag.objects.create(
team=self.team,
created_by=self.user,
active=True,
key="active-flag",
filters={"groups": [{"properties": [], "rollout_percentage": 100}]},
)
flag2: FeatureFlag = FeatureFlag.objects.create(
team=self.team,
created_by=self.user,
active=True,
key="other-flag",
filters={
"groups": [
{"properties": [{"key": "tear", "value": "tear", "type": "person"}], "rollout_percentage": 100}
],
},
)

matcher = FeatureFlagMatcher([flag, flag2], "bxss.me/t/xss.html?\x00")

self.assertEqual(
matcher.get_matches(),
(
{"active-flag": True},
{
"active-flag": {
"condition_index": 0,
"reason": FeatureFlagMatchReason.CONDITION_MATCH,
}
},
{},
True,
),
)
self.assertEqual(matcher.failed_to_fetch_conditions, True)
mock_database_healthcheck.set_connection.assert_not_called()

# with operational error, should set db down
with connection.execute_wrapper(QueryTimeoutWrapper()):
matcher = FeatureFlagMatcher([flag, flag2], "bxss.me/t/xss.html")

self.assertEqual(
matcher.get_matches(),
(
{"active-flag": True},
{
"active-flag": {
"condition_index": 0,
"reason": FeatureFlagMatchReason.CONDITION_MATCH,
}
},
{},
True,
),
)
self.assertEqual(matcher.failed_to_fetch_conditions, True)
mock_database_healthcheck.set_connection.assert_called_once_with(False)

def test_legacy_rollout_percentage(self):
feature_flag = self.create_feature_flag(rollout_percentage=50)
self.assertEqual(
Expand Down
Loading