Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generics support #163

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
37 changes: 35 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ Supported types:
- `Mapping` (with typed keys and values), `Set`, `Sequence`

Example:
class Annotated:
pass

```python
import dataclasses
Expand Down Expand Up @@ -77,3 +75,38 @@ loaded = mr.load(CompanyUpdateData, {"annual_turnover": None})
assert loaded.name is mr.MISSING
assert loaded.annual_turnover is None
```

Also generics are supported. All works automatically except one case. Dump operation of generic dataclass with `frozen=True` or `slots=True` requires explicitly specified subscripted generic type as `cls` argument of `dump` and `dump_many` methods.

```python
import dataclasses
from typing import Generic, TypeVar
import marshmallow_recipe as mr

T = TypeVar("T")

@dataclasses.dataclass()
class Regular(Generic[T]):
value: T

mr.dump(Regular[int](value=123)) # it works without explicit cls arg

@dataclasses.dataclass(frozen=True)
class Frozen(Generic[T]):
value: T

mr.dump(Frozen[int](value=123), cls=Frozen[int]) # cls required for frozen generic

@dataclasses.dataclass(slots=True)
class Slots(Generic[T]):
value: T

mr.dump(Slots[int](value=123), cls=Slots[int]) # cls required for generic with slots

@dataclasses.dataclass(slots=True)
class SlotsNonGeneric(Slots[int]):
pass

mr.dump(SlotsNonGeneric(value=123)) # cls not required

```
94 changes: 44 additions & 50 deletions marshmallow_recipe/bake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import inspect
import types
import uuid
from typing import Annotated, Any, Protocol, TypeVar, Union, get_args, get_origin
from typing import Annotated, Any, NamedTuple, Protocol, TypeVar, Union, cast, get_args, get_origin

import marshmallow as m

Expand All @@ -29,6 +29,7 @@
tuple_field,
uuid_field,
)
from .generics import TypeLike, get_fields_type_map
from .hooks import get_pre_loads
from .metadata import EMPTY_METADATA, Metadata, is_metadata
from .naming_case import NamingCase
Expand All @@ -48,16 +49,23 @@ class _SchemaTypeKey:
_schema_types: dict[_SchemaTypeKey, type[m.Schema]] = {}


class _FieldDescription(NamedTuple):
field: dataclasses.Field
value_type: TypeLike
metadata: Metadata


def bake_schema(
cls: type,
*,
naming_case: NamingCase | None = None,
none_value_handling: NoneValueHandling | None = None,
) -> type[m.Schema]:
if not dataclasses.is_dataclass(cls):
raise ValueError(f"{cls} is not a dataclass")
origin: type = get_origin(cls) or cls
outring marked this conversation as resolved.
Show resolved Hide resolved
if not dataclasses.is_dataclass(origin):
raise ValueError(f"{origin} is not a dataclass")

