Skip to content

Commit

Permalink
Don't encode dataclass classes
Browse files Browse the repository at this point in the history
Previously our "is this a dataclass-like thing" check would erroneously
pass for dataclass _classes_ as well as their instances. We now properly
exclude classes and only encode instances.
  • Loading branch information
jcrist committed Oct 20, 2024
1 parent 74ebfe0 commit ada66a7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
6 changes: 3 additions & 3 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -13137,7 +13137,7 @@ mpack_encode_uncommon(EncoderState *self, PyTypeObject *type, PyObject *obj)
else if (PyAnySet_Check(obj)) {
return mpack_encode_set(self, obj);
}
else if (type->tp_dict != NULL) {
else if (!PyType_Check(obj) && type->tp_dict != NULL) {
PyObject *fields = PyObject_GetAttr(obj, self->mod->str___dataclass_fields__);
if (fields != NULL) {
int status = mpack_encode_dataclass(self, obj, fields);
Expand Down Expand Up @@ -14231,7 +14231,7 @@ json_encode_uncommon(EncoderState *self, PyTypeObject *type, PyObject *obj) {
else if (PyAnySet_Check(obj)) {
return json_encode_set(self, obj);
}
else if (type->tp_dict != NULL) {
else if (!PyType_Check(obj) && type->tp_dict != NULL) {
PyObject *fields = PyObject_GetAttr(obj, self->mod->str___dataclass_fields__);
if (fields != NULL) {
int status = json_encode_dataclass(self, obj, fields);
Expand Down Expand Up @@ -19924,7 +19924,7 @@ to_builtins(ToBuiltinsState *self, PyObject *obj, bool is_key) {
else if (PyAnySet_Check(obj)) {
return to_builtins_set(self, obj, is_key);
}
else if (type->tp_dict != NULL) {
else if (!PyType_Check(obj) && type->tp_dict != NULL) {
PyObject *fields = PyObject_GetAttr(obj, self->mod->str___dataclass_fields__);
if (fields != NULL) {
PyObject *out = to_builtins_dataclass(self, obj, fields);
Expand Down
8 changes: 8 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,14 @@ class Ex:
with pytest.raises(RuntimeError, match="is not a dict"):
proto.encode(Ex(1))

def test_encode_dataclass_class_errors(self, proto):
@dataclass
class Ex:
x: int

with pytest.raises(TypeError, match="Encoding objects of type type"):
proto.encode(Ex)

def test_encode_dataclass_no_slots(self, proto):
@dataclass
class Test:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_to_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,14 @@ class Ex:
with pytest.raises(TypeError, match="Encoding objects of type Bad"):
to_builtins(msg)

def test_dataclass_class_errors(self):
@dataclass
class Ex:
x: int

with pytest.raises(TypeError, match="Encoding objects of type type"):
to_builtins(Ex)

@pytest.mark.parametrize("slots", [True, False])
def test_attrs(self, slots):
attrs = pytest.importorskip("attrs")
Expand Down

0 comments on commit ada66a7

Please sign in to comment.