diff --git a/ninja/operation.py b/ninja/operation.py index 6ca3087c6..55321b6e0 100644 --- a/ninja/operation.py +++ b/ninja/operation.py @@ -204,8 +204,8 @@ def _result_to_response( return temporal_response resp_object = ResponseObject(result) - # ^ we need object because getter_dict seems work only with from_orm - result = response_model.from_orm(resp_object).model_dump( + # ^ we need object because getter_dict seems work only with model_validate + result = response_model.model_validate(resp_object).model_dump( by_alias=self.by_alias, exclude_unset=self.exclude_unset, exclude_defaults=self.exclude_defaults, @@ -419,7 +419,7 @@ def _not_allowed(self) -> HttpResponse: class ResponseObject: - "Basically this is just a helper to be able to pass response to pydantic's from_orm" + "Basically this is just a helper to be able to pass response to pydantic's model_validate" def __init__(self, response: HttpResponse) -> None: self.response = response diff --git a/ninja/schema.py b/ninja/schema.py index 75735204d..a32d759bc 100644 --- a/ninja/schema.py +++ b/ninja/schema.py @@ -21,7 +21,7 @@ def resolve_initials(self, obj): """ import warnings -from typing import Any, Callable, Dict, Type, TypeVar, Union, no_type_check +from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union, no_type_check import pydantic from django.db.models import Manager, QuerySet @@ -29,6 +29,7 @@ def resolve_initials(self, obj): from django.template import Variable, VariableDoesNotExist from pydantic import BaseModel, Field, ValidationInfo, model_validator, validator from pydantic._internal._model_construction import ModelMetaclass +from pydantic.functional_validators import ModelWrapValidatorHandler from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from ninja.signature.utils import get_args_names, has_kwargs @@ -45,7 +46,7 @@ def resolve_initials(self, obj): class DjangoGetter: __slots__ = ("_obj", "_schema_cls", "_context") - def __init__(self, obj: Any, schema_cls: "Schema", context: Any = None): + def __init__(self, obj: Any, schema_cls: Type[S], context: Any = None): self._obj = obj self._schema_cls = schema_cls self._context = context @@ -54,7 +55,7 @@ def __getattr__(self, key: str) -> Any: # if key.startswith("__pydantic"): # return getattr(self._obj, key) - resolver = self._schema_cls._ninja_resolvers.get(key) # type: ignore + resolver = self._schema_cls._ninja_resolvers.get(key) if resolver: value = resolver(getter=self) else: @@ -198,15 +199,43 @@ class Schema(BaseModel, metaclass=ResolverMetaclass): class Config: from_attributes = True # aka orm_mode - @model_validator(mode="before") - def _run_root_validator(cls, values: Any, info: ValidationInfo) -> Any: + @model_validator(mode="wrap") + @classmethod + def _run_root_validator( + cls, values: Any, handler: ModelWrapValidatorHandler[S], info: ValidationInfo + ) -> S: + # We dont perform 'before' validations if an validating through 'model_validate' + through_model_validate = ( + info and info.context and info.context.get("through_model_validate", False) + ) + if not through_model_validate: + handler(values) + + # We add our DjangoGetter for the Schema values = DjangoGetter(values, cls, info.context) - return values + + # To update the schema with our DjangoGetter + return handler(values) @classmethod def from_orm(cls: Type[S], obj: Any) -> S: return cls.model_validate(obj) + @classmethod + def model_validate( + cls: Type[S], + obj: Any, + *, + strict: Optional[bool] = None, + from_attributes: Optional[bool] = None, + context: Optional[Dict[str, Any]] = None, + ) -> S: + context = context or {} + context["through_model_validate"] = True + return super().model_validate( + obj, strict=strict, from_attributes=from_attributes, context=context + ) + def dict(self, *a: Any, **kw: Any) -> DictStrAny: "Backward compatibility with pydantic 1.x" return self.model_dump(*a, **kw) diff --git a/tests/test_request.py b/tests/test_request.py index d7e4dfbe7..40d5dd5f3 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -1,8 +1,23 @@ +from typing import Optional + import pytest +from pydantic import ConfigDict -from ninja import Cookie, Header, Router +from ninja import Body, Cookie, Header, Router, Schema from ninja.testing import TestClient + +class OptionalEmptySchema(Schema): + model_config = ConfigDict(extra="forbid") + name: Optional[str] = None + + +class ExtraForbidSchema(Schema): + model_config = ConfigDict(extra="forbid") + name: str + metadata: Optional[OptionalEmptySchema] = None + + router = Router() @@ -41,6 +56,11 @@ def cookies2(request, wpn: str = Cookie(..., alias="weapon")): return wpn +@router.post("/test-schema") +def test_schema(request, payload: ExtraForbidSchema = Body(...)): + return "ok" + + client = TestClient(router) @@ -77,3 +97,57 @@ def test_headers(path, expected_status, expected_response): assert response.status_code == expected_status, response.content print(response.json()) assert response.json() == expected_response + + +@pytest.mark.parametrize( + "path,json,expected_status,expected_response", + [ + ( + "/test-schema", + {"name": "test", "extra_name": "test2"}, + 422, + { + "detail": [ + { + "type": "extra_forbidden", + "loc": ["body", "payload", "extra_name"], + "msg": "Extra inputs are not permitted", + } + ] + }, + ), + ( + "/test-schema", + {"name": "test", "metadata": {"extra_name": "xxx"}}, + 422, + { + "detail": [ + { + "loc": ["body", "payload", "metadata", "extra_name"], + "msg": "Extra inputs are not permitted", + "type": "extra_forbidden", + } + ] + }, + ), + ( + "/test-schema", + {"name": "test", "metadata": "test2"}, + 422, + { + "detail": [ + { + "type": "model_attributes_type", + "loc": ["body", "payload", "metadata"], + "msg": "Input should be a valid dictionary or object to extract fields from", + } + ] + }, + ), + ], +) +def test_pydantic_config(path, json, expected_status, expected_response): + # test extra forbid + response = client.post(path, json=json) + assert response.json() == expected_response + assert response.status_code == expected_status