if options := try_get_options_for(cls):
if options := try_get_options_for(origin):
cls_none_value_handling = none_value_handling or options.none_value_handling
cls_naming_case = naming_case or options.naming_case
else:
Expand All @@ -72,81 +80,81 @@ def bake_schema(
if result := _schema_types.get(key):
return result

fields_with_metadata = [
(
fields_type_map = get_fields_type_map(cls)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we pass cls here and not origin? If there is why do we use origin on line 95 and why we even collecting field names one more time and not using the map?

Copy link
Contributor Author

@slawwan slawwan Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we pass cls here and not origin?

in case with genetic type origin will be unsubscripted (open) generic like Xxx[T] without args and we can not use it to calculate actual field type and cls will be subscripted (closed) generic like Xxx[int] which has int as arg to replace with it all TypeVars used in fields types.

If there is why do we use origin on line 95

subscripted generic of unsubscripted dataclass is not a dataclass. we can get fields only from origin - from unsubscripted declared type. please check PR description. There you can find explanation with examples.

why we even collecting field names one more time and not using the map?

that map is only about fields types with access by field name. Here we collect fields to get field descriptor and field metadata to build schema


fields = [
_FieldDescription(
field,
fields_type_map[field.name],
_get_metadata(
name=field.name if cls_naming_case is None else cls_naming_case(field.name),
default=_get_field_default(field),
metadata=field.metadata,
),
)
for field in dataclasses.fields(cls)
for field in dataclasses.fields(origin)
if field.init
]

for field, _ in fields_with_metadata:
for other_field, metadata in fields_with_metadata:
if field is other_field:
for first in fields:
for second in fields:
if first is second:
continue
second_name = second.metadata["name"]
if first.field.name == second_name:
raise ValueError(f"Invalid name={second_name} in metadata for field={second.field.name}")

other_field_name = metadata["name"]
if field.name == other_field_name:
raise ValueError(f"Invalid name={other_field_name} in metadata for field={other_field.name}")

schema_type: type[m.Schema] = type(
schema_type = type(
cls.__name__,
(_get_base_schema(cls, cls_none_value_handling or NoneValueHandling.IGNORE),),
{"__module__": f"{__package__}.auto_generated"}
| {
field.name: get_field_for(
field.type, # type: ignore
value_type,
metadata,
naming_case=naming_case,
none_value_handling=none_value_handling,
)
for field, metadata in fields_with_metadata
for field, value_type, metadata in fields
},
)
_schema_types[key] = schema_type
return schema_type


def get_field_for(
type: type,
t: TypeLike,
metadata: Metadata,
naming_case: NamingCase | None,
none_value_handling: NoneValueHandling | None,
) -> m.fields.Field:
if type is Any:
if t is Any:
return raw_field(**metadata)

type = _substitute_any_to_open_generic(type)

if underlying_type_from_optional := _try_get_underlying_type_from_optional(type):
if underlying_type_from_optional := _try_get_underlying_type_from_optional(t):
required = False
allow_none = True
type = underlying_type_from_optional
t = underlying_type_from_optional
elif metadata.get("default", dataclasses.MISSING) is not dataclasses.MISSING:
required = False
allow_none = False
else:
required = True
allow_none = False

if inspect.isclass(type) and issubclass(type, enum.Enum):
return enum_field(enum_type=type, required=required, allow_none=allow_none, **metadata)
if inspect.isclass(t) and issubclass(t, enum.Enum):
return enum_field(enum_type=t, required=required, allow_none=allow_none, **metadata)

if dataclasses.is_dataclass(type):
if dataclasses.is_dataclass(get_origin(t) or t):
return nested_field(
bake_schema(type, naming_case=naming_case, none_value_handling=none_value_handling),
bake_schema(cast(type, t), naming_case=naming_case, none_value_handling=none_value_handling),
required=required,
allow_none=allow_none,
**metadata,
)

if (origin := get_origin(type)) is not None:
arguments = get_args(type)
if (origin := get_origin(t)) is not None:
arguments = get_args(t)

if origin is list or origin is collections.abc.Sequence:
collection_field_metadata = dict(metadata)
Expand Down Expand Up @@ -268,11 +276,11 @@ def get_field_for(
none_value_handling=none_value_handling,
)

field_factory = _SIMPLE_TYPE_FIELD_FACTORIES.get(type)
if field_factory:
if t in _SIMPLE_TYPE_FIELD_FACTORIES:
field_factory = _SIMPLE_TYPE_FIELD_FACTORIES[t]
return field_factory(required=required, allow_none=allow_none, **metadata)

raise ValueError(f"Unsupported {type=}")
raise ValueError(f"Unsupported {t=}")


if _MARSHMALLOW_VERSION_MAJOR >= 3:
Expand Down Expand Up @@ -373,26 +381,12 @@ def _get_metadata(*, name: str, default: Any, metadata: collections.abc.Mapping[
return Metadata(values)


def _substitute_any_to_open_generic(type: type) -> type:
if type is list:
return list[Any]
if type is set:
return set[Any]
if type is frozenset:
return frozenset[Any]
if type is dict:
return dict[Any, Any]
if type is tuple:
return tuple[Any, ...]
return type


def _try_get_underlying_type_from_optional(type: type) -> type | None:
def _try_get_underlying_type_from_optional(t: TypeLike) -> TypeLike | None:
# to support Union[int, None] and int | None
if get_origin(type) is Union or isinstance(type, types.UnionType): # type: ignore
type_args = get_args(type)
if get_origin(t) is Union or isinstance(t, types.UnionType): # type: ignore
type_args = get_args(t)
if types.NoneType not in type_args or len(type_args) != 2:
raise ValueError(f"Unsupported {type=}")
raise ValueError(f"Unsupported {t=}")
return next(type_arg for type_arg in type_args if type_arg is not types.NoneType) # noqa

return None
120 changes: 120 additions & 0 deletions marshmallow_recipe/generics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import dataclasses
import types
import typing
from typing import TYPE_CHECKING, Annotated, Any, Generic, TypeAlias, TypeVar, Union, get_args, get_origin

_GenericAlias: TypeAlias = typing._GenericAlias # type: ignore

if TYPE_CHECKING:
from _typeshed import DataclassInstance
else:
DataclassInstance: TypeAlias = type

TypeLike: TypeAlias = type | TypeVar | types.UnionType | types.GenericAlias | _GenericAlias
FieldsTypeMap: TypeAlias = dict[str, TypeLike]
TypeVarMap: TypeAlias = dict[TypeVar, TypeLike]
FieldsClassMap: TypeAlias = dict[str, TypeLike]
ClassTypeVarMap: TypeAlias = dict[TypeLike, TypeVarMap]
FieldsTypeVarMap: TypeAlias = dict[str, TypeVarMap]


def get_fields_type_map(cls: type) -> FieldsTypeMap:
origin: type = get_origin(cls) or cls
if not dataclasses.is_dataclass(origin):
raise ValueError(f"{origin} is not a dataclass")

class_type_var_map = get_class_type_var_map(cls)
fields_class_map = get_fields_class_map(origin)
return {
f.name: build_subscripted_type(f.type, class_type_var_map.get(fields_class_map[f.name], {}))
for f in dataclasses.fields(origin)
}


def get_fields_class_map(cls: type[DataclassInstance]) -> FieldsClassMap:
names: dict[str, dataclasses.Field] = {}
result: FieldsClassMap = {}

mro = cls.__mro__
for cls in (*mro[-1:0:-1], cls):
maradik marked this conversation as resolved.
Show resolved Hide resolved
if not dataclasses.is_dataclass(cls):
continue
for field in dataclasses.fields(cls):
if names.get(field.name) != field:
names[field.name] = field
result[field.name] = cls

return result


def build_subscripted_type(t: TypeLike, type_var_map: TypeVarMap) -> TypeLike:
if isinstance(t, TypeVar):
return build_subscripted_type(type_var_map[t], type_var_map)

origin = get_origin(t)
if origin is Union or origin is types.UnionType:
return Union[*(build_subscripted_type(x, type_var_map) for x in get_args(t))]

if origin is Annotated:
t, *annotations = get_args(t)
return Annotated[build_subscripted_type(t, type_var_map), *annotations]

if origin and isinstance(t, types.GenericAlias):
return types.GenericAlias(origin, tuple(build_subscripted_type(x, type_var_map) for x in get_args(t)))

if origin and isinstance(t, _GenericAlias):
return _GenericAlias(origin, tuple(build_subscripted_type(x, type_var_map) for x in get_args(t)))

return _subscript_with_any(t)


def get_class_type_var_map(t: TypeLike) -> ClassTypeVarMap:
class_type_var_map: ClassTypeVarMap = {}
_build_class_type_var_map(t, class_type_var_map)
return class_type_var_map


def _build_class_type_var_map(t: TypeLike, class_type_var_map: ClassTypeVarMap) -> None:
if _get_params(t):
raise ValueError(f"Expected subscripted generic, but got unsubscripted {t}")

type_var_map: TypeVarMap = {}
origin = get_origin(t) or t
params = _get_params(origin)
args = get_args(t)
if params or args:
if not params or not args or len(params) != len(args):
raise ValueError(f"Unexpected generic {t}")
class_type_var_map[origin] = type_var_map
maradik marked this conversation as resolved.
Show resolved Hide resolved
for i, parameter in enumerate(params):
assert isinstance(parameter, TypeVar)
type_var_map[parameter] = args[i]

if orig_bases := _get_orig_bases(origin):
for orig_base in orig_bases:
if get_origin(orig_base) is Generic:
continue
subscripted_base = build_subscripted_type(orig_base, type_var_map)
_build_class_type_var_map(subscripted_base, class_type_var_map)


def _get_params(t: Any) -> tuple[TypeLike, ...] | None:
return hasattr(t, "__parameters__") and getattr(t, "__parameters__") or None


def _get_orig_bases(t: Any) -> tuple[TypeLike, ...] | None:
return hasattr(t, "__orig_bases__") and getattr(t, "__orig_bases__") or None
maradik marked this conversation as resolved.
Show resolved Hide resolved


def _subscript_with_any(t: TypeLike) -> TypeLike:
if t is list:
return list[Any]
if t is set:
return set[Any]
if t is frozenset:
return frozenset[Any]
if t is dict:
return dict[Any, Any]
if t is tuple:
return tuple[Any, ...]
return t
Loading
Loading