diff --git a/starlette_plus/middleware/ratelimiter.py b/starlette_plus/middleware/ratelimiter.py index 6cf2a3f..43e18a2 100644 --- a/starlette_plus/middleware/ratelimiter.py +++ b/starlette_plus/middleware/ratelimiter.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: - from starlette.routing import Route + from starlette.routing import Mount, WebSocketRoute from starlette.types import ASGIApp, Receive, Scope, Send from ..redis import Redis @@ -70,12 +70,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request: Request = Request(scope) forwarded: str | None = request.headers.get("X-Forwarded-For", None) - routes: list[Route] = scope["app"].routes - route: Route | None = None + routes: list[Route | Mount | WebSocketRoute] = scope["app"].routes + route: Route | Mount | WebSocketRoute | None = None for r in routes: - methods: set[str] = r.methods or set() - if r.path == request.url.path and request.method in methods: + methods: set[str] | None = r.methods if isinstance(r, Route) else None + if r.path != request.url.path: + continue + + if not methods or request.method in methods: route = r break