diff --git a/aiohttp_deps/swagger.py b/aiohttp_deps/swagger.py index 348ae5d..3c0389d 100644 --- a/aiohttp_deps/swagger.py +++ b/aiohttp_deps/swagger.py @@ -85,6 +85,17 @@ def dummy(_var: annotation.annotation) -> None: # type: ignore return var == Optional[var] +def _get_param_schema(annotation: Optional[inspect.Parameter]) -> Dict[str, Any]: + if annotation is None or annotation.annotation == annotation.empty: + return {} + + def dummy(_var: annotation.annotation) -> None: # type: ignore + """Dummy function to use for type resolution.""" + + var = get_type_hints(dummy).get("_var") + return pydantic.TypeAdapter(var).json_schema(ref_template=REF_TEMPLATE) + + def _add_route_def( # noqa: C901, WPS210, WPS211 openapi_schema: Dict[str, Any], route: web.ResourceRoute, @@ -140,25 +151,33 @@ def _insert_in_params(data: Dict[str, Any]) -> None: "content": {content_type: {}}, } elif isinstance(dependency.dependency, Query): + schema = _get_param_schema(dependency.signature) + openapi_schema["components"]["schemas"].update(schema.pop("$defs", {})) _insert_in_params( { "name": dependency.dependency.alias or dependency.param_name, "in": "query", "description": dependency.dependency.description, "required": not _is_optional(dependency.signature), + "schema": schema, }, ) elif isinstance(dependency.dependency, Header): name = dependency.dependency.alias or dependency.param_name + schema = _get_param_schema(dependency.signature) + openapi_schema["components"]["schemas"].update(schema.pop("$defs", {})) _insert_in_params( { "name": name.capitalize(), "in": "header", "description": dependency.dependency.description, "required": not _is_optional(dependency.signature), + "schema": schema, }, ) elif isinstance(dependency.dependency, Path): + schema = _get_param_schema(dependency.signature) + openapi_schema["components"]["schemas"].update(schema.pop("$defs", {})) _insert_in_params( { "name": dependency.dependency.alias or dependency.param_name, @@ -166,6 +185,7 @@ def _insert_in_params(data: Dict[str, Any]) -> None: "description": dependency.dependency.description, "required": not _is_optional(dependency.signature), "allowEmptyValue": _is_optional(dependency.signature), + "schema": schema, }, ) diff --git a/pyproject.toml b/pyproject.toml index a2eae69..f3866cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "aiohttp-deps" description = "Dependency injection for AioHTTP" authors = ["Taskiq team "] maintainers = ["Taskiq team "] -version = "0.3.1" +version = "0.3.2" readme = "README.md" license = "LICENSE" classifiers = [ diff --git a/tests/test_swagger.py b/tests/test_swagger.py index 2715253..4da3f68 100644 --- a/tests/test_swagger.py +++ b/tests/test_swagger.py @@ -215,6 +215,7 @@ async def my_handler(my_var: int = Depends(Query(description="desc"))): "required": True, "in": "query", "description": "desc", + "schema": {"type": "integer"}, } @@ -242,6 +243,7 @@ async def my_handler(my_var: Optional[int] = Depends(Query())): "required": False, "in": "query", "description": "", + "schema": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, } @@ -269,6 +271,7 @@ async def my_handler(my_var: int = Depends(Query(alias="qqq"))): "required": True, "in": "query", "description": "", + "schema": {"type": "integer"}, } @@ -278,7 +281,13 @@ async def my_handler(my_var: int = Depends(Query(alias="qqq"))): ( ( Query(), - {"name": "my_var", "required": True, "in": "query", "description": ""}, + { + "name": "my_var", + "required": True, + "in": "query", + "description": "", + "schema": {"type": "integer"}, + }, ), ( Query(description="my query"), @@ -287,15 +296,28 @@ async def my_handler(my_var: int = Depends(Query(alias="qqq"))): "required": True, "in": "query", "description": "my query", + "schema": {"type": "integer"}, }, ), ( Query(alias="a"), - {"name": "a", "required": True, "in": "query", "description": ""}, + { + "name": "a", + "required": True, + "in": "query", + "description": "", + "schema": {"type": "integer"}, + }, ), ( Header(), - {"name": "My_var", "required": True, "in": "header", "description": ""}, + { + "name": "My_var", + "required": True, + "in": "header", + "description": "", + "schema": {"type": "integer"}, + }, ), ( Header(description="my header"), @@ -304,11 +326,18 @@ async def my_handler(my_var: int = Depends(Query(alias="qqq"))): "required": True, "in": "header", "description": "my header", + "schema": {"type": "integer"}, }, ), ( Header(alias="a"), - {"name": "A", "required": True, "in": "header", "description": ""}, + { + "name": "A", + "required": True, + "in": "header", + "description": "", + "schema": {"type": "integer"}, + }, ), ( Path(), @@ -318,6 +347,7 @@ async def my_handler(my_var: int = Depends(Query(alias="qqq"))): "in": "path", "description": "", "allowEmptyValue": False, + "schema": {"type": "integer"}, }, ), ( @@ -328,6 +358,7 @@ async def my_handler(my_var: int = Depends(Query(alias="qqq"))): "in": "path", "description": "my path", "allowEmptyValue": False, + "schema": {"type": "integer"}, }, ), ( @@ -338,6 +369,7 @@ async def my_handler(my_var: int = Depends(Query(alias="qqq"))): "in": "path", "description": "", "allowEmptyValue": False, + "schema": {"type": "integer"}, }, ), ), @@ -364,6 +396,65 @@ async def my_handler(my_var: int = Depends(dependecy)): assert handler_info["parameters"][0] == param_info +@pytest.mark.anyio +@pytest.mark.parametrize( + ["dependecy", "param_info"], + ( + ( + Query(), + { + "name": "my_var", + "required": False, + "in": "query", + "description": "", + "schema": {}, + }, + ), + ( + Header(), + { + "name": "My_var", + "required": False, + "in": "header", + "description": "", + "schema": {}, + }, + ), + ( + Path(), + { + "name": "my_var", + "required": False, + "in": "path", + "description": "", + "allowEmptyValue": True, + "schema": {}, + }, + ), + ), +) +async def test_parameters_untyped( + my_app: web.Application, + aiohttp_client: ClientGenerator, + dependecy: Any, + param_info: Dict[str, Any], +): + OPENAPI_URL = "/my_api_def.json" + my_app.on_startup.append(setup_swagger(schema_url=OPENAPI_URL)) + + async def my_handler(my_var=Depends(dependecy)): + """Nothing.""" + + my_app.router.add_get("/a", my_handler) + + client = await aiohttp_client(my_app) + resp = await client.get(OPENAPI_URL) + assert resp.status == 200 + resp_json = await resp.json() + handler_info = resp_json["paths"]["/a"]["get"] + assert handler_info["parameters"][0] == param_info + + @pytest.mark.anyio async def test_view_success( my_app: web.Application,