Skip to content

Commit

Permalink
feat(hogql): events query runner (#17892)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra authored Oct 11, 2023
1 parent 9939786 commit 0df3b2f
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 243 deletions.
8 changes: 2 additions & 6 deletions posthog/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@

from posthog.hogql_queries.query_runner import get_query_runner
from posthog.models import Team
from posthog.models.event.events_query import run_events_query
from posthog.models.user import User
from posthog.permissions import ProjectMembershipNecessaryPermissions, TeamMemberAccessPermission
from posthog.queries.time_to_see_data.serializers import SessionEventsQuerySerializer, SessionsQuerySerializer
from posthog.queries.time_to_see_data.sessions import get_session_events, get_sessions
from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle, TeamRateThrottle
from posthog.schema import EventsQuery, HogQLQuery, HogQLMetadata
from posthog.schema import HogQLQuery, HogQLMetadata
from posthog.utils import refresh_requested_by_client

QUERY_WITH_RUNNER = [
Expand All @@ -47,6 +46,7 @@
"WebTopPagesQuery",
]
QUERY_WITH_RUNNER_NO_CACHE = [
"EventsQuery",
"PersonsQuery",
]

Expand Down Expand Up @@ -224,10 +224,6 @@ def process_query(
elif query_kind in QUERY_WITH_RUNNER_NO_CACHE:
query_runner = get_query_runner(query_json, team)
return _unwrap_pydantic_dict(query_runner.calculate())
elif query_kind == "EventsQuery":
events_query = EventsQuery.model_validate(query_json)
events_response = run_events_query(query=events_query, team=team, default_limit=default_limit)
return _unwrap_pydantic_dict(events_response)
elif query_kind == "HogQLQuery":
hogql_query = HogQLQuery.model_validate(query_json)
values = (
Expand Down
278 changes: 278 additions & 0 deletions posthog/hogql_queries/events_query_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
import json
from datetime import timedelta
from typing import Dict, List, Optional, Any

from dateutil.parser import isoparse
from django.db.models import Prefetch
from django.utils.timezone import now

from posthog.api.element import ElementSerializer
from posthog.api.utils import get_pk_or_uuid
from posthog.clickhouse.client.connection import Workload
from posthog.hogql import ast
from posthog.hogql.parser import parse_expr, parse_order_expr
from posthog.hogql.property import action_to_expr, has_aggregation, property_to_expr
from posthog.hogql.query import execute_hogql_query
from posthog.hogql.timings import HogQLTimings
from posthog.hogql_queries.query_runner import QueryRunner
from posthog.models import Action, Person, Team
from posthog.models.element import chain_to_elements
from posthog.models.person.util import get_persons_by_distinct_ids
from posthog.schema import EventsQuery, EventsQueryResponse
from posthog.utils import relative_date_parse

# Allow-listed fields returned when you select "*" from events. Person and group fields will be nested later.
SELECT_STAR_FROM_EVENTS_FIELDS = [
"uuid",
"event",
"properties",
"timestamp",
"team_id",
"distinct_id",
"elements_chain",
"created_at",
]


class EventsQueryRunner(QueryRunner):
query: EventsQuery
query_type = EventsQuery

def __init__(
self,
query: EventsQuery | Dict[str, Any],
team: Team,
timings: Optional[HogQLTimings] = None,
default_limit: Optional[int] = None,
):
super().__init__(query, team, timings)
if isinstance(query, EventsQuery):
self.query = query
else:
self.query = EventsQuery.model_validate(query)
self.default_limit = default_limit

def to_query(self) -> ast.SelectQuery:
# Note: This code is inefficient and problematic, see https://github.com/PostHog/posthog/issues/13485 for details.
if self.timings is None:
self.timings = HogQLTimings()

with self.timings.measure("build_ast"):
# limit & offset
offset = 0 if self.query.offset is None else self.query.offset

# columns & group_by
with self.timings.measure("columns"):
select_input: List[str] = []
person_indices: List[int] = []
for index, col in enumerate(self.select_input_raw()):
# Selecting a "*" expands the list of columns, resulting in a table that's not what we asked for.
# Instead, ask for a tuple with all the columns we want. Later transform this back into a dict.
if col == "*":
select_input.append(f"tuple({', '.join(SELECT_STAR_FROM_EVENTS_FIELDS)})")
elif col.split("--")[0].strip() == "person":
# This will be expanded into a followup query
select_input.append("distinct_id")
person_indices.append(index)
else:
select_input.append(col)
select: List[ast.Expr] = [parse_expr(column, timings=self.timings) for column in select_input]

with self.timings.measure("aggregations"):
group_by: List[ast.Expr] = [column for column in select if not has_aggregation(column)]
aggregations: List[ast.Expr] = [column for column in select if has_aggregation(column)]
has_any_aggregation = len(aggregations) > 0

# filters
with self.timings.measure("filters"):
with self.timings.measure("where"):
where_input = self.query.where or []
where_exprs = [parse_expr(expr, timings=self.timings) for expr in where_input]
if self.query.properties:
with self.timings.measure("properties"):
where_exprs.extend(property_to_expr(property, self.team) for property in self.query.properties)
if self.query.fixedProperties:
with self.timings.measure("fixed_properties"):
where_exprs.extend(
property_to_expr(property, self.team) for property in self.query.fixedProperties
)
if self.query.event:
with self.timings.measure("event"):
where_exprs.append(
parse_expr(
"event = {event}", {"event": ast.Constant(value=self.query.event)}, timings=self.timings
)
)
if self.query.actionId:
with self.timings.measure("action_id"):
try:
action = Action.objects.get(pk=self.query.actionId, team_id=self.team.pk)
except Action.DoesNotExist:
raise Exception("Action does not exist")
if action.steps.count() == 0:
raise Exception("Action does not have any match groups")
where_exprs.append(action_to_expr(action))
if self.query.personId:
with self.timings.measure("person_id"):
person: Optional[Person] = get_pk_or_uuid(Person.objects.all(), self.query.personId).first()
distinct_ids = person.distinct_ids if person is not None else []
ids_list = list(map(str, distinct_ids))
where_exprs.append(
parse_expr(
"distinct_id in {list}", {"list": ast.Constant(value=ids_list)}, timings=self.timings
)
)

with self.timings.measure("timestamps"):
# prevent accidentally future events from being visible by default
before = self.query.before or (now() + timedelta(seconds=5)).isoformat()
try:
parsed_date = isoparse(before)
except ValueError:
parsed_date = relative_date_parse(before, self.team.timezone_info)
where_exprs.append(
parse_expr(
"timestamp < {timestamp}", {"timestamp": ast.Constant(value=parsed_date)}, timings=self.timings
)
)

# limit to the last 24h by default
after = self.query.after or "-24h"
if after != "all":
try:
parsed_date = isoparse(after)
except ValueError:
parsed_date = relative_date_parse(after, self.team.timezone_info)
where_exprs.append(
parse_expr(
"timestamp > {timestamp}",
{"timestamp": ast.Constant(value=parsed_date)},
timings=self.timings,
)
)

# where & having
with self.timings.measure("where"):
where_list = [expr for expr in where_exprs if not has_aggregation(expr)]
where = ast.And(exprs=where_list) if len(where_list) > 0 else None
having_list = [expr for expr in where_exprs if has_aggregation(expr)]
having = ast.And(exprs=having_list) if len(having_list) > 0 else None

# order by
with self.timings.measure("order"):
if self.query.orderBy is not None:
order_by = [parse_order_expr(column, timings=self.timings) for column in self.query.orderBy]
elif "count()" in select_input:
order_by = [ast.OrderExpr(expr=parse_expr("count()"), order="DESC")]
elif len(aggregations) > 0:
order_by = [ast.OrderExpr(expr=aggregations[0], order="DESC")]
elif "timestamp" in select_input:
order_by = [ast.OrderExpr(expr=ast.Field(chain=["timestamp"]), order="DESC")]
elif len(select) > 0:
order_by = [ast.OrderExpr(expr=select[0], order="ASC")]
else:
order_by = []

with self.timings.measure("select"):
stmt = ast.SelectQuery(
select=select,
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
where=where,
having=having,
group_by=group_by if has_any_aggregation else None,
order_by=order_by,
limit=ast.Constant(value=self.limit()),
offset=ast.Constant(value=offset),
)
return stmt

def calculate(self) -> EventsQueryResponse:
query_result = execute_hogql_query(
query=self.to_query(),
team=self.team,
workload=Workload.ONLINE,
query_type="EventsQuery",
timings=self.timings,
)

# Convert star field from tuple to dict in each result
if "*" in self.select_input_raw():
with self.timings.measure("expand_asterisk"):
star_idx = self.select_input_raw().index("*")
for index, result in enumerate(query_result.results):
query_result.results[index] = list(result)
select = result[star_idx]
new_result = dict(zip(SELECT_STAR_FROM_EVENTS_FIELDS, select))
new_result["properties"] = json.loads(new_result["properties"])
if new_result["elements_chain"]:
new_result["elements"] = ElementSerializer(
chain_to_elements(new_result["elements_chain"]), many=True
).data
query_result.results[index][star_idx] = new_result

person_indices: List[int] = []
for index, col in enumerate(self.select_input_raw()):
if col.split("--")[0].strip() == "person":
person_indices.append(index)

if len(person_indices) > 0 and len(query_result.results) > 0:
with self.timings.measure("person_column_extra_query"):
# Make a query into postgres to fetch person
person_idx = person_indices[0]
distinct_ids = list(set(event[person_idx] for event in query_result.results))
persons = get_persons_by_distinct_ids(self.team.pk, distinct_ids)
persons = persons.prefetch_related(Prefetch("persondistinctid_set", to_attr="distinct_ids_cache"))
distinct_to_person: Dict[str, Person] = {}
for person in persons:
if person:
for person_distinct_id in person.distinct_ids:
distinct_to_person[person_distinct_id] = person

# Loop over all columns in case there is more than one "person" column
for column_index in person_indices:
for index, result in enumerate(query_result.results):
distinct_id: str = result[column_index]
query_result.results[index] = list(result)
if distinct_to_person.get(distinct_id):
person = distinct_to_person[distinct_id]
query_result.results[index][column_index] = {
"uuid": person.uuid,
"created_at": person.created_at,
"properties": person.properties or {},
"distinct_id": distinct_id,
}
else:
query_result.results[index][column_index] = {
"distinct_id": distinct_id,
}

received_extra_row = len(query_result.results) == self.limit() # limit was +=1'd above
return EventsQueryResponse(
results=query_result.results[: self.limit() - 1] if received_extra_row else query_result.results,
columns=self.select_input_raw(),
types=[type for _, type in query_result.types],
hasMore=received_extra_row,
timings=self.timings.to_list(),
)

def select_input_raw(self) -> List[str]:
return ["*"] if len(self.query.select) == 0 else self.query.select

def limit(self) -> int:
# importing locally so we could override in a test
from posthog.hogql.constants import DEFAULT_RETURNED_ROWS, MAX_SELECT_RETURNED_ROWS

# adding +1 to the limit to check if there's a "next page" after the requested results
return (
min(
MAX_SELECT_RETURNED_ROWS,
self.default_limit or DEFAULT_RETURNED_ROWS if self.query.limit is None else self.query.limit,
)
+ 1
)

def _is_stale(self, cached_result_package):
return True

def _refresh_frequency(self):
return timedelta(minutes=1)
17 changes: 14 additions & 3 deletions posthog/hogql_queries/query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
WebTopPagesQuery,
WebOverviewStatsQuery,
PersonsQuery,
EventsQuery,
)
from posthog.utils import generate_cache_key, get_safe_cache

Expand Down Expand Up @@ -64,8 +65,9 @@ class CachedQueryResponse(QueryResponse):

RunnableQueryNode = Union[
TrendsQuery,
PersonsQuery,
LifecycleQuery,
EventsQuery,
PersonsQuery,
WebOverviewStatsQuery,
WebTopSourcesQuery,
WebTopClicksQuery,
Expand All @@ -74,7 +76,10 @@ class CachedQueryResponse(QueryResponse):


def get_query_runner(
query: Dict[str, Any] | RunnableQueryNode, team: Team, timings: Optional[HogQLTimings] = None
query: Dict[str, Any] | RunnableQueryNode,
team: Team,
timings: Optional[HogQLTimings] = None,
default_limit: Optional[int] = None,
) -> "QueryRunner":
kind = None
if isinstance(query, dict):
Expand All @@ -90,6 +95,12 @@ def get_query_runner(
from .insights.trends_query_runner import TrendsQueryRunner

return TrendsQueryRunner(query=cast(TrendsQuery | Dict[str, Any], query), team=team, timings=timings)
if kind == "EventsQuery":
from .events_query_runner import EventsQueryRunner

return EventsQueryRunner(
query=cast(EventsQuery | Dict[str, Any], query), team=team, timings=timings, default_limit=default_limit
)
if kind == "PersonsQuery":
from .persons_query_runner import PersonsQueryRunner

Expand Down Expand Up @@ -134,7 +145,7 @@ def calculate(self) -> BaseModel:
# Due to the way schema.py is generated, we don't have a good inheritance story here.
raise NotImplementedError()

def run(self, refresh_requested: bool) -> CachedQueryResponse:
def run(self, refresh_requested: Optional[bool] = None) -> CachedQueryResponse:
cache_key = self._cache_key()
tag_queries(cache_key=cache_key)

Expand Down
Loading

0 comments on commit 0df3b2f

Please sign in to comment.