From 1b3dc1704cad01b57523814a1f7ca779e6f225b7 Mon Sep 17 00:00:00 2001 From: eric <eeoneric@gmail.com> Date: Tue, 13 Feb 2024 14:09:15 -0500 Subject: [PATCH] split out query builder --- .../data_warehouse_trends_query_builder.py | 526 ++++++++++++++++++ .../insights/trends/test/test_trends.py | 1 + ...ery_builder.py => trends_query_builder.py} | 61 +- .../insights/trends/trends_query_runner.py | 30 +- 4 files changed, 564 insertions(+), 54 deletions(-) create mode 100644 posthog/hogql_queries/insights/trends/data_warehouse_trends_query_builder.py rename posthog/hogql_queries/insights/trends/{query_builder.py => trends_query_builder.py} (92%) diff --git a/posthog/hogql_queries/insights/trends/data_warehouse_trends_query_builder.py b/posthog/hogql_queries/insights/trends/data_warehouse_trends_query_builder.py new file mode 100644 index 0000000000000..4950ea5028f21 --- /dev/null +++ b/posthog/hogql_queries/insights/trends/data_warehouse_trends_query_builder.py @@ -0,0 +1,526 @@ +from typing import List, Optional, cast +from posthog.hogql import ast +from posthog.hogql.parser import parse_expr, parse_select +from posthog.hogql.property import property_to_expr +from posthog.hogql.timings import HogQLTimings +from posthog.hogql_queries.insights.trends.aggregation_operations import ( + AggregationOperations, +) +from posthog.hogql_queries.insights.trends.breakdown import Breakdown +from posthog.hogql_queries.insights.trends.display import TrendsDisplay +from posthog.hogql_queries.utils.query_date_range import QueryDateRange +from posthog.models.filters.mixins.utils import cached_property +from posthog.models.team.team import Team +from posthog.schema import HogQLQueryModifiers, TrendsQuery, DataWarehouseNode + + +class DataWarehouseTrendsQueryBuilder: + query: TrendsQuery + team: Team + query_date_range: QueryDateRange + series: DataWarehouseNode + timings: HogQLTimings + modifiers: HogQLQueryModifiers + + def __init__( + self, + trends_query: TrendsQuery, + team: Team, + query_date_range: QueryDateRange, + series: DataWarehouseNode, + timings: HogQLTimings, + modifiers: HogQLQueryModifiers, + ): + self.query = trends_query + self.team = team + self.query_date_range = query_date_range + self.series = series + self.timings = timings + self.modifiers = modifiers + + def build_query(self) -> ast.SelectQuery | ast.SelectUnionQuery: + breakdown = self._breakdown(is_actors_query=False) + + events_query: ast.SelectQuery | ast.SelectUnionQuery + + if self._trends_display.should_aggregate_values(): + events_query = self._get_events_subquery(False, is_actors_query=False, breakdown=breakdown) + else: + date_subqueries = self._get_date_subqueries(breakdown=breakdown) + event_query = self._get_events_subquery(False, is_actors_query=False, breakdown=breakdown) + + events_query = ast.SelectUnionQuery(select_queries=[*date_subqueries, event_query]) + + inner_select = self._inner_select_query(inner_query=events_query, breakdown=breakdown) + full_query = self._outer_select_query(inner_query=inner_select, breakdown=breakdown) + + return full_query + + def _get_date_subqueries(self, breakdown: Breakdown, ignore_breakdowns: bool = False) -> List[ast.SelectQuery]: + if not breakdown.enabled or ignore_breakdowns: + return [ + cast( + ast.SelectQuery, + parse_select( + """ + SELECT + 0 AS total, + {date_to_start_of_interval} - {number_interval_period} AS day_start + FROM + numbers( + coalesce(dateDiff({interval}, {date_from}, {date_to}), 0) + ) + """, + placeholders={ + **self.query_date_range.to_placeholders(), + }, + ), + ), + cast( + ast.SelectQuery, + parse_select( + """ + SELECT + 0 AS total, + {date_from_start_of_interval} AS day_start + """, + placeholders={ + **self.query_date_range.to_placeholders(), + }, + ), + ), + ] + + return [ + cast( + ast.SelectQuery, + parse_select( + """ + SELECT + 0 AS total, + ticks.day_start as day_start, + breakdown_value + FROM ( + SELECT + {date_to_start_of_interval} - {number_interval_period} AS day_start + FROM + numbers( + coalesce(dateDiff({interval}, {date_from}, {date_to}), 0) + ) + UNION ALL + SELECT {date_from_start_of_interval} AS day_start + ) as ticks + CROSS JOIN ( + SELECT breakdown_value + FROM ( + SELECT {cross_join_breakdown_values} + ) + ARRAY JOIN breakdown_value as breakdown_value + ) as sec + ORDER BY breakdown_value, day_start + """, + placeholders={ + **self.query_date_range.to_placeholders(), + **breakdown.placeholders(), + }, + ), + ) + ] + + def _get_events_subquery( + self, + no_modifications: Optional[bool], + is_actors_query: bool, + breakdown: Breakdown, + breakdown_values_override: Optional[str | int] = None, + actors_query_time_frame: Optional[str | int] = None, + ) -> ast.SelectQuery: + day_start = ast.Alias( + alias="day_start", + expr=ast.Call( + name=f"toStartOf{self.query_date_range.interval_name.title()}", + args=[ast.Call(name="toDateTime", args=[ast.Field(chain=[self.series.timestamp_field])])], + ), + ) + + events_filter = self._events_filter( + ignore_breakdowns=False, + breakdown=breakdown, + is_actors_query=is_actors_query, + breakdown_values_override=breakdown_values_override, + actors_query_time_frame=actors_query_time_frame, + ) + + default_query = cast( + ast.SelectQuery, + parse_select( + """ + SELECT + {aggregation_operation} AS total + FROM {table} AS e + WHERE {events_filter} + """, + placeholders={ + "events_filter": events_filter, + "aggregation_operation": self._aggregation_operation.select_aggregation(), + "table": self._table_expr, + }, + ), + ) + + default_query.group_by = [] + + if not self._trends_display.should_aggregate_values() and not is_actors_query: + default_query.select.append(day_start) + default_query.group_by.append(ast.Field(chain=["day_start"])) + + # TODO: Move this logic into the below branches when working on adding breakdown support for the person modal + if is_actors_query: + default_query.select = [ast.Alias(alias="person_id", expr=ast.Field(chain=["e", "person", "id"]))] + default_query.distinct = True + default_query.group_by = [] + + # No breakdowns and no complex series aggregation + if ( + not breakdown.enabled + and not self._aggregation_operation.requires_query_orchestration() + and not self._aggregation_operation.aggregating_on_session_duration() + ) or no_modifications is True: + return default_query + # Both breakdowns and complex series aggregation + elif breakdown.enabled and self._aggregation_operation.requires_query_orchestration(): + orchestrator = self._aggregation_operation.get_query_orchestrator( + events_where_clause=events_filter, + sample_value=self._sample_value(), + ) + + if is_actors_query: + orchestrator.events_query_builder.append_select( + ast.Alias(alias="person_id", expr=ast.Field(chain=["e", "person", "id"])) + ) + orchestrator.inner_select_query_builder.append_select(ast.Field(chain=["person_id"])) + orchestrator.parent_select_query_builder.append_select(ast.Field(chain=["person_id"])) + else: + orchestrator.events_query_builder.append_select(breakdown.column_expr()) + orchestrator.events_query_builder.append_group_by(ast.Field(chain=["breakdown_value"])) + + orchestrator.inner_select_query_builder.append_select(ast.Field(chain=["breakdown_value"])) + orchestrator.inner_select_query_builder.append_group_by(ast.Field(chain=["breakdown_value"])) + + orchestrator.parent_select_query_builder.append_select(ast.Field(chain=["breakdown_value"])) + + if ( + self._aggregation_operation.should_aggregate_values + and not self._aggregation_operation.is_count_per_actor_variant() + and not is_actors_query + ): + orchestrator.parent_select_query_builder.append_group_by(ast.Field(chain=["breakdown_value"])) + + return orchestrator.build() + # Breakdowns and session duration math property + elif breakdown.enabled and self._aggregation_operation.aggregating_on_session_duration(): + default_query.select = [ + ast.Alias( + alias="session_duration", expr=ast.Call(name="any", args=[ast.Field(chain=["session", "duration"])]) + ), + breakdown.column_expr(), + ] + + default_query.group_by.extend([ast.Field(chain=["session", "id"]), ast.Field(chain=["breakdown_value"])]) + + wrapper = self.session_duration_math_property_wrapper(default_query) + assert wrapper.group_by is not None + + if not self._trends_display.should_aggregate_values() and not is_actors_query: + default_query.select.append(day_start) + default_query.group_by.append(ast.Field(chain=["day_start"])) + + wrapper.select.append(ast.Field(chain=["day_start"])) + wrapper.group_by.append(ast.Field(chain=["day_start"])) + + if is_actors_query: + default_query.select.append(ast.Alias(alias="person_id", expr=ast.Field(chain=["e", "person", "id"]))) + wrapper.select.append(ast.Field(chain=["person_id"])) + else: + wrapper.select.append(ast.Field(chain=["breakdown_value"])) + wrapper.group_by.append(ast.Field(chain=["breakdown_value"])) + + return wrapper + # Just breakdowns + elif breakdown.enabled: + if not is_actors_query: + default_query.select.append(breakdown.column_expr()) + default_query.group_by.append(ast.Field(chain=["breakdown_value"])) + # Just session duration math property + elif self._aggregation_operation.aggregating_on_session_duration(): + default_query.select = [ + ast.Alias( + alias="session_duration", expr=ast.Call(name="any", args=[ast.Field(chain=["session", "duration"])]) + ) + ] + default_query.group_by.append(ast.Field(chain=["session", "id"])) + + wrapper = self.session_duration_math_property_wrapper(default_query) + + if not self._trends_display.should_aggregate_values() and not is_actors_query: + assert wrapper.group_by is not None + + default_query.select.append(day_start) + default_query.group_by.append(ast.Field(chain=["day_start"])) + + wrapper.select.append(ast.Field(chain=["day_start"])) + wrapper.group_by.append(ast.Field(chain=["day_start"])) + + if is_actors_query: + default_query.select.append(ast.Alias(alias="person_id", expr=ast.Field(chain=["e", "person", "id"]))) + wrapper.select.append(ast.Field(chain=["person_id"])) + + return wrapper + # Just complex series aggregation + elif self._aggregation_operation.requires_query_orchestration(): + orchestrator = self._aggregation_operation.get_query_orchestrator( + events_where_clause=events_filter, + sample_value=self._sample_value(), + ) + + if is_actors_query: + orchestrator.events_query_builder.append_select( + ast.Alias(alias="person_id", expr=ast.Field(chain=["e", "person", "id"])) + ) + orchestrator.inner_select_query_builder.append_select(ast.Field(chain=["person_id"])) + orchestrator.parent_select_query_builder.append_select(ast.Field(chain=["person_id"])) + + return orchestrator.build() + + return default_query + + def _outer_select_query(self, breakdown: Breakdown, inner_query: ast.SelectQuery) -> ast.SelectQuery: + query = cast( + ast.SelectQuery, + parse_select( + """ + SELECT + groupArray(day_start) AS date, + groupArray(count) AS total + FROM {inner_query} + """, + placeholders={"inner_query": inner_query}, + ), + ) + + query = self._trends_display.modify_outer_query( + outer_query=query, + inner_query=inner_query, + dates_queries=ast.SelectUnionQuery( + select_queries=self._get_date_subqueries(ignore_breakdowns=True, breakdown=breakdown) + ), + ) + + query.order_by = [ast.OrderExpr(expr=ast.Call(name="sum", args=[ast.Field(chain=["count"])]), order="DESC")] + + if breakdown.enabled: + query.select.append( + ast.Alias( + alias="breakdown_value", + expr=ast.Call( + name="ifNull", + args=[ + ast.Call(name="toString", args=[ast.Field(chain=["breakdown_value"])]), + ast.Constant(value=""), + ], + ), + ) + ) + query.group_by = [ast.Field(chain=["breakdown_value"])] + query.order_by.append(ast.OrderExpr(expr=ast.Field(chain=["breakdown_value"]), order="ASC")) + + return query + + def _inner_select_query( + self, breakdown: Breakdown, inner_query: ast.SelectQuery | ast.SelectUnionQuery + ) -> ast.SelectQuery: + query = cast( + ast.SelectQuery, + parse_select( + """ + SELECT + sum(total) AS count + FROM {inner_query} + """, + placeholders={"inner_query": inner_query}, + ), + ) + + if ( + self.query.trendsFilter is not None + and self.query.trendsFilter.smoothingIntervals is not None + and self.query.trendsFilter.smoothingIntervals > 1 + ): + rolling_average = ast.Alias( + alias="count", + expr=ast.Call( + name="floor", + args=[ + ast.WindowFunction( + name="avg", + args=[ast.Call(name="sum", args=[ast.Field(chain=["total"])])], + over_expr=ast.WindowExpr( + order_by=[ast.OrderExpr(expr=ast.Field(chain=["day_start"]), order="ASC")], + frame_method="ROWS", + frame_start=ast.WindowFrameExpr( + frame_type="PRECEDING", + frame_value=int(self.query.trendsFilter.smoothingIntervals - 1), + ), + frame_end=ast.WindowFrameExpr(frame_type="CURRENT ROW"), + ), + ) + ], + ), + ) + query.select = [rolling_average] + + query.group_by = [] + query.order_by = [] + + if not self._trends_display.should_aggregate_values(): + query.select.append(ast.Field(chain=["day_start"])) + query.group_by.append(ast.Field(chain=["day_start"])) + query.order_by.append(ast.OrderExpr(expr=ast.Field(chain=["day_start"]), order="ASC")) + + if breakdown.enabled: + query.select.append(ast.Field(chain=["breakdown_value"])) + query.group_by.append(ast.Field(chain=["breakdown_value"])) + query.order_by.append(ast.OrderExpr(expr=ast.Field(chain=["breakdown_value"]), order="ASC")) + + if self._trends_display.should_wrap_inner_query(): + query = self._trends_display.wrap_inner_query(query, breakdown.enabled) + if breakdown.enabled: + query.select.append(ast.Field(chain=["breakdown_value"])) + + return query + + def _events_filter( + self, + is_actors_query: bool, + breakdown: Breakdown | None, + ignore_breakdowns: bool = False, + breakdown_values_override: Optional[str | int] = None, + actors_query_time_frame: Optional[str | int] = None, + ) -> ast.Expr: + series = self.series + filters: List[ast.Expr] = [] + + # Dates + if is_actors_query and actors_query_time_frame is not None: + to_start_of_time_frame = f"toStartOf{self.query_date_range.interval_name.capitalize()}" + filters.append( + ast.CompareOperation( + left=ast.Call(name=to_start_of_time_frame, args=[ast.Field(chain=["timestamp"])]), + op=ast.CompareOperationOp.Eq, + right=ast.Call(name="toDateTime", args=[ast.Constant(value=actors_query_time_frame)]), + ) + ) + elif not self._aggregation_operation.requires_query_orchestration(): + filters.extend( + [ + parse_expr( + "{timestamp_field} >= {date_from_with_adjusted_start_of_interval}", + placeholders={ + "timestamp_field": ast.Call( + name="toDateTime", args=[ast.Field(chain=[self.series.timestamp_field])] + ), + **self.query_date_range.to_placeholders(), + }, + ), + parse_expr( + "{timestamp_field} <= {date_to}", + placeholders={ + "timestamp_field": ast.Call( + name="toDateTime", args=[ast.Field(chain=[self.series.timestamp_field])] + ), + **self.query_date_range.to_placeholders(), + }, + ), + ] + ) + + # Properties + if self.query.properties is not None and self.query.properties != []: + filters.append(property_to_expr(self.query.properties, self.team)) + + # Series Filters + if series.properties is not None and series.properties != []: + filters.append(property_to_expr(series.properties, self.team)) + + # Breakdown + if not ignore_breakdowns and breakdown is not None: + if breakdown.enabled and not breakdown.is_histogram_breakdown: + breakdown_filter = breakdown.events_where_filter() + if breakdown_filter is not None: + filters.append(breakdown_filter) + + if len(filters) == 0: + return ast.Constant(value=True) + + return ast.And(exprs=filters) + + # TODO: remove this + def _sample_value(self) -> ast.RatioExpr: + if self.query.samplingFactor is None: + return ast.RatioExpr(left=ast.Constant(value=1)) + + return ast.RatioExpr(left=ast.Constant(value=self.query.samplingFactor)) + + def session_duration_math_property_wrapper(self, default_query: ast.SelectQuery) -> ast.SelectQuery: + query = cast( + ast.SelectQuery, + parse_select( + """ + SELECT {aggregation_operation} AS total + FROM {default_query} + """, + placeholders={ + "aggregation_operation": self._aggregation_operation.select_aggregation(), + "default_query": default_query, + }, + ), + ) + + query.group_by = [] + return query + + def _breakdown(self, is_actors_query: bool, breakdown_values_override: Optional[str | int] = None): + return Breakdown( + team=self.team, + query=self.query, + series=self.series, + query_date_range=self.query_date_range, + timings=self.timings, + modifiers=self.modifiers, + events_filter=self._events_filter( + breakdown=None, # Passing in None because we know we dont actually need it + ignore_breakdowns=True, + is_actors_query=is_actors_query, + breakdown_values_override=breakdown_values_override, + ), + breakdown_values_override=[breakdown_values_override] if breakdown_values_override is not None else None, + ) + + @cached_property + def _aggregation_operation(self) -> AggregationOperations: + return AggregationOperations( + self.team, self.series, self.query_date_range, self._trends_display.should_aggregate_values() + ) + + @cached_property + def _trends_display(self) -> TrendsDisplay: + display = ( + self.query.trendsFilter.display + if self.query.trendsFilter is not None and self.query.trendsFilter.display is not None + else None + ) + return TrendsDisplay(display) + + @cached_property + def _table_expr(self) -> ast.Field: + return ast.Field(chain=[self.series.table_name]) diff --git a/posthog/hogql_queries/insights/trends/test/test_trends.py b/posthog/hogql_queries/insights/trends/test/test_trends.py index de5b55986c7ff..fd19df6433fda 100644 --- a/posthog/hogql_queries/insights/trends/test/test_trends.py +++ b/posthog/hogql_queries/insights/trends/test/test_trends.py @@ -513,6 +513,7 @@ def test_trends_per_day(self): self.assertEqual(response[0]["labels"][5], "2-Jan-2020") self.assertEqual(response[0]["data"][5], 1.0) + @snapshot_clickhouse_queries def test_trends_data_warehouse(self): self._create_events() diff --git a/posthog/hogql_queries/insights/trends/query_builder.py b/posthog/hogql_queries/insights/trends/trends_query_builder.py similarity index 92% rename from posthog/hogql_queries/insights/trends/query_builder.py rename to posthog/hogql_queries/insights/trends/trends_query_builder.py index 9edb400e9d24a..ffbbff9406876 100644 --- a/posthog/hogql_queries/insights/trends/query_builder.py +++ b/posthog/hogql_queries/insights/trends/trends_query_builder.py @@ -13,14 +13,14 @@ from posthog.models.action.action import Action from posthog.models.filters.mixins.utils import cached_property from posthog.models.team.team import Team -from posthog.schema import ActionsNode, HogQLQueryModifiers, TrendsQuery, DataWarehouseNode, SeriesType +from posthog.schema import ActionsNode, EventsNode, HogQLQueryModifiers, TrendsQuery class TrendsQueryBuilder: query: TrendsQuery team: Team query_date_range: QueryDateRange - series: SeriesType + series: EventsNode | ActionsNode timings: HogQLTimings modifiers: HogQLQueryModifiers @@ -29,7 +29,7 @@ def __init__( trends_query: TrendsQuery, team: Team, query_date_range: QueryDateRange, - series: SeriesType, + series: EventsNode | ActionsNode, timings: HogQLTimings, modifiers: HogQLQueryModifiers, ): @@ -158,7 +158,12 @@ def _get_events_subquery( breakdown_values_override: Optional[str | int] = None, actors_query_time_frame: Optional[str | int] = None, ) -> ast.SelectQuery: - day_start = self._day_start_expr + day_start = ast.Alias( + alias="day_start", + expr=ast.Call( + name=f"toStartOf{self.query_date_range.interval_name.title()}", args=[ast.Field(chain=["timestamp"])] + ), + ) events_filter = self._events_filter( ignore_breakdowns=False, @@ -174,20 +179,18 @@ def _get_events_subquery( """ SELECT {aggregation_operation} AS total - FROM {table} AS e + FROM events AS e + SAMPLE {sample} WHERE {events_filter} """, placeholders={ "events_filter": events_filter, "aggregation_operation": self._aggregation_operation.select_aggregation(), - "table": self._table_expr, + "sample": self._sample_value(), }, ), ) - if not isinstance(self.series, DataWarehouseNode): - default_query.select_from.sample = self._sample_value() - default_query.group_by = [] if not self._trends_display.should_aggregate_values() and not is_actors_query: @@ -444,18 +447,12 @@ def _events_filter( filters.extend( [ parse_expr( - "{timestamp_field} >= {date_from_with_adjusted_start_of_interval}", - placeholders={ - "timestamp_field": self._timestamp_field, - **self.query_date_range.to_placeholders(), - }, + "timestamp >= {date_from_with_adjusted_start_of_interval}", + placeholders=self.query_date_range.to_placeholders(), ), parse_expr( - "{timestamp_field} <= {date_to}", - placeholders={ - "timestamp_field": self._timestamp_field, - **self.query_date_range.to_placeholders(), - }, + "timestamp <= {date_to}", + placeholders=self.query_date_range.to_placeholders(), ), ] ) @@ -562,29 +559,3 @@ def _trends_display(self) -> TrendsDisplay: else None ) return TrendsDisplay(display) - - @cached_property - def _day_start_expr(self) -> ast.Expr: - field = ast.Field(chain=["timestamp"]) - - if isinstance(self.series, DataWarehouseNode): - field = ast.Call(name="toDateTime", args=[ast.Field(chain=[self.series.timestamp_field])]) - - return ast.Alias( - alias="day_start", - expr=ast.Call(name=f"toStartOf{self.query_date_range.interval_name.title()}", args=[field]), - ) - - @cached_property - def _table_expr(self) -> ast.Field: - if isinstance(self.series, DataWarehouseNode): - return ast.Field(chain=[self.series.table_name]) - - return ast.Field(chain=["events"]) - - @cached_property - def _timestamp_field(self) -> ast.Field: - if isinstance(self.series, DataWarehouseNode): - return ast.Call(name="toDateTime", args=[ast.Field(chain=[self.series.timestamp_field])]) - - return ast.Field(chain=["timestamp"]) diff --git a/posthog/hogql_queries/insights/trends/trends_query_runner.py b/posthog/hogql_queries/insights/trends/trends_query_runner.py index b6c3ec905f8e1..015cb289af784 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/trends_query_runner.py @@ -28,7 +28,8 @@ BREAKDOWN_OTHER_STRING_LABEL, ) from posthog.hogql_queries.insights.trends.display import TrendsDisplay -from posthog.hogql_queries.insights.trends.query_builder import TrendsQueryBuilder +from posthog.hogql_queries.insights.trends.trends_query_builder import TrendsQueryBuilder +from posthog.hogql_queries.insights.trends.data_warehouse_trends_query_builder import DataWarehouseTrendsQueryBuilder from posthog.hogql_queries.insights.trends.series_with_extras import SeriesWithExtras from posthog.hogql_queries.query_runner import QueryRunner from posthog.hogql_queries.utils.formula_ast import FormulaAST @@ -111,14 +112,25 @@ def to_query(self) -> List[ast.SelectQuery | ast.SelectUnionQuery]: else: query_date_range = self.query_previous_date_range - query_builder = TrendsQueryBuilder( - trends_query=series.overriden_query or self.query, - team=self.team, - query_date_range=query_date_range, - series=series.series, - timings=self.timings, - modifiers=self.modifiers, - ) + if isinstance(series.series, DataWarehouseNode): + query_builder = DataWarehouseTrendsQueryBuilder( + trends_query=series.overriden_query or self.query, + team=self.team, + query_date_range=query_date_range, + series=series.series, + timings=self.timings, + modifiers=self.modifiers, + ) + else: + query_builder = TrendsQueryBuilder( + trends_query=series.overriden_query or self.query, + team=self.team, + query_date_range=query_date_range, + series=series.series, + timings=self.timings, + modifiers=self.modifiers, + ) + queries.append(query_builder.build_query()) return queries