Skip to content

Commit

Permalink
Almost stop using marshmallow defaults (except Nones) (#143)
Browse files Browse the repository at this point in the history
* Fix test with defaults

* Stop using marshmallow default

---------

Co-authored-by: Iurii Pliner <[email protected]>
  • Loading branch information
roman-anna-money and Pliner authored Dec 15, 2023
1 parent 5c3fcdb commit 73eebb2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
34 changes: 17 additions & 17 deletions marshmallow_recipe/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def str_field(
validate=validate,
strip_whitespaces=strip_whitespaces,
post_load=post_load,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -84,7 +84,7 @@ def bool_field(
return m.fields.Bool(
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -127,7 +127,7 @@ def decimal_field(
as_string=as_string,
places=places,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -157,7 +157,7 @@ def int_field(
return m.fields.Int(
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -187,7 +187,7 @@ def float_field(
return m.fields.Float(
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -217,7 +217,7 @@ def uuid_field(
return m.fields.UUID(
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -252,7 +252,7 @@ def datetime_field(
allow_none=allow_none,
validate=validate,
format=format,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -282,7 +282,7 @@ def time_field(
return m.fields.Time(
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -312,7 +312,7 @@ def date_field(
return m.fields.Date(
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -347,7 +347,7 @@ def nested_field(
nested_schema,
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -380,7 +380,7 @@ def list_field(
field,
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -413,7 +413,7 @@ def set_field(
field,
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -446,7 +446,7 @@ def frozen_set_field(
field,
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -479,7 +479,7 @@ def tuple_field(
field,
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -522,7 +522,7 @@ def dict_field(
values=values_field,
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down Expand Up @@ -561,7 +561,7 @@ def enum_field(
enum_type=enum_type,
allow_none=allow_none,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand All @@ -576,7 +576,7 @@ def raw_field(
return m.fields.Raw(
allow_none=True,
validate=validate,
**default_fields(None if default is dataclasses.MISSING else default),
**(default_fields(None) if default is dataclasses.MISSING else {}),
**data_key_fields(name),
)

Expand Down
14 changes: 13 additions & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class SimpleTypesContainers:
optional_enum_field="even",
)

raw_no_defaults = {k: v for k, v in raw.items() if not k.endswith("default") or not k.endswith("default_factory")}
raw_no_defaults = {k: v for k, v in raw.items() if not k.endswith("default") and not k.endswith("default_factory")}

loaded = mr.load(SimpleTypesContainers, raw)
loaded_no_defaults = mr.load(SimpleTypesContainers, raw_no_defaults)
Expand Down Expand Up @@ -568,3 +568,15 @@ class StrContainer:

assert StrContainer(value1="111111", value2=None) == mr.load(StrContainer, {"value1": "11-11-11"})
assert mr.dump(StrContainer(value1="11-11-11", value2=None)) == {"value1": "11-11-11"}


def test_nested_default() -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class IntContainer:
value: int = 42

@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class RootContainer:
int_container: IntContainer = dataclasses.field(default_factory=IntContainer)

assert mr.load(RootContainer, {}) == RootContainer()

0 comments on commit 73eebb2

Please sign in to comment.