Skip to content

Commit

Permalink
feat(hogql): hidden aliases for special fields (#18725)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra authored Nov 27, 2023
1 parent c0b5061 commit 0582506
Show file tree
Hide file tree
Showing 16 changed files with 2,763 additions and 1,574 deletions.
46 changes: 23 additions & 23 deletions posthog/api/test/__snapshots__/test_query.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '')), ''))
FROM events
Expand All @@ -20,7 +20,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '')), ''))
FROM events
Expand All @@ -37,7 +37,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '')), ''))
FROM events
Expand All @@ -54,7 +54,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
nullIf(nullIf(events.mat_key, ''), 'null'),
nullIf(nullIf(events.mat_key, ''), 'null') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(nullIf(nullIf(events.mat_key, ''), 'null')), ''))
FROM events
Expand All @@ -71,7 +71,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
nullIf(nullIf(events.mat_key, ''), 'null'),
nullIf(nullIf(events.mat_key, ''), 'null') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(nullIf(nullIf(events.mat_key, ''), 'null')), ''))
FROM events
Expand All @@ -88,7 +88,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
nullIf(nullIf(events.mat_key, ''), 'null'),
nullIf(nullIf(events.mat_key, ''), 'null') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(nullIf(nullIf(events.mat_key, ''), 'null')), ''))
FROM events
Expand Down Expand Up @@ -144,7 +144,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '')
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key
FROM events
WHERE equals(events.team_id, 2)
ORDER BY toTimeZone(events.timestamp, 'UTC') ASC
Expand All @@ -158,7 +158,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
nullIf(nullIf(events.mat_key, ''), 'null')
nullIf(nullIf(events.mat_key, ''), 'null') AS key
FROM events
WHERE equals(events.team_id, 2)
ORDER BY toTimeZone(events.timestamp, 'UTC') ASC
Expand Down Expand Up @@ -204,7 +204,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '')), ''))
FROM events
Expand All @@ -221,7 +221,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '')), ''))
FROM events
Expand All @@ -238,7 +238,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '')), ''))
FROM events
Expand All @@ -255,7 +255,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '')), ''))
FROM events
Expand All @@ -272,7 +272,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
nullIf(nullIf(events.mat_key, ''), 'null'),
nullIf(nullIf(events.mat_key, ''), 'null') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(nullIf(nullIf(events.mat_key, ''), 'null')), ''))
FROM events
Expand All @@ -289,7 +289,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
nullIf(nullIf(events.mat_key, ''), 'null'),
nullIf(nullIf(events.mat_key, ''), 'null') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(nullIf(nullIf(events.mat_key, ''), 'null')), ''))
FROM events
Expand All @@ -306,7 +306,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
nullIf(nullIf(events.mat_key, ''), 'null'),
nullIf(nullIf(events.mat_key, ''), 'null') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(nullIf(nullIf(events.mat_key, ''), 'null')), ''))
FROM events
Expand All @@ -323,7 +323,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
nullIf(nullIf(events.mat_key, ''), 'null'),
nullIf(nullIf(events.mat_key, ''), 'null') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(nullIf(nullIf(events.mat_key, ''), 'null')), ''))
FROM events
Expand All @@ -340,7 +340,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '')), ''))
FROM events
Expand Down Expand Up @@ -374,7 +374,7 @@
/* user_id:0 request:_snapshot_ */
SELECT events.event,
events.distinct_id,
nullIf(nullIf(events.mat_key, ''), 'null'),
nullIf(nullIf(events.mat_key, ''), 'null') AS key,
'a%sd',
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(nullIf(nullIf(events.mat_key, ''), 'null')), ''))
FROM events
Expand Down Expand Up @@ -406,7 +406,7 @@
# name: TestQuery.test_property_filter_aggregations
'
/* user_id:0 request:_snapshot_ */
SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
count()
FROM events
WHERE and(equals(events.team_id, 2), less(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-10 12:14:05.000000', 6, 'UTC')), greater(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-09 12:14:00.000000', 6, 'UTC')))
Expand All @@ -421,7 +421,7 @@
# name: TestQuery.test_property_filter_aggregations.1
'
/* user_id:0 request:_snapshot_ */
SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
count()
FROM events
WHERE and(equals(events.team_id, 2), less(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-10 12:14:05.000000', 6, 'UTC')), greater(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-09 12:14:00.000000', 6, 'UTC')))
Expand All @@ -437,7 +437,7 @@
# name: TestQuery.test_property_filter_aggregations_materialized
'
/* user_id:0 request:_snapshot_ */
SELECT nullIf(nullIf(events.mat_key, ''), 'null'),
SELECT nullIf(nullIf(events.mat_key, ''), 'null') AS key,
count()
FROM events
WHERE and(equals(events.team_id, 2), less(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-10 12:14:05.000000', 6, 'UTC')), greater(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-09 12:14:00.000000', 6, 'UTC')))
Expand All @@ -452,7 +452,7 @@
# name: TestQuery.test_property_filter_aggregations_materialized.1
'
/* user_id:0 request:_snapshot_ */
SELECT nullIf(nullIf(events.mat_key, ''), 'null'),
SELECT nullIf(nullIf(events.mat_key, ''), 'null') AS key,
count()
FROM events
WHERE and(equals(events.team_id, 2), less(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-10 12:14:05.000000', 6, 'UTC')), greater(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-09 12:14:00.000000', 6, 'UTC')))
Expand Down Expand Up @@ -483,7 +483,7 @@
# name: TestQuery.test_select_hogql_expressions
'
/* user_id:0 request:_snapshot_ */
SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', ''),
SELECT replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '') AS key,
events.event,
events.distinct_id,
concat(ifNull(toString(events.event), ''), ' ', ifNull(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(events.properties, 'key'), ''), 'null'), '^"|"$', '')), ''))
Expand Down
14 changes: 14 additions & 0 deletions posthog/hogql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def get_child(self, name: str) -> Type:
def has_child(self, name: str) -> bool:
return self.type.has_child(name)

