Skip to content

Commit

Permalink
parser tests with snapshots
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra committed Nov 21, 2023
1 parent 5249cae commit 8577afc
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 968 deletions.
7 changes: 5 additions & 2 deletions posthog/hogql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def has_child(self, name: str) -> bool:
def resolve_constant_type(self):
return self.type.resolve_constant_type()

def resolve_database_field(self):
if isinstance(self.type, FieldType):
return self.type.resolve_database_field()
raise NotImplementedException("FieldAliasType.resolve_database_field not implemented")


@dataclass(kw_only=True)
class BaseTableType(Type):
Expand Down Expand Up @@ -126,8 +131,6 @@ class SelectQueryType(Type):
aliases: Dict[str, FieldAliasType] = field(default_factory=dict)
# all types a select query exports
columns: Dict[str, Type] = field(default_factory=dict)
# these column have an explicit alias and can't be overridden
columns_with_explicit_alias: Dict[str, bool] = field(default_factory=dict)
# all from and join, tables and subqueries with aliases
tables: Dict[str, TableOrSelectType] = field(default_factory=dict)
ctes: Dict[str, CTE] = field(default_factory=dict)
Expand Down
30 changes: 13 additions & 17 deletions posthog/hogql/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,17 @@ def visit_select_query(self, node: ast.SelectQuery):
new_node.array_join_list = [self.visit(expr) for expr in node.array_join_list]

# Visit all the "SELECT a,b,c" columns. Mark each for export in "columns".
select_nodes = []
for expr in node.select or []:
new_expr = self.visit(expr)

# if it's an asterisk, carry on in a subroutine
if isinstance(new_expr.type, ast.AsteriskType):
self._expand_asterisk_columns(new_node, new_expr.type)
continue
columns = self._asterisk_columns(new_expr.type)
select_nodes.extend([self.visit(expr) for expr in columns])
else:
select_nodes.append(new_expr)

# not an asterisk
columns_with_explicit_alias = {}
for new_expr in select_nodes:
if isinstance(new_expr.type, ast.FieldAliasType):
alias = new_expr.type.alias
elif isinstance(new_expr.type, ast.FieldType):
Expand All @@ -139,12 +141,12 @@ def visit_select_query(self, node: ast.SelectQuery):
if alias:
# Remember the first visible or last hidden expr for each alias
if isinstance(new_expr, ast.Alias) and new_expr.hidden:
if alias not in node_type.columns or not node_type.columns_with_explicit_alias.get(alias, False):
if alias not in node_type.columns or not columns_with_explicit_alias.get(alias, False):
node_type.columns[alias] = new_expr.type
node_type.columns_with_explicit_alias[alias] = False
columns_with_explicit_alias[alias] = False
else:
node_type.columns[alias] = new_expr.type
node_type.columns_with_explicit_alias[alias] = True
columns_with_explicit_alias[alias] = True

# add the column to the new select query
new_node.select.append(new_expr)
Expand Down Expand Up @@ -172,15 +174,12 @@ def visit_select_query(self, node: ast.SelectQuery):

return new_node

def _expand_asterisk_columns(self, select_query: ast.SelectQuery, asterisk: ast.AsteriskType):
def _asterisk_columns(self, asterisk: ast.AsteriskType) -> List[ast.Expr]:
"""Expand an asterisk. Mutates `select_query.select` and `select_query.type.columns` with the new fields"""
if isinstance(asterisk.table_type, ast.BaseTableType):
table = asterisk.table_type.resolve_database_table()
database_fields = table.get_asterisk()
for key in database_fields.keys():
type = ast.FieldType(name=key, table_type=asterisk.table_type)
select_query.select.append(ast.Field(chain=[key], type=type))
select_query.type.columns[key] = type
return [ast.Field(chain=[key]) for key in database_fields.keys()]
elif (
isinstance(asterisk.table_type, ast.SelectUnionQueryType)
or isinstance(asterisk.table_type, ast.SelectQueryType)
Expand All @@ -192,10 +191,7 @@ def _expand_asterisk_columns(self, select_query: ast.SelectQuery, asterisk: ast.
if isinstance(select, ast.SelectUnionQueryType):
select = select.types[0]
if isinstance(select, ast.SelectQueryType):
for name in select.columns.keys():
type = ast.FieldType(name=name, table_type=asterisk.table_type)
select_query.select.append(ast.Field(chain=[name], type=type))
select_query.type.columns[name] = type
return [ast.Field(chain=[key]) for key in select.columns.keys()]
else:
raise ResolverException("Can't expand asterisk (*) on subquery")
else:
Expand Down
Loading

0 comments on commit 8577afc

Please sign in to comment.