Skip to content

Commit

Permalink
Fix bug handling default values for frozen & slotted dataclasses
Browse files Browse the repository at this point in the history
Previously there was a bug where dataclasses or attrs classes with
`slots=True, frozen=True` wouldn't be able to properly use default
values for optional fields.
  • Loading branch information
jcrist committed Oct 17, 2023
1 parent 260dfbc commit 9f5f50b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -8049,7 +8049,7 @@ DataclassInfo_post_decode(DataclassInfo *self, PyObject *obj, PathNode *path) {
default_value = CALL_NO_ARGS(default_value);
if (default_value == NULL) return -1;
}
int status = PyObject_SetAttr(obj, name, default_value);
int status = PyObject_GenericSetAttr(obj, name, default_value);
if (is_factory) {
Py_DECREF(default_value);
}
Expand Down
15 changes: 9 additions & 6 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2694,16 +2694,17 @@ class Example:
):
dec.decode(proto.encode({"a": "bad"}))

@pytest.mark.parametrize("frozen", [False, True])
@pytest.mark.parametrize("slots", [False, True])
def test_decode_dataclass_defaults(self, proto, slots):
def test_decode_dataclass_defaults(self, proto, frozen, slots):
if slots:
if not PY310:
pytest.skip(reason="Python 3.10+ required")
kws = {"slots": True}
else:
kws = {}

@dataclass(**kws)
@dataclass(frozen=frozen, **kws)
class Example:
a: int
b: int
Expand Down Expand Up @@ -2904,9 +2905,10 @@ class Example:
):
dec.decode(proto.encode({"a": "bad"}))

@pytest.mark.parametrize("frozen", [False, True])
@pytest.mark.parametrize("slots", [False, True])
def test_decode_attrs_defaults(self, proto, slots):
@attrs.define(slots=slots)
def test_decode_attrs_defaults(self, proto, frozen, slots):
@attrs.define(frozen=frozen, slots=slots)
class Example:
a: int
b: int
Expand All @@ -2916,9 +2918,10 @@ class Example:

dec = proto.Decoder(Example)
for args in [(1, 2), (1, 2, 3), (1, 2, 3, 4), (1, 2, 3, 4, 5)]:
msg = Example(*args)
sol = Example(*args)
msg = dict(zip("abcde", args))
res = dec.decode(proto.encode(msg))
assert res == msg
assert res == sol

# Missing fields error
with pytest.raises(ValidationError, match="missing required field `a`"):
Expand Down
10 changes: 6 additions & 4 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,17 +1388,18 @@ class Ex:

assert convert(msg, Ex, from_attributes=True) == Ex(1)

@pytest.mark.parametrize("frozen", [False, True])
@pytest.mark.parametrize("slots", [False, True])
@mapcls_and_from_attributes
def test_dataclass_defaults(self, slots, mapcls, from_attributes):
def test_dataclass_defaults(self, frozen, slots, mapcls, from_attributes):
if slots:
if not PY310:
pytest.skip(reason="Python 3.10+ required")
kws = {"slots": True}
else:
kws = {}

@dataclass(**kws)
@dataclass(frozen=frozen, **kws)
class Example:
a: int
b: int
Expand Down Expand Up @@ -1557,10 +1558,11 @@ class Ex:

assert convert(msg, Ex, from_attributes=True) == Ex(1)

@pytest.mark.parametrize("frozen", [False, True])
@pytest.mark.parametrize("slots", [False, True])
@mapcls_and_from_attributes
def test_attrs_defaults(self, slots, mapcls, from_attributes):
@attrs.define(slots=slots)
def test_attrs_defaults(self, frozen, slots, mapcls, from_attributes):
@attrs.define(frozen=frozen, slots=slots)
class Example:
a: int
b: int
Expand Down

0 comments on commit 9f5f50b

Please sign in to comment.