From acf30da01900ff3b6103eeab7a0a9ba1d2439f30 Mon Sep 17 00:00:00 2001 From: Tom Owers Date: Tue, 17 Oct 2023 16:42:29 +0100 Subject: [PATCH] feat: added trends breakdown by cohort (#18028) * Fixed missing param when building Breakdown * Added support for breakdowns with hogql expressions * Added trends breakdown by cohort * Remove None valued keys from the filter obj --- .../insights/trends/breakdown.py | 17 ++++ .../insights/trends/breakdown_values.py | 9 ++- .../insights/trends/series_with_extras.py | 11 ++- .../insights/trends/trends_query_runner.py | 78 +++++++++++++++++-- 4 files changed, 102 insertions(+), 13 deletions(-) diff --git a/posthog/hogql_queries/insights/trends/breakdown.py b/posthog/hogql_queries/insights/trends/breakdown.py index b1f48fa51f603..f58f27ff2e6b6 100644 --- a/posthog/hogql_queries/insights/trends/breakdown.py +++ b/posthog/hogql_queries/insights/trends/breakdown.py @@ -47,6 +47,16 @@ def placeholders(self) -> Dict[str, ast.Expr]: def column_expr(self) -> ast.Expr: if self.is_histogram_breakdown: return ast.Alias(alias="breakdown_value", expr=self._get_breakdown_histogram_multi_if()) + elif self.query.breakdown.breakdown_type == "hogql": + return ast.Alias( + alias="breakdown_value", + expr=parse_expr(self.query.breakdown.breakdown), + ) + elif self.query.breakdown.breakdown_type == "cohort": + return ast.Alias( + alias="breakdown_value", + expr=ast.Constant(value=int(self.query.breakdown.breakdown)), + ) if self.query.breakdown.breakdown_type == "hogql": return ast.Alias( @@ -60,6 +70,13 @@ def column_expr(self) -> ast.Expr: ) def events_where_filter(self) -> ast.Expr: + if self.query.breakdown.breakdown_type == "cohort": + return ast.CompareOperation( + left=ast.Field(chain=["person_id"]), + op=ast.CompareOperationOp.InCohort, + right=ast.Constant(value=int(self.query.breakdown.breakdown)), + ) + if self.query.breakdown.breakdown_type == "hogql": left = parse_expr(self.query.breakdown.breakdown) else: diff --git a/posthog/hogql_queries/insights/trends/breakdown_values.py b/posthog/hogql_queries/insights/trends/breakdown_values.py index 1bcb37ea4e46a..118bd0f1dc56a 100644 --- a/posthog/hogql_queries/insights/trends/breakdown_values.py +++ b/posthog/hogql_queries/insights/trends/breakdown_values.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Union from posthog.hogql import ast from posthog.hogql.parser import parse_expr, parse_select from posthog.hogql.query import execute_hogql_query @@ -10,7 +10,7 @@ class BreakdownValues: team: Team event_name: str - breakdown_field: str + breakdown_field: Union[str, float] breakdown_type: str query_date_range: QueryDateRange histogram_bin_count: Optional[int] @@ -20,7 +20,7 @@ def __init__( self, team: Team, event_name: str, - breakdown_field: str, + breakdown_field: Union[str, float], query_date_range: QueryDateRange, breakdown_type: str, histogram_bin_count: Optional[float] = None, @@ -35,6 +35,9 @@ def __init__( self.group_type_index = int(group_type_index) if group_type_index is not None else None def get_breakdown_values(self) -> List[str]: + if self.breakdown_type == "cohort": + return [int(self.breakdown_field)] + if self.breakdown_type == "hogql": select_field = ast.Alias( alias="value", diff --git a/posthog/hogql_queries/insights/trends/series_with_extras.py b/posthog/hogql_queries/insights/trends/series_with_extras.py index e95035fa907f0..df8ff57fb0e7d 100644 --- a/posthog/hogql_queries/insights/trends/series_with_extras.py +++ b/posthog/hogql_queries/insights/trends/series_with_extras.py @@ -1,11 +1,18 @@ from typing import Optional -from posthog.schema import ActionsNode, EventsNode +from posthog.schema import ActionsNode, EventsNode, TrendsQuery class SeriesWithExtras: series: EventsNode | ActionsNode is_previous_period_series: Optional[bool] + overriden_query: Optional[TrendsQuery] - def __init__(self, series: EventsNode | ActionsNode, is_previous_period_series: Optional[bool]): + def __init__( + self, + series: EventsNode | ActionsNode, + is_previous_period_series: Optional[bool], + overriden_query: Optional[TrendsQuery], + ): self.series = series self.is_previous_period_series = is_previous_period_series + self.overriden_query = overriden_query diff --git a/posthog/hogql_queries/insights/trends/trends_query_runner.py b/posthog/hogql_queries/insights/trends/trends_query_runner.py index 6e6dc8a1c82ba..675f320b2f7ac 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/trends_query_runner.py @@ -1,3 +1,4 @@ +from copy import deepcopy from datetime import timedelta from itertools import groupby from math import ceil @@ -18,6 +19,7 @@ from posthog.hogql_queries.utils.query_date_range import QueryDateRange from posthog.hogql_queries.utils.query_previous_period_date_range import QueryPreviousPeriodDateRange from posthog.models import Team +from posthog.models.cohort.cohort import Cohort from posthog.models.filters.mixins.utils import cached_property from posthog.models.property_definition import PropertyDefinition from posthog.schema import ActionsNode, EventsNode, HogQLQueryResponse, TrendsQuery, TrendsQueryResponse @@ -63,7 +65,13 @@ def to_query(self) -> List[ast.SelectQuery]: else: query_date_range = self.query_previous_date_range - query_builder = TrendsQueryBuilder(self.query, self.team, query_date_range, series.series, self.timings) + query_builder = TrendsQueryBuilder( + trends_query=series.overriden_query or self.query, + team=self.team, + query_date_range=query_date_range, + series=series.series, + timings=self.timings, + ) queries.append(query_builder.build_query()) return queries @@ -105,6 +113,7 @@ def build_series_response(self, response: HogQLQueryResponse, series: SeriesWith "days": [item.strftime("%Y-%m-%d") for item in val[0]], # TODO: Add back in hour formatting "count": float(sum(val[1])), "label": "All events" if self.series_event(series.series) is None else self.series_event(series.series), + "filter": self._query_to_filter(), "action": { # TODO: Populate missing props in `action` "id": self.series_event(series.series), "type": "events", @@ -136,6 +145,12 @@ def build_series_response(self, response: HogQLQueryResponse, series: SeriesWith remapped_label = self._convert_boolean(val[2]) series_object["label"] = "{} - {}".format(series_object["label"], remapped_label) series_object["breakdown_value"] = remapped_label + elif self.query.breakdown.breakdown_type == "cohort": + cohort_id = val[2] + cohort_name = Cohort.objects.get(pk=cohort_id).name + + series_object["label"] = "{} - {}".format(series_object["label"], cohort_name) + series_object["breakdown_value"] = val[2] else: series_object["label"] = "{} - {}".format(series_object["label"], val[2]) series_object["breakdown_value"] = val[2] @@ -161,14 +176,40 @@ def series_event(self, series: EventsNode | ActionsNode) -> str | None: return None def setup_series(self) -> List[SeriesWithExtras]: - if self.query.trendsFilter is not None and self.query.trendsFilter.compare: + series_with_extras = [SeriesWithExtras(series, None, None) for series in self.query.series] + + if self.query.breakdown is not None and self.query.breakdown.breakdown_type == "cohort": updated_series = [] - for series in self.query.series: - updated_series.append(SeriesWithExtras(series, is_previous_period_series=False)) - updated_series.append(SeriesWithExtras(series, is_previous_period_series=True)) - return updated_series + for cohort_id in self.query.breakdown.breakdown: + for series in series_with_extras: + copied_query = deepcopy(self.query) + copied_query.breakdown.breakdown = cohort_id + + updated_series.append( + SeriesWithExtras( + series=series.series, + is_previous_period_series=series.is_previous_period_series, + overriden_query=copied_query, + ) + ) + series_with_extras = updated_series - return [SeriesWithExtras(series, is_previous_period_series=False) for series in self.query.series] + if self.query.trendsFilter is not None and self.query.trendsFilter.compare: + updated_series = [] + for series in series_with_extras: + updated_series.append( + SeriesWithExtras( + series=series.series, is_previous_period_series=False, overriden_query=series.overriden_query + ) + ) + updated_series.append( + SeriesWithExtras( + series=series.series, is_previous_period_series=True, overriden_query=series.overriden_query + ) + ) + series_with_extras = updated_series + + return series_with_extras def apply_formula(self, formula: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: if self.query.trendsFilter is not None and self.query.trendsFilter.compare: @@ -199,7 +240,7 @@ def apply_formula(self, formula: str, results: List[Dict[str, Any]]) -> List[Dic return [new_result] def _is_breakdown_field_boolean(self): - if self.query.breakdown.breakdown_type == "hogql": + if self.query.breakdown.breakdown_type == "hogql" or self.query.breakdown.breakdown_type == "cohort": return False if self.query.breakdown.breakdown_type == "person": @@ -225,3 +266,24 @@ def _event_property(self, field: str, field_type: PropertyDefinition.Type, group type=field_type, group_type_index=group_type_index if field_type == PropertyDefinition.Type.GROUP else None, ).property_type + + def _query_to_filter(self) -> Dict[str, any]: + filter_dict = { + "insight": "TRENDS", + "properties": self.query.properties, + "filter_test_accounts": self.query.filterTestAccounts, + "date_to": self.query_date_range.date_to(), + "date_from": self.query_date_range.date_from(), + "entity_type": "events", + "sampling_factor": self.query.samplingFactor, + "aggregation_group_type_index": self.query.aggregation_group_type_index, + "interval": self.query.interval, + } + + if self.query.breakdown is not None: + filter_dict.update(self.query.breakdown.__dict__) + + if self.query.trendsFilter is not None: + filter_dict.update(self.query.trendsFilter.__dict__) + + return {k: v for k, v in filter_dict.items() if v is not None}