Skip to content

Commit

Permalink
feat(django_getter): use function cache to increase performance
Browse files Browse the repository at this point in the history
Update schema.py
  • Loading branch information
Yorick Rommers authored and yorickr-sendcloud committed Aug 2, 2024
1 parent eecb05f commit 67ac1be
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 80 deletions.
42 changes: 19 additions & 23 deletions ninja/openapi/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,21 @@ def __init__(self, api: "NinjaAPI", path_prefix: str) -> None:
self.securitySchemes: DictStrAny = {}
self.all_operation_ids: Set = set()
extra_info = api.openapi_extra.get("info", {})
super().__init__(
[
("openapi", "3.1.0"),
(
"info",
{
"title": api.title,
"version": api.version,
"description": api.description,
**extra_info,
},
),
("paths", self.get_paths()),
("components", self.get_components()),
("servers", api.servers),
]
)
super().__init__([
("openapi", "3.1.0"),
(
"info",
{
"title": api.title,
"version": api.version,
"description": api.description,
**extra_info,
},
),
("paths", self.get_paths()),
("components", self.get_components()),
("servers", api.servers),
])
for k, v in api.openapi_extra.items():
if k not in self:
self[k] = v
Expand Down Expand Up @@ -242,12 +240,10 @@ def _create_multipart_schema_from_models(
content_type = BODY_CONTENT_TYPES["file"]

# get the various schemas
result = merge_schemas(
[
self._create_schema_from_model(model, remove_level=False)[0]
for model in models
]
)
result = merge_schemas([
self._create_schema_from_model(model, remove_level=False)[0]
for model in models
])
result["title"] = "MultiPartBodyParams"

return result, content_type
Expand Down
112 changes: 61 additions & 51 deletions ninja/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ def resolve_name(obj):
"""

from __future__ import annotations

import warnings
from functools import partial
from typing import (
Any,
Callable,
Dict,
Type,
ClassVar,
TypeVar,
Union,
no_type_check,
)

Expand All @@ -39,6 +40,7 @@ def resolve_name(obj):
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from typing_extensions import dataclass_transform

from ninja.constants import NOT_SET
from ninja.signature.utils import get_args_names, has_kwargs
from ninja.types import DictStrAny

Expand All @@ -50,45 +52,68 @@ def resolve_name(obj):
S = TypeVar("S", bound="Schema")


def dict_getter(key: str, obj: DjangoGetter) -> Any:
if key not in obj._obj:
raise AttributeError(key)
return obj._obj[key]


def attr_getter(key: str, obj: DjangoGetter) -> Any:
try:
return Variable(key).resolve(obj._obj)
except VariableDoesNotExist as e:
raise AttributeError(key) from e


def resolver(resolve_func: Callable, _: str, obj: DjangoGetter) -> Any:
return resolve_func(getter=obj)


def get_attr(key: str, obj: DjangoGetter) -> Any:
return getattr(obj._obj, key)


class DjangoGetter:
__slots__ = ("_obj", "_schema_cls", "_context", "__dict__")
__slots__ = ("_obj", "_schema_cls", "_context", "__dict__", "_cache_key")
_cache: ClassVar[dict[str, Callable]] = {}

def __init__(self, obj: Any, schema_cls: Type[S], context: Any = None):
def __init__(self, obj: Any, schema_cls: type[S], context: Any = None) -> None:
self._obj = obj
self._schema_cls = schema_cls
self._context = context
self._cache_key = f"{self._schema_cls.__module__}.{self._schema_cls.__name__}.{self._obj.__class__.__name__}"

def __getattr__(self, key: str) -> Any:
# if key.startswith("__pydantic"):
# return getattr(self._obj, key)

resolver = self._schema_cls._ninja_resolvers.get(key)
if resolver:
value = resolver(getter=self)
else:
if isinstance(self._obj, dict):
if key not in self._obj:
raise AttributeError(key)
value = self._obj[key]
else:
try:
value = getattr(self._obj, key)
except AttributeError:
try:
# value = attrgetter(key)(self._obj)
value = Variable(key).resolve(self._obj)
# TODO: Variable(key) __init__ is actually slower than
# Variable.resolve - so it better be cached
except VariableDoesNotExist as e:
raise AttributeError(key) from e
cache_key = f"{self._cache_key}.{key}"
if cache_key in DjangoGetter._cache:
# Use cached function, if available.
value = DjangoGetter._cache[cache_key](key, self)
return self._convert_result(value)

stored_resolver = self._schema_cls._ninja_resolvers.get(key)
if stored_resolver:
# Use resolver when provided for this key.
value = stored_resolver(getter=self)
# bind resolver of this key to the _cache
DjangoGetter._cache[cache_key] = partial(resolver, stored_resolver)
return self._convert_result(value)

if isinstance(self._obj, dict):
# Use dict lookup, faster than getattr
value = dict_getter(key, self)
DjangoGetter._cache[cache_key] = dict_getter
return self._convert_result(value)

value = getattr(self._obj, key, NOT_SET)
if value is not NOT_SET:
# If getattr worked, use that.
DjangoGetter._cache[cache_key] = get_attr
return self._convert_result(value)
# Finally, fallback to attr_getter
value = attr_getter(key, self)
DjangoGetter._cache[cache_key] = attr_getter
return self._convert_result(value)

# def get(self, key: Any, default: Any = None) -> Any:
# try:
# return self[key]
# except KeyError:
# return default

def _convert_result(self, result: Any) -> Any:
if isinstance(result, Manager):
return list(result.all())
Expand Down Expand Up @@ -116,7 +141,7 @@ class Resolver:
_func: Any
_takes_context: bool

def __init__(self, func: Union[Callable, staticmethod]):
def __init__(self, func: Callable | staticmethod):
if isinstance(func, staticmethod):
self._static = True
self._func = func.__func__
Expand All @@ -139,25 +164,10 @@ def __call__(self, getter: DjangoGetter) -> Any:
) # pragma: no cover
# return self._func(self._fake_instance(getter), getter._obj)

# def _fake_instance(self, getter: DjangoGetter) -> "Schema":
# """
# Generate a partial schema instance that can be used as the ``self``
# attribute of resolver functions.
# """

# class PartialSchema(Schema):
# def __getattr__(self, key: str) -> Any:
# value = getattr(getter, key)
# field = getter._schema_cls.model_fields[key]
# value = field.validate(value, values={}, loc=key, cls=None)[0]
# return value

# return PartialSchema()


@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
class ResolverMetaclass(ModelMetaclass):
_ninja_resolvers: Dict[str, Resolver]
_ninja_resolvers: dict[str, Resolver]

@no_type_check
def __new__(cls, name, bases, namespace, **kwargs):
Expand Down Expand Up @@ -228,7 +238,7 @@ def _run_root_validator(
return handler(values)

@classmethod
def from_orm(cls: Type[S], obj: Any, **kw: Any) -> S:
def from_orm(cls: type[S], obj: Any, **kw: Any) -> S:
return cls.model_validate(obj, **kw)

def dict(self, *a: Any, **kw: Any) -> DictStrAny:
Expand Down
10 changes: 4 additions & 6 deletions ninja/testing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,10 @@ def _build_request(
request.META = request_params.pop("META", {"REMOTE_ADDR": "127.0.0.1"})
request.FILES = request_params.pop("FILES", {})

request.META.update(
{
f"HTTP_{k.replace('-', '_')}": v
for k, v in request_params.pop("headers", {}).items()
}
)
request.META.update({
f"HTTP_{k.replace('-', '_')}": v
for k, v in request_params.pop("headers", {}).items()
})

request.headers = HttpHeaders(request.META)

Expand Down

0 comments on commit 67ac1be

Please sign in to comment.