Skip to content

Commit

Permalink
feat(hogql): basic pretty print (#18086)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra authored Oct 19, 2023
1 parent a7ae3b0 commit 03d22c7
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 17 deletions.
65 changes: 49 additions & 16 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def team_id_guard_for_table(table_type: Union[ast.TableType, ast.TableAliasType]
def to_printed_hogql(query: ast.Expr, team_id: int) -> str:
"""Prints the HogQL query without mutating the node"""
return print_ast(
clone_expr(query), dialect="hogql", context=HogQLContext(team_id=team_id, enable_select_queries=True)
clone_expr(query),
dialect="hogql",
context=HogQLContext(team_id=team_id, enable_select_queries=True),
pretty=True,
)


Expand All @@ -67,9 +70,12 @@ def print_ast(
dialect: Literal["hogql", "clickhouse"],
stack: Optional[List[ast.SelectQuery]] = None,
settings: Optional[HogQLGlobalSettings] = None,
pretty: bool = False,
) -> str:
prepared_ast = prepare_ast_for_printing(node=node, context=context, dialect=dialect, stack=stack, settings=settings)
return print_prepared_ast(node=prepared_ast, context=context, dialect=dialect, stack=stack, settings=settings)
return print_prepared_ast(
node=prepared_ast, context=context, dialect=dialect, stack=stack, settings=settings, pretty=pretty
)


def prepare_ast_for_printing(
Expand Down Expand Up @@ -111,10 +117,13 @@ def print_prepared_ast(
dialect: Literal["hogql", "clickhouse"],
stack: Optional[List[ast.SelectQuery]] = None,
settings: Optional[HogQLGlobalSettings] = None,
pretty: bool = False,
) -> str:
with context.timings.measure("printer"):
# _Printer also adds a team_id guard if printing clickhouse
return _Printer(context=context, dialect=dialect, stack=stack or [], settings=settings).visit(node)
return _Printer(context=context, dialect=dialect, stack=stack or [], settings=settings, pretty=pretty).visit(
node
)


@dataclass
Expand All @@ -132,15 +141,24 @@ def __init__(
dialect: Literal["hogql", "clickhouse"],
stack: Optional[List[AST]] = None,
settings: Optional[HogQLGlobalSettings] = None,
pretty: bool = False,
):
self.context = context
self.dialect = dialect
self.stack: List[AST] = stack or [] # Keep track of all traversed nodes.
self.settings = settings
self.pretty = pretty
self._indent = -1
self.tab_size = 4

def indent(self, extra: int = 0):
return " " * self.tab_size * (self._indent + extra)

def visit(self, node: AST):
self.stack.append(node)
self._indent += 1
response = super().visit(node)
self._indent -= 1
self.stack.pop()

if len(self.stack) == 0 and self.dialect == "clickhouse" and self.settings:
Expand All @@ -153,9 +171,15 @@ def visit(self, node: AST):
return response

def visit_select_union_query(self, node: ast.SelectUnionQuery):
query = " UNION ALL ".join([self.visit(expr) for expr in node.select_queries])
self._indent -= 1
queries = [self.visit(expr) for expr in node.select_queries]
if self.pretty:
query = f"\n{self.indent(1)}UNION ALL\n{self.indent(1)}".join([query.strip() for query in queries])
else:
query = " UNION ALL ".join(queries)
self._indent += 1
if len(self.stack) > 1:
return f"({query})"
return f"({query.strip()})"
return query

def visit_select_query(self, node: ast.SelectQuery):
Expand Down Expand Up @@ -221,16 +245,19 @@ def visit_select_query(self, node: ast.SelectQuery):
raise HogQLException(f"Invalid ARRAY JOIN without an array")
array_join += f" {', '.join(self.visit(expr) for expr in node.array_join_list)}"

space = f"\n{self.indent(1)}" if self.pretty else " "
comma = f",\n{self.indent(1)}" if self.pretty else ", "

clauses = [
f"SELECT {'DISTINCT ' if node.distinct else ''}{', '.join(columns)}",
f"FROM {' '.join(joined_tables)}" if len(joined_tables) > 0 else None,
array_join,
"PREWHERE " + prewhere if prewhere else None,
"WHERE " + where if where else None,
f"GROUP BY {', '.join(group_by)}" if group_by and len(group_by) > 0 else None,
"HAVING " + having if having else None,
"WINDOW " + window if window else None,
f"ORDER BY {', '.join(order_by)}" if order_by and len(order_by) > 0 else None,
f"SELECT{space}{'DISTINCT ' if node.distinct else ''}{comma.join(columns)}",
f"FROM{space}{' '.join(joined_tables)}" if len(joined_tables) > 0 else None,
array_join if array_join else None,
f"PREWHERE{space}" + prewhere if prewhere else None,
f"WHERE{space}" + where if where else None,
f"GROUP BY{space}{comma.join(group_by)}" if group_by and len(group_by) > 0 else None,
f"HAVING{space}" + having if having else None,
f"WINDOW{space}" + window if window else None,
f"ORDER BY{space}{comma.join(order_by)}" if order_by and len(order_by) > 0 else None,
]

limit = node.limit
Expand All @@ -257,11 +284,17 @@ def visit_select_query(self, node: ast.SelectQuery):
if settings is not None:
clauses.append(settings)

response = " ".join([clause for clause in clauses if clause])
if self.pretty:
response = "\n".join([f"{self.indent()}{clause}" for clause in clauses if clause is not None])
else:
response = " ".join([clause for clause in clauses if clause is not None])

# If we are printing a SELECT subquery (not the first AST node we are visiting), wrap it in parentheses.
if not part_of_select_union and not is_top_level_query:
response = f"({response})"
if self.pretty:
response = f"({response.strip()})"
else:
response = f"({response})"

return response

Expand Down
70 changes: 70 additions & 0 deletions posthog/hogql/test/__snapshots__/test_printer.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# name: TestPrinter.test_large_pretty_print
'
SELECT
groupArray(start_of_period) AS date,
groupArray(counts) AS total,
status
FROM
(SELECT
if(equals(status, 'dormant'), negate(sum(counts)), negate(negate(sum(counts)))) AS counts,
start_of_period,
status
FROM
(SELECT
periods.start_of_period AS start_of_period,
0 AS counts,
status
FROM
(SELECT
minus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(number)) AS start_of_period
FROM
numbers(dateDiff('day', dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))), dateTrunc('day', plus(assumeNotNull(toDateTime('2023-10-19 23:59:59')), toIntervalDay(1))))) AS numbers) AS periods CROSS JOIN (SELECT
status
FROM
(SELECT
1)
ARRAY JOIN ['new', 'returning', 'resurrecting', 'dormant'] AS status) AS sec
ORDER BY
status ASC,
start_of_period ASC
UNION ALL
SELECT
start_of_period,
count(DISTINCT person_id) AS counts,
status
FROM
(SELECT
events.person.id AS person_id,
min(events.person.created_at) AS created_at,
arraySort(groupUniqArray(dateTrunc('day', events.timestamp))) AS all_activity,
arrayPopBack(arrayPushFront(all_activity, dateTrunc('day', created_at))) AS previous_activity,
arrayPopFront(arrayPushBack(all_activity, dateTrunc('day', toDateTime('1970-01-01 00:00:00')))) AS following_activity,
arrayMap((previous, current, index) -> if(equals(previous, current), 'new', if(and(equals(minus(current, toIntervalDay(1)), previous), notEquals(index, 1)), 'returning', 'resurrecting')), previous_activity, all_activity, arrayEnumerate(all_activity)) AS initial_status,
arrayMap((current, next) -> if(equals(plus(current, toIntervalDay(1)), next), '', 'dormant'), all_activity, following_activity) AS dormant_status,
arrayMap(x -> plus(x, toIntervalDay(1)), arrayFilter((current, is_dormant) -> equals(is_dormant, 'dormant'), all_activity, dormant_status)) AS dormant_periods,
arrayMap(x -> 'dormant', dormant_periods) AS dormant_label,
arrayConcat(arrayZip(all_activity, initial_status), arrayZip(dormant_periods, dormant_label)) AS temp_concat,
arrayJoin(temp_concat) AS period_status_pairs,
period_status_pairs.1 AS start_of_period,
period_status_pairs.2 AS status
FROM
events
WHERE
and(greaterOrEquals(timestamp, minus(dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))), toIntervalDay(1))), less(timestamp, plus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(1))), equals(event, '$pageview'))
GROUP BY
person_id)
GROUP BY
start_of_period,
status)
WHERE
and(lessOrEquals(start_of_period, dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59')))), greaterOrEquals(start_of_period, dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00')))))
GROUP BY
start_of_period,
status
ORDER BY
start_of_period ASC)
GROUP BY
status
LIMIT 10000
'
---
100 changes: 99 additions & 1 deletion posthog/hogql/test/test_printer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Literal, Optional, Dict

