Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hogql): basic pretty print #18086

Merged
merged 3 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def team_id_guard_for_table(table_type: Union[ast.TableType, ast.TableAliasType]
def to_printed_hogql(query: ast.Expr, team_id: int) -> str:
"""Prints the HogQL query without mutating the node"""
return print_ast(
clone_expr(query), dialect="hogql", context=HogQLContext(team_id=team_id, enable_select_queries=True)
clone_expr(query),
dialect="hogql",
context=HogQLContext(team_id=team_id, enable_select_queries=True),
pretty=True,
)


Expand All @@ -67,9 +70,12 @@ def print_ast(
dialect: Literal["hogql", "clickhouse"],
stack: Optional[List[ast.SelectQuery]] = None,
settings: Optional[HogQLGlobalSettings] = None,
pretty: bool = False,
) -> str:
prepared_ast = prepare_ast_for_printing(node=node, context=context, dialect=dialect, stack=stack, settings=settings)
return print_prepared_ast(node=prepared_ast, context=context, dialect=dialect, stack=stack, settings=settings)
return print_prepared_ast(
node=prepared_ast, context=context, dialect=dialect, stack=stack, settings=settings, pretty=pretty
)


def prepare_ast_for_printing(
Expand Down Expand Up @@ -111,10 +117,13 @@ def print_prepared_ast(
dialect: Literal["hogql", "clickhouse"],
stack: Optional[List[ast.SelectQuery]] = None,
settings: Optional[HogQLGlobalSettings] = None,
pretty: bool = False,
) -> str:
with context.timings.measure("printer"):
# _Printer also adds a team_id guard if printing clickhouse
return _Printer(context=context, dialect=dialect, stack=stack or [], settings=settings).visit(node)
return _Printer(context=context, dialect=dialect, stack=stack or [], settings=settings, pretty=pretty).visit(
node
)


@dataclass
Expand All @@ -132,15 +141,24 @@ def __init__(
dialect: Literal["hogql", "clickhouse"],
stack: Optional[List[AST]] = None,
settings: Optional[HogQLGlobalSettings] = None,
pretty: bool = False,
):
self.context = context
self.dialect = dialect
self.stack: List[AST] = stack or [] # Keep track of all traversed nodes.
self.settings = settings
self.pretty = pretty
self._indent = -1
self.tab_size = 4

def indent(self, extra: int = 0):
return " " * self.tab_size * (self._indent + extra)

def visit(self, node: AST):
self.stack.append(node)
self._indent += 1
response = super().visit(node)
self._indent -= 1
self.stack.pop()

if len(self.stack) == 0 and self.dialect == "clickhouse" and self.settings:
Expand All @@ -153,9 +171,15 @@ def visit(self, node: AST):
return response

def visit_select_union_query(self, node: ast.SelectUnionQuery):
query = " UNION ALL ".join([self.visit(expr) for expr in node.select_queries])
self._indent -= 1
queries = [self.visit(expr) for expr in node.select_queries]
if self.pretty:
query = f"\n{self.indent(1)}UNION ALL\n{self.indent(1)}".join([query.strip() for query in queries])
else:
query = " UNION ALL ".join(queries)
self._indent += 1
if len(self.stack) > 1:
return f"({query})"
return f"({query.strip()})"
return query

def visit_select_query(self, node: ast.SelectQuery):
Expand Down Expand Up @@ -221,16 +245,19 @@ def visit_select_query(self, node: ast.SelectQuery):
raise HogQLException(f"Invalid ARRAY JOIN without an array")
array_join += f" {', '.join(self.visit(expr) for expr in node.array_join_list)}"

space = f"\n{self.indent(1)}" if self.pretty else " "
comma = f",\n{self.indent(1)}" if self.pretty else ", "

clauses = [
f"SELECT {'DISTINCT ' if node.distinct else ''}{', '.join(columns)}",
f"FROM {' '.join(joined_tables)}" if len(joined_tables) > 0 else None,
array_join,
"PREWHERE " + prewhere if prewhere else None,
"WHERE " + where if where else None,
f"GROUP BY {', '.join(group_by)}" if group_by and len(group_by) > 0 else None,
"HAVING " + having if having else None,
"WINDOW " + window if window else None,
f"ORDER BY {', '.join(order_by)}" if order_by and len(order_by) > 0 else None,
f"SELECT{space}{'DISTINCT ' if node.distinct else ''}{comma.join(columns)}",
f"FROM{space}{' '.join(joined_tables)}" if len(joined_tables) > 0 else None,
array_join if array_join else None,
f"PREWHERE{space}" + prewhere if prewhere else None,
f"WHERE{space}" + where if where else None,
f"GROUP BY{space}{comma.join(group_by)}" if group_by and len(group_by) > 0 else None,
f"HAVING{space}" + having if having else None,
f"WINDOW{space}" + window if window else None,
f"ORDER BY{space}{comma.join(order_by)}" if order_by and len(order_by) > 0 else None,
]

limit = node.limit
Expand All @@ -257,11 +284,17 @@ def visit_select_query(self, node: ast.SelectQuery):
if settings is not None:
clauses.append(settings)

response = " ".join([clause for clause in clauses if clause])
if self.pretty:
response = "\n".join([f"{self.indent()}{clause}" for clause in clauses if clause is not None])
else:
response = " ".join([clause for clause in clauses if clause is not None])

# If we are printing a SELECT subquery (not the first AST node we are visiting), wrap it in parentheses.
if not part_of_select_union and not is_top_level_query:
response = f"({response})"
if self.pretty:
response = f"({response.strip()})"
else:
response = f"({response})"

return response

Expand Down
98 changes: 98 additions & 0 deletions posthog/hogql/test/test_printer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Literal, Optional, Dict

import pytest
from django.test import override_settings

from posthog.hogql import ast
Expand Down Expand Up @@ -50,6 +51,15 @@ def _assert_select_error(self, statement, expected_error):
raise AssertionError(f"Expected '{expected_error}' in '{str(context.exception)}'")
self.assertTrue(expected_error in str(context.exception))

def _pretty(self, query: str):
printed = print_ast(
parse_select(query),
HogQLContext(team_id=self.team.pk, enable_select_queries=True),
"hogql",
pretty=True,
)
return printed

def test_to_printed_hogql(self):
expr = parse_select("select 1 + 2, 3 from events")
repsponse = to_printed_hogql(expr, self.team.pk)
Expand Down Expand Up @@ -916,3 +926,91 @@ def test_print_both_settings(self):
printed,
f"SELECT 1 FROM events WHERE equals(events.team_id, {self.team.pk}) LIMIT 10000 SETTINGS optimize_aggregation_in_order=1, readonly=2, max_execution_time=10, allow_experimental_object_type=1",
)

def test_pretty_print(self):
printed = self._pretty("SELECT 1, event FROM events")
self.assertEqual(
printed,
f"SELECT\n 1,\n event\nFROM\n events\nLIMIT 10000",
)

def test_pretty_print_subquery(self):
printed = self._pretty("SELECT 1, event FROM (select 1, event from events)")
self.assertEqual(
printed,
f"""SELECT\n 1,\n event\nFROM\n (SELECT\n 1,\n event\n FROM\n events)\nLIMIT 10000""",
)

@pytest.mark.usefixtures("unittest_snapshot")
def test_large_pretty_print(self):
printed = self._pretty(
"""
SELECT
groupArray(start_of_period) AS date,
groupArray(counts) AS total,
status
FROM
(SELECT
if(equals(status, 'dormant'), negate(sum(counts)), negate(negate(sum(counts)))) AS counts,
start_of_period,
status
FROM
(SELECT
periods.start_of_period AS start_of_period,
0 AS counts,
status
FROM
(SELECT
minus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(number)) AS start_of_period
FROM
numbers(dateDiff('day', dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))), dateTrunc('day', plus(assumeNotNull(toDateTime('2023-10-19 23:59:59')), toIntervalDay(1))))) AS numbers) AS periods CROSS JOIN (SELECT
status
FROM
(SELECT
1)
ARRAY JOIN ['new', 'returning', 'resurrecting', 'dormant'] AS status) AS sec
ORDER BY
status ASC,
start_of_period ASC
UNION ALL
SELECT
start_of_period,
count(DISTINCT person_id) AS counts,
status
FROM
(SELECT
events.person.id AS person_id,
min(events.person.created_at) AS created_at,
arraySort(groupUniqArray(dateTrunc('day', events.timestamp))) AS all_activity,
arrayPopBack(arrayPushFront(all_activity, dateTrunc('day', created_at))) AS previous_activity,
arrayPopFront(arrayPushBack(all_activity, dateTrunc('day', toDateTime('1970-01-01 00:00:00')))) AS following_activity,
arrayMap((previous, current, index) -> if(equals(previous, current), 'new', if(and(equals(minus(current, toIntervalDay(1)), previous), notEquals(index, 1)), 'returning', 'resurrecting')), previous_activity, all_activity, arrayEnumerate(all_activity)) AS initial_status,
arrayMap((current, next) -> if(equals(plus(current, toIntervalDay(1)), next), '', 'dormant'), all_activity, following_activity) AS dormant_status,
arrayMap(x -> plus(x, toIntervalDay(1)), arrayFilter((current, is_dormant) -> equals(is_dormant, 'dormant'), all_activity, dormant_status)) AS dormant_periods,
arrayMap(x -> 'dormant', dormant_periods) AS dormant_label,
arrayConcat(arrayZip(all_activity, initial_status), arrayZip(dormant_periods, dormant_label)) AS temp_concat,
arrayJoin(temp_concat) AS period_status_pairs,
period_status_pairs.1 AS start_of_period,
period_status_pairs.2 AS status
FROM
events
WHERE
and(greaterOrEquals(timestamp, minus(dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))), toIntervalDay(1))), less(timestamp, plus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(1))), equals(event, '$pageview'))
GROUP BY
person_id)
GROUP BY
start_of_period,
status)
WHERE
and(lessOrEquals(start_of_period, dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59')))), greaterOrEquals(start_of_period, dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00')))))
GROUP BY
start_of_period,
status
ORDER BY
start_of_period ASC)
GROUP BY
status
LIMIT 10000
"""
)
assert printed == self.snapshot
Loading