Skip to content

Commit

Permalink
feat: added trends breakdown by cohort (#18028)
Browse files Browse the repository at this point in the history
* Fixed missing param when building Breakdown

* Added support for breakdowns with hogql expressions

* Added trends breakdown by cohort

* Remove None valued keys from the filter obj
  • Loading branch information
Gilbert09 authored Oct 17, 2023
1 parent 3e7c837 commit acf30da
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 13 deletions.
17 changes: 17 additions & 0 deletions posthog/hogql_queries/insights/trends/breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def placeholders(self) -> Dict[str, ast.Expr]:
def column_expr(self) -> ast.Expr:
if self.is_histogram_breakdown:
return ast.Alias(alias="breakdown_value", expr=self._get_breakdown_histogram_multi_if())
elif self.query.breakdown.breakdown_type == "hogql":
return ast.Alias(
alias="breakdown_value",
expr=parse_expr(self.query.breakdown.breakdown),
)
elif self.query.breakdown.breakdown_type == "cohort":
return ast.Alias(
alias="breakdown_value",
expr=ast.Constant(value=int(self.query.breakdown.breakdown)),
)

if self.query.breakdown.breakdown_type == "hogql":
return ast.Alias(
Expand All @@ -60,6 +70,13 @@ def column_expr(self) -> ast.Expr:
)

def events_where_filter(self) -> ast.Expr:
if self.query.breakdown.breakdown_type == "cohort":
return ast.CompareOperation(
left=ast.Field(chain=["person_id"]),
op=ast.CompareOperationOp.InCohort,
right=ast.Constant(value=int(self.query.breakdown.breakdown)),
)

if self.query.breakdown.breakdown_type == "hogql":
left = parse_expr(self.query.breakdown.breakdown)
else:
Expand Down
9 changes: 6 additions & 3 deletions posthog/hogql_queries/insights/trends/breakdown_values.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Union
from posthog.hogql import ast
from posthog.hogql.parser import parse_expr, parse_select
from posthog.hogql.query import execute_hogql_query
Expand All @@ -10,7 +10,7 @@
class BreakdownValues:
team: Team
event_name: str
breakdown_field: str
breakdown_field: Union[str, float]
breakdown_type: str
query_date_range: QueryDateRange
histogram_bin_count: Optional[int]
Expand All @@ -20,7 +20,7 @@ def __init__(
self,
team: Team,
event_name: str,
breakdown_field: str,
breakdown_field: Union[str, float],
query_date_range: QueryDateRange,
breakdown_type: str,
histogram_bin_count: Optional[float] = None,
Expand All @@ -35,6 +35,9 @@ def __init__(
self.group_type_index = int(group_type_index) if group_type_index is not None else None

def get_breakdown_values(self) -> List[str]:
if self.breakdown_type == "cohort":
return [int(self.breakdown_field)]

if self.breakdown_type == "hogql":
select_field = ast.Alias(
alias="value",
Expand Down
11 changes: 9 additions & 2 deletions posthog/hogql_queries/insights/trends/series_with_extras.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from typing import Optional
from posthog.schema import ActionsNode, EventsNode
from posthog.schema import ActionsNode, EventsNode, TrendsQuery


class SeriesWithExtras:
series: EventsNode | ActionsNode
is_previous_period_series: Optional[bool]
overriden_query: Optional[TrendsQuery]

def __init__(self, series: EventsNode | ActionsNode, is_previous_period_series: Optional[bool]):
def __init__(
self,
series: EventsNode | ActionsNode,
is_previous_period_series: Optional[bool],
overriden_query: Optional[TrendsQuery],
):
self.series = series
self.is_previous_period_series = is_previous_period_series
self.overriden_query = overriden_query
78 changes: 70 additions & 8 deletions posthog/hogql_queries/insights/trends/trends_query_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from datetime import timedelta
from itertools import groupby
from math import ceil
Expand All @@ -18,6 +19,7 @@
from posthog.hogql_queries.utils.query_date_range import QueryDateRange
from posthog.hogql_queries.utils.query_previous_period_date_range import QueryPreviousPeriodDateRange
from posthog.models import Team
from posthog.models.cohort.cohort import Cohort
from posthog.models.filters.mixins.utils import cached_property
from posthog.models.property_definition import PropertyDefinition
from posthog.schema import ActionsNode, EventsNode, HogQLQueryResponse, TrendsQuery, TrendsQueryResponse
Expand Down Expand Up @@ -63,7 +65,13 @@ def to_query(self) -> List[ast.SelectQuery]:
else:
query_date_range = self.query_previous_date_range

query_builder = TrendsQueryBuilder(self.query, self.team, query_date_range, series.series, self.timings)
query_builder = TrendsQueryBuilder(
trends_query=series.overriden_query or self.query,
team=self.team,
query_date_range=query_date_range,
series=series.series,
timings=self.timings,
)
queries.append(query_builder.build_query())

return queries
Expand Down Expand Up @@ -105,6 +113,7 @@ def build_series_response(self, response: HogQLQueryResponse, series: SeriesWith
"days": [item.strftime("%Y-%m-%d") for item in val[0]], # TODO: Add back in hour formatting
"count": float(sum(val[1])),
"label": "All events" if self.series_event(series.series) is None else self.series_event(series.series),
"filter": self._query_to_filter(),
"action": { # TODO: Populate missing props in `action`
"id": self.series_event(series.series),
"type": "events",
Expand Down Expand Up @@ -136,6 +145,12 @@ def build_series_response(self, response: HogQLQueryResponse, series: SeriesWith
remapped_label = self._convert_boolean(val[2])
series_object["label"] = "{} - {}".format(series_object["label"], remapped_label)
series_object["breakdown_value"] = remapped_label
elif self.query.breakdown.breakdown_type == "cohort":
cohort_id = val[2]
cohort_name = Cohort.objects.get(pk=cohort_id).name

series_object["label"] = "{} - {}".format(series_object["label"], cohort_name)
series_object["breakdown_value"] = val[2]
else:
series_object["label"] = "{} - {}".format(series_object["label"], val[2])
series_object["breakdown_value"] = val[2]
Expand All @@ -161,14 +176,40 @@ def series_event(self, series: EventsNode | ActionsNode) -> str | None:
return None

def setup_series(self) -> List[SeriesWithExtras]:
if self.query.trendsFilter is not None and self.query.trendsFilter.compare:
series_with_extras = [SeriesWithExtras(series, None, None) for series in self.query.series]

if self.query.breakdown is not None and self.query.breakdown.breakdown_type == "cohort":
updated_series = []
for series in self.query.series:
updated_series.append(SeriesWithExtras(series, is_previous_period_series=False))
updated_series.append(SeriesWithExtras(series, is_previous_period_series=True))
return updated_series
for cohort_id in self.query.breakdown.breakdown:
for series in series_with_extras:
copied_query = deepcopy(self.query)
copied_query.breakdown.breakdown = cohort_id

updated_series.append(
SeriesWithExtras(
series=series.series,
is_previous_period_series=series.is_previous_period_series,
overriden_query=copied_query,
)
)
series_with_extras = updated_series

return [SeriesWithExtras(series, is_previous_period_series=False) for series in self.query.series]
if self.query.trendsFilter is not None and self.query.trendsFilter.compare:
updated_series = []
for series in series_with_extras:
updated_series.append(
SeriesWithExtras(
series=series.series, is_previous_period_series=False, overriden_query=series.overriden_query
)
)
updated_series.append(
SeriesWithExtras(
series=series.series, is_previous_period_series=True, overriden_query=series.overriden_query
)
)
series_with_extras = updated_series

return series_with_extras

def apply_formula(self, formula: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
if self.query.trendsFilter is not None and self.query.trendsFilter.compare:
Expand Down Expand Up @@ -199,7 +240,7 @@ def apply_formula(self, formula: str, results: List[Dict[str, Any]]) -> List[Dic
return [new_result]

def _is_breakdown_field_boolean(self):
if self.query.breakdown.breakdown_type == "hogql":
if self.query.breakdown.breakdown_type == "hogql" or self.query.breakdown.breakdown_type == "cohort":
return False

if self.query.breakdown.breakdown_type == "person":
Expand All @@ -225,3 +266,24 @@ def _event_property(self, field: str, field_type: PropertyDefinition.Type, group
type=field_type,
group_type_index=group_type_index if field_type == PropertyDefinition.Type.GROUP else None,
).property_type

def _query_to_filter(self) -> Dict[str, any]:
filter_dict = {
"insight": "TRENDS",
"properties": self.query.properties,
"filter_test_accounts": self.query.filterTestAccounts,
"date_to": self.query_date_range.date_to(),
"date_from": self.query_date_range.date_from(),
"entity_type": "events",
"sampling_factor": self.query.samplingFactor,
"aggregation_group_type_index": self.query.aggregation_group_type_index,
"interval": self.query.interval,
}

if self.query.breakdown is not None:
filter_dict.update(self.query.breakdown.__dict__)

if self.query.trendsFilter is not None:
filter_dict.update(self.query.trendsFilter.__dict__)

return {k: v for k, v in filter_dict.items() if v is not None}

0 comments on commit acf30da

Please sign in to comment.