import pytest
from django.test import override_settings

from posthog.hogql import ast
Expand Down Expand Up @@ -50,10 +51,19 @@ def _assert_select_error(self, statement, expected_error):
raise AssertionError(f"Expected '{expected_error}' in '{str(context.exception)}'")
self.assertTrue(expected_error in str(context.exception))

def _pretty(self, query: str):
printed = print_ast(
parse_select(query),
HogQLContext(team_id=self.team.pk, enable_select_queries=True),
"hogql",
pretty=True,
)
return printed

def test_to_printed_hogql(self):
expr = parse_select("select 1 + 2, 3 from events")
repsponse = to_printed_hogql(expr, self.team.pk)
self.assertEqual(repsponse, "SELECT plus(1, 2), 3 FROM events LIMIT 10000")
self.assertEqual(repsponse, "SELECT\n plus(1, 2),\n 3\nFROM\n events\nLIMIT 10000")

def test_literals(self):
self.assertEqual(self._expr("1 + 2"), "plus(1, 2)")
Expand Down Expand Up @@ -916,3 +926,91 @@ def test_print_both_settings(self):
printed,
f"SELECT 1 FROM events WHERE equals(events.team_id, {self.team.pk}) LIMIT 10000 SETTINGS optimize_aggregation_in_order=1, readonly=2, max_execution_time=10, allow_experimental_object_type=1",
)

