From 5eaed8c0a3fd17f65d9770a51173e4e3e4f78a52 Mon Sep 17 00:00:00 2001 From: Bharat Sinha Date: Thu, 6 Jun 2024 12:37:31 -0600 Subject: [PATCH 1/3] Add handling for backend error --- ratelimit/backends/redis.py | 29 ++++++++----- ratelimit/core.py | 21 ++++++++- ratelimit/exceptions/__init__.py | 3 ++ ratelimit/exceptions/backend_connection.py | 9 ++++ ratelimit/exceptions/base_backend.py | 6 +++ tests/test_core.py | 50 ++++++++++++++++++++++ 6 files changed, 106 insertions(+), 12 deletions(-) create mode 100644 ratelimit/exceptions/__init__.py create mode 100644 ratelimit/exceptions/backend_connection.py create mode 100644 ratelimit/exceptions/base_backend.py diff --git a/ratelimit/backends/redis.py b/ratelimit/backends/redis.py index cfa6668..846c562 100644 --- a/ratelimit/backends/redis.py +++ b/ratelimit/backends/redis.py @@ -1,7 +1,9 @@ import json from redis.asyncio import StrictRedis +from redis.exceptions import ConnectionError +from ..exceptions import BackendConnectionException from ..rule import Rule from . import BaseBackend @@ -41,17 +43,22 @@ async def is_blocking(self, user: str) -> int: return int(await self._redis.ttl(f"blocking:{user}")) async def retry_after(self, path: str, user: str, rule: Rule) -> int: - block_time = await self.is_blocking(user) - if block_time > 0: - return block_time + try: + block_time = await self.is_blocking(user) + if block_time > 0: + return block_time - ruleset = rule.ruleset(path, user) - retry_after = int( - await self.lua_script(keys=list(ruleset.keys()), args=[json.dumps(ruleset)]) - ) + ruleset = rule.ruleset(path, user) + retry_after = int( + await self.lua_script( + keys=list(ruleset.keys()), args=[json.dumps(ruleset)] + ) + ) - if retry_after > 0 and rule.block_time: - await self.set_block_time(user, rule.block_time) - retry_after = rule.block_time + if retry_after > 0 and rule.block_time: + await self.set_block_time(user, rule.block_time) + retry_after = rule.block_time - return retry_after + return retry_after + except ConnectionError as ce: + raise BackendConnectionException(f"Error connecting to Redis: {ce}") diff --git a/ratelimit/core.py b/ratelimit/core.py index 36146b8..7dce35e 100644 --- a/ratelimit/core.py +++ b/ratelimit/core.py @@ -3,6 +3,7 @@ from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple from .backends import BaseBackend +from .exceptions import BaseBackendException from .rule import RULENAMES, Rule from .types import ASGIApp, Receive, Scope, Send @@ -23,6 +24,19 @@ async def default_429(scope: Scope, receive: Receive, send: Send) -> None: return default_429 +def _on_backend_error(err) -> ASGIApp: + async def default_503(scope: Scope, receive: Receive, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": 503, + } + ) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + return default_503 + + class RateLimitMiddleware: """ rate limit middleware @@ -37,6 +51,7 @@ def __init__( *, on_auth_error: Optional[Callable[[Exception], Awaitable[ASGIApp]]] = None, on_blocked: Callable[[int], ASGIApp] = _on_blocked, + on_backend_error: Callable[[int], ASGIApp] = _on_backend_error, ) -> None: self.app = app self.authenticate = authenticate @@ -53,6 +68,7 @@ def __init__( self.on_auth_error = on_auth_error self.on_blocked = on_blocked + self.on_backend_error = on_backend_error async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": # pragma: no cover @@ -90,7 +106,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return await self.app(scope, receive, send) path: str = url_path if rule.zone is None else rule.zone - retry_after = await self.backend.retry_after(path, user, rule) + try: + retry_after = await self.backend.retry_after(path, user, rule) + except BaseBackendException as be: + return await self.on_backend_error(be)(scope, receive, send) if retry_after == 0: return await self.app(scope, receive, send) diff --git a/ratelimit/exceptions/__init__.py b/ratelimit/exceptions/__init__.py new file mode 100644 index 0000000..7358eaa --- /dev/null +++ b/ratelimit/exceptions/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa: F401 +from .backend_connection import BackendConnectionException +from .base_backend import BaseBackendException diff --git a/ratelimit/exceptions/backend_connection.py b/ratelimit/exceptions/backend_connection.py new file mode 100644 index 0000000..cec419e --- /dev/null +++ b/ratelimit/exceptions/backend_connection.py @@ -0,0 +1,9 @@ +from .base_backend import BaseBackendException + + +class BackendConnectionException(BaseBackendException): + """ + Backend exception for ConnectionError + """ + + pass diff --git a/ratelimit/exceptions/base_backend.py b/ratelimit/exceptions/base_backend.py new file mode 100644 index 0000000..4bc8abb --- /dev/null +++ b/ratelimit/exceptions/base_backend.py @@ -0,0 +1,6 @@ +class BaseBackendException(Exception): + """ + Base class for exception raised by Backends + """ + + pass diff --git a/tests/test_core.py b/tests/test_core.py index 791dd38..a458462 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -143,6 +143,56 @@ async def inside_yourself_429(scope: Scope, receive: Receive, send: Send) -> Non return inside_yourself_429 +@pytest.mark.asyncio +async def test_on_backend_error(): + # use incorrect port to force connection error + rate_limit = RateLimitMiddleware( + hello_world, + authenticate=auth_func, + backend=RedisBackend(StrictRedis(port=6369)), + config={r"/": [Rule(second=1), Rule(group="admin")]}, + ) + + async with httpx.AsyncClient( + app=rate_limit, base_url="http://testserver" + ) as client: # type: httpx.AsyncClient + response = await client.get("/", headers={"user": "user", "group": "default"}) + assert response.status_code == 503 + + +@pytest.mark.asyncio +async def test_custom_on_backend_error(): + # use incorrect port to force connection error + rate_limit = RateLimitMiddleware( + hello_world, + authenticate=auth_func, + backend=RedisBackend(StrictRedis(port=6369)), + config={r"/": [Rule(second=1), Rule(group="admin")]}, + on_backend_error=yourself_503, + ) + + async with httpx.AsyncClient( + app=rate_limit, base_url="http://testserver" + ) as client: # type: httpx.AsyncClient + response = await client.get("/", headers={"user": "user", "group": "default"}) + assert response.status_code == 503 + assert response.text == "custom 503 page" + + +def yourself_503(retry_after: int): + async def inside_yourself_503(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "http.response.start", "status": 503}) + await send( + { + "type": "http.response.body", + "body": b"custom 503 page", + "more_body": False, + } + ) + + return inside_yourself_503 + + @pytest.mark.asyncio async def test_custom_blocked(): rate_limit = RateLimitMiddleware( From b8c3ff63df1bb9c9edf94ab733f7ade96d48492e Mon Sep 17 00:00:00 2001 From: Bharat Sinha Date: Tue, 30 Jul 2024 21:41:09 -0600 Subject: [PATCH 2/3] Fix typing, move exceptions to single file --- ratelimit/core.py | 4 ++-- .../{exceptions/backend_connection.py => exceptions.py} | 7 ++++++- ratelimit/exceptions/__init__.py | 3 --- ratelimit/exceptions/base_backend.py | 6 ------ tests/test_core.py | 2 +- 5 files changed, 9 insertions(+), 13 deletions(-) rename ratelimit/{exceptions/backend_connection.py => exceptions.py} (52%) delete mode 100644 ratelimit/exceptions/__init__.py delete mode 100644 ratelimit/exceptions/base_backend.py diff --git a/ratelimit/core.py b/ratelimit/core.py index 7dce35e..9c2c941 100644 --- a/ratelimit/core.py +++ b/ratelimit/core.py @@ -24,7 +24,7 @@ async def default_429(scope: Scope, receive: Receive, send: Send) -> None: return default_429 -def _on_backend_error(err) -> ASGIApp: +def _on_backend_error(err: Exception) -> ASGIApp: async def default_503(scope: Scope, receive: Receive, send: Send) -> None: await send( { @@ -51,7 +51,7 @@ def __init__( *, on_auth_error: Optional[Callable[[Exception], Awaitable[ASGIApp]]] = None, on_blocked: Callable[[int], ASGIApp] = _on_blocked, - on_backend_error: Callable[[int], ASGIApp] = _on_backend_error, + on_backend_error: Callable[[Exception], ASGIApp] = _on_backend_error, ) -> None: self.app = app self.authenticate = authenticate diff --git a/ratelimit/exceptions/backend_connection.py b/ratelimit/exceptions.py similarity index 52% rename from ratelimit/exceptions/backend_connection.py rename to ratelimit/exceptions.py index cec419e..e868c83 100644 --- a/ratelimit/exceptions/backend_connection.py +++ b/ratelimit/exceptions.py @@ -1,4 +1,9 @@ -from .base_backend import BaseBackendException +class BaseBackendException(Exception): + """ + Base class for exception raised by Backends + """ + + pass class BackendConnectionException(BaseBackendException): diff --git a/ratelimit/exceptions/__init__.py b/ratelimit/exceptions/__init__.py deleted file mode 100644 index 7358eaa..0000000 --- a/ratelimit/exceptions/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# flake8: noqa: F401 -from .backend_connection import BackendConnectionException -from .base_backend import BaseBackendException diff --git a/ratelimit/exceptions/base_backend.py b/ratelimit/exceptions/base_backend.py deleted file mode 100644 index 4bc8abb..0000000 --- a/ratelimit/exceptions/base_backend.py +++ /dev/null @@ -1,6 +0,0 @@ -class BaseBackendException(Exception): - """ - Base class for exception raised by Backends - """ - - pass diff --git a/tests/test_core.py b/tests/test_core.py index a458462..c5f3e5a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -179,7 +179,7 @@ async def test_custom_on_backend_error(): assert response.text == "custom 503 page" -def yourself_503(retry_after: int): +def yourself_503(err: Exception): async def inside_yourself_503(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "http.response.start", "status": 503}) await send( From e5fc22732e3bea4b3d1697991e7a5e9af93e9d48 Mon Sep 17 00:00:00 2001 From: Bharat Sinha Date: Sat, 3 Aug 2024 15:39:52 -0600 Subject: [PATCH 3/3] Update readme for custom backend error handler --- README.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/README.md b/README.md index 12b0ee9..e7620fc 100644 --- a/README.md +++ b/README.md @@ -239,3 +239,30 @@ async def handle_auth_error(exc: Exception) -> ASGIApp: # await send({"type": "http.response.start", "status": 429}) return response ``` + +### Custom backend error handler + +Normally exceptions raised in the backend due to Connection/Network errors result in an Internal Server Error, but you can pass a function to handle the errors and send the appropriate response back to the user. For example, if you're using FastAPI or Starlette: + +```python +from fastapi.responses import JSONResponse +from ratelimit.types import ASGIApp + +async def handle_backend_error(exc: Exception) -> ASGIApp: + return JSONResponse({"message": "Cache unavailable."}, status_code=500) + +RateLimitMiddleware(..., on_backend_error=handle_backend_error) +``` + +For advanced usage you can handle the response completely by yourself: + +```python +from fastapi.responses import JSONResponse +from ratelimit.types import ASGIApp, Scope, Receive, Send + +async def handle_backend_error(exc: Exception) -> ASGIApp: + async def response(scope: Scope, receive: Receive, send: Send): + # do something here e.g. + # await send({"type": "http.response.start", "status": 500}) + return response +```