Skip to content

Commit

Permalink
chore(data-warehouse): Added some safe guarding around joins and adde…
Browse files Browse the repository at this point in the history
…d tests (#20899)

Added some safe guarding around joins and added tests
  • Loading branch information
Gilbert09 authored Mar 13, 2024
1 parent 428c480 commit 6fda11e
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 44 deletions.
99 changes: 57 additions & 42 deletions posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, TypedDict
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from pydantic import ConfigDict, BaseModel
from sentry_sdk import capture_exception
from posthog.hogql import ast
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.models import (
Expand Down Expand Up @@ -204,50 +205,64 @@ def create_hogql_database(
database.add_warehouse_tables(**tables)

for join in DataWarehouseJoin.objects.filter(team_id=team.pk).exclude(deleted=True):
source_table = database.get_table(join.source_table_name)
joining_table = database.get_table(join.joining_table_name)

field = parse_expr(join.source_table_key)
if not isinstance(field, ast.Field):
raise HogQLException("Data Warehouse Join HogQL expression should be a Field node")
from_field = field.chain

field = parse_expr(join.joining_table_key)
if not isinstance(field, ast.Field):
raise HogQLException("Data Warehouse Join HogQL expression should be a Field node")
to_field = field.chain

source_table.fields[join.field_name] = LazyJoin(
from_field=from_field,
to_field=to_field,
join_table=joining_table,
join_function=join.join_function,
)
try:
source_table = database.get_table(join.source_table_name)
joining_table = database.get_table(join.joining_table_name)

field = parse_expr(join.source_table_key)
if not isinstance(field, ast.Field):
raise HogQLException("Data Warehouse Join HogQL expression should be a Field node")
from_field = field.chain

field = parse_expr(join.joining_table_key)
if not isinstance(field, ast.Field):
raise HogQLException("Data Warehouse Join HogQL expression should be a Field node")
to_field = field.chain

source_table.fields[join.field_name] = LazyJoin(
from_field=from_field,
to_field=to_field,
join_table=joining_table,
join_function=join.join_function,
)

if join.source_table_name == "persons":
person_field = database.events.fields["person"]
if isinstance(person_field, ast.FieldTraverser):
table_or_field: ast.FieldOrTable = database.events
for chain in person_field.chain:
if isinstance(table_or_field, ast.LazyJoin):
table_or_field = table_or_field.resolve_table(HogQLContext(team_id=team_id, database=database))
if table_or_field.has_field(chain):
if join.source_table_name == "persons":
person_field = database.events.fields["person"]
if isinstance(person_field, ast.FieldTraverser):
table_or_field: ast.FieldOrTable = database.events
for chain in person_field.chain:
if isinstance(table_or_field, ast.LazyJoin):
table_or_field = table_or_field.resolve_table(
HogQLContext(team_id=team_id, database=database)
)
if table_or_field.has_field(chain):
table_or_field = table_or_field.get_field(chain)
if isinstance(table_or_field, ast.LazyJoin):
table_or_field = table_or_field.resolve_table(
HogQLContext(team_id=team_id, database=database)
)
elif isinstance(table_or_field, ast.Table):
table_or_field = table_or_field.get_field(chain)
if isinstance(table_or_field, ast.LazyJoin):
table_or_field = table_or_field.resolve_table(
HogQLContext(team_id=team_id, database=database)
)
elif isinstance(table_or_field, ast.Table):
table_or_field = table_or_field.get_field(chain)

assert isinstance(table_or_field, ast.Table)

table_or_field.fields[join.field_name] = LazyJoin(
from_field=from_field,
to_field=to_field,
join_table=joining_table,
join_function=join.join_function,
)

assert isinstance(table_or_field, ast.Table)

if isinstance(table_or_field, ast.VirtualTable):
table_or_field.fields[join.field_name] = ast.FieldTraverser(chain=["..", join.field_name])
database.events.fields[join.field_name] = LazyJoin(
from_field=from_field,
to_field=to_field,
join_table=joining_table,
join_function=join.join_function,
)
else:
table_or_field.fields[join.field_name] = LazyJoin(
from_field=from_field,
to_field=to_field,
join_table=joining_table,
join_function=join.join_function,
)
except Exception as e:
capture_exception(e)

return database

Expand Down
160 changes: 158 additions & 2 deletions posthog/hogql/database/test/test_database.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import json
from typing import Any
from typing import Any, cast

from unittest.mock import patch
import pytest
from django.test import override_settings
from parameterized import parameterized

from posthog.hogql.database.database import create_hogql_database, serialize_database
from posthog.hogql.database.models import FieldTraverser, StringDatabaseField, ExpressionField
from posthog.hogql.database.models import FieldTraverser, LazyJoin, StringDatabaseField, ExpressionField, Table
from posthog.hogql.errors import HogQLException
from posthog.hogql.modifiers import create_default_modifiers_for_team
from posthog.hogql.parser import parse_expr, parse_select
from posthog.hogql.printer import print_ast
from posthog.hogql.context import HogQLContext
from posthog.models.group_type_mapping import GroupTypeMapping
from posthog.models.organization import Organization
from posthog.models.team.team import Team
from posthog.test.base import BaseTest
from posthog.warehouse.models import DataWarehouseTable, DataWarehouseCredential
from posthog.hogql.query import execute_hogql_query
from posthog.warehouse.models.join import DataWarehouseJoin


class TestDatabase(BaseTest):
Expand Down Expand Up @@ -132,3 +136,155 @@ def test_database_expression_fields(self):
query
== "SELECT number AS number FROM (SELECT numbers.number AS number FROM numbers(2) AS numbers) LIMIT 10000"
), query

def test_database_warehouse_joins(self):
DataWarehouseJoin.objects.create(
team=self.team,
source_table_name="events",
source_table_key="event",
joining_table_name="groups",
joining_table_key="key",
field_name="some_field",
)

db = create_hogql_database(team_id=self.team.pk)
context = HogQLContext(
team_id=self.team.pk,
enable_select_queries=True,
database=db,
)

sql = "select some_field.key from events"
print_ast(parse_select(sql), context, dialect="clickhouse")

def test_database_warehouse_joins_deleted_join(self):
DataWarehouseJoin.objects.create(
team=self.team,
source_table_name="events",
source_table_key="lower(event)",
joining_table_name="groups",
joining_table_key="upper(key)",
field_name="some_field",
deleted=True,
)

db = create_hogql_database(team_id=self.team.pk)
context = HogQLContext(
team_id=self.team.pk,
enable_select_queries=True,
database=db,
)

sql = "select some_field.key from events"
with pytest.raises(HogQLException):
print_ast(parse_select(sql), context, dialect="clickhouse")

def test_database_warehouse_joins_other_team(self):
other_organization = Organization.objects.create(name="some_other_org")
other_team = Team.objects.create(organization=other_organization)

DataWarehouseJoin.objects.create(
team=other_team,
source_table_name="events",
source_table_key="lower(event)",
joining_table_name="groups",
joining_table_key="upper(key)",
field_name="some_field",
)

db = create_hogql_database(team_id=self.team.pk)
context = HogQLContext(
team_id=self.team.pk,
enable_select_queries=True,
database=db,
)

sql = "select some_field.key from events"
with pytest.raises(HogQLException):
print_ast(parse_select(sql), context, dialect="clickhouse")

def test_database_warehouse_joins_bad_key_expression(self):
DataWarehouseJoin.objects.create(
team=self.team,
source_table_name="events",
source_table_key="blah_de_blah(event)",
joining_table_name="groups",
joining_table_key="upper(key)",
field_name="some_field",
)

create_hogql_database(team_id=self.team.pk)

@override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False)
def test_database_warehouse_joins_persons_no_poe(self):
DataWarehouseJoin.objects.create(
team=self.team,
source_table_name="persons",
source_table_key="properties.email",
joining_table_name="groups",
joining_table_key="key",
field_name="some_field",
)

db = create_hogql_database(team_id=self.team.pk)
context = HogQLContext(
team_id=self.team.pk,
enable_select_queries=True,
database=db,
)

pdi = cast(LazyJoin, db.events.fields["pdi"])
pdi_persons_join = cast(LazyJoin, pdi.resolve_table(context).fields["person"])
pdi_table = pdi_persons_join.resolve_table(context)

assert pdi_table.fields["some_field"] is not None

print_ast(parse_select("select person.some_field.key from events"), context, dialect="clickhouse")

@override_settings(PERSON_ON_EVENTS_OVERRIDE=True, PERSON_ON_EVENTS_V2_OVERRIDE=False)
def test_database_warehouse_joins_persons_poe_v1(self):
DataWarehouseJoin.objects.create(
team=self.team,
source_table_name="persons",
source_table_key="properties.email",
joining_table_name="groups",
joining_table_key="key",
field_name="some_field",
)

db = create_hogql_database(team_id=self.team.pk)
context = HogQLContext(
team_id=self.team.pk,
enable_select_queries=True,
database=db,
)

poe = cast(Table, db.events.fields["poe"])

assert poe.fields["some_field"] is not None

print_ast(parse_select("select person.some_field.key from events"), context, dialect="clickhouse")

@override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=True)
def test_database_warehouse_joins_persons_poe_v2(self):
DataWarehouseJoin.objects.create(
team=self.team,
source_table_name="persons",
source_table_key="properties.email",
joining_table_name="groups",
joining_table_key="key",
field_name="some_field",
)

db = create_hogql_database(team_id=self.team.pk)
context = HogQLContext(
team_id=self.team.pk,
enable_select_queries=True,
database=db,
)

poe = cast(Table, db.events.fields["poe"])

assert poe.fields["some_field"] is not None

print_ast(parse_select("select person.some_field.key from events"), context, dialect="clickhouse")

0 comments on commit 6fda11e

Please sign in to comment.