From 25ffb83f9f5b203ba5c9b9f97faca85521e7f6b4 Mon Sep 17 00:00:00 2001 From: Anton Date: Thu, 25 Jul 2024 09:53:37 +0300 Subject: [PATCH 1/3] feat: changed approach to dealing with idle tasks --- taskiq/context.py | 11 ++++- taskiq/depends/task_idler.py | 19 ++++++++ tests/depends/test_task_idler.py | 76 ++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 taskiq/depends/task_idler.py create mode 100644 tests/depends/test_task_idler.py diff --git a/taskiq/context.py b/taskiq/context.py index 098b9d6b..3ebbbbf1 100644 --- a/taskiq/context.py +++ b/taskiq/context.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING +from contextlib import _AsyncGeneratorContextManager +from typing import TYPE_CHECKING, Callable, Optional from taskiq.abc.broker import AsyncBroker from taskiq.exceptions import NoResultError, TaskRejectedError @@ -11,11 +12,17 @@ class Context: """Context class.""" - def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None: + def __init__( + self, + message: TaskiqMessage, + broker: AsyncBroker, + idle: "Callable[[Optional[int]], _AsyncGeneratorContextManager[None]]", + ) -> None: self.message = message self.broker = broker self.state: "TaskiqState" = None # type: ignore self.state = broker.state + self.idle = idle async def requeue(self) -> None: """ diff --git a/taskiq/depends/task_idler.py b/taskiq/depends/task_idler.py new file mode 100644 index 00000000..60d48517 --- /dev/null +++ b/taskiq/depends/task_idler.py @@ -0,0 +1,19 @@ +from contextlib import asynccontextmanager +from typing import AsyncIterator, Optional + +from taskiq_dependencies import Depends + +from taskiq.context import Context + + +class TaskIdler: + """Task's dependency to idle task.""" + + def __init__(self, context: Context = Depends()) -> None: + self.context = context + + @asynccontextmanager + async def __call__(self, timeout: Optional[int] = None) -> AsyncIterator[None]: + """Idle task.""" + async with self.context.idle(timeout): + yield diff --git a/tests/depends/test_task_idler.py b/tests/depends/test_task_idler.py new file mode 100644 index 00000000..672149ba --- /dev/null +++ b/tests/depends/test_task_idler.py @@ -0,0 +1,76 @@ +import asyncio +from asyncio.exceptions import CancelledError + +import anyio +import pytest +from taskiq_dependencies import Depends + +from taskiq.api.receiver import run_receiver_task +from taskiq.brokers.inmemory_broker import InmemoryResultBackend +from taskiq.depends.task_idler import TaskIdler +from tests.utils import AsyncQueueBroker + + +@pytest.mark.anyio +async def test_task_idler() -> None: + broker = AsyncQueueBroker().with_result_backend(InmemoryResultBackend()) + kicked = 0 + desired_kicked = 20 + + @broker.task(timeout=1) + async def test_func(idle: TaskIdler = Depends()) -> None: + nonlocal kicked + async with idle(): + await asyncio.sleep(0.5) + kicked += 1 + + receiver_task = asyncio.create_task(run_receiver_task(broker, max_async_tasks=1)) + + tasks = [] + for _ in range(desired_kicked): + tasks.append(await test_func.kiq()) + + with anyio.fail_after(1): + for task in tasks: + await task.wait_result(check_interval=0.01) + + receiver_task.cancel() + assert kicked == desired_kicked + + +@pytest.mark.anyio +async def test_task_idler_task_cancelled() -> None: + broker = AsyncQueueBroker().with_result_backend(InmemoryResultBackend()) + kicked = 0 + desired_kicked = 20 + + @broker.task(timeout=0.2) + async def test_func_timeout(idle: TaskIdler = Depends()) -> None: + nonlocal kicked + try: + async with idle(): + await asyncio.sleep(2) + except CancelledError: + kicked += 1 + raise + + @broker.task(timeout=2) + async def test_func(idle: TaskIdler = Depends()) -> None: + nonlocal kicked + async with idle(): + await asyncio.sleep(0.5) + kicked += 1 + + receiver_task = asyncio.create_task(run_receiver_task(broker, max_async_tasks=1)) + + tasks = [] + tasks.append(await test_func_timeout.kiq()) + for _ in range(desired_kicked): + tasks.append(await test_func.kiq()) + + with anyio.fail_after(1): + for task in tasks: + await task.wait_result(check_interval=0.01) + + receiver_task.cancel() + assert kicked == desired_kicked + 1 From 72b76e29912ba31af2bb43cca7c8ba00a334f474 Mon Sep 17 00:00:00 2001 From: Anton Date: Thu, 25 Jul 2024 09:54:35 +0300 Subject: [PATCH 2/3] fix: receiver --- taskiq/receiver/receiver.py | 73 +++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 3 deletions(-) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 7d5a4035..58d1d8ff 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -1,9 +1,20 @@ import asyncio import inspect from concurrent.futures import Executor +from contextlib import asynccontextmanager from logging import getLogger from time import time -from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + List, + Optional, + Set, + Union, + get_type_hints, +) import anyio from taskiq_dependencies import DependencyGraph @@ -21,6 +32,7 @@ logger = getLogger(__name__) QUEUE_DONE = b"-1" +QUEUE_SKIP = b"-2" def _run_sync( @@ -83,6 +95,11 @@ def __init__( "can result in undefined behavior", ) self.sem_prefetch = asyncio.Semaphore(max_prefetch) + self.idle_tasks: "Set[asyncio.Task[Any]]" = set() + self.sem_lock: asyncio.Lock = asyncio.Lock() + self.listen_queue: "asyncio.Queue[Union[AckableMessage, bytes]]" = ( + asyncio.Queue() + ) async def callback( # noqa: C901, PLR0912 self, @@ -227,7 +244,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 broker_ctx = self.broker.custom_dependency_context broker_ctx.update( { - Context: Context(message, self.broker), + Context: Context(message, self.broker, self.idle), TaskiqState: self.broker.state, }, ) @@ -329,6 +346,7 @@ async def listen(self) -> None: # pragma: no cover await self.broker.startup() logger.info("Listening started.") queue: "asyncio.Queue[Union[bytes, AckableMessage]]" = asyncio.Queue() + self.listen_queue = queue async with anyio.create_task_group() as gr: gr.start_soon(self.prefetcher, queue) @@ -396,7 +414,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None: while True: # Waits for semaphore to be released. if self.sem is not None: - await self.sem.acquire() + async with self.sem_lock: + await self.sem.acquire() self.sem_prefetch.release() message = await queue.get() @@ -407,6 +426,11 @@ def task_cb(task: "asyncio.Task[Any]") -> None: await asyncio.wait(tasks, timeout=self.wait_tasks_timeout) break + if message is QUEUE_SKIP: + if self.sem is not None: + self.sem.release() + continue + task = asyncio.create_task( self.callback(message=message, raise_err=False), ) @@ -420,6 +444,49 @@ def task_cb(task: "asyncio.Task[Any]") -> None: # https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/ task.add_done_callback(task_cb) + @asynccontextmanager + async def idle(self, timeout: Optional[int] = None) -> AsyncIterator[None]: + """Idle task. + + :param timeout: idle time + """ + if self.sem is not None: + self.sem.release() + + def acquire() -> "asyncio.Task[Any]": + if self.sem is None: + raise ValueError(self.sem) + + task = asyncio.create_task(self.sem.acquire()) + task.add_done_callback(self.idle_tasks.discard) + self.idle_tasks.add(task) + return task + + cancelled = False + try: + with anyio.fail_after(timeout): + yield + except asyncio.CancelledError: + if self.sem: + acquire() + + cancelled = True + raise + + finally: + if not cancelled and self.sem is not None: + try: + await self.sem_lock.acquire() + except asyncio.CancelledError: + acquire() + raise + + try: + self.listen_queue.put_nowait(QUEUE_SKIP) + await acquire() + finally: + self.sem_lock.release() + def _prepare_task(self, name: str, handler: Callable[..., Any]) -> None: """ Prepare task for execution. From 7e31b24cc6d2ee0c735f9c71e8cf32220e091177 Mon Sep 17 00:00:00 2001 From: Anton Date: Thu, 25 Jul 2024 10:12:33 +0300 Subject: [PATCH 3/3] fix: issue with python 3.8 --- taskiq/context.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/taskiq/context.py b/taskiq/context.py index 3ebbbbf1..9eaba7ac 100644 --- a/taskiq/context.py +++ b/taskiq/context.py @@ -1,13 +1,23 @@ -from contextlib import _AsyncGeneratorContextManager -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional + +from typing_extensions import TypeAlias from taskiq.abc.broker import AsyncBroker from taskiq.exceptions import NoResultError, TaskRejectedError from taskiq.message import TaskiqMessage if TYPE_CHECKING: # pragma: no cover + from contextlib import _AsyncGeneratorContextManager + from taskiq.state import TaskiqState + IdleType: TypeAlias = ( + "Callable[[Optional[int]], _AsyncGeneratorContextManager[None]]" + ) + +else: + IdleType: TypeAlias = Any + class Context: """Context class.""" @@ -16,7 +26,7 @@ def __init__( self, message: TaskiqMessage, broker: AsyncBroker, - idle: "Callable[[Optional[int]], _AsyncGeneratorContextManager[None]]", + idle: IdleType, ) -> None: self.message = message self.broker = broker