Skip to content

Commit

Permalink
feat: added the ability to view all chart types on trends (#18311)
Browse files Browse the repository at this point in the history
* Added the final set of aggregation functions

* Added the persons query to trends query

* Added the ability to view all chart types on trends
  • Loading branch information
Gilbert09 authored Nov 10, 2023
1 parent f9bea9b commit 0d26804
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 38 deletions.
68 changes: 68 additions & 0 deletions posthog/hogql_queries/insights/trends/display.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from posthog.hogql import ast
from posthog.schema import ChartDisplayType


class TrendsDisplay:
display_type: ChartDisplayType

def __init__(self, display_type: ChartDisplayType) -> None:
self.display_type = display_type

def should_aggregate_values(self) -> bool:
return (
self.display_type == ChartDisplayType.BoldNumber
or self.display_type == ChartDisplayType.ActionsPie
or self.display_type == ChartDisplayType.ActionsBarValue
or self.display_type == ChartDisplayType.WorldMap
)

def wrap_inner_query(self, inner_query: ast.SelectQuery, breakdown_enabled: bool) -> ast.SelectQuery:
if self.display_type == ChartDisplayType.ActionsLineGraphCumulative:
return self._get_cumulative_query(inner_query, breakdown_enabled)

return inner_query

def should_wrap_inner_query(self) -> bool:
return self.display_type == ChartDisplayType.ActionsLineGraphCumulative

def modify_outer_query(self, outer_query: ast.SelectQuery, inner_query: ast.SelectQuery) -> ast.SelectQuery:
if (
self.display_type == ChartDisplayType.BoldNumber
or self.display_type == ChartDisplayType.ActionsPie
or self.display_type == ChartDisplayType.WorldMap
):
return ast.SelectQuery(
select=[
ast.Alias(
alias="total",
expr=ast.Call(name="sum", args=[ast.Field(chain=["count"])]),
)
],
select_from=ast.JoinExpr(table=inner_query),
)

return outer_query

def _get_cumulative_query(self, inner_query: ast.SelectQuery, breakdown_enabled: bool) -> ast.SelectQuery:
if breakdown_enabled:
window_expr = ast.WindowExpr(
order_by=[ast.OrderExpr(expr=ast.Field(chain=["day_start"]), order="ASC")],
partition_by=[ast.Field(chain=["breakdown_value"])],
)
else:
window_expr = ast.WindowExpr(order_by=[ast.OrderExpr(expr=ast.Field(chain=["day_start"]), order="ASC")])

return ast.SelectQuery(
select=[
ast.Field(chain=["day_start"]),
ast.Alias(
alias="count",
expr=ast.WindowFunction(
name="sum",
args=[ast.Field(chain=["count"])],
over_expr=window_expr,
),
),
],
select_from=ast.JoinExpr(table=inner_query),
)
21 changes: 19 additions & 2 deletions posthog/hogql_queries/insights/trends/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
AggregationOperations,
)
from posthog.hogql_queries.insights.trends.breakdown import Breakdown
from posthog.hogql_queries.insights.trends.display import TrendsDisplay
from posthog.hogql_queries.insights.trends.utils import series_event_name
from posthog.hogql_queries.utils.query_date_range import QueryDateRange
from posthog.models.filters.mixins.utils import cached_property
from posthog.models.team.team import Team
from posthog.schema import ActionsNode, EventsNode, TrendsQuery
from posthog.schema import ActionsNode, ChartDisplayType, EventsNode, TrendsQuery


class TrendsQueryBuilder:
Expand Down Expand Up @@ -196,6 +197,8 @@ def _outer_select_query(self, inner_query: ast.SelectQuery) -> ast.SelectQuery:
),
)

query = self._trends_display.modify_outer_query(outer_query=query, inner_query=inner_query)

if self._breakdown.enabled:
query.select.append(ast.Field(chain=["breakdown_value"]))
query.group_by = [ast.Field(chain=["breakdown_value"])]
Expand Down Expand Up @@ -224,6 +227,11 @@ def _inner_select_query(self, inner_query: ast.SelectUnionQuery) -> ast.SelectQu
query.group_by.append(ast.Field(chain=["breakdown_value"]))
query.order_by.append(ast.OrderExpr(expr=ast.Field(chain=["breakdown_value"]), order="ASC"))

if self._trends_display.should_wrap_inner_query():
query = self._trends_display.wrap_inner_query(query, self._breakdown.enabled)
if self._breakdown.enabled:
query.select.append(ast.Field(chain=["breakdown_value"]))

return query

def _events_filter(self) -> ast.Expr:
Expand Down Expand Up @@ -298,5 +306,14 @@ def _breakdown(self):
)

@cached_property
def _aggregation_operation(self):
def _aggregation_operation(self) -> AggregationOperations:
return AggregationOperations(self.series, self.query_date_range)

@cached_property
def _trends_display(self) -> TrendsDisplay:
if self.query.trendsFilter is None or self.query.trendsFilter.display is None:
display = ChartDisplayType.ActionsLineGraph
else:
display = self.query.trendsFilter.display

