Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added trends breakdown by cohort #18028

Merged
merged 5 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions posthog/hogql_queries/insights/trends/breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def placeholders(self) -> Dict[str, ast.Expr]:
def column_expr(self) -> ast.Expr:
if self.is_histogram_breakdown:
return ast.Alias(alias="breakdown_value", expr=self._get_breakdown_histogram_multi_if())
elif self.query.breakdown.breakdown_type == "hogql":
return ast.Alias(
alias="breakdown_value",
expr=parse_expr(self.query.breakdown.breakdown),
)
elif self.query.breakdown.breakdown_type == "cohort":
return ast.Alias(
alias="breakdown_value",
expr=ast.Constant(value=int(self.query.breakdown.breakdown)),
)

if self.query.breakdown.breakdown_type == "hogql":
return ast.Alias(
Expand All @@ -60,6 +70,13 @@ def column_expr(self) -> ast.Expr:
)

def events_where_filter(self) -> ast.Expr:
if self.query.breakdown.breakdown_type == "cohort":
return ast.CompareOperation(
left=ast.Field(chain=["person_id"]),
op=ast.CompareOperationOp.InCohort,
right=ast.Constant(value=int(self.query.breakdown.breakdown)),
)

if self.query.breakdown.breakdown_type == "hogql":
left = parse_expr(self.query.breakdown.breakdown)
else:
Expand Down
9 changes: 6 additions & 3 deletions posthog/hogql_queries/insights/trends/breakdown_values.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Union
from posthog.hogql import ast
from posthog.hogql.parser import parse_expr, parse_select
from posthog.hogql.query import execute_hogql_query
Expand All @@ -10,7 +10,7 @@
class BreakdownValues:
team: Team
event_name: str
breakdown_field: str
breakdown_field: Union[str, float]
breakdown_type: str
query_date_range: QueryDateRange
histogram_bin_count: Optional[int]
Expand All @@ -20,7 +20,7 @@ def __init__(
self,
team: Team,
event_name: str,
breakdown_field: str,
breakdown_field: Union[str, float],
query_date_range: QueryDateRange,
breakdown_type: str,
histogram_bin_count: Optional[float] = None,
Expand All @@ -35,6 +35,9 @@ def __init__(
self.group_type_index = int(group_type_index) if group_type_index is not None else None

def get_breakdown_values(self) -> List[str]:
if self.breakdown_type == "cohort":
return [int(self.breakdown_field)]

if self.breakdown_type == "hogql":
select_field = ast.Alias(
alias="value",
Expand Down
11 changes: 9 additions & 2 deletions posthog/hogql_queries/insights/trends/series_with_extras.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from typing import Optional
from posthog.schema import ActionsNode, EventsNode
from posthog.schema import ActionsNode, EventsNode, TrendsQuery


class SeriesWithExtras:
series: EventsNode | ActionsNode
is_previous_period_series: Optional[bool]
overriden_query: Optional[TrendsQuery]

def __init__(self, series: EventsNode | ActionsNode, is_previous_period_series: Optional[bool]):
def __init__(
self,
series: EventsNode | ActionsNode,
is_previous_period_series: Optional[bool],
overriden_query: Optional[TrendsQuery],
):
self.series = series
self.is_previous_period_series = is_previous_period_series
self.overriden_query = overriden_query
78 changes: 70 additions & 8 deletions posthog/hogql_queries/insights/trends/trends_query_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from datetime import timedelta
from itertools import groupby
from math import ceil
Expand All @@ -18,6 +19,7 @@
from posthog.hogql_queries.utils.query_date_range import QueryDateRange
from posthog.hogql_queries.utils.query_previous_period_date_range import QueryPreviousPeriodDateRange
from posthog.models import Team
from posthog.models.cohort.cohort import Cohort
from posthog.models.filters.mixins.utils import cached_property
from posthog.models.property_definition import PropertyDefinition
from posthog.schema import ActionsNode, EventsNode, HogQLQueryResponse, TrendsQuery, TrendsQueryResponse
Expand Down Expand Up @@ -63,7 +65,13 @@ def to_query(self) -> List[ast.SelectQuery]:
else:
query_date_range = self.query_previous_date_range

query_builder = TrendsQueryBuilder(self.query, self.team, query_date_range, series.series, self.timings)
query_builder = TrendsQueryBuilder(
trends_query=series.overriden_query or self.query,
team=self.team,
query_date_range=query_date_range,
series=series.series,
timings=self.timings,
)
queries.append(query_builder.build_query())

return queries
Expand Down Expand Up @@ -105,6 +113,7 @@ def build_series_response(self, response: HogQLQueryResponse, series: SeriesWith
"days": [item.strftime("%Y-%m-%d") for item in val[0]], # TODO: Add back in hour formatting
"count": float(sum(val[1])),
"label": "All events" if self.series_event(series.series) is None else self.series_event(series.series),
"filter": self._query_to_filter(),
"action": { # TODO: Populate missing props in `action`
"id": self.series_event(series.series),
"type": "events",
Expand Down Expand Up @@ -136,6 +145,12 @@ def build_series_response(self, response: HogQLQueryResponse, series: SeriesWith
remapped_label = self._convert_boolean(val[2])
series_object["label"] = "{} - {}".format(series_object["label"], remapped_label)
series_object["breakdown_value"] = remapped_label
elif self.query.breakdown.breakdown_type == "cohort":
cohort_id = val[2]
cohort_name = Cohort.objects.get(pk=cohort_id).name

series_object["label"] = "{} - {}".format(series_object["label"], cohort_name)
series_object["breakdown_value"] = val[2]
else:
series_object["label"] = "{} - {}".format(series_object["label"], val[2])
series_object["breakdown_value"] = val[2]
Expand All @@ -161,14 +176,40 @@ def series_event(self, series: EventsNode | ActionsNode) -> str | None:
return None

def setup_series(self) -> List[SeriesWithExtras]:
if self.query.trendsFilter is not None and self.query.trendsFilter.compare:
series_with_extras = [SeriesWithExtras(series, None, None) for series in self.query.series]

if self.query.breakdown is not None and self.query.breakdown.breakdown_type == "cohort":
updated_series = []
for series in self.query.series:
updated_series.append(SeriesWithExtras(series, is_previous_period_series=False))
updated_series.append(SeriesWithExtras(series, is_previous_period_series=True))
return updated_series
for cohort_id in self.query.breakdown.breakdown:
for series in series_with_extras:
copied_query = deepcopy(self.query)
copied_query.breakdown.breakdown = cohort_id

updated_series.append(
SeriesWithExtras(
series=series.series,
is_previous_period_series=series.is_previous_period_series,
overriden_query=copied_query,
)
)
series_with_extras = updated_series

return [SeriesWithExtras(series, is_previous_period_series=False) for series in self.query.series]
if self.query.trendsFilter is not None and self.query.trendsFilter.compare:
updated_series = []
for series in series_with_extras:
updated_series.append(
SeriesWithExtras(
series=series.series, is_previous_period_series=False, overriden_query=series.overriden_query
)
)
updated_series.append(
SeriesWithExtras(
series=series.series, is_previous_period_series=True, overriden_query=series.overriden_query
)
)
series_with_extras = updated_series

return series_with_extras

def apply_formula(self, formula: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
if self.query.trendsFilter is not None and self.query.trendsFilter.compare:
Expand Down Expand Up @@ -199,7 +240,7 @@ def apply_formula(self, formula: str, results: List[Dict[str, Any]]) -> List[Dic
return [new_result]

def _is_breakdown_field_boolean(self):
if self.query.breakdown.breakdown_type == "hogql":
if self.query.breakdown.breakdown_type == "hogql" or self.query.breakdown.breakdown_type == "cohort":
return False

if self.query.breakdown.breakdown_type == "person":
Expand All @@ -225,3 +266,24 @@ def _event_property(self, field: str, field_type: PropertyDefinition.Type, group
type=field_type,
group_type_index=group_type_index if field_type == PropertyDefinition.Type.GROUP else None,
).property_type

def _query_to_filter(self) -> Dict[str, any]:
filter_dict = {
"insight": "TRENDS",
"properties": self.query.properties,
"filter_test_accounts": self.query.filterTestAccounts,
"date_to": self.query_date_range.date_to(),
"date_from": self.query_date_range.date_from(),
"entity_type": "events",
"sampling_factor": self.query.samplingFactor,
"aggregation_group_type_index": self.query.aggregation_group_type_index,
"interval": self.query.interval,
}

if self.query.breakdown is not None:
filter_dict.update(self.query.breakdown.__dict__)

if self.query.trendsFilter is not None:
filter_dict.update(self.query.trendsFilter.__dict__)

return {k: v for k, v in filter_dict.items() if v is not None}
Loading