From 70b7bae30430f311f83d5e5d341732033c2673ce Mon Sep 17 00:00:00 2001 From: Viacheslav Ovchinnikov Date: Fri, 1 Nov 2024 19:43:21 +0500 Subject: [PATCH] added int enum support (#161) --- marshmallow_recipe/fields.py | 52 ++++++++++++++------------ tests/test_serialization.py | 71 ++++++++++++++++++++++++++++++------ 2 files changed, 88 insertions(+), 35 deletions(-) diff --git a/marshmallow_recipe/fields.py b/marshmallow_recipe/fields.py index 582e1f6..e2cc871 100644 --- a/marshmallow_recipe/fields.py +++ b/marshmallow_recipe/fields.py @@ -581,7 +581,7 @@ def raw_field( DateTimeField: type[m.fields.DateTime] -EnumField: type[m.fields.String] +EnumField: type[m.fields.Field] DictField: type[m.fields.Field] SetField: type[m.fields.List] FrozenSetField: type[m.fields.List] @@ -683,7 +683,7 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any: TupleField = TupleFieldV3 - class EnumFieldV3(m.fields.String): + class EnumFieldV3(m.fields.Field): default_error = "Not a valid choice: '{input}'. Allowed values: {choices}" def __init__( @@ -705,13 +705,13 @@ def __init__( ) self.enum_type = enum_type - self._validate_enum(self.enum_type) + enum_value_type = self._extract_enum_value_type(self.enum_type) self.error = error or EnumFieldV3.default_error self._validate_error(self.error) self.choices = [enum_instance.value for enum_instance in enum_type] - self._validate_choices(self.choices) + self._validate_choices(self.choices, enum_value_type) if allow_none: self.choices.append(None) @@ -746,9 +746,9 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any: return None if isinstance(value, self.enum_type): return value - string_value = super()._deserialize(value, attr, data) + enum_value = super()._deserialize(value, attr, data) try: - return self.enum_type(string_value) + return self.enum_type(enum_value) except ValueError: if self.extendable_default is m.missing: raise m.ValidationError(self.default_error.format(input=value, choices=self.choices)) @@ -757,11 +757,14 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any: raise m.ValidationError(self.default_error.format(input=value, choices=self.choices)) @staticmethod - def _validate_enum(enum_type: Any) -> None: + def _extract_enum_value_type(enum_type: Any) -> type[str | int]: if not issubclass(enum_type, enum.Enum): raise ValueError(f"Enum type {enum_type} should be subtype of Enum") - if not issubclass(enum_type, str): - raise ValueError(f"Enum type {enum_type} should be subtype of str") + if issubclass(enum_type, str): + return str + if issubclass(enum_type, int): + return int + raise ValueError(f"Enum type {enum_type} should be subtype of str or int") @staticmethod def _validate_error(error: str) -> None: @@ -771,10 +774,10 @@ def _validate_error(error: str) -> None: raise ValueError("Error should contain only {{input}} and {{choices}}'") @staticmethod - def _validate_choices(choices: list) -> None: + def _validate_choices(choices: list, enum_value_type: type[str | int]) -> None: for choice in choices: - if not isinstance(choice, str): - raise ValueError(f"There is enum value, which is not a string: {choice}") + if not isinstance(choice, enum_value_type): + raise ValueError(f"There is enum value, which is not isinstance of {enum_value_type}: {choice}") @staticmethod def _validate_default(enum_type: Any, default: Any, allow_none: bool) -> None: @@ -904,7 +907,7 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **_: Any) -> Any: TupleField = TupleFieldV2 - class EnumFieldV2(m.fields.String): + class EnumFieldV2(m.fields.Field): default_error = "Not a valid choice: '{input}'. Allowed values: {choices}" def __init__( @@ -926,13 +929,13 @@ def __init__( ) self.enum_type = enum_type - self._validate_enum(self.enum_type) + enum_value_type = self._extract_enum_value_type(self.enum_type) self.error = error or EnumFieldV2.default_error self._validate_error(self.error) self.choices = [enum_instance.value for enum_instance in enum_type] - self._validate_choices(self.choices) + self._validate_choices(self.choices, enum_value_type) if allow_none: self.choices.append(None) @@ -967,9 +970,9 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any: return None if isinstance(value, self.enum_type): return value - string_value = super()._deserialize(value, attr, data) + enum_value = super()._deserialize(value, attr, data) try: - return self.enum_type(string_value) + return self.enum_type(enum_value) except ValueError: if self.extendable_default is m.missing: raise m.ValidationError(self.default_error.format(input=value, choices=self.choices)) @@ -978,11 +981,14 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any: raise m.ValidationError(self.default_error.format(input=value, choices=self.choices)) @staticmethod - def _validate_enum(enum_type: Any) -> None: + def _extract_enum_value_type(enum_type: Any) -> type[str | int]: if not issubclass(enum_type, enum.Enum): raise ValueError(f"Enum type {enum_type} should be subtype of Enum") - if not issubclass(enum_type, str): - raise ValueError(f"Enum type {enum_type} should be subtype of str") + if issubclass(enum_type, str): + return str + if issubclass(enum_type, int): + return int + raise ValueError(f"Enum type {enum_type} should be subtype of str or int") @staticmethod def _validate_error(error: str) -> None: @@ -992,10 +998,10 @@ def _validate_error(error: str) -> None: raise ValueError("Error should contain only {{input}} and {{choices}}'") @staticmethod - def _validate_choices(choices: list) -> None: + def _validate_choices(choices: list, enum_value_type: type[str | int]) -> None: for choice in choices: - if not isinstance(choice, str): - raise ValueError(f"There is enum value, which is not a string: {choice}") + if not isinstance(choice, enum_value_type): + raise ValueError(f"There is enum value, which is not isinstance of {enum_value_type}: {choice}") @staticmethod def _validate_default(enum_type: Any, default: Any, allow_none: bool) -> None: diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 63082a1..3029cdf 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -15,6 +15,11 @@ class Parity(str, enum.Enum): EVEN = "even" +class Bit(int, enum.Enum): + Zero = 0 + One = 1 + + def test_simple_types() -> None: @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class SimpleTypesContainers: @@ -50,8 +55,10 @@ class SimpleTypesContainers: optional_frozenset_field: frozenset[str] | None tuple_field: tuple[str, ...] optional_tuple_field: tuple[str, ...] | None - enum_field: Parity - optional_enum_field: Parity | None + enum_str_field: Parity + optional_enum_str_field: Parity | None + enum_int_field: Bit + optional_enum_int_field: Bit | None # with default str_field_with_default: str = "42" bool_field_with_default: bool = True @@ -64,7 +71,8 @@ class SimpleTypesContainers: ) time_field_with_default: datetime.time = datetime.time(11, 33, 48) date_field_with_default: datetime.date = datetime.date(2022, 2, 20) - enum_field_with_default: Parity = Parity.ODD + enum_str_field_with_default: Parity = Parity.ODD + enum_int_field_with_default: Bit = Bit.Zero # with default factory str_field_with_default_factory: str = dataclasses.field(default_factory=lambda: "42") bool_field_with_default_factory: bool = dataclasses.field(default_factory=lambda: True) @@ -91,7 +99,8 @@ class SimpleTypesContainers: set_field_with_default_factory: set[str] = dataclasses.field(default_factory=lambda: set()) frozenset_field_with_default_factory: frozenset[str] = dataclasses.field(default_factory=lambda: frozenset()) tuple_field_with_default_factory: tuple[str, ...] = dataclasses.field(default_factory=lambda: tuple()) - enum_field_with_default_factory: Parity = dataclasses.field(default_factory=lambda: Parity.ODD) + enum_str_field_with_default_factory: Parity = dataclasses.field(default_factory=lambda: Parity.ODD) + enum_int_field_with_default_factory: Bit = dataclasses.field(default_factory=lambda: Bit.Zero) raw = dict( any_field={}, @@ -150,10 +159,14 @@ class SimpleTypesContainers: frozenset_field=["value"], frozenset_field_with_default_factory=[], optional_frozenset_field=["value"], - enum_field="odd", - enum_field_with_default="odd", - enum_field_with_default_factory="odd", - optional_enum_field="even", + enum_str_field="odd", + enum_str_field_with_default="odd", + enum_str_field_with_default_factory="odd", + optional_enum_str_field="even", + enum_int_field=0, + enum_int_field_with_default=0, + enum_int_field_with_default_factory=0, + optional_enum_int_field=1, ) raw_no_defaults = {k: v for k, v in raw.items() if not k.endswith("default") and not k.endswith("default_factory")} @@ -198,8 +211,10 @@ class SimpleTypesContainers: optional_frozenset_field=frozenset({"value"}), tuple_field=("value",), optional_tuple_field=("value",), - enum_field=Parity.ODD, - optional_enum_field=Parity.EVEN, + enum_str_field=Parity.ODD, + optional_enum_str_field=Parity.EVEN, + enum_int_field=Bit.Zero, + optional_enum_int_field=Bit.One, ) ) @@ -412,7 +427,7 @@ class DateTimeContainer: (Parity.EVEN, "even"), ], ) -def test_enum_field_dump(value: Parity, raw: str) -> None: +def test_enum_str_field_dump(value: Parity, raw: str) -> None: @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class EnumContainer: enum_field: Parity @@ -428,7 +443,7 @@ class EnumContainer: ("even", Parity.EVEN), ], ) -def test_enum_field_load(value: Parity, raw: str) -> None: +def test_enum_str_field_load(value: Parity, raw: str) -> None: @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class EnumContainer: enum_field: Parity @@ -437,6 +452,38 @@ class EnumContainer: assert dumped == EnumContainer(enum_field=value) +@pytest.mark.parametrize( + "value, raw", + [ + (Bit.Zero, 0), + (Bit.One, 1), + ], +) +def test_enum_int_field_dump(value: Bit, raw: str) -> None: + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class EnumContainer: + enum_field: Bit + + dumped = mr.dump(EnumContainer(enum_field=value)) + assert dumped == dict(enum_field=raw) + + +@pytest.mark.parametrize( + "raw, value", + [ + (0, Bit.Zero), + (1, Bit.One), + ], +) +def test_enum_int_field_load(value: Bit, raw: str) -> None: + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) + class EnumContainer: + enum_field: Bit + + dumped = mr.load(EnumContainer, dict(enum_field=raw)) + assert dumped == EnumContainer(enum_field=value) + + def test_naming_case_in_options() -> None: @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) @mr.options(naming_case=mr.CAMEL_CASE)