diff --git a/posthog/hogql/database.py b/posthog/hogql/database.py index 5244c37b49cad..680677d6991fe 100644 --- a/posthog/hogql/database.py +++ b/posthog/hogql/database.py @@ -278,6 +278,35 @@ def hogql_table(self): return "session_recording_events" +class CohortPeople(Table): + person_id: StringDatabaseField = StringDatabaseField(name="person_id") + cohort_id: IntegerDatabaseField = IntegerDatabaseField(name="cohort_id") + team_id: IntegerDatabaseField = IntegerDatabaseField(name="team_id") + sign: IntegerDatabaseField = IntegerDatabaseField(name="sign") + version: IntegerDatabaseField = IntegerDatabaseField(name="version") + + # TODO: automatically add "HAVING SUM(sign) > 0" to fields selected from this table? + + person: LazyTable = LazyTable(from_field="person_id", table=PersonsTable(), join_function=join_with_persons_table) + + def clickhouse_table(self): + return "cohortpeople" + + +class StaticCohortPeople(Table): + person_id: StringDatabaseField = StringDatabaseField(name="person_id") + cohort_id: IntegerDatabaseField = IntegerDatabaseField(name="cohort_id") + team_id: IntegerDatabaseField = IntegerDatabaseField(name="team_id") + + person: LazyTable = LazyTable(from_field="person_id", table=PersonsTable(), join_function=join_with_persons_table) + + def avoid_asterisk_fields(self): + return ["_timestamp", "_offset"] + + def clickhouse_table(self): + return "person_static_cohort" + + class Database(BaseModel): class Config: extra = Extra.forbid @@ -287,6 +316,8 @@ class Config: persons: PersonsTable = PersonsTable() person_distinct_ids: PersonDistinctIdTable = PersonDistinctIdTable() session_recording_events: SessionRecordingEvents = SessionRecordingEvents() + cohort_people: CohortPeople = CohortPeople() + static_cohort_people: StaticCohortPeople = StaticCohortPeople() def has_table(self, table_name: str) -> bool: return hasattr(self, table_name) diff --git a/posthog/hogql/property.py b/posthog/hogql/property.py index a0de118726e28..e3e6bab132e17 100644 --- a/posthog/hogql/property.py +++ b/posthog/hogql/property.py @@ -8,7 +8,7 @@ from posthog.hogql.constants import HOGQL_AGGREGATIONS from posthog.hogql.parser import parse_expr from posthog.hogql.visitor import TraversingVisitor -from posthog.models import Action, ActionStep, Property +from posthog.models import Action, ActionStep, Cohort, Property, Team from posthog.models.event import Selector from posthog.models.property import PropertyGroup from posthog.models.property.util import build_selector_regex @@ -32,6 +32,10 @@ def visit(self, node): else: super().visit(node) + def visit_select_query(self, node: ast.SelectQuery): + # don't care about aggregations in subqueries + pass + def visit_call(self, node: ast.Call): if node.name in HOGQL_AGGREGATIONS: self.has_aggregation = True @@ -40,11 +44,13 @@ def visit_call(self, node: ast.Call): self.visit(arg) -def property_to_expr(property: Union[BaseModel, PropertyGroup, Property, dict, list]) -> ast.Expr: +def property_to_expr( + property: Union[BaseModel, PropertyGroup, Property, dict, list], team: Optional[Team] = None +) -> ast.Expr: if isinstance(property, dict): property = Property(**property) elif isinstance(property, list): - properties = [property_to_expr(p) for p in property] + properties = [property_to_expr(p, team) for p in property] if len(properties) == 1: return properties[0] return ast.And(exprs=properties) @@ -53,12 +59,12 @@ def property_to_expr(property: Union[BaseModel, PropertyGroup, Property, dict, l elif isinstance(property, PropertyGroup): if property.type == PropertyOperatorType.AND: if len(property.values) == 1: - return property_to_expr(property.values[0]) - return ast.And(exprs=[property_to_expr(p) for p in property.values]) + return property_to_expr(property.values[0], team) + return ast.And(exprs=[property_to_expr(p, team) for p in property.values]) if property.type == PropertyOperatorType.OR: if len(property.values) == 1: - return property_to_expr(property.values[0]) - return ast.Or(exprs=[property_to_expr(p) for p in property.values]) + return property_to_expr(property.values[0], team) + return ast.Or(exprs=[property_to_expr(p, team) for p in property.values]) raise NotImplementedError(f'PropertyGroup of unknown type "{property.type}"') elif isinstance(property, BaseModel): property = Property(**property.dict()) @@ -76,7 +82,7 @@ def property_to_expr(property: Union[BaseModel, PropertyGroup, Property, dict, l else: exprs = [ property_to_expr( - Property(type=property.type, key=property.key, operator=property.operator, value=v) + Property(type=property.type, key=property.key, operator=property.operator, value=v), team ) for v in value ] @@ -137,7 +143,7 @@ def property_to_expr(property: Union[BaseModel, PropertyGroup, Property, dict, l else: exprs = [ property_to_expr( - Property(type=property.type, key=property.key, operator=property.operator, value=v) + Property(type=property.type, key=property.key, operator=property.operator, value=v), team ) for v in value ] @@ -166,10 +172,17 @@ def property_to_expr(property: Union[BaseModel, PropertyGroup, Property, dict, l return element_chain_key_filter("text", str(value), operator) raise NotImplementedError(f"property_to_expr for type element not implemented for key {property.key}") - # "cohort", - # "element", - # "static-cohort", - # "precalculated-cohort", + elif property.type == "cohort" or property.type == "static-cohort" or property.type == "precalculated-cohort": + if not team: + raise Exception("Can not convert cohort property to expression without team") + cohort = Cohort.objects.get(team=team, id=property.value) + + if cohort.is_static: + sql = "person_id in (SELECT person_id FROM static_cohort_people WHERE cohort_id = {cohort_id})" + else: + sql = "person_id in (SELECT person_id FROM cohort_people WHERE cohort_id = {cohort_id} GROUP BY person_id, cohort_id, version HAVING sum(sign) > 0)" + + return parse_expr(sql, {"cohort_id": ast.Constant(value=cohort.pk)}) # "group", # "recording", # "behavioral", diff --git a/posthog/hogql/test/test_property.py b/posthog/hogql/test/test_property.py index c3a1da5fe5209..5afb188a08ea2 100644 --- a/posthog/hogql/test/test_property.py +++ b/posthog/hogql/test/test_property.py @@ -11,7 +11,7 @@ selector_to_expr, tag_name_to_expr, ) -from posthog.models import Action, ActionStep, Property +from posthog.models import Action, ActionStep, Cohort, Property from posthog.models.property import PropertyGroup from posthog.schema import HogQLPropertyFilter, PropertyOperator from posthog.test.base import BaseTest @@ -323,3 +323,25 @@ def test_action_to_expr(self): }, ), ) + + def test_cohort_filter_static(self): + cohort = Cohort.objects.create( + team=self.team, + is_static=True, + groups=[{"properties": [{"key": "$os", "value": "Chrome", "type": "person"}]}], + ) + self.assertEqual( + property_to_expr({"type": "cohort", "key": "id", "value": cohort.pk}, self.team), + parse_expr(f"person_id IN (SELECT person_id FROM static_cohort_people WHERE cohort_id = {cohort.pk})"), + ) + + def test_cohort_filter_dynamic(self): + cohort = Cohort.objects.create( + team=self.team, groups=[{"properties": [{"key": "$os", "value": "Chrome", "type": "person"}]}] + ) + self.assertEqual( + property_to_expr({"type": "cohort", "key": "id", "value": cohort.pk}, self.team), + parse_expr( + f"person_id IN (SELECT person_id FROM cohort_people WHERE cohort_id = {cohort.pk} GROUP BY person_id, cohort_id, version HAVING sum(sign) > 0)" + ), + ) diff --git a/posthog/hogql/test/test_query.py b/posthog/hogql/test/test_query.py index 3b2726401e467..ecb3c9b4b440d 100644 --- a/posthog/hogql/test/test_query.py +++ b/posthog/hogql/test/test_query.py @@ -4,7 +4,10 @@ from posthog import datetime from posthog.hogql import ast +from posthog.hogql.property import property_to_expr from posthog.hogql.query import execute_hogql_query +from posthog.models import Cohort +from posthog.models.cohort.util import recalculate_cohortpeople from posthog.models.utils import UUIDT from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person, flush_persons_and_events @@ -411,3 +414,63 @@ def test_select_person_on_events(self): "SELECT poe.properties.email, count() FROM events AS s GROUP BY poe.properties.email LIMIT 10", ) self.assertEqual(response.results[0][0], "tim@posthog.com") + + def test_prop_cohort_basic(self): + with freeze_time("2020-01-10"): + _create_person(distinct_ids=["some_other_id"], team_id=self.team.pk, properties={"$some_prop": "something"}) + _create_person( + distinct_ids=["some_id"], + team_id=self.team.pk, + properties={"$some_prop": "something", "$another_prop": "something"}, + ) + _create_person(distinct_ids=["no_match"], team_id=self.team.pk) + _create_event(event="$pageview", team=self.team, distinct_id="some_id", properties={"attr": "some_val"}) + _create_event( + event="$pageview", team=self.team, distinct_id="some_other_id", properties={"attr": "some_val"} + ) + cohort = Cohort.objects.create( + team=self.team, + groups=[{"properties": [{"key": "$some_prop", "value": "something", "type": "person"}]}], + name="cohort", + ) + recalculate_cohortpeople(cohort, pending_version=0) + response = execute_hogql_query( + "SELECT event, count() FROM events WHERE {cohort_filter} GROUP BY event", + team=self.team, + placeholders={ + "cohort_filter": property_to_expr({"type": "cohort", "key": "id", "value": cohort.pk}, self.team) + }, + ) + self.assertEqual(response.results, [("$pageview", 2)]) + self.assertEqual( + response.clickhouse, + f"SELECT event, count(*) FROM events INNER JOIN (SELECT argMax(person_distinct_id2.person_id, version) AS person_id, distinct_id FROM person_distinct_id2 WHERE equals(team_id, {self.team.pk}) GROUP BY distinct_id HAVING equals(argMax(is_deleted, version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) WHERE and(equals(team_id, {self.team.pk}), in(events__pdi.person_id, (SELECT person_id FROM cohortpeople WHERE and(equals(team_id, {self.team.pk}), equals(cohort_id, {cohort.pk})) GROUP BY person_id, cohort_id, version HAVING greater(sum(sign), 0)))) GROUP BY event LIMIT 100", + ) + + def test_prop_cohort_static(self): + with freeze_time("2020-01-10"): + _create_person(distinct_ids=["some_other_id"], team_id=self.team.pk, properties={"$some_prop": "something"}) + _create_person( + distinct_ids=["some_id"], + team_id=self.team.pk, + properties={"$some_prop": "something", "$another_prop": "something"}, + ) + _create_person(distinct_ids=["no_match"], team_id=self.team.pk) + _create_event(event="$pageview", team=self.team, distinct_id="some_id", properties={"attr": "some_val"}) + _create_event( + event="$pageview", team=self.team, distinct_id="some_other_id", properties={"attr": "some_val"} + ) + cohort = Cohort.objects.create(team=self.team, groups=[], is_static=True) + cohort.insert_users_by_list(["some_id"]) + response = execute_hogql_query( + "SELECT event, count() FROM events WHERE {cohort_filter} GROUP BY event", + team=self.team, + placeholders={ + "cohort_filter": property_to_expr({"type": "cohort", "key": "id", "value": cohort.pk}, self.team) + }, + ) + self.assertEqual(response.results, [("$pageview", 1)]) + self.assertEqual( + response.clickhouse, + f"SELECT event, count(*) FROM events INNER JOIN (SELECT argMax(person_distinct_id2.person_id, version) AS person_id, distinct_id FROM person_distinct_id2 WHERE equals(team_id, {self.team.pk}) GROUP BY distinct_id HAVING equals(argMax(is_deleted, version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) WHERE and(equals(team_id, {self.team.pk}), in(events__pdi.person_id, (SELECT person_id FROM person_static_cohort WHERE and(equals(team_id, {self.team.pk}), equals(cohort_id, {cohort.pk}))))) GROUP BY event LIMIT 100", + ) diff --git a/posthog/models/event/events_query.py b/posthog/models/event/events_query.py index 21f98d65db69a..93c008ee7b98d 100644 --- a/posthog/models/event/events_query.py +++ b/posthog/models/event/events_query.py @@ -74,9 +74,9 @@ def run_events_query( where_input = query.where or [] where_exprs = [parse_expr(expr) for expr in where_input] if query.properties: - where_exprs.extend(property_to_expr(property) for property in query.properties) + where_exprs.extend(property_to_expr(property, team) for property in query.properties) if query.fixedProperties: - where_exprs.extend(property_to_expr(property) for property in query.fixedProperties) + where_exprs.extend(property_to_expr(property, team) for property in query.fixedProperties) if query.event: where_exprs.append(parse_expr("event = {event}", {"event": ast.Constant(value=query.event)})) if query.actionId: