Skip to content

Commit

Permalink
improved typing
Browse files Browse the repository at this point in the history
  • Loading branch information
slawwan committed Nov 13, 2024
1 parent 68aff75 commit 4bbccfd
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
27 changes: 14 additions & 13 deletions marshmallow_recipe/generics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import dataclasses
import types
import typing
from typing import Annotated, Any, Generic, TypeAlias, TypeVar, Union, get_args, get_origin
from typing import TYPE_CHECKING, Annotated, Any, Generic, TypeAlias, TypeVar, Union, get_args, get_origin

_GenericAlias: TypeAlias = typing._GenericAlias # type: ignore

if TYPE_CHECKING:
from _typeshed import DataclassInstance
else:
DataclassInstance: TypeAlias = type

TypeLike: TypeAlias = type | TypeVar | types.UnionType | types.GenericAlias | _GenericAlias
FieldsTypeMap: TypeAlias = dict[str, TypeLike]
TypeVarMap: TypeAlias = dict[TypeVar, TypeLike]
Expand All @@ -13,29 +18,25 @@
FieldsTypeVarMap: TypeAlias = dict[str, TypeVarMap]


def get_fields_type_map(t: TypeLike) -> FieldsTypeMap:
origin = get_origin(t) or t
def get_fields_type_map(cls: type) -> FieldsTypeMap:
origin: type = get_origin(cls) or cls
if not dataclasses.is_dataclass(origin):
return {}
raise ValueError(f"{origin} is not a dataclass")

class_type_var_map = get_class_type_var_map(t)
fields_class_map = get_fields_class_map(t)
class_type_var_map = get_class_type_var_map(cls)
fields_class_map = get_fields_class_map(origin)
return {
f.name: build_subscripted_type(f.type, class_type_var_map.get(fields_class_map[f.name], {}))
for f in dataclasses.fields(origin)
}


def get_fields_class_map(t: TypeLike) -> FieldsClassMap:
origin = get_origin(t) or t
if not dataclasses.is_dataclass(origin):
return {}

def get_fields_class_map(cls: type[DataclassInstance]) -> FieldsClassMap:
names: dict[str, dataclasses.Field] = {}
result: FieldsClassMap = {}

mro = origin.__mro__ # type: ignore
for cls in (*mro[-1:0:-1], origin):
mro = cls.__mro__
for cls in (*mro[-1:0:-1], cls):
if not dataclasses.is_dataclass(cls):
continue
for field in dataclasses.fields(cls):
Expand Down
7 changes: 4 additions & 3 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ class Value2(Value1[int], NonGeneric):
}


def test_get_fields_type_map_non_data_class() -> None:
actual = get_fields_type_map(int | None)
assert actual == {}
def test_get_fields_type_map_non_dataclass() -> None:
with pytest.raises(ValueError) as e:
get_fields_type_map(list[int])
assert e.value.args[0] == "<class 'list'> is not a dataclass"


def test_get_fields_type_map_not_subscripted() -> None:
Expand Down

0 comments on commit 4bbccfd

Please sign in to comment.