diff --git a/ee/clickhouse/test/test_client.py b/ee/clickhouse/test/test_client.py deleted file mode 100644 index ab5ba1b4a53e0..0000000000000 --- a/ee/clickhouse/test/test_client.py +++ /dev/null @@ -1,129 +0,0 @@ -from unittest.mock import patch - -import fakeredis -from clickhouse_driver.errors import ServerException -from django.test import TestCase - -from posthog.clickhouse.client import execute_async as client -from posthog.client import sync_execute -from posthog.test.base import ClickhouseTestMixin - - -class ClickhouseClientTestCase(TestCase, ClickhouseTestMixin): - def setUp(self): - self.redis_client = fakeredis.FakeStrictRedis() - - def test_async_query_client(self): - query = "SELECT 1+1" - team_id = 2 - query_id = client.enqueue_execute_with_progress(team_id, query, bypass_celery=True) - result = client.get_status_or_results(team_id, query_id) - self.assertFalse(result.error) - self.assertTrue(result.complete) - self.assertEqual(result.results, [[2]]) - - def test_async_query_client_errors(self): - query = "SELECT WOW SUCH DATA FROM NOWHERE THIS WILL CERTAINLY WORK" - team_id = 2 - self.assertRaises( - ServerException, - client.enqueue_execute_with_progress, - **{"team_id": team_id, "query": query, "bypass_celery": True}, - ) - try: - query_id = client.enqueue_execute_with_progress(team_id, query, bypass_celery=True) - except Exception: - pass - - result = client.get_status_or_results(team_id, query_id) - self.assertTrue(result.error) - self.assertRegex(result.error_message, "Code: 62.\nDB::Exception: Syntax error:") - - def test_async_query_client_does_not_leak(self): - query = "SELECT 1+1" - team_id = 2 - wrong_team = 5 - query_id = client.enqueue_execute_with_progress(team_id, query, bypass_celery=True) - result = client.get_status_or_results(wrong_team, query_id) - self.assertTrue(result.error) - self.assertEqual(result.error_message, "Requesting team is not executing team") - - @patch("posthog.clickhouse.client.execute_async.enqueue_clickhouse_execute_with_progress") - def test_async_query_client_is_lazy(self, execute_sync_mock): - query = "SELECT 4 + 4" - team_id = 2 - client.enqueue_execute_with_progress(team_id, query, bypass_celery=True) - - # Try the same query again - client.enqueue_execute_with_progress(team_id, query, bypass_celery=True) - - # Try the same query again (for good measure!) - client.enqueue_execute_with_progress(team_id, query, bypass_celery=True) - - # Assert that we only called clickhouse once - execute_sync_mock.assert_called_once() - - @patch("posthog.clickhouse.client.execute_async.enqueue_clickhouse_execute_with_progress") - def test_async_query_client_is_lazy_but_not_too_lazy(self, execute_sync_mock): - query = "SELECT 8 + 8" - team_id = 2 - client.enqueue_execute_with_progress(team_id, query, bypass_celery=True) - - # Try the same query again, but with force - client.enqueue_execute_with_progress(team_id, query, bypass_celery=True, force=True) - - # Try the same query again (for good measure!) - client.enqueue_execute_with_progress(team_id, query, bypass_celery=True) - - # Assert that we called clickhouse twice - self.assertEqual(execute_sync_mock.call_count, 2) - - @patch("posthog.clickhouse.client.execute_async.enqueue_clickhouse_execute_with_progress") - def test_async_query_client_manual_query_uuid(self, execute_sync_mock): - # This is a unique test because technically in the test pattern `SELECT 8 + 8` is already - # in redis. This tests to make sure it is treated as a unique run of that query - query = "SELECT 8 + 8" - team_id = 2 - query_id = "I'm so unique" - client.enqueue_execute_with_progress(team_id, query, query_id=query_id, bypass_celery=True) - - # Try the same query again, but with force - client.enqueue_execute_with_progress(team_id, query, query_id=query_id, bypass_celery=True, force=True) - - # Try the same query again (for good measure!) - client.enqueue_execute_with_progress(team_id, query, query_id=query_id, bypass_celery=True) - - # Assert that we called clickhouse twice - self.assertEqual(execute_sync_mock.call_count, 2) - - def test_client_strips_comments_from_request(self): - """ - To ensure we can easily copy queries from `system.query_log` in e.g. - Metabase, we strip comments from the query we send. Metabase doesn't - display multilined output. - - See https://github.com/metabase/metabase/issues/14253 - - Note I'm not really testing much complexity, I trust that those will - come out as failures in other tests. - """ - from posthog.clickhouse.query_tagging import tag_queries - - # First add in the request information that should be added to the sql. - # We check this to make sure it is not removed by the comment stripping - with self.capture_select_queries() as sqls: - tag_queries(kind="request", id="1") - sync_execute( - query=""" - -- this request returns 1 - SELECT 1 - """ - ) - self.assertEqual(len(sqls), 1) - first_query = sqls[0] - self.assertIn(f"SELECT 1", first_query) - self.assertNotIn("this request returns", first_query) - - # Make sure it still includes the "annotation" comment that includes - # request routing information for debugging purposes - self.assertIn("/* request:1 */", first_query) diff --git a/ee/clickhouse/views/test/__snapshots__/test_clickhouse_experiment_secondary_results.ambr b/ee/clickhouse/views/test/__snapshots__/test_clickhouse_experiment_secondary_results.ambr index f039a2994204e..262c4a8e1e195 100644 --- a/ee/clickhouse/views/test/__snapshots__/test_clickhouse_experiment_secondary_results.ambr +++ b/ee/clickhouse/views/test/__snapshots__/test_clickhouse_experiment_secondary_results.ambr @@ -1,6 +1,6 @@ # name: ClickhouseTestExperimentSecondaryResults.test_basic_secondary_metric_results ' - /* user_id:126 celery:posthog.celery.sync_insight_caching_state */ + /* user_id:132 celery:posthog.celery.sync_insight_caching_state */ SELECT team_id, date_diff('second', max(timestamp), now()) AS age FROM events diff --git a/frontend/__snapshots__/scenes-app-insights--funnel-left-to-right-breakdown.png b/frontend/__snapshots__/scenes-app-insights--funnel-left-to-right-breakdown.png index f90c93d7e84d0..b899e5e48f727 100644 Binary files a/frontend/__snapshots__/scenes-app-insights--funnel-left-to-right-breakdown.png and b/frontend/__snapshots__/scenes-app-insights--funnel-left-to-right-breakdown.png differ diff --git a/frontend/__snapshots__/scenes-app-pipeline--pipeline-transformations-page.png b/frontend/__snapshots__/scenes-app-pipeline--pipeline-transformations-page.png index 80fdf0c1d8632..11799f73e5bb9 100644 Binary files a/frontend/__snapshots__/scenes-app-pipeline--pipeline-transformations-page.png and b/frontend/__snapshots__/scenes-app-pipeline--pipeline-transformations-page.png differ diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index d1de3a313acb2..ad84040f92c31 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -65,7 +65,7 @@ import { EVENT_PROPERTY_DEFINITIONS_PER_PAGE } from 'scenes/data-management/prop import { ActivityLogItem, ActivityScope } from 'lib/components/ActivityLog/humanizeActivity' import { ActivityLogProps } from 'lib/components/ActivityLog/ActivityLog' import { SavedSessionRecordingPlaylistsResult } from 'scenes/session-recordings/saved-playlists/savedSessionRecordingPlaylistsLogic' -import { QuerySchema } from '~/queries/schema' +import { QuerySchema, QueryStatus } from '~/queries/schema' import { decompressSync, strFromU8 } from 'fflate' import { getCurrentExporterData } from '~/exporter/exporterViewLogic' import { encodeParams } from 'kea-router' @@ -542,6 +542,10 @@ class ApiRequest { return this.projectsDetail(teamId).addPathComponent('query') } + public queryStatus(queryId: string, teamId?: TeamType['id']): ApiRequest { + return this.query(teamId).addPathComponent(queryId) + } + // Notebooks public notebooks(teamId?: TeamType['id']): ApiRequest { return this.projectsDetail(teamId).addPathComponent('notebooks') @@ -1722,6 +1726,12 @@ const api = { }, }, + queryStatus: { + async get(queryId: string): Promise { + return await new ApiRequest().queryStatus(queryId).get() + }, + }, + queryURL: (): string => { return new ApiRequest().query().assembleFullUrl(true) }, @@ -1730,7 +1740,8 @@ const api = { query: T, options?: ApiMethodOptions, queryId?: string, - refresh?: boolean + refresh?: boolean, + async?: boolean ): Promise< T extends { [response: string]: any } ? T['response'] extends infer P | undefined @@ -1740,7 +1751,7 @@ const api = { > { return await new ApiRequest() .query() - .create({ ...options, data: { query, client_query_id: queryId, refresh: refresh } }) + .create({ ...options, data: { query, client_query_id: queryId, refresh: refresh, async } }) }, /** Fetch data from specified URL. The result already is JSON-parsed. */ diff --git a/frontend/src/lib/constants.tsx b/frontend/src/lib/constants.tsx index 63ef80121bccc..f8b3e96456bbc 100644 --- a/frontend/src/lib/constants.tsx +++ b/frontend/src/lib/constants.tsx @@ -136,6 +136,7 @@ export const FEATURE_FLAGS = { ROLE_BASED_ACCESS: 'role-based-access', // owner: #team-experiments, @liyiy QUERY_RUNNING_TIME: 'query_running_time', // owner: @mariusandra QUERY_TIMINGS: 'query-timings', // owner: @mariusandra + QUERY_ASYNC: 'query-async', // owner: @webjunkie POSTHOG_3000: 'posthog-3000', // owner: @Twixes POSTHOG_3000_NAV: 'posthog-3000-nav', // owner: @Twixes ENABLE_PROMPTS: 'enable-prompts', // owner: @lharries diff --git a/frontend/src/queries/nodes/DataNode/dataNodeLogic.ts b/frontend/src/queries/nodes/DataNode/dataNodeLogic.ts index a83d0398a49d4..94c2651343b31 100644 --- a/frontend/src/queries/nodes/DataNode/dataNodeLogic.ts +++ b/frontend/src/queries/nodes/DataNode/dataNodeLogic.ts @@ -477,7 +477,7 @@ export const dataNodeLogic = kea([ abortQuery: async ({ queryId }) => { try { const { currentTeamId } = values - await api.create(`api/projects/${currentTeamId}/insights/cancel`, { client_query_id: queryId }) + await api.delete(`api/projects/${currentTeamId}/query/${queryId}/`) } catch (e) { console.warn('Failed cancelling query', e) } diff --git a/frontend/src/queries/query.ts b/frontend/src/queries/query.ts index df71c0b1cef4b..01c218a290e6c 100644 --- a/frontend/src/queries/query.ts +++ b/frontend/src/queries/query.ts @@ -26,13 +26,16 @@ import { isStickinessFilter, isTrendsFilter, } from 'scenes/insights/sharedUtils' -import { flattenObject, toParams } from 'lib/utils' +import { flattenObject, delay, toParams } from 'lib/utils' import { queryNodeToFilter } from './nodes/InsightQuery/utils/queryNodeToFilter' import { now } from 'lib/dayjs' import { currentSessionId } from 'lib/internalMetrics' import { featureFlagLogic } from 'lib/logic/featureFlagLogic' import { FEATURE_FLAGS } from 'lib/constants' +const QUERY_ASYNC_MAX_INTERVAL_SECONDS = 10 +const QUERY_ASYNC_TOTAL_POLL_SECONDS = 300 + //get export context for a given query export function queryExportContext( query: N, @@ -91,6 +94,43 @@ export function queryExportContext( throw new Error(`Unsupported query: ${query.kind}`) } +async function executeQuery( + queryNode: N, + methodOptions?: ApiMethodOptions, + refresh?: boolean, + queryId?: string +): Promise> { + const queryAsyncEnabled = Boolean(featureFlagLogic.findMounted()?.values.featureFlags?.[FEATURE_FLAGS.QUERY_ASYNC]) + const excludedKinds = ['HogQLMetadata'] + const queryAsync = queryAsyncEnabled && !excludedKinds.includes(queryNode.kind) + const response = await api.query(queryNode, methodOptions, queryId, refresh, queryAsync) + + if (!queryAsync || !response.query_async) { + return response + } + + const pollStart = performance.now() + let currentDelay = 300 // start low, because all queries will take at minimum this + + while (performance.now() - pollStart < QUERY_ASYNC_TOTAL_POLL_SECONDS * 1000) { + await delay(currentDelay) + currentDelay = Math.min(currentDelay * 2, QUERY_ASYNC_MAX_INTERVAL_SECONDS * 1000) + + if (methodOptions?.signal?.aborted) { + const customAbortError = new Error('Query aborted') + customAbortError.name = 'AbortError' + throw customAbortError + } + + const statusResponse = await api.queryStatus.get(response.id) + + if (statusResponse.complete || statusResponse.error) { + return statusResponse.results + } + } + throw new Error('Query timed out') +} + // Return data for a given query export async function query( queryNode: N, @@ -216,7 +256,7 @@ export async function query( response = await fetchLegacyInsights() } } else { - response = await api.query(queryNode, methodOptions, queryId, refresh) + response = await executeQuery(queryNode, methodOptions, refresh, queryId) if (isHogQLQuery(queryNode) && response && typeof response === 'object') { logParams.clickhouse_sql = (response as HogQLQueryResponse)?.clickhouse } diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index 019230809f749..d7b3b2ca0da1d 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -2413,6 +2413,51 @@ } ] }, + "QueryStatus": { + "additionalProperties": false, + "properties": { + "complete": { + "default": false, + "type": "boolean" + }, + "end_time": { + "format": "date-time", + "type": "string" + }, + "error": { + "default": false, + "type": "boolean" + }, + "error_message": { + "default": "", + "type": "string" + }, + "expiration_time": { + "format": "date-time", + "type": "string" + }, + "id": { + "type": "string" + }, + "query_async": { + "default": true, + "type": "boolean" + }, + "results": {}, + "start_time": { + "format": "date-time", + "type": "string" + }, + "task_id": { + "type": "string" + }, + "team_id": { + "type": "integer" + } + }, + "required": ["id", "query_async", "team_id", "error", "complete", "error_message"], + "type": "object" + }, "QueryTiming": { "additionalProperties": false, "properties": { diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index b0ad656cea099..5e0f19452d438 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -542,6 +542,28 @@ export interface QueryResponse { next_allowed_client_refresh?: string } +export type QueryStatus = { + id: string + /** @default true */ + query_async: boolean + /** @asType integer */ + team_id: number + /** @default false */ + error: boolean + /** @default false */ + complete: boolean + /** @default "" */ + error_message: string + results?: any + /** @format date-time */ + start_time?: string + /** @format date-time */ + end_time?: string + /** @format date-time */ + expiration_time?: string + task_id?: string +} + export interface LifecycleQueryResponse extends QueryResponse { results: Record[] } diff --git a/package.json b/package.json index c99c80f987667..f89d7b151eb93 100644 --- a/package.json +++ b/package.json @@ -39,7 +39,7 @@ "build:esbuild": "node frontend/build.mjs", "schema:build": "pnpm run schema:build:json && pnpm run schema:build:python", "schema:build:json": "ts-json-schema-generator -f tsconfig.json --path 'frontend/src/queries/schema.ts' --no-type-check > frontend/src/queries/schema.json && prettier --write frontend/src/queries/schema.json", - "schema:build:python": "datamodel-codegen --class-name='SchemaRoot' --collapse-root-models --disable-timestamp --use-one-literal-as-default --use-default-kwarg --use-subclass-enum --input frontend/src/queries/schema.json --input-file-type jsonschema --output posthog/schema.py --output-model-type pydantic_v2.BaseModel && ruff format posthog/schema.py", + "schema:build:python": "datamodel-codegen --class-name='SchemaRoot' --collapse-root-models --disable-timestamp --use-one-literal-as-default --use-default --use-default-kwarg --use-subclass-enum --input frontend/src/queries/schema.json --input-file-type jsonschema --output posthog/schema.py --output-model-type pydantic_v2.BaseModel && ruff format posthog/schema.py", "grammar:build": "npm run grammar:build:python && npm run grammar:build:cpp", "grammar:build:python": "cd posthog/hogql/grammar && antlr -Dlanguage=Python3 HogQLLexer.g4 && antlr -visitor -no-listener -Dlanguage=Python3 HogQLParser.g4", "grammar:build:cpp": "cd posthog/hogql/grammar && antlr -o ../../../hogql_parser -Dlanguage=Cpp HogQLLexer.g4 && antlr -o ../../../hogql_parser -visitor -no-listener -Dlanguage=Cpp HogQLParser.g4", diff --git a/posthog/api/insight.py b/posthog/api/insight.py index a31f2dd9dbe05..20ec5e93d0619 100644 --- a/posthog/api/insight.py +++ b/posthog/api/insight.py @@ -21,7 +21,6 @@ from rest_framework.settings import api_settings from rest_framework_csv import renderers as csvrenderers from sentry_sdk import capture_exception -from statshog.defaults.django import statsd from posthog import schema from posthog.api.documentation import extend_schema @@ -32,6 +31,7 @@ TrendResultsSerializer, TrendSerializer, ) +from posthog.clickhouse.cancel import cancel_query_on_cluster from posthog.api.routing import StructuredViewSetMixin from posthog.api.shared import UserBasicSerializer from posthog.api.tagged_item import TaggedItemSerializerMixin, TaggedItemViewSetMixin @@ -43,7 +43,6 @@ synchronously_update_cache, ) from posthog.caching.insights_api import should_refresh_insight -from posthog.client import sync_execute from posthog.constants import ( BREAKDOWN_VALUES_LIMIT, INSIGHT, @@ -95,7 +94,6 @@ ClickHouseSustainedRateThrottle, ) from posthog.settings import CAPTURE_TIME_TO_SEE_DATA, SITE_URL -from posthog.settings.data_stores import CLICKHOUSE_CLUSTER from prometheus_client import Counter from posthog.user_permissions import UserPermissionsSerializerMixin from posthog.utils import ( @@ -1034,11 +1032,7 @@ def activity(self, request: request.Request, **kwargs): def cancel(self, request: request.Request, **kwargs): if "client_query_id" not in request.data: raise serializers.ValidationError({"client_query_id": "Field is required."}) - sync_execute( - f"KILL QUERY ON CLUSTER '{CLICKHOUSE_CLUSTER}' WHERE query_id LIKE %(client_query_id)s", - {"client_query_id": f"{self.team.pk}_{request.data['client_query_id']}%"}, - ) - statsd.incr("clickhouse.query.cancellation_requested", tags={"team_id": self.team.pk}) + cancel_query_on_cluster(team_id=self.team.pk, client_query_id=request.data["client_query_id"]) return Response(status=status.HTTP_201_CREATED) @action(methods=["POST"], detail=False) diff --git a/posthog/api/query.py b/posthog/api/query.py index 224aedce40464..021139911cb96 100644 --- a/posthog/api/query.py +++ b/posthog/api/query.py @@ -1,11 +1,11 @@ import json import re -from typing import Dict, Optional, cast, Any, List +import uuid +from typing import Dict -from django.http import HttpResponse, JsonResponse +from django.http import JsonResponse from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, OpenApiResponse -from pydantic import BaseModel from rest_framework import viewsets from rest_framework.decorators import action from rest_framework.exceptions import ParseError, ValidationError, NotAuthenticated @@ -17,46 +17,31 @@ from posthog import schema from posthog.api.documentation import extend_schema +from posthog.api.services.query import process_query from posthog.api.routing import StructuredViewSetMixin +from posthog.clickhouse.client.execute_async import ( + cancel_query, + enqueue_process_query_task, + get_query_status, +) from posthog.clickhouse.query_tagging import tag_queries from posthog.errors import ExposedCHQueryError from posthog.hogql.ai import PromptUnclear, write_sql_from_prompt -from posthog.hogql.database.database import create_hogql_database, serialize_database from posthog.hogql.errors import HogQLException -from posthog.hogql.metadata import get_hogql_metadata -from posthog.hogql.modifiers import create_default_modifiers_for_team -from posthog.hogql_queries.query_runner import get_query_runner -from posthog.models import Team from posthog.models.user import User from posthog.permissions import ( ProjectMembershipNecessaryPermissions, TeamMemberAccessPermission, ) -from posthog.queries.time_to_see_data.serializers import ( - SessionEventsQuerySerializer, - SessionsQuerySerializer, -) -from posthog.queries.time_to_see_data.sessions import get_session_events, get_sessions from posthog.rate_limit import ( AIBurstRateThrottle, AISustainedRateThrottle, TeamRateThrottle, ) -from posthog.schema import HogQLMetadata +from posthog.schema import QueryStatus from posthog.utils import refresh_requested_by_client -QUERY_WITH_RUNNER = [ - "LifecycleQuery", - "TrendsQuery", - "WebOverviewQuery", - "WebTopSourcesQuery", - "WebTopClicksQuery", - "WebTopPagesQuery", - "WebStatsTableQuery", -] -QUERY_WITH_RUNNER_NO_CACHE = ["EventsQuery", "PersonsQuery", "HogQLQuery", "SessionsTimelineQuery"] - class QueryThrottle(TeamRateThrottle): scope = "query" @@ -116,40 +101,73 @@ def get_throttles(self): OpenApiParameter( "client_query_id", OpenApiTypes.STR, - description="Client provided query ID. Can be used to cancel queries.", + description="Client provided query ID. Can be used to retrieve the status or cancel the query.", + ), + OpenApiParameter( + "async", + OpenApiTypes.BOOL, + description=( + "(Experimental) " + "Whether to run the query asynchronously. Defaults to False." + " If True, the `id` of the query can be used to check the status and to cancel it." + ), ), ], responses={ 200: OpenApiResponse(description="Query results"), }, ) - def list(self, request: Request, **kw) -> HttpResponse: - self._tag_client_query_id(request.GET.get("client_query_id")) - query_json = QuerySchemaParser.validate_query(self._query_json_from_request(request)) - # allow lists as well as dicts in response with safe=False - try: - return JsonResponse(process_query(self.team, query_json, request=request), safe=False) - except HogQLException as e: - raise ValidationError(str(e)) - except ExposedCHQueryError as e: - raise ValidationError(str(e), e.code_name) - - def post(self, request, *args, **kwargs): + def create(self, request, *args, **kwargs) -> JsonResponse: request_json = request.data query_json = request_json.get("query") - self._tag_client_query_id(request_json.get("client_query_id")) - # allow lists as well as dicts in response with safe=False + query_async = request_json.get("async", False) + refresh_requested = refresh_requested_by_client(request) + + client_query_id = request_json.get("client_query_id") or uuid.uuid4().hex + self._tag_client_query_id(client_query_id) + + if query_async: + query_id = enqueue_process_query_task( + team_id=self.team.pk, + query_json=query_json, + query_id=client_query_id, + refresh_requested=refresh_requested, + ) + return JsonResponse(QueryStatus(id=query_id, team_id=self.team.pk).model_dump(), safe=False) + try: - return JsonResponse(process_query(self.team, query_json, request=request), safe=False) - except HogQLException as e: - raise ValidationError(str(e)) - except ExposedCHQueryError as e: - raise ValidationError(str(e), e.code_name) + result = process_query(self.team, query_json, refresh_requested=refresh_requested) + return JsonResponse(result, safe=False) + except (HogQLException, ExposedCHQueryError) as e: + raise ValidationError(str(e), getattr(e, "code_name", None)) except Exception as e: self.handle_column_ch_error(e) capture_exception(e) raise e + @extend_schema( + description="(Experimental)", + responses={ + 200: OpenApiResponse(description="Query status"), + }, + ) + @extend_schema( + description="(Experimental)", + responses={ + 200: OpenApiResponse(description="Query status"), + }, + ) + def retrieve(self, request: Request, pk=None, *args, **kwargs) -> JsonResponse: + status = get_query_status(team_id=self.team.pk, query_id=pk) + return JsonResponse(status.__dict__, safe=False) + + @extend_schema( + description="(Experimental)", + ) + def destroy(self, request, pk=None, *args, **kwargs): + cancel_query(self.team.pk, pk) + return Response(status=204) + @action(methods=["GET"], detail=False) def draft_sql(self, request: Request, *args, **kwargs) -> Response: if not isinstance(request.user, User): @@ -177,8 +195,10 @@ def handle_column_ch_error(self, error): return def _tag_client_query_id(self, query_id: str | None): - if query_id is not None: - tag_queries(client_query_id=query_id) + if query_id is None: + return + + tag_queries(client_query_id=query_id) def _query_json_from_request(self, request): if request.method == "POST": @@ -205,73 +225,3 @@ def parsing_error(ex): except (json.JSONDecodeError, UnicodeDecodeError) as error_main: raise ValidationError("Invalid JSON: %s" % (str(error_main))) return query - - -def _unwrap_pydantic(response: Any) -> Dict | List: - if isinstance(response, list): - return [_unwrap_pydantic(item) for item in response] - - elif isinstance(response, BaseModel): - resp1: Dict[str, Any] = {} - for key in response.__fields__.keys(): - resp1[key] = _unwrap_pydantic(getattr(response, key)) - return resp1 - - elif isinstance(response, dict): - resp2: Dict[str, Any] = {} - for key in response.keys(): - resp2[key] = _unwrap_pydantic(response.get(key)) - return resp2 - - return response - - -def _unwrap_pydantic_dict(response: Any) -> Dict: - return cast(dict, _unwrap_pydantic(response)) - - -def process_query( - team: Team, - query_json: Dict, - in_export_context: Optional[bool] = False, - request: Optional[Request] = None, -) -> Dict: - # query_json has been parsed by QuerySchemaParser - # it _should_ be impossible to end up in here with a "bad" query - query_kind = query_json.get("kind") - tag_queries(query=query_json) - - if query_kind in QUERY_WITH_RUNNER: - refresh_requested = refresh_requested_by_client(request) if request else False - query_runner = get_query_runner(query_json, team, in_export_context=in_export_context) - return _unwrap_pydantic_dict(query_runner.run(refresh_requested=refresh_requested)) - elif query_kind in QUERY_WITH_RUNNER_NO_CACHE: - query_runner = get_query_runner(query_json, team, in_export_context=in_export_context) - return _unwrap_pydantic_dict(query_runner.calculate()) - elif query_kind == "HogQLMetadata": - metadata_query = HogQLMetadata.model_validate(query_json) - metadata_response = get_hogql_metadata(query=metadata_query, team=team) - return _unwrap_pydantic_dict(metadata_response) - elif query_kind == "DatabaseSchemaQuery": - database = create_hogql_database(team.pk, modifiers=create_default_modifiers_for_team(team)) - return serialize_database(database) - elif query_kind == "TimeToSeeDataSessionsQuery": - sessions_query_serializer = SessionsQuerySerializer(data=query_json) - sessions_query_serializer.is_valid(raise_exception=True) - return {"results": get_sessions(sessions_query_serializer).data} - elif query_kind == "TimeToSeeDataQuery": - serializer = SessionEventsQuerySerializer( - data={ - "team_id": team.pk, - "session_start": query_json["sessionStart"], - "session_end": query_json["sessionEnd"], - "session_id": query_json["sessionId"], - } - ) - serializer.is_valid(raise_exception=True) - return get_session_events(serializer) or {} - else: - if query_json.get("source"): - return process_query(team, query_json["source"]) - - raise ValidationError(f"Unsupported query kind: {query_kind}") diff --git a/posthog/api/services/__init__.py b/posthog/api/services/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/posthog/api/services/query.py b/posthog/api/services/query.py new file mode 100644 index 0000000000000..1ef831bde1b82 --- /dev/null +++ b/posthog/api/services/query.py @@ -0,0 +1,97 @@ +import structlog +from typing import Any, Dict, List, Optional, cast + +from pydantic import BaseModel +from rest_framework.exceptions import ValidationError + +from posthog.clickhouse.query_tagging import tag_queries +from posthog.hogql.database.database import create_hogql_database, serialize_database +from posthog.hogql.metadata import get_hogql_metadata +from posthog.hogql.modifiers import create_default_modifiers_for_team +from posthog.hogql_queries.query_runner import get_query_runner +from posthog.models import Team +from posthog.queries.time_to_see_data.serializers import SessionEventsQuerySerializer, SessionsQuerySerializer +from posthog.queries.time_to_see_data.sessions import get_session_events, get_sessions +from posthog.schema import HogQLMetadata + +logger = structlog.get_logger(__name__) + +QUERY_WITH_RUNNER = [ + "LifecycleQuery", + "TrendsQuery", + "WebOverviewQuery", + "WebTopSourcesQuery", + "WebTopClicksQuery", + "WebTopPagesQuery", + "WebStatsTableQuery", +] +QUERY_WITH_RUNNER_NO_CACHE = ["EventsQuery", "PersonsQuery", "HogQLQuery", "SessionsTimelineQuery"] + + +def _unwrap_pydantic(response: Any) -> Dict | List: + if isinstance(response, list): + return [_unwrap_pydantic(item) for item in response] + + elif isinstance(response, BaseModel): + resp1: Dict[str, Any] = {} + for key in response.__fields__.keys(): + resp1[key] = _unwrap_pydantic(getattr(response, key)) + return resp1 + + elif isinstance(response, dict): + resp2: Dict[str, Any] = {} + for key in response.keys(): + resp2[key] = _unwrap_pydantic(response.get(key)) + return resp2 + + return response + + +def _unwrap_pydantic_dict(response: Any) -> Dict: + return cast(dict, _unwrap_pydantic(response)) + + +def process_query( + team: Team, + query_json: Dict, + in_export_context: Optional[bool] = False, + refresh_requested: Optional[bool] = False, +) -> Dict: + # query_json has been parsed by QuerySchemaParser + # it _should_ be impossible to end up in here with a "bad" query + query_kind = query_json.get("kind") + tag_queries(query=query_json) + + if query_kind in QUERY_WITH_RUNNER: + query_runner = get_query_runner(query_json, team, in_export_context=in_export_context) + return _unwrap_pydantic_dict(query_runner.run(refresh_requested=refresh_requested)) + elif query_kind in QUERY_WITH_RUNNER_NO_CACHE: + query_runner = get_query_runner(query_json, team, in_export_context=in_export_context) + return _unwrap_pydantic_dict(query_runner.calculate()) + elif query_kind == "HogQLMetadata": + metadata_query = HogQLMetadata.model_validate(query_json) + metadata_response = get_hogql_metadata(query=metadata_query, team=team) + return _unwrap_pydantic_dict(metadata_response) + elif query_kind == "DatabaseSchemaQuery": + database = create_hogql_database(team.pk, modifiers=create_default_modifiers_for_team(team)) + return serialize_database(database) + elif query_kind == "TimeToSeeDataSessionsQuery": + sessions_query_serializer = SessionsQuerySerializer(data=query_json) + sessions_query_serializer.is_valid(raise_exception=True) + return {"results": get_sessions(sessions_query_serializer).data} + elif query_kind == "TimeToSeeDataQuery": + serializer = SessionEventsQuerySerializer( + data={ + "team_id": team.pk, + "session_start": query_json["sessionStart"], + "session_end": query_json["sessionEnd"], + "session_id": query_json["sessionId"], + } + ) + serializer.is_valid(raise_exception=True) + return get_session_events(serializer) or {} + else: + if query_json.get("source"): + return process_query(team, query_json["source"]) + + raise ValidationError(f"Unsupported query kind: {query_kind}") diff --git a/posthog/api/test/test_query.py b/posthog/api/test/test_query.py index b49cd25b83287..ff03704605014 100644 --- a/posthog/api/test/test_query.py +++ b/posthog/api/test/test_query.py @@ -1,11 +1,11 @@ import json +from unittest import mock from unittest.mock import patch -from urllib.parse import quote from freezegun import freeze_time from rest_framework import status -from posthog.api.query import process_query +from posthog.api.services.query import process_query from posthog.models.property_definition import PropertyDefinition, PropertyType from posthog.models.utils import UUIDT from posthog.schema import ( @@ -336,51 +336,9 @@ def test_person_property_filter(self): response = self.client.post(f"/api/projects/{self.team.id}/query/", {"query": query.dict()}).json() self.assertEqual(len(response["results"]), 2) - def test_json_undefined_constant_error(self): - response = self.client.get( - f"/api/projects/{self.team.id}/query/?query=%7B%22kind%22%3A%22EventsQuery%22%2C%22select%22%3A%5B%22*%22%5D%2C%22limit%22%3AInfinity%7D" - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual( - response.json(), - { - "type": "validation_error", - "code": "invalid_input", - "detail": "Unsupported constant found in JSON: Infinity", - "attr": None, - }, - ) - - response = self.client.get( - f"/api/projects/{self.team.id}/query/?query=%7B%22kind%22%3A%22EventsQuery%22%2C%22select%22%3A%5B%22*%22%5D%2C%22limit%22%3ANaN%7D" - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual( - response.json(), - { - "type": "validation_error", - "code": "invalid_input", - "detail": "Unsupported constant found in JSON: NaN", - "attr": None, - }, - ) - def test_safe_clickhouse_error_passed_through(self): query = {"kind": "EventsQuery", "select": ["timestamp + 'string'"]} - # Safe errors are passed through in GET requests - response_get = self.client.get(f"/api/projects/{self.team.id}/query/?query={quote(json.dumps(query))}") - self.assertEqual(response_get.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual( - response_get.json(), - self.validation_error_response( - "Illegal types DateTime64(6, 'UTC') and String of arguments of function plus: " - "While processing toTimeZone(timestamp, 'UTC') + 'string'.", - "illegal_type_of_argument", - ), - ) - - # Safe errors are passed through in POST requests too response_post = self.client.post(f"/api/projects/{self.team.id}/query/", {"query": query}) self.assertEqual(response_post.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual( @@ -396,11 +354,6 @@ def test_safe_clickhouse_error_passed_through(self): def test_unsafe_clickhouse_error_is_swallowed(self, sqlparse_format_mock): query = {"kind": "EventsQuery", "select": ["timestamp"]} - # Unsafe errors are swallowed in GET requests (in this case we should not expose malformed SQL) - response_get = self.client.get(f"/api/projects/{self.team.id}/query/?query={quote(json.dumps(query))}") - self.assertEqual(response_get.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - - # Unsafe errors are swallowed in POST requests too response_post = self.client.post(f"/api/projects/{self.team.id}/query/", {"query": query}) self.assertEqual(response_post.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) @@ -832,3 +785,87 @@ def test_full_hogql_query_values(self): ) self.assertEqual(response.get("results", [])[0][0], 20) + + +class TestQueryRetrieve(APIBaseTest): + def setUp(self): + super().setUp() + self.team_id = self.team.pk + self.valid_query_id = "12345" + self.invalid_query_id = "invalid-query-id" + self.redis_client_mock = mock.Mock() + self.redis_get_patch = mock.patch("posthog.redis.get_client", return_value=self.redis_client_mock) + self.redis_get_patch.start() + + def tearDown(self): + self.redis_get_patch.stop() + + def test_with_valid_query_id(self): + self.redis_client_mock.get.return_value = json.dumps( + { + "id": self.valid_query_id, + "team_id": self.team_id, + "error": False, + "complete": True, + "results": ["result1", "result2"], + } + ).encode() + response = self.client.get(f"/api/projects/{self.team.id}/query/{self.valid_query_id}/") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["complete"], True, response.content) + + def test_with_invalid_query_id(self): + self.redis_client_mock.get.return_value = None + response = self.client.get(f"/api/projects/{self.team.id}/query/{self.invalid_query_id}/") + self.assertEqual(response.status_code, 404) + + def test_completed_query(self): + self.redis_client_mock.get.return_value = json.dumps( + { + "id": self.valid_query_id, + "team_id": self.team_id, + "complete": True, + "results": ["result1", "result2"], + } + ).encode() + response = self.client.get(f"/api/projects/{self.team.id}/query/{self.valid_query_id}/") + self.assertEqual(response.status_code, 200) + self.assertTrue(response.json()["complete"]) + + def test_running_query(self): + self.redis_client_mock.get.return_value = json.dumps( + { + "id": self.valid_query_id, + "team_id": self.team_id, + "complete": False, + } + ).encode() + response = self.client.get(f"/api/projects/{self.team.id}/query/{self.valid_query_id}/") + self.assertEqual(response.status_code, 200) + self.assertFalse(response.json()["complete"]) + + def test_failed_query(self): + self.redis_client_mock.get.return_value = json.dumps( + { + "id": self.valid_query_id, + "team_id": self.team_id, + "error": True, + "error_message": "Query failed", + } + ).encode() + response = self.client.get(f"/api/projects/{self.team.id}/query/{self.valid_query_id}/") + self.assertEqual(response.status_code, 200) + self.assertTrue(response.json()["error"]) + + def test_destroy(self): + self.redis_client_mock.get.return_value = json.dumps( + { + "id": self.valid_query_id, + "team_id": self.team_id, + "error": True, + "error_message": "Query failed", + } + ).encode() + response = self.client.delete(f"/api/projects/{self.team.id}/query/{self.valid_query_id}/") + self.assertEqual(response.status_code, 204) + self.redis_client_mock.delete.assert_called_once() diff --git a/posthog/caching/calculate_results.py b/posthog/caching/calculate_results.py index be11c4ffe48b5..f7ee632e2ad48 100644 --- a/posthog/caching/calculate_results.py +++ b/posthog/caching/calculate_results.py @@ -141,7 +141,7 @@ def calculate_for_query_based_insight( ) # local import to avoid circular reference - from posthog.api.query import process_query + from posthog.api.services.query import process_query # TODO need to properly check that hogql is enabled? return cache_key, cache_type, process_query(team, insight.query, True) diff --git a/posthog/celery.py b/posthog/celery.py index 374c90db7e4ec..53c67214783ee 100644 --- a/posthog/celery.py +++ b/posthog/celery.py @@ -395,24 +395,19 @@ def redis_heartbeat(): @app.task(ignore_result=True, bind=True) -def enqueue_clickhouse_execute_with_progress( - self, team_id, query_id, query, args=None, settings=None, with_column_types=False -): +def process_query_task(self, team_id, query_id, query_json, in_export_context=False, refresh_requested=False): """ - Kick off query with progress reporting - Iterate over the progress status - Save status to redis + Kick off query Once complete save results to redis """ - from posthog.client import execute_with_progress - - execute_with_progress( - team_id, - query_id, - query, - args, - settings, - with_column_types, + from posthog.client import execute_process_query + + execute_process_query( + team_id=team_id, + query_id=query_id, + query_json=query_json, + in_export_context=in_export_context, + refresh_requested=refresh_requested, task_id=self.request.id, ) diff --git a/posthog/clickhouse/cancel.py b/posthog/clickhouse/cancel.py new file mode 100644 index 0000000000000..e05eea7ad3d64 --- /dev/null +++ b/posthog/clickhouse/cancel.py @@ -0,0 +1,14 @@ +from statshog.defaults.django import statsd + +from posthog.api.services.query import logger +from posthog.clickhouse.client import sync_execute +from posthog.settings import CLICKHOUSE_CLUSTER + + +def cancel_query_on_cluster(team_id: int, client_query_id: str) -> None: + result = sync_execute( + f"KILL QUERY ON CLUSTER '{CLICKHOUSE_CLUSTER}' WHERE query_id LIKE %(client_query_id)s", + {"client_query_id": f"{team_id}_{client_query_id}%"}, + ) + logger.info("Cancelled query %s for team %s, result: %s", client_query_id, team_id, result) + statsd.incr("clickhouse.query.cancellation_requested", tags={"team_id": team_id}) diff --git a/posthog/clickhouse/client/__init__.py b/posthog/clickhouse/client/__init__.py index f2ad255c395e1..a249ebbabb4ad 100644 --- a/posthog/clickhouse/client/__init__.py +++ b/posthog/clickhouse/client/__init__.py @@ -1,8 +1,8 @@ from posthog.clickhouse.client.execute import query_with_columns, sync_execute -from posthog.clickhouse.client.execute_async import execute_with_progress +from posthog.clickhouse.client.execute_async import execute_process_query __all__ = [ "sync_execute", "query_with_columns", - "execute_with_progress", + "execute_process_query", ] diff --git a/posthog/clickhouse/client/execute_async.py b/posthog/clickhouse/client/execute_async.py index 3bb28c3f20075..fc9e292b08ee4 100644 --- a/posthog/clickhouse/client/execute_async.py +++ b/posthog/clickhouse/client/execute_async.py @@ -1,172 +1,94 @@ -import hashlib +import datetime import json -import time -from dataclasses import asdict as dataclass_asdict -from dataclasses import dataclass -from time import perf_counter -from typing import Any, Optional - -from posthog import celery -from clickhouse_driver import Client as SyncClient -from django.conf import settings as app_settings -from statshog.defaults.django import statsd - -from posthog import redis -from posthog.celery import enqueue_clickhouse_execute_with_progress -from posthog.clickhouse.client.execute import _prepare_query -from posthog.errors import wrap_query_error -from posthog.settings import ( - CLICKHOUSE_CA, - CLICKHOUSE_DATABASE, - CLICKHOUSE_HOST, - CLICKHOUSE_PASSWORD, - CLICKHOUSE_SECURE, - CLICKHOUSE_USER, - CLICKHOUSE_VERIFY, -) - -REDIS_STATUS_TTL = 600 # 10 minutes - - -@dataclass -class QueryStatus: - team_id: int - num_rows: float = 0 - total_rows: float = 0 - error: bool = False - complete: bool = False - error_message: str = "" - results: Any = None - start_time: Optional[float] = None - end_time: Optional[float] = None - task_id: Optional[str] = None - - -def generate_redis_results_key(query_id): - REDIS_KEY_PREFIX_ASYNC_RESULTS = "query_with_progress" - key = f"{REDIS_KEY_PREFIX_ASYNC_RESULTS}:{query_id}" - return key - - -def execute_with_progress( +import uuid + +import structlog +from rest_framework.exceptions import NotFound + +from posthog import celery, redis +from posthog.celery import process_query_task +from posthog.clickhouse.query_tagging import tag_queries +from posthog.schema import QueryStatus + +logger = structlog.get_logger(__name__) + +REDIS_STATUS_TTL_SECONDS = 600 # 10 minutes +REDIS_KEY_PREFIX_ASYNC_RESULTS = "query_async" + + +class QueryNotFoundError(NotFound): + pass + + +class QueryRetrievalError(Exception): + pass + + +def generate_redis_results_key(query_id: str, team_id: int) -> str: + return f"{REDIS_KEY_PREFIX_ASYNC_RESULTS}:{team_id}:{query_id}" + + +def execute_process_query( team_id, query_id, - query, - args=None, - settings=None, - with_column_types=False, - update_freq=0.2, + query_json, + in_export_context, + refresh_requested, task_id=None, ): - """ - Kick off query with progress reporting - Iterate over the progress status - Save status to redis - Once complete save results to redis - """ - - key = generate_redis_results_key(query_id) - ch_client = SyncClient( - host=CLICKHOUSE_HOST, - database=CLICKHOUSE_DATABASE, - secure=CLICKHOUSE_SECURE, - user=CLICKHOUSE_USER, - password=CLICKHOUSE_PASSWORD, - ca_certs=CLICKHOUSE_CA, - verify=CLICKHOUSE_VERIFY, - settings={"max_result_rows": "10000"}, - ) + key = generate_redis_results_key(query_id, team_id) redis_client = redis.get_client() - start_time = perf_counter() - - prepared_sql, prepared_args, tags = _prepare_query(client=ch_client, query=query, args=args) + from posthog.models import Team + from posthog.api.services.query import process_query - query_status = QueryStatus(team_id, task_id=task_id) + team = Team.objects.get(pk=team_id) - start_time = time.time() + query_status = QueryStatus( + id=query_id, + team_id=team_id, + task_id=task_id, + complete=False, + error=True, # Assume error in case nothing below ends up working + start_time=datetime.datetime.utcnow(), + ) + value = query_status.model_dump_json() try: - progress = ch_client.execute_with_progress( - prepared_sql, - params=prepared_args, - settings=settings, - with_column_types=with_column_types, + tag_queries(client_query_id=query_id, team_id=team_id) + results = process_query( + team=team, query_json=query_json, in_export_context=in_export_context, refresh_requested=refresh_requested ) - for num_rows, total_rows in progress: - query_status = QueryStatus( - team_id=team_id, - num_rows=num_rows, - total_rows=total_rows, - complete=False, - error=False, - error_message="", - results=None, - start_time=start_time, - task_id=task_id, - ) - redis_client.set(key, json.dumps(dataclass_asdict(query_status)), ex=REDIS_STATUS_TTL) - time.sleep(update_freq) - else: - rv = progress.get_result() - query_status = QueryStatus( - team_id=team_id, - num_rows=query_status.num_rows, - total_rows=query_status.total_rows, - complete=True, - error=False, - start_time=query_status.start_time, - end_time=time.time(), - error_message="", - results=rv, - task_id=task_id, - ) - redis_client.set(key, json.dumps(dataclass_asdict(query_status)), ex=REDIS_STATUS_TTL) - + logger.info("Got results for team %s query %s", team_id, query_id) + query_status.complete = True + query_status.error = False + query_status.results = results + query_status.expiration_time = datetime.datetime.utcnow() + datetime.timedelta(seconds=REDIS_STATUS_TTL_SECONDS) + query_status.end_time = datetime.datetime.utcnow() + value = query_status.model_dump_json() except Exception as err: - err = wrap_query_error(err) - tags["failed"] = True - tags["reason"] = type(err).__name__ - statsd.incr("clickhouse_sync_execution_failure") - query_status = QueryStatus( - team_id=team_id, - num_rows=query_status.num_rows, - total_rows=query_status.total_rows, - complete=False, - error=True, - start_time=query_status.start_time, - end_time=time.time(), - error_message=str(err), - results=None, - task_id=task_id, - ) - redis_client.set(key, json.dumps(dataclass_asdict(query_status)), ex=REDIS_STATUS_TTL) - + query_status.results = None # Clear results in case they are faulty + query_status.error_message = str(err) + logger.error("Error processing query for team %s query %s: %s", team_id, query_id, err) + value = query_status.model_dump_json() raise err finally: - ch_client.disconnect() + redis_client.set(key, value, ex=REDIS_STATUS_TTL_SECONDS) - execution_time = perf_counter() - start_time - statsd.timing("clickhouse_sync_execution_time", execution_time * 1000.0) - - if app_settings.SHELL_PLUS_PRINT_SQL: - print("Execution time: %.6fs" % (execution_time,)) # noqa T201 - - -def enqueue_execute_with_progress( +def enqueue_process_query_task( team_id, - query, - args=None, - settings=None, - with_column_types=False, - bypass_celery=False, + query_json, query_id=None, + refresh_requested=False, + in_export_context=False, + bypass_celery=False, force=False, ): if not query_id: - query_id = _query_hash(query, team_id, args) - key = generate_redis_results_key(query_id) + query_id = uuid.uuid4().hex + + key = generate_redis_results_key(query_id, team_id) redis_client = redis.get_client() if force: @@ -187,49 +109,55 @@ def enqueue_execute_with_progress( # If we've seen this query before return the query_id and don't resubmit it. return query_id - # Immediately set status so we don't have race with celery - query_status = QueryStatus(team_id=team_id, start_time=time.time()) - redis_client.set(key, json.dumps(dataclass_asdict(query_status)), ex=REDIS_STATUS_TTL) + # Immediately set status, so we don't have race with celery + query_status = QueryStatus(id=query_id, team_id=team_id) + redis_client.set(key, query_status.model_dump_json(), ex=REDIS_STATUS_TTL_SECONDS) if bypass_celery: # Call directly ( for testing ) - enqueue_clickhouse_execute_with_progress(team_id, query_id, query, args, settings, with_column_types) + process_query_task( + team_id, query_id, query_json, in_export_context=in_export_context, refresh_requested=refresh_requested + ) else: - enqueue_clickhouse_execute_with_progress.delay(team_id, query_id, query, args, settings, with_column_types) + task = process_query_task.delay( + team_id, query_id, query_json, in_export_context=in_export_context, refresh_requested=refresh_requested + ) + query_status.task_id = task.id + redis_client.set(key, query_status.model_dump_json(), ex=REDIS_STATUS_TTL_SECONDS) return query_id -def get_status_or_results(team_id, query_id): - """ - Returns QueryStatus data class - QueryStatus data class contains either: - Current status of running query - Results of completed query - Error payload of failed query - """ +def get_query_status(team_id, query_id): redis_client = redis.get_client() - key = generate_redis_results_key(query_id) + key = generate_redis_results_key(query_id, team_id) + try: byte_results = redis_client.get(key) - if byte_results: - str_results = byte_results.decode("utf-8") - else: - return QueryStatus(team_id, error=True, error_message="Query is unknown to backend") - query_status = QueryStatus(**json.loads(str_results)) - if query_status.team_id != team_id: - raise Exception("Requesting team is not executing team") except Exception as e: - query_status = QueryStatus(team_id, error=True, error_message=str(e)) - return query_status + raise QueryRetrievalError(f"Error retrieving query {query_id} for team {team_id}") from e + if not byte_results: + raise QueryNotFoundError(f"Query {query_id} not found for team {team_id}") -def _query_hash(query: str, team_id: int, args: Any) -> str: - """ - Takes a query and returns a hex encoded hash of the query and args - """ - if args: - key = hashlib.md5((str(team_id) + query + json.dumps(args)).encode("utf-8")).hexdigest() - else: - key = hashlib.md5((str(team_id) + query).encode("utf-8")).hexdigest() - return key + return QueryStatus(**json.loads(byte_results)) + + +def cancel_query(team_id, query_id): + query_status = get_query_status(team_id, query_id) + + if query_status.task_id: + logger.info("Got task id %s, attempting to revoke", query_status.task_id) + celery.app.control.revoke(query_status.task_id, terminate=True) + + from posthog.clickhouse.cancel import cancel_query_on_cluster + + logger.info("Revoked task id %s, attempting to cancel on cluster", query_status.task_id) + cancel_query_on_cluster(team_id, query_id) + + redis_client = redis.get_client() + key = generate_redis_results_key(query_id, team_id) + logger.info("Deleting redis query key %s", key) + redis_client.delete(key) + + return True diff --git a/posthog/clickhouse/client/test/test_execute_async.py b/posthog/clickhouse/client/test/test_execute_async.py new file mode 100644 index 0000000000000..1ab4bf49e03d3 --- /dev/null +++ b/posthog/clickhouse/client/test/test_execute_async.py @@ -0,0 +1,152 @@ +import uuid +from unittest.mock import patch + +from django.test import TestCase + +from posthog.clickhouse.client import execute_async as client +from posthog.client import sync_execute +from posthog.hogql.errors import HogQLException +from posthog.models import Organization, Team +from posthog.test.base import ClickhouseTestMixin + + +def build_query(sql): + return { + "kind": "HogQLQuery", + "query": sql, + } + + +class ClickhouseClientTestCase(TestCase, ClickhouseTestMixin): + def setUp(self): + self.organization = Organization.objects.create(name="test") + self.team = Team.objects.create(organization=self.organization) + self.team_id = self.team.pk + + def test_async_query_client(self): + query = build_query("SELECT 1+1") + team_id = self.team_id + query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True) + result = client.get_query_status(team_id, query_id) + self.assertFalse(result.error, result.error_message) + self.assertTrue(result.complete) + self.assertEqual(result.results["results"], [[2]]) + + def test_async_query_client_errors(self): + query = build_query("SELECT WOW SUCH DATA FROM NOWHERE THIS WILL CERTAINLY WORK") + self.assertRaises( + HogQLException, + client.enqueue_process_query_task, + **{"team_id": (self.team_id), "query_json": query, "bypass_celery": True}, + ) + query_id = uuid.uuid4().hex + try: + client.enqueue_process_query_task(self.team_id, query, query_id=query_id, bypass_celery=True) + except Exception: + pass + + result = client.get_query_status(self.team_id, query_id) + self.assertTrue(result.error) + self.assertRegex(result.error_message, "Unknown table") + + def test_async_query_client_uuid(self): + query = build_query("SELECT toUUID('00000000-0000-0000-0000-000000000000')") + team_id = self.team_id + query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True) + result = client.get_query_status(team_id, query_id) + self.assertFalse(result.error, result.error_message) + self.assertTrue(result.complete) + self.assertEqual(result.results["results"], [["00000000-0000-0000-0000-000000000000"]]) + + def test_async_query_client_does_not_leak(self): + query = build_query("SELECT 1+1") + team_id = self.team_id + wrong_team = 5 + query_id = client.enqueue_process_query_task(team_id, query, bypass_celery=True) + + try: + client.get_query_status(wrong_team, query_id) + except Exception as e: + self.assertEqual(str(e), f"Query {query_id} not found for team {wrong_team}") + + @patch("posthog.clickhouse.client.execute_async.process_query_task") + def test_async_query_client_is_lazy(self, execute_sync_mock): + query = build_query("SELECT 4 + 4") + query_id = uuid.uuid4().hex + team_id = self.team_id + client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True) + + # Try the same query again + client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True) + + # Try the same query again (for good measure!) + client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True) + + # Assert that we only called clickhouse once + execute_sync_mock.assert_called_once() + + @patch("posthog.clickhouse.client.execute_async.process_query_task") + def test_async_query_client_is_lazy_but_not_too_lazy(self, execute_sync_mock): + query = build_query("SELECT 8 + 8") + query_id = uuid.uuid4().hex + team_id = self.team_id + client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True) + + # Try the same query again, but with force + client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True, force=True) + + # Try the same query again (for good measure!) + client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True) + + # Assert that we called clickhouse twice + self.assertEqual(execute_sync_mock.call_count, 2) + + @patch("posthog.clickhouse.client.execute_async.process_query_task") + def test_async_query_client_manual_query_uuid(self, execute_sync_mock): + # This is a unique test because technically in the test pattern `SELECT 8 + 8` is already + # in redis. This tests to make sure it is treated as a unique run of that query + query = build_query("SELECT 8 + 8") + team_id = self.team_id + query_id = "I'm so unique" + client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True) + + # Try the same query again, but with force + client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True, force=True) + + # Try the same query again (for good measure!) + client.enqueue_process_query_task(team_id, query, query_id=query_id, bypass_celery=True) + + # Assert that we called clickhouse twice + self.assertEqual(execute_sync_mock.call_count, 2) + + def test_client_strips_comments_from_request(self): + """ + To ensure we can easily copy queries from `system.query_log` in e.g. + Metabase, we strip comments from the query we send. Metabase doesn't + display multilined output. + + See https://github.com/metabase/metabase/issues/14253 + + Note I'm not really testing much complexity, I trust that those will + come out as failures in other tests. + """ + from posthog.clickhouse.query_tagging import tag_queries + + # First add in the request information that should be added to the sql. + # We check this to make sure it is not removed by the comment stripping + with self.capture_select_queries() as sqls: + tag_queries(kind="request", id="1") + sync_execute( + query=""" + -- this request returns 1 + SELECT 1 + """ + ) + self.assertEqual(len(sqls), 1) + first_query = sqls[0] + self.assertIn(f"SELECT 1", first_query) + self.assertNotIn("this request returns", first_query) + + # Make sure it still includes the "annotation" comment that includes + # request routing information for debugging purposes + self.assertIn("/* request:1 */", first_query) diff --git a/posthog/schema.py b/posthog/schema.py index be7bb8619a4ce..a2057f903768f 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -3,6 +3,7 @@ from __future__ import annotations +from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional, Union @@ -439,6 +440,23 @@ class PropertyOperator(str, Enum): max = "max" +class QueryStatus(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + complete: Optional[bool] = False + end_time: Optional[datetime] = None + error: Optional[bool] = False + error_message: Optional[str] = "" + expiration_time: Optional[datetime] = None + id: str + query_async: Optional[bool] = True + results: Optional[Any] = None + start_time: Optional[datetime] = None + task_id: Optional[str] = None + team_id: int + + class QueryTiming(BaseModel): model_config = ConfigDict( extra="forbid", diff --git a/posthog/tasks/exports/csv_exporter.py b/posthog/tasks/exports/csv_exporter.py index 622798774ec1d..8f6fffd0c9f90 100644 --- a/posthog/tasks/exports/csv_exporter.py +++ b/posthog/tasks/exports/csv_exporter.py @@ -7,7 +7,7 @@ from django.http import QueryDict from sentry_sdk import capture_exception, push_scope -from posthog.api.query import process_query +from posthog.api.services.query import process_query from posthog.jwt import PosthogJwtAudience, encode_jwt from posthog.models.exported_asset import ExportedAsset, save_content from posthog.utils import absolute_uri diff --git a/posthog/warehouse/api/test/test_view_link.py b/posthog/warehouse/api/test/test_view_link.py index 3a2dcae6bf160..0bcb57e187b86 100644 --- a/posthog/warehouse/api/test/test_view_link.py +++ b/posthog/warehouse/api/test/test_view_link.py @@ -2,7 +2,7 @@ APIBaseTest, ) from posthog.warehouse.models import DataWarehouseViewLink, DataWarehouseSavedQuery -from posthog.api.query import process_query +from posthog.api.services.query import process_query class TestViewLinkQuery(APIBaseTest): diff --git a/posthog/warehouse/models/datawarehouse_saved_query.py b/posthog/warehouse/models/datawarehouse_saved_query.py index bca809bb30912..9117fa7c4eaf0 100644 --- a/posthog/warehouse/models/datawarehouse_saved_query.py +++ b/posthog/warehouse/models/datawarehouse_saved_query.py @@ -47,7 +47,7 @@ class Meta: ] def get_columns(self) -> Dict[str, str]: - from posthog.api.query import process_query + from posthog.api.services.query import process_query # TODO: catch and raise error response = process_query(self.team, self.query)