From a1ab542f95d3e35b559e309e707d2994f1932f2b Mon Sep 17 00:00:00 2001 From: Julian Bez Date: Wed, 10 Jan 2024 11:10:50 +0100 Subject: [PATCH] Remove parser, opt for using model in simpler way --- posthog/api/mixins.py | 15 +++++++++++++++ posthog/api/parsers.py | 35 ---------------------------------- posthog/api/query.py | 13 +++---------- posthog/api/services/query.py | 32 +++++++++++++++++-------------- posthog/api/test/test_query.py | 20 +++++++++++++++---- 5 files changed, 52 insertions(+), 63 deletions(-) create mode 100644 posthog/api/mixins.py delete mode 100644 posthog/api/parsers.py diff --git a/posthog/api/mixins.py b/posthog/api/mixins.py new file mode 100644 index 0000000000000..44681fc246b1a --- /dev/null +++ b/posthog/api/mixins.py @@ -0,0 +1,15 @@ +from typing import TypeVar, Type, Generic + +from pydantic import BaseModel, ValidationError + +from rest_framework.exceptions import ParseError + +T = TypeVar("T", bound=BaseModel) + + +class PydanticModelMixin(Generic[T]): + def get_model(self, data: dict, model: Type[T]) -> T: + try: + return model.model_validate(data) + except ValidationError as exc: + raise ParseError("JSON parse error - %s" % str(exc)) diff --git a/posthog/api/parsers.py b/posthog/api/parsers.py deleted file mode 100644 index 6f8fc48794bc1..0000000000000 --- a/posthog/api/parsers.py +++ /dev/null @@ -1,35 +0,0 @@ -import codecs - -from pydantic import BaseModel, ValidationError -from django.conf import settings - -from rest_framework.parsers import JSONParser -from rest_framework.exceptions import ParseError - - -class PydanticJSONParser(JSONParser): - """ - Parses JSON-serialized data using Pydantic. - """ - - def parse(self, stream, media_type=None, parser_context=None): - """ - Parses the incoming bytestream as JSON and returns the resulting data. - - The view needs a pydantic_models attribute with a dictionary of action names to pydantic models. - This is needed because otherwise the parser doesn't know which model to use. - """ - pydantic_model: type[BaseModel] = getattr(parser_context["view"], "pydantic_models", {}).get( - parser_context["view"].action, None - ) - if not pydantic_model: - return super().parse(stream, media_type, parser_context) - - parser_context = parser_context or {} - encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) - decoded_stream = codecs.getreader(encoding)(stream) - - try: - return pydantic_model.model_validate_json(decoded_stream.read()) - except ValidationError as exc: - raise ParseError("JSON parse error - %s" % str(exc)) diff --git a/posthog/api/query.py b/posthog/api/query.py index 6ce77565c92ac..8423bcfb11995 100644 --- a/posthog/api/query.py +++ b/posthog/api/query.py @@ -2,7 +2,6 @@ import re import uuid -from pydantic import BaseModel from django.http import JsonResponse from drf_spectacular.utils import OpenApiResponse from rest_framework import viewsets @@ -14,7 +13,7 @@ from sentry_sdk import capture_exception from posthog.api.documentation import extend_schema -from posthog.api.parsers import PydanticJSONParser +from posthog.api.mixins import PydanticModelMixin from posthog.api.routing import StructuredViewSetMixin from posthog.api.services.query import process_query_model from posthog.clickhouse.client.execute_async import ( @@ -44,16 +43,12 @@ class QueryThrottle(TeamRateThrottle): rate = "120/hour" -class QueryViewSet(StructuredViewSetMixin, viewsets.ViewSet): +class QueryViewSet(PydanticModelMixin, StructuredViewSetMixin, viewsets.ViewSet): permission_classes = [ IsAuthenticated, ProjectMembershipNecessaryPermissions, TeamMemberAccessPermission, ] - parser_classes = (PydanticJSONParser,) - pydantic_models = { - "create": QueryRequest, - } def get_throttles(self): if self.action == "draft_sql": @@ -68,7 +63,7 @@ def get_throttles(self): }, ) def create(self, request, *args, **kwargs) -> Response: - data: QueryRequest = request.data + data = self.get_model(request.data, QueryRequest) client_query_id = data.client_query_id or uuid.uuid4().hex self._tag_client_query_id(client_query_id) @@ -84,8 +79,6 @@ def create(self, request, *args, **kwargs) -> Response: try: result = process_query_model(self.team, data.query, refresh_requested=data.refresh) - if isinstance(result, BaseModel): - return Response(result.model_dump()) return Response(result) except (HogQLException, ExposedCHQueryError) as e: raise ValidationError(str(e), getattr(e, "code_name", None)) diff --git a/posthog/api/services/query.py b/posthog/api/services/query.py index eaa5d98d5342e..2380fb933ab60 100644 --- a/posthog/api/services/query.py +++ b/posthog/api/services/query.py @@ -1,5 +1,5 @@ import structlog -from typing import Dict, Optional +from typing import Optional from pydantic import BaseModel from rest_framework.exceptions import ValidationError @@ -41,11 +41,11 @@ def process_query( team: Team, - query: QuerySchemaRoot, + query_json: dict, limit_context: Optional[LimitContext] = None, refresh_requested: Optional[bool] = False, -): - model = QuerySchemaRoot.model_validate(query) +) -> dict: + model = QuerySchemaRoot.model_validate(query_json) return process_query_model( team, model.root, @@ -59,26 +59,26 @@ def process_query_model( query: "QuerySchemaRoot.root", limit_context: Optional[LimitContext] = None, refresh_requested: Optional[bool] = False, -) -> Dict | BaseModel: +) -> dict: tag_queries(query=query.kind) # TODO: ? if isinstance(query, QUERY_WITH_RUNNER): query_runner = get_query_runner(query, team, limit_context=limit_context) - return query_runner.run(refresh_requested=refresh_requested) - if isinstance(query, QUERY_WITH_RUNNER_NO_CACHE): + result = query_runner.run(refresh_requested=refresh_requested) + elif isinstance(query, QUERY_WITH_RUNNER_NO_CACHE): query_runner = get_query_runner(query, team, limit_context=limit_context) - return query_runner.calculate() + result = query_runner.calculate() elif isinstance(query, HogQLMetadata): metadata_query = HogQLMetadata.model_validate(query) metadata_response = get_hogql_metadata(query=metadata_query, team=team) - return metadata_response + result = metadata_response elif isinstance(query, DatabaseSchemaQuery): database = create_hogql_database(team.pk, modifiers=create_default_modifiers_for_team(team)) - return serialize_database(database) + result = serialize_database(database) elif isinstance(query, TimeToSeeDataSessionsQuery): sessions_query_serializer = SessionsQuerySerializer(data=query) sessions_query_serializer.is_valid(raise_exception=True) - return {"results": get_sessions(sessions_query_serializer).data} + result = {"results": get_sessions(sessions_query_serializer).data} elif isinstance(query, TimeToSeeDataQuery): serializer = SessionEventsQuerySerializer( data={ @@ -89,8 +89,12 @@ def process_query_model( } ) serializer.is_valid(raise_exception=True) - return get_session_events(serializer) or {} + result = get_session_events(serializer) or {} elif query.source: - return process_query(team, query.source) + result = process_query(team, query.source) + else: + raise ValidationError(f"Unsupported query kind: {query.kind}") - raise ValidationError(f"Unsupported query kind: {query.kind}") + if isinstance(result, BaseModel): + return result.model_dump() + return result diff --git a/posthog/api/test/test_query.py b/posthog/api/test/test_query.py index b31393316562a..fdf7439617f0f 100644 --- a/posthog/api/test/test_query.py +++ b/posthog/api/test/test_query.py @@ -694,10 +694,22 @@ def test_property_definition_annotation_does_not_break_things(self): def test_invalid_query_kind(self): api_response = self.client.post(f"/api/projects/{self.team.id}/query/", {"query": {"kind": "Tomato Soup"}}) - assert api_response.status_code == 400 - assert api_response.json()["code"] == "parse_error" - assert "1 validation error for QueryRequest" in api_response.json()["detail"] - assert "type=literal_error, input_value='Tomato Soup'" in api_response.json()["detail"] + self.assertEqual(api_response.status_code, 400) + self.assertEqual(api_response.json()["code"], "parse_error") + self.assertIn("1 validation error for QueryRequest", api_response.json()["detail"], api_response.content) + self.assertIn( + "Input tag 'Tomato Soup' found using 'kind' does not match any of the expected tags", + api_response.json()["detail"], + api_response.content, + ) + + def test_missing_query(self): + api_response = self.client.post(f"/api/projects/{self.team.id}/query/", {"query": {}}) + self.assertEqual(api_response.status_code, 400) + + def test_missing_body(self): + api_response = self.client.post(f"/api/projects/{self.team.id}/query/") + self.assertEqual(api_response.status_code, 400) @snapshot_clickhouse_queries def test_full_hogql_query_view(self):