Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added the final set of aggregation functions #18268

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 165 additions & 13 deletions posthog/hogql_queries/insights/trends/aggregation_operations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from posthog.hogql import ast
from posthog.hogql.parser import parse_expr, parse_select
from posthog.hogql_queries.utils.query_date_range import QueryDateRange
Expand All @@ -13,13 +13,15 @@ class QueryAlternator:
_group_bys: List[ast.Expr]
_select_from: ast.JoinExpr | None

def __init__(self, query: ast.SelectQuery):
def __init__(self, query: ast.SelectQuery | ast.SelectUnionQuery):
assert isinstance(query, ast.SelectQuery)

self._query = query
self._selects = []
self._group_bys = []
self._select_from = None

def build(self) -> ast.SelectQuery:
def build(self) -> ast.SelectQuery | ast.SelectUnionQuery:
if len(self._selects) > 0:
self._query.select.extend(self._selects)

Expand Down Expand Up @@ -48,26 +50,120 @@ class AggregationOperations:
series: EventsNode | ActionsNode
query_date_range: QueryDateRange

def __init__(self, series: str, query_date_range: QueryDateRange) -> None:
def __init__(self, series: EventsNode | ActionsNode, query_date_range: QueryDateRange) -> None:
self.series = series
self.query_date_range = query_date_range

def select_aggregation(self) -> ast.Expr:
if self.series.math == "hogql":
if self.series.math == "hogql" and self.series.math_hogql is not None:
return parse_expr(self.series.math_hogql)
elif self.series.math == "total":
return parse_expr("count(e.uuid)")
elif self.series.math == "dau":
return parse_expr("count(DISTINCT e.person_id)")
elif self.series.math == "weekly_active":
return ast.Field(chain=["counts"])
return ast.Field(chain=["counts"]) # This gets replaced when doing query orchestration
elif self.series.math == "monthly_active":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"This gets replaced" sounds a lot like ast.Placeholder 😄 I'm curious why this scheme instead of placeholders?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either is fine - no AST parsing/resolving is actually run before the replacement gets done, we just need something to populate the field. In my mind, Field is the simplest AST node we could pick here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they both function the same way here, I think it'd be better to use Placeholder, clearer in terms of intent. If that'd need more complex replacement logic, then Field is fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

return ast.Field(chain=["counts"]) # This gets replaced when doing query orchestration
elif self.series.math == "unique_session":
return parse_expr('count(DISTINCT e."$session_id")')
elif self.series.math == "unique_group" and self.series.math_group_type_index is not None:
return parse_expr(f'count(DISTINCT e."$group_{self.series.math_group_type_index}")')
elif self.series.math_property is not None:
if self.series.math == "avg":
return self._math_func("avg", None)
elif self.series.math == "sum":
return self._math_func("sum", None)
elif self.series.math == "min":
return self._math_func("min", None)
elif self.series.math == "max":
return self._math_func("max", None)
elif self.series.math == "median":
return self._math_func("median", None)
elif self.series.math == "p90":
return self._math_quantile(0.9, None)
elif self.series.math == "p95":
return self._math_quantile(0.95, None)
elif self.series.math == "p99":
return self._math_quantile(0.99, None)
else:
raise NotImplementedError()

return parse_expr("count(e.uuid)")

def requires_query_orchestration(self) -> bool:
return self.series.math == "weekly_active"
math_to_return_true = [
"weekly_active",
"monthly_active",
]

return self._is_count_per_actor_variant() or self.series.math in math_to_return_true

def _is_count_per_actor_variant(self):
return self.series.math in [
"avg_count_per_actor",
"min_count_per_actor",
"max_count_per_actor",
"median_count_per_actor",
"p90_count_per_actor",
"p95_count_per_actor",
"p99_count_per_actor",
]

def _math_func(self, method: str, override_chain: Optional[List[str | int]]) -> ast.Call:
if override_chain is not None:
return ast.Call(name=method, args=[ast.Field(chain=override_chain)])

if self.series.math_property == "$time":
return ast.Call(
name=method,
args=[
ast.Call(
name="toUnixTimestamp",
args=[ast.Field(chain=["properties", "$time"])],
)
],
)

Comment on lines +117 to +126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting – I thought the $time property was deprecated, in which case we probably don't need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is - but, it looks like there's a frontend bug that means the property default gets set to $time - one to fix in a follow-up

if self.series.math_property == "$session_duration":
chain = ["session", "duration"]
else:
chain = ["properties", self.series.math_property]

return ast.Call(name=method, args=[ast.Field(chain=chain)])

def _math_quantile(self, percentile: float, override_chain: Optional[List[str | int]]) -> ast.Call:
chain = ["properties", self.series.math_property]

return ast.Call(
name="quantile",
params=[ast.Constant(value=percentile)],
args=[ast.Field(chain=override_chain or chain)],
)

