Skip to content

Commit

Permalink
Merge pull request #17 from taskiq-python/feature/app-keys
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius authored Dec 6, 2023
2 parents e69ca2a + ef004ad commit 6e015ca
Show file tree
Hide file tree
Showing 22 changed files with 808 additions and 1,336 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ jobs:
matrix:
cmd:
- black
- flake8
- isort
- mypy
- autoflake
- ruff
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
Expand Down
31 changes: 8 additions & 23 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,6 @@ repos:
language: system
types: [python]

- id: autoflake
name: autoflake
entry: poetry run autoflake
language: system
types: [python]
args: [--in-place, --remove-all-unused-imports, --remove-duplicate-keys]

- id: isort
name: isort
entry: poetry run isort
language: system
types: [python]

- id: flake8
name: Check with Flake8
entry: poetry run flake8
language: system
pass_filenames: false
types: [python]
args: [--count, aiohttp_deps]

- id: mypy
name: Validate types with MyPy
Expand All @@ -51,8 +31,13 @@ repos:
pass_filenames: false
args: [aiohttp_deps]

- id: yesqa
name: Remove usless noqa
entry: poetry run yesqa
- id: ruff
name: Run ruff lints
entry: poetry run ruff
language: system
pass_filenames: false
types: [python]
args:
- "--fix"
- "aiohttp_deps"
- "tests"
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ router = Router()


async def get_db_session(app: web.Application = Depends()):
async with app["db"] as sess:
async with app[web.AppKey("db")] as sess:
yield sess


Expand Down Expand Up @@ -363,7 +363,7 @@ Sometimes for tests you don't want to calculate actual functions
and you want to pass another functions instead.

To do so, you can add "dependency_overrides" or "values_overrides" to the aplication's state.
These values should be dicts.
These values should be dicts. The keys for these values can be found in `aiohttp_deps.keys` module.

Here's an example.

Expand All @@ -385,18 +385,22 @@ where you create your application. And make sure that keys
of that dict are actual function that are being replaced.

```python
my_app["values_overrides"] = {original_dep: 2}
from aiohttp_deps import VALUES_OVERRIDES_KEY

my_app[VALUES_OVERRIDES_KEY] = {original_dep: 2}
```

But `values_overrides` only overrides values. If you want to
But `values_overrides` only overrides returned values. If you want to
override functions, you have to use `dependency_overrides`. Here's an example:

```python
from aiohttp_deps import DEPENDENCY_OVERRIDES_KEY

def replacing_function() -> int:
return 2


my_app["dependency_overrides"] = {original_dep: replacing_function}
my_app[DEPENDENCY_OVERRIDES_KEY] = {original_dep: replacing_function}
```

The cool point about `dependency_overrides`, is that it recalculates graph and
Expand Down
3 changes: 3 additions & 0 deletions aiohttp_deps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from taskiq_dependencies import Depends

from aiohttp_deps.initializer import init
from aiohttp_deps.keys import DEPENDENCY_OVERRIDES_KEY, VALUES_OVERRIDES_KEY
from aiohttp_deps.router import Router
from aiohttp_deps.swagger import extra_openapi, openapi_response, setup_swagger
from aiohttp_deps.utils import Form, Header, Json, Path, Query
Expand All @@ -20,4 +21,6 @@
"Form",
"Path",
"openapi_response",
"DEPENDENCY_OVERRIDES_KEY",
"VALUES_OVERRIDES_KEY",
]
11 changes: 8 additions & 3 deletions aiohttp_deps/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aiohttp import hdrs, web
from taskiq_dependencies import DependencyGraph

from aiohttp_deps.keys import DEPENDENCY_OVERRIDES_KEY, VALUES_OVERRIDES_KEY
from aiohttp_deps.view import View


