From f21169a02b02459fc5bc48491e7330ae447f870f Mon Sep 17 00:00:00 2001 From: MystyPy Date: Sun, 19 May 2024 02:41:29 +1000 Subject: [PATCH] Perform periodic health checks on Redis, default to in-memory after a failure. --- .gitignore | 5 +++- starlette_plus/core.py | 22 +--------------- starlette_plus/limiter.py | 4 +-- starlette_plus/middleware/sessions.py | 18 ++++++------- starlette_plus/redis.py | 37 ++++++++++++++++++++++++++- 5 files changed, 51 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index 6769e21..e9bbef8 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ + +# Test files +test*.py \ No newline at end of file diff --git a/starlette_plus/core.py b/starlette_plus/core.py index de8388c..ce5b5a8 100644 --- a/starlette_plus/core.py +++ b/starlette_plus/core.py @@ -34,7 +34,6 @@ if TYPE_CHECKING: from starlette.types import ASGIApp, Message, Receive, Scope, Send - from .redis import Redis from .types_.core import Methods, RouteOptions from .types_.limiter import BucketType, ExemptCallable, RateLimitData @@ -192,31 +191,12 @@ def __init__(self, *args: Any, **kwargs: Unpack[ApplicationOptions]) -> None: middleware_: list[Middleware] = kwargs.pop("middleware", []) middleware_.insert(0, Middleware(LoggingMiddleware)) if self._access_log else None - statrtups = kwargs.pop("on_startup", []) - statrtups.append(self.__startup) - - super().__init__(*args, **kwargs, middleware=middleware_, on_startup=statrtups) # type: ignore + super().__init__(*args, **kwargs, middleware=middleware_) # type: ignore self.add_view(self) for view in views: self.add_view(view) - async def __startup(self) -> None: - for middleware in self.user_middleware: - redis: Redis | None = middleware.kwargs.get("redis", None) # type: ignore - - if not redis: - continue - - try: - resp: bool = await redis.ping() - except Exception: - resp = False - - if not resp: - logger.warning("Unable to connect to redis on %s, defaulting to in-memory.", middleware.cls.__name__) - middleware.kwargs["redis"] = None - def __new__(cls, *args: Any, **kwargs: Any) -> Self: self: Self = super().__new__(cls) self.__routes__ = [] diff --git a/starlette_plus/limiter.py b/starlette_plus/limiter.py index ccfc5d2..3186730 100644 --- a/starlette_plus/limiter.py +++ b/starlette_plus/limiter.py @@ -50,14 +50,14 @@ def __init__(self, redis: Redis | None = None) -> None: async def get_tat(self, key: str, /) -> datetime.datetime: now: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) - if self.redis: + if self.redis and self.redis.could_connect: value: str | None = await self.redis.pool.get(key) # type: ignore return datetime.datetime.fromisoformat(value) if value else now # type: ignore return self._keys.get(key, {"tat": now}).get("tat", now) async def set_tat(self, key: str, /, *, tat: datetime.datetime, limit: RateLimit) -> None: - if self.redis: + if self.redis and self.redis.could_connect: await self.redis.pool.set(key, tat.isoformat(), ex=int(limit.period.total_seconds() + 60)) # type: ignore else: self._keys[key] = {"tat": tat, "limit": limit} diff --git a/starlette_plus/middleware/sessions.py b/starlette_plus/middleware/sessions.py index b997ab7..da65878 100644 --- a/starlette_plus/middleware/sessions.py +++ b/starlette_plus/middleware/sessions.py @@ -25,7 +25,6 @@ from typing import TYPE_CHECKING, Any import itsdangerous -import redis.asyncio as redis from starlette.datastructures import MutableHeaders from starlette.requests import HTTPConnection @@ -33,7 +32,6 @@ if TYPE_CHECKING: - import redis.asyncio as redis from starlette.types import ASGIApp, Message, Receive, Scope, Send from ..redis import Redis @@ -43,10 +41,10 @@ class Storage: - __slots__ = ("pool", "_keys") + __slots__ = ("redis", "_keys") def __init__(self, *, redis: Redis | None = None) -> None: - self.pool: redis.Redis | None = redis.pool if redis else None + self.redis: Redis | None = redis self._keys: dict[str, Any] = {} async def get(self, data: dict[str, Any]) -> dict[str, Any]: @@ -57,23 +55,23 @@ async def get(self, data: dict[str, Any]) -> dict[str, Any]: await self.delete(key) return {} - if self.pool: - session: Any = await self.pool.get(key) # type: ignore + if self.redis and self.redis.could_connect: + session: Any = await self.redis.pool.get(key) # type: ignore else: session: Any = self._keys.get(key) return json.loads(session) if session else {} async def set(self, key: str, value: dict[str, Any], *, max_age: int) -> None: - if self.pool: - await self.pool.set(key, json.dumps(value), ex=max_age) # type: ignore + if self.redis and self.redis.could_connect: + await self.redis.pool.set(key, json.dumps(value), ex=max_age) # type: ignore return self._keys[key] = json.dumps(value) async def delete(self, key: str) -> None: - if self.pool: - await self.pool.delete(key) # type: ignore + if self.redis and self.redis.could_connect: + await self.redis.pool.delete(key) # type: ignore else: self._keys.pop(key, None) diff --git a/starlette_plus/redis.py b/starlette_plus/redis.py index 4a85e3e..cec0c41 100644 --- a/starlette_plus/redis.py +++ b/starlette_plus/redis.py @@ -13,15 +13,50 @@ limitations under the License. """ +import asyncio +import logging + import redis.asyncio as redis +logger: logging.Logger = logging.getLogger(__name__) + + class Redis: def __init__(self, *, url: str | None = None) -> None: url = url or "redis://localhost:6379/0" pool = redis.ConnectionPool.from_url(url, decode_responses=True) # type: ignore self.pool: redis.Redis = redis.Redis.from_pool(pool) + self.url = url + + self._could_connect: bool | None = None + self._task = asyncio.create_task(self._health_task()) + + @property + def could_connect(self) -> bool | None: + return self._could_connect async def ping(self) -> bool: - return bool(await self.pool.ping()) # type: ignore + try: + async with asyncio.timeout(3.0): + self._could_connect = bool(await self.pool.ping()) # type: ignore + except Exception: + if self._could_connect is not False: + logger.warning( + "Unable to connect to Redis: %s. Services relying on this instance will now be in-memory.", self.url + ) + + self._could_connect = False + + return self._could_connect + + async def _health_task(self) -> None: + while True: + previous = self.could_connect + await self.ping() + + if not previous and self.could_connect: + logger.info("Redis connection has been (re)established: %s", self.url) + + await asyncio.sleep(5)