Skip to content

Commit

Permalink
Merge pull request #23 from taskiq-python/feature/deps-extra-swagger
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius authored Sep 30, 2024
2 parents b2e6742 + 33c7650 commit b1af5ac
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 14 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,15 @@ async def my_handler(var: str = Depends(Path())):
```


## Overridiing dependencies
## ExtraOpenAPI

This dependency is used to add additional swagger fields to the endpoint's swagger
that is using this dependency. It might be even indirect dependency.

You can check how this thing can be used in our [examples/swagger_auth.py](https://github.com/taskiq-python/aiohttp-deps/tree/master/examples/swagger_auth.py).


## Overriding dependencies

Sometimes for tests you don't want to calculate actual functions
and you want to pass another functions instead.
Expand Down
3 changes: 2 additions & 1 deletion aiohttp_deps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from aiohttp_deps.keys import DEPENDENCY_OVERRIDES_KEY, VALUES_OVERRIDES_KEY
from aiohttp_deps.router import Router
from aiohttp_deps.swagger import extra_openapi, openapi_response, setup_swagger
from aiohttp_deps.utils import Form, Header, Json, Path, Query
from aiohttp_deps.utils import ExtraOpenAPI, Form, Header, Json, Path, Query
from aiohttp_deps.view import View

__all__ = [
Expand All @@ -21,6 +21,7 @@
"Query",
"Form",
"Path",
"ExtraOpenAPI",
"openapi_response",
"DEPENDENCY_OVERRIDES_KEY",
"VALUES_OVERRIDES_KEY",
Expand Down
38 changes: 27 additions & 11 deletions aiohttp_deps/swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
TypeVar,
Expand All @@ -19,7 +20,7 @@

from aiohttp_deps.initializer import InjectableFuncHandler, InjectableViewHandler
from aiohttp_deps.keys import SWAGGER_SCHEMA_KEY
from aiohttp_deps.utils import Form, Header, Json, Path, Query
from aiohttp_deps.utils import ExtraOpenAPI, Form, Header, Json, Path, Query

_T = TypeVar("_T")

Expand Down Expand Up @@ -99,7 +100,7 @@ def dummy(_var: annotation.annotation) -> None: # type: ignore
)


def _add_route_def( # noqa: C901
def _add_route_def( # noqa: C901, PLR0912
openapi_schema: Dict[str, Any],
route: web.ResourceRoute,
method: str,
Expand All @@ -119,6 +120,7 @@ def _add_route_def( # noqa: C901
openapi_schema["components"]["schemas"].update(extra_openapi_schemas)

params: Dict[Tuple[str, str], Any] = {}
updaters: List[Callable[[Dict[str, Any]], None]] = []

def _insert_in_params(data: Dict[str, Any]) -> None:
element = params.get((data["name"], data["in"]))
Expand Down Expand Up @@ -191,8 +193,18 @@ def _insert_in_params(data: Dict[str, Any]) -> None:
"schema": schema,
},
)
elif isinstance(dependency.dependency, ExtraOpenAPI):
if dependency.dependency.updater is not None:
updaters.append(dependency.dependency.updater)
if dependency.dependency.extra_openapi is not None:
extra_openapi = always_merger.merge(
extra_openapi,
dependency.dependency.extra_openapi,
)

route_info["parameters"] = list(params.values())
for updater in updaters:
updater(route_info)
openapi_schema["paths"][route.resource.canonical].update(
{method.lower(): always_merger.merge(route_info, extra_openapi)},
)
Expand All @@ -207,6 +219,7 @@ def setup_swagger( # noqa: C901
title: str = "AioHTTP",
description: Optional[str] = None,
version: str = "1.0.0",
extra_openapi: Optional[Dict[str, Any]] = None,
) -> Callable[[web.Application], Awaitable[None]]:
"""
Add swagger documentation.
Expand All @@ -230,8 +243,11 @@ def setup_swagger( # noqa: C901
:param title: Title of an application.
:param description: description of an application.
:param version: version of an application.
:param extra_openapi: extra openAPI dict that will be merged with generated schema.
:return: startup event handler.
"""
if extra_openapi is None:
extra_openapi = {}

async def event_handler(app: web.Application) -> None: # noqa: C901
openapi_schema = {
Expand All @@ -252,12 +268,12 @@ async def event_handler(app: web.Application) -> None: # noqa: C901
if hide_options and route.method.upper() == "OPTIONS":
continue
if isinstance(route._handler, InjectableFuncHandler):
extra_openapi = getattr(
route_extra_openapi = getattr(
route._handler.original_handler,
"__extra_openapi__",
{},
)
extra_schemas = getattr(
route_extra_schemas = getattr(
route._handler.original_handler,
"__extra_openapi_schemas__",
{},
Expand All @@ -268,8 +284,8 @@ async def event_handler(app: web.Application) -> None: # noqa: C901
route, # type: ignore
route.method,
route._handler.graph,
extra_openapi=extra_openapi,
extra_openapi_schemas=extra_schemas,
extra_openapi=route_extra_openapi,
extra_openapi_schemas=route_extra_schemas,
)
except Exception as exc: # pragma: no cover
logger.warn(
Expand All @@ -280,12 +296,12 @@ async def event_handler(app: web.Application) -> None: # noqa: C901

elif isinstance(route._handler, InjectableViewHandler):
for key, graph in route._handler.graph_map.items():
extra_openapi = getattr(
route_extra_openapi = getattr(
getattr(route._handler.original_handler, key),
"__extra_openapi__",
{},
)
extra_schemas = getattr(
route_extra_schemas = getattr(
getattr(route._handler.original_handler, key),
"__extra_openapi_schemas__",
{},
Expand All @@ -296,8 +312,8 @@ async def event_handler(app: web.Application) -> None: # noqa: C901
route, # type: ignore
key,
graph,
extra_openapi=extra_openapi,
extra_openapi_schemas=extra_schemas,
extra_openapi=route_extra_openapi,
extra_openapi_schemas=route_extra_schemas,
)
except Exception as exc: # pragma: no cover
logger.warn(
Expand All @@ -306,7 +322,7 @@ async def event_handler(app: web.Application) -> None: # noqa: C901
exc_info=True,
)

app[SWAGGER_SCHEMA_KEY] = openapi_schema
app[SWAGGER_SCHEMA_KEY] = always_merger.merge(openapi_schema, extra_openapi)

app.router.add_get(
schema_url,
Expand Down
36 changes: 35 additions & 1 deletion aiohttp_deps/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import json
from typing import Any, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

import pydantic
from aiohttp import web
Expand Down Expand Up @@ -344,3 +344,37 @@ def __call__(
headers={"Content-Type": "application/json"},
text=json.dumps(errors),
) from err


class ExtraOpenAPI:
"""
Update swagger for the endpoint.
You can use this dependency to add swagger to an endpoint from
a dependency. It's useful when you want to add some extra swagger
to the route when some specific dependency is used by it.
"""

def __init__(
self,
extra_openapi: Optional[Dict[str, Any]] = None,
swagger_updater: Optional[Callable[[Dict[str, Any]], None]] = None,
) -> None:
"""
Initialize the dependency.
:param swagger_updater: function that takes final swagger endpoint and
updates it.
:param extra_swagger: extra swagger to add to the endpoint. This one might
override other extra_swagger on the endpoint.
"""
self.updater = swagger_updater
self.extra_openapi = extra_openapi

def __call__(self) -> None:
"""
This method is called when dependency is resolved.
It's empty, becuase it's used by the swagger function and
there is no actual dependency.
"""
87 changes: 87 additions & 0 deletions examples/swagger_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import base64

from aiohttp import web
from pydantic import BaseModel

from aiohttp_deps import Depends, ExtraOpenAPI, Header, Router, init, setup_swagger


class UserInfo(BaseModel):
"""Abstract user model."""

id: int
name: str
password: str


router = Router()

# Here we create a simple user storage.
# In real-world applications, you would use a database.
users = {
"john": UserInfo(id=1, name="John Doe", password="123"), # noqa: S106
"caren": UserInfo(id=2, name="Caren Doe", password="321"), # noqa: S106
}


def get_current_user(
# Current auth header.
authorization: str = Depends(Header()),
# We don't need a name to this variable,
# because it will only affect the API schema,
# but won't be used in runtime.
_: None = Depends(
ExtraOpenAPI(
extra_openapi={
"security": [{"basicAuth": []}],
},
),
),
) -> UserInfo:
"""This function checks if the user authorized."""
# Here we check if the authorization header is present.
if not authorization.startswith("Basic"):
raise web.HTTPUnauthorized(reason="Unsupported authorization type")
# We decode credentials from the header.
# And check if the user exists.
creds = base64.b64decode(authorization.split(" ")[1]).decode()
username, password = creds.split(":")
found_user = users.get(username)
if found_user is None:
raise web.HTTPUnauthorized(reason="User not found")
if found_user.password != password:
raise web.HTTPUnauthorized(reason="Invalid password")
return found_user


@router.get("/")
async def index(current_user: UserInfo = Depends(get_current_user)) -> web.Response:
"""Index handler returns current user."""
return web.json_response(current_user.model_dump(mode="json"))


app = web.Application()
app.router.add_routes(router)
app.on_startup.extend(
[
init,
setup_swagger(
# Here we add security schemes used
# to authorize users.
extra_openapi={
"components": {
"securitySchemes": {
# We only support basic auth.
"basicAuth": {
"type": "http",
"scheme": "basic",
},
},
},
},
),
],
)

if __name__ == "__main__":
web.run_app(app)
58 changes: 58 additions & 0 deletions tests/test_swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from aiohttp_deps import (
Depends,
ExtraOpenAPI,
Form,
Header,
Json,
Expand Down Expand Up @@ -780,3 +781,60 @@ async def my_handler() -> None:
schema = await response.json()
assert "get" in schema["paths"]["/"]
assert method.lower() not in schema["paths"]["/"]


@pytest.mark.anyio
async def test_extra_openapi_dep_func(
my_app: web.Application,
aiohttp_client: ClientGenerator,
) -> None:
openapi_url = "/my_api_def.json"
my_app.on_startup.append(setup_swagger(schema_url=openapi_url))

async def dep(
_: None = Depends(ExtraOpenAPI(extra_openapi={"responses": {"200": {}}})),
) -> None:
"""Test dep that adds swagger through a dependency."""

async def my_handler(_: None = Depends(dep)) -> None:
"""Nothing."""

my_app.router.add_get("/a", my_handler)

client = await aiohttp_client(my_app)
resp = await client.get(openapi_url)
assert resp.status == 200
resp_json = await resp.json()

handler_info = resp_json["paths"]["/a"]["get"]
assert handler_info["responses"] == {"200": {}}


@pytest.mark.anyio
async def test_extra_openapi_dep_updater_func(
my_app: web.Application,
aiohttp_client: ClientGenerator,
) -> None:
openapi_url = "/my_api_def.json"
my_app.on_startup.append(setup_swagger(schema_url=openapi_url))

def schema_updater(schema: Dict[str, Any]) -> None:
schema["responses"] = {"200": {}}

async def dep(
_: None = Depends(ExtraOpenAPI(swagger_updater=schema_updater)),
) -> None:
"""Test dep that adds swagger through a dependency."""

async def my_handler(_: None = Depends(dep)) -> None:
"""Nothing."""

my_app.router.add_get("/a", my_handler)

client = await aiohttp_client(my_app)
resp = await client.get(openapi_url)
assert resp.status == 200
resp_json = await resp.json()

handler_info = resp_json["paths"]["/a"]["get"]
assert handler_info["responses"] == {"200": {}}

0 comments on commit b1af5ac

Please sign in to comment.