From 92e58feeaabbcfc48a49a4c7a754745bcda734a8 Mon Sep 17 00:00:00 2001 From: Sandy Spicer Date: Mon, 1 Jul 2024 10:08:51 -0700 Subject: [PATCH] fix: prevent runtime error - make timings work with threads (#23341) --- posthog/hogql/timings.py | 32 ++++++---- .../trends/test/test_trends_query_runner.py | 14 +++++ .../insights/trends/trends_query_runner.py | 62 +++++++++++-------- 3 files changed, 70 insertions(+), 38 deletions(-) diff --git a/posthog/hogql/timings.py b/posthog/hogql/timings.py index 950d0f5bf23ae..7d54e4c9ab3c6 100644 --- a/posthog/hogql/timings.py +++ b/posthog/hogql/timings.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, field from time import perf_counter from contextlib import contextmanager @@ -7,17 +6,26 @@ from posthog.schema import QueryTiming -@dataclass +# Not thread safe. +# See trends_query_runner for an example of how to use for multithreaded queries class HogQLTimings: - # Completed time in seconds for different parts of the HogQL query - timings: dict[str, float] = field(default_factory=dict) + timings: dict[str, float] + _timing_starts: dict[str, float] + _timing_pointer: str - # Used for housekeeping - _timing_starts: dict[str, float] = field(default_factory=dict) - _timing_pointer: str = "." + def __init__(self, _timing_pointer: str = "."): + # Completed time in seconds for different parts of the HogQL query + self.timings = {} - def __post_init__(self): - self._timing_starts["."] = perf_counter() + # Used for housekeeping + self._timing_pointer = _timing_pointer + self._timing_starts = {self._timing_pointer: perf_counter()} + + def clone_for_subquery(self, series_index: int): + return HogQLTimings(f"{self._timing_pointer}/series_{series_index}") + + def clear_timings(self): + self.timings = {} @contextmanager def measure(self, key: str): @@ -42,5 +50,7 @@ def to_dict(self) -> dict[str, float]: timings[key] = timings.get(key, 0.0) + (perf_counter() - start) return timings - def to_list(self) -> list[QueryTiming]: - return [QueryTiming(k=key, t=time) for key, time in self.to_dict().items()] + def to_list(self, back_out_stack=True) -> list[QueryTiming]: + return [ + QueryTiming(k=key, t=time) for key, time in (self.to_dict() if back_out_stack else self.timings).items() + ] diff --git a/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py b/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py index 947cc6de5d515..3064fce7422f8 100644 --- a/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py @@ -1,6 +1,8 @@ +import re import zoneinfo from dataclasses import dataclass from datetime import datetime +from itertools import groupby from typing import Optional from unittest.mock import MagicMock, patch from django.test import override_settings @@ -361,6 +363,18 @@ def test_trends_multiple_series(self): self.assertEqual([1, 0, 1, 3, 1, 0, 2, 0, 1, 0, 1], response.results[0]["data"]) self.assertEqual([0, 0, 1, 1, 3, 0, 0, 1, 0, 0, 0], response.results[1]["data"]) + # Check the timings + response_groups = [ + k + for k, _ in groupby( + response.timings, key=lambda query_timing: "".join(re.findall(r"series_\d+", query_timing.k)) + ) + ] + assert response_groups[0] == "" + assert response_groups[1] == "series_0" + assert response_groups[2] == "series_1" + assert response_groups[3] == "" + def test_formula(self): self._create_test_events() diff --git a/posthog/hogql_queries/insights/trends/trends_query_runner.py b/posthog/hogql_queries/insights/trends/trends_query_runner.py index 77e53e414a456..7489e9ce9d09b 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/trends_query_runner.py @@ -297,13 +297,14 @@ def calculate(self): response_hogql = to_printed_hogql(response_hogql_query, self.team, self.modifiers) res_matrix: list[list[Any] | Any | None] = [None] * len(queries) - timings_matrix: list[list[QueryTiming] | None] = [None] * len(queries) + timings_matrix: list[list[QueryTiming] | None] = [None] * (2 + len(queries)) errors: list[Exception] = [] debug_errors: list[str] = [] def run( index: int, query: ast.SelectQuery | ast.SelectUnionQuery, + timings: HogQLTimings, is_parallel: bool, query_tags: Optional[dict] = None, ): @@ -317,12 +318,12 @@ def run( query_type="TrendsQuery", query=query, team=self.team, - timings=self.timings, + timings=timings, modifiers=self.modifiers, limit_context=self.limit_context, ) - timings_matrix[index] = response.timings + timings_matrix[index + 1] = response.timings res_matrix[index] = self.build_series_response(response, series_with_extra, len(queries)) if response.error: debug_errors.append(response.error) @@ -335,26 +336,31 @@ def run( # This will only close the DB connection for the newly spawned thread and not the whole app connection.close() - # This exists so that we're not spawning threads during unit tests. We can't do - # this right now due to the lack of multithreaded support of Django - if settings.IN_UNIT_TESTING: - for index, query in enumerate(queries): - run(index, query, False) - elif len(queries) == 1: - run(0, queries[0], False) - else: - jobs = [ - threading.Thread(target=run, args=(index, query, True, query_tagging.get_query_tags())) - for index, query in enumerate(queries) - ] + with self.timings.measure("execute_queries"): + timings_matrix[0] = self.timings.to_list(back_out_stack=False) + self.timings.clear_timings() - # Start the threads - for j in jobs: - j.start() - - # Ensure all of the threads have finished - for j in jobs: - j.join() + # This exists so that we're not spawning threads during unit tests. We can't do + # this right now due to the lack of multithreaded support of Django + if len(queries) == 1 or settings.IN_UNIT_TESTING: + for index, query in enumerate(queries): + run(index, query, self.timings.clone_for_subquery(index), False) + else: + jobs = [ + threading.Thread( + target=run, + args=( + index, + query, + self.timings.clone_for_subquery(index), + True, + query_tagging.get_query_tags(), + ), + ) + for index, query in enumerate(queries) + ] + [j.start() for j in jobs] # type:ignore + [j.join() for j in jobs] # type:ignore # Raise any errors raised in a seperate thread if len(errors) > 0: @@ -368,11 +374,6 @@ def run( elif isinstance(result, dict): returned_results.append([result]) - timings: list[QueryTiming] = [] - for timing in timings_matrix: - if isinstance(timing, list): - timings.extend(timing) - if ( self.query.trendsFilter is not None and self.query.trendsFilter.formula is not None @@ -397,6 +398,13 @@ def run( elif isinstance(result, dict): raise ValueError("This should not happen") + timings_matrix[-1] = self.timings.to_list() + + timings: list[QueryTiming] = [] + for timing in timings_matrix: + if isinstance(timing, list): + timings.extend(timing) + return TrendsQueryResponse( results=final_result, timings=timings,