From ada66a75d3ec8dd2d0ca5c40646aeb7eeb2c7744 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Sun, 20 Oct 2024 14:30:19 -0500 Subject: [PATCH] Don't encode dataclass classes Previously our "is this a dataclass-like thing" check would erroneously pass for dataclass _classes_ as well as their instances. We now properly exclude classes and only encode instances. --- msgspec/_core.c | 6 +++--- tests/test_common.py | 8 ++++++++ tests/test_to_builtins.py | 8 ++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/msgspec/_core.c b/msgspec/_core.c index 1b349c8b..a32df857 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -13137,7 +13137,7 @@ mpack_encode_uncommon(EncoderState *self, PyTypeObject *type, PyObject *obj) else if (PyAnySet_Check(obj)) { return mpack_encode_set(self, obj); } - else if (type->tp_dict != NULL) { + else if (!PyType_Check(obj) && type->tp_dict != NULL) { PyObject *fields = PyObject_GetAttr(obj, self->mod->str___dataclass_fields__); if (fields != NULL) { int status = mpack_encode_dataclass(self, obj, fields); @@ -14231,7 +14231,7 @@ json_encode_uncommon(EncoderState *self, PyTypeObject *type, PyObject *obj) { else if (PyAnySet_Check(obj)) { return json_encode_set(self, obj); } - else if (type->tp_dict != NULL) { + else if (!PyType_Check(obj) && type->tp_dict != NULL) { PyObject *fields = PyObject_GetAttr(obj, self->mod->str___dataclass_fields__); if (fields != NULL) { int status = json_encode_dataclass(self, obj, fields); @@ -19924,7 +19924,7 @@ to_builtins(ToBuiltinsState *self, PyObject *obj, bool is_key) { else if (PyAnySet_Check(obj)) { return to_builtins_set(self, obj, is_key); } - else if (type->tp_dict != NULL) { + else if (!PyType_Check(obj) && type->tp_dict != NULL) { PyObject *fields = PyObject_GetAttr(obj, self->mod->str___dataclass_fields__); if (fields != NULL) { PyObject *out = to_builtins_dataclass(self, obj, fields); diff --git a/tests/test_common.py b/tests/test_common.py index c5bd92fa..2ef8aaf2 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -2490,6 +2490,14 @@ class Ex: with pytest.raises(RuntimeError, match="is not a dict"): proto.encode(Ex(1)) + def test_encode_dataclass_class_errors(self, proto): + @dataclass + class Ex: + x: int + + with pytest.raises(TypeError, match="Encoding objects of type type"): + proto.encode(Ex) + def test_encode_dataclass_no_slots(self, proto): @dataclass class Test: diff --git a/tests/test_to_builtins.py b/tests/test_to_builtins.py index 0f726e8b..56da76e9 100644 --- a/tests/test_to_builtins.py +++ b/tests/test_to_builtins.py @@ -428,6 +428,14 @@ class Ex: with pytest.raises(TypeError, match="Encoding objects of type Bad"): to_builtins(msg) + def test_dataclass_class_errors(self): + @dataclass + class Ex: + x: int + + with pytest.raises(TypeError, match="Encoding objects of type type"): + to_builtins(Ex) + @pytest.mark.parametrize("slots", [True, False]) def test_attrs(self, slots): attrs = pytest.importorskip("attrs")