Skip to content

Commit

Permalink
Perform periodic health checks on Redis, default to in-memory after a…
Browse files Browse the repository at this point in the history
… failure.
  • Loading branch information
EvieePy committed May 18, 2024
1 parent 2dbaa3e commit f21169a
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 35 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
#.idea/

# Test files
test*.py
22 changes: 1 addition & 21 deletions starlette_plus/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
if TYPE_CHECKING:
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from .redis import Redis
from .types_.core import Methods, RouteOptions
from .types_.limiter import BucketType, ExemptCallable, RateLimitData

Expand Down Expand Up @@ -192,31 +191,12 @@ def __init__(self, *args: Any, **kwargs: Unpack[ApplicationOptions]) -> None:
middleware_: list[Middleware] = kwargs.pop("middleware", [])
middleware_.insert(0, Middleware(LoggingMiddleware)) if self._access_log else None

statrtups = kwargs.pop("on_startup", [])
statrtups.append(self.__startup)

super().__init__(*args, **kwargs, middleware=middleware_, on_startup=statrtups) # type: ignore
super().__init__(*args, **kwargs, middleware=middleware_) # type: ignore

self.add_view(self)
for view in views:
self.add_view(view)

async def __startup(self) -> None:
for middleware in self.user_middleware:
redis: Redis | None = middleware.kwargs.get("redis", None) # type: ignore

if not redis:
continue

try:
resp: bool = await redis.ping()
except Exception:
resp = False

if not resp:
logger.warning("Unable to connect to redis on %s, defaulting to in-memory.", middleware.cls.__name__)
middleware.kwargs["redis"] = None

def __new__(cls, *args: Any, **kwargs: Any) -> Self:
self: Self = super().__new__(cls)
self.__routes__ = []
Expand Down
4 changes: 2 additions & 2 deletions starlette_plus/limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def __init__(self, redis: Redis | None = None) -> None:
async def get_tat(self, key: str, /) -> datetime.datetime:
now: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)

if self.redis:
if self.redis and self.redis.could_connect:
value: str | None = await self.redis.pool.get(key) # type: ignore
return datetime.datetime.fromisoformat(value) if value else now # type: ignore

return self._keys.get(key, {"tat": now}).get("tat", now)

async def set_tat(self, key: str, /, *, tat: datetime.datetime, limit: RateLimit) -> None:
if self.redis:
if self.redis and self.redis.could_connect:
await self.redis.pool.set(key, tat.isoformat(), ex=int(limit.period.total_seconds() + 60)) # type: ignore
else:
self._keys[key] = {"tat": tat, "limit": limit}
Expand Down
18 changes: 8 additions & 10 deletions starlette_plus/middleware/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@
from typing import TYPE_CHECKING, Any

import itsdangerous
import redis.asyncio as redis
from starlette.datastructures import MutableHeaders
from starlette.requests import HTTPConnection

from ..redis import Redis


if TYPE_CHECKING:
import redis.asyncio as redis
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from ..redis import Redis
Expand All @@ -43,10 +41,10 @@


class Storage:
__slots__ = ("pool", "_keys")
__slots__ = ("redis", "_keys")

def __init__(self, *, redis: Redis | None = None) -> None:
self.pool: redis.Redis | None = redis.pool if redis else None
self.redis: Redis | None = redis
self._keys: dict[str, Any] = {}

async def get(self, data: dict[str, Any]) -> dict[str, Any]:
Expand All @@ -57,23 +55,23 @@ async def get(self, data: dict[str, Any]) -> dict[str, Any]:
await self.delete(key)
return {}

if self.pool:
session: Any = await self.pool.get(key) # type: ignore
if self.redis and self.redis.could_connect:
session: Any = await self.redis.pool.get(key) # type: ignore
else:
session: Any = self._keys.get(key)

return json.loads(session) if session else {}

async def set(self, key: str, value: dict[str, Any], *, max_age: int) -> None:
if self.pool:
await self.pool.set(key, json.dumps(value), ex=max_age) # type: ignore
if self.redis and self.redis.could_connect:
await self.redis.pool.set(key, json.dumps(value), ex=max_age) # type: ignore
return

self._keys[key] = json.dumps(value)

async def delete(self, key: str) -> None:
if self.pool:
await self.pool.delete(key) # type: ignore
if self.redis and self.redis.could_connect:
await self.redis.pool.delete(key) # type: ignore
else:
self._keys.pop(key, None)

Expand Down
37 changes: 36 additions & 1 deletion starlette_plus/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,50 @@
limitations under the License.
"""

import asyncio
import logging

import redis.asyncio as redis


logger: logging.Logger = logging.getLogger(__name__)


class Redis:
def __init__(self, *, url: str | None = None) -> None:
url = url or "redis://localhost:6379/0"
pool = redis.ConnectionPool.from_url(url, decode_responses=True) # type: ignore

self.pool: redis.Redis = redis.Redis.from_pool(pool)
self.url = url

self._could_connect: bool | None = None
self._task = asyncio.create_task(self._health_task())

@property
def could_connect(self) -> bool | None:
return self._could_connect

async def ping(self) -> bool:
return bool(await self.pool.ping()) # type: ignore
try:
async with asyncio.timeout(3.0):
self._could_connect = bool(await self.pool.ping()) # type: ignore
except Exception:
if self._could_connect is not False:
logger.warning(
"Unable to connect to Redis: %s. Services relying on this instance will now be in-memory.", self.url
)

self._could_connect = False

return self._could_connect

async def _health_task(self) -> None:
while True:
previous = self.could_connect
await self.ping()

if not previous and self.could_connect:
logger.info("Redis connection has been (re)established: %s", self.url)

await asyncio.sleep(5)

0 comments on commit f21169a

Please sign in to comment.