Skip to content

Commit

Permalink
fix(flags): Do the expected for numeric comparisons (#18359)
Browse files Browse the repository at this point in the history
  • Loading branch information
neilkakkar authored Nov 6, 2023
1 parent 11b3745 commit 3847340
Show file tree
Hide file tree
Showing 8 changed files with 693 additions and 53 deletions.
36 changes: 36 additions & 0 deletions ee/clickhouse/models/test/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,42 @@ def test_numerical(self):
events = _filter_events(filter, self.team)
self.assertEqual(events[0]["id"], event1_uuid)

def test_numerical_person_properties(self):
_create_person(team_id=self.team.pk, distinct_ids=["p1"], properties={"$a_number": 4})
_create_person(team_id=self.team.pk, distinct_ids=["p2"], properties={"$a_number": 5})
_create_person(team_id=self.team.pk, distinct_ids=["p3"], properties={"$a_number": 6})

filter = Filter(
data={
"properties": [
{
"type": "person",
"key": "$a_number",
"value": 4,
"operator": "gt",
}
]
}
)
self.assertEqual(len(_filter_persons(filter, self.team)), 2)

filter = Filter(data={"properties": [{"type": "person", "key": "$a_number", "value": 5}]})
self.assertEqual(len(_filter_persons(filter, self.team)), 1)

filter = Filter(
data={
"properties": [
{
"type": "person",
"key": "$a_number",
"value": 6,
"operator": "lt",
}
]
}
)
self.assertEqual(len(_filter_persons(filter, self.team)), 2)

def test_contains(self):
_create_event(team=self.team, distinct_id="test", event="$pageview")
event2_uuid = _create_event(
Expand Down
43 changes: 38 additions & 5 deletions posthog/models/feature_flag/flag_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from django.db import DatabaseError, IntegrityError, OperationalError
from django.db.models.expressions import ExpressionWrapper, RawSQL
from django.db.models.fields import BooleanField
from django.db.models import Q
from django.db.models import Q, Func, F, CharField
from django.db.models.query import QuerySet
from sentry_sdk.api import capture_exception, start_span
from posthog.metrics import LABEL_TEAM_ID
Expand Down Expand Up @@ -396,6 +396,13 @@ def condition_eval(key, condition):
annotate_query = True
nonlocal person_query

property_list = Filter(data=condition).property_groups.flat
properties_with_math_operators = [
key_and_field_for_property(prop)
for prop in property_list
if prop.operator in ["gt", "lt", "gte", "lte"]
]

if len(condition.get("properties", {})) > 0:
# Feature Flags don't support OR filtering yet
target_properties = self.property_value_overrides
Expand All @@ -404,8 +411,9 @@ def condition_eval(key, condition):
self.cache.group_type_index_to_name[feature_flag.aggregation_group_type_index],
{},
)

expr = properties_to_Q(
Filter(data=condition).property_groups.flat,
property_list,
override_property_values=target_properties,
cohorts_cache=self.cohorts_cache,
using_database=DATABASE_FOR_FLAG_MATCHING,
Expand All @@ -428,13 +436,24 @@ def condition_eval(key, condition):

if annotate_query:
if feature_flag.aggregation_group_type_index is None:
# :TRICKY: Flag matching depends on type of property when doing >, <, >=, <= comparisons.
# This requires a generated field to query in Q objects, which sadly don't allow inlining fields,
# hence we need to annotate the query here, even though these annotations are used much deeper,
# in properties_to_q, in empty_or_null_with_value_q
# These need to come in before the expr so they're available to use inside the expr.
# Same holds for the group queries below.
type_property_annotations = {
prop_key: Func(F(prop_field), function="JSONB_TYPEOF", output_field=CharField())
for prop_key, prop_field in properties_with_math_operators
}
person_query = person_query.annotate(
**type_property_annotations,
**{
key: ExpressionWrapper(
expr if expr else RawSQL("true", []),
output_field=BooleanField(),
)
}
),
},
)
person_fields.append(key)
else:
Expand All @@ -445,13 +464,18 @@ def condition_eval(key, condition):
group_query,
group_fields,
) = group_query_per_group_type_mapping[feature_flag.aggregation_group_type_index]
type_property_annotations = {
prop_key: Func(F(prop_field), function="JSONB_TYPEOF", output_field=CharField())
for prop_key, prop_field in properties_with_math_operators
}
group_query = group_query.annotate(
**type_property_annotations,
**{
key: ExpressionWrapper(
expr if expr else RawSQL("true", []),
output_field=BooleanField(),
)
}
},
)
group_fields.append(key)
group_query_per_group_type_mapping[feature_flag.aggregation_group_type_index] = (
Expand Down Expand Up @@ -881,3 +905,12 @@ def parse_exception_for_error_message(err: Exception):
reason = "query_wait_timeout"

return reason


def key_and_field_for_property(property: Property) -> Tuple[str, str]:
column = "group_properties" if property.type == "group" else "properties"
key = property.key
return (
f"{column}_{key}_type",
f"{column}__{key}",
)
88 changes: 51 additions & 37 deletions posthog/models/filters/test/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from typing import Any, Callable, Dict, List, Optional, cast

from django.db.models import Q
from django.db.models import Q, Func, F, CharField

from posthog.constants import FILTER_TEST_ACCOUNTS
from posthog.models import Cohort, Filter, Person, Team
Expand Down Expand Up @@ -219,42 +219,6 @@ def test_incomplete_data(self):
)
self.assertListEqual(filter.property_groups.values, [])