def resolve_constant_type(self):
return self.type.resolve_constant_type()

def resolve_database_field(self):
if isinstance(self.type, FieldType):
return self.type.resolve_database_field()
raise NotImplementedException("FieldAliasType.resolve_database_field not implemented")


@dataclass(kw_only=True)
class BaseTableType(Type):
Expand Down Expand Up @@ -346,6 +354,12 @@ class LambdaArgumentType(Type):
class Alias(Expr):
alias: str
expr: Expr
"""
Aliases are "hidden" if they're automatically created by HogQL when abstracting fields.
E.g. "events.timestamp" gets turned into a "toTimeZone(events.timestamp, 'UTC') AS timestamp".
Hidden aliases are printed only when printing the columns of a SELECT query in the ClickHouse dialect.
"""
hidden: bool = False


class ArithmeticOperationOp(str, Enum):
Expand Down
2 changes: 1 addition & 1 deletion posthog/hogql/database/schema/event_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def join_with_events_table_session_duration(
"""
select "$session_id" as id, dateDiff('second', min(timestamp), max(timestamp)) as duration
from events
group by id
group by id
"""
)

Expand Down
12 changes: 6 additions & 6 deletions posthog/hogql/database/schema/test/test_event_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_with_simple_equality_clause(self):

assert len(compare_operators) == 1
assert compare_operators[0] == ast.CompareOperation(
left=ast.Field(chain=["team_id"]),
left=ast.Alias(alias="team_id", hidden=True, expr=ast.Field(chain=["team_id"])),
op=ast.CompareOperationOp.Eq,
right=ast.Constant(value=1),
)
Expand All @@ -56,7 +56,7 @@ def test_with_timestamps(self):

assert len(compare_operators) == 1
assert compare_operators[0] == ast.CompareOperation(
left=ast.Field(chain=["timestamp"]),
left=ast.Alias(alias="timestamp", hidden=True, expr=ast.Field(chain=["timestamp"])),
op=ast.CompareOperationOp.Gt,
right=ast.Constant(value="2023-01-01"),
)
Expand All @@ -74,7 +74,7 @@ def test_with_alias_table(self):

assert len(compare_operators) == 1
assert compare_operators[0] == ast.CompareOperation(
left=ast.Field(chain=["team_id"]),
left=ast.Alias(alias="team_id", hidden=True, expr=ast.Field(chain=["team_id"])),
op=ast.CompareOperationOp.Eq,
right=ast.Constant(value=1),
)
Expand All @@ -92,12 +92,12 @@ def test_with_multiple_clauses(self):

