Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: non-blocking tasks #128

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions taskiq/context.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Callable, Optional

from typing_extensions import TypeAlias

from taskiq.abc.broker import AsyncBroker
from taskiq.exceptions import NoResultError, TaskRejectedError
from taskiq.message import TaskiqMessage

if TYPE_CHECKING: # pragma: no cover
from contextlib import _AsyncGeneratorContextManager

from taskiq.state import TaskiqState

IdleType: TypeAlias = (
"Callable[[Optional[int]], _AsyncGeneratorContextManager[None]]"
)

else:
IdleType: TypeAlias = Any


class Context:
"""Context class."""

def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None:
def __init__(
self,
message: TaskiqMessage,
broker: AsyncBroker,
idle: IdleType,
) -> None:
self.message = message
self.broker = broker
self.state: "TaskiqState" = None # type: ignore
self.state = broker.state
self.idle = idle

async def requeue(self) -> None:
"""
Expand Down
19 changes: 19 additions & 0 deletions taskiq/depends/task_idler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional

from taskiq_dependencies import Depends

from taskiq.context import Context


class TaskIdler:
"""Task's dependency to idle task."""

def __init__(self, context: Context = Depends()) -> None:
self.context = context

@asynccontextmanager
async def __call__(self, timeout: Optional[int] = None) -> AsyncIterator[None]:
"""Idle task."""
async with self.context.idle(timeout):
yield
73 changes: 70 additions & 3 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import asyncio
import inspect
from concurrent.futures import Executor
from contextlib import asynccontextmanager
from logging import getLogger
from time import time
from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
List,
Optional,
Set,
Union,
get_type_hints,
)

import anyio
from taskiq_dependencies import DependencyGraph
Expand All @@ -21,6 +32,7 @@

logger = getLogger(__name__)
QUEUE_DONE = b"-1"
QUEUE_SKIP = b"-2"


def _run_sync(
Expand Down Expand Up @@ -83,6 +95,11 @@ def __init__(
"can result in undefined behavior",
)
self.sem_prefetch = asyncio.Semaphore(max_prefetch)
self.idle_tasks: "Set[asyncio.Task[Any]]" = set()
self.sem_lock: asyncio.Lock = asyncio.Lock()
self.listen_queue: "asyncio.Queue[Union[AckableMessage, bytes]]" = (
asyncio.Queue()
)

async def callback( # noqa: C901, PLR0912
self,
Expand Down Expand Up @@ -227,7 +244,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
broker_ctx = self.broker.custom_dependency_context
broker_ctx.update(
{
Context: Context(message, self.broker),
Context: Context(message, self.broker, self.idle),
TaskiqState: self.broker.state,
},
)
Expand Down Expand Up @@ -329,6 +346,7 @@ async def listen(self) -> None: # pragma: no cover
await self.broker.startup()
logger.info("Listening started.")
queue: "asyncio.Queue[Union[bytes, AckableMessage]]" = asyncio.Queue()
self.listen_queue = queue

async with anyio.create_task_group() as gr:
gr.start_soon(self.prefetcher, queue)
Expand Down Expand Up @@ -396,7 +414,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
while True:
# Waits for semaphore to be released.
if self.sem is not None:
await self.sem.acquire()
async with self.sem_lock:
await self.sem.acquire()

self.sem_prefetch.release()
message = await queue.get()
Expand All @@ -407,6 +426,11 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
await asyncio.wait(tasks, timeout=self.wait_tasks_timeout)
break

if message is QUEUE_SKIP:
if self.sem is not None:
self.sem.release()
continue

task = asyncio.create_task(
self.callback(message=message, raise_err=False),
)
Expand All @@ -420,6 +444,49 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
# https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
task.add_done_callback(task_cb)

@asynccontextmanager
async def idle(self, timeout: Optional[int] = None) -> AsyncIterator[None]:
"""Idle task.

:param timeout: idle time
"""
if self.sem is not None:
self.sem.release()

def acquire() -> "asyncio.Task[Any]":
if self.sem is None:
raise ValueError(self.sem)

task = asyncio.create_task(self.sem.acquire())
task.add_done_callback(self.idle_tasks.discard)
self.idle_tasks.add(task)
return task

cancelled = False
try:
with anyio.fail_after(timeout):
yield
except asyncio.CancelledError:
if self.sem:
acquire()

cancelled = True
raise

finally:
if not cancelled and self.sem is not None:
try:
await self.sem_lock.acquire()
except asyncio.CancelledError:
acquire()
raise

try:
self.listen_queue.put_nowait(QUEUE_SKIP)
await acquire()
finally:
self.sem_lock.release()

def _prepare_task(self, name: str, handler: Callable[..., Any]) -> None:
"""
Prepare task for execution.
Expand Down
76 changes: 76 additions & 0 deletions tests/depends/test_task_idler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import asyncio
from asyncio.exceptions import CancelledError

import anyio
import pytest
from taskiq_dependencies import Depends

from taskiq.api.receiver import run_receiver_task
from taskiq.brokers.inmemory_broker import InmemoryResultBackend
from taskiq.depends.task_idler import TaskIdler
from tests.utils import AsyncQueueBroker


@pytest.mark.anyio
async def test_task_idler() -> None:
broker = AsyncQueueBroker().with_result_backend(InmemoryResultBackend())
kicked = 0
desired_kicked = 20

@broker.task(timeout=1)
async def test_func(idle: TaskIdler = Depends()) -> None:
nonlocal kicked
async with idle():
await asyncio.sleep(0.5)
kicked += 1

receiver_task = asyncio.create_task(run_receiver_task(broker, max_async_tasks=1))

tasks = []
for _ in range(desired_kicked):
tasks.append(await test_func.kiq())

with anyio.fail_after(1):
for task in tasks:
await task.wait_result(check_interval=0.01)

receiver_task.cancel()
assert kicked == desired_kicked


@pytest.mark.anyio
async def test_task_idler_task_cancelled() -> None:
broker = AsyncQueueBroker().with_result_backend(InmemoryResultBackend())
kicked = 0
desired_kicked = 20

@broker.task(timeout=0.2)
async def test_func_timeout(idle: TaskIdler = Depends()) -> None:
nonlocal kicked
try:
async with idle():
await asyncio.sleep(2)
except CancelledError:
kicked += 1
raise

@broker.task(timeout=2)
async def test_func(idle: TaskIdler = Depends()) -> None:
nonlocal kicked
async with idle():
await asyncio.sleep(0.5)
kicked += 1

receiver_task = asyncio.create_task(run_receiver_task(broker, max_async_tasks=1))

tasks = []
tasks.append(await test_func_timeout.kiq())
for _ in range(desired_kicked):
tasks.append(await test_func.kiq())

with anyio.fail_after(1):
for task in tasks:
await task.wait_result(check_interval=0.01)

receiver_task.cancel()
assert kicked == desired_kicked + 1
Loading