diff --git a/posthog/hogql_queries/insights/trends/trends_query_builder.py b/posthog/hogql_queries/insights/trends/trends_query_builder.py index ed5d867b48b75..9e4d27c369479 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_builder.py +++ b/posthog/hogql_queries/insights/trends/trends_query_builder.py @@ -182,31 +182,15 @@ def _get_events_subquery( actors_query_time_frame=actors_query_time_frame, ) - default_query = cast( - ast.SelectQuery, - parse_select( - """ - SELECT - {aggregation_operation} AS total - FROM {table} AS e - WHERE {events_filter} - """ - if isinstance(self.series, DataWarehouseNode) - else """ - SELECT - {aggregation_operation} AS total - FROM {table} AS e - SAMPLE {sample} - WHERE {events_filter} - """, - placeholders={ - "table": self._table_expr, - "events_filter": events_filter, - "aggregation_operation": self._aggregation_operation.select_aggregation(), - "sample": self._sample_value(), - }, - ), + default_query = ast.SelectQuery( + select=[ast.Alias(alias="total", expr=self._aggregation_operation.select_aggregation())], + select_from=ast.JoinExpr(table=self._table_expr, alias="e"), + where=events_filter, ) + if not isinstance(self.series, DataWarehouseNode): + default_query.select_from.sample = ast.SampleExpr( + sample_value=self._sample_value(), + ) default_query.group_by = [] @@ -461,13 +445,20 @@ def _events_filter( # Dates if is_actors_query and actors_query_time_frame is not None: - to_start_of_time_frame = f"toStartOf{self.query_date_range.interval_name.capitalize()}" - filters.append( - ast.CompareOperation( - left=ast.Call(name=to_start_of_time_frame, args=[ast.Field(chain=["timestamp"])]), - op=ast.CompareOperationOp.Eq, - right=ast.Call(name="toDateTime", args=[ast.Constant(value=actors_query_time_frame)]), - ) + actors_from, actors_to = self.query_date_range.interval_bounds_from_str(actors_query_time_frame) + filters.extend( + [ + ast.CompareOperation( + left=ast.Field(chain=["timestamp"]), + op=ast.CompareOperationOp.GtEq, + right=ast.Call(name="toDateTime", args=[ast.Constant(value=actors_from)]), + ), + ast.CompareOperation( + left=ast.Field(chain=["timestamp"]), + op=ast.CompareOperationOp.Lt, + right=ast.Call(name="toDateTime", args=[ast.Constant(value=actors_to)]), + ), + ] ) elif not self._aggregation_operation.requires_query_orchestration(): filters.extend( diff --git a/posthog/hogql_queries/utils/query_date_range.py b/posthog/hogql_queries/utils/query_date_range.py index f2e5cef3d82a3..9289566b8eeaa 100644 --- a/posthog/hogql_queries/utils/query_date_range.py +++ b/posthog/hogql_queries/utils/query_date_range.py @@ -4,6 +4,7 @@ from typing import Literal, Optional, Dict, List from zoneinfo import ZoneInfo +from dateutil.parser import parse from dateutil.relativedelta import relativedelta from posthog.hogql.errors import HogQLException @@ -116,36 +117,39 @@ def interval_type(self) -> IntervalType: def interval_name(self) -> Literal["hour", "day", "week", "month"]: return self.interval_type.name - def all_values(self) -> List[str]: - start: datetime = self.date_from() - end: datetime = self.date_to() - interval = self.interval_name - - if interval == "hour": - start = start.replace(minute=0, second=0, microsecond=0) - elif interval == "day": - start = start.replace(hour=0, minute=0, second=0, microsecond=0) - elif interval == "week": + def align_with_interval(self, start: datetime) -> datetime: + if self.interval_name == "hour": + return start.replace(minute=0, second=0, microsecond=0) + elif self.interval_name == "day": + return start.replace(hour=0, minute=0, second=0, microsecond=0) + elif self.interval_name == "week": start = start.replace(hour=0, minute=0, second=0, microsecond=0) week_start_alignment_days = start.isoweekday() % 7 if self._team.week_start_day == WeekStartDay.MONDAY: week_start_alignment_days = start.weekday() start -= timedelta(days=week_start_alignment_days) - elif interval == "month": - start = start.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + return start + elif self.interval_name == "month": + return start.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + def interval_relativedelta(self): + return relativedelta( + days=1 if self.interval_name == "day" else 0, + weeks=1 if self.interval_name == "week" else 0, + months=1 if self.interval_name == "month" else 0, + hours=1 if self.interval_name == "hour" else 0, + ) + def all_values(self) -> List[str]: + start = self.align_with_interval(self.date_from()) + end: datetime = self.date_to() values: List[str] = [] while start <= end: - if interval == "hour": + if self.interval_name == "hour": values.append(start.strftime("%Y-%m-%d %H:%M:%S")) else: values.append(start.strftime("%Y-%m-%d")) - start += relativedelta( - days=1 if interval == "day" else 0, - weeks=1 if interval == "week" else 0, - months=1 if interval == "month" else 0, - hours=1 if interval == "hour" else 0, - ) + start += self.interval_relativedelta() return values def date_to_as_hogql(self) -> ast.Expr: @@ -257,6 +261,11 @@ def to_placeholders(self) -> Dict[str, ast.Expr]: else self.date_from_as_hogql(), } + def interval_bounds_from_str(self, time_frame: str) -> tuple[datetime, datetime]: + date_from = parse(time_frame) + date_to = date_from + self.interval_relativedelta() + return date_from, date_to + class QueryDateRangeWithIntervals(QueryDateRange): def __init__(