assert len(compare_operators) == 2
assert compare_operators[0] == ast.CompareOperation(
left=ast.Field(chain=["team_id"]),
left=ast.Alias(alias="team_id", hidden=True, expr=ast.Field(chain=["team_id"])),
op=ast.CompareOperationOp.Eq,
right=ast.Constant(value=1),
)
assert compare_operators[1] == ast.CompareOperation(
left=ast.Field(chain=["timestamp"]),
left=ast.Alias(alias="timestamp", hidden=True, expr=ast.Field(chain=["timestamp"])),
op=ast.CompareOperationOp.Gt,
right=ast.Constant(value="2023-01-01"),
)
Expand All @@ -117,7 +117,7 @@ def test_with_join(self):

assert len(compare_operators) == 1
assert compare_operators[0] == ast.CompareOperation(
left=ast.Field(chain=["team_id"]),
left=ast.Alias(alias="team_id", hidden=True, expr=ast.Field(chain=["team_id"])),
op=ast.CompareOperationOp.Eq,
right=ast.Constant(value=1),
)
Expand Down
57 changes: 53 additions & 4 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,52 @@ def visit_select_query(self, node: ast.SelectQuery):
else:
where = ast.And(exprs=[extra_where, where])
else:
raise HogQLException(f"Invalid where of type {type(extra_where).__name__} returned by join_expr")
raise HogQLException(
f"Invalid where of type {type(extra_where).__name__} returned by join_expr", node=visited_join.where
)

next_join = next_join.next_join

columns = [self.visit(column) for column in node.select] if node.select else ["1"]
if node.select:
# Only for ClickHouse: Gather all visible aliases, and/or the last hidden alias for
# each unique alias name. Then make the last hidden aliases visible.
if self.dialect == "clickhouse":
visible_aliases = {}
for alias in reversed(node.select):
if isinstance(alias, ast.Alias):
if not visible_aliases.get(alias.alias, None) or not alias.hidden:
visible_aliases[alias.alias] = alias

columns = []
for column in node.select:
if isinstance(column, ast.Alias):
# It's either a visible alias, or the last hidden alias for this name.
if visible_aliases.get(column.alias) == column:
if column.hidden:
if (
isinstance(column.expr, ast.Field)
and isinstance(column.expr.type, ast.FieldType)
and column.expr.type.name == column.alias
):
# Hide the hidden alias only if it's a simple field,
# and we're using the same name for the field and the alias
# E.g. events.event AS event --> events.evnet.
column = column.expr
else:
# Make the hidden alias visible
column = cast(ast.Alias, clone_expr(column))
column.hidden = False
else:
# Always print visible aliases.
pass
else:
# This is not the alias for this unique alias name. Skip it.
column = column.expr
columns.append(self.visit(column))
else:
columns = [self.visit(column) for column in node.select]
else:
columns = ["1"]
window = (
", ".join(
[f"{self._print_identifier(name)} AS ({self.visit(expr)})" for name, expr in node.window_exprs.items()]
Expand Down Expand Up @@ -810,8 +851,14 @@ def visit_placeholder(self, node: ast.Placeholder):
raise HogQLException(f"Placeholders, such as {{{node.field}}}, are not supported in this context")

def visit_alias(self, node: ast.Alias):
inside = self.visit(node.expr)
if isinstance(node.expr, ast.Alias):
# Skip hidden aliases completely.
if node.hidden:
return self.visit(node.expr)
expr = node.expr
while isinstance(expr, ast.Alias) and expr.hidden:
expr = expr.expr
inside = self.visit(expr)
if isinstance(expr, ast.Alias):
inside = f"({inside})"
alias = self._print_identifier(node.alias)
return f"{inside} AS {alias}"
Expand Down Expand Up @@ -1100,6 +1147,8 @@ def _is_nullable(self, node: ast.Expr) -> bool:
return True
elif isinstance(node.type, ast.FieldType):
return node.type.is_nullable()
elif isinstance(node, ast.Alias):
return self._is_nullable(node.expr)

# we don't know if it's nullable, so we assume it can be
return True
Expand Down
Loading

0 comments on commit 0582506

Please sign in to comment.