Skip to content

Commit

Permalink
feat: added support for daily and weekly active user aggregations (#1…
Browse files Browse the repository at this point in the history
…8101)

* Added support for daily and weekly active user aggregations

* Removed unused join

* Fixed tests and applied Sample placeholders

* Fixed formula running on an empty string

* Updated the name of QueryModifier

* Comments and fixed name casing
  • Loading branch information
Gilbert09 authored Oct 20, 2023
1 parent 6f707f9 commit af0c46a
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 17 deletions.
149 changes: 149 additions & 0 deletions posthog/hogql_queries/insights/trends/aggregation_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import List
from posthog.hogql import ast
from posthog.hogql.parser import parse_expr, parse_select
from posthog.hogql_queries.utils.query_date_range import QueryDateRange
from posthog.schema import ActionsNode, EventsNode


class QueryAlternator:
"""Allows query_builder to modify the query without having to expost the whole AST interface"""

_query: ast.SelectQuery
_selects: List[ast.Expr]
_group_bys: List[ast.Expr]
_select_from: ast.JoinExpr | None

def __init__(self, query: ast.SelectQuery):
self._query = query
self._selects = []
self._group_bys = []
self._select_from = None

def build(self) -> ast.SelectQuery:
if len(self._selects) > 0:
self._query.select.extend(self._selects)

if len(self._group_bys) > 0:
if self._query.group_by is None:
self._query.group_by = self._group_bys
else:
self._query.group_by.extend(self._group_bys)

if self._select_from is not None:
self._query.select_from = self._select_from

return self._query

def append_select(self, expr: ast.Expr) -> None:
self._selects.append(expr)

def append_group_by(self, expr: ast.Expr) -> None:
self._group_bys.append(expr)

def replace_select_from(self, join_expr: ast.JoinExpr) -> None:
self._select_from = join_expr


class AggregationOperations:
series: EventsNode | ActionsNode
query_date_range: QueryDateRange

def __init__(self, series: str, query_date_range: QueryDateRange) -> None:
self.series = series
self.query_date_range = query_date_range

def select_aggregation(self) -> ast.Expr:
if self.series.math == "hogql":
return parse_expr(self.series.math_hogql)
elif self.series.math == "total":
return parse_expr("count(e.uuid)")
elif self.series.math == "dau":
return parse_expr("count(DISTINCT e.person_id)")
elif self.series.math == "weekly_active":
return ast.Field(chain=["counts"])

return parse_expr("count(e.uuid)")

def requires_query_orchestration(self) -> bool:
return self.series.math == "weekly_active"

def _parent_select_query(self, inner_query: ast.SelectQuery) -> ast.SelectQuery:
return parse_select(
"""
SELECT
counts AS total,
dateTrunc({interval}, timestamp) AS day_start
FROM {inner_query}
WHERE timestamp >= {date_from} AND timestamp <= {date_to}
""",
placeholders={
**self.query_date_range.to_placeholders(),
"inner_query": inner_query,
},
)

def _inner_select_query(self, cross_join_select_query: ast.SelectQuery) -> ast.SelectQuery:
return parse_select(
"""
SELECT
d.timestamp,
COUNT(DISTINCT actor_id) AS counts
FROM (
SELECT
toStartOfDay({date_to}) - toIntervalDay(number) AS timestamp
FROM
numbers(dateDiff('day', toStartOfDay({date_from} - INTERVAL 7 DAY), {date_to}))
) d
CROSS JOIN {cross_join_select_query} e
WHERE
e.timestamp <= d.timestamp + INTERVAL 1 DAY AND
e.timestamp > d.timestamp - INTERVAL 6 DAY
GROUP BY d.timestamp
ORDER BY d.timestamp
""",
placeholders={
**self.query_date_range.to_placeholders(),
"cross_join_select_query": cross_join_select_query,
},
)

def _events_query(self, events_where_clause: ast.Expr, sample_value: ast.RatioExpr) -> ast.SelectQuery:
return parse_select(
"""
SELECT
timestamp as timestamp,
e.person_id AS actor_id
FROM
events e
SAMPLE {sample}
WHERE {events_where_clause}
GROUP BY
timestamp,
actor_id
""",
placeholders={"events_where_clause": events_where_clause, "sample": sample_value},
)

def get_query_orchestrator(self, events_where_clause: ast.Expr, sample_value: str):
events_query = self._events_query(events_where_clause, sample_value)
inner_select = self._inner_select_query(events_query)
parent_select = self._parent_select_query(inner_select)

class QueryOrchestrator:
events_query_builder: QueryAlternator
inner_select_query_builder: QueryAlternator
parent_select_query_builder: QueryAlternator

def __init__(self):
self.events_query_builder = QueryAlternator(events_query)
self.inner_select_query_builder = QueryAlternator(inner_select)
self.parent_select_query_builder = QueryAlternator(parent_select)

def build(self):
self.events_query_builder.build()
self.inner_select_query_builder.build()
self.parent_select_query_builder.build()

return parent_select

return QueryOrchestrator()
57 changes: 41 additions & 16 deletions posthog/hogql_queries/insights/trends/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from posthog.hogql.parser import parse_expr, parse_select
from posthog.hogql.property import property_to_expr
from posthog.hogql.timings import HogQLTimings
from posthog.hogql_queries.insights.trends.aggregation_operations import AggregationOperations
from posthog.hogql_queries.insights.trends.breakdown import Breakdown
from posthog.hogql_queries.insights.trends.breakdown_session import BreakdownSession
from posthog.hogql_queries.insights.trends.utils import series_event_name
Expand Down Expand Up @@ -107,11 +108,11 @@ def _get_date_subqueries(self) -> List[ast.SelectQuery]:
]

