Skip to content

Commit

Permalink
fix: receiver
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton committed Jul 25, 2024
1 parent 25ffb83 commit 72b76e2
Showing 1 changed file with 70 additions and 3 deletions.
73 changes: 70 additions & 3 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import asyncio
import inspect
from concurrent.futures import Executor
from contextlib import asynccontextmanager
from logging import getLogger
from time import time
from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
List,
Optional,
Set,
Union,
get_type_hints,
)

import anyio
from taskiq_dependencies import DependencyGraph
Expand All @@ -21,6 +32,7 @@

logger = getLogger(__name__)
QUEUE_DONE = b"-1"
QUEUE_SKIP = b"-2"


def _run_sync(
Expand Down Expand Up @@ -83,6 +95,11 @@ def __init__(
"can result in undefined behavior",
)
self.sem_prefetch = asyncio.Semaphore(max_prefetch)
self.idle_tasks: "Set[asyncio.Task[Any]]" = set()
self.sem_lock: asyncio.Lock = asyncio.Lock()
self.listen_queue: "asyncio.Queue[Union[AckableMessage, bytes]]" = (
asyncio.Queue()
)

async def callback( # noqa: C901, PLR0912
self,
Expand Down Expand Up @@ -227,7 +244,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
broker_ctx = self.broker.custom_dependency_context
broker_ctx.update(
{
Context: Context(message, self.broker),
Context: Context(message, self.broker, self.idle),
TaskiqState: self.broker.state,
},
)
Expand Down Expand Up @@ -329,6 +346,7 @@ async def listen(self) -> None: # pragma: no cover
await self.broker.startup()
logger.info("Listening started.")
queue: "asyncio.Queue[Union[bytes, AckableMessage]]" = asyncio.Queue()
self.listen_queue = queue

async with anyio.create_task_group() as gr:
gr.start_soon(self.prefetcher, queue)
Expand Down Expand Up @@ -396,7 +414,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
while True:
# Waits for semaphore to be released.
if self.sem is not None:
await self.sem.acquire()
async with self.sem_lock:
await self.sem.acquire()

self.sem_prefetch.release()
message = await queue.get()
Expand All @@ -407,6 +426,11 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
await asyncio.wait(tasks, timeout=self.wait_tasks_timeout)
break

if message is QUEUE_SKIP:
if self.sem is not None:
self.sem.release()
continue

task = asyncio.create_task(
self.callback(message=message, raise_err=False),
)
Expand All @@ -420,6 +444,49 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
# https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
task.add_done_callback(task_cb)

@asynccontextmanager
async def idle(self, timeout: Optional[int] = None) -> AsyncIterator[None]:
"""Idle task.
:param timeout: idle time
"""
if self.sem is not None:
self.sem.release()

def acquire() -> "asyncio.Task[Any]":
if self.sem is None:
raise ValueError(self.sem)

task = asyncio.create_task(self.sem.acquire())
task.add_done_callback(self.idle_tasks.discard)
self.idle_tasks.add(task)
return task

cancelled = False
try:
with anyio.fail_after(timeout):
yield
except asyncio.CancelledError:
if self.sem:
acquire()

cancelled = True
raise

finally:
if not cancelled and self.sem is not None:
try:
await self.sem_lock.acquire()
except asyncio.CancelledError:
acquire()
raise

try:
self.listen_queue.put_nowait(QUEUE_SKIP)
await acquire()
finally:
self.sem_lock.release()

def _prepare_task(self, name: str, handler: Callable[..., Any]) -> None:
"""
Prepare task for execution.
Expand Down

0 comments on commit 72b76e2

Please sign in to comment.