From 417e63761dfca718317c765266564ec79da30399 Mon Sep 17 00:00:00 2001 From: Vitaliy Kucheryaviy Date: Sat, 9 Sep 2023 16:11:23 +0300 Subject: [PATCH] Annotated syntax support --- WHATSNEW_V1.md | 7 +- ninja/params_functions.py | 14 +-- ninja/signature/details.py | 48 +++++---- tests/test_annotated.py | 203 +++++++++++++++++++++++++++++++++++++ 4 files changed, 242 insertions(+), 30 deletions(-) create mode 100644 tests/test_annotated.py diff --git a/WHATSNEW_V1.md b/WHATSNEW_V1.md index 1a771978a..d46239470 100644 --- a/WHATSNEW_V1.md +++ b/WHATSNEW_V1.md @@ -3,9 +3,14 @@ - CSRF changes - Auth async support - Schema.Meta - - pagination request in paginate_queryset + - pagination: request in paginate_queryset + - decorators - openapi docs plugable + - add_router supports strings TODO: - async pagination +Backwards incompatible stuff + - resolve_xxx(self, ...) + - pydantic v1 \ No newline at end of file diff --git a/ninja/params_functions.py b/ninja/params_functions.py index e54777b69..aebe03122 100644 --- a/ninja/params_functions.py +++ b/ninja/params_functions.py @@ -9,7 +9,7 @@ def Path( # noqa: N802 - default: Any, + default: Any = ..., *, alias: Optional[str] = None, title: Optional[str] = None, @@ -48,7 +48,7 @@ def Path( # noqa: N802 def Query( # noqa: N802 - default: Any, + default: Any = ..., *, alias: Optional[str] = None, title: Optional[str] = None, @@ -87,7 +87,7 @@ def Query( # noqa: N802 def Header( # noqa: N802 - default: Any, + default: Any = ..., *, alias: Optional[str] = None, title: Optional[str] = None, @@ -126,7 +126,7 @@ def Header( # noqa: N802 def Cookie( # noqa: N802 - default: Any, + default: Any = ..., *, alias: Optional[str] = None, title: Optional[str] = None, @@ -165,7 +165,7 @@ def Cookie( # noqa: N802 def Body( # noqa: N802 - default: Any, + default: Any = ..., *, alias: Optional[str] = None, title: Optional[str] = None, @@ -204,7 +204,7 @@ def Body( # noqa: N802 def Form( # noqa: N802 - default: Any, + default: Any = ..., *, alias: Optional[str] = None, title: Optional[str] = None, @@ -243,7 +243,7 @@ def Form( # noqa: N802 def File( # noqa: N802 - default: Any, + default: Any = ..., *, alias: Optional[str] = None, title: Optional[str] = None, diff --git a/ninja/signature/details.py b/ninja/signature/details.py index 7ac979e5f..5b1188288 100644 --- a/ninja/signature/details.py +++ b/ninja/signature/details.py @@ -7,15 +7,10 @@ from django.http import HttpResponse from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined +from typing_extensions import Annotated from ninja import UploadedFile, params -from ninja.compatibility.util import ( - UNION_TYPES, - get_args, -) -from ninja.compatibility.util import ( - get_origin as get_collection_origin, -) +from ninja.compatibility.util import UNION_TYPES, get_args, get_origin from ninja.errors import ConfigError from ninja.params import Body, File, Form, _MultiPartBody from ninja.params_models import TModel, TModels @@ -205,15 +200,24 @@ def _model_flatten_map(self, model: TModel, prefix: str) -> Generator: def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam: # _EMPTY = self.signature.empty annotation = arg.annotation + default = arg.default + + if get_origin(annotation) is Annotated: + args = get_args(annotation) + if isinstance(args[1], params.Param): + prev_default = default + annotation, default = args + if prev_default != self.signature.empty: + default.default = prev_default if annotation == self.signature.empty: - if arg.default == self.signature.empty: + if default == self.signature.empty: annotation = str else: - if isinstance(arg.default, params.Param): - annotation = type(arg.default.default) + if isinstance(default, params.Param): + annotation = type(default.default) else: - annotation = type(arg.default) + annotation = type(default) if annotation == PydanticUndefined.__class__: # TODO: ^ check why is that so @@ -228,34 +232,34 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam: is_collection and annotation.__args__[0] == UploadedFile ): # People often forgot to mark UploadedFile as a File, so we better assign it automatically - if arg.default == self.signature.empty or arg.default is None: - default = arg.default == self.signature.empty and ... or arg.default + if default == self.signature.empty or default is None: + default = default == self.signature.empty and ... or default return FuncParam(name, name, File(default), annotation, is_collection) # 1) if type of the param is defined as one of the Param's subclasses - we just use that definition - if isinstance(arg.default, params.Param): - param_source = arg.default + if isinstance(default, params.Param): + param_source = default # 2) if param name is a part of the path parameter elif name in self.path_params_names: assert ( - arg.default == self.signature.empty + default == self.signature.empty ), f"'{name}' is a path param, default not allowed" param_source = params.Path(...) # 3) if param is a collection, or annotation is part of pydantic model: elif is_collection or is_pydantic_model(annotation): - if arg.default == self.signature.empty: + if default == self.signature.empty: param_source = params.Body(...) else: - param_source = params.Body(arg.default) + param_source = params.Body(default) # 4) the last case is query param else: - if arg.default == self.signature.empty: + if default == self.signature.empty: param_source = params.Query(...) else: - param_source = params.Query(arg.default) + param_source = params.Query(default) return FuncParam( name, param_source.alias or name, param_source, annotation, is_collection @@ -264,7 +268,7 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam: def is_pydantic_model(cls: Any) -> bool: try: - if get_collection_origin(cls) in UNION_TYPES: + if get_origin(cls) in UNION_TYPES: return any(issubclass(arg, pydantic.BaseModel) for arg in get_args(cls)) return issubclass(cls, pydantic.BaseModel) except TypeError: @@ -272,7 +276,7 @@ def is_pydantic_model(cls: Any) -> bool: def is_collection_type(annotation: Any) -> bool: - origin = get_collection_origin(annotation) + origin = get_origin(annotation) collection_types = (List, list, set, tuple) if origin is None: return ( diff --git a/tests/test_annotated.py b/tests/test_annotated.py new file mode 100644 index 000000000..3c8ab9404 --- /dev/null +++ b/tests/test_annotated.py @@ -0,0 +1,203 @@ +from typing_extensions import Annotated +from ninja import NinjaAPI, Schema, Form, Query, Path, Header, Body, Cookie +from ninja.testing import TestClient + + +api = NinjaAPI() + + +class FormData(Schema): + x: int + y: float + + +class Payload(Schema): + t: int + p: str + + +@api.post("/multi/{p}") +def multi_op( + request, + q: Annotated[str, Query(description="Query param")], + p: Annotated[int, Path(description="Path param")], + f: Annotated[FormData, Form(description="Form params")], + c: Annotated[str, Cookie(description="Cookie params")], +): + return {"q": q, "p": p, "f": f.dict(), "c": c} + + +@api.post("/query_list") +def query_list( + request, + q: Annotated[list[str], Query(description="User ID")], +): + return {"q": q} + + +@api.post("/headers") +def headers(request, h: Annotated[str, Header()] = "some-default"): + return {"h": h} + + +@api.post("/body") +def body_op( + request, payload: Annotated[Payload, Body(examples=[{"t": 42, "p": "test"}])] +): + return {"payload": payload} + + +client = TestClient(api) + + +def test_multi_op(): + response = client.post("/multi/42?q=1", data={"x": 1, "y": 2}, COOKIES={"c": "3"}) + assert response.status_code == 200, response.content + assert response.json() == { + "q": "1", + "p": 42, + "f": {"x": 1, "y": 2.0}, + "c": "3", + } + + +def test_query_list(): + response = client.post("/query_list?q=1&q=2") + assert response.status_code == 200, response.content + assert response.json() == {"q": ["1", "2"]} + + +def test_body_op(): + response = client.post("/body", json={"t": 42, "p": "test"}) + assert response.status_code == 200, response.content + assert response.json() == {"payload": {"p": "test", "t": 42}} + + +def test_headers(): + response = client.post("/headers", headers={"h": "test"}) + assert response.status_code == 200, response.content + assert response.json() == {"h": "test"} + + +def test_openapi_schema(): + schema = api.get_openapi_schema()["paths"] + print(schema) + assert schema == { + "/api/multi/{p}": { + "post": { + "operationId": "test_annotated_multi_op", + "summary": "Multi Op", + "parameters": [ + { + "in": "query", + "name": "q", + "schema": { + "description": "Query param", + "title": "Q", + "type": "string", + }, + "required": True, + "description": "Query param", + }, + { + "in": "path", + "name": "p", + "schema": { + "description": "Path param", + "title": "P", + "type": "integer", + }, + "required": True, + "description": "Path param", + }, + { + "in": "cookie", + "name": "c", + "schema": { + "description": "Cookie params", + "title": "C", + "type": "string", + }, + "required": True, + "description": "Cookie params", + }, + ], + "responses": {200: {"description": "OK"}}, + "requestBody": { + "content": { + "application/x-www-form-urlencoded": { + "schema": { + "title": "FormParams", + "type": "object", + "properties": { + "x": {"title": "X", "type": "integer"}, + "y": {"title": "Y", "type": "number"}, + }, + "required": ["x", "y"], + } + } + }, + "required": True, + }, + } + }, + "/api/query_list": { + "post": { + "operationId": "test_annotated_query_list", + "summary": "Query List", + "parameters": [ + { + "in": "query", + "name": "q", + "schema": { + "description": "User ID", + "items": {"type": "string"}, + "title": "Q", + "type": "array", + }, + "required": True, + "description": "User ID", + } + ], + "responses": {200: {"description": "OK"}}, + } + }, + "/api/headers": { + "post": { + "operationId": "test_annotated_headers", + "summary": "Headers", + "parameters": [ + { + "in": "header", + "name": "h", + "schema": { + "default": "some-default", + "title": "H", + "type": "string", + }, + "required": False, + } + ], + "responses": {200: {"description": "OK"}}, + } + }, + "/api/body": { + "post": { + "operationId": "test_annotated_body_op", + "summary": "Body Op", + "parameters": [], + "responses": {200: {"description": "OK"}}, + "requestBody": { + "content": { + "application/json": { + "schema": { + "allOf": [{"$ref": "#/components/schemas/Payload"}], + "examples": [{"p": "test", "t": 42}], + } + } + }, + "required": True, + }, + } + }, + }