Skip to content

Commit

Permalink
chore(data-warehouse): pass context through serializer (#21343)
Browse files Browse the repository at this point in the history
* pass context

* add fallback

* typing

* more fallbacks

* pass through context
  • Loading branch information
EDsCODE authored Apr 4, 2024
1 parent ecb2230 commit 6b9617f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
14 changes: 12 additions & 2 deletions posthog/warehouse/api/external_data_schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from rest_framework import serializers
from posthog.warehouse.models import ExternalDataSchema
from typing import Optional
from typing import Optional, Dict, Any
from posthog.api.routing import TeamAndOrgViewSetMixin
from rest_framework import viewsets, filters
from rest_framework.exceptions import NotAuthenticated
from posthog.models import User
from posthog.hogql.database.database import create_hogql_database


class ExternalDataSchemaSerializer(serializers.ModelSerializer):
Expand All @@ -18,7 +19,11 @@ class Meta:
def get_table(self, schema: ExternalDataSchema) -> Optional[dict]:
from posthog.warehouse.api.table import SimpleTableSerializer

return SimpleTableSerializer(schema.table).data or None
hogql_context = self.context.get("database", None)
if not hogql_context:
hogql_context = create_hogql_database(team_id=self.context["team_id"])

return SimpleTableSerializer(schema.table, context={"database": hogql_context}).data or None


class SimpleExternalDataSchemaSerializer(serializers.ModelSerializer):
Expand All @@ -35,6 +40,11 @@ class ExternalDataSchemaViewset(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
search_fields = ["name"]
ordering = "-created_at"

def get_serializer_context(self) -> Dict[str, Any]:
context = super().get_serializer_context()
context["database"] = create_hogql_database(team_id=self.team_id)
return context

def get_queryset(self):
if not isinstance(self.request.user, User) or self.request.user.current_team is None:
raise NotAuthenticated()
Expand Down
10 changes: 8 additions & 2 deletions posthog/warehouse/api/external_data_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid
from typing import Any, List, Tuple
from typing import Any, List, Tuple, Dict

import structlog
from rest_framework import filters, serializers, status, viewsets
Expand All @@ -20,6 +20,7 @@
)
from posthog.warehouse.models import ExternalDataSource, ExternalDataSchema, ExternalDataJob
from posthog.warehouse.api.external_data_schema import ExternalDataSchemaSerializer
from posthog.hogql.database.database import create_hogql_database
from posthog.temporal.data_imports.pipelines.schemas import (
PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING,
)
Expand Down Expand Up @@ -69,7 +70,7 @@ def get_last_run_at(self, instance: ExternalDataSource) -> str:

def get_schemas(self, instance: ExternalDataSource):
schemas = instance.schemas.order_by("name").all()
return ExternalDataSchemaSerializer(schemas, many=True, read_only=True).data
return ExternalDataSchemaSerializer(schemas, many=True, read_only=True, context=self.context).data


class SimpleExternalDataSourceSerializers(serializers.ModelSerializer):
Expand Down Expand Up @@ -97,6 +98,11 @@ class ExternalDataSourceViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
search_fields = ["source_id"]
ordering = "-created_at"

def get_serializer_context(self) -> Dict[str, Any]:
context = super().get_serializer_context()
context["database"] = create_hogql_database(team_id=self.team_id)
return context

def get_queryset(self):
if not isinstance(self.request.user, User) or self.request.user.current_team is None:
raise NotAuthenticated()
Expand Down
22 changes: 14 additions & 8 deletions posthog/warehouse/api/table.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any, List
from typing import Any, List, Dict

from rest_framework import filters, request, response, serializers, status, viewsets
from rest_framework.exceptions import NotAuthenticated

from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.database import SerializedField, create_hogql_database, serialize_fields
from posthog.models import User
from posthog.warehouse.models import (
Expand Down Expand Up @@ -55,10 +54,11 @@ class Meta:
read_only_fields = ["id", "created_by", "created_at", "columns", "external_data_source", "external_schema"]

def get_columns(self, table: DataWarehouseTable) -> List[SerializedField]:
team_id = self.context["team_id"]
context = HogQLContext(team_id=team_id, database=create_hogql_database(team_id=team_id))
hogql_context = self.context.get("database", None)
if not hogql_context:
hogql_context = create_hogql_database(team_id=self.context["team_id"])

return serialize_fields(table.hogql_definition().fields, context)
return serialize_fields(table.hogql_definition().fields, hogql_context)

def get_external_schema(self, instance: DataWarehouseTable):
from posthog.warehouse.api.external_data_schema import SimpleExternalDataSchemaSerializer
Expand Down Expand Up @@ -92,10 +92,11 @@ class Meta:
read_only_fields = ["id", "name", "columns"]

def get_columns(self, table: DataWarehouseTable) -> List[SerializedField]:
team_id = table.team.pk
context = HogQLContext(team_id=team_id, database=create_hogql_database(team_id=team_id))
hogql_context = self.context.get("database", None)
if not hogql_context:
hogql_context = create_hogql_database(team_id=self.context["team_id"])

return serialize_fields(table.hogql_definition().fields, context)
return serialize_fields(table.hogql_definition().fields, hogql_context)


class TableViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
Expand All @@ -110,6 +111,11 @@ class TableViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
search_fields = ["name"]
ordering = "-created_at"

def get_serializer_context(self) -> Dict[str, Any]:
context = super().get_serializer_context()
context["database"] = create_hogql_database(team_id=self.team_id)
return context

def get_queryset(self):
if not isinstance(self.request.user, User) or self.request.user.current_team is None:
raise NotAuthenticated()
Expand Down

0 comments on commit 6b9617f

Please sign in to comment.