Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache typing functions to speed up @serde codegen #536

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 66 additions & 12 deletions serde/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from collections import defaultdict
from collections.abc import Iterator
from dataclasses import is_dataclass
from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union
from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable

import typing_inspect
from typing_extensions import TypeGuard
from typing_extensions import TypeGuard, ParamSpec

from .sqlalchemy import is_sqlalchemy_inspectable

Expand All @@ -29,7 +29,7 @@ def get_np_origin(tp: type[Any]) -> Optional[Any]:
return None


def get_np_args(tp: Any) -> tuple[Any, ...]:
def get_np_args(tp: type[Any]) -> tuple[Any, ...]:
return ()


Expand Down Expand Up @@ -93,20 +93,48 @@ class SerdeSkip(Exception):
"""


def is_hashable(typ: Any) -> TypeGuard[Hashable]:
"""
Test is an object hashable
"""
try:
hash(typ)
except TypeError:
return False
return True


P = ParamSpec("P")


def cache(f: Callable[P, T]) -> Callable[P, T]:
"""
Wrapper for `functools.cache` to avoid `Hashable` related type errors.
"""

def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return f(*args, **kwargs)

return functools.cache(wrapper) # type: ignore


@cache
def get_origin(typ: Any) -> Optional[Any]:
"""
Provide `get_origin` that works in all python versions.
"""
return typing.get_origin(typ) or get_np_origin(typ)


def get_args(typ: Any) -> tuple[Any, ...]:
@cache
def get_args(typ: type[Any]) -> tuple[Any, ...]:
"""
Provide `get_args` that works in all python versions.
"""
return typing.get_args(typ) or get_np_args(typ)


@cache
def typename(typ: Any, with_typing_module: bool = False) -> str:
"""
>>> from typing import Any
Expand Down Expand Up @@ -268,16 +296,16 @@ def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]: # type: ig
TypeLike = Union[type[Any], typing.Any]


def iter_types(cls: TypeLike) -> list[TypeLike]:
def iter_types(cls: type[Any]) -> list[type[Any]]:
"""
Iterate field types recursively.
The correct return type is `Iterator[Union[Type, typing._specialform]],
but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead.
"""
lst: set[TypeLike] = set()
lst: set[type[Any]] = set()

def recursive(cls: TypeLike) -> None:
def recursive(cls: type[Any]) -> None:
if cls in lst:
return

Expand All @@ -288,12 +316,12 @@ def recursive(cls: TypeLike) -> None:
elif isinstance(cls, str):
lst.add(cls)
elif is_opt(cls):
lst.add(Optional)
lst.add(Optional) # type: ignore
args = type_args(cls)
if args:
recursive(args[0])
elif is_union(cls):
lst.add(Union)
lst.add(Union) # type: ignore
for arg in type_args(cls):
recursive(arg)
elif is_list(cls):
Expand Down Expand Up @@ -366,13 +394,13 @@ def recursive(cls: TypeLike) -> None:
return list(lst)


def iter_literals(cls: TypeLike) -> list[TypeLike]:
def iter_literals(cls: type[Any]) -> list[TypeLike]:
"""
Iterate over all literals that are used in the dataclass
"""
lst: set[TypeLike] = set()
lst: set[type[Any]] = set()

def recursive(cls: TypeLike) -> None:
def recursive(cls: type[Any]) -> None:
if cls in lst:
return

Expand Down Expand Up @@ -406,6 +434,7 @@ def recursive(cls: TypeLike) -> None:
return list(lst)


@cache
def is_union(typ: Any) -> bool:
"""
Test if the type is `typing.Union`.
Expand Down Expand Up @@ -433,6 +462,7 @@ def is_union(typ: Any) -> bool:
return typing_inspect.is_union_type(typ) # type: ignore


@cache
def is_opt(typ: Any) -> bool:
"""
Test if the type is `typing.Optional`.
Expand Down Expand Up @@ -469,6 +499,7 @@ def is_opt(typ: Any) -> bool:
return typ is Optional


@cache
def is_bare_opt(typ: Any) -> bool:
"""
Test if the type is `typing.Optional` without type args.
Expand All @@ -482,6 +513,7 @@ def is_bare_opt(typ: Any) -> bool:
return not type_args(typ) and typ is Optional


@cache
def is_list(typ: type[Any]) -> bool:
"""
Test if the type is `list`.
Expand All @@ -497,6 +529,7 @@ def is_list(typ: type[Any]) -> bool:
return typ is list


@cache
def is_bare_list(typ: type[Any]) -> bool:
"""
Test if the type is `list` without type args.
Expand All @@ -509,6 +542,7 @@ def is_bare_list(typ: type[Any]) -> bool:
return typ is list


@cache
def is_tuple(typ: Any) -> bool:
"""
Test if the type is tuple.
Expand All @@ -519,6 +553,7 @@ def is_tuple(typ: Any) -> bool:
return typ is tuple


