Skip to content

Commit

Permalink
feat: Added lazy joins for groups on events (#17950)
Browse files Browse the repository at this point in the history
* Added lazy joins for groups on events

* Update query snapshots

* Fixed test_resolver.py tests

* Added aliases for groups from group mappings

---------

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
Gilbert09 and github-actions[bot] authored Oct 12, 2023
1 parent d50eed9 commit 60398b7
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 1 deletion.
4 changes: 4 additions & 0 deletions posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from posthog.hogql.database.schema.session_replay_events import RawSessionReplayEventsTable, SessionReplayEventsTable
from posthog.hogql.database.schema.static_cohort_people import StaticCohortPeople
from posthog.hogql.errors import HogQLException
from posthog.models.group_type_mapping import GroupTypeMapping
from posthog.models.team.team import WeekStartDay
from posthog.utils import PersonOnEventsMode

Expand Down Expand Up @@ -118,6 +119,9 @@ def create_hogql_database(team_id: int) -> Database:
database.events.fields["person"] = FieldTraverser(chain=["poe"])
database.events.fields["person_id"] = StringDatabaseField(name="person_id")

for mapping in GroupTypeMapping.objects.filter(team=team):
database.events.fields[mapping.group_type] = FieldTraverser(chain=[f"group_{mapping.group_type_index}"])

for view in DataWarehouseViewLink.objects.filter(team_id=team.pk).exclude(deleted=True):
table = database.get_table(view.table)

Expand Down
11 changes: 11 additions & 0 deletions posthog/hogql/database/schema/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FieldTraverser,
FieldOrTable,
)
from posthog.hogql.database.schema.groups import GroupsTable, join_with_group_n_table
from posthog.hogql.database.schema.person_distinct_ids import (
PersonDistinctIdsTable,
join_with_person_distinct_ids_table,
Expand Down Expand Up @@ -85,6 +86,16 @@ class EventsTable(Table):
# These are swapped out if the user has PoE enabled
"person": FieldTraverser(chain=["pdi", "person"]),
"person_id": FieldTraverser(chain=["pdi", "person_id"]),
"$group_0": StringDatabaseField(name="$group_0"),
"group_0": LazyJoin(from_field="$group_0", join_table=GroupsTable(), join_function=join_with_group_n_table(0)),
"$group_1": StringDatabaseField(name="$group_1"),
"group_1": LazyJoin(from_field="$group_1", join_table=GroupsTable(), join_function=join_with_group_n_table(1)),
"$group_2": StringDatabaseField(name="$group_2"),
"group_2": LazyJoin(from_field="$group_2", join_table=GroupsTable(), join_function=join_with_group_n_table(2)),
"$group_3": StringDatabaseField(name="$group_3"),
"group_3": LazyJoin(from_field="$group_3", join_table=GroupsTable(), join_function=join_with_group_n_table(3)),
"$group_4": StringDatabaseField(name="$group_4"),
"group_4": LazyJoin(from_field="$group_4", join_table=GroupsTable(), join_function=join_with_group_n_table(4)),
}

def to_printed_clickhouse(self, context):
Expand Down
31 changes: 30 additions & 1 deletion posthog/hogql/database/schema/groups.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Any, Dict, List

from posthog.hogql.database.argmax import argmax_select
from posthog.hogql.database.models import (
Expand All @@ -10,6 +10,7 @@
Table,
FieldOrTable,
)
from posthog.hogql.errors import HogQLException

GROUPS_TABLE_FIELDS = {
"index": IntegerDatabaseField(name="group_type_index"),
Expand All @@ -30,6 +31,34 @@ def select_from_groups_table(requested_fields: Dict[str, List[str]]):
)


def join_with_group_n_table(group_index: int):
def join_with_group_table(from_table: str, to_table: str, requested_fields: Dict[str, Any]):
from posthog.hogql import ast

if not requested_fields:
raise HogQLException("No fields requested from person_distinct_ids")

select_query = select_from_groups_table(requested_fields)
select_query.where = ast.CompareOperation(
left=ast.Field(chain=["index"]), op=ast.CompareOperationOp.Eq, right=ast.Constant(value=group_index)
)

