From c6cbed1f0054a708e79006c8edbca6ed66a18651 Mon Sep 17 00:00:00 2001 From: Manuel Holtgrewe Date: Thu, 21 Sep 2023 16:29:24 +0200 Subject: [PATCH] Add Union type support where inner types are scalars --- src/drf_pydantic/fields.py | 48 ++++++++++++++++++++++++++++++++++++-- src/drf_pydantic/parse.py | 15 ++++++++---- tests/test_models.py | 33 ++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 7 deletions(-) diff --git a/src/drf_pydantic/fields.py b/src/drf_pydantic/fields.py index 809bfe4..e8d8156 100644 --- a/src/drf_pydantic/fields.py +++ b/src/drf_pydantic/fields.py @@ -1,8 +1,9 @@ from enum import Enum -from typing import Type, Optional, Union +from types import NoneType +from typing import Any, List, Type, Optional, Union from rest_framework.fields import empty -from rest_framework.serializers import ChoiceField +from rest_framework.serializers import ChoiceField, Field class EnumField(ChoiceField): @@ -43,3 +44,46 @@ def to_representation(self, value: Optional[Union[Enum, str]]) -> Optional[str]: return value.value return value + + +#: Define shortcut for scalar types +ScalarTypes = int | float | str | bool | NoneType + + +class UnionField(Field): + """ + Custom DRF field that supports union fields of scalar values. + """ + + default_error_messages = {"invalid": "No match in type union"} + + #: The allowed types + types: List[type] + + def __init__(self, + types: List[type], + **kwargs): + super().__init__(**kwargs) + self._check_all_types_scalar(types) + self.types = types + self.allow_null = NoneType in types + + def _check_all_types_scalar(self, types: List[type]): + for type_ in types: + if not any(type_ is t for t in (int, float, str, bool, NoneType)): + raise ValueError(f"UnionField only supports scalar types but found: {type_}") + + def run_validation(self, data: Any) -> ScalarTypes: + if type(data) not in self.types: + self.fail("invalid") + return super().run_validation(data) + + def to_internal_value(self, data: Any) -> ScalarTypes: + if type(data) in self.types: + return data + self.fail("invalid") + + def to_representation(self, data: Any) -> ScalarTypes: + if type(data) in self.types: + return data + self.fail("invalid") diff --git a/src/drf_pydantic/parse.py b/src/drf_pydantic/parse.py index f6281fa..7c044e1 100644 --- a/src/drf_pydantic/parse.py +++ b/src/drf_pydantic/parse.py @@ -10,7 +10,7 @@ import pydantic from rest_framework import serializers -from drf_pydantic.fields import EnumField +from drf_pydantic.fields import EnumField, UnionField # Cache serializer classes to ensure that there is a one-to-one relationship # between pydantic models and serializer classes @@ -138,11 +138,16 @@ def _convert_field(field: pydantic.fields.ModelField) -> serializers.Field: return _convert_type(field.type_)(**extra_kwargs) # Alias - if field.type_.__origin__ is typing.Literal: + if typing.get_origin(field.type_) is typing.Literal: choices = field.type_.__args__ assert all(isinstance(choice, str) for choice in choices) return serializers.ChoiceField(choices=choices, **extra_kwargs) - raise NotImplementedError(f"{field.type_.__name__} is not yet supported") + + # Union types (only supported for scalar options) + if type(field.type_) is types.UnionType: + return UnionField(types=typing.get_args(field.type_)) + + raise NotImplementedError(f"'{repr(field.type_)}' is not yet supported") # Container field assert isinstance( @@ -152,10 +157,10 @@ def _convert_field(field: pydantic.fields.ModelField) -> serializers.Field: getattr(typing, "_GenericAlias"), ), ), f"Unsupported container type '{field.outer_type_.__name__}'" - if field.outer_type_.__origin__ is list or field.outer_type_.__origin__ is tuple: + if typing.get_origin(field.outer_type_) is list or typing.get_origin(field.outer_type_) is tuple: return serializers.ListField(child=_convert_type(field.type_)(**extra_kwargs)) raise NotImplementedError( - f"Container type '{field.outer_type_.__origin__.__name__}' is not yet supported" + f"Container type '{typing.getorigin(field.outer_type_).__name__}' is not yet supported" ) diff --git a/tests/test_models.py b/tests/test_models.py index 9ef34d2..216944e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -381,3 +381,36 @@ class Human(BaseModel): bad_value_serializer = serializer(data={'sex': 'bad_value', 'age': 25}) assert bad_value_serializer.is_valid() is False + + +def test_union_value(): + + class ModelWithUnions(BaseModel): + number: int | float + number_opt: int | float | None + + serializer = ModelWithUnions.drf_serializer + + normal_serializer = serializer(data={'number': 10, 'number_opt': None}) + + assert normal_serializer.is_valid() + assert normal_serializer.validated_data['number'] == 10 + assert normal_serializer.validated_data['number_opt'] == None + + value_serializer = serializer(data={'number': 10, 'number_opt': None}) + + assert value_serializer.is_valid() + assert value_serializer.validated_data['number'] == 10 + assert value_serializer.validated_data['number_opt'] == None + + bad_value_serializer = serializer(data={'number': None, 'number_opt': 'foo'}) + + assert bad_value_serializer.is_valid() is False + + +def test_union_value_failure(): + """test case of non-scalar field definition""" + + with pytest.raises(ValueError): + class ModelWithUnions(BaseModel): + bad: int | float | list[int]