diff --git a/posthog/hogql/ast.py b/posthog/hogql/ast.py index 806226b8f1b9e..2f72678747ee8 100644 --- a/posthog/hogql/ast.py +++ b/posthog/hogql/ast.py @@ -187,10 +187,24 @@ def has_child(self, name: str, context: HogQLContext) -> bool: class SelectQueryAliasType(Type): alias: str select_query_type: SelectQueryType | SelectUnionQueryType + view_name: Optional[str] = None def get_child(self, name: str, context: HogQLContext) -> Type: if name == "*": return AsteriskType(table_type=self) + if self.view_name: + field = context.database.get_table(self.view_name).get_field(name) + if isinstance(field, LazyJoin): + return LazyJoinType(table_type=self, field=name, lazy_join=field) + if isinstance(field, LazyTable): + return LazyTableType(table=field) + if isinstance(field, FieldTraverser): + return FieldTraverserType(table_type=self, chain=field.chain) + if isinstance(field, VirtualTable): + return VirtualTableType(table_type=self, field=name, virtual_table=field) + if isinstance(field, ExpressionField): + return ExpressionFieldType(table_type=self, name=name, expr=field.expr) + return FieldType(name=name, table_type=self) if self.select_query_type.has_child(name, context): return FieldType(name=name, table_type=self) raise HogQLException(f"Field {name} not found on query with alias {self.alias}") @@ -575,6 +589,7 @@ class SelectQuery(Expr): limit_with_ties: Optional[bool] = None offset: Optional[Expr] = None settings: Optional[HogQLQuerySettings] = None + view_name: Optional[str] = None @dataclass(kw_only=True) diff --git a/posthog/hogql/resolver.py b/posthog/hogql/resolver.py index 19bbddb9f65b7..feff5ed4aaaad 100644 --- a/posthog/hogql/resolver.py +++ b/posthog/hogql/resolver.py @@ -207,6 +207,7 @@ def visit_select_query(self, node: ast.SelectQuery): {name: self.visit(expr) for name, expr in node.window_exprs.items()} if node.window_exprs else None ) new_node.settings = node.settings.model_copy() if node.settings is not None else None + new_node.view_name = node.view_name self.scopes.pop() @@ -276,6 +277,10 @@ def visit_join_expr(self, node: ast.JoinExpr): raise ResolverException("Nested views are not supported") node.table = parse_select(str(database_table.query)) + + if isinstance(node.table, ast.SelectQuery): + node.table.view_name = database_table.name + node.alias = table_alias or database_table.name node = self.visit(node) @@ -334,6 +339,9 @@ def visit_join_expr(self, node: ast.JoinExpr): f'Already have joined a table called "{node.alias}". Can\'t join another one with the same name.' ) node.type = ast.SelectQueryAliasType(alias=node.alias, select_query_type=node.table.type) + if isinstance(node.table, ast.SelectQuery): + node.type.view_name = node.table.view_name + scope.tables[node.alias] = node.type else: node.type = node.table.type diff --git a/posthog/hogql/views.py b/posthog/hogql/views.py new file mode 100644 index 0000000000000..944c32e80e991 --- /dev/null +++ b/posthog/hogql/views.py @@ -0,0 +1,58 @@ +from posthog.hogql import ast +from typing import List, Optional, Literal +from posthog.hogql.context import HogQLContext + +from posthog.hogql.database.models import ( + SavedQuery, +) + +from posthog.hogql.visitor import CloningVisitor +from posthog.hogql.parser import parse_select +from posthog.hogql.transforms.property_types import resolve_property_types + + +def resolve_views( + node: ast.Expr, + context: HogQLContext, + dialect: Literal["hogql", "clickhouse"], + scopes: Optional[List[ast.SelectQueryType]] = None, +) -> ast.Expr: + return ViewResolver(scopes=scopes, context=context, dialect=dialect).visit(node) + + +class ViewResolver(CloningVisitor): + """The ViewResolver only visits an AST and resolves all views""" + + def __init__( + self, + context: HogQLContext, + dialect: Literal["hogql", "clickhouse"] = "clickhouse", + scopes: Optional[List[ast.SelectQueryType]] = None, + ): + super().__init__() + # Each SELECT query creates a new scope (type). Store all of them in a list as we traverse the tree. + self.scopes: List[ast.SelectQueryType] = scopes or [] + self.current_view_depth: int = 0 + self.context = context + self.dialect = dialect + self.database = context.database + + def visit_join_expr(self, node: ast.JoinExpr): + from posthog.hogql.resolver import resolve_types + + if ( + isinstance(node.type, ast.TableAliasType) + and isinstance(node.type.table_type, ast.TableType) + and isinstance(node.type.table_type.table, SavedQuery) + ): + resolved_table = parse_select(str(node.type.table_type.table.query)) + resolved_table = resolve_types(resolved_table, self.context, self.dialect) + resolved_table = resolve_property_types(resolved_table, self.context) + + node.type = ast.SelectQueryAliasType( + select_query_type=resolved_table.type, + alias=node.alias, + ) + node.table = resolved_table + + return node diff --git a/posthog/hogql/visitor.py b/posthog/hogql/visitor.py index 2bf968abf2ab0..21eadda584505 100644 --- a/posthog/hogql/visitor.py +++ b/posthog/hogql/visitor.py @@ -485,6 +485,7 @@ def visit_select_query(self, node: ast.SelectQuery): if node.window_exprs else None, settings=node.settings.model_copy() if node.settings is not None else None, + view_name=node.view_name, ) def visit_select_union_query(self, node: ast.SelectUnionQuery):