diff --git a/posthog/hogql_queries/insights/trends/display.py b/posthog/hogql_queries/insights/trends/display.py new file mode 100644 index 0000000000000..db0fa29e0045e --- /dev/null +++ b/posthog/hogql_queries/insights/trends/display.py @@ -0,0 +1,68 @@ +from posthog.hogql import ast +from posthog.schema import ChartDisplayType + + +class TrendsDisplay: + display_type: ChartDisplayType + + def __init__(self, display_type: ChartDisplayType) -> None: + self.display_type = display_type + + def should_aggregate_values(self) -> bool: + return ( + self.display_type == ChartDisplayType.BoldNumber + or self.display_type == ChartDisplayType.ActionsPie + or self.display_type == ChartDisplayType.ActionsBarValue + or self.display_type == ChartDisplayType.WorldMap + ) + + def wrap_inner_query(self, inner_query: ast.SelectQuery, breakdown_enabled: bool) -> ast.SelectQuery: + if self.display_type == ChartDisplayType.ActionsLineGraphCumulative: + return self._get_cumulative_query(inner_query, breakdown_enabled) + + return inner_query + + def should_wrap_inner_query(self) -> bool: + return self.display_type == ChartDisplayType.ActionsLineGraphCumulative + + def modify_outer_query(self, outer_query: ast.SelectQuery, inner_query: ast.SelectQuery) -> ast.SelectQuery: + if ( + self.display_type == ChartDisplayType.BoldNumber + or self.display_type == ChartDisplayType.ActionsPie + or self.display_type == ChartDisplayType.WorldMap + ): + return ast.SelectQuery( + select=[ + ast.Alias( + alias="total", + expr=ast.Call(name="sum", args=[ast.Field(chain=["count"])]), + ) + ], + select_from=ast.JoinExpr(table=inner_query), + ) + + return outer_query + + def _get_cumulative_query(self, inner_query: ast.SelectQuery, breakdown_enabled: bool) -> ast.SelectQuery: + if breakdown_enabled: + window_expr = ast.WindowExpr( + order_by=[ast.OrderExpr(expr=ast.Field(chain=["day_start"]), order="ASC")], + partition_by=[ast.Field(chain=["breakdown_value"])], + ) + else: + window_expr = ast.WindowExpr(order_by=[ast.OrderExpr(expr=ast.Field(chain=["day_start"]), order="ASC")]) + + return ast.SelectQuery( + select=[ + ast.Field(chain=["day_start"]), + ast.Alias( + alias="count", + expr=ast.WindowFunction( + name="sum", + args=[ast.Field(chain=["count"])], + over_expr=window_expr, + ), + ), + ], + select_from=ast.JoinExpr(table=inner_query), + ) diff --git a/posthog/hogql_queries/insights/trends/query_builder.py b/posthog/hogql_queries/insights/trends/query_builder.py index 859f6a2f5e691..ddf873f10a0da 100644 --- a/posthog/hogql_queries/insights/trends/query_builder.py +++ b/posthog/hogql_queries/insights/trends/query_builder.py @@ -7,11 +7,12 @@ AggregationOperations, ) from posthog.hogql_queries.insights.trends.breakdown import Breakdown +from posthog.hogql_queries.insights.trends.display import TrendsDisplay from posthog.hogql_queries.insights.trends.utils import series_event_name from posthog.hogql_queries.utils.query_date_range import QueryDateRange from posthog.models.filters.mixins.utils import cached_property from posthog.models.team.team import Team -from posthog.schema import ActionsNode, EventsNode, TrendsQuery +from posthog.schema import ActionsNode, ChartDisplayType, EventsNode, TrendsQuery class TrendsQueryBuilder: @@ -196,6 +197,8 @@ def _outer_select_query(self, inner_query: ast.SelectQuery) -> ast.SelectQuery: ), ) + query = self._trends_display.modify_outer_query(outer_query=query, inner_query=inner_query) + if self._breakdown.enabled: query.select.append(ast.Field(chain=["breakdown_value"])) query.group_by = [ast.Field(chain=["breakdown_value"])] @@ -224,6 +227,11 @@ def _inner_select_query(self, inner_query: ast.SelectUnionQuery) -> ast.SelectQu query.group_by.append(ast.Field(chain=["breakdown_value"])) query.order_by.append(ast.OrderExpr(expr=ast.Field(chain=["breakdown_value"]), order="ASC")) + if self._trends_display.should_wrap_inner_query(): + query = self._trends_display.wrap_inner_query(query, self._breakdown.enabled) + if self._breakdown.enabled: + query.select.append(ast.Field(chain=["breakdown_value"])) + return query def _events_filter(self) -> ast.Expr: @@ -298,5 +306,14 @@ def _breakdown(self): ) @cached_property - def _aggregation_operation(self): + def _aggregation_operation(self) -> AggregationOperations: return AggregationOperations(self.series, self.query_date_range) + + @cached_property + def _trends_display(self) -> TrendsDisplay: + if self.query.trendsFilter is None or self.query.trendsFilter.display is None: + display = ChartDisplayType.ActionsLineGraph + else: + display = self.query.trendsFilter.display + + return TrendsDisplay(display) diff --git a/posthog/hogql_queries/insights/trends/series_with_extras.py b/posthog/hogql_queries/insights/trends/series_with_extras.py index df8ff57fb0e7d..fb63a205f33d0 100644 --- a/posthog/hogql_queries/insights/trends/series_with_extras.py +++ b/posthog/hogql_queries/insights/trends/series_with_extras.py @@ -6,13 +6,16 @@ class SeriesWithExtras: series: EventsNode | ActionsNode is_previous_period_series: Optional[bool] overriden_query: Optional[TrendsQuery] + aggregate_values: Optional[bool] def __init__( self, series: EventsNode | ActionsNode, is_previous_period_series: Optional[bool], overriden_query: Optional[TrendsQuery], + aggregate_values: Optional[bool], ): self.series = series self.is_previous_period_series = is_previous_period_series self.overriden_query = overriden_query + self.aggregate_values = aggregate_values diff --git a/posthog/hogql_queries/insights/trends/test/test_query_builder.py b/posthog/hogql_queries/insights/trends/test/test_query_builder.py new file mode 100644 index 0000000000000..ce65a1605123c --- /dev/null +++ b/posthog/hogql_queries/insights/trends/test/test_query_builder.py @@ -0,0 +1,122 @@ +from datetime import datetime +from freezegun import freeze_time + +from posthog.hogql.query import execute_hogql_query +from posthog.hogql.timings import HogQLTimings +from posthog.hogql_queries.insights.trends.query_builder import TrendsQueryBuilder +from posthog.hogql_queries.utils.query_date_range import QueryDateRange +from posthog.schema import ( + BaseMathType, + BreakdownFilter, + BreakdownType, + ChartDisplayType, + DateRange, + EventsNode, + HogQLQueryResponse, + TrendsFilter, + TrendsQuery, +) +from posthog.test.base import BaseTest, _create_event, _create_person + + +class TestQueryBuilder(BaseTest): + def setUp(self): + super().setUp() + + with freeze_time("2023-02-01"): + _create_person( + distinct_ids=["some_id"], + team_id=self.team.pk, + properties={"$some_prop": "something", "$another_prop": "something"}, + ) + _create_event( + event="$pageview", + team=self.team, + distinct_id="some_id", + properties={"$geoip_country_code": "AU"}, + ) + + def get_response(self, trends_query: TrendsQuery) -> HogQLQueryResponse: + query_date_range = QueryDateRange( + date_range=trends_query.dateRange, + team=self.team, + interval=trends_query.interval, + now=datetime.now(), + ) + + timings = HogQLTimings() + + query_builder = TrendsQueryBuilder( + trends_query=trends_query, + team=self.team, + query_date_range=query_date_range, + series=trends_query.series[0], + timings=timings, + ) + + query = query_builder.build_query() + + return execute_hogql_query( + query_type="TrendsQuery", + query=query, + team=self.team, + timings=timings, + ) + + def test_column_names(self): + trends_query = TrendsQuery( + kind="TrendsQuery", + dateRange=DateRange(date_from="2023-01-01"), + series=[EventsNode(event="$pageview", math=BaseMathType.total)], + ) + + response = self.get_response(trends_query) + + assert response.columns is not None + assert set(response.columns).issubset({"date", "total", "breakdown_value"}) + + def assert_column_names_with_display_type(self, display_type: ChartDisplayType): + trends_query = TrendsQuery( + kind="TrendsQuery", + dateRange=DateRange(date_from="2023-01-01"), + series=[EventsNode(event="$pageview")], + trendsFilter=TrendsFilter(display=display_type), + ) + + response = self.get_response(trends_query) + + assert response.columns is not None + assert set(response.columns).issubset({"date", "total", "breakdown_value"}) + + def assert_column_names_with_display_type_and_breakdowns(self, display_type: ChartDisplayType): + trends_query = TrendsQuery( + kind="TrendsQuery", + dateRange=DateRange(date_from="2023-01-01"), + series=[EventsNode(event="$pageview")], + trendsFilter=TrendsFilter(display=display_type), + breakdown=BreakdownFilter(breakdown="$geoip_country_code", breakdown_type=BreakdownType.event), + ) + + response = self.get_response(trends_query) + + assert response.columns is not None + assert set(response.columns).issubset({"date", "total", "breakdown_value"}) + + def test_column_names_with_display_type(self): + self.assert_column_names_with_display_type(ChartDisplayType.ActionsAreaGraph) + self.assert_column_names_with_display_type(ChartDisplayType.ActionsBar) + self.assert_column_names_with_display_type(ChartDisplayType.ActionsBarValue) + self.assert_column_names_with_display_type(ChartDisplayType.ActionsLineGraph) + self.assert_column_names_with_display_type(ChartDisplayType.ActionsPie) + self.assert_column_names_with_display_type(ChartDisplayType.BoldNumber) + self.assert_column_names_with_display_type(ChartDisplayType.WorldMap) + self.assert_column_names_with_display_type(ChartDisplayType.ActionsLineGraphCumulative) + + def test_column_names_with_display_type_and_breakdowns(self): + self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsAreaGraph) + self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsBar) + self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsBarValue) + self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsLineGraph) + self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsPie) + self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.WorldMap) + self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsLineGraphCumulative) diff --git a/posthog/hogql_queries/insights/trends/trends_query_runner.py b/posthog/hogql_queries/insights/trends/trends_query_runner.py index 7518711e73eb8..a8d25dd1b5081 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/trends_query_runner.py @@ -15,6 +15,7 @@ from posthog.hogql import ast from posthog.hogql.query import execute_hogql_query from posthog.hogql.timings import HogQLTimings +from posthog.hogql_queries.insights.trends.display import TrendsDisplay from posthog.hogql_queries.insights.trends.query_builder import TrendsQueryBuilder from posthog.hogql_queries.insights.trends.series_with_extras import SeriesWithExtras from posthog.hogql_queries.query_runner import QueryRunner @@ -29,6 +30,7 @@ from posthog.models.property_definition import PropertyDefinition from posthog.schema import ( ActionsNode, + ChartDisplayType, EventsNode, HogQLQueryResponse, TrendsQuery, @@ -146,38 +148,73 @@ def build_series_response(self, response: HogQLQueryResponse, series: SeriesWith if response.results is None: return [] + def get_value(name: str, val: Any): + if name not in ["date", "total", "breakdown_value"]: + raise Exception("Column not found in hogql results") + if response.columns is None: + raise Exception("No columns returned from hogql results") + + index = response.columns.index(name) + return val[index] + res = [] for val in response.results: - series_object = { - "data": val[1], - "labels": [ - item.strftime( - "%-d-%b-%Y{}".format(" %H:%M" if self.query_date_range.interval_name == "hour" else "") - ) - for item in val[0] - ], - "days": [ - item.strftime( - "%Y-%m-%d{}".format(" %H:%M:%S" if self.query_date_range.interval_name == "hour" else "") - ) - for item in val[0] - ], - "count": float(sum(val[1])), - "label": "All events" if self.series_event(series.series) is None else self.series_event(series.series), - "filter": self._query_to_filter(), - "action": { # TODO: Populate missing props in `action` - "id": self.series_event(series.series), - "type": "events", - "order": 0, - "name": self.series_event(series.series) or "All events", - "custom_name": None, - "math": series.series.math, - "math_property": None, - "math_hogql": None, - "math_group_type_index": None, - "properties": {}, - }, - } + if series.aggregate_values: + series_object = { + "data": [], + "days": [], + "count": 0, + "aggregated_value": get_value("total", val), + "label": "All events" + if self.series_event(series.series) is None + else self.series_event(series.series), + "filter": self._query_to_filter(), + "action": { # TODO: Populate missing props in `action` + "id": self.series_event(series.series), + "type": "events", + "order": 0, + "name": self.series_event(series.series) or "All events", + "custom_name": None, + "math": series.series.math, + "math_property": None, + "math_hogql": None, + "math_group_type_index": None, + "properties": {}, + }, + } + else: + series_object = { + "data": get_value("total", val), + "labels": [ + item.strftime( + "%-d-%b-%Y{}".format(" %H:%M" if self.query_date_range.interval_name == "hour" else "") + ) + for item in val[0] + ], + "days": [ + item.strftime( + "%Y-%m-%d{}".format(" %H:%M:%S" if self.query_date_range.interval_name == "hour" else "") + ) + for item in val[0] + ], + "count": float(sum(get_value("total", val))), + "label": "All events" + if self.series_event(series.series) is None + else self.series_event(series.series), + "filter": self._query_to_filter(), + "action": { # TODO: Populate missing props in `action` + "id": self.series_event(series.series), + "type": "events", + "order": 0, + "name": self.series_event(series.series) or "All events", + "custom_name": None, + "math": series.series.math, + "math_property": None, + "math_hogql": None, + "math_group_type_index": None, + "properties": {}, + }, + } # Modifications for when comparing to previous period if self.query.trendsFilter is not None and self.query.trendsFilter.compare: @@ -196,18 +233,18 @@ def build_series_response(self, response: HogQLQueryResponse, series: SeriesWith # Modifications for when breakdowns are active if self.query.breakdown is not None and self.query.breakdown.breakdown is not None: if self._is_breakdown_field_boolean(): - remapped_label = self._convert_boolean(val[2]) + remapped_label = self._convert_boolean(get_value("breakdown_value", val)) series_object["label"] = "{} - {}".format(series_object["label"], remapped_label) series_object["breakdown_value"] = remapped_label elif self.query.breakdown.breakdown_type == "cohort": - cohort_id = val[2] + cohort_id = get_value("breakdown_value", val) cohort_name = Cohort.objects.get(pk=cohort_id).name series_object["label"] = "{} - {}".format(series_object["label"], cohort_name) - series_object["breakdown_value"] = val[2] + series_object["breakdown_value"] = get_value("breakdown_value", val) else: - series_object["label"] = "{} - {}".format(series_object["label"], val[2]) - series_object["breakdown_value"] = val[2] + series_object["label"] = "{} - {}".format(series_object["label"], get_value("breakdown_value", val)) + series_object["breakdown_value"] = get_value("breakdown_value", val) res.append(series_object) return res @@ -236,7 +273,15 @@ def series_event(self, series: EventsNode | ActionsNode) -> str | None: return None def setup_series(self) -> List[SeriesWithExtras]: - series_with_extras = [SeriesWithExtras(series, None, None) for series in self.query.series] + series_with_extras = [ + SeriesWithExtras( + series, + None, + None, + self._trends_display.should_aggregate_values(), + ) + for series in self.query.series + ] if self.query.breakdown is not None and self.query.breakdown.breakdown_type == "cohort": updated_series = [] @@ -250,6 +295,7 @@ def setup_series(self) -> List[SeriesWithExtras]: series=series.series, is_previous_period_series=series.is_previous_period_series, overriden_query=copied_query, + aggregate_values=self._trends_display.should_aggregate_values(), ) ) series_with_extras = updated_series @@ -262,6 +308,7 @@ def setup_series(self) -> List[SeriesWithExtras]: series=series.series, is_previous_period_series=False, overriden_query=series.overriden_query, + aggregate_values=self._trends_display.should_aggregate_values(), ) ) updated_series.append( @@ -269,6 +316,7 @@ def setup_series(self) -> List[SeriesWithExtras]: series=series.series, is_previous_period_series=True, overriden_query=series.overriden_query, + aggregate_values=self._trends_display.should_aggregate_values(), ) ) series_with_extras = updated_series @@ -362,3 +410,12 @@ def _query_to_filter(self) -> Dict[str, any]: filter_dict.update(**self.query.breakdown.__dict__) return {k: v for k, v in filter_dict.items() if v is not None} + + @cached_property + def _trends_display(self) -> TrendsDisplay: + if self.query.trendsFilter is None or self.query.trendsFilter.display is None: + display = ChartDisplayType.ActionsLineGraph + else: + display = self.query.trendsFilter.display + + return TrendsDisplay(display)