Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue for extra fields not appearing in the errors #845

Closed
wants to merge 13 commits into from
6 changes: 3 additions & 3 deletions ninja/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def _result_to_response(
return temporal_response

resp_object = ResponseObject(result)
# ^ we need object because getter_dict seems work only with from_orm
result = response_model.from_orm(resp_object).model_dump(
# ^ we need object because getter_dict seems work only with model_validate
result = response_model.model_validate(resp_object).model_dump(
by_alias=self.by_alias,
exclude_unset=self.exclude_unset,
exclude_defaults=self.exclude_defaults,
Expand Down Expand Up @@ -419,7 +419,7 @@ def _not_allowed(self) -> HttpResponse:


class ResponseObject:
"Basically this is just a helper to be able to pass response to pydantic's from_orm"
"Basically this is just a helper to be able to pass response to pydantic's model_validate"

def __init__(self, response: HttpResponse) -> None:
self.response = response
41 changes: 35 additions & 6 deletions ninja/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ def resolve_initials(self, obj):

"""
import warnings
from typing import Any, Callable, Dict, Type, TypeVar, Union, no_type_check
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union, no_type_check

import pydantic
from django.db.models import Manager, QuerySet
from django.db.models.fields.files import FieldFile
from django.template import Variable, VariableDoesNotExist
from pydantic import BaseModel, Field, ValidationInfo, model_validator, validator
from pydantic._internal._model_construction import ModelMetaclass
from pydantic.functional_validators import ModelWrapValidatorHandler
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue

from ninja.signature.utils import get_args_names, has_kwargs
Expand All @@ -45,7 +46,7 @@ def resolve_initials(self, obj):
class DjangoGetter:
__slots__ = ("_obj", "_schema_cls", "_context")

def __init__(self, obj: Any, schema_cls: "Schema", context: Any = None):
def __init__(self, obj: Any, schema_cls: Type[S], context: Any = None):
self._obj = obj
self._schema_cls = schema_cls
self._context = context
Expand All @@ -54,7 +55,7 @@ def __getattr__(self, key: str) -> Any:
# if key.startswith("__pydantic"):
# return getattr(self._obj, key)

resolver = self._schema_cls._ninja_resolvers.get(key) # type: ignore
resolver = self._schema_cls._ninja_resolvers.get(key)
if resolver:
value = resolver(getter=self)
else:
Expand Down Expand Up @@ -198,15 +199,43 @@ class Schema(BaseModel, metaclass=ResolverMetaclass):
class Config:
from_attributes = True # aka orm_mode

@model_validator(mode="before")
def _run_root_validator(cls, values: Any, info: ValidationInfo) -> Any:
@model_validator(mode="wrap")
@classmethod
def _run_root_validator(
cls, values: Any, handler: ModelWrapValidatorHandler[S], info: ValidationInfo
) -> S:
# We dont perform 'before' validations if an validating through 'model_validate'
through_model_validate = (
info and info.context and info.context.get("through_model_validate", False)
)
if not through_model_validate:
handler(values)

# We add our DjangoGetter for the Schema
values = DjangoGetter(values, cls, info.context)
return values

# To update the schema with our DjangoGetter
return handler(values)

@classmethod
def from_orm(cls: Type[S], obj: Any) -> S:
return cls.model_validate(obj)

@classmethod
def model_validate(
cls: Type[S],
obj: Any,
*,
strict: Optional[bool] = None,
from_attributes: Optional[bool] = None,
context: Optional[Dict[str, Any]] = None,
) -> S:
context = context or {}
context["through_model_validate"] = True
return super().model_validate(
obj, strict=strict, from_attributes=from_attributes, context=context
)

def dict(self, *a: Any, **kw: Any) -> DictStrAny:
"Backward compatibility with pydantic 1.x"
return self.model_dump(*a, **kw)
Expand Down
76 changes: 75 additions & 1 deletion tests/test_request.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
from typing import Optional

import pytest
from pydantic import ConfigDict

from ninja import Cookie, Header, Router
from ninja import Body, Cookie, Header, Router, Schema
from ninja.testing import TestClient


class OptionalEmptySchema(Schema):
model_config = ConfigDict(extra="forbid")
name: Optional[str] = None


class ExtraForbidSchema(Schema):
model_config = ConfigDict(extra="forbid")
name: str
metadata: Optional[OptionalEmptySchema] = None


router = Router()


Expand Down Expand Up @@ -41,6 +56,11 @@ def cookies2(request, wpn: str = Cookie(..., alias="weapon")):
return wpn


@router.post("/test-schema")
def test_schema(request, payload: ExtraForbidSchema = Body(...)):
return "ok"


client = TestClient(router)


Expand Down Expand Up @@ -77,3 +97,57 @@ def test_headers(path, expected_status, expected_response):
assert response.status_code == expected_status, response.content
print(response.json())
assert response.json() == expected_response


@pytest.mark.parametrize(
"path,json,expected_status,expected_response",
[
(
"/test-schema",
{"name": "test", "extra_name": "test2"},
422,
{
"detail": [
{
"type": "extra_forbidden",
"loc": ["body", "payload", "extra_name"],
"msg": "Extra inputs are not permitted",
}
]
},
),
(
"/test-schema",
{"name": "test", "metadata": {"extra_name": "xxx"}},
422,
{
"detail": [
{
"loc": ["body", "payload", "metadata", "extra_name"],
"msg": "Extra inputs are not permitted",
"type": "extra_forbidden",
}
]
},
),
(
"/test-schema",
{"name": "test", "metadata": "test2"},
422,
{
"detail": [
{
"type": "model_attributes_type",
"loc": ["body", "payload", "metadata"],
"msg": "Input should be a valid dictionary or object to extract fields from",
}
]
},
),
],
)
def test_pydantic_config(path, json, expected_status, expected_response):
# test extra forbid
response = client.post(path, json=json)
assert response.json() == expected_response
assert response.status_code == expected_status