Skip to content

Commit

Permalink
allow unresolved field when printing hogql
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra committed Apr 25, 2024
1 parent 494e83d commit 7faf6c3
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 69 deletions.
5 changes: 0 additions & 5 deletions mypy-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,7 @@ posthog/hogql/resolver.py:0: error: Incompatible types in assignment (expression
posthog/hogql/resolver.py:0: error: Argument 1 to "visit" of "Resolver" has incompatible type "SampleExpr | None"; expected "Expr" [arg-type]
posthog/hogql/resolver.py:0: error: Argument 2 to "convert_hogqlx_tag" has incompatible type "int | None"; expected "int" [arg-type]
posthog/hogql/resolver.py:0: error: Invalid index type "str | int" for "dict[str, BaseTableType | SelectUnionQueryType | SelectQueryType | SelectQueryAliasType | SelectViewType]"; expected type "str" [index]
posthog/hogql/resolver.py:0: error: Argument 2 to "lookup_field_by_name" has incompatible type "str | int"; expected "str" [arg-type]
posthog/hogql/resolver.py:0: error: Argument 2 to "lookup_cte_by_name" has incompatible type "str | int"; expected "str" [arg-type]
posthog/hogql/resolver.py:0: error: Argument 1 to "get_child" of "Type" has incompatible type "str | int"; expected "str" [arg-type]
posthog/hogql/resolver.py:0: error: Incompatible types in assignment (expression has type "Expr", variable has type "Alias") [assignment]
posthog/hogql/resolver.py:0: error: Argument "alias" to "Alias" has incompatible type "str | int"; expected "str" [arg-type]
posthog/hogql/resolver.py:0: error: Argument 1 to "join" of "str" has incompatible type "list[str | int]"; expected "Iterable[str]" [arg-type]
posthog/hogql/transforms/lazy_tables.py:0: error: Incompatible default for argument "context" (default has type "None", argument has type "HogQLContext") [assignment]
posthog/hogql/transforms/lazy_tables.py:0: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True
posthog/hogql/transforms/lazy_tables.py:0: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase
Expand Down
11 changes: 11 additions & 0 deletions posthog/hogql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,17 @@ def resolve_table_type(self, context: HogQLContext):
return self.table_type


@dataclass(kw_only=True)
class UnresolvedFieldType(Type):
name: str

def get_child(self, name: str | int, context: HogQLContext) -> "Type":
raise QueryError(f"Unable to resolve field: {self.name}")

def has_child(self, name: str | int, context: HogQLContext) -> bool:
return False


@dataclass(kw_only=True)
class PropertyType(Type):
chain: list[str | int]
Expand Down
23 changes: 23 additions & 0 deletions posthog/hogql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class HogQLContext:
warnings: list["HogQLNotice"] = field(default_factory=list)
# Notices returned with the metadata query
notices: list["HogQLNotice"] = field(default_factory=list)
# Errors returned with the metadata query
errors: list["HogQLNotice"] = field(default_factory=list)

# Timings in seconds for different parts of the HogQL query
timings: HogQLTimings = field(default_factory=HogQLTimings)
# Modifications requested by the HogQL client
Expand All @@ -68,3 +71,23 @@ def add_notice(
):
if not any(n.start == start and n.end == end and n.message == message and n.fix == fix for n in self.notices):
self.notices.append(HogQLNotice(start=start, end=end, message=message, fix=fix))

def add_warning(
self,
message: str,
start: Optional[int] = None,
end: Optional[int] = None,
fix: Optional[str] = None,
):
if not any(n.start == start and n.end == end and n.message == message and n.fix == fix for n in self.warnings):
self.warnings.append(HogQLNotice(start=start, end=end, message=message, fix=fix))

def add_error(
self,
message: str,
start: Optional[int] = None,
end: Optional[int] = None,
fix: Optional[str] = None,
):
if not any(n.start == start and n.end == end and n.message == message and n.fix == fix for n in self.errors):
self.errors.append(HogQLNotice(start=start, end=end, message=message, fix=fix))
2 changes: 2 additions & 0 deletions posthog/hogql/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def get_hogql_metadata(
raise ValueError("Either expr or select must be provided")
response.warnings = context.warnings
response.notices = context.notices
response.errors = context.errors
response.isValid = len(response.errors) == 0
except Exception as e:
response.isValid = False
if isinstance(e, ExposedHogQLError):
Expand Down
6 changes: 6 additions & 0 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,12 @@ def visit_lazy_table_type(self, type: ast.LazyJoinType):
def visit_field_traverser_type(self, type: ast.FieldTraverserType):
raise ImpossibleASTError("Unexpected ast.FieldTraverserType. This should have been resolved.")

def visit_unresolved_field_type(self, type: ast.UnresolvedFieldType):
if self.dialect == "clickhouse":
raise QueryError(f"Unable to resolve field: {type.name}")
else:
return self._print_identifier(type.name)

def visit_unknown(self, node: AST):
raise ImpossibleASTError(f"Unknown AST node {type(node).__name__}")

Expand Down
114 changes: 62 additions & 52 deletions posthog/hogql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def execute_hogql_query(
context = HogQLContext(team_id=team.pk)

query_modifiers = create_default_modifiers_for_team(team, modifiers)
error: Optional[str] = None
explain_output: Optional[list[str]] = None
results = None
types = None
metadata: Optional[HogQLMetadataResponse] = None

with timings.measure("query"):
if isinstance(query, ast.SelectQuery) or isinstance(query, ast.SelectUnionQuery):
Expand Down Expand Up @@ -133,73 +138,78 @@ def execute_hogql_query(

# Print the ClickHouse SQL query
with timings.measure("print_ast"):
clickhouse_context = dataclasses.replace(
context,
# set the team.pk here so someone can't pass a context for a different team 🤷‍️
team_id=team.pk,
team=team,
enable_select_queries=True,
timings=timings,
modifiers=query_modifiers,
)

clickhouse_sql = print_ast(
select_query,
context=clickhouse_context,
dialect="clickhouse",
settings=settings,
pretty=pretty if pretty is not None else True,
)

timings_dict = timings.to_dict()
with timings.measure("clickhouse_execute"):
tag_queries(
team_id=team.pk,
query_type=query_type,
has_joins="JOIN" in clickhouse_sql,
has_json_operations="JSONExtract" in clickhouse_sql or "JSONHas" in clickhouse_sql,
timings=timings_dict,
modifiers={k: v for k, v in modifiers.model_dump().items() if v is not None} if modifiers else {},
)

error = None
try:
results, types = sync_execute(
clickhouse_sql,
clickhouse_context.values,
with_column_types=True,
workload=workload,
clickhouse_context = dataclasses.replace(
context,
# set the team.pk here so someone can't pass a context for a different team 🤷‍️
team_id=team.pk,
readonly=True,
team=team,
enable_select_queries=True,
timings=timings,
modifiers=query_modifiers,
)
clickhouse_sql = print_ast(
select_query,
context=clickhouse_context,
dialect="clickhouse",
settings=settings,
pretty=pretty if pretty is not None else True,
)
except Exception as e:
if explain:
results, types = None, None
clickhouse_sql = None
if isinstance(e, ExposedCHQueryError | ExposedHogQLError):
error = str(e)
else:
error = "Unknown error"
else:
raise e

metadata: Optional[HogQLMetadataResponse] = None
if explain and error is None: # If the query errored, explain will fail as well.
with timings.measure("explain"):
explain_results = sync_execute(
f"EXPLAIN {clickhouse_sql}",
clickhouse_context.values,
with_column_types=True,
workload=workload,
if clickhouse_sql is not None:
timings_dict = timings.to_dict()
with timings.measure("clickhouse_execute"):
tag_queries(
team_id=team.pk,
readonly=True,
query_type=query_type,
has_joins="JOIN" in clickhouse_sql,
has_json_operations="JSONExtract" in clickhouse_sql or "JSONHas" in clickhouse_sql,
timings=timings_dict,
modifiers={k: v for k, v in modifiers.model_dump().items() if v is not None} if modifiers else {},
)
explain_output = [str(r[0]) for r in explain_results[0]]
with timings.measure("metadata"):
from posthog.hogql.metadata import get_hogql_metadata

metadata = get_hogql_metadata(HogQLMetadata(select=hogql, debug=True), team)
else:
explain_output = None
try:
results, types = sync_execute(
clickhouse_sql,
clickhouse_context.values,
with_column_types=True,
workload=workload,
team_id=team.pk,
readonly=True,
)
except Exception as e:
if explain:
if isinstance(e, ExposedCHQueryError | ExposedHogQLError):
error = str(e)
else:
error = "Unknown error"
else:
raise e

if explain and error is None: # If the query errored, explain will fail as well.
with timings.measure("explain"):
explain_results = sync_execute(
f"EXPLAIN {clickhouse_sql}",
clickhouse_context.values,
with_column_types=True,
workload=workload,
team_id=team.pk,
readonly=True,
)
explain_output = [str(r[0]) for r in explain_results[0]]
with timings.measure("metadata"):
from posthog.hogql.metadata import get_hogql_metadata

metadata = get_hogql_metadata(HogQLMetadata(select=hogql, debug=True), team)

return HogQLQueryResponse(
query=query,
Expand Down
22 changes: 15 additions & 7 deletions posthog/hogql/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def resolve_constant_data_type(constant: Any) -> ConstantType:


def resolve_types(
node: ast.Expr,
node: ast.Expr | ast.SelectQuery,
context: HogQLContext,
dialect: Literal["hogql", "clickhouse"],
scopes: Optional[list[ast.SelectQueryType]] = None,
Expand Down Expand Up @@ -450,7 +450,7 @@ def visit_field(self, node: ast.Field):
scope = self.scopes[-1]

type: Optional[ast.Type] = None
name = node.chain[0]
name = str(node.chain[0])

# If the field contains at least two parts, the first might be a table.
if len(node.chain) > 1 and name in scope.tables:
Expand Down Expand Up @@ -487,10 +487,18 @@ def visit_field(self, node: ast.Field):
return response

if not type:
raise QueryError(f"Unable to resolve field: {name}")
if self.dialect == "clickhouse":
raise QueryError(f"Unable to resolve field: {name}")
else:
type = ast.UnresolvedFieldType(name=name)
self.context.add_error(
start=node.start,
end=node.end,
message=f"Unable to resolve field: {name}",
)

# Recursively resolve the rest of the chain until we can point to the deepest node.
field_name = node.chain[-1]
field_name = str(node.chain[-1])
loop_type = type
chain_to_parse = node.chain[1:]
previous_types = []
Expand All @@ -509,7 +517,7 @@ def visit_field(self, node: ast.Field):
loop_type = previous_types[-1]
next_chain = chain_to_parse.pop(0)

loop_type = loop_type.get_child(next_chain, self.context)
loop_type = loop_type.get_child(str(next_chain), self.context)
if loop_type is None:
raise ResolutionError(f"Cannot resolve type {'.'.join(node.chain)}. Unable to resolve {next_chain}.")
node.type = loop_type
Expand All @@ -518,7 +526,7 @@ def visit_field(self, node: ast.Field):
# only swap out expression fields in ClickHouse
if self.dialect == "clickhouse":
new_expr = clone_expr(node.type.expr)
new_node = ast.Alias(alias=node.type.name, expr=new_expr, hidden=True)
new_node: ast.Expr = ast.Alias(alias=node.type.name, expr=new_expr, hidden=True)
new_node = self.visit(new_node)
return new_node

Expand All @@ -537,7 +545,7 @@ def visit_field(self, node: ast.Field):
type=ast.FieldAliasType(alias=node.type.name, type=node.type),
)
elif isinstance(node.type, ast.PropertyType):
property_alias = "__".join(node.type.chain)
property_alias = "__".join(str(s) for s in node.type.chain)
return ast.Alias(
alias=property_alias,
expr=node,
Expand Down
14 changes: 12 additions & 2 deletions posthog/hogql/test/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _print_hogql(self, select: str):

def setUp(self):
self.database = create_hogql_database(self.team.pk)
self.context = HogQLContext(database=self.database, team_id=self.team.pk)
self.context = HogQLContext(database=self.database, team_id=self.team.pk, enable_select_queries=True)

@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_events_table(self):
Expand Down Expand Up @@ -125,6 +125,16 @@ def test_resolve_errors(self):
resolve_types(self._select(query), self.context, dialect="clickhouse")
self.assertIn("Unable to resolve field:", str(e.exception))

def test_unresolved_field_type(self):
query = "SELECT x"
# raises with ClickHouse
with self.assertRaises(QueryError):
resolve_types(self._select(query), self.context, dialect="clickhouse")
# does not raise with HogQL
select = self._select(query)
select = resolve_types(select, self.context, dialect="hogql")
assert isinstance(select.select[0].type, ast.UnresolvedFieldType)

@pytest.mark.usefixtures("unittest_snapshot")
def test_resolve_lazy_pdi_person_table(self):
expr = self._select("select distinct_id, person.id from person_distinct_ids")
Expand Down Expand Up @@ -299,7 +309,7 @@ def test_asterisk_expander_multiple_table_error(self):
def test_asterisk_expander_select_union(self):
self.setUp() # rebuild self.database with PERSON_ON_EVENTS_OVERRIDE=False
node = self._select("select * from (select * from events union all select * from events)")
node = resolve_types(node, self.context, dialect="clickhouse")
node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse"))
assert pretty_dataclasses(node) == self.snapshot

def test_lambda_parent_scope(self):
Expand Down
9 changes: 6 additions & 3 deletions posthog/hogql/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ def visit_uuid_type(self, node: ast.UUIDType):
def visit_property_type(self, node: ast.PropertyType):
self.visit(node.field_type)

def visit_expression_field_type(self, node: ast.ExpressionFieldType):
pass

def visit_unresolved_field_type(self, node: ast.UnresolvedFieldType):
pass

def visit_window_expr(self, node: ast.WindowExpr):
for expr in node.partition_by or []:
self.visit(expr)
Expand All @@ -250,9 +256,6 @@ def visit_window_frame_expr(self, node: ast.WindowFrameExpr):
def visit_join_constraint(self, node: ast.JoinConstraint):
self.visit(node.expr)

def visit_expression_field_type(self, node: ast.ExpressionFieldType):
pass

def visit_hogqlx_tag(self, node: ast.HogQLXTag):
for attribute in node.attributes:
self.visit(attribute)
Expand Down

0 comments on commit 7faf6c3

Please sign in to comment.