Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(insights): Attribute async queries to users #21019

Merged
merged 10 commits into from
Mar 25, 2024
3 changes: 2 additions & 1 deletion posthog/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ def create(self, request, *args, **kwargs) -> Response:
if data.async_:
query_status = enqueue_process_query_task(
team_id=self.team.pk,
user_id=self.request.user.pk,
query_json=request.data["query"],
query_id=client_query_id,
refresh_requested=data.refresh,
refresh_requested=data.refresh or False,
)
return Response(query_status.model_dump(), status=status.HTTP_202_ACCEPTED)

Expand Down
44 changes: 28 additions & 16 deletions posthog/clickhouse/client/execute_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import json
from typing import Optional
import uuid

import structlog
Expand Down Expand Up @@ -69,11 +70,12 @@ def delete_query_status(self):


def execute_process_query(
team_id,
query_id,
query_json,
limit_context,
refresh_requested,
team_id: int,
user_id: int,
query_id: str,
query_json: dict,
limit_context: Optional[LimitContext],
refresh_requested: bool,
):
manager = QueryStatusManager(query_id, team_id)

Expand All @@ -91,7 +93,7 @@ def execute_process_query(
QUERY_WAIT_TIME.observe(wait_duration)

try:
tag_queries(client_query_id=query_id, team_id=team_id)
tag_queries(client_query_id=query_id, team_id=team_id, user_id=user_id)
results = process_query(
team=team, query_json=query_json, limit_context=limit_context, refresh_requested=refresh_requested
)
Expand All @@ -113,12 +115,13 @@ def execute_process_query(


def enqueue_process_query_task(
team_id,
query_json,
query_id=None,
refresh_requested=False,
bypass_celery=False,
force=False,
team_id: int,
user_id: int,
query_json: dict,
query_id: Optional[str] = None,
refresh_requested: bool = False,
force: bool = False,
_test_only_bypass_celery: bool = False,
) -> QueryStatus:
if not query_id:
query_id = uuid.uuid4().hex
Expand All @@ -136,14 +139,23 @@ def enqueue_process_query_task(
query_status = QueryStatus(id=query_id, team_id=team_id, start_time=datetime.datetime.now(datetime.timezone.utc))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative idea: We put the user ID in the query status here, and thereby have it in the task, without the need to pass it to the task as args. As further advantages, it's then also persisted and returned via API.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that could work too. Though I don't want to spend time reworking this right now to this degree, if the current solution works. 😄 It doesn't seem super useful right now to return the user ID via the API. I guess if that changes, it might be fine to adjust this

manager.store_query_status(query_status)

if bypass_celery:
# Call directly ( for testing )
if _test_only_bypass_celery:
process_query_task(
team_id, query_id, query_json, limit_context=LimitContext.QUERY_ASYNC, refresh_requested=refresh_requested
team_id,
user_id,
query_id,
query_json,
limit_context=LimitContext.QUERY_ASYNC,
refresh_requested=refresh_requested,
)
else:
task = process_query_task.delay(
team_id, query_id, query_json, limit_context=LimitContext.QUERY_ASYNC, refresh_requested=refresh_requested
team_id,
user_id,
query_id,
query_json,
limit_context=LimitContext.QUERY_ASYNC,
refresh_requested=refresh_requested,
)
query_status.task_id = task.id
manager.store_query_status(query_status)
Expand Down
62 changes: 43 additions & 19 deletions posthog/clickhouse/client/test/test_execute_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def setUp(self):
self.organization = Organization.objects.create(name="test")
self.team = Team.objects.create(organization=self.organization)
self.team_id = self.team.pk
self.user_id = 1337
self.query_id = "test_query_id"
self.query_json = {}
self.limit_context = None
Expand All @@ -41,7 +42,9 @@ def test_execute_process_query(self, mock_process_query, mock_redis_client):

mock_process_query.return_value = [float("inf"), float("-inf"), float("nan"), 1.0, "👍"]

execute_process_query(self.team_id, self.query_id, self.query_json, self.limit_context, self.refresh_requested)
execute_process_query(
self.team_id, self.user_id, self.query_id, self.query_json, self.limit_context, self.refresh_requested
)

mock_redis_client.assert_called_once()
mock_process_query.assert_called_once()
Expand All @@ -55,15 +58,16 @@ def test_execute_process_query(self, mock_process_query, mock_redis_client):

class ClickhouseClientTestCase(TestCase, ClickhouseTestMixin):
def setUp(self):
self.organization = Organization.objects.create(name="test")
self.team = Team.objects.create(organization=self.organization)
self.team_id = self.team.pk
self.organization: Organization = Organization.objects.create(name="test")
self.team: Team = Team.objects.create(organization=self.organization)
self.team_id: int = self.team.pk
self.user_id: int = 2137

@snapshot_clickhouse_queries
def test_async_query_client(self):
query = build_query("SELECT 1+1")
team_id = self.team_id
query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True).id
query_id = client.enqueue_process_query_task(team_id, self.user_id, query, _test_only_bypass_celery=True).id
result = client.get_query_status(team_id, query_id)
self.assertFalse(result.error, result.error_message)
self.assertTrue(result.complete)
Expand All @@ -74,11 +78,13 @@ def test_async_query_client_errors(self):
self.assertRaises(
HogQLException,
client.enqueue_process_query_task,
**{"team_id": (self.team_id), "query_json": query, "bypass_celery": True},
**{"team_id": self.team_id, "user_id": self.user_id, "query_json": query, "_test_only_bypass_celery": True},
)
query_id = uuid.uuid4().hex
try:
client.enqueue_process_query_task(self.team_id, query, query_id=query_id, bypass_celery=True)
client.enqueue_process_query_task(
self.team_id, self.user_id, query, query_id=query_id, _test_only_bypass_celery=True
)
except Exception:
pass

Expand All @@ -89,7 +95,7 @@ def test_async_query_client_errors(self):
def test_async_query_client_uuid(self):
query = build_query("SELECT toUUID('00000000-0000-0000-0000-000000000000')")
team_id = self.team_id
query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True).id
query_id = client.enqueue_process_query_task(team_id, self.user_id, query, _test_only_bypass_celery=True).id
result = client.get_query_status(team_id, query_id)
self.assertFalse(result.error, result.error_message)
self.assertTrue(result.complete)
Expand All @@ -99,7 +105,7 @@ def test_async_query_client_does_not_leak(self):
query = build_query("SELECT 1+1")
team_id = self.team_id
wrong_team = 5
query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True).id
query_id = client.enqueue_process_query_task(team_id, self.user_id, query, _test_only_bypass_celery=True).id

