Skip to content

Commit

Permalink
Use a custom join function
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbachhuber committed Dec 17, 2024
1 parent c7b33e0 commit 006778d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
5 changes: 2 additions & 3 deletions posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from posthog.hogql.database.schema.person_distinct_ids import (
PersonDistinctIdsTable,
RawPersonDistinctIdsTable,
join_with_person_distinct_ids_table,
join_data_warehouse_experiment_table_with_person_distinct_ids_table,
)
from posthog.hogql.database.schema.persons import (
PersonsTable,
Expand Down Expand Up @@ -462,8 +462,7 @@ def define_mappings(warehouse: dict[str, Table], get_table: Callable):
source_table.fields["pdi"] = LazyJoin(
from_field=from_field,
join_table=PersonDistinctIdsTable(),
join_function=join_with_person_distinct_ids_table,
override_source_table_key=join.source_table_key,
join_function=join_data_warehouse_experiment_table_with_person_distinct_ids_table,
)
source_table.fields["person"] = FieldTraverser(chain=["pdi", "person"])

Expand Down
28 changes: 24 additions & 4 deletions posthog/hogql/database/schema/person_distinct_ids.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Optional
from posthog.hogql.ast import SelectQuery
from posthog.hogql.constants import HogQLQuerySettings
from posthog.hogql.context import HogQLContext
Expand Down Expand Up @@ -49,20 +48,41 @@ def join_with_person_distinct_ids_table(
join_to_add: LazyJoinToAdd,
context: HogQLContext,
node: SelectQuery,
override_source_table_key: Optional[str] = None,
):
from posthog.hogql import ast

if not join_to_add.fields_accessed:
raise ResolutionError("No fields requested from person_distinct_ids")
source_table_key = override_source_table_key or "distinct_id"
join_expr = ast.JoinExpr(table=select_from_person_distinct_ids_table(join_to_add.fields_accessed))
join_expr.join_type = "INNER JOIN"
join_expr.alias = join_to_add.to_table
join_expr.constraint = ast.JoinConstraint(
expr=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Field(chain=[join_to_add.from_table, source_table_key]),
left=ast.Field(chain=[join_to_add.from_table, "distinct_id"]),
right=ast.Field(chain=[join_to_add.to_table, "distinct_id"]),
),
constraint_type="ON",
)
return join_expr


def join_data_warehouse_experiment_table_with_person_distinct_ids_table(
join_to_add: LazyJoinToAdd,
context: HogQLContext,
node: SelectQuery,
):
from posthog.hogql import ast

if not join_to_add.fields_accessed:
raise ResolutionError("No fields requested from person_distinct_ids")
join_expr = ast.JoinExpr(table=select_from_person_distinct_ids_table(join_to_add.fields_accessed))
join_expr.join_type = "LEFT JOIN"
join_expr.alias = join_to_add.to_table
join_expr.constraint = ast.JoinConstraint(
expr=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Field(chain=[join_to_add.from_table, *join_to_add.lazy_join.from_field]),
right=ast.Field(chain=[join_to_add.to_table, "distinct_id"]),
),
constraint_type="ON",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def test_query_runner_with_data_warehouse_series_no_end_date_and_nested_id(self)
)

# Assert the expected join condition in the clickhouse SQL
expected_join_condition = f"and(equals(events.team_id, {query_runner.count_query_runner.team.id}), equals(event, %(hogql_val_12)s), greaterOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_13)s, 6, %(hogql_val_14)s))), lessOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_15)s, 6, %(hogql_val_16)s))))) AS e__events ON"
expected_join_condition = f"and(equals(events.team_id, {query_runner.count_query_runner.team.id}), equals(event, %(hogql_val_11)s), greaterOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_12)s, 6, %(hogql_val_13)s))), lessOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_14)s, 6, %(hogql_val_15)s))))) AS e__events ON"
self.assertIn(
expected_join_condition,
str(response.clickhouse),
Expand Down

0 comments on commit 006778d

Please sign in to comment.