Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generics support #163

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open

generics support #163

wants to merge 13 commits into from

Conversation

slawwan
Copy link
Contributor

@slawwan slawwan commented Nov 8, 2024

Generics support

This PR contains implementation for generics support for both dump and load methods. I have supported all cases which I met and I could imagine. But may be I missed something. Hope no.

Key points:

  • dump and dump_many were improved to extract actual runtime type for serialization with _extract_type function
  • get_fields_type_map returns dict with actual runtime type for each field accessible by field's name.
  • bake_schema enumerates all dataclass fields and gets it's type according dict returned by get_fields_type_map
  • get_field_for mostly was not change because it receives already prepared type to build schema field

I went deeper into Python generics first time and it was not obviously how it works. All issues I met with explanation how I solved them are listed below in Issues section.

All provided code snippets can be executed in Python playground to check how it works:

Generics in Python

Let we have generic dataclass class Xxx(Generic[T1, T2]) then:

  • Xxx[int, str] - subscripted type (closed generic)
  • (int, str) - tuple with args of Xxx[int, str]
  • Xxx - unsubscripted type (open generic), aka origin
  • (T1, T2) - tuple with TypeVar items which are params of Xxx
import dataclasses
from typing import Generic, TypeVar, get_args, get_origin

T1 = TypeVar("T1")
T2 = TypeVar("T2")

@dataclasses.dataclass()
class Xxx(Generic[T1, T2]):
    pass

assert get_args(Xxx[int, str]) == (int, str)
assert get_origin(Xxx[int, str]) == Xxx
assert Xxx.__parameters__ == (T1, T2)

Issues

Extracting generic instance type

This is what we should take into account:

  • type(...) function always returns unsubscripted type
  • __orig_class__ returns subscripted type
  • __orig_class__ does not exist for dataclasses with frozen=True or slots=True
  • __orig_class__ does not exist for non generic dataclasses
import dataclasses
from typing import Any, Generic, TypeVar, get_origin

T1 = TypeVar("T1")
T2 = TypeVar("T2")

@dataclasses.dataclass()
class Xxx(Generic[T1, T2]):
    pass

@dataclasses.dataclass(slots=True)
class Zzz(Generic[T1, T2]):
    pass

xxx = Xxx[int, str]()
zzz = Zzz[int, str]()

assert type(xxx) == Xxx
assert getattr(xxx, "__orig_class__") == Xxx[int, str]
assert type(zzz) == Zzz
assert hasattr(zzz, "__orig_class__") == False

It means in some cases we can not extract subscripted generic type from instance and it should be passed explicitly to dump function. All this stuff implemented in _extract_type and tested by test_generic_extract_type_on_dump

Extracting generic dataclass field type

Subscripted type Xxx[int, str] is not a dataclass. Only unsubscripted Xxx is dataclass. Fields of Xxx are also have unsubscripted types. Here are steps how to calculate actual fields types:

  1. Get list of args from Xxx[int, str]
  2. Get origin Xxx from Xxx[int, str]
  3. Ger list of params from Xxx
  4. Build a map from param to arg
  5. Replace field TypeVar with arg from map
import dataclasses
from typing import Generic, TypeVar, get_args, get_origin

T1 = TypeVar("T1")
T2 = TypeVar("T2")

@dataclasses.dataclass()
class Xxx(Generic[T1, T2]):
    t1: T1
    t2: T2

xxx = Xxx[int, str](t1=123, t2="qwe")

cls = xxx.__orig_class__
origin = get_origin(cls)

assert dataclasses.is_dataclass(cls) is False
assert dataclasses.is_dataclass(origin) is True

args = get_args(cls)
params = origin.__parameters__

type_var_map = {param: args[i] for i, param in enumerate(params)}
assert typevars == {T1: int, T2: str}

fields = {f.name: type_var_map[f.type] for f in dataclasses.fields(origin)}
assert fields == {"t1": int, "t2": str}

Class field type with nested generic

Class field type can be more complicated than just TypeVar:

  • types.UnionType
  • Union
  • Annotated
  • types.GenericAlias
  • _GenericAlias

Subscripted type can be built recursively using the same type_var_map from snippet above. It is implemented in build_subscripted_type and types recognition can be tested using the code snippet below.

import dataclasses
import types
from typing import Annotated, Generic, TypeVar, Union, get_origin, _GenericAlias

T1 = TypeVar("T1")
T2 = TypeVar("T2")

@dataclasses.dataclass()
class Zzz(Generic[T1, T2]):
    pass

@dataclasses.dataclass()
class Xxx(Generic[T1, T2]):
    type_var: T1
    optional_set: set[T1] | None
    optional_zzz: Zzz[T1, T2] | None
    annotated: Annotated[T1, "meta"]
    generic_set: set[T1]
    generic_zzz: Zzz[T1, T2]

fields = {f.name: f.type for f in dataclasses.fields(Xxx)}

assert isinstance(fields["type_var"], TypeVar)
assert get_origin(fields["optional_set"]) is types.UnionType
assert get_origin(fields["optional_zzz"]) is Union
assert get_origin(fields["annotated"]) is Annotated
assert isinstance(fields["generic_set"], types.GenericAlias)
assert isinstance(fields["generic_zzz"], _GenericAlias)

Inheritance with generics

There are three connected issues to solve

Parents or parents of parents can be also generic

  • They can be subscripted
    class Child(Parent[int]):
        pass
  • They can be subscripted with child generic arg
    class Child(Generic[T], Parent[T]):
        pass
  • They can have nested generic
    class Child(Generic[T], Parent[list[T]]):
        pass

