Skip to content

Commit

Permalink
PatchDict utility
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalik committed Aug 7, 2024
1 parent 3c55d99 commit 57e7bc0
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ninja/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Query,
QueryEx,
)
from ninja.patch_dict import PatchDict
from ninja.router import Router
from ninja.schema import Schema

Expand Down Expand Up @@ -55,4 +56,5 @@
"FilterSchema",
"Swagger",
"Redoc",
"PatchDict",
]
52 changes: 52 additions & 0 deletions ninja/patch_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from pydantic_core import core_schema
from typing_extensions import Annotated

from ninja import Body
from ninja.utils import is_optional_type


class ModelToDict(dict):
_wrapped_model: Any = None
_wrapped_model_dump_params: Dict[str, Any] = {}

@classmethod
def __get_pydantic_core_schema__(cls, _source: Any, _handler: Any) -> Any:
return core_schema.no_info_after_validator_function(
cls._validate,
cls._wrapped_model.__pydantic_core_schema__,
)

@classmethod
def _validate(cls, input_value: Any) -> Any:
return input_value.model_dump(**cls._wrapped_model_dump_params)


def create_patch_schema(schema_cls: Type[Any]) -> Type[ModelToDict]:
values, annotations = {}, {}
for f in schema_cls.__fields__.keys():
t = schema_cls.__annotations__[f]
if not is_optional_type(t):
values[f] = getattr(schema_cls, f, None)
annotations[f] = Optional[t]
values["__annotations__"] = annotations
OptionalSchema = type(f"{schema_cls.__name__}Patch", (schema_cls,), values)

class OptionalDictSchema(ModelToDict):
_wrapped_model = OptionalSchema
_wrapped_model_dump_params = {"exclude_unset": True}

return OptionalDictSchema


class PatchDictUtil:
def __getitem__(self, schema_cls: Any) -> Any:
new_cls = create_patch_schema(schema_cls)
return Body[new_cls] # type: ignore


if TYPE_CHECKING: # pragma: nocover
PatchDict = Annotated[dict, "<PatchDict>"]
else:
PatchDict = PatchDictUtil()
7 changes: 7 additions & 0 deletions ninja/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def is_async_callable(f: Callable[..., Any]) -> bool:
)


def is_optional_type(t: Type[Any]) -> bool:
try:
return type(None) in t.__args__
except AttributeError:
return False


def contribute_operation_callback(
func: Callable[..., Any], callback: Callable[..., Any]
) -> None:
Expand Down
60 changes: 60 additions & 0 deletions tests/test_patch_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Optional

import pytest

from ninja import NinjaAPI, Schema
from ninja.patch_dict import PatchDict
from ninja.testing import TestClient

api = NinjaAPI()

client = TestClient(api)


class SomeSchema(Schema):
name: str
age: int
category: Optional[str] = None


@api.patch("/patch")
def patch(request, payload: PatchDict[SomeSchema]):
return {"payload": payload, "type": str(type(payload))}


@pytest.mark.parametrize(
"input,output",
[
({"name": "foo"}, {"name": "foo"}),
({"age": "1"}, {"age": 1}),
({}, {}),
({"wrong_param": 1}, {}),
({"age": None}, {"age": None}),
],
)
def test_patch_calls(input: dict, output: dict):
response = client.patch("/patch", json=input)
assert response.json() == {"payload": output, "type": "<class 'dict'>"}


def test_schema():
"Checking that json schema properties are all optional"
schema = api.get_openapi_schema()
assert schema["components"]["schemas"]["SomeSchemaPatch"] == {
"title": "SomeSchemaPatch",
"type": "object",
"properties": {
"name": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"title": "Name",
},
"age": {
"anyOf": [{"type": "integer"}, {"type": "null"}],
"title": "Age",
},
"category": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"title": "Category",
},
},
}

0 comments on commit 57e7bc0

Please sign in to comment.