Skip to content

Commit

Permalink
Remove parser, opt for using model in simpler way
Browse files Browse the repository at this point in the history
  • Loading branch information
webjunkie committed Jan 10, 2024
1 parent d6e866f commit a1ab542
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 63 deletions.
15 changes: 15 additions & 0 deletions posthog/api/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import TypeVar, Type, Generic

from pydantic import BaseModel, ValidationError

from rest_framework.exceptions import ParseError

T = TypeVar("T", bound=BaseModel)


class PydanticModelMixin(Generic[T]):
def get_model(self, data: dict, model: Type[T]) -> T:
try:
return model.model_validate(data)
except ValidationError as exc:
raise ParseError("JSON parse error - %s" % str(exc))

Check warning

Code scanning / CodeQL

Information exposure through an exception Medium

Stack trace information
flows to this location and may be exposed to an external user.
35 changes: 0 additions & 35 deletions posthog/api/parsers.py

This file was deleted.

13 changes: 3 additions & 10 deletions posthog/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re
import uuid

from pydantic import BaseModel
from django.http import JsonResponse
from drf_spectacular.utils import OpenApiResponse
from rest_framework import viewsets
Expand All @@ -14,7 +13,7 @@
from sentry_sdk import capture_exception

from posthog.api.documentation import extend_schema
from posthog.api.parsers import PydanticJSONParser
from posthog.api.mixins import PydanticModelMixin
from posthog.api.routing import StructuredViewSetMixin
from posthog.api.services.query import process_query_model
from posthog.clickhouse.client.execute_async import (
Expand Down Expand Up @@ -44,16 +43,12 @@ class QueryThrottle(TeamRateThrottle):
rate = "120/hour"


class QueryViewSet(StructuredViewSetMixin, viewsets.ViewSet):
class QueryViewSet(PydanticModelMixin, StructuredViewSetMixin, viewsets.ViewSet):
permission_classes = [
IsAuthenticated,
ProjectMembershipNecessaryPermissions,
TeamMemberAccessPermission,
]
parser_classes = (PydanticJSONParser,)
pydantic_models = {
"create": QueryRequest,
}

def get_throttles(self):
if self.action == "draft_sql":
Expand All @@ -68,7 +63,7 @@ def get_throttles(self):
},
)
def create(self, request, *args, **kwargs) -> Response:
data: QueryRequest = request.data
data = self.get_model(request.data, QueryRequest)
client_query_id = data.client_query_id or uuid.uuid4().hex

self._tag_client_query_id(client_query_id)
Expand All @@ -84,8 +79,6 @@ def create(self, request, *args, **kwargs) -> Response:

try:
result = process_query_model(self.team, data.query, refresh_requested=data.refresh)
if isinstance(result, BaseModel):
return Response(result.model_dump())
return Response(result)
except (HogQLException, ExposedCHQueryError) as e:
raise ValidationError(str(e), getattr(e, "code_name", None))
Expand Down
32 changes: 18 additions & 14 deletions posthog/api/services/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import structlog
from typing import Dict, Optional
from typing import Optional

from pydantic import BaseModel
from rest_framework.exceptions import ValidationError
Expand Down Expand Up @@ -41,11 +41,11 @@

def process_query(
team: Team,
query: QuerySchemaRoot,
query_json: dict,
limit_context: Optional[LimitContext] = None,
refresh_requested: Optional[bool] = False,
):
model = QuerySchemaRoot.model_validate(query)
) -> dict:
model = QuerySchemaRoot.model_validate(query_json)
return process_query_model(
team,
model.root,
Expand All @@ -59,26 +59,26 @@ def process_query_model(
query: "QuerySchemaRoot.root",
limit_context: Optional[LimitContext] = None,
refresh_requested: Optional[bool] = False,
) -> Dict | BaseModel:
) -> dict:
tag_queries(query=query.kind) # TODO: ?

if isinstance(query, QUERY_WITH_RUNNER):
query_runner = get_query_runner(query, team, limit_context=limit_context)
return query_runner.run(refresh_requested=refresh_requested)
if isinstance(query, QUERY_WITH_RUNNER_NO_CACHE):
result = query_runner.run(refresh_requested=refresh_requested)
elif isinstance(query, QUERY_WITH_RUNNER_NO_CACHE):
query_runner = get_query_runner(query, team, limit_context=limit_context)
return query_runner.calculate()
result = query_runner.calculate()
elif isinstance(query, HogQLMetadata):
metadata_query = HogQLMetadata.model_validate(query)
metadata_response = get_hogql_metadata(query=metadata_query, team=team)
return metadata_response
result = metadata_response
elif isinstance(query, DatabaseSchemaQuery):
database = create_hogql_database(team.pk, modifiers=create_default_modifiers_for_team(team))
return serialize_database(database)
result = serialize_database(database)
elif isinstance(query, TimeToSeeDataSessionsQuery):
sessions_query_serializer = SessionsQuerySerializer(data=query)
sessions_query_serializer.is_valid(raise_exception=True)
return {"results": get_sessions(sessions_query_serializer).data}
result = {"results": get_sessions(sessions_query_serializer).data}
elif isinstance(query, TimeToSeeDataQuery):
serializer = SessionEventsQuerySerializer(
data={
Expand All @@ -89,8 +89,12 @@ def process_query_model(
}
)
serializer.is_valid(raise_exception=True)
return get_session_events(serializer) or {}
result = get_session_events(serializer) or {}
elif query.source:
return process_query(team, query.source)
result = process_query(team, query.source)
else:
raise ValidationError(f"Unsupported query kind: {query.kind}")

raise ValidationError(f"Unsupported query kind: {query.kind}")
if isinstance(result, BaseModel):
return result.model_dump()
return result
20 changes: 16 additions & 4 deletions posthog/api/test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,10 +694,22 @@ def test_property_definition_annotation_does_not_break_things(self):

def test_invalid_query_kind(self):
api_response = self.client.post(f"/api/projects/{self.team.id}/query/", {"query": {"kind": "Tomato Soup"}})
assert api_response.status_code == 400
assert api_response.json()["code"] == "parse_error"
assert "1 validation error for QueryRequest" in api_response.json()["detail"]
assert "type=literal_error, input_value='Tomato Soup'" in api_response.json()["detail"]
self.assertEqual(api_response.status_code, 400)
self.assertEqual(api_response.json()["code"], "parse_error")
self.assertIn("1 validation error for QueryRequest", api_response.json()["detail"], api_response.content)
self.assertIn(
"Input tag 'Tomato Soup' found using 'kind' does not match any of the expected tags",
api_response.json()["detail"],
api_response.content,
)

def test_missing_query(self):
api_response = self.client.post(f"/api/projects/{self.team.id}/query/", {"query": {}})
self.assertEqual(api_response.status_code, 400)

def test_missing_body(self):
api_response = self.client.post(f"/api/projects/{self.team.id}/query/")
self.assertEqual(api_response.status_code, 400)

@snapshot_clickhouse_queries
def test_full_hogql_query_view(self):
Expand Down

0 comments on commit a1ab542

Please sign in to comment.