try:
client.get_query_status(wrong_team, query_id)
Expand All @@ -111,13 +117,19 @@ def test_async_query_client_is_lazy(self, execute_sync_mock):
query = build_query("SELECT 4 + 4")
query_id = uuid.uuid4().hex
team_id = self.team_id
client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True)
client.enqueue_process_query_task(
team_id, self.user_id, query, query_id=query_id, _test_only_bypass_celery=True
)

# Try the same query again
client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True)
client.enqueue_process_query_task(
team_id, self.user_id, query, query_id=query_id, _test_only_bypass_celery=True
)

# Try the same query again (for good measure!)
client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True)
client.enqueue_process_query_task(
team_id, self.user_id, query, query_id=query_id, _test_only_bypass_celery=True
)

# Assert that we only called clickhouse once
execute_sync_mock.assert_called_once()
Expand All @@ -127,13 +139,19 @@ def test_async_query_client_is_lazy_but_not_too_lazy(self, execute_sync_mock):
query = build_query("SELECT 8 + 8")
query_id = uuid.uuid4().hex
team_id = self.team_id
client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True)
client.enqueue_process_query_task(
team_id, self.user_id, query, query_id=query_id, _test_only_bypass_celery=True
)

# Try the same query again, but with force
client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True, force=True)
client.enqueue_process_query_task(
team_id, self.user_id, query, query_id=query_id, _test_only_bypass_celery=True, force=True
)

# Try the same query again (for good measure!)
client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True)
client.enqueue_process_query_task(
team_id, self.user_id, query, query_id=query_id, _test_only_bypass_celery=True
)

# Assert that we called clickhouse twice
self.assertEqual(execute_sync_mock.call_count, 2)
Expand All @@ -145,13 +163,19 @@ def test_async_query_client_manual_query_uuid(self, execute_sync_mock):
query = build_query("SELECT 8 + 8")
team_id = self.team_id
query_id = "I'm so unique"
client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True)
client.enqueue_process_query_task(
team_id, self.user_id, query, query_id=query_id, _test_only_bypass_celery=True
)

# Try the same query again, but with force
client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True, force=True)
client.enqueue_process_query_task(
team_id, self.user_id, query, query_id=query_id, _test_only_bypass_celery=True, force=True
)

# Try the same query again (for good measure!)
client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True)
client.enqueue_process_query_task(
team_id, self.user_id, query, query_id=query_id, _test_only_bypass_celery=True
)

# Assert that we called clickhouse twice
self.assertEqual(execute_sync_mock.call_count, 2)
Expand Down Expand Up @@ -186,4 +210,4 @@ def test_client_strips_comments_from_request(self):

# Make sure it still includes the "annotation" comment that includes
# request routing information for debugging purposes
self.assertIn("/* request:1 */", first_query)
self.assertIn(f"/* user_id:{self.user_id} request:1 */", first_query)
7 changes: 4 additions & 3 deletions posthog/errors.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from dataclasses import dataclass
import re
from typing import Dict
from typing import Dict, Optional

from clickhouse_driver.errors import ServerException

from posthog.exceptions import EstimatedQueryExecutionTimeTooLong, QuerySizeExceeded


class InternalCHQueryError(ServerException):
code_name: str
code_name: Optional[str]
"""Can be null if re-raised from a thread (see `failhard_threadhook_context`)."""

def __init__(self, message, *, code=None, nested=None, code_name):
def __init__(self, message, *, code=None, nested=None, code_name=None):
self.code_name = code_name
super().__init__(message, code, nested)

Expand Down
11 changes: 9 additions & 2 deletions posthog/tasks/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Any, Optional
from typing import Optional
from uuid import UUID

from celery import shared_task
Expand All @@ -9,6 +9,7 @@
from prometheus_client import Gauge

from posthog.cloud_utils import is_cloud
from posthog.hogql.constants import LimitContext
from posthog.metrics import pushed_metrics_registry
from posthog.ph_client import get_ph_client
from posthog.redis import get_client
Expand All @@ -33,7 +34,12 @@ def redis_heartbeat() -> None:

@shared_task(ignore_result=True, queue=CeleryQueue.ANALYTICS_QUERIES.value)
def process_query_task(
team_id: str, query_id: str, query_json: Any, limit_context: Any = None, refresh_requested: bool = False
team_id: int,
user_id: int,
query_id: str,
query_json: dict,
limit_context: Optional[LimitContext] = None,
refresh_requested: bool = False,
) -> None:
"""
Kick off query
Expand All @@ -43,6 +49,7 @@ def process_query_task(

execute_process_query(
team_id=team_id,
user_id=user_id,
query_id=query_id,
query_json=query_json,
limit_context=limit_context,
Expand Down
Loading