Skip to content

Commit

Permalink
feat: Validate batch export configuration in API
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Dec 24, 2024
1 parent 9089380 commit fabd5cf
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 0 deletions.
89 changes: 89 additions & 0 deletions posthog/api/test/batch_exports/test_create.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import datetime as dt
import json
from unittest import mock
Expand All @@ -13,7 +14,9 @@
from posthog.api.test.test_organization import create_organization
from posthog.api.test.test_team import create_team
from posthog.api.test.test_user import create_user
from posthog.batch_exports.http import is_compatible_with_field
from posthog.batch_exports.models import BatchExport
from posthog.batch_exports.service import BigQueryBatchExportInputs
from posthog.temporal.common.client import sync_connect
from posthog.temporal.common.codec import EncryptionCodec

Expand Down Expand Up @@ -396,3 +399,89 @@ def test_create_batch_export_fails_with_invalid_query(client: HttpClient, invali
)

assert response.status_code == status.HTTP_400_BAD_REQUEST, response.json()


@pytest.mark.parametrize(
"input_class,field_name,value,expected",
[
(BigQueryBatchExportInputs, "team_id", 123, True),
(BigQueryBatchExportInputs, "team_id", "a", False),
(BigQueryBatchExportInputs, "project_id", "a_project", True),
(BigQueryBatchExportInputs, "project_id", 1, False),
(BigQueryBatchExportInputs, "exclude_events", ["one", "two"], True),
(BigQueryBatchExportInputs, "exclude_events", [], True),
(BigQueryBatchExportInputs, "exclude_events", None, True),
(BigQueryBatchExportInputs, "exclude_events", ["one", 1], False),
(BigQueryBatchExportInputs, "use_json_type", True, True),
(BigQueryBatchExportInputs, "use_json_type", False, True),
],
)
def test_is_compatible_with_field_with_inputs(input_class, field_name, value, expected):
"""Test compatibility of several input fields."""
field = next(field for field in dataclasses.fields(input_class) if field.name == field_name)
assert is_compatible_with_field(value, field) is expected


@pytest.mark.parametrize(
"destination, invalid_config",
[
# Invalid exclude_events value
(
"S3",
{
"bucket_name": "my-production-s3-bucket",
"region": "us-east-1",
"prefix": "posthog-events/",
"exclude_events": "invalid",
},
),
# Invalid compression value
(
"S3",
{
"bucket_name": "my-production-s3-bucket",
"region": "us-east-1",
"prefix": "posthog-events/",
"compression": "invalid",
},
),
# Invalid file_format value
(
"S3",
{
"bucket_name": "my-production-s3-bucket",
"region": "us-east-1",
"prefix": "posthog-events/",
"file_format": "invalid",
},
),
],
)
def test_create_batch_export_fails_with_invalid_config(client: HttpClient, destination, invalid_config):
"""Test creating a BatchExport should fail with an invalid query."""
temporal = sync_connect()

destination_data = {
"type": destination,
"config": invalid_config,
}

batch_export_data = {
"name": "my-production-s3-bucket-destination",
"destination": destination_data,
"interval": "hour",
}

organization = create_organization("Test Org")
team = create_team(organization)
user = create_user("[email protected]", "Test User", organization)
client.force_login(user)

with start_test_worker(temporal):
response = create_batch_export(
client,
team.pk,
batch_export_data,
)

assert response.status_code == status.HTTP_400_BAD_REQUEST, response.json()
122 changes: 122 additions & 0 deletions posthog/batch_exports/http.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import dataclasses
import datetime as dt
import types
import typing
from typing import Any, TypedDict, cast

