diff --git a/posthog/hogql/database/schema/events.py b/posthog/hogql/database/schema/events.py index 3f85dcd53e4b4d..9934511ef59445 100644 --- a/posthog/hogql/database/schema/events.py +++ b/posthog/hogql/database/schema/events.py @@ -11,6 +11,7 @@ FieldTraverser, FieldOrTable, ) +from posthog.hogql.database.schema.groups import GroupsTable, join_with_group_n_table from posthog.hogql.database.schema.person_distinct_ids import ( PersonDistinctIdsTable, join_with_person_distinct_ids_table, @@ -85,6 +86,16 @@ class EventsTable(Table): # These are swapped out if the user has PoE enabled "person": FieldTraverser(chain=["pdi", "person"]), "person_id": FieldTraverser(chain=["pdi", "person_id"]), + "$group_0": StringDatabaseField(name="$group_0"), + "group_0": LazyJoin(from_field="$group_0", join_table=GroupsTable(), join_function=join_with_group_n_table(0)), + "$group_1": StringDatabaseField(name="$group_1"), + "group_1": LazyJoin(from_field="$group_1", join_table=GroupsTable(), join_function=join_with_group_n_table(1)), + "$group_2": StringDatabaseField(name="$group_2"), + "group_2": LazyJoin(from_field="$group_2", join_table=GroupsTable(), join_function=join_with_group_n_table(2)), + "$group_3": StringDatabaseField(name="$group_3"), + "group_3": LazyJoin(from_field="$group_3", join_table=GroupsTable(), join_function=join_with_group_n_table(3)), + "$group_4": StringDatabaseField(name="$group_4"), + "group_4": LazyJoin(from_field="$group_4", join_table=GroupsTable(), join_function=join_with_group_n_table(4)), } def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/groups.py b/posthog/hogql/database/schema/groups.py index 8344c1549c0a04..6a674488d9cfdf 100644 --- a/posthog/hogql/database/schema/groups.py +++ b/posthog/hogql/database/schema/groups.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Any, Dict, List from posthog.hogql.database.argmax import argmax_select from posthog.hogql.database.models import ( @@ -10,6 +10,7 @@ Table, FieldOrTable, ) +from posthog.hogql.errors import HogQLException GROUPS_TABLE_FIELDS = { "index": IntegerDatabaseField(name="group_type_index"), @@ -30,6 +31,34 @@ def select_from_groups_table(requested_fields: Dict[str, List[str]]): ) +def join_with_group_n_table(group_index: int): + def join_with_group_table(from_table: str, to_table: str, requested_fields: Dict[str, Any]): + from posthog.hogql import ast + + if not requested_fields: + raise HogQLException("No fields requested from person_distinct_ids") + + select_query = select_from_groups_table(requested_fields) + select_query.where = ast.CompareOperation( + left=ast.Field(chain=["index"]), op=ast.CompareOperationOp.Eq, right=ast.Constant(value=group_index) + ) + + join_expr = ast.JoinExpr(table=select_query) + join_expr.join_type = "LEFT JOIN" + join_expr.alias = to_table + join_expr.constraint = ast.JoinConstraint( + expr=ast.CompareOperation( + op=ast.CompareOperationOp.Eq, + left=ast.Field(chain=[from_table, f"$group_{group_index}"]), + right=ast.Field(chain=[to_table, "key"]), + ) + ) + + return join_expr + + return join_with_group_table + + class RawGroupsTable(Table): fields: Dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS