Skip to content

Commit

Permalink
perf: Resolve field alias types in nullable checks (#25247)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkaemming authored Sep 27, 2024
1 parent c831181 commit df183b5
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 120 deletions.
21 changes: 10 additions & 11 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ def value_expr(self) -> str:
return f"{self.__qualified_column}[{self.property_name}]"


def resolve_field_type(expr: ast.Expr) -> ast.Type | None:
expr_type = expr.type
while isinstance(expr_type, ast.FieldAliasType):
expr_type = expr_type.type
return expr_type


class _Printer(Visitor):
# NOTE: Call "print_ast()", not this class directly.

Expand Down Expand Up @@ -609,12 +616,6 @@ def __get_optimized_property_group_compare_operation(self, node: ast.CompareOper
if self.context.modifiers.propertyGroupsMode != PropertyGroupsMode.OPTIMIZED:
return None

def resolve_field_type(expr: ast.Expr) -> ast.Type | None:
expr_type = expr.type
while isinstance(expr_type, ast.FieldAliasType):
expr_type = expr_type.type
return expr_type

if node.op in (ast.CompareOperationOp.Eq, ast.CompareOperationOp.NotEq):
# For commutative operations, we can rewrite the expression with parameters in either order without
# affecting the result.
Expand Down Expand Up @@ -946,11 +947,6 @@ def __get_optimized_property_group_call(self, node: ast.Call) -> str | None:
# XXX: A lot of this is duplicated (sometimes just copy/pasted) from the null equality comparison logic -- it
# might make sense to make it so that ``isNull``/``isNotNull`` is rewritten to comparison expressions before
# this step, similar to how ``equals``/``notEquals`` are interpreted as their comparison operation counterparts.
def resolve_field_type(expr: ast.Expr) -> ast.Type | None:
expr_type = expr.type
while isinstance(expr_type, ast.FieldAliasType):
expr_type = expr_type.type
return expr_type

match node:
case ast.Call(name="isNull" | "isNotNull" as function_name, args=[field]):
Expand Down Expand Up @@ -1523,6 +1519,9 @@ def _is_nullable(self, node: ast.Expr) -> bool:
return node.type.is_nullable(self.context)
elif isinstance(node, ast.Alias):
return self._is_nullable(node.expr)
elif isinstance(node.type, ast.FieldAliasType):
if (field_type := resolve_field_type(node)) and isinstance(field_type, ast.FieldType):
return field_type.is_nullable(self.context)

# we don't know if it's nullable, so we assume it can be
return True
Expand Down
27 changes: 6 additions & 21 deletions posthog/hogql/test/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,10 +1952,7 @@ def test_inline_persons(self):
dialect="clickhouse",
settings=HogQLGlobalSettings(max_execution_time=10),
)
assert (
f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), ifNull(in(id, tuple(1, 2, 3)), 0))"
in printed
)
assert f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), in(id, tuple(1, 2, 3)))" in printed

def test_dont_inline_persons(self):
query = parse_select(
Expand All @@ -1982,10 +1979,7 @@ def test_inline_persons_alias(self):
dialect="clickhouse",
settings=HogQLGlobalSettings(max_execution_time=10),
)
assert (
f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), ifNull(in(id, tuple(1, 2, 3)), 0))"
in printed
)
assert f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), in(id, tuple(1, 2, 3)))" in printed

def test_two_joins(self):
query = parse_select(
Expand All @@ -2001,14 +1995,8 @@ def test_two_joins(self):
dialect="clickhouse",
settings=HogQLGlobalSettings(max_execution_time=10),
)
assert (
f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), ifNull(in(id, tuple(1, 2, 3)), 0))"
in printed
)
assert (
f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), ifNull(in(id, tuple(4, 5, 6)), 0))"
in printed
)
assert f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), in(id, tuple(1, 2, 3)))" in printed
assert f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), in(id, tuple(4, 5, 6)))" in printed

