diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py index 2abd7339aeafb..69ebf975fe312 100644 --- a/posthog/hogql/database/database.py +++ b/posthog/hogql/database/database.py @@ -53,7 +53,7 @@ from posthog.hogql.database.schema.person_distinct_ids import ( PersonDistinctIdsTable, RawPersonDistinctIdsTable, - join_with_person_distinct_ids_table, + join_data_warehouse_experiment_table_with_person_distinct_ids_table, ) from posthog.hogql.database.schema.persons import ( PersonsTable, @@ -462,8 +462,7 @@ def define_mappings(warehouse: dict[str, Table], get_table: Callable): source_table.fields["pdi"] = LazyJoin( from_field=from_field, join_table=PersonDistinctIdsTable(), - join_function=join_with_person_distinct_ids_table, - override_source_table_key=join.source_table_key, + join_function=join_data_warehouse_experiment_table_with_person_distinct_ids_table, ) source_table.fields["person"] = FieldTraverser(chain=["pdi", "person"]) diff --git a/posthog/hogql/database/schema/person_distinct_ids.py b/posthog/hogql/database/schema/person_distinct_ids.py index 1db99914eadb1..2fecbae4f5960 100644 --- a/posthog/hogql/database/schema/person_distinct_ids.py +++ b/posthog/hogql/database/schema/person_distinct_ids.py @@ -1,4 +1,3 @@ -from typing import Optional from posthog.hogql.ast import SelectQuery from posthog.hogql.constants import HogQLQuerySettings from posthog.hogql.context import HogQLContext @@ -49,20 +48,41 @@ def join_with_person_distinct_ids_table( join_to_add: LazyJoinToAdd, context: HogQLContext, node: SelectQuery, - override_source_table_key: Optional[str] = None, ): from posthog.hogql import ast if not join_to_add.fields_accessed: raise ResolutionError("No fields requested from person_distinct_ids") - source_table_key = override_source_table_key or "distinct_id" join_expr = ast.JoinExpr(table=select_from_person_distinct_ids_table(join_to_add.fields_accessed)) join_expr.join_type = "INNER JOIN" join_expr.alias = join_to_add.to_table join_expr.constraint = ast.JoinConstraint( expr=ast.CompareOperation( op=ast.CompareOperationOp.Eq, - left=ast.Field(chain=[join_to_add.from_table, source_table_key]), + left=ast.Field(chain=[join_to_add.from_table, "distinct_id"]), + right=ast.Field(chain=[join_to_add.to_table, "distinct_id"]), + ), + constraint_type="ON", + ) + return join_expr + + +def join_data_warehouse_experiment_table_with_person_distinct_ids_table( + join_to_add: LazyJoinToAdd, + context: HogQLContext, + node: SelectQuery, +): + from posthog.hogql import ast + + if not join_to_add.fields_accessed: + raise ResolutionError("No fields requested from person_distinct_ids") + join_expr = ast.JoinExpr(table=select_from_person_distinct_ids_table(join_to_add.fields_accessed)) + join_expr.join_type = "LEFT JOIN" + join_expr.alias = join_to_add.to_table + join_expr.constraint = ast.JoinConstraint( + expr=ast.CompareOperation( + op=ast.CompareOperationOp.Eq, + left=ast.Field(chain=[join_to_add.from_table, *join_to_add.lazy_join.from_field]), right=ast.Field(chain=[join_to_add.to_table, "distinct_id"]), ), constraint_type="ON", diff --git a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py index 00c579ecbd89c..9198c2abaef99 100644 --- a/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py +++ b/posthog/hogql_queries/experiments/test/test_experiment_trends_query_runner.py @@ -938,7 +938,7 @@ def test_query_runner_with_data_warehouse_series_no_end_date_and_nested_id(self) ) # Assert the expected join condition in the clickhouse SQL - expected_join_condition = f"and(equals(events.team_id, {query_runner.count_query_runner.team.id}), equals(event, %(hogql_val_12)s), greaterOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_13)s, 6, %(hogql_val_14)s))), lessOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_15)s, 6, %(hogql_val_16)s))))) AS e__events ON" + expected_join_condition = f"and(equals(events.team_id, {query_runner.count_query_runner.team.id}), equals(event, %(hogql_val_11)s), greaterOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_12)s, 6, %(hogql_val_13)s))), lessOrEquals(timestamp, assumeNotNull(parseDateTime64BestEffortOrNull(%(hogql_val_14)s, 6, %(hogql_val_15)s))))) AS e__events ON" self.assertIn( expected_join_condition, str(response.clickhouse),