From 41fe84a9560698214649941a7f73a97b37ba48b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-S=C3=A9bastien?= Date: Wed, 7 Feb 2024 09:17:26 +0100 Subject: [PATCH 1/5] feat(broker): allow to set queue name dynamically when kicking --- taskiq_redis/redis_broker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/taskiq_redis/redis_broker.py b/taskiq_redis/redis_broker.py index 75ba868..0712b85 100644 --- a/taskiq_redis/redis_broker.py +++ b/taskiq_redis/redis_broker.py @@ -60,8 +60,9 @@ async def kick(self, message: BrokerMessage) -> None: :param message: message to send. """ + queue_name = message.labels.get("queue_name") or self.queue_name async with Redis(connection_pool=self.connection_pool) as redis_conn: - await redis_conn.publish(self.queue_name, message.message) + await redis_conn.publish(queue_name, message.message) async def listen(self) -> AsyncGenerator[bytes, None]: """ @@ -95,8 +96,9 @@ async def kick(self, message: BrokerMessage) -> None: :param message: message to append. """ + queue_name = message.labels.get("queue_name") or self.queue_name async with Redis(connection_pool=self.connection_pool) as redis_conn: - await redis_conn.lpush(self.queue_name, message.message) + await redis_conn.lpush(queue_name, message.message) async def listen(self) -> AsyncGenerator[bytes, None]: """ From 77a51b49e2ea33ecf019ef717d8103e4ffb9ad8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Mus=C3=ADlek?= Date: Wed, 7 Feb 2024 14:23:31 +0100 Subject: [PATCH 2/5] Replace ConnectionPool with BlockingConnectionPool --- taskiq_redis/redis_broker.py | 8 +++--- taskiq_redis/schedule_source.py | 6 ++--- tests/test_broker.py | 46 +++++++++++++++++++++++++++++++++ tests/test_schedule_source.py | 13 ++++++++++ 4 files changed, 67 insertions(+), 6 deletions(-) diff --git a/taskiq_redis/redis_broker.py b/taskiq_redis/redis_broker.py index 75ba868..3576125 100644 --- a/taskiq_redis/redis_broker.py +++ b/taskiq_redis/redis_broker.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import Any, AsyncGenerator, Callable, Optional, TypeVar -from redis.asyncio import ConnectionPool, Redis +from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis from taskiq.abc.broker import AsyncBroker from taskiq.abc.result_backend import AsyncResultBackend from taskiq.message import BrokerMessage @@ -31,14 +31,16 @@ def __init__( :param result_backend: custom result backend. :param queue_name: name for a list in redis. :param max_connection_pool_size: maximum number of connections in pool. - :param connection_kwargs: additional arguments for aio-redis ConnectionPool. + Each worker opens its own connection. Therefore this value has to be + at least number of workers + 1. + :param connection_kwargs: additional arguments for redis BlockingConnectionPool. """ super().__init__( result_backend=result_backend, task_id_generator=task_id_generator, ) - self.connection_pool: ConnectionPool = ConnectionPool.from_url( + self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url( url=url, max_connections=max_connection_pool_size, **connection_kwargs, diff --git a/taskiq_redis/schedule_source.py b/taskiq_redis/schedule_source.py index 17ed1ee..bc53141 100644 --- a/taskiq_redis/schedule_source.py +++ b/taskiq_redis/schedule_source.py @@ -1,6 +1,6 @@ from typing import Any, List, Optional -from redis.asyncio import ConnectionPool, Redis, RedisCluster +from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis, RedisCluster from taskiq import ScheduleSource from taskiq.abc.serializer import TaskiqSerializer from taskiq.compat import model_dump, model_validate @@ -22,7 +22,7 @@ class RedisScheduleSource(ScheduleSource): This is how many keys will be fetched at once. :param max_connection_pool_size: maximum number of connections in pool. :param serializer: serializer for data. - :param connection_kwargs: additional arguments for aio-redis ConnectionPool. + :param connection_kwargs: additional arguments for redis BlockingConnectionPool. """ def __init__( @@ -35,7 +35,7 @@ def __init__( **connection_kwargs: Any, ) -> None: self.prefix = prefix - self.connection_pool: ConnectionPool = ConnectionPool.from_url( + self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url( url=url, max_connections=max_connection_pool_size, **connection_kwargs, diff --git a/tests/test_broker.py b/tests/test_broker.py index 813e72e..08f5dff 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -71,6 +71,29 @@ async def test_pub_sub_broker( await broker.shutdown() +@pytest.mark.anyio +async def test_pub_sub_broker_max_connections( + valid_broker_message: BrokerMessage, + redis_url: str, +) -> None: + """Test PubSubBroker with connection limit set.""" + broker = PubSubBroker( + url=redis_url, + queue_name=uuid.uuid4().hex, + max_connection_pool_size=4, + timeout=1, + ) + worker_tasks = [asyncio.create_task(get_message(broker)) for _ in range(3)] + await asyncio.sleep(0.3) + + await asyncio.gather(*[broker.kick(valid_broker_message) for _ in range(50)]) + await asyncio.sleep(0.3) + + for worker in worker_tasks: + worker.cancel() + await broker.shutdown() + + @pytest.mark.anyio async def test_list_queue_broker( valid_broker_message: BrokerMessage, @@ -98,6 +121,29 @@ async def test_list_queue_broker( await broker.shutdown() +@pytest.mark.anyio +async def test_list_queue_broker_max_connections( + valid_broker_message: BrokerMessage, + redis_url: str, +) -> None: + """Test ListQueueBroker with connection limit set.""" + broker = ListQueueBroker( + url=redis_url, + queue_name=uuid.uuid4().hex, + max_connection_pool_size=4, + timeout=1, + ) + worker_tasks = [asyncio.create_task(get_message(broker)) for _ in range(3)] + await asyncio.sleep(0.3) + + await asyncio.gather(*[broker.kick(valid_broker_message) for _ in range(50)]) + await asyncio.sleep(0.3) + + for worker in worker_tasks: + worker.cancel() + await broker.shutdown() + + @pytest.mark.anyio async def test_list_queue_cluster_broker( valid_broker_message: BrokerMessage, diff --git a/tests/test_schedule_source.py b/tests/test_schedule_source.py index b9c1685..ed245fd 100644 --- a/tests/test_schedule_source.py +++ b/tests/test_schedule_source.py @@ -1,3 +1,4 @@ +import asyncio import datetime as dt import uuid @@ -108,6 +109,18 @@ async def test_buffer(redis_url: str) -> None: await source.shutdown() +@pytest.mark.anyio +async def test_max_connections(redis_url: str) -> None: + prefix = uuid.uuid4().hex + source = RedisScheduleSource( + redis_url, + prefix=prefix, + max_connection_pool_size=1, + timeout=3, + ) + await asyncio.gather(*[source.get_schedules() for _ in range(10)]) + + @pytest.mark.anyio async def test_cluster_set_schedule(redis_cluster_url: str) -> None: prefix = uuid.uuid4().hex From ffed4dbfb82db400146e9271b4a3b30213a4c69e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Mus=C3=ADlek?= Date: Wed, 7 Feb 2024 14:31:19 +0100 Subject: [PATCH 3/5] Add connection_kwargs to result backends --- taskiq_redis/redis_backend.py | 21 +++++++++++++++++---- tests/test_result_backend.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/taskiq_redis/redis_backend.py b/taskiq_redis/redis_backend.py index 3a0810d..026653c 100644 --- a/taskiq_redis/redis_backend.py +++ b/taskiq_redis/redis_backend.py @@ -1,7 +1,7 @@ import pickle -from typing import Dict, Optional, TypeVar, Union +from typing import Any, Dict, Optional, TypeVar, Union -from redis.asyncio import ConnectionPool, Redis +from redis.asyncio import BlockingConnectionPool, Redis from redis.asyncio.cluster import RedisCluster from taskiq import AsyncResultBackend from taskiq.abc.result_backend import TaskiqResult @@ -24,6 +24,8 @@ def __init__( keep_results: bool = True, result_ex_time: Optional[int] = None, result_px_time: Optional[int] = None, + max_connection_pool_size: Optional[int] = None, + **connection_kwargs: Any, ) -> None: """ Constructs a new result backend. @@ -32,13 +34,19 @@ def __init__( :param keep_results: flag to not remove results from Redis after reading. :param result_ex_time: expire time in seconds for result. :param result_px_time: expire time in milliseconds for result. + :param max_connection_pool_size: maximum number of connections in pool. + :param connection_kwargs: additional arguments for redis BlockingConnectionPool. :raises DuplicateExpireTimeSelectedError: if result_ex_time and result_px_time are selected. :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time and result_px_time are equal zero. """ - self.redis_pool = ConnectionPool.from_url(redis_url) + self.redis_pool = BlockingConnectionPool.from_url( + url=redis_url, + max_connections=max_connection_pool_size, + **connection_kwargs, + ) self.keep_results = keep_results self.result_ex_time = result_ex_time self.result_px_time = result_px_time @@ -146,6 +154,7 @@ def __init__( keep_results: bool = True, result_ex_time: Optional[int] = None, result_px_time: Optional[int] = None, + **connection_kwargs: Any, ) -> None: """ Constructs a new result backend. @@ -154,13 +163,17 @@ def __init__( :param keep_results: flag to not remove results from Redis after reading. :param result_ex_time: expire time in seconds for result. :param result_px_time: expire time in milliseconds for result. + :param connection_kwargs: additional arguments for RedisCluster. :raises DuplicateExpireTimeSelectedError: if result_ex_time and result_px_time are selected. :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time and result_px_time are equal zero. """ - self.redis: RedisCluster[bytes] = RedisCluster.from_url(redis_url) + self.redis: RedisCluster[bytes] = RedisCluster.from_url( + redis_url, + **connection_kwargs, + ) self.keep_results = keep_results self.result_ex_time = result_ex_time self.result_px_time = result_px_time diff --git a/tests/test_result_backend.py b/tests/test_result_backend.py index 15ecdd0..d85b28b 100644 --- a/tests/test_result_backend.py +++ b/tests/test_result_backend.py @@ -1,3 +1,4 @@ +import asyncio import uuid import pytest @@ -132,6 +133,38 @@ async def test_keep_results_after_reading(redis_url: str) -> None: await result_backend.shutdown() +@pytest.mark.anyio +async def test_set_result_max_connections(redis_url: str) -> None: + """ + Tests that asynchronous backend works with connection limit. + + :param redis_url: redis URL. + """ + result_backend = RedisAsyncResultBackend( # type: ignore + redis_url=redis_url, + max_connection_pool_size=1, + timeout=3, + ) + + task_id = uuid.uuid4().hex + result: "TaskiqResult[int]" = TaskiqResult( + is_err=True, + log="My Log", + return_value=11, + execution_time=112.2, + ) + await result_backend.set_result( + task_id=task_id, + result=result, + ) + + async def get_result() -> None: + await result_backend.get_result(task_id=task_id, with_logs=True) + + await asyncio.gather(*[get_result() for _ in range(10)]) + await result_backend.shutdown() + + @pytest.mark.anyio async def test_set_result_success_cluster(redis_cluster_url: str) -> None: """ From 7cfb6f70c5edd6c4c177fbdfb880f86129c29b16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Mus=C3=ADlek?= Date: Thu, 8 Feb 2024 09:34:49 +0100 Subject: [PATCH 4/5] Update docs --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 45aa1ce..683b581 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,9 @@ Brokers parameters: * `result_backend` - custom result backend. * `queue_name` - name of the pub/sub channel in redis. * `max_connection_pool_size` - maximum number of connections in pool. +* Any other keyword arguments are passed to `redis.asyncio.BlockingConnectionPool`. + Notably, you can use `timeout` to set custom timeout in seconds for reconnects + (or set it to `None` to try reconnects indefinitely). ## RedisAsyncResultBackend configuration @@ -79,6 +82,9 @@ RedisAsyncResultBackend parameters: * `keep_results` - flag to not remove results from Redis after reading. * `result_ex_time` - expire time in seconds (by default - not specified) * `result_px_time` - expire time in milliseconds (by default - not specified) +* Any other keyword arguments are passed to `redis.asyncio.BlockingConnectionPool`. + Notably, you can use `timeout` to set custom timeout in seconds for reconnects + (or set it to `None` to try reconnects indefinitely). > IMPORTANT: **It is highly recommended to use expire time ​​in RedisAsyncResultBackend** > If you want to add expiration, either `result_ex_time` or `result_px_time` must be set. >```python From c5bb440443b21d714e8f583dcfd59d44752e11f9 Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sat, 20 Apr 2024 00:22:52 +0200 Subject: [PATCH 5/5] Version bumped to 0.5.6. Signed-off-by: Pavel Kirilin --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b53c6ab..d5155ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "taskiq-redis" -version = "0.5.5" +version = "0.5.6" description = "Redis integration for taskiq" authors = ["taskiq-team "] readme = "README.md"