To solve this issue we can recursively iterate all parents via __orig_bases__ and apply build_subscripted_type for each parent class with type_var_map of child class

Class in hierarchy can use same TypeVar as generic param

class Parent(Generic[T]):
    v1: T

class Child(Generic[T], Parent[int]):
    v2: T

instance = Child[str](v1=111, v2="x")

v1 is int and v2 is str but for both field.type is T

So we can not have single type_var_map we should have separate maps for each class in hierarchy. All this stuff is implemented in get_class_type_var_map. Then we just need to get field's owner class to find its type_ver_map and build subscripted field type. BUT... Here is another issue. Check it below.

Dataclass fields have no owner class reference

It means we need to build a map from field to its owner class. But we need take into account these things:

  • dataclasses.fields(...) returns all dataclass fields including all fields from all parents
  • dataclass fields can be overridden in child dataclasses including generic override
  • all dataclass fields have unique names even if they were overridden
  • every field override will have its own descriptor in child dataclass

To build fields_class_map can be used the same approach as in dataclasses module using enumeration of reversed __mro__. Field name can be used as a key as it is unique. This stuff is implemented in get_fields_class_map function. The snippet below explains how it works and can be used to play with it.

import dataclasses

@dataclasses.dataclass()
class Parent():
    v1: str
    v2: str

@dataclasses.dataclass()
class Child(Parent):
    v2: str
    v3: str

fields_p = dataclasses.fields(Parent)
fields_c = dataclasses.fields(Child)

assert [f.name for f in fields_p] == ["v1", "v2"]
assert [f.name for f in fields_c] == ["v1", "v2", "v3"]

assert fields_c[0] == fields_p[0]  # same descriptor for v1
assert fields_c[1] != fields_p[1]  # new descriptor for overridden v2

fields: dict[str, dataclasses.Field] = {}
classes: dict[str, type] = {}

target = Child
for cls in (*target.__mro__[-1:0:-1], target):  # same as dataclass collects fields
    if not dataclasses.is_dataclass(cls):
        continue
    for field in dataclasses.fields(cls):
        if fields.get(field.name) != field:  # new field detected including override
            fields[field.name] = field
            classes[field.name] = cls

assert classes == {
    "v1": Parent,
    "v2": Child,
    "v3": Child
}

That's it

It works!

@@ -36,7 +36,9 @@ def __call__(


class DumpFunction(Protocol):
def __call__(self, data: Any, *, naming_case: NamingCase | None = None) -> dict[str, Any]: ...
def __call__(
self, data: Any, *, naming_case: NamingCase | None = None, t: type | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of passing t here?

Copy link
Contributor Author

@slawwan slawwan Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added commit

that function c396fa8#diff-7f5a7271c4eef44b2f5fda6334ecae0952550de169520dde4a919d86949361c2R655

and that test c396fa8#diff-7f5a7271c4eef44b2f5fda6334ecae0952550de169520dde4a919d86949361c2R655

should give an answer

the issue is that type(instance) returns unsubscripted (open) generic but we can get subscripted (closed) generic from instance.__origin_class__ but it does not for frozen=True or slots=True dataclasses that is why we need to specify type explicitly

I will also add more details in PR description

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description updated

marshmallow_recipe/bake.py Outdated Show resolved Hide resolved

origin = get_origin(t)
if origin is Union or origin is types.UnionType:
return Union[*(build_subscripted_type(x, type_var_map) for x in get_args(t))] # type: ignore
Copy link
Contributor Author

@slawwan slawwan Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pyright highlights * as unsupported for Python 3.10 in this case of usage but as I see there are no Python 3.10 in tests matrix.

Should I support this behaviour for Python 3.10 or can I increase Python version for Pyright up to 3.11?

The same thing for Annotated type on the next line

Also Pyright requesting for update

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated this branch and removed # type: ignore for this case

also I noticed that Pyright does not notify about redundant # type: ignore as mypy do

@slawwan slawwan marked this pull request as ready for review November 12, 2024 19:22
@slawwan slawwan requested a review from a team as a code owner November 12, 2024 19:22
@slawwan slawwan requested a review from Pliner November 12, 2024 19:33
@@ -72,81 +80,81 @@ def bake_schema(
if result := _schema_types.get(key):
return result

fields_with_metadata = [
(
fields_type_map = get_fields_type_map(cls)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we pass cls here and not origin? If there is why do we use origin on line 95 and why we even collecting field names one more time and not using the map?

Copy link
Contributor Author

@slawwan slawwan Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we pass cls here and not origin?

in case with genetic type origin will be unsubscripted (open) generic like Xxx[T] without args and we can not use it to calculate actual field type and cls will be subscripted (closed) generic like Xxx[int] which has int as arg to replace with it all TypeVars used in fields types.

If there is why do we use origin on line 95

subscripted generic of unsubscripted dataclass is not a dataclass. we can get fields only from origin - from unsubscripted declared type. please check PR description. There you can find explanation with examples.

why we even collecting field names one more time and not using the map?

that map is only about fields types with access by field name. Here we collect fields to get field descriptor and field metadata to build schema

@slawwan slawwan requested a review from outring November 13, 2024 12:59
@slawwan slawwan requested a review from maradik November 22, 2024 10:12
maradik
maradik previously approved these changes Nov 26, 2024
outring
outring previously approved these changes Dec 2, 2024
@Pliner Pliner dismissed stale reviews from outring and maradik via fcf589b December 17, 2024 16:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants