Skip to content

Commit

Permalink
Cache typing functions to speed up @serde codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
yukinarit committed Jun 5, 2024
1 parent 41ce566 commit b915348
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 13 deletions.
78 changes: 66 additions & 12 deletions serde/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from collections import defaultdict
from collections.abc import Iterator
from dataclasses import is_dataclass
from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union
from typing import TypeVar, Generic, Any, ClassVar, Optional, NewType, Union, Hashable, Callable

import typing_inspect
from typing_extensions import TypeGuard
from typing_extensions import TypeGuard, ParamSpec

from .sqlalchemy import is_sqlalchemy_inspectable

Expand All @@ -29,7 +29,7 @@ def get_np_origin(tp: type[Any]) -> Optional[Any]:
return None


def get_np_args(tp: Any) -> tuple[Any, ...]:
def get_np_args(tp: type[Any]) -> tuple[Any, ...]:
return ()


Expand Down Expand Up @@ -93,20 +93,48 @@ class SerdeSkip(Exception):
"""


def is_hashable(typ: Any) -> TypeGuard[Hashable]:
"""
Test is an object hashable
"""
try:
hash(typ)
except TypeError:
return False
return True


P = ParamSpec("P")


def cache(f: Callable[P, T]) -> Callable[P, T]:
"""
Wrapper for `functools.cache` to avoid `Hashable` related type errors.
"""

def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return f(*args, **kwargs)

return functools.cache(wrapper) # type: ignore


@cache
def get_origin(typ: Any) -> Optional[Any]:
"""
Provide `get_origin` that works in all python versions.
"""
return typing.get_origin(typ) or get_np_origin(typ)


def get_args(typ: Any) -> tuple[Any, ...]:
@cache
def get_args(typ: type[Any]) -> tuple[Any, ...]:
"""
Provide `get_args` that works in all python versions.
"""
return typing.get_args(typ) or get_np_args(typ)


@cache
def typename(typ: Any, with_typing_module: bool = False) -> str:
"""
>>> from typing import Any
Expand Down Expand Up @@ -268,16 +296,16 @@ def dataclass_fields(cls: type[Any]) -> Iterator[dataclasses.Field]: # type: ig
TypeLike = Union[type[Any], typing.Any]


def iter_types(cls: TypeLike) -> list[TypeLike]:
def iter_types(cls: type[Any]) -> list[type[Any]]:
"""
Iterate field types recursively.
The correct return type is `Iterator[Union[Type, typing._specialform]],
but `typing._specialform` doesn't exist for python 3.6. Use `Any` instead.
"""
lst: set[TypeLike] = set()
lst: set[type[Any]] = set()

def recursive(cls: TypeLike) -> None:
def recursive(cls: type[Any]) -> None:
if cls in lst:
return

Expand All @@ -288,12 +316,12 @@ def recursive(cls: TypeLike) -> None:
elif isinstance(cls, str):
lst.add(cls)
elif is_opt(cls):
lst.add(Optional)
lst.add(Optional) # type: ignore
args = type_args(cls)
if args:
recursive(args[0])
elif is_union(cls):
lst.add(Union)
lst.add(Union) # type: ignore
for arg in type_args(cls):
recursive(arg)
elif is_list(cls):
Expand Down Expand Up @@ -366,13 +394,13 @@ def recursive(cls: TypeLike) -> None:
return list(lst)


def iter_literals(cls: TypeLike) -> list[TypeLike]:
def iter_literals(cls: type[Any]) -> list[TypeLike]:
"""
Iterate over all literals that are used in the dataclass
"""
lst: set[TypeLike] = set()
lst: set[type[Any]] = set()

def recursive(cls: TypeLike) -> None:
def recursive(cls: type[Any]) -> None:
if cls in lst:
return

Expand Down Expand Up @@ -406,6 +434,7 @@ def recursive(cls: TypeLike) -> None:
return list(lst)


@cache
def is_union(typ: Any) -> bool:
"""
Test if the type is `typing.Union`.
Expand Down Expand Up @@ -433,6 +462,7 @@ def is_union(typ: Any) -> bool:
return typing_inspect.is_union_type(typ) # type: ignore


@cache
def is_opt(typ: Any) -> bool:
"""
Test if the type is `typing.Optional`.
Expand Down Expand Up @@ -469,6 +499,7 @@ def is_opt(typ: Any) -> bool:
return typ is Optional


@cache
def is_bare_opt(typ: Any) -> bool:
"""
Test if the type is `typing.Optional` without type args.
Expand All @@ -482,6 +513,7 @@ def is_bare_opt(typ: Any) -> bool:
return not type_args(typ) and typ is Optional


@cache
def is_list(typ: type[Any]) -> bool:
"""
Test if the type is `list`.
Expand All @@ -497,6 +529,7 @@ def is_list(typ: type[Any]) -> bool:
return typ is list


@cache
def is_bare_list(typ: type[Any]) -> bool:
"""
Test if the type is `list` without type args.
Expand All @@ -509,6 +542,7 @@ def is_bare_list(typ: type[Any]) -> bool:
return typ is list


@cache
def is_tuple(typ: Any) -> bool:
"""
Test if the type is tuple.
Expand All @@ -519,6 +553,7 @@ def is_tuple(typ: Any) -> bool:
return typ is tuple


