diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index 800b649760285..3108c31d05124 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -57,7 +57,10 @@ def team_id_guard_for_table(table_type: Union[ast.TableType, ast.TableAliasType] def to_printed_hogql(query: ast.Expr, team_id: int) -> str: """Prints the HogQL query without mutating the node""" return print_ast( - clone_expr(query), dialect="hogql", context=HogQLContext(team_id=team_id, enable_select_queries=True) + clone_expr(query), + dialect="hogql", + context=HogQLContext(team_id=team_id, enable_select_queries=True), + pretty=True, ) @@ -67,9 +70,12 @@ def print_ast( dialect: Literal["hogql", "clickhouse"], stack: Optional[List[ast.SelectQuery]] = None, settings: Optional[HogQLGlobalSettings] = None, + pretty: bool = False, ) -> str: prepared_ast = prepare_ast_for_printing(node=node, context=context, dialect=dialect, stack=stack, settings=settings) - return print_prepared_ast(node=prepared_ast, context=context, dialect=dialect, stack=stack, settings=settings) + return print_prepared_ast( + node=prepared_ast, context=context, dialect=dialect, stack=stack, settings=settings, pretty=pretty + ) def prepare_ast_for_printing( @@ -111,10 +117,13 @@ def print_prepared_ast( dialect: Literal["hogql", "clickhouse"], stack: Optional[List[ast.SelectQuery]] = None, settings: Optional[HogQLGlobalSettings] = None, + pretty: bool = False, ) -> str: with context.timings.measure("printer"): # _Printer also adds a team_id guard if printing clickhouse - return _Printer(context=context, dialect=dialect, stack=stack or [], settings=settings).visit(node) + return _Printer(context=context, dialect=dialect, stack=stack or [], settings=settings, pretty=pretty).visit( + node + ) @dataclass @@ -132,15 +141,24 @@ def __init__( dialect: Literal["hogql", "clickhouse"], stack: Optional[List[AST]] = None, settings: Optional[HogQLGlobalSettings] = None, + pretty: bool = False, ): self.context = context self.dialect = dialect self.stack: List[AST] = stack or [] # Keep track of all traversed nodes. self.settings = settings + self.pretty = pretty + self._indent = -1 + self.tab_size = 4 + + def indent(self, extra: int = 0): + return " " * self.tab_size * (self._indent + extra) def visit(self, node: AST): self.stack.append(node) + self._indent += 1 response = super().visit(node) + self._indent -= 1 self.stack.pop() if len(self.stack) == 0 and self.dialect == "clickhouse" and self.settings: @@ -153,9 +171,15 @@ def visit(self, node: AST): return response def visit_select_union_query(self, node: ast.SelectUnionQuery): - query = " UNION ALL ".join([self.visit(expr) for expr in node.select_queries]) + self._indent -= 1 + queries = [self.visit(expr) for expr in node.select_queries] + if self.pretty: + query = f"\n{self.indent(1)}UNION ALL\n{self.indent(1)}".join([query.strip() for query in queries]) + else: + query = " UNION ALL ".join(queries) + self._indent += 1 if len(self.stack) > 1: - return f"({query})" + return f"({query.strip()})" return query def visit_select_query(self, node: ast.SelectQuery): @@ -221,16 +245,19 @@ def visit_select_query(self, node: ast.SelectQuery): raise HogQLException(f"Invalid ARRAY JOIN without an array") array_join += f" {', '.join(self.visit(expr) for expr in node.array_join_list)}" + space = f"\n{self.indent(1)}" if self.pretty else " " + comma = f",\n{self.indent(1)}" if self.pretty else ", " + clauses = [ - f"SELECT {'DISTINCT ' if node.distinct else ''}{', '.join(columns)}", - f"FROM {' '.join(joined_tables)}" if len(joined_tables) > 0 else None, - array_join, - "PREWHERE " + prewhere if prewhere else None, - "WHERE " + where if where else None, - f"GROUP BY {', '.join(group_by)}" if group_by and len(group_by) > 0 else None, - "HAVING " + having if having else None, - "WINDOW " + window if window else None, - f"ORDER BY {', '.join(order_by)}" if order_by and len(order_by) > 0 else None, + f"SELECT{space}{'DISTINCT ' if node.distinct else ''}{comma.join(columns)}", + f"FROM{space}{' '.join(joined_tables)}" if len(joined_tables) > 0 else None, + array_join if array_join else None, + f"PREWHERE{space}" + prewhere if prewhere else None, + f"WHERE{space}" + where if where else None, + f"GROUP BY{space}{comma.join(group_by)}" if group_by and len(group_by) > 0 else None, + f"HAVING{space}" + having if having else None, + f"WINDOW{space}" + window if window else None, + f"ORDER BY{space}{comma.join(order_by)}" if order_by and len(order_by) > 0 else None, ] limit = node.limit @@ -257,11 +284,17 @@ def visit_select_query(self, node: ast.SelectQuery): if settings is not None: clauses.append(settings) - response = " ".join([clause for clause in clauses if clause]) + if self.pretty: + response = "\n".join([f"{self.indent()}{clause}" for clause in clauses if clause is not None]) + else: + response = " ".join([clause for clause in clauses if clause is not None]) # If we are printing a SELECT subquery (not the first AST node we are visiting), wrap it in parentheses. if not part_of_select_union and not is_top_level_query: - response = f"({response})" + if self.pretty: + response = f"({response.strip()})" + else: + response = f"({response})" return response diff --git a/posthog/hogql/test/__snapshots__/test_printer.ambr b/posthog/hogql/test/__snapshots__/test_printer.ambr new file mode 100644 index 0000000000000..c59af7ed313bd --- /dev/null +++ b/posthog/hogql/test/__snapshots__/test_printer.ambr @@ -0,0 +1,70 @@ +# name: TestPrinter.test_large_pretty_print + ' + SELECT + groupArray(start_of_period) AS date, + groupArray(counts) AS total, + status + FROM + (SELECT + if(equals(status, 'dormant'), negate(sum(counts)), negate(negate(sum(counts)))) AS counts, + start_of_period, + status + FROM + (SELECT + periods.start_of_period AS start_of_period, + 0 AS counts, + status + FROM + (SELECT + minus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(number)) AS start_of_period + FROM + numbers(dateDiff('day', dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))), dateTrunc('day', plus(assumeNotNull(toDateTime('2023-10-19 23:59:59')), toIntervalDay(1))))) AS numbers) AS periods CROSS JOIN (SELECT + status + FROM + (SELECT + 1) + ARRAY JOIN ['new', 'returning', 'resurrecting', 'dormant'] AS status) AS sec + ORDER BY + status ASC, + start_of_period ASC + UNION ALL + SELECT + start_of_period, + count(DISTINCT person_id) AS counts, + status + FROM + (SELECT + events.person.id AS person_id, + min(events.person.created_at) AS created_at, + arraySort(groupUniqArray(dateTrunc('day', events.timestamp))) AS all_activity, + arrayPopBack(arrayPushFront(all_activity, dateTrunc('day', created_at))) AS previous_activity, + arrayPopFront(arrayPushBack(all_activity, dateTrunc('day', toDateTime('1970-01-01 00:00:00')))) AS following_activity, + arrayMap((previous, current, index) -> if(equals(previous, current), 'new', if(and(equals(minus(current, toIntervalDay(1)), previous), notEquals(index, 1)), 'returning', 'resurrecting')), previous_activity, all_activity, arrayEnumerate(all_activity)) AS initial_status, + arrayMap((current, next) -> if(equals(plus(current, toIntervalDay(1)), next), '', 'dormant'), all_activity, following_activity) AS dormant_status, + arrayMap(x -> plus(x, toIntervalDay(1)), arrayFilter((current, is_dormant) -> equals(is_dormant, 'dormant'), all_activity, dormant_status)) AS dormant_periods, + arrayMap(x -> 'dormant', dormant_periods) AS dormant_label, + arrayConcat(arrayZip(all_activity, initial_status), arrayZip(dormant_periods, dormant_label)) AS temp_concat, + arrayJoin(temp_concat) AS period_status_pairs, + period_status_pairs.1 AS start_of_period, + period_status_pairs.2 AS status + FROM + events + WHERE + and(greaterOrEquals(timestamp, minus(dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))), toIntervalDay(1))), less(timestamp, plus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(1))), equals(event, '$pageview')) + GROUP BY + person_id) + GROUP BY + start_of_period, + status) + WHERE + and(lessOrEquals(start_of_period, dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59')))), greaterOrEquals(start_of_period, dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))))) + GROUP BY + start_of_period, + status + ORDER BY + start_of_period ASC) + GROUP BY + status + LIMIT 10000 + ' +--- diff --git a/posthog/hogql/test/test_printer.py b/posthog/hogql/test/test_printer.py index 59c2c80a8ac40..3861bd77fce42 100644 --- a/posthog/hogql/test/test_printer.py +++ b/posthog/hogql/test/test_printer.py @@ -1,5 +1,6 @@ from typing import Literal, Optional, Dict +import pytest from django.test import override_settings from posthog.hogql import ast @@ -50,10 +51,19 @@ def _assert_select_error(self, statement, expected_error): raise AssertionError(f"Expected '{expected_error}' in '{str(context.exception)}'") self.assertTrue(expected_error in str(context.exception)) + def _pretty(self, query: str): + printed = print_ast( + parse_select(query), + HogQLContext(team_id=self.team.pk, enable_select_queries=True), + "hogql", + pretty=True, + ) + return printed + def test_to_printed_hogql(self): expr = parse_select("select 1 + 2, 3 from events") repsponse = to_printed_hogql(expr, self.team.pk) - self.assertEqual(repsponse, "SELECT plus(1, 2), 3 FROM events LIMIT 10000") + self.assertEqual(repsponse, "SELECT\n plus(1, 2),\n 3\nFROM\n events\nLIMIT 10000") def test_literals(self): self.assertEqual(self._expr("1 + 2"), "plus(1, 2)") @@ -916,3 +926,91 @@ def test_print_both_settings(self): printed, f"SELECT 1 FROM events WHERE equals(events.team_id, {self.team.pk}) LIMIT 10000 SETTINGS optimize_aggregation_in_order=1, readonly=2, max_execution_time=10, allow_experimental_object_type=1", ) + + def test_pretty_print(self): + printed = self._pretty("SELECT 1, event FROM events") + self.assertEqual( + printed, + f"SELECT\n 1,\n event\nFROM\n events\nLIMIT 10000", + ) + + def test_pretty_print_subquery(self): + printed = self._pretty("SELECT 1, event FROM (select 1, event from events)") + self.assertEqual( + printed, + f"""SELECT\n 1,\n event\nFROM\n (SELECT\n 1,\n event\n FROM\n events)\nLIMIT 10000""", + ) + + @pytest.mark.usefixtures("unittest_snapshot") + def test_large_pretty_print(self): + printed = self._pretty( + """ + SELECT + groupArray(start_of_period) AS date, + groupArray(counts) AS total, + status + FROM + (SELECT + if(equals(status, 'dormant'), negate(sum(counts)), negate(negate(sum(counts)))) AS counts, + start_of_period, + status + FROM + (SELECT + periods.start_of_period AS start_of_period, + 0 AS counts, + status + FROM + (SELECT + minus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(number)) AS start_of_period + FROM + numbers(dateDiff('day', dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))), dateTrunc('day', plus(assumeNotNull(toDateTime('2023-10-19 23:59:59')), toIntervalDay(1))))) AS numbers) AS periods CROSS JOIN (SELECT + status + FROM + (SELECT + 1) + ARRAY JOIN ['new', 'returning', 'resurrecting', 'dormant'] AS status) AS sec + ORDER BY + status ASC, + start_of_period ASC + UNION ALL + SELECT + start_of_period, + count(DISTINCT person_id) AS counts, + status + FROM + (SELECT + events.person.id AS person_id, + min(events.person.created_at) AS created_at, + arraySort(groupUniqArray(dateTrunc('day', events.timestamp))) AS all_activity, + arrayPopBack(arrayPushFront(all_activity, dateTrunc('day', created_at))) AS previous_activity, + arrayPopFront(arrayPushBack(all_activity, dateTrunc('day', toDateTime('1970-01-01 00:00:00')))) AS following_activity, + arrayMap((previous, current, index) -> if(equals(previous, current), 'new', if(and(equals(minus(current, toIntervalDay(1)), previous), notEquals(index, 1)), 'returning', 'resurrecting')), previous_activity, all_activity, arrayEnumerate(all_activity)) AS initial_status, + arrayMap((current, next) -> if(equals(plus(current, toIntervalDay(1)), next), '', 'dormant'), all_activity, following_activity) AS dormant_status, + arrayMap(x -> plus(x, toIntervalDay(1)), arrayFilter((current, is_dormant) -> equals(is_dormant, 'dormant'), all_activity, dormant_status)) AS dormant_periods, + arrayMap(x -> 'dormant', dormant_periods) AS dormant_label, + arrayConcat(arrayZip(all_activity, initial_status), arrayZip(dormant_periods, dormant_label)) AS temp_concat, + arrayJoin(temp_concat) AS period_status_pairs, + period_status_pairs.1 AS start_of_period, + period_status_pairs.2 AS status + FROM + events + WHERE + and(greaterOrEquals(timestamp, minus(dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))), toIntervalDay(1))), less(timestamp, plus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(1))), equals(event, '$pageview')) + GROUP BY + person_id) + GROUP BY + start_of_period, + status) + WHERE + and(lessOrEquals(start_of_period, dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59')))), greaterOrEquals(start_of_period, dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))))) + GROUP BY + start_of_period, + status + ORDER BY + start_of_period ASC) + GROUP BY + status + LIMIT 10000 + """ + ) + assert printed == self.snapshot