Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(django_getter): use function cache to increase performance #1150

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 61 additions & 51 deletions ninja/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ def resolve_name(obj):

"""

from __future__ import annotations

import warnings
from functools import partial
from typing import (
Any,
Callable,
Dict,
Type,
ClassVar,
TypeVar,
Union,
no_type_check,
)

Expand All @@ -39,6 +40,7 @@ def resolve_name(obj):
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from typing_extensions import dataclass_transform

from ninja.constants import NOT_SET
from ninja.signature.utils import get_args_names, has_kwargs
from ninja.types import DictStrAny

Expand All @@ -50,45 +52,68 @@ def resolve_name(obj):
S = TypeVar("S", bound="Schema")


def dict_getter(key: str, obj: DjangoGetter) -> Any:
if key not in obj._obj:
raise AttributeError(key)
return obj._obj[key]


def attr_getter(key: str, obj: DjangoGetter) -> Any:
try:
return Variable(key).resolve(obj._obj)
except VariableDoesNotExist as e:
raise AttributeError(key) from e


def resolver(resolve_func: Callable, _: str, obj: DjangoGetter) -> Any:
return resolve_func(getter=obj)


def get_attr(key: str, obj: DjangoGetter) -> Any:
return getattr(obj._obj, key)


class DjangoGetter:
__slots__ = ("_obj", "_schema_cls", "_context", "__dict__")
__slots__ = ("_obj", "_schema_cls", "_context", "__dict__", "_cache_key")
_cache: ClassVar[dict[str, Callable]] = {}

def __init__(self, obj: Any, schema_cls: Type[S], context: Any = None):
def __init__(self, obj: Any, schema_cls: type[S], context: Any = None) -> None:
self._obj = obj
self._schema_cls = schema_cls
self._context = context
self._cache_key = f"{self._schema_cls.__module__}.{self._schema_cls.__name__}.{self._obj.__class__.__name__}"

def __getattr__(self, key: str) -> Any:
# if key.startswith("__pydantic"):
# return getattr(self._obj, key)

resolver = self._schema_cls._ninja_resolvers.get(key)
if resolver:
value = resolver(getter=self)
else:
if isinstance(self._obj, dict):
if key not in self._obj:
raise AttributeError(key)
value = self._obj[key]
else:
try:
value = getattr(self._obj, key)
except AttributeError:
try:
# value = attrgetter(key)(self._obj)
value = Variable(key).resolve(self._obj)
# TODO: Variable(key) __init__ is actually slower than
# Variable.resolve - so it better be cached
except VariableDoesNotExist as e:
raise AttributeError(key) from e
cache_key = f"{self._cache_key}.{key}"
if cache_key in DjangoGetter._cache:
# Use cached function, if available.
value = DjangoGetter._cache[cache_key](key, self)
return self._convert_result(value)

stored_resolver = self._schema_cls._ninja_resolvers.get(key)
if stored_resolver:
# Use resolver when provided for this key.
value = stored_resolver(getter=self)
# bind resolver of this key to the _cache
DjangoGetter._cache[cache_key] = partial(resolver, stored_resolver)
return self._convert_result(value)

if isinstance(self._obj, dict):
# Use dict lookup, faster than getattr
value = dict_getter(key, self)
DjangoGetter._cache[cache_key] = dict_getter
return self._convert_result(value)

value = getattr(self._obj, key, NOT_SET)
if value is not NOT_SET:
# If getattr worked, use that.
DjangoGetter._cache[cache_key] = get_attr
return self._convert_result(value)
# Finally, fallback to attr_getter
value = attr_getter(key, self)
DjangoGetter._cache[cache_key] = attr_getter
return self._convert_result(value)

# def get(self, key: Any, default: Any = None) -> Any:
# try:
# return self[key]
# except KeyError:
# return default

def _convert_result(self, result: Any) -> Any:
if isinstance(result, Manager):
return list(result.all())
Expand Down Expand Up @@ -116,7 +141,7 @@ class Resolver:
_func: Any
_takes_context: bool

def __init__(self, func: Union[Callable, staticmethod]):
def __init__(self, func: Callable | staticmethod):
if isinstance(func, staticmethod):
self._static = True
self._func = func.__func__
Expand All @@ -139,25 +164,10 @@ def __call__(self, getter: DjangoGetter) -> Any:
) # pragma: no cover
# return self._func(self._fake_instance(getter), getter._obj)

# def _fake_instance(self, getter: DjangoGetter) -> "Schema":
# """
# Generate a partial schema instance that can be used as the ``self``
# attribute of resolver functions.
# """

# class PartialSchema(Schema):
# def __getattr__(self, key: str) -> Any:
# value = getattr(getter, key)
# field = getter._schema_cls.model_fields[key]
# value = field.validate(value, values={}, loc=key, cls=None)[0]
# return value

# return PartialSchema()


@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
class ResolverMetaclass(ModelMetaclass):
_ninja_resolvers: Dict[str, Resolver]
_ninja_resolvers: dict[str, Resolver]

@no_type_check
def __new__(cls, name, bases, namespace, **kwargs):
Expand Down Expand Up @@ -228,7 +238,7 @@ def _run_root_validator(
return handler(values)

@classmethod
def from_orm(cls: Type[S], obj: Any, **kw: Any) -> S:
def from_orm(cls: type[S], obj: Any, **kw: Any) -> S:
return cls.model_validate(obj, **kw)

def dict(self, *a: Any, **kw: Any) -> DictStrAny:
Expand Down