diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py index e78df97c56823..bbf8e8d0cb772 100644 --- a/posthog/hogql/database/database.py +++ b/posthog/hogql/database/database.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, TypedDict from zoneinfo import ZoneInfo, ZoneInfoNotFoundError from pydantic import ConfigDict, BaseModel +from sentry_sdk import capture_exception from posthog.hogql import ast from posthog.hogql.context import HogQLContext from posthog.hogql.database.models import ( @@ -204,50 +205,64 @@ def create_hogql_database( database.add_warehouse_tables(**tables) for join in DataWarehouseJoin.objects.filter(team_id=team.pk).exclude(deleted=True): - source_table = database.get_table(join.source_table_name) - joining_table = database.get_table(join.joining_table_name) - - field = parse_expr(join.source_table_key) - if not isinstance(field, ast.Field): - raise HogQLException("Data Warehouse Join HogQL expression should be a Field node") - from_field = field.chain - - field = parse_expr(join.joining_table_key) - if not isinstance(field, ast.Field): - raise HogQLException("Data Warehouse Join HogQL expression should be a Field node") - to_field = field.chain - - source_table.fields[join.field_name] = LazyJoin( - from_field=from_field, - to_field=to_field, - join_table=joining_table, - join_function=join.join_function, - ) + try: + source_table = database.get_table(join.source_table_name) + joining_table = database.get_table(join.joining_table_name) + + field = parse_expr(join.source_table_key) + if not isinstance(field, ast.Field): + raise HogQLException("Data Warehouse Join HogQL expression should be a Field node") + from_field = field.chain + + field = parse_expr(join.joining_table_key) + if not isinstance(field, ast.Field): + raise HogQLException("Data Warehouse Join HogQL expression should be a Field node") + to_field = field.chain + + source_table.fields[join.field_name] = LazyJoin( + from_field=from_field, + to_field=to_field, + join_table=joining_table, + join_function=join.join_function, + ) - if join.source_table_name == "persons": - person_field = database.events.fields["person"] - if isinstance(person_field, ast.FieldTraverser): - table_or_field: ast.FieldOrTable = database.events - for chain in person_field.chain: - if isinstance(table_or_field, ast.LazyJoin): - table_or_field = table_or_field.resolve_table(HogQLContext(team_id=team_id, database=database)) - if table_or_field.has_field(chain): + if join.source_table_name == "persons": + person_field = database.events.fields["person"] + if isinstance(person_field, ast.FieldTraverser): + table_or_field: ast.FieldOrTable = database.events + for chain in person_field.chain: + if isinstance(table_or_field, ast.LazyJoin): + table_or_field = table_or_field.resolve_table( + HogQLContext(team_id=team_id, database=database) + ) + if table_or_field.has_field(chain): + table_or_field = table_or_field.get_field(chain) + if isinstance(table_or_field, ast.LazyJoin): + table_or_field = table_or_field.resolve_table( + HogQLContext(team_id=team_id, database=database) + ) + elif isinstance(table_or_field, ast.Table): table_or_field = table_or_field.get_field(chain) - if isinstance(table_or_field, ast.LazyJoin): - table_or_field = table_or_field.resolve_table( - HogQLContext(team_id=team_id, database=database) - ) - elif isinstance(table_or_field, ast.Table): - table_or_field = table_or_field.get_field(chain) - - assert isinstance(table_or_field, ast.Table) - - table_or_field.fields[join.field_name] = LazyJoin( - from_field=from_field, - to_field=to_field, - join_table=joining_table, - join_function=join.join_function, - ) + + assert isinstance(table_or_field, ast.Table) + + if isinstance(table_or_field, ast.VirtualTable): + table_or_field.fields[join.field_name] = ast.FieldTraverser(chain=["..", join.field_name]) + database.events.fields[join.field_name] = LazyJoin( + from_field=from_field, + to_field=to_field, + join_table=joining_table, + join_function=join.join_function, + ) + else: + table_or_field.fields[join.field_name] = LazyJoin( + from_field=from_field, + to_field=to_field, + join_table=joining_table, + join_function=join.join_function, + ) + except Exception as e: + capture_exception(e) return database diff --git a/posthog/hogql/database/test/test_database.py b/posthog/hogql/database/test/test_database.py index ec1ade4231a04..da17e15c03107 100644 --- a/posthog/hogql/database/test/test_database.py +++ b/posthog/hogql/database/test/test_database.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, cast from unittest.mock import patch import pytest @@ -7,15 +7,19 @@ from parameterized import parameterized from posthog.hogql.database.database import create_hogql_database, serialize_database -from posthog.hogql.database.models import FieldTraverser, StringDatabaseField, ExpressionField +from posthog.hogql.database.models import FieldTraverser, LazyJoin, StringDatabaseField, ExpressionField, Table +from posthog.hogql.errors import HogQLException from posthog.hogql.modifiers import create_default_modifiers_for_team from posthog.hogql.parser import parse_expr, parse_select from posthog.hogql.printer import print_ast from posthog.hogql.context import HogQLContext from posthog.models.group_type_mapping import GroupTypeMapping +from posthog.models.organization import Organization +from posthog.models.team.team import Team from posthog.test.base import BaseTest from posthog.warehouse.models import DataWarehouseTable, DataWarehouseCredential from posthog.hogql.query import execute_hogql_query +from posthog.warehouse.models.join import DataWarehouseJoin class TestDatabase(BaseTest): @@ -132,3 +136,155 @@ def test_database_expression_fields(self): query == "SELECT number AS number FROM (SELECT numbers.number AS number FROM numbers(2) AS numbers) LIMIT 10000" ), query + + def test_database_warehouse_joins(self): + DataWarehouseJoin.objects.create( + team=self.team, + source_table_name="events", + source_table_key="event", + joining_table_name="groups", + joining_table_key="key", + field_name="some_field", + ) + + db = create_hogql_database(team_id=self.team.pk) + context = HogQLContext( + team_id=self.team.pk, + enable_select_queries=True, + database=db, + ) + + sql = "select some_field.key from events" + print_ast(parse_select(sql), context, dialect="clickhouse") + + def test_database_warehouse_joins_deleted_join(self): + DataWarehouseJoin.objects.create( + team=self.team, + source_table_name="events", + source_table_key="lower(event)", + joining_table_name="groups", + joining_table_key="upper(key)", + field_name="some_field", + deleted=True, + ) + + db = create_hogql_database(team_id=self.team.pk) + context = HogQLContext( + team_id=self.team.pk, + enable_select_queries=True, + database=db, + ) + + sql = "select some_field.key from events" + with pytest.raises(HogQLException): + print_ast(parse_select(sql), context, dialect="clickhouse") + + def test_database_warehouse_joins_other_team(self): + other_organization = Organization.objects.create(name="some_other_org") + other_team = Team.objects.create(organization=other_organization) + + DataWarehouseJoin.objects.create( + team=other_team, + source_table_name="events", + source_table_key="lower(event)", + joining_table_name="groups", + joining_table_key="upper(key)", + field_name="some_field", + ) + + db = create_hogql_database(team_id=self.team.pk) + context = HogQLContext( + team_id=self.team.pk, + enable_select_queries=True, + database=db, + ) + + sql = "select some_field.key from events" + with pytest.raises(HogQLException): + print_ast(parse_select(sql), context, dialect="clickhouse") + + def test_database_warehouse_joins_bad_key_expression(self): + DataWarehouseJoin.objects.create( + team=self.team, + source_table_name="events", + source_table_key="blah_de_blah(event)", + joining_table_name="groups", + joining_table_key="upper(key)", + field_name="some_field", + ) + + create_hogql_database(team_id=self.team.pk) + + @override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False) + def test_database_warehouse_joins_persons_no_poe(self): + DataWarehouseJoin.objects.create( + team=self.team, + source_table_name="persons", + source_table_key="properties.email", + joining_table_name="groups", + joining_table_key="key", + field_name="some_field", + ) + + db = create_hogql_database(team_id=self.team.pk) + context = HogQLContext( + team_id=self.team.pk, + enable_select_queries=True, + database=db, + ) + + pdi = cast(LazyJoin, db.events.fields["pdi"]) + pdi_persons_join = cast(LazyJoin, pdi.resolve_table(context).fields["person"]) + pdi_table = pdi_persons_join.resolve_table(context) + + assert pdi_table.fields["some_field"] is not None + + print_ast(parse_select("select person.some_field.key from events"), context, dialect="clickhouse") + + @override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False) + def test_database_warehouse_joins_persons_poe_v1(self): + DataWarehouseJoin.objects.create( + team=self.team, + source_table_name="persons", + source_table_key="properties.email", + joining_table_name="groups", + joining_table_key="key", + field_name="some_field", + ) + + db = create_hogql_database(team_id=self.team.pk) + context = HogQLContext( + team_id=self.team.pk, + enable_select_queries=True, + database=db, + ) + + poe = cast(Table, db.events.fields["poe"]) + + assert poe.fields["some_field"] is not None + + print_ast(parse_select("select person.some_field.key from events"), context, dialect="clickhouse") + + @override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=True) + def test_database_warehouse_joins_persons_poe_v2(self): + DataWarehouseJoin.objects.create( + team=self.team, + source_table_name="persons", + source_table_key="properties.email", + joining_table_name="groups", + joining_table_key="key", + field_name="some_field", + ) + + db = create_hogql_database(team_id=self.team.pk) + context = HogQLContext( + team_id=self.team.pk, + enable_select_queries=True, + database=db, + ) + + poe = cast(Table, db.events.fields["poe"]) + + assert poe.fields["some_field"] is not None + + print_ast(parse_select("select person.some_field.key from events"), context, dialect="clickhouse")