@cache
def is_bare_tuple(typ: type[Any]) -> bool:
"""
Test if the type is tuple without type args.
Expand All @@ -531,6 +566,7 @@ def is_bare_tuple(typ: type[Any]) -> bool:
return typ is tuple


@cache
def is_variable_tuple(typ: type[Any]) -> bool:
"""
Test if the type is a variable length of tuple tuple[T, ...]`.
Expand All @@ -547,6 +583,7 @@ def is_variable_tuple(typ: type[Any]) -> bool:
return istuple and len(args) == 2 and is_ellipsis(args[1])


@cache
def is_set(typ: type[Any]) -> bool:
"""
Test if the type is `set` or `frozenset`.
Expand All @@ -564,6 +601,7 @@ def is_set(typ: type[Any]) -> bool:
return typ in (set, frozenset)


@cache
def is_bare_set(typ: type[Any]) -> bool:
"""
Test if the type is `set` without type args.
Expand All @@ -576,6 +614,7 @@ def is_bare_set(typ: type[Any]) -> bool:
return typ in (set, frozenset)


@cache
def is_frozen_set(typ: type[Any]) -> bool:
"""
Test if the type is `frozenset`.
Expand All @@ -591,6 +630,7 @@ def is_frozen_set(typ: type[Any]) -> bool:
return typ is frozenset


@cache
def is_dict(typ: type[Any]) -> bool:
"""
Test if the type is dict.
Expand All @@ -608,6 +648,7 @@ def is_dict(typ: type[Any]) -> bool:
return typ in (dict, defaultdict)


@cache
def is_bare_dict(typ: type[Any]) -> bool:
"""
Test if the type is `dict` without type args.
Expand All @@ -620,6 +661,7 @@ def is_bare_dict(typ: type[Any]) -> bool:
return typ is dict


@cache
def is_default_dict(typ: type[Any]) -> bool:
"""
Test if the type is `defaultdict`.
Expand All @@ -635,6 +677,7 @@ def is_default_dict(typ: type[Any]) -> bool:
return typ is defaultdict


@cache
def is_none(typ: type[Any]) -> bool:
"""
>>> is_none(int)
Expand All @@ -650,6 +693,7 @@ def is_none(typ: type[Any]) -> bool:
PRIMITIVES = [int, float, bool, str]


@cache
def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]:
"""
Test if the type is `enum.Enum`.
Expand All @@ -660,6 +704,7 @@ def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]:
return isinstance(typ, enum.Enum)


@cache
def is_primitive_subclass(typ: type[Any]) -> bool:
"""
Test if the type is a subclass of primitive type.
Expand All @@ -674,6 +719,7 @@ def is_primitive_subclass(typ: type[Any]) -> bool:
return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ)


@cache
def is_primitive(typ: Union[type[Any], NewType]) -> bool:
"""
Test if the type is primitive.
Expand All @@ -691,6 +737,7 @@ def is_primitive(typ: Union[type[Any], NewType]) -> bool:
return is_new_type_primitive(typ)


@cache
def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool:
"""
Test if the type is a NewType of primitives.
Expand All @@ -702,10 +749,12 @@ def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool:
return any(isinstance(typ, ty) for ty in PRIMITIVES)


@cache
def has_generic_base(typ: Any) -> bool:
return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ())


@cache
def is_generic(typ: Any) -> bool:
"""
Test if the type is derived from `typing.Generic`.
Expand All @@ -722,6 +771,7 @@ def is_generic(typ: Any) -> bool:
return origin is not None and has_generic_base(origin)


@cache
def is_class_var(typ: type[Any]) -> bool:
"""
Test if the type is `typing.ClassVar`.
Expand All @@ -734,6 +784,7 @@ def is_class_var(typ: type[Any]) -> bool:
return get_origin(typ) is ClassVar or typ is ClassVar # type: ignore


@cache
def is_literal(typ: type[Any]) -> bool:
"""
Test if the type is derived from `typing.Literal`.
Expand All @@ -750,13 +801,15 @@ def is_literal(typ: type[Any]) -> bool:
return origin is not None and origin is typing.Literal


@cache
def is_any(typ: type[Any]) -> bool:
"""
Test if the type is `typing.Any`.
"""
return typ is Any


@cache
def is_str_serializable(typ: type[Any]) -> bool:
"""
Test if the type is serializable to `str`.
Expand Down Expand Up @@ -787,6 +840,7 @@ def is_ellipsis(typ: Any) -> bool:
return typ is Ellipsis


@cache
def get_type_var_names(cls: type[Any]) -> Optional[list[str]]:
"""
Get type argument names of a generic class.
Expand Down
Loading