def test_numerical_person_properties(self):
person_factory(team_id=self.team.pk, distinct_ids=["p1"], properties={"$a_number": 4})
person_factory(team_id=self.team.pk, distinct_ids=["p2"], properties={"$a_number": 5})
person_factory(team_id=self.team.pk, distinct_ids=["p3"], properties={"$a_number": 6})

filter = Filter(
data={
"properties": [
{
"type": "person",
"key": "$a_number",
"value": 4,
"operator": "gt",
}
]
}
)
self.assertEqual(len(filter_persons(filter, self.team)), 2)

filter = Filter(data={"properties": [{"type": "person", "key": "$a_number", "value": 5}]})
self.assertEqual(len(filter_persons(filter, self.team)), 1)

filter = Filter(
data={
"properties": [
{
"type": "person",
"key": "$a_number",
"value": 6,
"operator": "lt",
}
]
}
)
self.assertEqual(len(filter_persons(filter, self.team)), 2)

def test_contains_persons(self):
person_factory(
team_id=self.team.pk,
Expand Down Expand Up @@ -819,6 +783,56 @@ def _filter_with_date_range(

return Filter(data=data)

def test_numerical_person_properties(self):
_create_person(team_id=self.team.pk, distinct_ids=["p1"], properties={"$a_number": 4})
_create_person(team_id=self.team.pk, distinct_ids=["p2"], properties={"$a_number": 5})
_create_person(team_id=self.team.pk, distinct_ids=["p3"], properties={"$a_number": 6})
_create_person(team_id=self.team.pk, distinct_ids=["p4"], properties={"$a_number": 14})

flush_persons_and_events()

def filter_persons_with_annotation(filter: Filter, team: Team):
persons = Person.objects.annotate(
**{
"properties_$a_number_type": Func(
F("properties__$a_number"), function="JSONB_TYPEOF", output_field=CharField()
)
}
).filter(properties_to_Q(filter.property_groups.flat))
persons = persons.filter(team_id=team.pk)
return [str(uuid) for uuid in persons.values_list("uuid", flat=True)]

filter = Filter(
data={
"properties": [
{
"type": "person",
"key": "$a_number",
"value": "4",
"operator": "gt",
}
]
}
)
self.assertEqual(len(filter_persons_with_annotation(filter, self.team)), 3)

filter = Filter(data={"properties": [{"type": "person", "key": "$a_number", "value": 5}]})
self.assertEqual(len(filter_persons_with_annotation(filter, self.team)), 1)

filter = Filter(
data={
"properties": [
{
"type": "person",
"key": "$a_number",
"value": 6,
"operator": "lt",
}
]
}
)
self.assertEqual(len(filter_persons_with_annotation(filter, self.team)), 2)


def filter_persons_with_property_group(
filter: Filter, team: Team, property_overrides: Dict[str, Any] = {}
Expand Down
56 changes: 50 additions & 6 deletions posthog/queries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

from dateutil import parser
from django.db.models import Exists, OuterRef, Q
from django.db.models import Exists, OuterRef, Q, Value
from rest_framework.exceptions import ValidationError

from posthog.constants import PropertyOperatorType
Expand All @@ -29,10 +29,10 @@
from posthog.queries.util import convert_to_datetime_aware
from posthog.utils import get_compare_period_dates, is_valid_regex

F = TypeVar("F", Filter, PathFilter)
FilterType = TypeVar("FilterType", Filter, PathFilter)


def determine_compared_filter(filter: F) -> F:
def determine_compared_filter(filter: FilterType) -> FilterType:
if not filter.date_to or not filter.date_from:
raise ValidationError("You need date_from and date_to to compare")
date_from, date_to = get_compare_period_dates(
Expand Down Expand Up @@ -142,8 +142,34 @@ def compute_exact_match(value: ValueT, override_value: Any) -> bool:
except re.error:
return False

if operator == "gt":
return type(override_value) == type(value) and override_value > value
if operator in ("gt", "gte", "lt", "lte"):
# :TRICKY: We adjust comparison based on the override value passed in,
# to make sure we handle both numeric and string comparisons appropriately.
def compare(lhs, rhs, operator):
if operator == "gt":
return lhs > rhs
elif operator == "gte":
return lhs >= rhs
elif operator == "lt":
return lhs < rhs
elif operator == "lte":
return lhs <= rhs
else:
raise ValueError(f"Invalid operator: {operator}")

parsed_value = None
try:
parsed_value = float(value) # type: ignore
except Exception:
pass

if parsed_value is not None:
if isinstance(override_value, str):
return compare(override_value, str(value), operator)
else:
return compare(override_value, parsed_value, operator)
else:
return compare(str(override_value), str(value), operator)

if operator == "gte":
return type(override_value) == type(value) and override_value >= value
Expand Down Expand Up @@ -207,7 +233,25 @@ def empty_or_null_with_value_q(
f"{column}__{key}", value_as_coerced_to_number
)
else:
target_filter = Q(**{f"{column}__{key}__{operator}": value})
if isinstance(value, list):
raise TypeError(f"empty_or_null_with_value_q: Operator {operator} does not support list values")

parsed_value = None
if operator in ("gt", "gte", "lt", "lte"):
try:
parsed_value = float(value)
except (ValueError, TypeError):
pass

if parsed_value is not None:
# When we can coerce given value to a number, check whether the value in DB is a number
# and do a numeric comparison. Otherwise, do a string comparison.
target_filter = Q(
Q(**{f"{column}__{key}__{operator}": str(value), f"{column}_{key}_type": Value("string")})
| Q(**{f"{column}__{key}__{operator}": parsed_value, f"{column}_{key}_type": Value("number")})
)
else:
target_filter = Q(**{f"{column}__{key}__{operator}": value})

query_filter = Q(target_filter & Q(**{f"{column}__has_key": key}) & ~Q(**{f"{column}__{key}": None}))

Expand Down
21 changes: 19 additions & 2 deletions posthog/queries/test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def test_match_properties_math_operators(self):

self.assertFalse(match_property(property_a, {"key": 0}))
self.assertFalse(match_property(property_a, {"key": -1}))
self.assertFalse(match_property(property_a, {"key": "23"}))
# now we handle type mismatches so this should be true
self.assertTrue(match_property(property_a, {"key": "23"}))

property_b = Property(key="key", value=1, operator="lt")
self.assertTrue(match_property(property_b, {"key": 0}))
Expand All @@ -171,16 +172,32 @@ def test_match_properties_math_operators(self):

self.assertFalse(match_property(property_c, {"key": 0}))
self.assertFalse(match_property(property_c, {"key": -1}))
self.assertFalse(match_property(property_c, {"key": "3"}))
# now we handle type mismatches so this should be true
self.assertTrue(match_property(property_c, {"key": "3"}))

property_d = Property(key="key", value="43", operator="lt")
self.assertTrue(match_property(property_d, {"key": "41"}))
self.assertTrue(match_property(property_d, {"key": "42"}))
self.assertTrue(match_property(property_d, {"key": 42}))

self.assertFalse(match_property(property_d, {"key": "43"}))
self.assertFalse(match_property(property_d, {"key": "44"}))
self.assertFalse(match_property(property_d, {"key": 44}))

property_e = Property(key="key", value="30", operator="lt")
self.assertTrue(match_property(property_e, {"key": "29"}))

# depending on the type of override, we adjust type comparison
self.assertTrue(match_property(property_e, {"key": "100"}))
self.assertFalse(match_property(property_e, {"key": 100}))

property_f = Property(key="key", value="123aloha", operator="gt")
self.assertFalse(match_property(property_f, {"key": "123"}))
self.assertFalse(match_property(property_f, {"key": 122}))

# this turns into a string comparison
self.assertTrue(match_property(property_f, {"key": 129}))

def test_match_property_date_operators(self):
property_a = Property(key="key", value="2022-05-01", operator="is_date_before")
self.assertTrue(match_property(property_a, {"key": "2022-03-01"}))
Expand Down
Loading

0 comments on commit 3847340

Please sign in to comment.