Expand Down Expand Up @@ -40,13 +41,17 @@ async def __call__(self, request: web.Request) -> web.StreamResponse:
:param request: current request.
:return: response.
"""
# Hack for mypy to work
values_overrides = request.app.get(VALUES_OVERRIDES_KEY)
if values_overrides is None:
values_overrides = {}
async with self.graph.async_ctx(
{
web.Request: request,
web.Application: request.app,
**request.app.get("values_overrides", {}),
**values_overrides,
},
replaced_deps=request.app.get("dependency_overrides"),
replaced_deps=request.app.get(DEPENDENCY_OVERRIDES_KEY),
) as resolver:
return await self.original_handler(**(await resolver.resolve_kwargs()))

Expand All @@ -72,7 +77,7 @@ def __init__(
allowed_methods = {
method.lower()
for method in hdrs.METH_ALL
if hasattr(original_route, method.lower()) # noqa: WPS421
if hasattr(original_route, method.lower())
}
self.graph_map = {
method: DependencyGraph(getattr(original_route, method))
Expand Down
7 changes: 7 additions & 0 deletions aiohttp_deps/keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Any, Dict

from aiohttp import web

SWAGGER_SCHEMA_KEY = web.AppKey("openapi_schema", Dict[str, Any])
VALUES_OVERRIDES_KEY = web.AppKey("values_overrides", Dict[Any, Any])
DEPENDENCY_OVERRIDES_KEY = web.AppKey("dependency_overrides", Dict[Any, Any])
13 changes: 9 additions & 4 deletions aiohttp_deps/router.pyi
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Any, Awaitable, Callable, Iterable, Type, Union

import typing_extensions
from aiohttp import web
from aiohttp.abc import AbstractView

_Handler = Callable[..., Awaitable[web.StreamResponse]]
_Handler: typing_extensions.TypeAlias = Callable[..., Awaitable[web.StreamResponse]]

_ViewedHandler = Union[Type[AbstractView], _Handler]
_ViewedHandler: typing_extensions.TypeAlias = Union[Type[AbstractView], _Handler]

_Deco = Callable[[_ViewedHandler], _ViewedHandler]
_Deco: typing_extensions.TypeAlias = Callable[[_ViewedHandler], _ViewedHandler]

class Router(web.RouteTableDef):
def head(self, path: str, **kwargs: Any) -> _Deco: ...
Expand All @@ -18,4 +19,8 @@ class Router(web.RouteTableDef):
def delete(self, path: str, **kwargs: Any) -> _Deco: ...
def options(self, path: str, **kwargs: Any) -> _Deco: ...
def view(self, path: str, **kwargs: Any) -> _Deco: ...
def add_routes(self, router: Iterable[web.AbstractRouteDef], prefix: str = "") -> None: ...
def add_routes(
self,
router: Iterable[web.AbstractRouteDef],
prefix: str = "",
) -> None: ...
14 changes: 7 additions & 7 deletions aiohttp_deps/swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from taskiq_dependencies import DependencyGraph

from aiohttp_deps.initializer import InjectableFuncHandler, InjectableViewHandler
from aiohttp_deps.keys import SWAGGER_SCHEMA_KEY
from aiohttp_deps.utils import Form, Header, Json, Path, Query

_T = TypeVar("_T") # noqa: WPS111
_T = TypeVar("_T")

REF_TEMPLATE = "#/components/schemas/{model}"
SCHEMA_KEY = "openapi_schema"
SWAGGER_HTML_TEMPALTE = """
<html lang="en">
Expand Down Expand Up @@ -52,15 +52,15 @@
</body>
</html>
"""
METHODS_WITH_BODY = {"POST", "PUT", "PATCH"} # noqa: WPS407
METHODS_WITH_BODY = {"POST", "PUT", "PATCH"}

logger = getLogger()


async def _schema_handler(
request: web.Request,
) -> web.Response:
return web.json_response(request.app[SCHEMA_KEY])
return web.json_response(request.app[SWAGGER_SCHEMA_KEY])


def _get_swagger_handler(
Expand Down Expand Up @@ -99,7 +99,7 @@ def dummy(_var: annotation.annotation) -> None: # type: ignore
)


def _add_route_def( # noqa: C901, WPS210, WPS211
def _add_route_def( # noqa: C901
openapi_schema: Dict[str, Any],
route: web.ResourceRoute,
method: str,
Expand Down Expand Up @@ -198,7 +198,7 @@ def _insert_in_params(data: Dict[str, Any]) -> None:
)


def setup_swagger( # noqa: C901, WPS211
def setup_swagger( # noqa: C901
schema_url: str = "/openapi.json",
swagger_ui_url: str = "/docs",
enable_ui: bool = True,
Expand Down Expand Up @@ -302,7 +302,7 @@ async def event_handler(app: web.Application) -> None:
exc_info=True,
)

app[SCHEMA_KEY] = openapi_schema
app[SWAGGER_SCHEMA_KEY] = openapi_schema

app.router.add_get(
schema_url,
Expand Down
24 changes: 12 additions & 12 deletions aiohttp_deps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def __init__(
alias: Optional[str] = None,
multiple: bool = False,
description: str = "",
):
) -> None:
self.default = default
self.alias = alias
self.multiple = multiple
self.description = description
self.type_initialized = False
self.type_cache: "Union[pydantic.TypeAdapter[Any], None]" = None

def __call__( # noqa: C901
def __call__(
self,
param_info: ParamInfo = Depends(),
request: web.Request = Depends(),
Expand Down Expand Up @@ -85,7 +85,7 @@ def __call__( # noqa: C901
raise web.HTTPBadRequest(
headers={"Content-Type": "application/json"},
text=json.dumps(errors),
)
) from err


class Json:
Expand All @@ -100,7 +100,7 @@ def __init__(self) -> None:
self.type_initialized = False
self.type_cache: "Union[pydantic.TypeAdapter[Any], None]" = None

async def __call__( # noqa: C901
async def __call__(
self,
param_info: ParamInfo = Depends(),
request: web.Request = Depends(),
Expand Down Expand Up @@ -142,7 +142,7 @@ async def __call__( # noqa: C901
raise web.HTTPBadRequest(
headers={"Content-Type": "application/json"},
text=json.dumps(errors),
)
) from err


class Query:
Expand All @@ -165,15 +165,15 @@ def __init__(
alias: Optional[str] = None,
multiple: bool = False,
description: str = "",
):
) -> None:
self.default = default
self.alias = alias
self.multiple = multiple
self.description = description
self.type_initialized = False
self.type_cache: "Union[pydantic.TypeAdapter[Any], None]" = None

def __call__( # noqa: C901
def __call__(
self,
param_info: ParamInfo = Depends(),
request: web.Request = Depends(),
Expand Down Expand Up @@ -223,7 +223,7 @@ def __call__( # noqa: C901
raise web.HTTPBadRequest(
headers={"Content-Type": "application/json"},
text=json.dumps(errors),
)
) from err


class Form:
Expand All @@ -240,7 +240,7 @@ def __init__(self) -> None:
self.type_initialized = False
self.type_cache: "Union[pydantic.TypeAdapter[Any], None]" = None

async def __call__( # noqa: C901
async def __call__(
self,
param_info: ParamInfo = Depends(),
request: web.Request = Depends(),
Expand Down Expand Up @@ -279,7 +279,7 @@ async def __call__( # noqa: C901
raise web.HTTPBadRequest(
headers={"Content-Type": "application/json"},
text=json.dumps(errors),
)
) from err


class Path:
Expand All @@ -304,7 +304,7 @@ def __init__(
self.type_initialized = False
self.type_cache: "Union[pydantic.TypeAdapter[Any], None]" = None

def __call__( # noqa: C901
def __call__(
self,
param_info: ParamInfo = Depends(),
request: web.Request = Depends(),
Expand Down Expand Up @@ -343,4 +343,4 @@ def __call__( # noqa: C901
raise web.HTTPBadRequest(
headers={"Content-Type": "application/json"},
text=json.dumps(errors),
)
) from err
9 changes: 7 additions & 2 deletions aiohttp_deps/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from aiohttp.web_response import StreamResponse
from taskiq_dependencies import DependencyGraph

from aiohttp_deps.keys import DEPENDENCY_OVERRIDES_KEY, VALUES_OVERRIDES_KEY


class View(web.View):
"""
Expand Down Expand Up @@ -39,12 +41,15 @@ async def _iter(self) -> StreamResponse:
)
if method is None:
self._raise_allowed_methods()
values_overrides = self.request.app.get(VALUES_OVERRIDES_KEY)
if values_overrides is None:
values_overrides = {}
async with self._graph_map[self.request.method.lower()].async_ctx(
{
web.Request: self.request,
web.Application: self.request.app,
**self.request.app.get("values_overrides", {}),
**values_overrides,
},
replaced_deps=self.request.app.get("dependency_overrides"),
replaced_deps=self.request.app.get(DEPENDENCY_OVERRIDES_KEY),
) as ctx:
return await method(**(await ctx.resolve_kwargs())) # type: ignore
Loading

0 comments on commit 6e015ca

Please sign in to comment.