Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hogql): cohort filter #14600

Merged
merged 16 commits into from
Mar 9, 2023
Merged
31 changes: 31 additions & 0 deletions posthog/hogql/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,35 @@ def hogql_table(self):
return "session_recording_events"


class CohortPeople(Table):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tangential to this PR but I wonder what stops these definitions and the actual table definitions from diverging other than people being careful

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

People are always careful :).

Good question. Noting really. I added "Lock down" tables for HogQL. Make sure we don't change something users might query on. Expose only what we deem relevant. as a point to the list of all lists.

person_id: StringDatabaseField = StringDatabaseField(name="person_id")
cohort_id: IntegerDatabaseField = IntegerDatabaseField(name="cohort_id")
team_id: IntegerDatabaseField = IntegerDatabaseField(name="team_id")
sign: IntegerDatabaseField = IntegerDatabaseField(name="sign")
version: IntegerDatabaseField = IntegerDatabaseField(name="version")

# TODO: automatically add "HAVING SUM(sign) > 0" to fields selected from this table?

person: LazyTable = LazyTable(from_field="person_id", table=PersonsTable(), join_function=join_with_persons_table)

def clickhouse_table(self):
return "cohortpeople"


class StaticCohortPeople(Table):
person_id: StringDatabaseField = StringDatabaseField(name="person_id")
cohort_id: IntegerDatabaseField = IntegerDatabaseField(name="cohort_id")
team_id: IntegerDatabaseField = IntegerDatabaseField(name="team_id")

person: LazyTable = LazyTable(from_field="person_id", table=PersonsTable(), join_function=join_with_persons_table)

def avoid_asterisk_fields(self):
return ["_timestamp", "_offset"]

def clickhouse_table(self):
return "person_static_cohort"


class Database(BaseModel):
class Config:
extra = Extra.forbid
Expand All @@ -287,6 +316,8 @@ class Config:
persons: PersonsTable = PersonsTable()
person_distinct_ids: PersonDistinctIdTable = PersonDistinctIdTable()
session_recording_events: SessionRecordingEvents = SessionRecordingEvents()
cohort_people: CohortPeople = CohortPeople()
static_cohort_people: StaticCohortPeople = StaticCohortPeople()

def has_table(self, table_name: str) -> bool:
return hasattr(self, table_name)
Expand Down
39 changes: 26 additions & 13 deletions posthog/hogql/property.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from posthog.hogql.constants import HOGQL_AGGREGATIONS
from posthog.hogql.parser import parse_expr
from posthog.hogql.visitor import TraversingVisitor
from posthog.models import Action, ActionStep, Property
from posthog.models import Action, ActionStep, Cohort, Property, Team
from posthog.models.event import Selector
from posthog.models.property import PropertyGroup
from posthog.models.property.util import build_selector_regex
Expand All @@ -32,6 +32,10 @@ def visit(self, node):
else:
super().visit(node)

def visit_select_query(self, node: ast.SelectQuery):
# don't care about aggregations in subqueries
pass

def visit_call(self, node: ast.Call):
if node.name in HOGQL_AGGREGATIONS:
self.has_aggregation = True
Expand All @@ -40,11 +44,13 @@ def visit_call(self, node: ast.Call):
self.visit(arg)