def _get_events_subquery(self) -> ast.SelectQuery:
query = parse_select(
default_query = parse_select(
"""
SELECT
{aggregation_operation} AS total,
dateTrunc({interval}, toTimeZone(toDateTime(timestamp), 'UTC')) AS day_start
dateTrunc({interval}, timestamp) AS day_start
FROM events AS e
SAMPLE {sample}
WHERE {events_filter}
Expand All @@ -120,19 +121,46 @@ def _get_events_subquery(self) -> ast.SelectQuery:
placeholders={
**self.query_date_range.to_placeholders(),
"events_filter": self._events_filter(),
"aggregation_operation": self._aggregation_operation(),
"aggregation_operation": self._aggregation_operation.select_aggregation(),
"sample": self._sample_value(),
},
)

if self._breakdown.enabled:
query.select.append(self._breakdown.column_expr())
query.group_by.append(ast.Field(chain=["breakdown_value"]))
# No breakdowns and no complex series aggregation
if not self._breakdown.enabled and not self._aggregation_operation.requires_query_orchestration():
return default_query
# Both breakdowns and complex series aggregation
elif self._breakdown.enabled and self._aggregation_operation.requires_query_orchestration():
orchestrator = self._aggregation_operation.get_query_orchestrator(
events_where_clause=self._events_filter(),
sample_value=self._sample_value(),
)

orchestrator.events_query_builder.append_select(self._breakdown.column_expr())
orchestrator.events_query_builder.append_group_by(ast.Field(chain=["breakdown_value"]))
if self._breakdown.is_session_type:
query.select_from = self._breakdown_session.session_inner_join()
orchestrator.events_query_builder.replace_select_from(self._breakdown_session.session_inner_join())

return query
orchestrator.inner_select_query_builder.append_select(ast.Field(chain=["breakdown_value"]))
orchestrator.inner_select_query_builder.append_group_by(ast.Field(chain=["breakdown_value"]))

orchestrator.parent_select_query_builder.append_select(ast.Field(chain=["breakdown_value"]))

return orchestrator.build()
# Just breakdowns
elif self._breakdown.enabled:
default_query.select.append(self._breakdown.column_expr())
default_query.group_by.append(ast.Field(chain=["breakdown_value"]))

if self._breakdown.is_session_type:
default_query.select_from = self._breakdown_session.session_inner_join()
# Just complex series aggregation
elif self._aggregation_operation.requires_query_orchestration():
return self._aggregation_operation.get_query_orchestrator(
events_where_clause=self._events_filter(), sample_value=self._sample_value()
).build()

return default_query

def _outer_select_query(self, inner_query: ast.SelectQuery) -> ast.SelectQuery:
query = parse_select(
Expand Down Expand Up @@ -234,14 +262,7 @@ def _events_filter(self) -> ast.Expr:
else:
return ast.And(exprs=filters)

def _aggregation_operation(self) -> ast.Expr:
if self.series.math == "hogql":
return parse_expr(self.series.math_hogql)

return parse_expr("count(e.uuid)")

# Using string interpolation for SAMPLE due to HogQL limitations with `UNION ALL` and `SAMPLE` AST nodes
def _sample_value(self) -> str:
def _sample_value(self) -> ast.RatioExpr:
if self.query.samplingFactor is None:
return ast.RatioExpr(left=ast.Constant(value=1))

Expand All @@ -260,3 +281,7 @@ def _breakdown(self):
@cached_property
def _breakdown_session(self):
return BreakdownSession(self.query_date_range)

@cached_property
def _aggregation_operation(self):
return AggregationOperations(self.series, self.query_date_range)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import cast
from posthog.hogql import ast
from posthog.hogql.parser import parse_select
from posthog.hogql_queries.insights.trends.aggregation_operations import QueryAlternator


class TestQueryAlternator:
def test_select(self):
query = parse_select("SELECT event from events")

query_modifier = QueryAlternator(query)
query_modifier.append_select(ast.Field(chain=["test"]))
query_modifier.build()

assert len(query.select) == 2
assert cast(ast.Field, query.select[1]).chain == ["test"]

def test_group_no_pre_existing(self):
query = parse_select("SELECT event from events")

query_modifier = QueryAlternator(query)
query_modifier.append_group_by(ast.Field(chain=["event"]))
query_modifier.build()

assert len(query.group_by) == 1
assert cast(ast.Field, query.group_by[0]).chain == ["event"]

def test_group_with_pre_existing(self):
query = parse_select("SELECT event from events GROUP BY uuid")

query_modifier = QueryAlternator(query)
query_modifier.append_group_by(ast.Field(chain=["event"]))
query_modifier.build()

assert len(query.group_by) == 2
assert cast(ast.Field, query.group_by[0]).chain == ["uuid"]
assert cast(ast.Field, query.group_by[1]).chain == ["event"]

def test_replace_select_from(self):
query = parse_select("SELECT event from events")

query_modifier = QueryAlternator(query)
query_modifier.replace_select_from(ast.JoinExpr(table=ast.Field(chain=["groups"])))
query_modifier.build()

assert query.select_from.table.chain == ["groups"]
6 changes: 5 additions & 1 deletion posthog/hogql_queries/insights/trends/trends_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def calculate(self):

res.extend(self.build_series_response(response, series_with_extra))

if self.query.trendsFilter is not None and self.query.trendsFilter.formula is not None:
if (
self.query.trendsFilter is not None
and self.query.trendsFilter.formula is not None
and self.query.trendsFilter.formula != ""
):
res = self.apply_formula(self.query.trendsFilter.formula, res)

return TrendsQueryResponse(results=res, timings=timings)
Expand Down

0 comments on commit af0c46a

Please sign in to comment.