Skip to content

Commit

Permalink
fix: prevent runtime error - make timings work with threads (#23341)
Browse files Browse the repository at this point in the history
  • Loading branch information
aspicer authored Jul 1, 2024
1 parent 51ccb84 commit 92e58fe
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 38 deletions.
32 changes: 21 additions & 11 deletions posthog/hogql/timings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass, field
from time import perf_counter
from contextlib import contextmanager

Expand All @@ -7,17 +6,26 @@
from posthog.schema import QueryTiming


@dataclass
# Not thread safe.
# See trends_query_runner for an example of how to use for multithreaded queries
class HogQLTimings:
# Completed time in seconds for different parts of the HogQL query
timings: dict[str, float] = field(default_factory=dict)
timings: dict[str, float]
_timing_starts: dict[str, float]
_timing_pointer: str

# Used for housekeeping
_timing_starts: dict[str, float] = field(default_factory=dict)
_timing_pointer: str = "."
def __init__(self, _timing_pointer: str = "."):
# Completed time in seconds for different parts of the HogQL query
self.timings = {}

def __post_init__(self):
self._timing_starts["."] = perf_counter()
# Used for housekeeping
self._timing_pointer = _timing_pointer
self._timing_starts = {self._timing_pointer: perf_counter()}

def clone_for_subquery(self, series_index: int):
return HogQLTimings(f"{self._timing_pointer}/series_{series_index}")

def clear_timings(self):
self.timings = {}

@contextmanager
def measure(self, key: str):
Expand All @@ -42,5 +50,7 @@ def to_dict(self) -> dict[str, float]:
timings[key] = timings.get(key, 0.0) + (perf_counter() - start)
return timings

def to_list(self) -> list[QueryTiming]:
return [QueryTiming(k=key, t=time) for key, time in self.to_dict().items()]
def to_list(self, back_out_stack=True) -> list[QueryTiming]:
return [
QueryTiming(k=key, t=time) for key, time in (self.to_dict() if back_out_stack else self.timings).items()
]
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import re
import zoneinfo
from dataclasses import dataclass
from datetime import datetime
from itertools import groupby
from typing import Optional
from unittest.mock import MagicMock, patch
from django.test import override_settings
Expand Down Expand Up @@ -361,6 +363,18 @@ def test_trends_multiple_series(self):
self.assertEqual([1, 0, 1, 3, 1, 0, 2, 0, 1, 0, 1], response.results[0]["data"])
self.assertEqual([0, 0, 1, 1, 3, 0, 0, 1, 0, 0, 0], response.results[1]["data"])

# Check the timings
response_groups = [
k
for k, _ in groupby(
response.timings, key=lambda query_timing: "".join(re.findall(r"series_\d+", query_timing.k))
)
]
assert response_groups[0] == ""
assert response_groups[1] == "series_0"
assert response_groups[2] == "series_1"
assert response_groups[3] == ""

def test_formula(self):
self._create_test_events()

Expand Down
62 changes: 35 additions & 27 deletions posthog/hogql_queries/insights/trends/trends_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,14 @@ def calculate(self):
response_hogql = to_printed_hogql(response_hogql_query, self.team, self.modifiers)

res_matrix: list[list[Any] | Any | None] = [None] * len(queries)
timings_matrix: list[list[QueryTiming] | None] = [None] * len(queries)
timings_matrix: list[list[QueryTiming] | None] = [None] * (2 + len(queries))
errors: list[Exception] = []
debug_errors: list[str] = []

def run(
index: int,
query: ast.SelectQuery | ast.SelectUnionQuery,
timings: HogQLTimings,
is_parallel: bool,
query_tags: Optional[dict] = None,
):
Expand All @@ -317,12 +318,12 @@ def run(
query_type="TrendsQuery",
query=query,
team=self.team,
timings=self.timings,
timings=timings,
modifiers=self.modifiers,
limit_context=self.limit_context,
)

timings_matrix[index] = response.timings
timings_matrix[index + 1] = response.timings
res_matrix[index] = self.build_series_response(response, series_with_extra, len(queries))
if response.error:
debug_errors.append(response.error)
Expand All @@ -335,26 +336,31 @@ def run(
# This will only close the DB connection for the newly spawned thread and not the whole app
connection.close()

# This exists so that we're not spawning threads during unit tests. We can't do
# this right now due to the lack of multithreaded support of Django
if settings.IN_UNIT_TESTING:
for index, query in enumerate(queries):
run(index, query, False)
elif len(queries) == 1:
run(0, queries[0], False)
else:
jobs = [
threading.Thread(target=run, args=(index, query, True, query_tagging.get_query_tags()))
for index, query in enumerate(queries)
]
with self.timings.measure("execute_queries"):
timings_matrix[0] = self.timings.to_list(back_out_stack=False)
self.timings.clear_timings()

# Start the threads
for j in jobs:
j.start()

# Ensure all of the threads have finished
for j in jobs:
j.join()
# This exists so that we're not spawning threads during unit tests. We can't do
# this right now due to the lack of multithreaded support of Django
if len(queries) == 1 or settings.IN_UNIT_TESTING:
for index, query in enumerate(queries):
run(index, query, self.timings.clone_for_subquery(index), False)
else:
jobs = [
threading.Thread(
target=run,
args=(
index,
query,
self.timings.clone_for_subquery(index),
True,
query_tagging.get_query_tags(),
),
)
for index, query in enumerate(queries)
]
[j.start() for j in jobs] # type:ignore
[j.join() for j in jobs] # type:ignore

# Raise any errors raised in a seperate thread
if len(errors) > 0:
Expand All @@ -368,11 +374,6 @@ def run(
elif isinstance(result, dict):
returned_results.append([result])

timings: list[QueryTiming] = []
for timing in timings_matrix:
if isinstance(timing, list):
timings.extend(timing)

if (
self.query.trendsFilter is not None
and self.query.trendsFilter.formula is not None
Expand All @@ -397,6 +398,13 @@ def run(
elif isinstance(result, dict):
raise ValueError("This should not happen")

timings_matrix[-1] = self.timings.to_list()

timings: list[QueryTiming] = []
for timing in timings_matrix:
if isinstance(timing, list):
timings.extend(timing)

return TrendsQueryResponse(
results=final_result,
timings=timings,
Expand Down

0 comments on commit 92e58fe

Please sign in to comment.