import posthoganalytics
Expand All @@ -19,6 +22,7 @@
from posthog.api.utils import action
from posthog.batch_exports.models import BATCH_EXPORT_INTERVALS
from posthog.batch_exports.service import (
DESTINATION_WORKFLOWS,
BatchExportIdError,
BatchExportSchema,
BatchExportServiceError,
Expand Down Expand Up @@ -162,12 +166,130 @@ def cancel(self, *args, **kwargs) -> response.Response:
return response.Response({"cancelled": True})


class BatchExportDestinationConfigurationValidator:
"""Validator for a batch export destination configuration."""

def __init__(self, type: str = "type", config: str = "config"):
self.type_field = type
self.config_field = config

def __call__(self, attrs: dict):
"""Validate configuration values are compatible with destination inputs."""
type = attrs[self.type_field]

try:
destination_inputs = DESTINATION_WORKFLOWS[type]
except KeyError:
raise serializers.ValidationError(
f"Unsupported batch export destination: '{type}'", code="unsupported_destination"
)

config = attrs[self.config_field]

for field in dataclasses.fields(destination_inputs):

Check failure on line 189 in posthog/batch_exports/http.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Argument 1 to "fields" has incompatible type "tuple[str, type]"; expected "DataclassInstance | type[DataclassInstance]"
is_required = (
field.default == dataclasses.MISSING
and field.default_factory == dataclasses.MISSING
# These two are required metadata fields but they are added by us.
and field.name not in {"batch_export_id", "team_id"}
)

if is_required and field.name not in config:
raise serializers.ValidationError(
f"The required configuration field '{field.name}' was not found", code="required_field"
)

provided_value = config.get(field.name, None)
is_compatible = is_compatible_with_field(provided_value, field)

if not is_compatible:
raise serializers.ValidationError(
f"An unsupported value ('{provided_value}') was provided for the configuration field '{field.name}' of type '{field.type}'",
code="unsupported_value",
)


def is_compatible_with_field(value: typing.Any, field: dataclasses.Field) -> bool:
"""Check if a value is compatible with a given dataclass field.
Compatibility is defined as the value being of the expected type by the
dataclass field.
Arguments:
value: The value to check
field: The dataclass field we are checking for compatibility
Returns:
True if compatible, False otherwise.
"""
origin = typing.get_origin(field.type)
args = typing.get_args(field.type)

# Unpack a Optional[X] union field to check its underlying type.
if (origin is typing.Union or origin is types.UnionType) and len(args) == 2 and type(None) in args:
if value is None:
# Value not provided for an optional field is compatible
return True

non_none_type = next(t for t in args if t is not type(None))
new_field = dataclasses.field()
new_field.name = field.name
new_field.type = non_none_type
return is_compatible_with_field(value, new_field)

# Unpack collection types to check underlying type of each element.
if origin in {list, dict, set, tuple}:
if not isinstance(value, origin): # type: ignore

Check failure on line 242 in posthog/batch_exports/http.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Unused "type: ignore" comment
return False

if args:
if origin is list or origin is set:
new_field = dataclasses.field()
new_field.name = field.name
new_field.type = args[0]

return all(is_compatible_with_field(v, new_field) for v in value)

elif origin is dict:
new_key_field = dataclasses.field()
new_key_field.name = field.name
new_key_field.type = args[0]

new_value_field = dataclasses.field()
new_value_field.name = field.name
new_value_field.type = args[1]

return all(
is_compatible_with_field(k, new_key_field) and is_compatible_with_field(v, new_value_field)
for k, v in value.items()
)

elif origin is tuple:
if not len(value) == len(args):
return False

for t, v in zip(args, value):
new_field = dataclasses.field()
new_field.name = field.name
new_field.type = t
is_compatible = is_compatible_with_field(v, new_field)

if not is_compatible:
return False

return True

# Check primitive types.
return isinstance(value, field.type) # type: ignore

Check failure on line 283 in posthog/batch_exports/http.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Unused "type: ignore" comment


class BatchExportDestinationSerializer(serializers.ModelSerializer):
"""Serializer for an BatchExportDestination model."""

class Meta:
model = BatchExportDestination
fields = ["type", "config"]
validators = [BatchExportDestinationConfigurationValidator()]

def create(self, validated_data: dict) -> BatchExportDestination:
"""Create a BatchExportDestination."""
Expand Down
6 changes: 6 additions & 0 deletions posthog/batch_exports/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime as dt
import enum
import typing
from dataclasses import asdict, dataclass, fields
from uuid import UUID
Expand Down Expand Up @@ -67,6 +68,11 @@ class BatchExportsInputsProtocol(typing.Protocol):
is_backfill: bool = False


class S3FileFormat(enum.StrEnum):
PARQUET = "Parquet"
JSONLINES = "JSONLines"


@dataclass
class S3BatchExportInputs:
"""Inputs for S3 export workflow.
Expand Down

0 comments on commit fabd5cf

Please sign in to comment.