def _interval_placeholders(self):
if self.series.math == "weekly_active":
return {
"exclusive_lookback": ast.Call(name="toIntervalDay", args=[ast.Constant(value=6)]),
"inclusive_lookback": ast.Call(name="toIntervalDay", args=[ast.Constant(value=7)]),
}
elif self.series.math == "monthly_active":
return {
"exclusive_lookback": ast.Call(name="toIntervalDay", args=[ast.Constant(value=29)]),
"inclusive_lookback": ast.Call(name="toIntervalDay", args=[ast.Constant(value=30)]),
}

raise NotImplementedError()

def _parent_select_query(
self, inner_query: ast.SelectQuery | ast.SelectUnionQuery
) -> ast.SelectQuery | ast.SelectUnionQuery:
if self._is_count_per_actor_variant():
return parse_select(
"SELECT total, day_start FROM {inner_query}",
placeholders={"inner_query": inner_query},
)

def _parent_select_query(self, inner_query: ast.SelectQuery) -> ast.SelectQuery:
return parse_select(
"""
SELECT
Expand All @@ -82,7 +178,42 @@ def _parent_select_query(self, inner_query: ast.SelectQuery) -> ast.SelectQuery:
},
)

def _inner_select_query(self, cross_join_select_query: ast.SelectQuery) -> ast.SelectQuery:
def _inner_select_query(
self, cross_join_select_query: ast.SelectQuery | ast.SelectUnionQuery
) -> ast.SelectQuery | ast.SelectUnionQuery:
if self._is_count_per_actor_variant():
if self.series.math == "avg_count_per_actor":
math_func = self._math_func("avg", ["total"])
elif self.series.math == "min_count_per_actor":
math_func = self._math_func("min", ["total"])
elif self.series.math == "max_count_per_actor":
math_func = self._math_func("max", ["total"])
elif self.series.math == "median_count_per_actor":
math_func = self._math_func("median", ["total"])
elif self.series.math == "p90_count_per_actor":
math_func = self._math_quantile(0.9, ["total"])
elif self.series.math == "p95_count_per_actor":
math_func = self._math_quantile(0.95, ["total"])
elif self.series.math == "p99_count_per_actor":
math_func = self._math_quantile(0.99, ["total"])
else:
raise NotImplementedError()

total_alias = ast.Alias(alias="total", expr=math_func)

return parse_select(
"""
SELECT
{total_alias}, day_start
FROM {inner_query}
GROUP BY day_start
""",
placeholders={
"inner_query": cross_join_select_query,
"total_alias": total_alias,
},
)

return parse_select(
"""
SELECT
Expand All @@ -92,22 +223,43 @@ def _inner_select_query(self, cross_join_select_query: ast.SelectQuery) -> ast.S
SELECT
toStartOfDay({date_to}) - toIntervalDay(number) AS timestamp
FROM
numbers(dateDiff('day', toStartOfDay({date_from} - INTERVAL 7 DAY), {date_to}))
numbers(dateDiff('day', toStartOfDay({date_from} - {inclusive_lookback}), {date_to}))
) d
CROSS JOIN {cross_join_select_query} e
WHERE
e.timestamp <= d.timestamp + INTERVAL 1 DAY AND
e.timestamp > d.timestamp - INTERVAL 6 DAY
e.timestamp > d.timestamp - {exclusive_lookback}
GROUP BY d.timestamp
ORDER BY d.timestamp
""",
placeholders={
**self.query_date_range.to_placeholders(),
**self._interval_placeholders(),
"cross_join_select_query": cross_join_select_query,
},
)

def _events_query(self, events_where_clause: ast.Expr, sample_value: ast.RatioExpr) -> ast.SelectQuery:
def _events_query(
self, events_where_clause: ast.Expr, sample_value: ast.RatioExpr
) -> ast.SelectQuery | ast.SelectUnionQuery:
if self._is_count_per_actor_variant():
return parse_select(
"""
SELECT
count(e.uuid) AS total,
dateTrunc({interval}, timestamp) AS day_start
FROM events AS e
SAMPLE {sample}
WHERE {events_where_clause}
GROUP BY e.person_id, day_start
""",
placeholders={
**self.query_date_range.to_placeholders(),
"events_where_clause": events_where_clause,
"sample": sample_value,
},
)

return parse_select(
"""
SELECT
Expand All @@ -127,7 +279,7 @@ def _events_query(self, events_where_clause: ast.Expr, sample_value: ast.RatioEx
},
)

def get_query_orchestrator(self, events_where_clause: ast.Expr, sample_value: str):
def get_query_orchestrator(self, events_where_clause: ast.Expr, sample_value: ast.RatioExpr):
events_query = self._events_query(events_where_clause, sample_value)
inner_select = self._inner_select_query(events_query)
parent_select = self._parent_select_query(inner_select)
Expand Down
Loading