def test_two_clauses(self):
query = parse_select(
Expand All @@ -2025,10 +2013,7 @@ def test_two_clauses(self):
settings=HogQLGlobalSettings(max_execution_time=10),
)
assert (
f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), ifNull(in(id, tuple(7, 8, 9)), 0), ifNull(in(id, tuple(1, 2, 3)), 0))"
in printed
)
assert (
f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), ifNull(in(id, tuple(4, 5, 6)), 0))"
f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), in(id, tuple(7, 8, 9)), in(id, tuple(1, 2, 3)))"
in printed
)
assert f"AS id FROM person WHERE and(equals(person.team_id, {self.team.pk}), in(id, tuple(4, 5, 6)))" in printed
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
FROM events LEFT JOIN (
SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(groups.group_properties, %(hogql_val_0)s), ''), 'null'), '^"|"$', ''), toTimeZone(groups._timestamp, %(hogql_val_1)s)) AS properties___group_boolean, groups.group_type_index AS index, groups.group_key AS key
FROM groups
WHERE and(equals(groups.team_id, 420), ifNull(equals(index, 0), 0))
WHERE and(equals(groups.team_id, 420), equals(index, 0))
GROUP BY groups.group_type_index, groups.group_key) AS events__group_0 ON equals(events.`$group_0`, events__group_0.key)
WHERE equals(events.team_id, 420)
LIMIT 50000
Expand All @@ -36,7 +36,7 @@
FROM events LEFT JOIN (
SELECT argMax(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(groups.group_properties, %(hogql_val_0)s), ''), 'null'), '^"|"$', ''), toTimeZone(groups._timestamp, %(hogql_val_1)s)) AS properties___inty, groups.group_type_index AS index, groups.group_key AS key
FROM groups
WHERE and(equals(groups.team_id, 420), ifNull(equals(index, 0), 0))
WHERE and(equals(groups.team_id, 420), equals(index, 0))
GROUP BY groups.group_type_index, groups.group_key) AS events__group_0 ON equals(events.`$group_0`, events__group_0.key)
WHERE equals(events.team_id, 420)
LIMIT 50000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@
groups.group_type_index AS index,
groups.group_key AS key
FROM groups
WHERE and(equals(groups.team_id, 2), ifNull(equals(index, 0), 0))
WHERE and(equals(groups.team_id, 2), equals(index, 0))
GROUP BY groups.group_type_index,
groups.group_key) AS e__group_0 ON equals(e.`$group_0`, e__group_0.key)
WHERE and(equals(e.team_id, 2), and(and(greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toDateTime64('2020-01-01 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), toDateTime64('2020-01-08 23:59:59.999999', 6, 'UTC'))), in(e.event, tuple('buy', 'play movie', 'sign up'))), or(ifNull(equals(step_0, 1), 0), ifNull(equals(step_1, 1), 0), ifNull(equals(step_2, 1), 0))))))))
Expand Down Expand Up @@ -1550,7 +1550,7 @@
groups.group_type_index AS index,
groups.group_key AS key
FROM groups
WHERE and(equals(groups.team_id, 2), ifNull(equals(index, 0), 0))
WHERE and(equals(groups.team_id, 2), equals(index, 0))
GROUP BY groups.group_type_index,
groups.group_key) AS e__group_0 ON equals(e.`$group_0`, e__group_0.key)
WHERE and(equals(e.team_id, 2), and(and(greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toDateTime64('2020-01-01 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), toDateTime64('2020-01-08 23:59:59.999999', 6, 'UTC'))), in(e.event, tuple('buy', 'play movie', 'sign up'))), or(ifNull(equals(step_0, 1), 0), ifNull(equals(step_1, 1), 0), ifNull(equals(step_2, 1), 0))))))))
Expand Down Expand Up @@ -1695,7 +1695,7 @@
groups.group_type_index AS index,
groups.group_key AS key
FROM groups
WHERE and(equals(groups.team_id, 2), ifNull(equals(index, 0), 0))
WHERE and(equals(groups.team_id, 2), equals(index, 0))
GROUP BY groups.group_type_index,
groups.group_key) AS e__group_0 ON equals(e.`$group_0`, e__group_0.key)
WHERE and(equals(e.team_id, 2), and(and(greaterOrEquals(toTimeZone(e.timestamp, 'UTC'), toDateTime64('2020-01-01 00:00:00.000000', 6, 'UTC')), lessOrEquals(toTimeZone(e.timestamp, 'UTC'), toDateTime64('2020-01-08 23:59:59.999999', 6, 'UTC'))), in(e.event, tuple('buy', 'play movie', 'sign up'))), or(ifNull(equals(step_0, 1), 0), ifNull(equals(step_1, 1), 0), ifNull(equals(step_2, 1), 0))))))))
Expand Down
Loading

0 comments on commit df183b5

Please sign in to comment.