From df9fd36ab0ec888caf4e9e1197c403ff572318e2 Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Fri, 26 Apr 2024 13:33:30 +0100 Subject: [PATCH] fix(flags): Defend against invalid data on creation --- posthog/api/feature_flag.py | 20 ++++ .../test/__snapshots__/test_feature_flag.ambr | 21 ++++ posthog/api/test/test_feature_flag.py | 97 +++++++++++++++++++ posthog/models/feature_flag/flag_matching.py | 42 +++++++- posthog/test/base.py | 7 +- posthog/test/test_feature_flag.py | 62 ++++++++++++ 6 files changed, 246 insertions(+), 3 deletions(-) diff --git a/posthog/api/feature_flag.py b/posthog/api/feature_flag.py index bd53f02955252..40c0632a1b275 100644 --- a/posthog/api/feature_flag.py +++ b/posthog/api/feature_flag.py @@ -46,6 +46,7 @@ get_user_blast_radius, ) from posthog.models.feature_flag.flag_analytics import increment_request_count +from posthog.models.feature_flag.flag_matching import check_flag_evaluation_query_is_ok from posthog.models.feedback.survey import Survey from posthog.models.group_type_mapping import GroupTypeMapping from posthog.models.property import Property @@ -263,6 +264,22 @@ def properties_all_match(predicate): return filters + def check_flag_evaluation(self, data): + # this is a very rough simulation of the actual query that will be run. + # Only reason we do it this way is to catch any DB level errors that will bork at runtime + # but aren't caught by above validation, like a regex valid according to re2 but not postgresql. + # We also randomly query for 20 people sans distinct id to make sure the query is valid. + + # TODO: Once we move to no DB level evaluation, can get rid of this. + + temporary_flag = FeatureFlag(**data) + team_id = self.context["team_id"] + + try: + check_flag_evaluation_query_is_ok(temporary_flag, team_id) + except Exception as e: + raise serializers.ValidationError("Can't evaluate flag: " + str(e)) + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag: request = self.context["request"] validated_data["created_by"] = request.user @@ -289,6 +306,9 @@ def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag raise exceptions.ValidationError( "Feature flag with this key already exists and is used in an experiment. Please delete the experiment before deleting the flag." ) + + self.check_flag_evaluation(validated_data) + instance: FeatureFlag = super().create(validated_data) self._attempt_set_tags(tags, instance) diff --git a/posthog/api/test/__snapshots__/test_feature_flag.ambr b/posthog/api/test/__snapshots__/test_feature_flag.ambr index 26c0b9679f69d..518781ea5ad77 100644 --- a/posthog/api/test/__snapshots__/test_feature_flag.ambr +++ b/posthog/api/test/__snapshots__/test_feature_flag.ambr @@ -1500,6 +1500,27 @@ LIMIT 21 ''' # --- +# name: TestFeatureFlag.test_cant_create_flag_with_data_that_fails_to_query + ''' + SELECT (("posthog_person"."properties" ->> 'email')::text ~ '2.3.9{0-9}{1}' + AND "posthog_person"."properties" ? 'email' + AND NOT (("posthog_person"."properties" -> 'email') = 'null'::jsonb)) AS "flag_X_condition_0" + FROM "posthog_person" + WHERE "posthog_person"."team_id" = 2 + LIMIT 10 + ''' +# --- +# name: TestFeatureFlag.test_cant_create_flag_with_group_data_that_fails_to_query + ''' + SELECT (("posthog_group"."group_properties" ->> 'email')::text ~ '2.3.9{0-9}{1 ef}' + AND "posthog_group"."group_properties" ? 'email' + AND NOT (("posthog_group"."group_properties" -> 'email') = 'null'::jsonb)) AS "flag_X_condition_0" + FROM "posthog_group" + WHERE ("posthog_group"."group_type_index" = 2 + AND "posthog_group"."team_id" = 2) + LIMIT 10 + ''' +# --- # name: TestFeatureFlag.test_creating_static_cohort ''' SELECT "posthog_user"."id", diff --git a/posthog/api/test/test_feature_flag.py b/posthog/api/test/test_feature_flag.py index 4c353b98124df..b9186fdb9b790 100644 --- a/posthog/api/test/test_feature_flag.py +++ b/posthog/api/test/test_feature_flag.py @@ -4051,6 +4051,103 @@ def test_cant_update_early_access_flag_with_group(self): response.json(), ) + def test_cant_create_flag_with_data_that_fails_to_query(self): + Person.objects.create( + distinct_ids=["123"], + team=self.team, + properties={"email": "x y z"}, + ) + Person.objects.create( + distinct_ids=["456"], + team=self.team, + properties={"email": "2.3.999"}, + ) + + # Only snapshot flag evaluation queries + with snapshot_postgres_queries_context(self, custom_query_matcher=lambda query: "posthog_person" in query): + response = self.client.post( + f"/api/projects/{self.team.id}/feature_flags", + { + "name": "Beta feature", + "key": "beta-x", + "filters": { + "groups": [ + { + "rollout_percentage": 65, + "properties": [ + { + "key": "email", + "type": "person", + "value": "2.3.9{0-9}{1}", + "operator": "regex", + } + ], + } + ] + }, + }, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual( + response.json(), + { + "type": "validation_error", + "code": "invalid_input", + "detail": "Can't evaluate flag: invalid regular expression: invalid repetition count(s)", + "attr": None, + }, + ) + + def test_cant_create_flag_with_group_data_that_fails_to_query(self): + GroupTypeMapping.objects.create(team=self.team, group_type="organization", group_type_index=0) + GroupTypeMapping.objects.create(team=self.team, group_type="xyz", group_type_index=1) + + for i in range(5): + create_group( + team_id=self.team.pk, + group_type_index=1, + group_key=f"xyz:{i}", + properties={"industry": f"{i}", "email": "2.3.4445"}, + ) + + # Only snapshot flag evaluation queries + with snapshot_postgres_queries_context(self, custom_query_matcher=lambda query: "posthog_group" in query): + # Test group flag with invalid regex + response = self.client.post( + f"/api/projects/{self.team.id}/feature_flags", + { + "name": "Beta feature", + "key": "beta-x", + "filters": { + "aggregation_group_type_index": 1, + "groups": [ + { + "rollout_percentage": 65, + "properties": [ + { + "key": "email", + "type": "group", + "group_type_index": 1, + "value": "2.3.9{0-9}{1 ef}", + "operator": "regex", + } + ], + } + ], + }, + }, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual( + response.json(), + { + "type": "validation_error", + "code": "invalid_input", + "detail": "Can't evaluate flag: invalid regular expression: invalid repetition count(s)", + "attr": None, + }, + ) + class TestCohortGenerationForFeatureFlag(APIBaseTest, ClickhouseTestMixin): def test_creating_static_cohort_with_deleted_flag(self): diff --git a/posthog/models/feature_flag/flag_matching.py b/posthog/models/feature_flag/flag_matching.py index 0b4a6befebc94..890a61a96974d 100644 --- a/posthog/models/feature_flag/flag_matching.py +++ b/posthog/models/feature_flag/flag_matching.py @@ -7,7 +7,7 @@ from prometheus_client import Counter from django.conf import settings -from django.db import DatabaseError, IntegrityError +from django.db import DatabaseError, IntegrityError, DataError from django.db.models.expressions import ExpressionWrapper, RawSQL from django.db.models.fields import BooleanField from django.db.models import Q, Func, F, CharField @@ -981,7 +981,9 @@ def handle_feature_flag_exception(err: Exception, log_message: str = "", set_hea if reason == "unknown": capture_exception(err) - if isinstance(err, DatabaseError) and set_healthcheck: + # DataErrors are generally not because the db is down, but because of bad data. + # We don't want to set the healthcheck down for bad data. + if not isinstance(err, DataError) and isinstance(err, DatabaseError) and set_healthcheck: postgres_healthcheck.set_connection(False) @@ -1061,3 +1063,39 @@ def check_pure_is_not_operator_condition(condition: dict) -> bool: if properties and all(prop.get("operator") in ("is_not_set", "is_not") for prop in properties): return True return False + + +def check_flag_evaluation_query_is_ok(feature_flag: FeatureFlag, team_id: int) -> bool: + # TRICKY: There are some cases where the regex is valid re2 syntax, but postgresql doesn't like it. + # This function tries to validate such cases. + # It however doesn't catch all cases, like when the property doesn't exist on any person, which shortcircuits regex evaluation + # so it's not a guarantee that the query will work. + + group_type_index = feature_flag.aggregation_group_type_index + + base_query: QuerySet = ( + Person.objects.filter(team_id=team_id) + if group_type_index is None + else Group.objects.filter(team_id=team_id, group_type_index=group_type_index) + ) + query_fields = [] + + for index, condition in enumerate(feature_flag.conditions): + key = f"flag_0_condition_{index}" + property_list = Filter(data=condition).property_groups.flat + expr = properties_to_Q( + team_id, + property_list, + ) + base_query = base_query.annotate( + **{ + key: ExpressionWrapper( + expr if expr else RawSQL("true", []), + output_field=BooleanField(), + ), + } + ) + query_fields.append(key) + + values = base_query.values(*query_fields)[:10] + return len(values) > 0 diff --git a/posthog/test/base.py b/posthog/test/base.py index 2ebfa6178e259..1deaf261a9195 100644 --- a/posthog/test/base.py +++ b/posthog/test/base.py @@ -8,6 +8,7 @@ from contextlib import contextmanager from functools import wraps from typing import Any, Optional, Union +from collections.abc import Callable from collections.abc import Generator from unittest.mock import patch @@ -627,6 +628,7 @@ def snapshot_postgres_queries_context( replace_all_numbers: bool = True, using: str = "default", capture_all_queries: bool = False, + custom_query_matcher: Optional[Callable] = None, ): """ Captures and snapshots select queries from test using `syrupy` library. @@ -659,7 +661,10 @@ def test_something(self): for query_with_time in context.captured_queries: query = query_with_time["sql"] - if capture_all_queries: + if custom_query_matcher: + if query and custom_query_matcher(query): + testcase.assertQueryMatchesSnapshot(query, replace_all_numbers=replace_all_numbers) + elif capture_all_queries: testcase.assertQueryMatchesSnapshot(query, replace_all_numbers=replace_all_numbers) elif query and "SELECT" in query and "django_session" not in query and not re.match(r"^\s*INSERT", query): testcase.assertQueryMatchesSnapshot(query, replace_all_numbers=replace_all_numbers) diff --git a/posthog/test/test_feature_flag.py b/posthog/test/test_feature_flag.py index 91db555b31b9e..b04b4584208f5 100644 --- a/posthog/test/test_feature_flag.py +++ b/posthog/test/test_feature_flag.py @@ -10,6 +10,7 @@ from freezegun import freeze_time import pytest +from posthog.api.test.test_feature_flag import QueryTimeoutWrapper from posthog.models import Cohort, FeatureFlag, GroupTypeMapping, Person from posthog.models.feature_flag import get_feature_flags_for_team_in_cache from posthog.models.feature_flag.flag_matching import ( @@ -4406,6 +4407,67 @@ def test_invalid_group_filters_dont_set_db_down(self, mock_database_healthcheck) self.assertEqual(matcher.failed_to_fetch_conditions, False) mock_database_healthcheck.set_connection.assert_not_called() + @patch("posthog.models.feature_flag.flag_matching.postgres_healthcheck") + def test_data_errors_dont_set_db_down(self, mock_database_healthcheck): + flag: FeatureFlag = FeatureFlag.objects.create( + team=self.team, + created_by=self.user, + active=True, + key="active-flag", + filters={"groups": [{"properties": [], "rollout_percentage": 100}]}, + ) + flag2: FeatureFlag = FeatureFlag.objects.create( + team=self.team, + created_by=self.user, + active=True, + key="other-flag", + filters={ + "groups": [ + {"properties": [{"key": "tear", "value": "tear", "type": "person"}], "rollout_percentage": 100} + ], + }, + ) + + matcher = FeatureFlagMatcher([flag, flag2], "bxss.me/t/xss.html?\x00") + + self.assertEqual( + matcher.get_matches(), + ( + {"active-flag": True}, + { + "active-flag": { + "condition_index": 0, + "reason": FeatureFlagMatchReason.CONDITION_MATCH, + } + }, + {}, + True, + ), + ) + self.assertEqual(matcher.failed_to_fetch_conditions, True) + mock_database_healthcheck.set_connection.assert_not_called() + + # with operational error, should set db down + with connection.execute_wrapper(QueryTimeoutWrapper()): + matcher = FeatureFlagMatcher([flag, flag2], "bxss.me/t/xss.html") + + self.assertEqual( + matcher.get_matches(), + ( + {"active-flag": True}, + { + "active-flag": { + "condition_index": 0, + "reason": FeatureFlagMatchReason.CONDITION_MATCH, + } + }, + {}, + True, + ), + ) + self.assertEqual(matcher.failed_to_fetch_conditions, True) + mock_database_healthcheck.set_connection.assert_called_once_with(False) + def test_legacy_rollout_percentage(self): feature_flag = self.create_feature_flag(rollout_percentage=50) self.assertEqual(