Skip to content

Commit

Permalink
improved extract_type
Browse files Browse the repository at this point in the history
  • Loading branch information
slawwan committed Nov 22, 2024
1 parent 1435382 commit 7a944cd
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 23 deletions.
25 changes: 25 additions & 0 deletions marshmallow_recipe/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@
FieldsTypeVarMap: TypeAlias = dict[str, TypeVarMap]


def extract_type(data: Any, cls: type | None) -> type:
data_type = _get_orig_class(data) or type(data)

if not _is_unsubscripted_type(data_type):
if cls and data_type != cls:
raise ValueError(f"{cls=} is invalid but can be removed, actual type is {data_type}")
return data_type

if not cls:
raise ValueError(f"Explicit cls required for unsubscripted type {data_type}")

if _is_unsubscripted_type(cls) or get_origin(cls) != data_type:
raise ValueError(f"{cls=} is not subscripted version of {data_type}")

return cls


def get_fields_type_map(cls: type) -> FieldsTypeMap:
origin: type = get_origin(cls) or cls
if not dataclasses.is_dataclass(origin):
Expand Down Expand Up @@ -98,6 +115,14 @@ def _build_class_type_var_map(t: TypeLike, class_type_var_map: ClassTypeVarMap)
_build_class_type_var_map(subscripted_base, class_type_var_map)


def _is_unsubscripted_type(t: TypeLike) -> bool:
return bool(_get_params(t)) or any(_is_unsubscripted_type(arg) for arg in get_args(t) or [])


def _get_orig_class(t: Any) -> type | None:
return hasattr(t, "__orig_class__") and getattr(t, "__orig_class__") or None


def _get_params(t: Any) -> tuple[TypeLike, ...] | None:
return hasattr(t, "__parameters__") and getattr(t, "__parameters__") or None

Expand Down
23 changes: 6 additions & 17 deletions marshmallow_recipe/serialization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import dataclasses
import importlib.metadata
from typing import Any, Protocol, TypeVar, get_origin
from typing import Any, Protocol, TypeVar

import marshmallow as m

from .bake import bake_schema
from .generics import extract_type
from .naming_case import NamingCase

