diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py index a75e9171dbc42..6c03b05a14418 100644 --- a/posthog/hogql/database/database.py +++ b/posthog/hogql/database/database.py @@ -33,6 +33,7 @@ from posthog.hogql.database.schema.session_replay_events import RawSessionReplayEventsTable, SessionReplayEventsTable from posthog.hogql.database.schema.static_cohort_people import StaticCohortPeople from posthog.hogql.errors import HogQLException +from posthog.models.group_type_mapping import GroupTypeMapping from posthog.models.team.team import WeekStartDay from posthog.utils import PersonOnEventsMode @@ -118,6 +119,9 @@ def create_hogql_database(team_id: int) -> Database: database.events.fields["person"] = FieldTraverser(chain=["poe"]) database.events.fields["person_id"] = StringDatabaseField(name="person_id") + for mapping in GroupTypeMapping.objects.filter(team=team): + database.events.fields[mapping.group_type] = FieldTraverser(chain=[f"group_{mapping.group_type_index}"]) + for view in DataWarehouseViewLink.objects.filter(team_id=team.pk).exclude(deleted=True): table = database.get_table(view.table) diff --git a/posthog/hogql/database/schema/events.py b/posthog/hogql/database/schema/events.py index 3f85dcd53e4b4..9934511ef5944 100644 --- a/posthog/hogql/database/schema/events.py +++ b/posthog/hogql/database/schema/events.py @@ -11,6 +11,7 @@ FieldTraverser, FieldOrTable, ) +from posthog.hogql.database.schema.groups import GroupsTable, join_with_group_n_table from posthog.hogql.database.schema.person_distinct_ids import ( PersonDistinctIdsTable, join_with_person_distinct_ids_table, @@ -85,6 +86,16 @@ class EventsTable(Table): # These are swapped out if the user has PoE enabled "person": FieldTraverser(chain=["pdi", "person"]), "person_id": FieldTraverser(chain=["pdi", "person_id"]), + "$group_0": StringDatabaseField(name="$group_0"), + "group_0": LazyJoin(from_field="$group_0", join_table=GroupsTable(), join_function=join_with_group_n_table(0)), + "$group_1": StringDatabaseField(name="$group_1"), + "group_1": LazyJoin(from_field="$group_1", join_table=GroupsTable(), join_function=join_with_group_n_table(1)), + "$group_2": StringDatabaseField(name="$group_2"), + "group_2": LazyJoin(from_field="$group_2", join_table=GroupsTable(), join_function=join_with_group_n_table(2)), + "$group_3": StringDatabaseField(name="$group_3"), + "group_3": LazyJoin(from_field="$group_3", join_table=GroupsTable(), join_function=join_with_group_n_table(3)), + "$group_4": StringDatabaseField(name="$group_4"), + "group_4": LazyJoin(from_field="$group_4", join_table=GroupsTable(), join_function=join_with_group_n_table(4)), } def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/groups.py b/posthog/hogql/database/schema/groups.py index 8344c1549c0a0..6a674488d9cfd 100644 --- a/posthog/hogql/database/schema/groups.py +++ b/posthog/hogql/database/schema/groups.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Any, Dict, List from posthog.hogql.database.argmax import argmax_select from posthog.hogql.database.models import ( @@ -10,6 +10,7 @@ Table, FieldOrTable, ) +from posthog.hogql.errors import HogQLException GROUPS_TABLE_FIELDS = { "index": IntegerDatabaseField(name="group_type_index"), @@ -30,6 +31,34 @@ def select_from_groups_table(requested_fields: Dict[str, List[str]]): ) +def join_with_group_n_table(group_index: int): + def join_with_group_table(from_table: str, to_table: str, requested_fields: Dict[str, Any]): + from posthog.hogql import ast + + if not requested_fields: + raise HogQLException("No fields requested from person_distinct_ids") + + select_query = select_from_groups_table(requested_fields) + select_query.where = ast.CompareOperation( + left=ast.Field(chain=["index"]), op=ast.CompareOperationOp.Eq, right=ast.Constant(value=group_index) + ) + + join_expr = ast.JoinExpr(table=select_query) + join_expr.join_type = "LEFT JOIN" + join_expr.alias = to_table + join_expr.constraint = ast.JoinConstraint( + expr=ast.CompareOperation( + op=ast.CompareOperationOp.Eq, + left=ast.Field(chain=[from_table, f"$group_{group_index}"]), + right=ast.Field(chain=[to_table, "key"]), + ) + ) + + return join_expr + + return join_with_group_table + + class RawGroupsTable(Table): fields: Dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS diff --git a/posthog/hogql/database/test/__snapshots__/test_database.ambr b/posthog/hogql/database/test/__snapshots__/test_database.ambr index f9abb21115a8e..3cd02926282cf 100644 --- a/posthog/hogql/database/test/__snapshots__/test_database.ambr +++ b/posthog/hogql/database/test/__snapshots__/test_database.ambr @@ -137,6 +137,91 @@ "pdi", "person_id" ] + }, + { + "key": "$group_0", + "type": "string" + }, + { + "key": "group_0", + "type": "lazy_table", + "table": "groups", + "fields": [ + "index", + "team_id", + "key", + "created_at", + "updated_at", + "properties" + ] + }, + { + "key": "$group_1", + "type": "string" + }, + { + "key": "group_1", + "type": "lazy_table", + "table": "groups", + "fields": [ + "index", + "team_id", + "key", + "created_at", + "updated_at", + "properties" + ] + }, + { + "key": "$group_2", + "type": "string" + }, + { + "key": "group_2", + "type": "lazy_table", + "table": "groups", + "fields": [ + "index", + "team_id", + "key", + "created_at", + "updated_at", + "properties" + ] + }, + { + "key": "$group_3", + "type": "string" + }, + { + "key": "group_3", + "type": "lazy_table", + "table": "groups", + "fields": [ + "index", + "team_id", + "key", + "created_at", + "updated_at", + "properties" + ] + }, + { + "key": "$group_4", + "type": "string" + }, + { + "key": "group_4", + "type": "lazy_table", + "table": "groups", + "fields": [ + "index", + "team_id", + "key", + "created_at", + "updated_at", + "properties" + ] } ], "groups": [ @@ -829,6 +914,91 @@ { "key": "person_id", "type": "string" + }, + { + "key": "$group_0", + "type": "string" + }, + { + "key": "group_0", + "type": "lazy_table", + "table": "groups", + "fields": [ + "index", + "team_id", + "key", + "created_at", + "updated_at", + "properties" + ] + }, + { + "key": "$group_1", + "type": "string" + }, + { + "key": "group_1", + "type": "lazy_table", + "table": "groups", + "fields": [ + "index", + "team_id", + "key", + "created_at", + "updated_at", + "properties" + ] + }, + { + "key": "$group_2", + "type": "string" + }, + { + "key": "group_2", + "type": "lazy_table", + "table": "groups", + "fields": [ + "index", + "team_id", + "key", + "created_at", + "updated_at", + "properties" + ] + }, + { + "key": "$group_3", + "type": "string" + }, + { + "key": "group_3", + "type": "lazy_table", + "table": "groups", + "fields": [ + "index", + "team_id", + "key", + "created_at", + "updated_at", + "properties" + ] + }, + { + "key": "$group_4", + "type": "string" + }, + { + "key": "group_4", + "type": "lazy_table", + "table": "groups", + "fields": [ + "index", + "team_id", + "key", + "created_at", + "updated_at", + "properties" + ] } ], "groups": [ diff --git a/posthog/hogql/test/test_resolver.py b/posthog/hogql/test/test_resolver.py index 1d1ce4ee8431b..6c281d788624d 100644 --- a/posthog/hogql/test/test_resolver.py +++ b/posthog/hogql/test/test_resolver.py @@ -779,6 +779,11 @@ def test_asterisk_expander_table(self): chain=["elements_chain"], type=ast.FieldType(name="elements_chain", table_type=events_table_type) ), ast.Field(chain=["created_at"], type=ast.FieldType(name="created_at", table_type=events_table_type)), + ast.Field(chain=["$group_0"], type=ast.FieldType(name="$group_0", table_type=events_table_type)), + ast.Field(chain=["$group_1"], type=ast.FieldType(name="$group_1", table_type=events_table_type)), + ast.Field(chain=["$group_2"], type=ast.FieldType(name="$group_2", table_type=events_table_type)), + ast.Field(chain=["$group_3"], type=ast.FieldType(name="$group_3", table_type=events_table_type)), + ast.Field(chain=["$group_4"], type=ast.FieldType(name="$group_4", table_type=events_table_type)), ], ) @@ -811,6 +816,11 @@ def test_asterisk_expander_table_alias(self): ast.Field( chain=["created_at"], type=ast.FieldType(name="created_at", table_type=events_table_alias_type) ), + ast.Field(chain=["$group_0"], type=ast.FieldType(name="$group_0", table_type=events_table_alias_type)), + ast.Field(chain=["$group_1"], type=ast.FieldType(name="$group_1", table_type=events_table_alias_type)), + ast.Field(chain=["$group_2"], type=ast.FieldType(name="$group_2", table_type=events_table_alias_type)), + ast.Field(chain=["$group_3"], type=ast.FieldType(name="$group_3", table_type=events_table_alias_type)), + ast.Field(chain=["$group_4"], type=ast.FieldType(name="$group_4", table_type=events_table_alias_type)), ], ) @@ -882,6 +892,11 @@ def test_asterisk_expander_from_subquery_table(self): "distinct_id": ast.FieldType(name="distinct_id", table_type=events_table_type), "elements_chain": ast.FieldType(name="elements_chain", table_type=events_table_type), "created_at": ast.FieldType(name="created_at", table_type=events_table_type), + "$group_0": ast.FieldType(name="$group_0", table_type=events_table_type), + "$group_1": ast.FieldType(name="$group_1", table_type=events_table_type), + "$group_2": ast.FieldType(name="$group_2", table_type=events_table_type), + "$group_3": ast.FieldType(name="$group_3", table_type=events_table_type), + "$group_4": ast.FieldType(name="$group_4", table_type=events_table_type), }, ) @@ -898,6 +913,11 @@ def test_asterisk_expander_from_subquery_table(self): type=ast.FieldType(name="elements_chain", table_type=inner_select_type), ), ast.Field(chain=["created_at"], type=ast.FieldType(name="created_at", table_type=inner_select_type)), + ast.Field(chain=["$group_0"], type=ast.FieldType(name="$group_0", table_type=inner_select_type)), + ast.Field(chain=["$group_1"], type=ast.FieldType(name="$group_1", table_type=inner_select_type)), + ast.Field(chain=["$group_2"], type=ast.FieldType(name="$group_2", table_type=inner_select_type)), + ast.Field(chain=["$group_3"], type=ast.FieldType(name="$group_3", table_type=inner_select_type)), + ast.Field(chain=["$group_4"], type=ast.FieldType(name="$group_4", table_type=inner_select_type)), ], ) @@ -930,6 +950,11 @@ def test_asterisk_expander_select_union(self): "distinct_id": ast.FieldType(name="distinct_id", table_type=events_table_type), "elements_chain": ast.FieldType(name="elements_chain", table_type=events_table_type), "created_at": ast.FieldType(name="created_at", table_type=events_table_type), + "$group_0": ast.FieldType(name="$group_0", table_type=events_table_type), + "$group_1": ast.FieldType(name="$group_1", table_type=events_table_type), + "$group_2": ast.FieldType(name="$group_2", table_type=events_table_type), + "$group_3": ast.FieldType(name="$group_3", table_type=events_table_type), + "$group_4": ast.FieldType(name="$group_4", table_type=events_table_type), }, ) ] @@ -949,6 +974,11 @@ def test_asterisk_expander_select_union(self): type=ast.FieldType(name="elements_chain", table_type=inner_select_type), ), ast.Field(chain=["created_at"], type=ast.FieldType(name="created_at", table_type=inner_select_type)), + ast.Field(chain=["$group_0"], type=ast.FieldType(name="$group_0", table_type=inner_select_type)), + ast.Field(chain=["$group_1"], type=ast.FieldType(name="$group_1", table_type=inner_select_type)), + ast.Field(chain=["$group_2"], type=ast.FieldType(name="$group_2", table_type=inner_select_type)), + ast.Field(chain=["$group_3"], type=ast.FieldType(name="$group_3", table_type=inner_select_type)), + ast.Field(chain=["$group_4"], type=ast.FieldType(name="$group_4", table_type=inner_select_type)), ], )