diff --git a/posthog/hogql_queries/insights/trends/aggregation_operations.py b/posthog/hogql_queries/insights/trends/aggregation_operations.py new file mode 100644 index 0000000000000..f585fc313dc70 --- /dev/null +++ b/posthog/hogql_queries/insights/trends/aggregation_operations.py @@ -0,0 +1,149 @@ +from typing import List +from posthog.hogql import ast +from posthog.hogql.parser import parse_expr, parse_select +from posthog.hogql_queries.utils.query_date_range import QueryDateRange +from posthog.schema import ActionsNode, EventsNode + + +class QueryAlternator: + """Allows query_builder to modify the query without having to expost the whole AST interface""" + + _query: ast.SelectQuery + _selects: List[ast.Expr] + _group_bys: List[ast.Expr] + _select_from: ast.JoinExpr | None + + def __init__(self, query: ast.SelectQuery): + self._query = query + self._selects = [] + self._group_bys = [] + self._select_from = None + + def build(self) -> ast.SelectQuery: + if len(self._selects) > 0: + self._query.select.extend(self._selects) + + if len(self._group_bys) > 0: + if self._query.group_by is None: + self._query.group_by = self._group_bys + else: + self._query.group_by.extend(self._group_bys) + + if self._select_from is not None: + self._query.select_from = self._select_from + + return self._query + + def append_select(self, expr: ast.Expr) -> None: + self._selects.append(expr) + + def append_group_by(self, expr: ast.Expr) -> None: + self._group_bys.append(expr) + + def replace_select_from(self, join_expr: ast.JoinExpr) -> None: + self._select_from = join_expr + + +class AggregationOperations: + series: EventsNode | ActionsNode + query_date_range: QueryDateRange + + def __init__(self, series: str, query_date_range: QueryDateRange) -> None: + self.series = series + self.query_date_range = query_date_range + + def select_aggregation(self) -> ast.Expr: + if self.series.math == "hogql": + return parse_expr(self.series.math_hogql) + elif self.series.math == "total": + return parse_expr("count(e.uuid)") + elif self.series.math == "dau": + return parse_expr("count(DISTINCT e.person_id)") + elif self.series.math == "weekly_active": + return ast.Field(chain=["counts"]) + + return parse_expr("count(e.uuid)") + + def requires_query_orchestration(self) -> bool: + return self.series.math == "weekly_active" + + def _parent_select_query(self, inner_query: ast.SelectQuery) -> ast.SelectQuery: + return parse_select( + """ + SELECT + counts AS total, + dateTrunc({interval}, timestamp) AS day_start + FROM {inner_query} + WHERE timestamp >= {date_from} AND timestamp <= {date_to} + """, + placeholders={ + **self.query_date_range.to_placeholders(), + "inner_query": inner_query, + }, + ) + + def _inner_select_query(self, cross_join_select_query: ast.SelectQuery) -> ast.SelectQuery: + return parse_select( + """ + SELECT + d.timestamp, + COUNT(DISTINCT actor_id) AS counts + FROM ( + SELECT + toStartOfDay({date_to}) - toIntervalDay(number) AS timestamp + FROM + numbers(dateDiff('day', toStartOfDay({date_from} - INTERVAL 7 DAY), {date_to})) + ) d + CROSS JOIN {cross_join_select_query} e + WHERE + e.timestamp <= d.timestamp + INTERVAL 1 DAY AND + e.timestamp > d.timestamp - INTERVAL 6 DAY + GROUP BY d.timestamp + ORDER BY d.timestamp + """, + placeholders={ + **self.query_date_range.to_placeholders(), + "cross_join_select_query": cross_join_select_query, + }, + ) + + def _events_query(self, events_where_clause: ast.Expr, sample_value: ast.RatioExpr) -> ast.SelectQuery: + return parse_select( + """ + SELECT + timestamp as timestamp, + e.person_id AS actor_id + FROM + events e + SAMPLE {sample} + WHERE {events_where_clause} + GROUP BY + timestamp, + actor_id + """, + placeholders={"events_where_clause": events_where_clause, "sample": sample_value}, + ) + + def get_query_orchestrator(self, events_where_clause: ast.Expr, sample_value: str): + events_query = self._events_query(events_where_clause, sample_value) + inner_select = self._inner_select_query(events_query) + parent_select = self._parent_select_query(inner_select) + + class QueryOrchestrator: + events_query_builder: QueryAlternator + inner_select_query_builder: QueryAlternator + parent_select_query_builder: QueryAlternator + + def __init__(self): + self.events_query_builder = QueryAlternator(events_query) + self.inner_select_query_builder = QueryAlternator(inner_select) + self.parent_select_query_builder = QueryAlternator(parent_select) + + def build(self): + self.events_query_builder.build() + self.inner_select_query_builder.build() + self.parent_select_query_builder.build() + + return parent_select + + return QueryOrchestrator() diff --git a/posthog/hogql_queries/insights/trends/query_builder.py b/posthog/hogql_queries/insights/trends/query_builder.py index cd5602303c7c7..3c0cd7d9356c7 100644 --- a/posthog/hogql_queries/insights/trends/query_builder.py +++ b/posthog/hogql_queries/insights/trends/query_builder.py @@ -3,6 +3,7 @@ from posthog.hogql.parser import parse_expr, parse_select from posthog.hogql.property import property_to_expr from posthog.hogql.timings import HogQLTimings +from posthog.hogql_queries.insights.trends.aggregation_operations import AggregationOperations from posthog.hogql_queries.insights.trends.breakdown import Breakdown from posthog.hogql_queries.insights.trends.breakdown_session import BreakdownSession from posthog.hogql_queries.insights.trends.utils import series_event_name @@ -107,11 +108,11 @@ def _get_date_subqueries(self) -> List[ast.SelectQuery]: ] def _get_events_subquery(self) -> ast.SelectQuery: - query = parse_select( + default_query = parse_select( """ SELECT {aggregation_operation} AS total, - dateTrunc({interval}, toTimeZone(toDateTime(timestamp), 'UTC')) AS day_start + dateTrunc({interval}, timestamp) AS day_start FROM events AS e SAMPLE {sample} WHERE {events_filter} @@ -120,19 +121,46 @@ def _get_events_subquery(self) -> ast.SelectQuery: placeholders={ **self.query_date_range.to_placeholders(), "events_filter": self._events_filter(), - "aggregation_operation": self._aggregation_operation(), + "aggregation_operation": self._aggregation_operation.select_aggregation(), "sample": self._sample_value(), }, ) - if self._breakdown.enabled: - query.select.append(self._breakdown.column_expr()) - query.group_by.append(ast.Field(chain=["breakdown_value"])) + # No breakdowns and no complex series aggregation + if not self._breakdown.enabled and not self._aggregation_operation.requires_query_orchestration(): + return default_query + # Both breakdowns and complex series aggregation + elif self._breakdown.enabled and self._aggregation_operation.requires_query_orchestration(): + orchestrator = self._aggregation_operation.get_query_orchestrator( + events_where_clause=self._events_filter(), + sample_value=self._sample_value(), + ) + orchestrator.events_query_builder.append_select(self._breakdown.column_expr()) + orchestrator.events_query_builder.append_group_by(ast.Field(chain=["breakdown_value"])) if self._breakdown.is_session_type: - query.select_from = self._breakdown_session.session_inner_join() + orchestrator.events_query_builder.replace_select_from(self._breakdown_session.session_inner_join()) - return query + orchestrator.inner_select_query_builder.append_select(ast.Field(chain=["breakdown_value"])) + orchestrator.inner_select_query_builder.append_group_by(ast.Field(chain=["breakdown_value"])) + + orchestrator.parent_select_query_builder.append_select(ast.Field(chain=["breakdown_value"])) + + return orchestrator.build() + # Just breakdowns + elif self._breakdown.enabled: + default_query.select.append(self._breakdown.column_expr()) + default_query.group_by.append(ast.Field(chain=["breakdown_value"])) + + if self._breakdown.is_session_type: + default_query.select_from = self._breakdown_session.session_inner_join() + # Just complex series aggregation + elif self._aggregation_operation.requires_query_orchestration(): + return self._aggregation_operation.get_query_orchestrator( + events_where_clause=self._events_filter(), sample_value=self._sample_value() + ).build() + + return default_query def _outer_select_query(self, inner_query: ast.SelectQuery) -> ast.SelectQuery: query = parse_select( @@ -234,14 +262,7 @@ def _events_filter(self) -> ast.Expr: else: return ast.And(exprs=filters) - def _aggregation_operation(self) -> ast.Expr: - if self.series.math == "hogql": - return parse_expr(self.series.math_hogql) - - return parse_expr("count(e.uuid)") - - # Using string interpolation for SAMPLE due to HogQL limitations with `UNION ALL` and `SAMPLE` AST nodes - def _sample_value(self) -> str: + def _sample_value(self) -> ast.RatioExpr: if self.query.samplingFactor is None: return ast.RatioExpr(left=ast.Constant(value=1)) @@ -260,3 +281,7 @@ def _breakdown(self): @cached_property def _breakdown_session(self): return BreakdownSession(self.query_date_range) + + @cached_property + def _aggregation_operation(self): + return AggregationOperations(self.series, self.query_date_range) diff --git a/posthog/hogql_queries/insights/trends/test/test_aggregation_operations.py b/posthog/hogql_queries/insights/trends/test/test_aggregation_operations.py new file mode 100644 index 0000000000000..656f2dc26ee69 --- /dev/null +++ b/posthog/hogql_queries/insights/trends/test/test_aggregation_operations.py @@ -0,0 +1,46 @@ +from typing import cast +from posthog.hogql import ast +from posthog.hogql.parser import parse_select +from posthog.hogql_queries.insights.trends.aggregation_operations import QueryAlternator + + +class TestQueryAlternator: + def test_select(self): + query = parse_select("SELECT event from events") + + query_modifier = QueryAlternator(query) + query_modifier.append_select(ast.Field(chain=["test"])) + query_modifier.build() + + assert len(query.select) == 2 + assert cast(ast.Field, query.select[1]).chain == ["test"] + + def test_group_no_pre_existing(self): + query = parse_select("SELECT event from events") + + query_modifier = QueryAlternator(query) + query_modifier.append_group_by(ast.Field(chain=["event"])) + query_modifier.build() + + assert len(query.group_by) == 1 + assert cast(ast.Field, query.group_by[0]).chain == ["event"] + + def test_group_with_pre_existing(self): + query = parse_select("SELECT event from events GROUP BY uuid") + + query_modifier = QueryAlternator(query) + query_modifier.append_group_by(ast.Field(chain=["event"])) + query_modifier.build() + + assert len(query.group_by) == 2 + assert cast(ast.Field, query.group_by[0]).chain == ["uuid"] + assert cast(ast.Field, query.group_by[1]).chain == ["event"] + + def test_replace_select_from(self): + query = parse_select("SELECT event from events") + + query_modifier = QueryAlternator(query) + query_modifier.replace_select_from(ast.JoinExpr(table=ast.Field(chain=["groups"]))) + query_modifier.build() + + assert query.select_from.table.chain == ["groups"] diff --git a/posthog/hogql_queries/insights/trends/trends_query_runner.py b/posthog/hogql_queries/insights/trends/trends_query_runner.py index b3f31ad5f6a6f..9c1dc4eca64f5 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/trends_query_runner.py @@ -102,7 +102,11 @@ def calculate(self): res.extend(self.build_series_response(response, series_with_extra)) - if self.query.trendsFilter is not None and self.query.trendsFilter.formula is not None: + if ( + self.query.trendsFilter is not None + and self.query.trendsFilter.formula is not None + and self.query.trendsFilter.formula != "" + ): res = self.apply_formula(self.query.trendsFilter.formula, res) return TrendsQueryResponse(results=res, timings=timings)