Skip to content

Commit

Permalink
chore(hogql): Use paginator in EventsQueryRunner and add tests for it (
Browse files Browse the repository at this point in the history
  • Loading branch information
webjunkie authored Dec 21, 2023
1 parent 7a34700 commit 41fa973
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 42 deletions.
8 changes: 7 additions & 1 deletion frontend/src/queries/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,12 @@
"hogql": {
"type": "string"
},
"limit": {
"type": "integer"
},
"offset": {
"type": "integer"
},
"results": {
"items": {
"items": {},
Expand All @@ -1067,7 +1073,7 @@
"type": "array"
}
},
"required": ["columns", "types", "results", "hogql"],
"required": ["columns", "types", "results", "hogql", "limit", "offset"],
"type": "object"
},
"FeaturePropertyFilter": {
Expand Down
4 changes: 4 additions & 0 deletions frontend/src/queries/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ export interface EventsQueryResponse {
hogql: string
hasMore?: boolean
timings?: QueryTiming[]
/** @asType integer */
limit: number
/** @asType integer */
offset: number
}
export interface EventsQueryPersonColumn {
uuid: string
Expand Down
47 changes: 20 additions & 27 deletions posthog/hogql_queries/events_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from posthog.api.utils import get_pk_or_uuid
from posthog.clickhouse.client.connection import Workload
from posthog.hogql import ast
from posthog.hogql.constants import get_max_limit_for_context, get_default_limit_for_context
from posthog.hogql.parser import parse_expr, parse_order_expr
from posthog.hogql.property import action_to_expr, has_aggregation, property_to_expr
from posthog.hogql.query import execute_hogql_query
from posthog.hogql.timings import HogQLTimings
from posthog.hogql_queries.insights.paginators import HogQLHasMorePaginator
from posthog.hogql_queries.query_runner import QueryRunner
from posthog.models import Action, Person
from posthog.models.element import chain_to_elements
Expand All @@ -40,15 +39,18 @@ class EventsQueryRunner(QueryRunner):
query: EventsQuery
query_type = EventsQuery

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.paginator = HogQLHasMorePaginator.from_limit_context(
limit_context=self.limit_context, limit=self.query.limit, offset=self.query.offset
)

def to_query(self) -> ast.SelectQuery:
# Note: This code is inefficient and problematic, see https://github.com/PostHog/posthog/issues/13485 for details.
if self.timings is None:
self.timings = HogQLTimings()

with self.timings.measure("build_ast"):
# limit & offset
offset = 0 if self.query.offset is None else self.query.offset

# columns & group_by
with self.timings.measure("columns"):
select_input: List[str] = []
Expand Down Expand Up @@ -175,13 +177,11 @@ def to_query(self) -> ast.SelectQuery:
having=having,
group_by=group_by if has_any_aggregation else None,
order_by=order_by,
limit=ast.Constant(value=self.limit()),
offset=ast.Constant(value=offset),
)
return stmt

def calculate(self) -> EventsQueryResponse:
query_result = execute_hogql_query(
query_result = self.paginator.execute_hogql_query(
query=self.to_query(),
team=self.team,
workload=Workload.ONLINE,
Expand All @@ -195,27 +195,27 @@ def calculate(self) -> EventsQueryResponse:
if "*" in self.select_input_raw():
with self.timings.measure("expand_asterisk"):
star_idx = self.select_input_raw().index("*")
for index, result in enumerate(query_result.results):
query_result.results[index] = list(result)
for index, result in enumerate(self.paginator.results):
self.paginator.results[index] = list(result)
select = result[star_idx]
new_result = dict(zip(SELECT_STAR_FROM_EVENTS_FIELDS, select))
new_result["properties"] = json.loads(new_result["properties"])
if new_result["elements_chain"]:
new_result["elements"] = ElementSerializer(
chain_to_elements(new_result["elements_chain"]), many=True
).data
query_result.results[index][star_idx] = new_result
self.paginator.results[index][star_idx] = new_result

person_indices: List[int] = []
for index, col in enumerate(self.select_input_raw()):
if col.split("--")[0].strip() == "person":
person_indices.append(index)

if len(person_indices) > 0 and len(query_result.results) > 0:
if len(person_indices) > 0 and len(self.paginator.results) > 0:
with self.timings.measure("person_column_extra_query"):
# Make a query into postgres to fetch person
person_idx = person_indices[0]
distinct_ids = list(set(event[person_idx] for event in query_result.results))
distinct_ids = list(set(event[person_idx] for event in self.paginator.results))
persons = get_persons_by_distinct_ids(self.team.pk, distinct_ids)
persons = persons.prefetch_related(Prefetch("persondistinctid_set", to_attr="distinct_ids_cache"))
distinct_to_person: Dict[str, Person] = {}
Expand All @@ -226,41 +226,34 @@ def calculate(self) -> EventsQueryResponse:

# Loop over all columns in case there is more than one "person" column
for column_index in person_indices:
for index, result in enumerate(query_result.results):
for index, result in enumerate(self.paginator.results):
distinct_id: str = result[column_index]
query_result.results[index] = list(result)
self.paginator.results[index] = list(result)
if distinct_to_person.get(distinct_id):
person = distinct_to_person[distinct_id]
query_result.results[index][column_index] = {
self.paginator.results[index][column_index] = {
"uuid": person.uuid,
"created_at": person.created_at,
"properties": person.properties or {},
"distinct_id": distinct_id,
}
else:
query_result.results[index][column_index] = {
self.paginator.results[index][column_index] = {
"distinct_id": distinct_id,
}

received_extra_row = len(query_result.results) == self.limit() # limit was +=1'd above
return EventsQueryResponse(
results=query_result.results[: self.limit() - 1] if received_extra_row else query_result.results,
results=self.paginator.results,
columns=self.select_input_raw(),
types=[type for _, type in query_result.types],
hasMore=received_extra_row,
types=[t for _, t in query_result.types] if query_result.types else None,
timings=self.timings.to_list(),
hogql=query_result.hogql,
**self.paginator.response_params(),
)

def select_input_raw(self) -> List[str]:
return ["*"] if len(self.query.select) == 0 else self.query.select

def limit(self) -> int:
# adding +1 to the limit to check if there's a "next page" after the requested results
max_rows = get_max_limit_for_context(self.limit_context)
default_rows = get_default_limit_for_context(self.limit_context)
return min(max_rows, default_rows if self.query.limit is None else self.query.limit) + 1

def _is_stale(self, cached_result_package):
return True

Expand Down
29 changes: 22 additions & 7 deletions posthog/hogql_queries/insights/paginators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from typing import List, Any, Optional, cast, Sequence
from typing import Any, Optional, cast

from posthog.hogql import ast
from posthog.hogql.constants import (
get_max_limit_for_context,
get_default_limit_for_context,
LimitContext,
DEFAULT_RETURNED_ROWS,
)
from posthog.hogql.query import execute_hogql_query
from posthog.schema import HogQLQueryResponse

Expand All @@ -11,15 +17,24 @@ class HogQLHasMorePaginator:
Takes care of setting the limit and offset on the query.
"""

def __init__(self, limit: int, offset: int):
def __init__(self, *, limit: Optional[int], offset: Optional[int]):
self.response: Optional[HogQLQueryResponse] = None
self.results: Sequence[Any] = []
self.limit = limit
self.offset = offset
self.results: list[Any] = []
self.limit = limit if limit and limit > 0 else DEFAULT_RETURNED_ROWS
self.offset = offset if offset and offset > 0 else 0

@classmethod
def from_limit_context(
cls, *, limit_context: LimitContext, limit: Optional[int], offset: Optional[int]
) -> "HogQLHasMorePaginator":
max_rows = get_max_limit_for_context(limit_context)
default_rows = get_default_limit_for_context(limit_context)
limit = min(max_rows, default_rows if (limit is None or limit <= 0) else limit)
return cls(limit=limit, offset=offset)

def paginate(self, query: ast.SelectQuery) -> ast.SelectQuery:
query.limit = ast.Constant(value=self.limit + 1)
query.offset = ast.Constant(value=self.offset or 0)
query.offset = ast.Constant(value=self.offset)
return query

def has_more(self) -> bool:
Expand All @@ -28,7 +43,7 @@ def has_more(self) -> bool:

return len(self.response.results) > self.limit

def trim_results(self) -> List[Any]:
def trim_results(self) -> list[Any]:
if not self.response or not self.response.results:
return []

Expand Down
Loading

0 comments on commit 41fa973

Please sign in to comment.