Skip to content

Commit

Permalink
added int enum support (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
slawwan authored Nov 1, 2024
1 parent 7eb5964 commit 70b7bae
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 35 deletions.
52 changes: 29 additions & 23 deletions marshmallow_recipe/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def raw_field(


DateTimeField: type[m.fields.DateTime]
EnumField: type[m.fields.String]
EnumField: type[m.fields.Field]
DictField: type[m.fields.Field]
SetField: type[m.fields.List]
FrozenSetField: type[m.fields.List]
Expand Down Expand Up @@ -683,7 +683,7 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any:

TupleField = TupleFieldV3

class EnumFieldV3(m.fields.String):
class EnumFieldV3(m.fields.Field):
default_error = "Not a valid choice: '{input}'. Allowed values: {choices}"

def __init__(
Expand All @@ -705,13 +705,13 @@ def __init__(
)

self.enum_type = enum_type
self._validate_enum(self.enum_type)
enum_value_type = self._extract_enum_value_type(self.enum_type)

self.error = error or EnumFieldV3.default_error
self._validate_error(self.error)

self.choices = [enum_instance.value for enum_instance in enum_type]
self._validate_choices(self.choices)
self._validate_choices(self.choices, enum_value_type)
if allow_none:
self.choices.append(None)

Expand Down Expand Up @@ -746,9 +746,9 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any:
return None
if isinstance(value, self.enum_type):
return value
string_value = super()._deserialize(value, attr, data)
enum_value = super()._deserialize(value, attr, data)
try:
return self.enum_type(string_value)
return self.enum_type(enum_value)
except ValueError:
if self.extendable_default is m.missing:
raise m.ValidationError(self.default_error.format(input=value, choices=self.choices))
Expand All @@ -757,11 +757,14 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any:
raise m.ValidationError(self.default_error.format(input=value, choices=self.choices))

@staticmethod
def _validate_enum(enum_type: Any) -> None:
def _extract_enum_value_type(enum_type: Any) -> type[str | int]:
if not issubclass(enum_type, enum.Enum):
raise ValueError(f"Enum type {enum_type} should be subtype of Enum")
if not issubclass(enum_type, str):
raise ValueError(f"Enum type {enum_type} should be subtype of str")
if issubclass(enum_type, str):
return str
if issubclass(enum_type, int):
return int
raise ValueError(f"Enum type {enum_type} should be subtype of str or int")

@staticmethod
def _validate_error(error: str) -> None:
Expand All @@ -771,10 +774,10 @@ def _validate_error(error: str) -> None:
raise ValueError("Error should contain only {{input}} and {{choices}}'")

@staticmethod
def _validate_choices(choices: list) -> None:
def _validate_choices(choices: list, enum_value_type: type[str | int]) -> None:
for choice in choices:
if not isinstance(choice, str):
raise ValueError(f"There is enum value, which is not a string: {choice}")
if not isinstance(choice, enum_value_type):
raise ValueError(f"There is enum value, which is not isinstance of {enum_value_type}: {choice}")

@staticmethod
def _validate_default(enum_type: Any, default: Any, allow_none: bool) -> None:
Expand Down Expand Up @@ -904,7 +907,7 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **_: Any) -> Any:

TupleField = TupleFieldV2

class EnumFieldV2(m.fields.String):
class EnumFieldV2(m.fields.Field):
default_error = "Not a valid choice: '{input}'. Allowed values: {choices}"

def __init__(
Expand All @@ -926,13 +929,13 @@ def __init__(
)

self.enum_type = enum_type
self._validate_enum(self.enum_type)
enum_value_type = self._extract_enum_value_type(self.enum_type)

self.error = error or EnumFieldV2.default_error
self._validate_error(self.error)

self.choices = [enum_instance.value for enum_instance in enum_type]
self._validate_choices(self.choices)
self._validate_choices(self.choices, enum_value_type)
if allow_none:
self.choices.append(None)

Expand Down Expand Up @@ -967,9 +970,9 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any:
return None
if isinstance(value, self.enum_type):
return value
string_value = super()._deserialize(value, attr, data)
enum_value = super()._deserialize(value, attr, data)
try:
return self.enum_type(string_value)
return self.enum_type(enum_value)
except ValueError:
if self.extendable_default is m.missing:
raise m.ValidationError(self.default_error.format(input=value, choices=self.choices))
Expand All @@ -978,11 +981,14 @@ def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any:
raise m.ValidationError(self.default_error.format(input=value, choices=self.choices))

@staticmethod
def _validate_enum(enum_type: Any) -> None:
def _extract_enum_value_type(enum_type: Any) -> type[str | int]:
if not issubclass(enum_type, enum.Enum):
raise ValueError(f"Enum type {enum_type} should be subtype of Enum")
if not issubclass(enum_type, str):
raise ValueError(f"Enum type {enum_type} should be subtype of str")
if issubclass(enum_type, str):
return str
if issubclass(enum_type, int):
return int
raise ValueError(f"Enum type {enum_type} should be subtype of str or int")

@staticmethod
def _validate_error(error: str) -> None:
Expand All @@ -992,10 +998,10 @@ def _validate_error(error: str) -> None:
raise ValueError("Error should contain only {{input}} and {{choices}}'")

@staticmethod
def _validate_choices(choices: list) -> None:
def _validate_choices(choices: list, enum_value_type: type[str | int]) -> None:
for choice in choices:
if not isinstance(choice, str):
raise ValueError(f"There is enum value, which is not a string: {choice}")
if not isinstance(choice, enum_value_type):
raise ValueError(f"There is enum value, which is not isinstance of {enum_value_type}: {choice}")

@staticmethod
def _validate_default(enum_type: Any, default: Any, allow_none: bool) -> None:
Expand Down
71 changes: 59 additions & 12 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class Parity(str, enum.Enum):
EVEN = "even"


class Bit(int, enum.Enum):
Zero = 0
One = 1


def test_simple_types() -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class SimpleTypesContainers:
Expand Down Expand Up @@ -50,8 +55,10 @@ class SimpleTypesContainers:
optional_frozenset_field: frozenset[str] | None
tuple_field: tuple[str, ...]
optional_tuple_field: tuple[str, ...] | None
enum_field: Parity
optional_enum_field: Parity | None
enum_str_field: Parity
optional_enum_str_field: Parity | None
enum_int_field: Bit
optional_enum_int_field: Bit | None
# with default
str_field_with_default: str = "42"
bool_field_with_default: bool = True
Expand All @@ -64,7 +71,8 @@ class SimpleTypesContainers:
)
time_field_with_default: datetime.time = datetime.time(11, 33, 48)
date_field_with_default: datetime.date = datetime.date(2022, 2, 20)
enum_field_with_default: Parity = Parity.ODD
enum_str_field_with_default: Parity = Parity.ODD
enum_int_field_with_default: Bit = Bit.Zero
# with default factory
str_field_with_default_factory: str = dataclasses.field(default_factory=lambda: "42")
bool_field_with_default_factory: bool = dataclasses.field(default_factory=lambda: True)
Expand All @@ -91,7 +99,8 @@ class SimpleTypesContainers:
set_field_with_default_factory: set[str] = dataclasses.field(default_factory=lambda: set())
frozenset_field_with_default_factory: frozenset[str] = dataclasses.field(default_factory=lambda: frozenset())
tuple_field_with_default_factory: tuple[str, ...] = dataclasses.field(default_factory=lambda: tuple())
enum_field_with_default_factory: Parity = dataclasses.field(default_factory=lambda: Parity.ODD)
enum_str_field_with_default_factory: Parity = dataclasses.field(default_factory=lambda: Parity.ODD)
enum_int_field_with_default_factory: Bit = dataclasses.field(default_factory=lambda: Bit.Zero)

raw = dict(
any_field={},
Expand Down Expand Up @@ -150,10 +159,14 @@ class SimpleTypesContainers:
frozenset_field=["value"],
frozenset_field_with_default_factory=[],
optional_frozenset_field=["value"],
enum_field="odd",
enum_field_with_default="odd",
enum_field_with_default_factory="odd",
optional_enum_field="even",
enum_str_field="odd",
enum_str_field_with_default="odd",
enum_str_field_with_default_factory="odd",
optional_enum_str_field="even",
enum_int_field=0,
enum_int_field_with_default=0,
enum_int_field_with_default_factory=0,
optional_enum_int_field=1,
)

raw_no_defaults = {k: v for k, v in raw.items() if not k.endswith("default") and not k.endswith("default_factory")}
Expand Down Expand Up @@ -198,8 +211,10 @@ class SimpleTypesContainers:
optional_frozenset_field=frozenset({"value"}),
tuple_field=("value",),
optional_tuple_field=("value",),
enum_field=Parity.ODD,
optional_enum_field=Parity.EVEN,
enum_str_field=Parity.ODD,
optional_enum_str_field=Parity.EVEN,
enum_int_field=Bit.Zero,
optional_enum_int_field=Bit.One,
)
)

