From 7a37a712ffa066ae32d6b0c8cb50ded277524432 Mon Sep 17 00:00:00 2001 From: Tom Owers Date: Mon, 30 Oct 2023 15:10:41 +0000 Subject: [PATCH] Added the final set of aggregation functions --- .../insights/trends/aggregation_operations.py | 178 ++++++++++++++++-- 1 file changed, 165 insertions(+), 13 deletions(-) diff --git a/posthog/hogql_queries/insights/trends/aggregation_operations.py b/posthog/hogql_queries/insights/trends/aggregation_operations.py index 3920344cbfd52..8b7d0cf71afda 100644 --- a/posthog/hogql_queries/insights/trends/aggregation_operations.py +++ b/posthog/hogql_queries/insights/trends/aggregation_operations.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from posthog.hogql import ast from posthog.hogql.parser import parse_expr, parse_select from posthog.hogql_queries.utils.query_date_range import QueryDateRange @@ -13,13 +13,15 @@ class QueryAlternator: _group_bys: List[ast.Expr] _select_from: ast.JoinExpr | None - def __init__(self, query: ast.SelectQuery): + def __init__(self, query: ast.SelectQuery | ast.SelectUnionQuery): + assert isinstance(query, ast.SelectQuery) + self._query = query self._selects = [] self._group_bys = [] self._select_from = None - def build(self) -> ast.SelectQuery: + def build(self) -> ast.SelectQuery | ast.SelectUnionQuery: if len(self._selects) > 0: self._query.select.extend(self._selects) @@ -48,26 +50,120 @@ class AggregationOperations: series: EventsNode | ActionsNode query_date_range: QueryDateRange - def __init__(self, series: str, query_date_range: QueryDateRange) -> None: + def __init__(self, series: EventsNode | ActionsNode, query_date_range: QueryDateRange) -> None: self.series = series self.query_date_range = query_date_range def select_aggregation(self) -> ast.Expr: - if self.series.math == "hogql": + if self.series.math == "hogql" and self.series.math_hogql is not None: return parse_expr(self.series.math_hogql) elif self.series.math == "total": return parse_expr("count(e.uuid)") elif self.series.math == "dau": return parse_expr("count(DISTINCT e.person_id)") elif self.series.math == "weekly_active": - return ast.Field(chain=["counts"]) + return ast.Field(chain=["counts"]) # This gets replaced when doing query orchestration + elif self.series.math == "monthly_active": + return ast.Field(chain=["counts"]) # This gets replaced when doing query orchestration + elif self.series.math == "unique_session": + return parse_expr('count(DISTINCT e."$session_id")') + elif self.series.math == "unique_group" and self.series.math_group_type_index is not None: + return parse_expr(f'count(DISTINCT e."$group_{self.series.math_group_type_index}")') + elif self.series.math_property is not None: + if self.series.math == "avg": + return self._math_func("avg", None) + elif self.series.math == "sum": + return self._math_func("sum", None) + elif self.series.math == "min": + return self._math_func("min", None) + elif self.series.math == "max": + return self._math_func("max", None) + elif self.series.math == "median": + return self._math_func("median", None) + elif self.series.math == "p90": + return self._math_quantile(0.9, None) + elif self.series.math == "p95": + return self._math_quantile(0.95, None) + elif self.series.math == "p99": + return self._math_quantile(0.99, None) + else: + raise NotImplementedError() return parse_expr("count(e.uuid)") def requires_query_orchestration(self) -> bool: - return self.series.math == "weekly_active" + math_to_return_true = [ + "weekly_active", + "monthly_active", + ] + + return self._is_count_per_actor_variant() or self.series.math in math_to_return_true + + def _is_count_per_actor_variant(self): + return self.series.math in [ + "avg_count_per_actor", + "min_count_per_actor", + "max_count_per_actor", + "median_count_per_actor", + "p90_count_per_actor", + "p95_count_per_actor", + "p99_count_per_actor", + ] + + def _math_func(self, method: str, override_chain: Optional[List[str | int]]) -> ast.Call: + if override_chain is not None: + return ast.Call(name=method, args=[ast.Field(chain=override_chain)]) + + if self.series.math_property == "$time": + return ast.Call( + name=method, + args=[ + ast.Call( + name="toUnixTimestamp", + args=[ast.Field(chain=["properties", "$time"])], + ) + ], + ) + + if self.series.math_property == "$session_duration": + chain = ["session", "duration"] + else: + chain = ["properties", self.series.math_property] + + return ast.Call(name=method, args=[ast.Field(chain=chain)]) + + def _math_quantile(self, percentile: float, override_chain: Optional[List[str | int]]) -> ast.Call: + chain = ["properties", self.series.math_property] + + return ast.Call( + name="quantile", + params=[ast.Constant(value=percentile)], + args=[ast.Field(chain=override_chain or chain)], + ) + + def _interval_placeholders(self): + if self.series.math == "weekly_active": + return { + "exclusive_lookback": ast.Call(name="toIntervalDay", args=[ast.Constant(value=6)]), + "inclusive_lookback": ast.Call(name="toIntervalDay", args=[ast.Constant(value=7)]), + } + elif self.series.math == "monthly_active": + return { + "exclusive_lookback": ast.Call(name="toIntervalDay", args=[ast.Constant(value=29)]), + "inclusive_lookback": ast.Call(name="toIntervalDay", args=[ast.Constant(value=30)]), + } + + raise NotImplementedError() + + def _parent_select_query( + self, inner_query: ast.SelectQuery | ast.SelectUnionQuery + ) -> ast.SelectQuery | ast.SelectUnionQuery: + if self._is_count_per_actor_variant(): + return parse_select( + "SELECT total, day_start FROM {inner_query}", + placeholders={"inner_query": inner_query}, + ) - def _parent_select_query(self, inner_query: ast.SelectQuery) -> ast.SelectQuery: return parse_select( """ SELECT @@ -82,7 +178,42 @@ def _parent_select_query(self, inner_query: ast.SelectQuery) -> ast.SelectQuery: }, ) - def _inner_select_query(self, cross_join_select_query: ast.SelectQuery) -> ast.SelectQuery: + def _inner_select_query( + self, cross_join_select_query: ast.SelectQuery | ast.SelectUnionQuery + ) -> ast.SelectQuery | ast.SelectUnionQuery: + if self._is_count_per_actor_variant(): + if self.series.math == "avg_count_per_actor": + math_func = self._math_func("avg", ["total"]) + elif self.series.math == "min_count_per_actor": + math_func = self._math_func("min", ["total"]) + elif self.series.math == "max_count_per_actor": + math_func = self._math_func("max", ["total"]) + elif self.series.math == "median_count_per_actor": + math_func = self._math_func("median", ["total"]) + elif self.series.math == "p90_count_per_actor": + math_func = self._math_quantile(0.9, ["total"]) + elif self.series.math == "p95_count_per_actor": + math_func = self._math_quantile(0.95, ["total"]) + elif self.series.math == "p99_count_per_actor": + math_func = self._math_quantile(0.99, ["total"]) + else: + raise NotImplementedError() + + total_alias = ast.Alias(alias="total", expr=math_func) + + return parse_select( + """ + SELECT + {total_alias}, day_start + FROM {inner_query} + GROUP BY day_start + """, + placeholders={ + "inner_query": cross_join_select_query, + "total_alias": total_alias, + }, + ) + return parse_select( """ SELECT @@ -92,22 +223,43 @@ def _inner_select_query(self, cross_join_select_query: ast.SelectQuery) -> ast.S SELECT toStartOfDay({date_to}) - toIntervalDay(number) AS timestamp FROM - numbers(dateDiff('day', toStartOfDay({date_from} - INTERVAL 7 DAY), {date_to})) + numbers(dateDiff('day', toStartOfDay({date_from} - {inclusive_lookback}), {date_to})) ) d CROSS JOIN {cross_join_select_query} e WHERE e.timestamp <= d.timestamp + INTERVAL 1 DAY AND - e.timestamp > d.timestamp - INTERVAL 6 DAY + e.timestamp > d.timestamp - {exclusive_lookback} GROUP BY d.timestamp ORDER BY d.timestamp """, placeholders={ **self.query_date_range.to_placeholders(), + **self._interval_placeholders(), "cross_join_select_query": cross_join_select_query, }, ) - def _events_query(self, events_where_clause: ast.Expr, sample_value: ast.RatioExpr) -> ast.SelectQuery: + def _events_query( + self, events_where_clause: ast.Expr, sample_value: ast.RatioExpr + ) -> ast.SelectQuery | ast.SelectUnionQuery: + if self._is_count_per_actor_variant(): + return parse_select( + """ + SELECT + count(e.uuid) AS total, + dateTrunc({interval}, timestamp) AS day_start + FROM events AS e + SAMPLE {sample} + WHERE {events_where_clause} + GROUP BY e.person_id, day_start + """, + placeholders={ + **self.query_date_range.to_placeholders(), + "events_where_clause": events_where_clause, + "sample": sample_value, + }, + ) + return parse_select( """ SELECT @@ -127,7 +279,7 @@ def _events_query(self, events_where_clause: ast.Expr, sample_value: ast.RatioEx }, ) - def get_query_orchestrator(self, events_where_clause: ast.Expr, sample_value: str): + def get_query_orchestrator(self, events_where_clause: ast.Expr, sample_value: ast.RatioExpr): events_query = self._events_query(events_where_clause, sample_value) inner_select = self._inner_select_query(events_query) parent_select = self._parent_select_query(inner_select)