Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add handling for backend error #75

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions ratelimit/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json

from redis.asyncio import StrictRedis
from redis.exceptions import ConnectionError

from ..exceptions import BackendConnectionException
from ..rule import Rule
from . import BaseBackend

Expand Down Expand Up @@ -41,17 +43,22 @@ async def is_blocking(self, user: str) -> int:
return int(await self._redis.ttl(f"blocking:{user}"))

async def retry_after(self, path: str, user: str, rule: Rule) -> int:
block_time = await self.is_blocking(user)
if block_time > 0:
return block_time
try:
block_time = await self.is_blocking(user)
if block_time > 0:
return block_time

ruleset = rule.ruleset(path, user)
retry_after = int(
await self.lua_script(keys=list(ruleset.keys()), args=[json.dumps(ruleset)])
)
ruleset = rule.ruleset(path, user)
retry_after = int(
await self.lua_script(
keys=list(ruleset.keys()), args=[json.dumps(ruleset)]
)
)

if retry_after > 0 and rule.block_time:
await self.set_block_time(user, rule.block_time)
retry_after = rule.block_time
if retry_after > 0 and rule.block_time:
await self.set_block_time(user, rule.block_time)
retry_after = rule.block_time

return retry_after
return retry_after
except ConnectionError as ce:
raise BackendConnectionException(f"Error connecting to Redis: {ce}")
21 changes: 20 additions & 1 deletion ratelimit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple

from .backends import BaseBackend
from .exceptions import BaseBackendException
from .rule import RULENAMES, Rule
from .types import ASGIApp, Receive, Scope, Send

Expand All @@ -23,6 +24,19 @@ async def default_429(scope: Scope, receive: Receive, send: Send) -> None:
return default_429


def _on_backend_error(err) -> ASGIApp:
async def default_503(scope: Scope, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": 503,
}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})

return default_503


class RateLimitMiddleware:
"""
rate limit middleware
Expand All @@ -37,6 +51,7 @@ def __init__(
*,
on_auth_error: Optional[Callable[[Exception], Awaitable[ASGIApp]]] = None,
on_blocked: Callable[[int], ASGIApp] = _on_blocked,
on_backend_error: Callable[[int], ASGIApp] = _on_backend_error,
Bharat23 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.app = app
self.authenticate = authenticate
Expand All @@ -53,6 +68,7 @@ def __init__(

self.on_auth_error = on_auth_error
self.on_blocked = on_blocked
self.on_backend_error = on_backend_error

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http": # pragma: no cover
Expand Down Expand Up @@ -90,7 +106,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return await self.app(scope, receive, send)

path: str = url_path if rule.zone is None else rule.zone
retry_after = await self.backend.retry_after(path, user, rule)
try:
retry_after = await self.backend.retry_after(path, user, rule)
except BaseBackendException as be:
return await self.on_backend_error(be)(scope, receive, send)
if retry_after == 0:
return await self.app(scope, receive, send)

Expand Down
3 changes: 3 additions & 0 deletions ratelimit/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa: F401
from .backend_connection import BackendConnectionException
Bharat23 marked this conversation as resolved.
Show resolved Hide resolved
from .base_backend import BaseBackendException
9 changes: 9 additions & 0 deletions ratelimit/exceptions/backend_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .base_backend import BaseBackendException


class BackendConnectionException(BaseBackendException):
"""
Backend exception for ConnectionError
"""

pass
6 changes: 6 additions & 0 deletions ratelimit/exceptions/base_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class BaseBackendException(Exception):
"""
Base class for exception raised by Backends
"""

pass
50 changes: 50 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,56 @@ async def inside_yourself_429(scope: Scope, receive: Receive, send: Send) -> Non
return inside_yourself_429


@pytest.mark.asyncio
async def test_on_backend_error():
# use incorrect port to force connection error
rate_limit = RateLimitMiddleware(
hello_world,
authenticate=auth_func,
backend=RedisBackend(StrictRedis(port=6369)),
config={r"/": [Rule(second=1), Rule(group="admin")]},
)

async with httpx.AsyncClient(
app=rate_limit, base_url="http://testserver"
) as client: # type: httpx.AsyncClient
response = await client.get("/", headers={"user": "user", "group": "default"})
assert response.status_code == 503


@pytest.mark.asyncio
async def test_custom_on_backend_error():
# use incorrect port to force connection error
rate_limit = RateLimitMiddleware(
hello_world,
authenticate=auth_func,
backend=RedisBackend(StrictRedis(port=6369)),
config={r"/": [Rule(second=1), Rule(group="admin")]},
on_backend_error=yourself_503,
)

async with httpx.AsyncClient(
app=rate_limit, base_url="http://testserver"
) as client: # type: httpx.AsyncClient
response = await client.get("/", headers={"user": "user", "group": "default"})
assert response.status_code == 503
assert response.text == "custom 503 page"


def yourself_503(retry_after: int):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint is incorrect.

async def inside_yourself_503(scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "http.response.start", "status": 503})
await send(
{
"type": "http.response.body",
"body": b"custom 503 page",
"more_body": False,
}
)

return inside_yourself_503


@pytest.mark.asyncio
async def test_custom_blocked():
rate_limit = RateLimitMiddleware(
Expand Down
Loading