Expand Down Expand Up @@ -412,7 +427,7 @@ class DateTimeContainer:
(Parity.EVEN, "even"),
],
)
def test_enum_field_dump(value: Parity, raw: str) -> None:
def test_enum_str_field_dump(value: Parity, raw: str) -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class EnumContainer:
enum_field: Parity
Expand All @@ -428,7 +443,7 @@ class EnumContainer:
("even", Parity.EVEN),
],
)
def test_enum_field_load(value: Parity, raw: str) -> None:
def test_enum_str_field_load(value: Parity, raw: str) -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class EnumContainer:
enum_field: Parity
Expand All @@ -437,6 +452,38 @@ class EnumContainer:
assert dumped == EnumContainer(enum_field=value)


@pytest.mark.parametrize(
"value, raw",
[
(Bit.Zero, 0),
(Bit.One, 1),
],
)
def test_enum_int_field_dump(value: Bit, raw: str) -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class EnumContainer:
enum_field: Bit

dumped = mr.dump(EnumContainer(enum_field=value))
assert dumped == dict(enum_field=raw)


@pytest.mark.parametrize(
"raw, value",
[
(0, Bit.Zero),
(1, Bit.One),
],
)
def test_enum_int_field_load(value: Bit, raw: str) -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class EnumContainer:
enum_field: Bit

dumped = mr.load(EnumContainer, dict(enum_field=raw))
assert dumped == EnumContainer(enum_field=value)


def test_naming_case_in_options() -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
@mr.options(naming_case=mr.CAMEL_CASE)
Expand Down

0 comments on commit 70b7bae

Please sign in to comment.