diff --git a/marshmallow_recipe/bake.py b/marshmallow_recipe/bake.py index 028f47c..776ce2c 100644 --- a/marshmallow_recipe/bake.py +++ b/marshmallow_recipe/bake.py @@ -55,12 +55,12 @@ class _FieldDescription(NamedTuple): def bake_schema( - t: type, + cls: type, *, naming_case: NamingCase | None = None, none_value_handling: NoneValueHandling | None = None, ) -> type[m.Schema]: - origin: type = get_origin(t) or t + origin: type = get_origin(cls) or cls if not dataclasses.is_dataclass(origin): raise ValueError(f"{origin} is not a dataclass") @@ -72,14 +72,14 @@ def bake_schema( cls_naming_case = naming_case key = _SchemaTypeKey( - cls=t, + cls=cls, naming_case=cls_naming_case, none_value_handling=cls_none_value_handling, ) if result := _schema_types.get(key): return result - fields_type_map = get_fields_type_map(t) + fields_type_map = get_fields_type_map(cls) fields = [ _FieldDescription( @@ -104,8 +104,8 @@ def bake_schema( raise ValueError(f"Invalid name={second_name} in metadata for field={second.field.name}") schema_type = type( - t.__name__, - (_get_base_schema(t, cls_none_value_handling or NoneValueHandling.IGNORE),), + cls.__name__, + (_get_base_schema(cls, cls_none_value_handling or NoneValueHandling.IGNORE),), {"__module__": f"{__package__}.auto_generated"} | { field.name: get_field_for( diff --git a/marshmallow_recipe/serialization.py b/marshmallow_recipe/serialization.py index 740c59e..e276fac 100644 --- a/marshmallow_recipe/serialization.py +++ b/marshmallow_recipe/serialization.py @@ -37,12 +37,14 @@ def __call__( class DumpFunction(Protocol): def __call__( - self, data: Any, *, naming_case: NamingCase | None = None, t: type | None = None + self, data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None ) -> dict[str, Any]: ... class DumpManyFunction(Protocol): - def __call__(self, data: list[Any], *, naming_case: NamingCase | None = None) -> list[dict[str, Any]]: ... + def __call__( + self, data: list[Any], *, naming_case: NamingCase | None = None, cls: type | None = None + ) -> list[dict[str, Any]]: ... schema: SchemaFunction @@ -75,13 +77,8 @@ def load_many_v3(cls: type[_T], data: list[dict[str, Any]], *, naming_case: Nami load_many = load_many_v3 - def dump_v3( - data: Any, - *, - naming_case: NamingCase | None = None, - t: type | None = None, - ) -> dict[str, Any]: - data_schema = schema_v3(t or type(data), naming_case=naming_case) + def dump_v3(data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None) -> dict[str, Any]: + data_schema = schema_v3(_extract_type(data, cls), naming_case=naming_case) dumped: dict[str, Any] = data_schema.dump(data) # type: ignore if errors := data_schema.validate(dumped): raise m.ValidationError(errors) @@ -89,10 +86,12 @@ def dump_v3( dump = dump_v3 - def dump_many_v3(data: list[Any], *, naming_case: NamingCase | None = None) -> list[dict[str, Any]]: + def dump_many_v3( + data: list[Any], *, naming_case: NamingCase | None = None, cls: type | None = None + ) -> list[dict[str, Any]]: if not data: return [] - data_schema = schema_v3(type(data[0]), many=True, naming_case=naming_case) + data_schema = schema_v3(_extract_type(data[0], cls), many=True, naming_case=naming_case) dumped: list[dict[str, Any]] = data_schema.dump(data) # type: ignore if errors := data_schema.validate(dumped): raise m.ValidationError(errors) @@ -125,13 +124,8 @@ def load_many_v2(cls: type[_T], data: list[dict[str, Any]], *, naming_case: Nami load_many = load_many_v2 - def dump_v2( - data: Any, - *, - naming_case: NamingCase | None = None, - t: type | None = None, - ) -> dict[str, Any]: - data_schema = schema_v2(t or type(data), naming_case=naming_case) + def dump_v2(data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None) -> dict[str, Any]: + data_schema = schema_v2(_extract_type(data, cls), naming_case=naming_case) dumped, errors = data_schema.dump(data) if errors: raise m.ValidationError(errors) @@ -141,10 +135,12 @@ def dump_v2( dump = dump_v2 - def dump_many_v2(data: list[Any], *, naming_case: NamingCase | None = None) -> list[dict[str, Any]]: + def dump_many_v2( + data: list[Any], *, naming_case: NamingCase | None = None, cls: type | None = None + ) -> list[dict[str, Any]]: if not data: return [] - data_schema = schema_v2(type(data[0]), many=True, naming_case=naming_case) + data_schema = schema_v2(_extract_type(data[0], cls), many=True, naming_case=naming_case) dumped, errors = data_schema.dump(data) if errors: raise m.ValidationError(errors) @@ -155,3 +151,11 @@ def dump_many_v2(data: list[Any], *, naming_case: NamingCase | None = None) -> l dump_many = dump_many_v2 EmptySchema = m.Schema + + +def _extract_type(data: Any, cls: type | None) -> type: + if cls: + return cls + if hasattr(data, "__orig_class__"): + return getattr(data, "__orig_class__") + return type(data) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 9e2c4e2..ad195f7 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -3,7 +3,21 @@ import decimal import enum import uuid -from typing import Annotated, Any, Dict, FrozenSet, Generic, Iterable, List, Set, Tuple, TypeVar +from contextlib import nullcontext as does_not_raise +from typing import ( + Annotated, + Any, + Callable, + ContextManager, + Dict, + FrozenSet, + Generic, + Iterable, + List, + Set, + Tuple, + TypeVar, +) import pytest @@ -629,6 +643,35 @@ class RootContainer: assert mr.load(RootContainer, {}) == RootContainer() +@pytest.mark.parametrize( + "frozen, slots, get_type, context", + [ + (False, False, lambda x: None, does_not_raise()), + (True, False, lambda x: None, pytest.raises(Exception, match="Expected subscripted generic")), + (True, True, lambda x: None, pytest.raises(Exception, match="Expected subscripted generic")), + (True, True, lambda x: x, does_not_raise()), + ], +) +def test_dump_generic_extract_type( + frozen: bool, slots: bool, get_type: Callable[[type], type | None], context: ContextManager +) -> None: + _TValue = TypeVar("_TValue") + + @dataclasses.dataclass(frozen=frozen, slots=slots) + class Data(Generic[_TValue]): + value: _TValue + + instance = Data[int](value=123) + with context: + dumped = mr.dump(instance, cls=get_type(Data[int])) + assert dumped == {"value": 123} + + instance_many = [Data[int](value=123), Data[int](value=456)] + with context: + dumped = mr.dump_many(instance_many, cls=get_type(Data[int])) + assert dumped == [{"value": 123}, {"value": 456}] + + def test_generic_in_parents() -> None: _TXxx = TypeVar("_TXxx") _TData = TypeVar("_TData") @@ -666,13 +709,13 @@ class T2(Generic[_T], T1[int]): instance = T2[str](t1=1, t2="2") - dumped = mr.dump(instance, t=T2[str]) + dumped = mr.dump(instance, cls=T2[str]) assert dumped == {"t1": 1, "t2": "2"} assert mr.load(T2[str], dumped) == instance -def test_override_with_generic() -> None: +def test_override_field_with_generic() -> None: @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class Value1: v1: str @@ -696,7 +739,7 @@ class T2(Generic[_TValue, _TItem], T1[_TItem]): instance = T2[Value2, int](value=Value2(v1="aaa", v2="bbb"), iterable=set([3, 4, 5])) - dumped = mr.dump(instance, t=T2[Value2, int]) + dumped = mr.dump(instance, cls=T2[Value2, int]) assert dumped == {"value": {"v1": "aaa", "v2": "bbb"}, "iterable": [3, 4, 5]} assert mr.load(T2[Value2, int], dumped) == instance @@ -710,13 +753,13 @@ class GenericContainer(Generic[_TItem]): items: list[_TItem] container_int = GenericContainer[int](items=[1, 2, 3]) - dumped = mr.dump(container_int, t=GenericContainer[int]) + dumped = mr.dump(container_int, cls=GenericContainer[int]) assert dumped == {"items": [1, 2, 3]} assert mr.load(GenericContainer[int], dumped) == container_int container_str = GenericContainer[str](items=["q", "w", "e"]) - dumped = mr.dump(container_str, t=GenericContainer[str]) + dumped = mr.dump(container_str, cls=GenericContainer[str]) assert dumped == {"items": ["q", "w", "e"]} assert mr.load(GenericContainer[str], dumped) == container_str