diff --git a/ninja/openapi/schema.py b/ninja/openapi/schema.py index 494e4e1d..584b52cf 100644 --- a/ninja/openapi/schema.py +++ b/ninja/openapi/schema.py @@ -40,23 +40,21 @@ def __init__(self, api: "NinjaAPI", path_prefix: str) -> None: self.securitySchemes: DictStrAny = {} self.all_operation_ids: Set = set() extra_info = api.openapi_extra.get("info", {}) - super().__init__( - [ - ("openapi", "3.1.0"), - ( - "info", - { - "title": api.title, - "version": api.version, - "description": api.description, - **extra_info, - }, - ), - ("paths", self.get_paths()), - ("components", self.get_components()), - ("servers", api.servers), - ] - ) + super().__init__([ + ("openapi", "3.1.0"), + ( + "info", + { + "title": api.title, + "version": api.version, + "description": api.description, + **extra_info, + }, + ), + ("paths", self.get_paths()), + ("components", self.get_components()), + ("servers", api.servers), + ]) for k, v in api.openapi_extra.items(): if k not in self: self[k] = v @@ -242,12 +240,10 @@ def _create_multipart_schema_from_models( content_type = BODY_CONTENT_TYPES["file"] # get the various schemas - result = merge_schemas( - [ - self._create_schema_from_model(model, remove_level=False)[0] - for model in models - ] - ) + result = merge_schemas([ + self._create_schema_from_model(model, remove_level=False)[0] + for model in models + ]) result["title"] = "MultiPartBodyParams" return result, content_type diff --git a/ninja/schema.py b/ninja/schema.py index e51dd43f..9d952266 100644 --- a/ninja/schema.py +++ b/ninja/schema.py @@ -18,14 +18,15 @@ def resolve_name(obj): """ +from __future__ import annotations + import warnings +from functools import partial from typing import ( Any, Callable, - Dict, - Type, + ClassVar, TypeVar, - Union, no_type_check, ) @@ -39,6 +40,7 @@ def resolve_name(obj): from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from typing_extensions import dataclass_transform +from ninja.constants import NOT_SET from ninja.signature.utils import get_args_names, has_kwargs from ninja.types import DictStrAny @@ -50,45 +52,68 @@ def resolve_name(obj): S = TypeVar("S", bound="Schema") +def dict_getter(key: str, obj: DjangoGetter) -> Any: + if key not in obj._obj: + raise AttributeError(key) + return obj._obj[key] + + +def attr_getter(key: str, obj: DjangoGetter) -> Any: + try: + return Variable(key).resolve(obj._obj) + except VariableDoesNotExist as e: + raise AttributeError(key) from e + + +def resolver(resolve_func: Callable, _: str, obj: DjangoGetter) -> Any: + return resolve_func(getter=obj) + + +def get_attr(key: str, obj: DjangoGetter) -> Any: + return getattr(obj._obj, key) + + class DjangoGetter: - __slots__ = ("_obj", "_schema_cls", "_context", "__dict__") + __slots__ = ("_obj", "_schema_cls", "_context", "__dict__", "_cache_key") + _cache: ClassVar[dict[str, Callable]] = {} - def __init__(self, obj: Any, schema_cls: Type[S], context: Any = None): + def __init__(self, obj: Any, schema_cls: type[S], context: Any = None) -> None: self._obj = obj self._schema_cls = schema_cls self._context = context + self._cache_key = f"{self._schema_cls.__module__}.{self._schema_cls.__name__}.{self._obj.__class__.__name__}" def __getattr__(self, key: str) -> Any: - # if key.startswith("__pydantic"): - # return getattr(self._obj, key) - - resolver = self._schema_cls._ninja_resolvers.get(key) - if resolver: - value = resolver(getter=self) - else: - if isinstance(self._obj, dict): - if key not in self._obj: - raise AttributeError(key) - value = self._obj[key] - else: - try: - value = getattr(self._obj, key) - except AttributeError: - try: - # value = attrgetter(key)(self._obj) - value = Variable(key).resolve(self._obj) - # TODO: Variable(key) __init__ is actually slower than - # Variable.resolve - so it better be cached - except VariableDoesNotExist as e: - raise AttributeError(key) from e + cache_key = f"{self._cache_key}.{key}" + if cache_key in DjangoGetter._cache: + # Use cached function, if available. + value = DjangoGetter._cache[cache_key](key, self) + return self._convert_result(value) + + stored_resolver = self._schema_cls._ninja_resolvers.get(key) + if stored_resolver: + # Use resolver when provided for this key. + value = stored_resolver(getter=self) + # bind resolver of this key to the _cache + DjangoGetter._cache[cache_key] = partial(resolver, stored_resolver) + return self._convert_result(value) + + if isinstance(self._obj, dict): + # Use dict lookup, faster than getattr + value = dict_getter(key, self) + DjangoGetter._cache[cache_key] = dict_getter + return self._convert_result(value) + + value = getattr(self._obj, key, NOT_SET) + if value is not NOT_SET: + # If getattr worked, use that. + DjangoGetter._cache[cache_key] = get_attr + return self._convert_result(value) + # Finally, fallback to attr_getter + value = attr_getter(key, self) + DjangoGetter._cache[cache_key] = attr_getter return self._convert_result(value) - # def get(self, key: Any, default: Any = None) -> Any: - # try: - # return self[key] - # except KeyError: - # return default - def _convert_result(self, result: Any) -> Any: if isinstance(result, Manager): return list(result.all()) @@ -116,7 +141,7 @@ class Resolver: _func: Any _takes_context: bool - def __init__(self, func: Union[Callable, staticmethod]): + def __init__(self, func: Callable | staticmethod): if isinstance(func, staticmethod): self._static = True self._func = func.__func__ @@ -139,25 +164,10 @@ def __call__(self, getter: DjangoGetter) -> Any: ) # pragma: no cover # return self._func(self._fake_instance(getter), getter._obj) - # def _fake_instance(self, getter: DjangoGetter) -> "Schema": - # """ - # Generate a partial schema instance that can be used as the ``self`` - # attribute of resolver functions. - # """ - - # class PartialSchema(Schema): - # def __getattr__(self, key: str) -> Any: - # value = getattr(getter, key) - # field = getter._schema_cls.model_fields[key] - # value = field.validate(value, values={}, loc=key, cls=None)[0] - # return value - - # return PartialSchema() - @dataclass_transform(kw_only_default=True, field_specifiers=(Field,)) class ResolverMetaclass(ModelMetaclass): - _ninja_resolvers: Dict[str, Resolver] + _ninja_resolvers: dict[str, Resolver] @no_type_check def __new__(cls, name, bases, namespace, **kwargs): @@ -228,7 +238,7 @@ def _run_root_validator( return handler(values) @classmethod - def from_orm(cls: Type[S], obj: Any, **kw: Any) -> S: + def from_orm(cls: type[S], obj: Any, **kw: Any) -> S: return cls.model_validate(obj, **kw) def dict(self, *a: Any, **kw: Any) -> DictStrAny: diff --git a/ninja/testing/client.py b/ninja/testing/client.py index 0be8ce0a..ab3df140 100644 --- a/ninja/testing/client.py +++ b/ninja/testing/client.py @@ -127,12 +127,10 @@ def _build_request( request.META = request_params.pop("META", {"REMOTE_ADDR": "127.0.0.1"}) request.FILES = request_params.pop("FILES", {}) - request.META.update( - { - f"HTTP_{k.replace('-', '_')}": v - for k, v in request_params.pop("headers", {}).items() - } - ) + request.META.update({ + f"HTTP_{k.replace('-', '_')}": v + for k, v in request_params.pop("headers", {}).items() + }) request.headers = HttpHeaders(request.META)