Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(data-warehouse): joins on views #21151

Merged
merged 20 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions mypy-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ posthog/hogql/database/argmax.py:0: error: Unsupported operand types for + ("lis
posthog/hogql/database/schema/numbers.py:0: error: Incompatible types in assignment (expression has type "dict[str, IntegerDatabaseField]", variable has type "dict[str, FieldOrTable]") [assignment]
posthog/hogql/database/schema/numbers.py:0: note: "Dict" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance
posthog/hogql/database/schema/numbers.py:0: note: Consider using "Mapping" instead, which is covariant in the value type
posthog/hogql/ast.py:0: error: Argument "chain" to "FieldTraverserType" has incompatible type "list[str]"; expected "list[str | int]" [arg-type]
posthog/hogql/ast.py:0: note: "List" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance
posthog/hogql/ast.py:0: note: Consider using "Sequence" instead, which is covariant
posthog/hogql/ast.py:0: error: Incompatible return value type (got "bool | None", expected "bool") [return-value]
posthog/hogql/visitor.py:0: error: Statement is unreachable [unreachable]
posthog/hogql/visitor.py:0: error: Argument 1 to "visit" of "Visitor" has incompatible type "Type | None"; expected "AST" [arg-type]
Expand Down
3 changes: 3 additions & 0 deletions posthog/api/test/__snapshots__/test_decide.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# serializer version: 1
# name: TestDatabaseCheckForDecide.test_decide_doesnt_error_out_when_database_is_down_and_database_check_isnt_cached
'SELECT 1'
# ---
# name: TestDecide.test_decide_doesnt_error_out_when_database_is_down
'''
SELECT "posthog_user"."id",
Expand Down
29 changes: 24 additions & 5 deletions posthog/hogql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def get_child(self, name: str, context: HogQLContext) -> Type:
raise HogQLException(f"Field not found: {name}")


TableOrSelectType = Union[BaseTableType, "SelectUnionQueryType", "SelectQueryType", "SelectQueryAliasType"]


@dataclass(kw_only=True)
class TableType(BaseTableType):
table: Table
Expand All @@ -104,7 +107,7 @@ def resolve_database_table(self, context: HogQLContext) -> Table:

@dataclass(kw_only=True)
class LazyJoinType(BaseTableType):
table_type: BaseTableType
table_type: TableOrSelectType
field: str
lazy_join: LazyJoin

Expand All @@ -122,7 +125,7 @@ def resolve_database_table(self, context: HogQLContext) -> Table:

@dataclass(kw_only=True)
class VirtualTableType(BaseTableType):
table_type: BaseTableType
table_type: TableOrSelectType
field: str
virtual_table: VirtualTable

Expand All @@ -133,9 +136,6 @@ def has_child(self, name: str, context: HogQLContext) -> bool:
return self.virtual_table.has_field(name)


TableOrSelectType = Union[BaseTableType, "SelectUnionQueryType", "SelectQueryType", "SelectQueryAliasType"]


@dataclass(kw_only=True)
class SelectQueryType(Type):
"""Type and new enclosed scope for a select query. Contains information about all tables and columns in the query."""
Expand Down Expand Up @@ -187,12 +187,30 @@ def has_child(self, name: str, context: HogQLContext) -> bool:
class SelectQueryAliasType(Type):
alias: str
select_query_type: SelectQueryType | SelectUnionQueryType
view_name: Optional[str] = None

def get_child(self, name: str, context: HogQLContext) -> Type:
if name == "*":
return AsteriskType(table_type=self)
if self.select_query_type.has_child(name, context):
return FieldType(name=name, table_type=self)
if self.view_name:
if context.database is None:
raise HogQLException("Database must be set for queries with views")

field = context.database.get_table(self.view_name).get_field(name)

if isinstance(field, LazyJoin):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: this resolution logic is repeated from BaseTableType

return LazyJoinType(table_type=self, field=name, lazy_join=field)
if isinstance(field, LazyTable):
return LazyTableType(table=field)
if isinstance(field, FieldTraverser):
return FieldTraverserType(table_type=self, chain=field.chain)
if isinstance(field, VirtualTable):
return VirtualTableType(table_type=self, field=name, virtual_table=field)
if isinstance(field, ExpressionField):
return ExpressionFieldType(table_type=self, name=name, expr=field.expr)
return FieldType(name=name, table_type=self)
raise HogQLException(f"Field {name} not found on query with alias {self.alias}")

def has_child(self, name: str, context: HogQLContext) -> bool:
Expand Down Expand Up @@ -575,6 +593,7 @@ class SelectQuery(Expr):
limit_with_ties: Optional[bool] = None
offset: Optional[Expr] = None
settings: Optional[HogQLQuerySettings] = None
view_name: Optional[str] = None


@dataclass(kw_only=True)
Expand Down
4 changes: 2 additions & 2 deletions posthog/hogql/autocomplete.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ def resolve_table_field_traversers(table: Table, context: HogQLContext) -> Table
current_table_or_field: FieldOrTable = new_table
for chain in field.chain:
if isinstance(current_table_or_field, Table):
chain_field = current_table_or_field.fields.get(chain)
chain_field = current_table_or_field.fields.get(str(chain))
elif isinstance(current_table_or_field, LazyJoin):
chain_field = current_table_or_field.resolve_table(context).fields.get(chain)
chain_field = current_table_or_field.resolve_table(context).fields.get(str(chain))
elif isinstance(current_table_or_field, DatabaseField):
chain_field = current_table_or_field
else:
Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ class _SerializedFieldBase(TypedDict):
class SerializedField(_SerializedFieldBase, total=False):
fields: List[str]
table: str
chain: List[str]
chain: List[str | int]


def serialize_database(context: HogQLContext) -> Dict[str, List[SerializedField]]:
Expand Down
9 changes: 5 additions & 4 deletions posthog/hogql/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,18 @@ class ExpressionField(DatabaseField):
class FieldTraverser(FieldOrTable):
model_config = ConfigDict(extra="forbid")

chain: List[str]
chain: List[str | int]


class Table(FieldOrTable):
fields: Dict[str, FieldOrTable]
model_config = ConfigDict(extra="forbid")

def has_field(self, name: str) -> bool:
return name in self.fields
def has_field(self, name: str | int) -> bool:
return str(name) in self.fields

def get_field(self, name: str) -> FieldOrTable:
def get_field(self, name: str | int) -> FieldOrTable:
name = str(name)
if self.has_field(name):
return self.fields[name]
raise Exception(f'Field "{name}" not found on table {self.__class__.__name__}')
Expand Down
28 changes: 27 additions & 1 deletion posthog/hogql/database/test/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from posthog.models.organization import Organization
from posthog.models.team.team import Team
from posthog.test.base import BaseTest
from posthog.warehouse.models import DataWarehouseTable, DataWarehouseCredential
from posthog.warehouse.models import DataWarehouseTable, DataWarehouseCredential, DataWarehouseSavedQuery
from posthog.hogql.query import execute_hogql_query
from posthog.warehouse.models.join import DataWarehouseJoin

Expand Down Expand Up @@ -288,3 +288,29 @@ def test_database_warehouse_joins_persons_poe_v2(self):
assert poe.fields["some_field"] is not None

print_ast(parse_select("select person.some_field.key from events"), context, dialect="clickhouse")

def test_database_warehouse_joins_on_view(self):
DataWarehouseSavedQuery.objects.create(
team=self.team,
name="event_view",
query={"query": "SELECT event AS event from events"},
columns={"event": "String"},
)
DataWarehouseJoin.objects.create(
team=self.team,
source_table_name="event_view",
source_table_key="event",
joining_table_name="groups",
joining_table_key="key",
field_name="some_field",
)

db = create_hogql_database(team_id=self.team.pk)
context = HogQLContext(
team_id=self.team.pk,
enable_select_queries=True,
database=db,
)

sql = "select e.some_field.key from event_view as e"
print_ast(parse_select(sql), context, dialect="clickhouse")
8 changes: 8 additions & 0 deletions posthog/hogql/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def visit_select_query(self, node: ast.SelectQuery):
{name: self.visit(expr) for name, expr in node.window_exprs.items()} if node.window_exprs else None
)
new_node.settings = node.settings.model_copy() if node.settings is not None else None
new_node.view_name = node.view_name

self.scopes.pop()

Expand Down Expand Up @@ -276,6 +277,10 @@ def visit_join_expr(self, node: ast.JoinExpr):
raise ResolverException("Nested views are not supported")

node.table = parse_select(str(database_table.query))

if isinstance(node.table, ast.SelectQuery):
node.table.view_name = database_table.name

node.alias = table_alias or database_table.name
node = self.visit(node)

Expand Down Expand Up @@ -334,6 +339,9 @@ def visit_join_expr(self, node: ast.JoinExpr):
f'Already have joined a table called "{node.alias}". Can\'t join another one with the same name.'
)
node.type = ast.SelectQueryAliasType(alias=node.alias, select_query_type=node.table.type)
if isinstance(node.table, ast.SelectQuery):
node.type.view_name = node.table.view_name

scope.tables[node.alias] = node.type
else:
node.type = node.table.type
Expand Down
1 change: 1 addition & 0 deletions posthog/hogql/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def visit_select_query(self, node: ast.SelectQuery):
if node.window_exprs
else None,
settings=node.settings.model_copy() if node.settings is not None else None,
view_name=node.view_name,
)

def visit_select_union_query(self, node: ast.SelectUnionQuery):
Expand Down
Loading