Skip to content

Commit

Permalink
Annotated syntax support
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalik committed Sep 9, 2023
1 parent 9f9b567 commit 417e637
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 30 deletions.
7 changes: 6 additions & 1 deletion WHATSNEW_V1.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
- CSRF changes
- Auth async support
- Schema.Meta
- pagination request in paginate_queryset
- pagination: request in paginate_queryset
- decorators
- openapi docs plugable
- add_router supports strings

TODO:
- async pagination

Backwards incompatible stuff
- resolve_xxx(self, ...)
- pydantic v1
14 changes: 7 additions & 7 deletions ninja/params_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def Path( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -48,7 +48,7 @@ def Path( # noqa: N802


def Query( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -87,7 +87,7 @@ def Query( # noqa: N802


def Header( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -126,7 +126,7 @@ def Header( # noqa: N802


def Cookie( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -165,7 +165,7 @@ def Cookie( # noqa: N802


def Body( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -204,7 +204,7 @@ def Body( # noqa: N802


def Form( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down Expand Up @@ -243,7 +243,7 @@ def Form( # noqa: N802


def File( # noqa: N802
default: Any,
default: Any = ...,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand Down
48 changes: 26 additions & 22 deletions ninja/signature/details.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,10 @@
from django.http import HttpResponse
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import Annotated

from ninja import UploadedFile, params
from ninja.compatibility.util import (
UNION_TYPES,
get_args,
)
from ninja.compatibility.util import (
get_origin as get_collection_origin,
)
from ninja.compatibility.util import UNION_TYPES, get_args, get_origin
from ninja.errors import ConfigError
from ninja.params import Body, File, Form, _MultiPartBody
from ninja.params_models import TModel, TModels
Expand Down Expand Up @@ -205,15 +200,24 @@ def _model_flatten_map(self, model: TModel, prefix: str) -> Generator:
def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
# _EMPTY = self.signature.empty
annotation = arg.annotation
default = arg.default

if get_origin(annotation) is Annotated:
args = get_args(annotation)
if isinstance(args[1], params.Param):
prev_default = default
annotation, default = args
if prev_default != self.signature.empty:
default.default = prev_default

if annotation == self.signature.empty:
if arg.default == self.signature.empty:
if default == self.signature.empty:
annotation = str
else:
if isinstance(arg.default, params.Param):
annotation = type(arg.default.default)
if isinstance(default, params.Param):
annotation = type(default.default)
else:
annotation = type(arg.default)
annotation = type(default)

if annotation == PydanticUndefined.__class__:
# TODO: ^ check why is that so
Expand All @@ -228,34 +232,34 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
is_collection and annotation.__args__[0] == UploadedFile
):
# People often forgot to mark UploadedFile as a File, so we better assign it automatically
if arg.default == self.signature.empty or arg.default is None:
default = arg.default == self.signature.empty and ... or arg.default
if default == self.signature.empty or default is None:
default = default == self.signature.empty and ... or default
return FuncParam(name, name, File(default), annotation, is_collection)

# 1) if type of the param is defined as one of the Param's subclasses - we just use that definition
if isinstance(arg.default, params.Param):
param_source = arg.default
if isinstance(default, params.Param):
param_source = default

# 2) if param name is a part of the path parameter
elif name in self.path_params_names:
assert (
arg.default == self.signature.empty
default == self.signature.empty
), f"'{name}' is a path param, default not allowed"
param_source = params.Path(...)

# 3) if param is a collection, or annotation is part of pydantic model:
elif is_collection or is_pydantic_model(annotation):
if arg.default == self.signature.empty:
if default == self.signature.empty:
param_source = params.Body(...)
else:
param_source = params.Body(arg.default)
param_source = params.Body(default)

# 4) the last case is query param
else:
if arg.default == self.signature.empty:
if default == self.signature.empty:
param_source = params.Query(...)
else:
param_source = params.Query(arg.default)
param_source = params.Query(default)

return FuncParam(
name, param_source.alias or name, param_source, annotation, is_collection
Expand All @@ -264,15 +268,15 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:

def is_pydantic_model(cls: Any) -> bool:
try:
if get_collection_origin(cls) in UNION_TYPES:
if get_origin(cls) in UNION_TYPES:
return any(issubclass(arg, pydantic.BaseModel) for arg in get_args(cls))
return issubclass(cls, pydantic.BaseModel)
except TypeError:
return False


def is_collection_type(annotation: Any) -> bool:
origin = get_collection_origin(annotation)
origin = get_origin(annotation)
collection_types = (List, list, set, tuple)
if origin is None:
return (
Expand Down
203 changes: 203 additions & 0 deletions tests/test_annotated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from typing_extensions import Annotated
from ninja import NinjaAPI, Schema, Form, Query, Path, Header, Body, Cookie
from ninja.testing import TestClient


api = NinjaAPI()


class FormData(Schema):
x: int
y: float


class Payload(Schema):
t: int
p: str


@api.post("/multi/{p}")
def multi_op(
request,
q: Annotated[str, Query(description="Query param")],
p: Annotated[int, Path(description="Path param")],
f: Annotated[FormData, Form(description="Form params")],
c: Annotated[str, Cookie(description="Cookie params")],
):
return {"q": q, "p": p, "f": f.dict(), "c": c}


@api.post("/query_list")
def query_list(
request,
q: Annotated[list[str], Query(description="User ID")],
):
return {"q": q}


@api.post("/headers")
def headers(request, h: Annotated[str, Header()] = "some-default"):
return {"h": h}


@api.post("/body")
def body_op(
request, payload: Annotated[Payload, Body(examples=[{"t": 42, "p": "test"}])]
):
return {"payload": payload}


client = TestClient(api)


def test_multi_op():
response = client.post("/multi/42?q=1", data={"x": 1, "y": 2}, COOKIES={"c": "3"})
assert response.status_code == 200, response.content
assert response.json() == {
"q": "1",
"p": 42,
"f": {"x": 1, "y": 2.0},
"c": "3",
}


def test_query_list():
response = client.post("/query_list?q=1&q=2")
assert response.status_code == 200, response.content
assert response.json() == {"q": ["1", "2"]}


def test_body_op():
response = client.post("/body", json={"t": 42, "p": "test"})
assert response.status_code == 200, response.content
assert response.json() == {"payload": {"p": "test", "t": 42}}


def test_headers():
response = client.post("/headers", headers={"h": "test"})
assert response.status_code == 200, response.content
assert response.json() == {"h": "test"}


def test_openapi_schema():
schema = api.get_openapi_schema()["paths"]
print(schema)
assert schema == {
"/api/multi/{p}": {
"post": {
"operationId": "test_annotated_multi_op",
"summary": "Multi Op",
"parameters": [
{
"in": "query",
"name": "q",
"schema": {
"description": "Query param",
"title": "Q",
"type": "string",
},
"required": True,
"description": "Query param",
},
{
"in": "path",
"name": "p",
"schema": {
"description": "Path param",
"title": "P",
"type": "integer",
},
"required": True,
"description": "Path param",
},
{
"in": "cookie",
"name": "c",
"schema": {
"description": "Cookie params",
"title": "C",
"type": "string",
},
"required": True,
"description": "Cookie params",
},
],
"responses": {200: {"description": "OK"}},
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"title": "FormParams",
"type": "object",
"properties": {
"x": {"title": "X", "type": "integer"},
"y": {"title": "Y", "type": "number"},
},
"required": ["x", "y"],
}
}
},
"required": True,
},
}
},
"/api/query_list": {
"post": {
"operationId": "test_annotated_query_list",
"summary": "Query List",
"parameters": [
{
"in": "query",
"name": "q",
"schema": {
"description": "User ID",
"items": {"type": "string"},
"title": "Q",
"type": "array",
},
"required": True,
"description": "User ID",
}
],
"responses": {200: {"description": "OK"}},
}
},
"/api/headers": {
"post": {
"operationId": "test_annotated_headers",
"summary": "Headers",
"parameters": [
{
"in": "header",
"name": "h",
"schema": {
"default": "some-default",
"title": "H",
"type": "string",
},
"required": False,
}
],
"responses": {200: {"description": "OK"}},
}
},
"/api/body": {
"post": {
"operationId": "test_annotated_body_op",
"summary": "Body Op",
"parameters": [],
"responses": {200: {"description": "OK"}},
"requestBody": {
"content": {
"application/json": {
"schema": {
"allOf": [{"$ref": "#/components/schemas/Payload"}],
"examples": [{"p": "test", "t": 42}],
}
}
},
"required": True,
},
}
},
}

0 comments on commit 417e637

Please sign in to comment.