Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add handling for backend error #75

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,30 @@ async def handle_auth_error(exc: Exception) -> ASGIApp:
# await send({"type": "http.response.start", "status": 429})
return response
```

### Custom backend error handler

Normally exceptions raised in the backend due to Connection/Network errors result in an Internal Server Error, but you can pass a function to handle the errors and send the appropriate response back to the user. For example, if you're using FastAPI or Starlette:

```python
from fastapi.responses import JSONResponse
from ratelimit.types import ASGIApp

async def handle_backend_error(exc: Exception) -> ASGIApp:
return JSONResponse({"message": "Cache unavailable."}, status_code=500)

RateLimitMiddleware(..., on_backend_error=handle_backend_error)
```

For advanced usage you can handle the response completely by yourself:

```python
from fastapi.responses import JSONResponse
from ratelimit.types import ASGIApp, Scope, Receive, Send

async def handle_backend_error(exc: Exception) -> ASGIApp:
async def response(scope: Scope, receive: Receive, send: Send):
# do something here e.g.
# await send({"type": "http.response.start", "status": 500})
return response
```
29 changes: 18 additions & 11 deletions ratelimit/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json

from redis.asyncio import StrictRedis
from redis.exceptions import ConnectionError

from ..exceptions import BackendConnectionException
from ..rule import Rule
from . import BaseBackend

Expand Down Expand Up @@ -41,17 +43,22 @@ async def is_blocking(self, user: str) -> int:
return int(await self._redis.ttl(f"blocking:{user}"))

async def retry_after(self, path: str, user: str, rule: Rule) -> int:
block_time = await self.is_blocking(user)
if block_time > 0:
return block_time
try:
block_time = await self.is_blocking(user)
if block_time > 0:
return block_time

ruleset = rule.ruleset(path, user)
retry_after = int(
await self.lua_script(keys=list(ruleset.keys()), args=[json.dumps(ruleset)])
)
ruleset = rule.ruleset(path, user)
retry_after = int(
await self.lua_script(
keys=list(ruleset.keys()), args=[json.dumps(ruleset)]
)
)

if retry_after > 0 and rule.block_time:
await self.set_block_time(user, rule.block_time)
retry_after = rule.block_time
if retry_after > 0 and rule.block_time:
await self.set_block_time(user, rule.block_time)
retry_after = rule.block_time

return retry_after
return retry_after
except ConnectionError as ce:
raise BackendConnectionException(f"Error connecting to Redis: {ce}")
21 changes: 20 additions & 1 deletion ratelimit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple

from .backends import BaseBackend
from .exceptions import BaseBackendException
from .rule import RULENAMES, Rule
from .types import ASGIApp, Receive, Scope, Send

Expand All @@ -23,6 +24,19 @@ async def default_429(scope: Scope, receive: Receive, send: Send) -> None:
return default_429


def _on_backend_error(err: Exception) -> ASGIApp:
async def default_503(scope: Scope, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": 503,
}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})

return default_503


class RateLimitMiddleware:
"""
rate limit middleware
Expand All @@ -37,6 +51,7 @@ def __init__(
*,
on_auth_error: Optional[Callable[[Exception], Awaitable[ASGIApp]]] = None,
on_blocked: Callable[[int], ASGIApp] = _on_blocked,
on_backend_error: Callable[[Exception], ASGIApp] = _on_backend_error,
) -> None:
self.app = app
self.authenticate = authenticate
Expand All @@ -53,6 +68,7 @@ def __init__(

self.on_auth_error = on_auth_error
self.on_blocked = on_blocked
self.on_backend_error = on_backend_error

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http": # pragma: no cover
Expand Down Expand Up @@ -90,7 +106,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return await self.app(scope, receive, send)

path: str = url_path if rule.zone is None else rule.zone
retry_after = await self.backend.retry_after(path, user, rule)
try:
retry_after = await self.backend.retry_after(path, user, rule)
except BaseBackendException as be:
return await self.on_backend_error(be)(scope, receive, send)
if retry_after == 0:
return await self.app(scope, receive, send)

Expand Down
14 changes: 14 additions & 0 deletions ratelimit/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class BaseBackendException(Exception):
"""
Base class for exception raised by Backends
"""

pass


class BackendConnectionException(BaseBackendException):
"""
Backend exception for ConnectionError
"""

pass
50 changes: 50 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,56 @@ async def inside_yourself_429(scope: Scope, receive: Receive, send: Send) -> Non
return inside_yourself_429


@pytest.mark.asyncio
async def test_on_backend_error():
# use incorrect port to force connection error
rate_limit = RateLimitMiddleware(
hello_world,
authenticate=auth_func,
backend=RedisBackend(StrictRedis(port=6369)),
config={r"/": [Rule(second=1), Rule(group="admin")]},
)

async with httpx.AsyncClient(
app=rate_limit, base_url="http://testserver"
) as client: # type: httpx.AsyncClient
response = await client.get("/", headers={"user": "user", "group": "default"})
assert response.status_code == 503


@pytest.mark.asyncio
async def test_custom_on_backend_error():
# use incorrect port to force connection error
rate_limit = RateLimitMiddleware(
hello_world,
authenticate=auth_func,
backend=RedisBackend(StrictRedis(port=6369)),
config={r"/": [Rule(second=1), Rule(group="admin")]},
on_backend_error=yourself_503,
)

async with httpx.AsyncClient(
app=rate_limit, base_url="http://testserver"
) as client: # type: httpx.AsyncClient
response = await client.get("/", headers={"user": "user", "group": "default"})
assert response.status_code == 503
assert response.text == "custom 503 page"


def yourself_503(err: Exception):
async def inside_yourself_503(scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "http.response.start", "status": 503})
await send(
{
"type": "http.response.body",
"body": b"custom 503 page",
"more_body": False,
}
)

return inside_yourself_503


@pytest.mark.asyncio
async def test_custom_blocked():
rate_limit = RateLimitMiddleware(
Expand Down
Loading