@cache
def is_bare_tuple(typ: type[Any]) -> bool:
"""
Test if the type is tuple without type args.
Expand All @@ -531,6 +566,7 @@ def is_bare_tuple(typ: type[Any]) -> bool:
return typ is tuple


@cache
def is_variable_tuple(typ: type[Any]) -> bool:
"""
Test if the type is a variable length of tuple tuple[T, ...]`.
Expand All @@ -547,6 +583,7 @@ def is_variable_tuple(typ: type[Any]) -> bool:
return istuple and len(args) == 2 and is_ellipsis(args[1])


@cache
def is_set(typ: type[Any]) -> bool:
"""
Test if the type is `set` or `frozenset`.
Expand All @@ -564,6 +601,7 @@ def is_set(typ: type[Any]) -> bool:
return typ in (set, frozenset)


@cache
def is_bare_set(typ: type[Any]) -> bool:
"""
Test if the type is `set` without type args.
Expand All @@ -576,6 +614,7 @@ def is_bare_set(typ: type[Any]) -> bool:
return typ in (set, frozenset)


@cache
def is_frozen_set(typ: type[Any]) -> bool:
"""
Test if the type is `frozenset`.
Expand All @@ -591,6 +630,7 @@ def is_frozen_set(typ: type[Any]) -> bool:
return typ is frozenset


@cache
def is_dict(typ: type[Any]) -> bool:
"""
Test if the type is dict.
Expand All @@ -608,6 +648,7 @@ def is_dict(typ: type[Any]) -> bool:
return typ in (dict, defaultdict)


@cache
def is_bare_dict(typ: type[Any]) -> bool:
"""
Test if the type is `dict` without type args.
Expand All @@ -620,6 +661,7 @@ def is_bare_dict(typ: type[Any]) -> bool:
return typ is dict


@cache
def is_default_dict(typ: type[Any]) -> bool:
"""
Test if the type is `defaultdict`.
Expand All @@ -635,6 +677,7 @@ def is_default_dict(typ: type[Any]) -> bool:
return typ is defaultdict


@cache
def is_none(typ: type[Any]) -> bool:
"""
>>> is_none(int)
Expand All @@ -650,6 +693,7 @@ def is_none(typ: type[Any]) -> bool:
PRIMITIVES = [int, float, bool, str]


@cache
def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]:
"""
Test if the type is `enum.Enum`.
Expand All @@ -660,6 +704,7 @@ def is_enum(typ: type[Any]) -> TypeGuard[enum.Enum]:
return isinstance(typ, enum.Enum)


@cache
def is_primitive_subclass(typ: type[Any]) -> bool:
"""
Test if the type is a subclass of primitive type.
Expand All @@ -674,6 +719,7 @@ def is_primitive_subclass(typ: type[Any]) -> bool:
return is_primitive(typ) and typ not in PRIMITIVES and not is_new_type_primitive(typ)


@cache
def is_primitive(typ: Union[type[Any], NewType]) -> bool:
"""
Test if the type is primitive.
Expand All @@ -691,6 +737,7 @@ def is_primitive(typ: Union[type[Any], NewType]) -> bool:
return is_new_type_primitive(typ)


@cache
def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool:
"""
Test if the type is a NewType of primitives.
Expand All @@ -702,10 +749,12 @@ def is_new_type_primitive(typ: Union[type[Any], NewType]) -> bool:
return any(isinstance(typ, ty) for ty in PRIMITIVES)


@cache
def has_generic_base(typ: Any) -> bool:
return Generic in getattr(typ, "__mro__", ()) or Generic in getattr(typ, "__bases__", ())


@cache
def is_generic(typ: Any) -> bool:
"""
Test if the type is derived from `typing.Generic`.
Expand All @@ -722,6 +771,7 @@ def is_generic(typ: Any) -> bool:
return origin is not None and has_generic_base(origin)


@cache
def is_class_var(typ: type[Any]) -> bool:
"""
Test if the type is `typing.ClassVar`.
Expand All @@ -734,6 +784,7 @@ def is_class_var(typ: type[Any]) -> bool:
return get_origin(typ) is ClassVar or typ is ClassVar # type: ignore


@cache
def is_literal(typ: type[Any]) -> bool:
"""
Test if the type is derived from `typing.Literal`.
Expand All @@ -750,13 +801,15 @@ def is_literal(typ: type[Any]) -> bool:
return origin is not None and origin is typing.Literal


@cache
def is_any(typ: type[Any]) -> bool:
"""
Test if the type is `typing.Any`.
"""
return typ is Any


@cache
def is_str_serializable(typ: type[Any]) -> bool:
"""
Test if the type is serializable to `str`.
Expand Down Expand Up @@ -787,6 +840,7 @@ def is_ellipsis(typ: Any) -> bool:
return typ is Ellipsis


@cache
def get_type_var_names(cls: type[Any]) -> Optional[list[str]]:
"""
Get type argument names of a generic class.
Expand Down
2 changes: 1 addition & 1 deletion serde/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def json_dumps(obj: Any, **opts: Any) -> str:
if "option" not in opts:
opts["option"] = orjson.OPT_SERIALIZE_NUMPY
return orjson.dumps(obj, **opts).decode() # type: ignore
return orjson.dumps(obj, **opts).decode()

def json_loads(s: Union[str, bytes], **opts: Any) -> Any:
return orjson.loads(s, **opts)
Expand Down

0 comments on commit b915348

Please sign in to comment.