Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
slawwan committed Nov 10, 2024
1 parent c396fa8 commit 886c1cb
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 20 deletions.
6 changes: 3 additions & 3 deletions marshmallow_recipe/bake.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class _SchemaTypeKey:

class _FieldDescription(NamedTuple):
field: dataclasses.Field
subscripted_type: TypeLike
value_type: TypeLike
metadata: Metadata


Expand Down Expand Up @@ -109,12 +109,12 @@ def bake_schema(
{"__module__": f"{__package__}.auto_generated"}
| {
field.name: get_field_for(
data_type,
value_type,
metadata,
naming_case=naming_case,
none_value_handling=none_value_handling,
)
for field, data_type, metadata in fields
for field, value_type, metadata in fields
},
)
_schema_types[key] = schema_type
Expand Down
23 changes: 13 additions & 10 deletions marshmallow_recipe/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_fields_class_map(t: TypeLike) -> FieldsClassMap:
if names.get(field.name) != field:
names[field.name] = field
result[field.name] = base

return result


Expand All @@ -68,33 +69,35 @@ def build_subscripted_type(t: TypeLike, type_var_map: TypeVarMap) -> TypeLike:

def get_class_type_var_map(t: TypeLike) -> ClassTypeVarMap:
class_type_var_map: ClassTypeVarMap = {}
_get_class_type_var_map(t, class_type_var_map)
_build_class_type_var_map(t, class_type_var_map)
return class_type_var_map


def _get_class_type_var_map(t: TypeLike, class_type_var_map: ClassTypeVarMap) -> None:
if _get_parameters(t):
def _build_class_type_var_map(t: TypeLike, class_type_var_map: ClassTypeVarMap) -> None:
if _get_params(t):
raise Exception(f"Expected subscripted generic, but got unsubscripted {t}")

type_var_map: TypeVarMap = {}
origin = get_origin(t) or t
parameters = _get_parameters(origin)
params = _get_params(origin)
args = get_args(t)
if parameters or args:
if not parameters or not args or len(parameters) != len(args):
if params or args:
if not params or not args or len(params) != len(args):
raise Exception(f"Unexpected generic {t}")
class_type_var_map[origin] = type_var_map
for i, parameter in enumerate(parameters):
for i, parameter in enumerate(params):
assert isinstance(parameter, TypeVar)
type_var_map[parameter] = args[i]

if orig_bases := _get_orig_bases(origin):
for orig_base in orig_bases:
if get_origin(orig_base) is not Generic:
_get_class_type_var_map(build_subscripted_type(orig_base, type_var_map), class_type_var_map)
if get_origin(orig_base) is Generic:
continue
subscripted_base = build_subscripted_type(orig_base, type_var_map)
_build_class_type_var_map(subscripted_base, class_type_var_map)


def _get_parameters(t: Any) -> tuple[TypeLike, ...] | None:
def _get_params(t: Any) -> tuple[TypeLike, ...] | None:
return hasattr(t, "__parameters__") and getattr(t, "__parameters__") or None


Expand Down
17 changes: 14 additions & 3 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


def test_get_fields_type_map_overrides() -> None:
def test_get_fields_type_map_with_field_override() -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class Value1:
v1: str
Expand Down Expand Up @@ -101,6 +101,17 @@ class Xxx(Generic[_T]):
)


def test_get_fields_type_map_for_subscripted() -> None:
_T = TypeVar("_T")

@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class Xxx(Generic[_T]):
xxx: _T

actual = get_fields_type_map(Xxx[str])
assert actual == {"xxx": str}


def test_get_fields_class_map() -> None:
_T = TypeVar("_T")

Expand Down Expand Up @@ -142,7 +153,7 @@ class Base3(Base2, BaseG):
}


def test_get_class_type_var_map_inheritance() -> None:
def test_get_class_type_var_map_with_inheritance() -> None:
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")
Expand Down Expand Up @@ -187,7 +198,7 @@ class Ddd(Generic[_T1, _T2, _T3], Bbb[_T2], Ccc[_T1], NonGeneric):
}


def test_get_class_type_var_map_nesting() -> None:
def test_get_class_type_var_map_with_nesting() -> None:
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ class RootContainer:
(True, True, lambda x: x, does_not_raise()),
],
)
def test_dump_generic_extract_type(
def test_generic_extract_type_on_dump(
frozen: bool, slots: bool, get_type: Callable[[type], type | None], context: ContextManager
) -> None:
_TValue = TypeVar("_TValue")
Expand Down Expand Up @@ -696,7 +696,7 @@ class ChildClass(ParentClass[Data[int]]):
assert mr.load(ChildClass, dumped) == instance


def test_generic_reused_type_var() -> None:
def test_generic_type_var_with_reuse() -> None:
_T = TypeVar("_T")

@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
Expand All @@ -715,7 +715,7 @@ class T2(Generic[_T], T1[int]):
assert mr.load(T2[str], dumped) == instance


def test_override_field_with_generic() -> None:
def test_generic_with_field_override() -> None:
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class Value1:
v1: str
Expand Down Expand Up @@ -745,7 +745,7 @@ class T2(Generic[_TValue, _TItem], T1[_TItem]):
assert mr.load(T2[Value2, int], dumped) == instance


def test_generic_reuse() -> None:
def test_generic_origin_reuse() -> None:
_TItem = TypeVar("_TItem")

@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
Expand Down

0 comments on commit 886c1cb

Please sign in to comment.