_T = TypeVar("_T")
Expand Down Expand Up @@ -78,7 +79,7 @@ def load_many_v3(cls: type[_T], data: list[dict[str, Any]], *, naming_case: Nami
load_many = load_many_v3

def dump_v3(data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None) -> dict[str, Any]:
data_schema = schema_v3(_extract_type(data, cls), naming_case=naming_case)
data_schema = schema_v3(extract_type(data, cls), naming_case=naming_case)
dumped: dict[str, Any] = data_schema.dump(data) # type: ignore
if errors := data_schema.validate(dumped):
raise m.ValidationError(errors)
Expand All @@ -91,7 +92,7 @@ def dump_many_v3(
) -> list[dict[str, Any]]:
if not data:
return []
data_schema = schema_v3(_extract_type(data[0], cls), many=True, naming_case=naming_case)
data_schema = schema_v3(extract_type(data[0], cls), many=True, naming_case=naming_case)
dumped: list[dict[str, Any]] = data_schema.dump(data) # type: ignore
if errors := data_schema.validate(dumped):
raise m.ValidationError(errors)
Expand Down Expand Up @@ -125,7 +126,7 @@ def load_many_v2(cls: type[_T], data: list[dict[str, Any]], *, naming_case: Nami
load_many = load_many_v2

def dump_v2(data: Any, *, naming_case: NamingCase | None = None, cls: type | None = None) -> dict[str, Any]:
data_schema = schema_v2(_extract_type(data, cls), naming_case=naming_case)
data_schema = schema_v2(extract_type(data, cls), naming_case=naming_case)
dumped, errors = data_schema.dump(data)
if errors:
raise m.ValidationError(errors)
Expand All @@ -140,7 +141,7 @@ def dump_many_v2(
) -> list[dict[str, Any]]:
if not data:
return []
data_schema = schema_v2(_extract_type(data[0], cls), many=True, naming_case=naming_case)
data_schema = schema_v2(extract_type(data[0], cls), many=True, naming_case=naming_case)
dumped, errors = data_schema.dump(data)
if errors:
raise m.ValidationError(errors)
Expand All @@ -151,15 +152,3 @@ def dump_many_v2(
dump_many = dump_many_v2

EmptySchema = m.Schema


def _extract_type(data: Any, cls: type | None) -> type:
if hasattr(data, "__orig_class__"):
return getattr(data, "__orig_class__")
data_type = type(data)
if not cls or cls == data_type:
return data_type
origin = get_origin(cls)
if origin != data_type:
raise ValueError(f"{cls=} is not subscripted version of {data_type}")
return cls
66 changes: 65 additions & 1 deletion tests/test_generics.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,80 @@
import dataclasses
import types
from typing import Annotated, Any, Generic, Iterable, List, TypeVar, Union
from contextlib import nullcontext as does_not_raise
from typing import Annotated, Any, ContextManager, Generic, Iterable, List, TypeVar, Union
from unittest.mock import ANY

import pytest

from marshmallow_recipe.generics import (
build_subscripted_type,
extract_type,
get_class_type_var_map,
get_fields_class_map,
get_fields_type_map,
)

T = TypeVar("T")


@dataclasses.dataclass()
class OtherType:
pass


@dataclasses.dataclass()
class NonGeneric:
pass


@dataclasses.dataclass()
class RegularGeneric(Generic[T]):
pass


@dataclasses.dataclass(frozen=True)
class FrozenGeneric(Generic[T]):
pass


def e(match: str) -> ContextManager:
return pytest.raises(ValueError, match=match)


@pytest.mark.parametrize(
"data, cls, expected, context",
[
(1, None, int, does_not_raise()),
(1, int, int, does_not_raise()),
(1, OtherType, ANY, e("OtherType'> is invalid but can be removed, actual type is <class 'int'>")),
(NonGeneric(), None, NonGeneric, does_not_raise()),
(NonGeneric(), NonGeneric, NonGeneric, does_not_raise()),
(NonGeneric(), OtherType, ANY, e("OtherType'> is invalid but can be removed, actual type is <class 'tests")),
(RegularGeneric(), None, ANY, e("Explicit cls required for unsubscripted type <class 'tests")),
(RegularGeneric(), RegularGeneric, ANY, e("RegularGeneric'> is not subscripted version of <class 'tests")),
(RegularGeneric(), RegularGeneric[int], RegularGeneric[int], does_not_raise()),
(RegularGeneric[int](), None, RegularGeneric[int], does_not_raise()),
(RegularGeneric[int](), RegularGeneric[int], RegularGeneric[int], does_not_raise()),
(RegularGeneric[int](), RegularGeneric[str], ANY, e("str] is invalid but can be removed, actual type is")),
(RegularGeneric[int](), RegularGeneric, ANY, e("RegularGeneric'> is invalid but can be removed, actual type")),
(RegularGeneric[RegularGeneric[int]](), RegularGeneric[RegularGeneric[str]], ANY, e("str]] is invalid but")),
(RegularGeneric[int](), OtherType, ANY, e("OtherType'> is invalid but can be removed, actual type is tests")),
(FrozenGeneric[int](), None, ANY, e("Explicit cls required for unsubscripted type <class")),
(FrozenGeneric[int](), FrozenGeneric[str], FrozenGeneric[str], does_not_raise()),
(FrozenGeneric(), FrozenGeneric[int], FrozenGeneric[int], does_not_raise()),
(FrozenGeneric(), FrozenGeneric[list], FrozenGeneric[list], does_not_raise()),
(FrozenGeneric(), FrozenGeneric[dict], FrozenGeneric[dict], does_not_raise()),
(FrozenGeneric(), FrozenGeneric[FrozenGeneric[str]], FrozenGeneric[FrozenGeneric[str]], does_not_raise()),
(FrozenGeneric(), FrozenGeneric, ANY, e("FrozenGeneric'> is not subscripted version of <class 'tests")),
(FrozenGeneric(), FrozenGeneric[FrozenGeneric], ANY, e("FrozenGeneric] is not subscripted version of")),
(FrozenGeneric(), OtherType, ANY, e("OtherType'> is not subscripted version of <class 'tests")),
],
)
def test_extract_type(data: Any, cls: type, expected: type, context: ContextManager) -> None:
with context:
actual = extract_type(data, cls)
assert actual == expected


def test_get_fields_type_map_with_field_override() -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
Expand Down
41 changes: 36 additions & 5 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,11 +648,13 @@ class RootContainer:
"frozen, slots, get_type, context",
[
(False, False, lambda x: None, does_not_raise()),
(True, False, lambda x: None, pytest.raises(ValueError, match="Expected subscripted generic")),
(False, True, lambda x: None, pytest.raises(ValueError, match="Expected subscripted generic")),
(True, True, lambda x: None, pytest.raises(ValueError, match="Expected subscripted generic")),
(True, True, lambda x: get_origin(x), pytest.raises(ValueError, match="Expected subscripted generic")),
(True, True, lambda x: list[int], pytest.raises(ValueError, match="is not subscripted version of")),
(False, False, lambda x: int, pytest.raises(ValueError, match="<class 'int'> is invalid but can be removed")),
(True, False, lambda x: None, pytest.raises(ValueError, match="Explicit cls required for unsubscripted type")),
(False, True, lambda x: None, pytest.raises(ValueError, match="Explicit cls required for unsubscripted type")),
(True, True, lambda x: None, pytest.raises(ValueError, match="Explicit cls required for unsubscripted type")),
(True, True, lambda x: get_origin(x), pytest.raises(ValueError, match=".Data'> is not subscripted version of")),
(True, True, lambda x: list[int], pytest.raises(ValueError, match="int] is not subscripted version of")),
(True, True, lambda x: int, pytest.raises(ValueError, match="<class 'int'> is not subscripted version of")),
(True, True, lambda x: x, does_not_raise()),
],
)
Expand All @@ -676,6 +678,35 @@ class Data(Generic[_TValue]):
assert dumped == [{"value": 123}, {"value": 456}]


@pytest.mark.parametrize(
"frozen, slots, get_type, context",
[
(False, False, lambda x: None, does_not_raise()),
(False, True, lambda x: x, does_not_raise()),
(True, False, lambda x: x, does_not_raise()),
(True, True, lambda x: x, does_not_raise()),
(False, False, lambda x: int, pytest.raises(ValueError, match="<class 'int'> is invalid but can be removed")),
(True, True, lambda x: int, pytest.raises(ValueError, match="<class 'int'> is invalid but can be removed")),
],
)
def test_non_generic_extract_type_on_dump(
frozen: bool, slots: bool, get_type: Callable[[type], type | None], context: ContextManager
) -> None:
@dataclasses.dataclass(frozen=frozen, slots=slots)
class Data:
value: int

instance = Data(value=123)
with context:
dumped = mr.dump(instance, cls=get_type(Data))
assert dumped == {"value": 123}

instance_many = [Data(value=123), Data(value=456)]
with context:
dumped = mr.dump_many(instance_many, cls=get_type(Data))
assert dumped == [{"value": 123}, {"value": 456}]


def test_generic_in_parents() -> None:
_TXxx = TypeVar("_TXxx")
_TData = TypeVar("_TData")
Expand Down

0 comments on commit 7a944cd

Please sign in to comment.