return TrendsDisplay(display)
3 changes: 3 additions & 0 deletions posthog/hogql_queries/insights/trends/series_with_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ class SeriesWithExtras:
series: EventsNode | ActionsNode
is_previous_period_series: Optional[bool]
overriden_query: Optional[TrendsQuery]
aggregate_values: Optional[bool]

def __init__(
self,
series: EventsNode | ActionsNode,
is_previous_period_series: Optional[bool],
overriden_query: Optional[TrendsQuery],
aggregate_values: Optional[bool],
):
self.series = series
self.is_previous_period_series = is_previous_period_series
self.overriden_query = overriden_query
self.aggregate_values = aggregate_values
122 changes: 122 additions & 0 deletions posthog/hogql_queries/insights/trends/test/test_query_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from datetime import datetime
from freezegun import freeze_time

from posthog.hogql.query import execute_hogql_query
from posthog.hogql.timings import HogQLTimings
from posthog.hogql_queries.insights.trends.query_builder import TrendsQueryBuilder
from posthog.hogql_queries.utils.query_date_range import QueryDateRange
from posthog.schema import (
BaseMathType,
BreakdownFilter,
BreakdownType,
ChartDisplayType,
DateRange,
EventsNode,
HogQLQueryResponse,
TrendsFilter,
TrendsQuery,
)
from posthog.test.base import BaseTest, _create_event, _create_person


class TestQueryBuilder(BaseTest):
def setUp(self):
super().setUp()

with freeze_time("2023-02-01"):
_create_person(
distinct_ids=["some_id"],
team_id=self.team.pk,
properties={"$some_prop": "something", "$another_prop": "something"},
)
_create_event(
event="$pageview",
team=self.team,
distinct_id="some_id",
properties={"$geoip_country_code": "AU"},
)

def get_response(self, trends_query: TrendsQuery) -> HogQLQueryResponse:
query_date_range = QueryDateRange(
date_range=trends_query.dateRange,
team=self.team,
interval=trends_query.interval,
now=datetime.now(),
)

timings = HogQLTimings()

query_builder = TrendsQueryBuilder(
trends_query=trends_query,
team=self.team,
query_date_range=query_date_range,
series=trends_query.series[0],
timings=timings,
)

query = query_builder.build_query()

return execute_hogql_query(
query_type="TrendsQuery",
query=query,
team=self.team,
timings=timings,
)

def test_column_names(self):
trends_query = TrendsQuery(
kind="TrendsQuery",
dateRange=DateRange(date_from="2023-01-01"),
series=[EventsNode(event="$pageview", math=BaseMathType.total)],
)

response = self.get_response(trends_query)

assert response.columns is not None
assert set(response.columns).issubset({"date", "total", "breakdown_value"})

def assert_column_names_with_display_type(self, display_type: ChartDisplayType):
trends_query = TrendsQuery(
kind="TrendsQuery",
dateRange=DateRange(date_from="2023-01-01"),
series=[EventsNode(event="$pageview")],
trendsFilter=TrendsFilter(display=display_type),
)

response = self.get_response(trends_query)

assert response.columns is not None
assert set(response.columns).issubset({"date", "total", "breakdown_value"})

def assert_column_names_with_display_type_and_breakdowns(self, display_type: ChartDisplayType):
trends_query = TrendsQuery(
kind="TrendsQuery",
dateRange=DateRange(date_from="2023-01-01"),
series=[EventsNode(event="$pageview")],
trendsFilter=TrendsFilter(display=display_type),
breakdown=BreakdownFilter(breakdown="$geoip_country_code", breakdown_type=BreakdownType.event),
)

response = self.get_response(trends_query)

assert response.columns is not None
assert set(response.columns).issubset({"date", "total", "breakdown_value"})

def test_column_names_with_display_type(self):
self.assert_column_names_with_display_type(ChartDisplayType.ActionsAreaGraph)
self.assert_column_names_with_display_type(ChartDisplayType.ActionsBar)
self.assert_column_names_with_display_type(ChartDisplayType.ActionsBarValue)
self.assert_column_names_with_display_type(ChartDisplayType.ActionsLineGraph)
self.assert_column_names_with_display_type(ChartDisplayType.ActionsPie)
self.assert_column_names_with_display_type(ChartDisplayType.BoldNumber)
self.assert_column_names_with_display_type(ChartDisplayType.WorldMap)
self.assert_column_names_with_display_type(ChartDisplayType.ActionsLineGraphCumulative)

def test_column_names_with_display_type_and_breakdowns(self):
self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsAreaGraph)
self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsBar)
self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsBarValue)
self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsLineGraph)
self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsPie)
self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.WorldMap)
self.assert_column_names_with_display_type_and_breakdowns(ChartDisplayType.ActionsLineGraphCumulative)
Loading

0 comments on commit 0d26804

Please sign in to comment.