Skip to content

Commit

Permalink
Add post_load delegate to str field (#131)
Browse files Browse the repository at this point in the history
* Add post_load delegate to str field

* Fix imports

* Add tests
  • Loading branch information
Pliner authored Sep 26, 2023
1 parent f5c0421 commit bb6a077
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 15 deletions.
47 changes: 37 additions & 10 deletions marshmallow_recipe/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ def str_field(
name: str | None = None,
validate: ValidationFunc | collections.abc.Sequence[ValidationFunc] | None = None,
strip_whitespaces: bool = False,
post_load: collections.abc.Callable[[str], str] | None = None,
**_: Any,
) -> m.fields.Field:
if default is m.missing:
return StrField(
allow_none=allow_none,
validate=validate,
strip_whitespaces=strip_whitespaces,
post_load=post_load,
**default_fields(m.missing),
**data_key_fields(name),
)
Expand All @@ -41,13 +43,15 @@ def str_field(
allow_none=allow_none,
validate=validate,
strip_whitespaces=strip_whitespaces,
post_load=post_load,
**data_key_fields(name),
)

return StrField(
allow_none=allow_none,
validate=validate,
strip_whitespaces=strip_whitespaces,
post_load=post_load,
**default_fields(None if default is dataclasses.MISSING else default),
**data_key_fields(name),
)
Expand Down Expand Up @@ -565,8 +569,14 @@ def default_fields(value: Any) -> dict[str, Any]:
return dict(dump_default=value, load_default=value)

class StrFieldV3(m.fields.Str):
def __init__(self, strip_whitespaces: bool = False, **kwargs: Any):
def __init__(
self,
strip_whitespaces: bool = False,
post_load: collections.abc.Callable[[str], str] | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
self.post_load = post_load
self.strip_whitespaces = strip_whitespaces

def _serialize(self, value: Any, attr: Any, obj: Any, **kwargs: Any) -> Any:
Expand All @@ -579,10 +589,15 @@ def _serialize(self, value: Any, attr: Any, obj: Any, **kwargs: Any) -> Any:

def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs: Any) -> Any:
result = super()._deserialize(value, attr, data, **kwargs)
if self.strip_whitespaces and result is not None:
result = result.strip()
if self.allow_none and len(result) == 0:
result = None
if result is not None:
if self.strip_whitespaces:
result = result.strip()
if self.allow_none and len(result) == 0:
return None

if self.post_load is not None:
result = self.post_load(result)

return result

StrField = StrFieldV3
Expand Down Expand Up @@ -754,8 +769,14 @@ def default_fields(value: Any) -> dict[str, Any]:
return dict(missing=value, default=value)

class StrFieldV2(m.fields.Str):
def __init__(self, strip_whitespaces: bool = False, **kwargs: Any):
def __init__(
self,
strip_whitespaces: bool = False,
post_load: collections.abc.Callable[[str], str] | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
self.post_load = post_load
self.strip_whitespaces = strip_whitespaces

def _serialize(self, value: Any, attr: Any, obj: Any, **_: Any) -> Any:
Expand All @@ -768,10 +789,16 @@ def _serialize(self, value: Any, attr: Any, obj: Any, **_: Any) -> Any:

def _deserialize(self, value: Any, attr: Any, data: Any, **_: Any) -> Any:
result = super()._deserialize(value, attr, data)
if self.strip_whitespaces and result is not None:
result = result.strip()
if self.allow_none and len(result) == 0:
result = None

if result is not None:
if self.strip_whitespaces:
result = result.strip()
if self.allow_none and len(result) == 0:
return None

if self.post_load is not None:
result = self.post_load(result)

return result

StrField = StrFieldV2
Expand Down
3 changes: 3 additions & 0 deletions marshmallow_recipe/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def str_metadata(
name: str = MISSING,
validate: ValidationFunc | collections.abc.Sequence[ValidationFunc] | None = None,
strip_whitespaces: bool | None = None,
post_load: collections.abc.Callable[[str], str] | None = None,
) -> Metadata:
values = dict[str, Any]()
if name is not MISSING:
Expand All @@ -55,6 +56,8 @@ def str_metadata(
values.update(validate=validate)
if strip_whitespaces is not None:
values.update(strip_whitespaces=strip_whitespaces)
if post_load is not None:
values.update(post_load=post_load)
return Metadata(values)


Expand Down
22 changes: 17 additions & 5 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,21 @@ class StrContainer:

@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class OptionalStrContainer:
value: Annotated[str | None, mr.str_meta(strip_whitespaces=True)]
value1: Annotated[str | None, mr.str_meta(strip_whitespaces=True)]
value2: Annotated[str | None, mr.str_meta(strip_whitespaces=False)]

assert OptionalStrContainer(value=None) == mr.load(OptionalStrContainer, {"value": ""})
assert OptionalStrContainer(value=None) == mr.load(OptionalStrContainer, {"value": None})
assert mr.dump(OptionalStrContainer(value="")) == {}
assert mr.dump(OptionalStrContainer(value=None)) == {}
assert OptionalStrContainer(value1=None, value2="") == mr.load(OptionalStrContainer, {"value1": "", "value2": ""})
assert OptionalStrContainer(value1=None, value2=None) == mr.load(
OptionalStrContainer, {"value1": None, "value2": None}
)
assert mr.dump(OptionalStrContainer(value1="", value2="")) == {"value2": ""}
assert mr.dump(OptionalStrContainer(value1=None, value2=None)) == {}


def test_str_post_load() -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class StrContainer:
value: Annotated[str, mr.str_meta(post_load=lambda x: x.replace("-", ""))]

assert StrContainer(value="111111") == mr.load(StrContainer, {"value": "11-11-11"})
assert mr.dump(StrContainer(value="11-11-11")) == {"value": "11-11-11"}

0 comments on commit bb6a077

Please sign in to comment.