join_expr = ast.JoinExpr(table=select_query)
join_expr.join_type = "LEFT JOIN"
join_expr.alias = to_table
join_expr.constraint = ast.JoinConstraint(
expr=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Field(chain=[from_table, f"$group_{group_index}"]),
right=ast.Field(chain=[to_table, "key"]),
)
)

return join_expr

return join_with_group_table


class RawGroupsTable(Table):
fields: Dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS

Expand Down
170 changes: 170 additions & 0 deletions posthog/hogql/database/test/__snapshots__/test_database.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,91 @@
"pdi",
"person_id"
]
},
{
"key": "$group_0",
"type": "string"
},
{
"key": "group_0",
"type": "lazy_table",
"table": "groups",
"fields": [
"index",
"team_id",
"key",
"created_at",
"updated_at",
"properties"
]
},
{
"key": "$group_1",
"type": "string"
},
{
"key": "group_1",
"type": "lazy_table",
"table": "groups",
"fields": [
"index",
"team_id",
"key",
"created_at",
"updated_at",
"properties"
]
},
{
"key": "$group_2",
"type": "string"
},
{
"key": "group_2",
"type": "lazy_table",
"table": "groups",
"fields": [
"index",
"team_id",
"key",
"created_at",
"updated_at",
"properties"
]
},
{
"key": "$group_3",
"type": "string"
},
{
"key": "group_3",
"type": "lazy_table",
"table": "groups",
"fields": [
"index",
"team_id",
"key",
"created_at",
"updated_at",
"properties"
]
},
{
"key": "$group_4",
"type": "string"
},
{
"key": "group_4",
"type": "lazy_table",
"table": "groups",
"fields": [
"index",
"team_id",
"key",
"created_at",
"updated_at",
"properties"
]
}
],
"groups": [
Expand Down Expand Up @@ -829,6 +914,91 @@
{
"key": "person_id",
"type": "string"
},
{
"key": "$group_0",
"type": "string"
},
{
"key": "group_0",
"type": "lazy_table",
"table": "groups",
"fields": [
"index",
"team_id",
"key",
"created_at",
"updated_at",
"properties"
]
},
{
"key": "$group_1",
"type": "string"
},
{
"key": "group_1",
"type": "lazy_table",
"table": "groups",
"fields": [
"index",
"team_id",
"key",
"created_at",
"updated_at",
"properties"
]
},
{
"key": "$group_2",
"type": "string"
},
{
"key": "group_2",
"type": "lazy_table",
"table": "groups",
"fields": [
"index",
"team_id",
"key",
"created_at",
"updated_at",
"properties"
]
},
{
"key": "$group_3",
"type": "string"
},
{
"key": "group_3",
"type": "lazy_table",
"table": "groups",
"fields": [
"index",
"team_id",
"key",
"created_at",
"updated_at",
"properties"
]
},
{
"key": "$group_4",
"type": "string"
},
{
"key": "group_4",
"type": "lazy_table",
"table": "groups",
"fields": [
"index",
"team_id",
"key",
"created_at",
"updated_at",
"properties"
]
}
],
"groups": [
Expand Down
30 changes: 30 additions & 0 deletions posthog/hogql/test/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,11 @@ def test_asterisk_expander_table(self):
chain=["elements_chain"], type=ast.FieldType(name="elements_chain", table_type=events_table_type)
),
ast.Field(chain=["created_at"], type=ast.FieldType(name="created_at", table_type=events_table_type)),
ast.Field(chain=["$group_0"], type=ast.FieldType(name="$group_0", table_type=events_table_type)),
ast.Field(chain=["$group_1"], type=ast.FieldType(name="$group_1", table_type=events_table_type)),
ast.Field(chain=["$group_2"], type=ast.FieldType(name="$group_2", table_type=events_table_type)),
ast.Field(chain=["$group_3"], type=ast.FieldType(name="$group_3", table_type=events_table_type)),
ast.Field(chain=["$group_4"], type=ast.FieldType(name="$group_4", table_type=events_table_type)),
],
)