def property_to_expr(property: Union[BaseModel, PropertyGroup, Property, dict, list]) -> ast.Expr:
def property_to_expr(
property: Union[BaseModel, PropertyGroup, Property, dict, list], team: Optional[Team] = None
) -> ast.Expr:
if isinstance(property, dict):
property = Property(**property)
elif isinstance(property, list):
properties = [property_to_expr(p) for p in property]
properties = [property_to_expr(p, team) for p in property]
if len(properties) == 1:
return properties[0]
return ast.And(exprs=properties)
Expand All @@ -53,12 +59,12 @@ def property_to_expr(property: Union[BaseModel, PropertyGroup, Property, dict, l
elif isinstance(property, PropertyGroup):
if property.type == PropertyOperatorType.AND:
if len(property.values) == 1:
return property_to_expr(property.values[0])
return ast.And(exprs=[property_to_expr(p) for p in property.values])
return property_to_expr(property.values[0], team)
return ast.And(exprs=[property_to_expr(p, team) for p in property.values])
if property.type == PropertyOperatorType.OR:
if len(property.values) == 1:
return property_to_expr(property.values[0])
return ast.Or(exprs=[property_to_expr(p) for p in property.values])
return property_to_expr(property.values[0], team)
return ast.Or(exprs=[property_to_expr(p, team) for p in property.values])
raise NotImplementedError(f'PropertyGroup of unknown type "{property.type}"')
elif isinstance(property, BaseModel):
property = Property(**property.dict())
Expand All @@ -76,7 +82,7 @@ def property_to_expr(property: Union[BaseModel, PropertyGroup, Property, dict, l
else:
exprs = [
property_to_expr(
Property(type=property.type, key=property.key, operator=property.operator, value=v)
Property(type=property.type, key=property.key, operator=property.operator, value=v), team
)
for v in value
]
Expand Down Expand Up @@ -137,7 +143,7 @@ def property_to_expr(property: Union[BaseModel, PropertyGroup, Property, dict, l
else:
exprs = [
property_to_expr(
Property(type=property.type, key=property.key, operator=property.operator, value=v)
Property(type=property.type, key=property.key, operator=property.operator, value=v), team
)
for v in value
]
Expand Down Expand Up @@ -166,10 +172,17 @@ def property_to_expr(property: Union[BaseModel, PropertyGroup, Property, dict, l
return element_chain_key_filter("text", str(value), operator)

raise NotImplementedError(f"property_to_expr for type element not implemented for key {property.key}")
# "cohort",
# "element",
# "static-cohort",
# "precalculated-cohort",
elif property.type == "cohort" or property.type == "static-cohort" or property.type == "precalculated-cohort":
if not team:
raise Exception("Can not convert cohort property to expression without team")
cohort = Cohort.objects.get(team=team, id=property.value)

if cohort.is_static:
sql = "person_id in (SELECT person_id FROM static_cohort_people WHERE cohort_id = {cohort_id})"
else:
sql = "person_id in (SELECT person_id FROM cohort_people WHERE cohort_id = {cohort_id} GROUP BY person_id, cohort_id, version HAVING sum(sign) > 0)"

return parse_expr(sql, {"cohort_id": ast.Constant(value=cohort.pk)})
# "group",
# "recording",
# "behavioral",
Expand Down
24 changes: 23 additions & 1 deletion posthog/hogql/test/test_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
selector_to_expr,
tag_name_to_expr,
)
from posthog.models import Action, ActionStep, Property
from posthog.models import Action, ActionStep, Cohort, Property
from posthog.models.property import PropertyGroup
from posthog.schema import HogQLPropertyFilter, PropertyOperator
from posthog.test.base import BaseTest
Expand Down Expand Up @@ -323,3 +323,25 @@ def test_action_to_expr(self):
},
),
)

def test_cohort_filter_static(self):
cohort = Cohort.objects.create(
team=self.team,
is_static=True,
groups=[{"properties": [{"key": "$os", "value": "Chrome", "type": "person"}]}],
)
self.assertEqual(
property_to_expr({"type": "cohort", "key": "id", "value": cohort.pk}, self.team),
parse_expr(f"person_id IN (SELECT person_id FROM static_cohort_people WHERE cohort_id = {cohort.pk})"),
)

def test_cohort_filter_dynamic(self):
cohort = Cohort.objects.create(
team=self.team, groups=[{"properties": [{"key": "$os", "value": "Chrome", "type": "person"}]}]
)
self.assertEqual(
property_to_expr({"type": "cohort", "key": "id", "value": cohort.pk}, self.team),
parse_expr(
f"person_id IN (SELECT person_id FROM cohort_people WHERE cohort_id = {cohort.pk} GROUP BY person_id, cohort_id, version HAVING sum(sign) > 0)"
),
)
63 changes: 63 additions & 0 deletions posthog/hogql/test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from posthog import datetime
from posthog.hogql import ast
from posthog.hogql.property import property_to_expr
from posthog.hogql.query import execute_hogql_query
from posthog.models import Cohort
from posthog.models.cohort.util import recalculate_cohortpeople
from posthog.models.utils import UUIDT
from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person, flush_persons_and_events

Expand Down Expand Up @@ -411,3 +414,63 @@ def test_select_person_on_events(self):
"SELECT poe.properties.email, count() FROM events AS s GROUP BY poe.properties.email LIMIT 10",
)
self.assertEqual(response.results[0][0], "[email protected]")

def test_prop_cohort_basic(self):
with freeze_time("2020-01-10"):
_create_person(distinct_ids=["some_other_id"], team_id=self.team.pk, properties={"$some_prop": "something"})
_create_person(
distinct_ids=["some_id"],
team_id=self.team.pk,
properties={"$some_prop": "something", "$another_prop": "something"},
)
_create_person(distinct_ids=["no_match"], team_id=self.team.pk)
_create_event(event="$pageview", team=self.team, distinct_id="some_id", properties={"attr": "some_val"})
_create_event(
event="$pageview", team=self.team, distinct_id="some_other_id", properties={"attr": "some_val"}
)
cohort = Cohort.objects.create(
team=self.team,
groups=[{"properties": [{"key": "$some_prop", "value": "something", "type": "person"}]}],
name="cohort",
)
recalculate_cohortpeople(cohort, pending_version=0)
response = execute_hogql_query(
"SELECT event, count() FROM events WHERE {cohort_filter} GROUP BY event",
team=self.team,
placeholders={
"cohort_filter": property_to_expr({"type": "cohort", "key": "id", "value": cohort.pk}, self.team)
},
)
self.assertEqual(response.results, [("$pageview", 2)])
self.assertEqual(
response.clickhouse,
f"SELECT event, count(*) FROM events INNER JOIN (SELECT argMax(person_distinct_id2.person_id, version) AS person_id, distinct_id FROM person_distinct_id2 WHERE equals(team_id, {self.team.pk}) GROUP BY distinct_id HAVING equals(argMax(is_deleted, version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) WHERE and(equals(team_id, {self.team.pk}), in(events__pdi.person_id, (SELECT person_id FROM cohortpeople WHERE and(equals(team_id, {self.team.pk}), equals(cohort_id, {cohort.pk})) GROUP BY person_id, cohort_id, version HAVING greater(sum(sign), 0)))) GROUP BY event LIMIT 100",
)

def test_prop_cohort_static(self):
with freeze_time("2020-01-10"):
_create_person(distinct_ids=["some_other_id"], team_id=self.team.pk, properties={"$some_prop": "something"})
_create_person(
distinct_ids=["some_id"],
team_id=self.team.pk,
properties={"$some_prop": "something", "$another_prop": "something"},
)
_create_person(distinct_ids=["no_match"], team_id=self.team.pk)
_create_event(event="$pageview", team=self.team, distinct_id="some_id", properties={"attr": "some_val"})
_create_event(
event="$pageview", team=self.team, distinct_id="some_other_id", properties={"attr": "some_val"}
)
cohort = Cohort.objects.create(team=self.team, groups=[], is_static=True)
cohort.insert_users_by_list(["some_id"])
response = execute_hogql_query(
"SELECT event, count() FROM events WHERE {cohort_filter} GROUP BY event",
team=self.team,
placeholders={
"cohort_filter": property_to_expr({"type": "cohort", "key": "id", "value": cohort.pk}, self.team)
},
)
self.assertEqual(response.results, [("$pageview", 1)])
self.assertEqual(
response.clickhouse,
f"SELECT event, count(*) FROM events INNER JOIN (SELECT argMax(person_distinct_id2.person_id, version) AS person_id, distinct_id FROM person_distinct_id2 WHERE equals(team_id, {self.team.pk}) GROUP BY distinct_id HAVING equals(argMax(is_deleted, version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) WHERE and(equals(team_id, {self.team.pk}), in(events__pdi.person_id, (SELECT person_id FROM person_static_cohort WHERE and(equals(team_id, {self.team.pk}), equals(cohort_id, {cohort.pk}))))) GROUP BY event LIMIT 100",
)
4 changes: 2 additions & 2 deletions posthog/models/event/events_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def run_events_query(
where_input = query.where or []
where_exprs = [parse_expr(expr) for expr in where_input]
if query.properties:
where_exprs.extend(property_to_expr(property) for property in query.properties)
where_exprs.extend(property_to_expr(property, team) for property in query.properties)
if query.fixedProperties:
where_exprs.extend(property_to_expr(property) for property in query.fixedProperties)
where_exprs.extend(property_to_expr(property, team) for property in query.fixedProperties)
if query.event:
where_exprs.append(parse_expr("event = {event}", {"event": ast.Constant(value=query.event)}))
if query.actionId:
Expand Down