def test_pretty_print(self):
printed = self._pretty("SELECT 1, event FROM events")
self.assertEqual(
printed,
f"SELECT\n 1,\n event\nFROM\n events\nLIMIT 10000",
)

def test_pretty_print_subquery(self):
printed = self._pretty("SELECT 1, event FROM (select 1, event from events)")
self.assertEqual(
printed,
f"""SELECT\n 1,\n event\nFROM\n (SELECT\n 1,\n event\n FROM\n events)\nLIMIT 10000""",
)

@pytest.mark.usefixtures("unittest_snapshot")
def test_large_pretty_print(self):
printed = self._pretty(
"""
SELECT
groupArray(start_of_period) AS date,
groupArray(counts) AS total,
status
FROM
(SELECT
if(equals(status, 'dormant'), negate(sum(counts)), negate(negate(sum(counts)))) AS counts,
start_of_period,
status
FROM
(SELECT
periods.start_of_period AS start_of_period,
0 AS counts,
status
FROM
(SELECT
minus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(number)) AS start_of_period
FROM
numbers(dateDiff('day', dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))), dateTrunc('day', plus(assumeNotNull(toDateTime('2023-10-19 23:59:59')), toIntervalDay(1))))) AS numbers) AS periods CROSS JOIN (SELECT
status
FROM
(SELECT
1)
ARRAY JOIN ['new', 'returning', 'resurrecting', 'dormant'] AS status) AS sec
ORDER BY
status ASC,
start_of_period ASC
UNION ALL
SELECT
start_of_period,
count(DISTINCT person_id) AS counts,
status
FROM
(SELECT
events.person.id AS person_id,
min(events.person.created_at) AS created_at,
arraySort(groupUniqArray(dateTrunc('day', events.timestamp))) AS all_activity,
arrayPopBack(arrayPushFront(all_activity, dateTrunc('day', created_at))) AS previous_activity,
arrayPopFront(arrayPushBack(all_activity, dateTrunc('day', toDateTime('1970-01-01 00:00:00')))) AS following_activity,
arrayMap((previous, current, index) -> if(equals(previous, current), 'new', if(and(equals(minus(current, toIntervalDay(1)), previous), notEquals(index, 1)), 'returning', 'resurrecting')), previous_activity, all_activity, arrayEnumerate(all_activity)) AS initial_status,
arrayMap((current, next) -> if(equals(plus(current, toIntervalDay(1)), next), '', 'dormant'), all_activity, following_activity) AS dormant_status,
arrayMap(x -> plus(x, toIntervalDay(1)), arrayFilter((current, is_dormant) -> equals(is_dormant, 'dormant'), all_activity, dormant_status)) AS dormant_periods,
arrayMap(x -> 'dormant', dormant_periods) AS dormant_label,
arrayConcat(arrayZip(all_activity, initial_status), arrayZip(dormant_periods, dormant_label)) AS temp_concat,
arrayJoin(temp_concat) AS period_status_pairs,
period_status_pairs.1 AS start_of_period,
period_status_pairs.2 AS status
FROM
events
WHERE
and(greaterOrEquals(timestamp, minus(dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00'))), toIntervalDay(1))), less(timestamp, plus(dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59'))), toIntervalDay(1))), equals(event, '$pageview'))
GROUP BY
person_id)
GROUP BY
start_of_period,
status)
WHERE
and(lessOrEquals(start_of_period, dateTrunc('day', assumeNotNull(toDateTime('2023-10-19 23:59:59')))), greaterOrEquals(start_of_period, dateTrunc('day', assumeNotNull(toDateTime('2023-09-19 00:00:00')))))
GROUP BY
start_of_period,
status
ORDER BY
start_of_period ASC)
GROUP BY
status
LIMIT 10000
"""
)
assert printed == self.snapshot

0 comments on commit 03d22c7

Please sign in to comment.