diff --git a/mypy-baseline.txt b/mypy-baseline.txt index 9b607b6222cd3..d8679f108cd3a 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -194,12 +194,7 @@ posthog/hogql/resolver.py:0: error: Incompatible types in assignment (expression posthog/hogql/resolver.py:0: error: Argument 1 to "visit" of "Resolver" has incompatible type "SampleExpr | None"; expected "Expr" [arg-type] posthog/hogql/resolver.py:0: error: Argument 2 to "convert_hogqlx_tag" has incompatible type "int | None"; expected "int" [arg-type] posthog/hogql/resolver.py:0: error: Invalid index type "str | int" for "dict[str, BaseTableType | SelectUnionQueryType | SelectQueryType | SelectQueryAliasType | SelectViewType]"; expected type "str" [index] -posthog/hogql/resolver.py:0: error: Argument 2 to "lookup_field_by_name" has incompatible type "str | int"; expected "str" [arg-type] posthog/hogql/resolver.py:0: error: Argument 2 to "lookup_cte_by_name" has incompatible type "str | int"; expected "str" [arg-type] -posthog/hogql/resolver.py:0: error: Argument 1 to "get_child" of "Type" has incompatible type "str | int"; expected "str" [arg-type] -posthog/hogql/resolver.py:0: error: Incompatible types in assignment (expression has type "Expr", variable has type "Alias") [assignment] -posthog/hogql/resolver.py:0: error: Argument "alias" to "Alias" has incompatible type "str | int"; expected "str" [arg-type] -posthog/hogql/resolver.py:0: error: Argument 1 to "join" of "str" has incompatible type "list[str | int]"; expected "Iterable[str]" [arg-type] posthog/hogql/transforms/lazy_tables.py:0: error: Incompatible default for argument "context" (default has type "None", argument has type "HogQLContext") [assignment] posthog/hogql/transforms/lazy_tables.py:0: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True posthog/hogql/transforms/lazy_tables.py:0: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase diff --git a/posthog/hogql/ast.py b/posthog/hogql/ast.py index e3fa80b3f3ee8..ebc4b5d259df6 100644 --- a/posthog/hogql/ast.py +++ b/posthog/hogql/ast.py @@ -398,6 +398,17 @@ def resolve_table_type(self, context: HogQLContext): return self.table_type +@dataclass(kw_only=True) +class UnresolvedFieldType(Type): + name: str + + def get_child(self, name: str | int, context: HogQLContext) -> "Type": + raise QueryError(f"Unable to resolve field: {self.name}") + + def has_child(self, name: str | int, context: HogQLContext) -> bool: + return False + + @dataclass(kw_only=True) class PropertyType(Type): chain: list[str | int] diff --git a/posthog/hogql/context.py b/posthog/hogql/context.py index 9b5b6092a6911..6e8c982cb038b 100644 --- a/posthog/hogql/context.py +++ b/posthog/hogql/context.py @@ -42,6 +42,9 @@ class HogQLContext: warnings: list["HogQLNotice"] = field(default_factory=list) # Notices returned with the metadata query notices: list["HogQLNotice"] = field(default_factory=list) + # Errors returned with the metadata query + errors: list["HogQLNotice"] = field(default_factory=list) + # Timings in seconds for different parts of the HogQL query timings: HogQLTimings = field(default_factory=HogQLTimings) # Modifications requested by the HogQL client @@ -68,3 +71,23 @@ def add_notice( ): if not any(n.start == start and n.end == end and n.message == message and n.fix == fix for n in self.notices): self.notices.append(HogQLNotice(start=start, end=end, message=message, fix=fix)) + + def add_warning( + self, + message: str, + start: Optional[int] = None, + end: Optional[int] = None, + fix: Optional[str] = None, + ): + if not any(n.start == start and n.end == end and n.message == message and n.fix == fix for n in self.warnings): + self.warnings.append(HogQLNotice(start=start, end=end, message=message, fix=fix)) + + def add_error( + self, + message: str, + start: Optional[int] = None, + end: Optional[int] = None, + fix: Optional[str] = None, + ): + if not any(n.start == start and n.end == end and n.message == message and n.fix == fix for n in self.errors): + self.errors.append(HogQLNotice(start=start, end=end, message=message, fix=fix)) diff --git a/posthog/hogql/metadata.py b/posthog/hogql/metadata.py index 13d2579b56856..c59a3df1bf471 100644 --- a/posthog/hogql/metadata.py +++ b/posthog/hogql/metadata.py @@ -58,6 +58,8 @@ def get_hogql_metadata( raise ValueError("Either expr or select must be provided") response.warnings = context.warnings response.notices = context.notices + response.errors = context.errors + response.isValid = len(response.errors) == 0 except Exception as e: response.isValid = False if isinstance(e, ExposedHogQLError): diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index a829697e9007a..0bca322df6022 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -1090,6 +1090,12 @@ def visit_lazy_table_type(self, type: ast.LazyJoinType): def visit_field_traverser_type(self, type: ast.FieldTraverserType): raise ImpossibleASTError("Unexpected ast.FieldTraverserType. This should have been resolved.") + def visit_unresolved_field_type(self, type: ast.UnresolvedFieldType): + if self.dialect == "clickhouse": + raise QueryError(f"Unable to resolve field: {type.name}") + else: + return self._print_identifier(type.name) + def visit_unknown(self, node: AST): raise ImpossibleASTError(f"Unknown AST node {type(node).__name__}") diff --git a/posthog/hogql/query.py b/posthog/hogql/query.py index b42a61b785541..56373e0be2643 100644 --- a/posthog/hogql/query.py +++ b/posthog/hogql/query.py @@ -49,6 +49,11 @@ def execute_hogql_query( context = HogQLContext(team_id=team.pk) query_modifiers = create_default_modifiers_for_team(team, modifiers) + error: Optional[str] = None + explain_output: Optional[list[str]] = None + results = None + types = None + metadata: Optional[HogQLMetadataResponse] = None with timings.measure("query"): if isinstance(query, ast.SelectQuery) or isinstance(query, ast.SelectUnionQuery): @@ -133,48 +138,26 @@ def execute_hogql_query( # Print the ClickHouse SQL query with timings.measure("print_ast"): - clickhouse_context = dataclasses.replace( - context, - # set the team.pk here so someone can't pass a context for a different team 🤷‍️ - team_id=team.pk, - team=team, - enable_select_queries=True, - timings=timings, - modifiers=query_modifiers, - ) - - clickhouse_sql = print_ast( - select_query, - context=clickhouse_context, - dialect="clickhouse", - settings=settings, - pretty=pretty if pretty is not None else True, - ) - - timings_dict = timings.to_dict() - with timings.measure("clickhouse_execute"): - tag_queries( - team_id=team.pk, - query_type=query_type, - has_joins="JOIN" in clickhouse_sql, - has_json_operations="JSONExtract" in clickhouse_sql or "JSONHas" in clickhouse_sql, - timings=timings_dict, - modifiers={k: v for k, v in modifiers.model_dump().items() if v is not None} if modifiers else {}, - ) - - error = None try: - results, types = sync_execute( - clickhouse_sql, - clickhouse_context.values, - with_column_types=True, - workload=workload, + clickhouse_context = dataclasses.replace( + context, + # set the team.pk here so someone can't pass a context for a different team 🤷‍️ team_id=team.pk, - readonly=True, + team=team, + enable_select_queries=True, + timings=timings, + modifiers=query_modifiers, + ) + clickhouse_sql = print_ast( + select_query, + context=clickhouse_context, + dialect="clickhouse", + settings=settings, + pretty=pretty if pretty is not None else True, ) except Exception as e: if explain: - results, types = None, None + clickhouse_sql = None if isinstance(e, ExposedCHQueryError | ExposedHogQLError): error = str(e) else: @@ -182,24 +165,51 @@ def execute_hogql_query( else: raise e - metadata: Optional[HogQLMetadataResponse] = None - if explain and error is None: # If the query errored, explain will fail as well. - with timings.measure("explain"): - explain_results = sync_execute( - f"EXPLAIN {clickhouse_sql}", - clickhouse_context.values, - with_column_types=True, - workload=workload, + if clickhouse_sql is not None: + timings_dict = timings.to_dict() + with timings.measure("clickhouse_execute"): + tag_queries( team_id=team.pk, - readonly=True, + query_type=query_type, + has_joins="JOIN" in clickhouse_sql, + has_json_operations="JSONExtract" in clickhouse_sql or "JSONHas" in clickhouse_sql, + timings=timings_dict, + modifiers={k: v for k, v in modifiers.model_dump().items() if v is not None} if modifiers else {}, ) - explain_output = [str(r[0]) for r in explain_results[0]] - with timings.measure("metadata"): - from posthog.hogql.metadata import get_hogql_metadata - metadata = get_hogql_metadata(HogQLMetadata(select=hogql, debug=True), team) - else: - explain_output = None + try: + results, types = sync_execute( + clickhouse_sql, + clickhouse_context.values, + with_column_types=True, + workload=workload, + team_id=team.pk, + readonly=True, + ) + except Exception as e: + if explain: + if isinstance(e, ExposedCHQueryError | ExposedHogQLError): + error = str(e) + else: + error = "Unknown error" + else: + raise e + + if explain and error is None: # If the query errored, explain will fail as well. + with timings.measure("explain"): + explain_results = sync_execute( + f"EXPLAIN {clickhouse_sql}", + clickhouse_context.values, + with_column_types=True, + workload=workload, + team_id=team.pk, + readonly=True, + ) + explain_output = [str(r[0]) for r in explain_results[0]] + with timings.measure("metadata"): + from posthog.hogql.metadata import get_hogql_metadata + + metadata = get_hogql_metadata(HogQLMetadata(select=hogql, debug=True), team) return HogQLQueryResponse( query=query, diff --git a/posthog/hogql/resolver.py b/posthog/hogql/resolver.py index 5921e5a6f2d94..ea6574e4488b4 100644 --- a/posthog/hogql/resolver.py +++ b/posthog/hogql/resolver.py @@ -55,7 +55,7 @@ def resolve_constant_data_type(constant: Any) -> ConstantType: def resolve_types( - node: ast.Expr, + node: ast.Expr | ast.SelectQuery, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], scopes: Optional[list[ast.SelectQueryType]] = None, @@ -450,7 +450,7 @@ def visit_field(self, node: ast.Field): scope = self.scopes[-1] type: Optional[ast.Type] = None - name = node.chain[0] + name = str(node.chain[0]) # If the field contains at least two parts, the first might be a table. if len(node.chain) > 1 and name in scope.tables: @@ -487,10 +487,18 @@ def visit_field(self, node: ast.Field): return response if not type: - raise QueryError(f"Unable to resolve field: {name}") + if self.dialect == "clickhouse": + raise QueryError(f"Unable to resolve field: {name}") + else: + type = ast.UnresolvedFieldType(name=name) + self.context.add_error( + start=node.start, + end=node.end, + message=f"Unable to resolve field: {name}", + ) # Recursively resolve the rest of the chain until we can point to the deepest node. - field_name = node.chain[-1] + field_name = str(node.chain[-1]) loop_type = type chain_to_parse = node.chain[1:] previous_types = [] @@ -509,7 +517,7 @@ def visit_field(self, node: ast.Field): loop_type = previous_types[-1] next_chain = chain_to_parse.pop(0) - loop_type = loop_type.get_child(next_chain, self.context) + loop_type = loop_type.get_child(str(next_chain), self.context) if loop_type is None: raise ResolutionError(f"Cannot resolve type {'.'.join(node.chain)}. Unable to resolve {next_chain}.") node.type = loop_type @@ -518,7 +526,7 @@ def visit_field(self, node: ast.Field): # only swap out expression fields in ClickHouse if self.dialect == "clickhouse": new_expr = clone_expr(node.type.expr) - new_node = ast.Alias(alias=node.type.name, expr=new_expr, hidden=True) + new_node: ast.Expr = ast.Alias(alias=node.type.name, expr=new_expr, hidden=True) new_node = self.visit(new_node) return new_node @@ -537,7 +545,7 @@ def visit_field(self, node: ast.Field): type=ast.FieldAliasType(alias=node.type.name, type=node.type), ) elif isinstance(node.type, ast.PropertyType): - property_alias = "__".join(node.type.chain) + property_alias = "__".join(str(s) for s in node.type.chain) return ast.Alias( alias=property_alias, expr=node, diff --git a/posthog/hogql/test/test_resolver.py b/posthog/hogql/test/test_resolver.py index 7cbd5a60a3245..869466d8a7446 100644 --- a/posthog/hogql/test/test_resolver.py +++ b/posthog/hogql/test/test_resolver.py @@ -44,7 +44,7 @@ def _print_hogql(self, select: str): def setUp(self): self.database = create_hogql_database(self.team.pk) - self.context = HogQLContext(database=self.database, team_id=self.team.pk) + self.context = HogQLContext(database=self.database, team_id=self.team.pk, enable_select_queries=True) @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_events_table(self): @@ -125,6 +125,16 @@ def test_resolve_errors(self): resolve_types(self._select(query), self.context, dialect="clickhouse") self.assertIn("Unable to resolve field:", str(e.exception)) + def test_unresolved_field_type(self): + query = "SELECT x" + # raises with ClickHouse + with self.assertRaises(QueryError): + resolve_types(self._select(query), self.context, dialect="clickhouse") + # does not raise with HogQL + select = self._select(query) + select = resolve_types(select, self.context, dialect="hogql") + assert isinstance(select.select[0].type, ast.UnresolvedFieldType) + @pytest.mark.usefixtures("unittest_snapshot") def test_resolve_lazy_pdi_person_table(self): expr = self._select("select distinct_id, person.id from person_distinct_ids") @@ -299,7 +309,7 @@ def test_asterisk_expander_multiple_table_error(self): def test_asterisk_expander_select_union(self): self.setUp() # rebuild self.database with PERSON_ON_EVENTS_OVERRIDE=False node = self._select("select * from (select * from events union all select * from events)") - node = resolve_types(node, self.context, dialect="clickhouse") + node = cast(ast.SelectQuery, resolve_types(node, self.context, dialect="clickhouse")) assert pretty_dataclasses(node) == self.snapshot def test_lambda_parent_scope(self): diff --git a/posthog/hogql/visitor.py b/posthog/hogql/visitor.py index 03c0ea2d93284..41a195b8ed4fb 100644 --- a/posthog/hogql/visitor.py +++ b/posthog/hogql/visitor.py @@ -231,6 +231,12 @@ def visit_uuid_type(self, node: ast.UUIDType): def visit_property_type(self, node: ast.PropertyType): self.visit(node.field_type) + def visit_expression_field_type(self, node: ast.ExpressionFieldType): + pass + + def visit_unresolved_field_type(self, node: ast.UnresolvedFieldType): + pass + def visit_window_expr(self, node: ast.WindowExpr): for expr in node.partition_by or []: self.visit(expr) @@ -250,9 +256,6 @@ def visit_window_frame_expr(self, node: ast.WindowFrameExpr): def visit_join_constraint(self, node: ast.JoinConstraint): self.visit(node.expr) - def visit_expression_field_type(self, node: ast.ExpressionFieldType): - pass - def visit_hogqlx_tag(self, node: ast.HogQLXTag): for attribute in node.attributes: self.visit(attribute)