Expand Down Expand Up @@ -811,6 +816,11 @@ def test_asterisk_expander_table_alias(self):
ast.Field(
chain=["created_at"], type=ast.FieldType(name="created_at", table_type=events_table_alias_type)
),
ast.Field(chain=["$group_0"], type=ast.FieldType(name="$group_0", table_type=events_table_alias_type)),
ast.Field(chain=["$group_1"], type=ast.FieldType(name="$group_1", table_type=events_table_alias_type)),
ast.Field(chain=["$group_2"], type=ast.FieldType(name="$group_2", table_type=events_table_alias_type)),
ast.Field(chain=["$group_3"], type=ast.FieldType(name="$group_3", table_type=events_table_alias_type)),
ast.Field(chain=["$group_4"], type=ast.FieldType(name="$group_4", table_type=events_table_alias_type)),
],
)

Expand Down Expand Up @@ -882,6 +892,11 @@ def test_asterisk_expander_from_subquery_table(self):
"distinct_id": ast.FieldType(name="distinct_id", table_type=events_table_type),
"elements_chain": ast.FieldType(name="elements_chain", table_type=events_table_type),
"created_at": ast.FieldType(name="created_at", table_type=events_table_type),
"$group_0": ast.FieldType(name="$group_0", table_type=events_table_type),
"$group_1": ast.FieldType(name="$group_1", table_type=events_table_type),
"$group_2": ast.FieldType(name="$group_2", table_type=events_table_type),
"$group_3": ast.FieldType(name="$group_3", table_type=events_table_type),
"$group_4": ast.FieldType(name="$group_4", table_type=events_table_type),
},
)

Expand All @@ -898,6 +913,11 @@ def test_asterisk_expander_from_subquery_table(self):
type=ast.FieldType(name="elements_chain", table_type=inner_select_type),
),
ast.Field(chain=["created_at"], type=ast.FieldType(name="created_at", table_type=inner_select_type)),
ast.Field(chain=["$group_0"], type=ast.FieldType(name="$group_0", table_type=inner_select_type)),
ast.Field(chain=["$group_1"], type=ast.FieldType(name="$group_1", table_type=inner_select_type)),
ast.Field(chain=["$group_2"], type=ast.FieldType(name="$group_2", table_type=inner_select_type)),
ast.Field(chain=["$group_3"], type=ast.FieldType(name="$group_3", table_type=inner_select_type)),
ast.Field(chain=["$group_4"], type=ast.FieldType(name="$group_4", table_type=inner_select_type)),
],
)

Expand Down Expand Up @@ -930,6 +950,11 @@ def test_asterisk_expander_select_union(self):
"distinct_id": ast.FieldType(name="distinct_id", table_type=events_table_type),
"elements_chain": ast.FieldType(name="elements_chain", table_type=events_table_type),
"created_at": ast.FieldType(name="created_at", table_type=events_table_type),
"$group_0": ast.FieldType(name="$group_0", table_type=events_table_type),
"$group_1": ast.FieldType(name="$group_1", table_type=events_table_type),
"$group_2": ast.FieldType(name="$group_2", table_type=events_table_type),
"$group_3": ast.FieldType(name="$group_3", table_type=events_table_type),
"$group_4": ast.FieldType(name="$group_4", table_type=events_table_type),
},
)
]
Expand All @@ -949,6 +974,11 @@ def test_asterisk_expander_select_union(self):
type=ast.FieldType(name="elements_chain", table_type=inner_select_type),
),
ast.Field(chain=["created_at"], type=ast.FieldType(name="created_at", table_type=inner_select_type)),
ast.Field(chain=["$group_0"], type=ast.FieldType(name="$group_0", table_type=inner_select_type)),
ast.Field(chain=["$group_1"], type=ast.FieldType(name="$group_1", table_type=inner_select_type)),
ast.Field(chain=["$group_2"], type=ast.FieldType(name="$group_2", table_type=inner_select_type)),
ast.Field(chain=["$group_3"], type=ast.FieldType(name="$group_3", table_type=inner_select_type)),
ast.Field(chain=["$group_4"], type=ast.FieldType(name="$group_4", table_type=inner_select_type)),
],
)

Expand Down

0 comments on commit 60398b7

Please sign in to comment.