Skip to content

Commit

Permalink
Call __post_init__ when converting struct from object
Browse files Browse the repository at this point in the history
Previously a struct's `__post_init__` method wouldn't be called when
converting from a custom mapping or object with `from_attributes=True`.
This PR fixes that and expands the tests to cover this case.
  • Loading branch information
jcrist committed Oct 20, 2024
1 parent 3363eea commit 150f64a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
3 changes: 3 additions & 0 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -21364,6 +21364,9 @@ convert_object_to_struct(
should_untrack = !MS_MAYBE_TRACKED(val);
}
}

if (Struct_decode_post_init(struct_type, out, path) < 0) goto error;

Py_LeaveRecursiveCall();
if (is_gc && !should_untrack)
PyObject_GC_Track(out);
Expand Down
35 changes: 16 additions & 19 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2239,18 +2239,18 @@ class Test2(Struct, Generic[T], tag=True, array_like=array_like):


class TestStructPostInit:
@pytest.mark.parametrize("array_like", [False, True])
@pytest.mark.parametrize("union", [False, True])
def test_struct_post_init(self, array_like, union):
count = 0
@mapcls_from_attributes_and_array_like
def test_struct_post_init(self, union, mapcls, from_attributes, array_like):
called = False
singleton = object()

class Ex(Struct, array_like=array_like, tag=union):
x: int

def __post_init__(self):
nonlocal count
count += 1
nonlocal called
called = True
return singleton

if union:
Expand All @@ -2262,25 +2262,23 @@ class Ex2(Struct, array_like=array_like, tag=True):
else:
typ = Ex

msg = Ex(1)
buf = to_builtins(msg)
res = convert(buf, type=typ)
assert res == msg
assert count == 2 # 1 for Ex(), 1 for decode
msg = mapcls(type="Ex", x=1) if union else mapcls(x=1)
res = convert(msg, type=typ, from_attributes=from_attributes)
assert type(res) is Ex
assert called
assert sys.getrefcount(singleton) == 2 # 1 for ref, 1 for call

@pytest.mark.parametrize("array_like", [False, True])
@pytest.mark.parametrize("union", [False, True])
@pytest.mark.parametrize("exc_class", [ValueError, TypeError, OSError])
def test_struct_post_init_errors(self, array_like, union, exc_class):
error = False

@mapcls_from_attributes_and_array_like
def test_struct_post_init_errors(
self, union, exc_class, mapcls, from_attributes, array_like
):
class Ex(Struct, array_like=array_like, tag=union):
x: int

def __post_init__(self):
if error:
raise exc_class("Oh no!")
raise exc_class("Oh no!")

if union:

Expand All @@ -2291,16 +2289,15 @@ class Ex2(Struct, array_like=array_like, tag=True):
else:
typ = Ex

msg = to_builtins([Ex(1)])
error = True
msg = [mapcls(type="Ex", x=1) if union else mapcls(x=1)]

if exc_class in (ValueError, TypeError):
expected = ValidationError
else:
expected = exc_class

with pytest.raises(expected, match="Oh no!") as rec:
convert(msg, type=List[typ])
convert(msg, type=List[typ], from_attributes=from_attributes)

if expected is ValidationError:
assert "- at `$[0]`" in str(rec.value)
Expand Down

0 comments on commit 150f64a

Please sign in to comment.