Skip to content

Commit

Permalink
added __orig_class__ as source of cls
Browse files Browse the repository at this point in the history
  • Loading branch information
slawwan committed Nov 9, 2024
1 parent d0e3fca commit c396fa8
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 32 deletions.
12 changes: 6 additions & 6 deletions marshmallow_recipe/bake.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ class _FieldDescription(NamedTuple):


def bake_schema(
t: type,
cls: type,
*,
naming_case: NamingCase | None = None,
none_value_handling: NoneValueHandling | None = None,
) -> type[m.Schema]:
origin: type = get_origin(t) or t
origin: type = get_origin(cls) or cls
if not dataclasses.is_dataclass(origin):
raise ValueError(f"{origin} is not a dataclass")

Expand All @@ -72,14 +72,14 @@ def bake_schema(
cls_naming_case = naming_case

key = _SchemaTypeKey(
cls=t,
cls=cls,
naming_case=cls_naming_case,
none_value_handling=cls_none_value_handling,
)
if result := _schema_types.get(key):
return result

fields_type_map = get_fields_type_map(t)
fields_type_map = get_fields_type_map(cls)

fields = [
_FieldDescription(
Expand All @@ -104,8 +104,8 @@ def bake_schema(
raise ValueError(f"Invalid name={second_name} in metadata for field={second.field.name}")

schema_type = type(
t.__name__,
(_get_base_schema(t, cls_none_value_handling or NoneValueHandling.IGNORE),),
cls.__name__,
(_get_base_schema(cls, cls_none_value_handling or NoneValueHandling.IGNORE),),
{"__module__": f"{__package__}.auto_generated"}
| {
field.name: get_field_for(
Expand Down
44 changes: 24 additions & 20 deletions marshmallow_recipe/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ def __call__(

class DumpFunction(Protocol):
def __call__(
self, data: Any, *, naming_case: NamingCase | None = None, t: type | None = None
self, data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None
) -> dict[str, Any]: ...


class DumpManyFunction(Protocol):
def __call__(self, data: list[Any], *, naming_case: NamingCase | None = None) -> list[dict[str, Any]]: ...
def __call__(
self, data: list[Any], *, naming_case: NamingCase | None = None, cls: type | None = None
) -> list[dict[str, Any]]: ...


schema: SchemaFunction
Expand Down Expand Up @@ -75,24 +77,21 @@ def load_many_v3(cls: type[_T], data: list[dict[str, Any]], *, naming_case: Nami

load_many = load_many_v3

def dump_v3(
data: Any,
*,
naming_case: NamingCase | None = None,
t: type | None = None,
) -> dict[str, Any]:
data_schema = schema_v3(t or type(data), naming_case=naming_case)
def dump_v3(data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None) -> dict[str, Any]:
data_schema = schema_v3(_extract_type(data, cls), naming_case=naming_case)
dumped: dict[str, Any] = data_schema.dump(data) # type: ignore
if errors := data_schema.validate(dumped):
raise m.ValidationError(errors)
return dumped

dump = dump_v3

def dump_many_v3(data: list[Any], *, naming_case: NamingCase | None = None) -> list[dict[str, Any]]:
def dump_many_v3(
data: list[Any], *, naming_case: NamingCase | None = None, cls: type | None = None
) -> list[dict[str, Any]]:
if not data:
return []
data_schema = schema_v3(type(data[0]), many=True, naming_case=naming_case)
data_schema = schema_v3(_extract_type(data[0], cls), many=True, naming_case=naming_case)
dumped: list[dict[str, Any]] = data_schema.dump(data) # type: ignore
if errors := data_schema.validate(dumped):
raise m.ValidationError(errors)
Expand Down Expand Up @@ -125,13 +124,8 @@ def load_many_v2(cls: type[_T], data: list[dict[str, Any]], *, naming_case: Nami

load_many = load_many_v2

def dump_v2(
data: Any,
*,
naming_case: NamingCase | None = None,
t: type | None = None,
) -> dict[str, Any]:
data_schema = schema_v2(t or type(data), naming_case=naming_case)
def dump_v2(data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None) -> dict[str, Any]:
data_schema = schema_v2(_extract_type(data, cls), naming_case=naming_case)
dumped, errors = data_schema.dump(data)
if errors:
raise m.ValidationError(errors)
Expand All @@ -141,10 +135,12 @@ def dump_v2(

dump = dump_v2

def dump_many_v2(data: list[Any], *, naming_case: NamingCase | None = None) -> list[dict[str, Any]]:
def dump_many_v2(
data: list[Any], *, naming_case: NamingCase | None = None, cls: type | None = None
) -> list[dict[str, Any]]:
if not data:
return []
data_schema = schema_v2(type(data[0]), many=True, naming_case=naming_case)
data_schema = schema_v2(_extract_type(data[0], cls), many=True, naming_case=naming_case)
dumped, errors = data_schema.dump(data)
if errors:
raise m.ValidationError(errors)
Expand All @@ -155,3 +151,11 @@ def dump_many_v2(data: list[Any], *, naming_case: NamingCase | None = None) -> l
dump_many = dump_many_v2

EmptySchema = m.Schema


def _extract_type(data: Any, cls: type | None) -> type:
if cls:
return cls
if hasattr(data, "__orig_class__"):
return getattr(data, "__orig_class__")
return type(data)
55 changes: 49 additions & 6 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
import decimal
import enum
import uuid
from typing import Annotated, Any, Dict, FrozenSet, Generic, Iterable, List, Set, Tuple, TypeVar
from contextlib import nullcontext as does_not_raise
from typing import (
Annotated,
Any,
Callable,
ContextManager,
Dict,
FrozenSet,
Generic,
Iterable,
List,
Set,
Tuple,
TypeVar,
)

import pytest

Expand Down Expand Up @@ -629,6 +643,35 @@ class RootContainer:
assert mr.load(RootContainer, {}) == RootContainer()


@pytest.mark.parametrize(
"frozen, slots, get_type, context",
[
(False, False, lambda x: None, does_not_raise()),
(True, False, lambda x: None, pytest.raises(Exception, match="Expected subscripted generic")),
(True, True, lambda x: None, pytest.raises(Exception, match="Expected subscripted generic")),
(True, True, lambda x: x, does_not_raise()),
],
)
def test_dump_generic_extract_type(
frozen: bool, slots: bool, get_type: Callable[[type], type | None], context: ContextManager
) -> None:
_TValue = TypeVar("_TValue")

@dataclasses.dataclass(frozen=frozen, slots=slots)
class Data(Generic[_TValue]):
value: _TValue

instance = Data[int](value=123)
with context:
dumped = mr.dump(instance, cls=get_type(Data[int]))
assert dumped == {"value": 123}

instance_many = [Data[int](value=123), Data[int](value=456)]
with context:
dumped = mr.dump_many(instance_many, cls=get_type(Data[int]))
assert dumped == [{"value": 123}, {"value": 456}]


def test_generic_in_parents() -> None:
_TXxx = TypeVar("_TXxx")
_TData = TypeVar("_TData")
Expand Down Expand Up @@ -666,13 +709,13 @@ class T2(Generic[_T], T1[int]):

instance = T2[str](t1=1, t2="2")

dumped = mr.dump(instance, t=T2[str])
dumped = mr.dump(instance, cls=T2[str])

assert dumped == {"t1": 1, "t2": "2"}
assert mr.load(T2[str], dumped) == instance


def test_override_with_generic() -> None:
def test_override_field_with_generic() -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class Value1:
v1: str
Expand All @@ -696,7 +739,7 @@ class T2(Generic[_TValue, _TItem], T1[_TItem]):

instance = T2[Value2, int](value=Value2(v1="aaa", v2="bbb"), iterable=set([3, 4, 5]))

dumped = mr.dump(instance, t=T2[Value2, int])
dumped = mr.dump(instance, cls=T2[Value2, int])

assert dumped == {"value": {"v1": "aaa", "v2": "bbb"}, "iterable": [3, 4, 5]}
assert mr.load(T2[Value2, int], dumped) == instance
Expand All @@ -710,13 +753,13 @@ class GenericContainer(Generic[_TItem]):
items: list[_TItem]

container_int = GenericContainer[int](items=[1, 2, 3])
dumped = mr.dump(container_int, t=GenericContainer[int])
dumped = mr.dump(container_int, cls=GenericContainer[int])

assert dumped == {"items": [1, 2, 3]}
assert mr.load(GenericContainer[int], dumped) == container_int

container_str = GenericContainer[str](items=["q", "w", "e"])
dumped = mr.dump(container_str, t=GenericContainer[str])
dumped = mr.dump(container_str, cls=GenericContainer[str])

assert dumped == {"items": ["q", "w", "e"]}
assert mr.load(GenericContainer[str], dumped) == container_str

0 comments on commit c396fa8

Please sign in to comment.