From 5971625952966f77639dda6f4cbbf078bc348d5d Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 19 Jul 2024 17:57:39 +0300 Subject: [PATCH 1/4] fix: sync decorators --- taskiq/receiver/receiver.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 7d5a403..853ed24 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -267,8 +267,16 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 if timeout is not None: if not is_coroutine: logger.warning("Timeouts for sync tasks don't work in python well.") - target_future = asyncio.wait_for(target_future, float(timeout)) - returned = await target_future + + with anyio.fail_after(float(timeout)): + while inspect.isawaitable(target_future): + target_future = await target_future + + else: + while inspect.isawaitable(target_future): + target_future = await target_future + + returned = target_future except NoResultError as no_res_exc: found_exception = no_res_exc logger.warning( From c3e051bd4e03e0e0680fa468f671317a8c39152a Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 19 Jul 2024 18:23:25 +0300 Subject: [PATCH 2/4] tests: new test case --- tests/receiver/test_receiver.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index 6b79e32..c3c4af2 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -2,6 +2,7 @@ import random import time from concurrent.futures import ThreadPoolExecutor +from functools import wraps from typing import Any, ClassVar, List, Optional import pytest @@ -472,3 +473,31 @@ async def task_no_result() -> str: assert resp.return_value is None assert not broker._running_tasks assert isinstance(resp.error, ValueError) + + +@pytest.mark.anyio +async def test_sync_decorator_on_async_function() -> None: + broker = InMemoryBroker() + wrapper_call = False + + def wrapper(f: Any) -> Any: + @wraps(f) + def wrapper_impl(*args: Any, **kwargs: Any) -> Any: + nonlocal wrapper_call + + wrapper_call = True + return f(*args, **kwargs) + + return wrapper_impl + + @broker.task + @wrapper + async def task_no_result() -> str: + return "some value" + + task = await task_no_result.kiq() + resp = await task.wait_result(timeout=1) + + assert resp.return_value == "some value" + assert not broker._running_tasks + assert wrapper_call is True From 30d527bc44963d6c8fcfd2e984c5764f896b7669 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 19 Jul 2024 18:30:43 +0300 Subject: [PATCH 3/4] tests: more tests --- tests/receiver/test_receiver.py | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index c3c4af2..8ed0852 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -501,3 +501,39 @@ async def task_no_result() -> str: assert resp.return_value == "some value" assert not broker._running_tasks assert wrapper_call is True + + +@pytest.mark.anyio +async def test_sync_decorator_on_async_function_with_timeout() -> None: + wrapper_call = False + + def wrapper(f: Any) -> Any: + @wraps(f) + def wrapper_impl(*args: Any, **kwargs: Any) -> Any: + nonlocal wrapper_call + + wrapper_call = True + return f(*args, **kwargs) + + return wrapper_impl + + @wrapper + async def test_func() -> None: + await asyncio.sleep(2) + + receiver = get_receiver() + + result = await receiver.run_task( + test_func, + TaskiqMessage( + task_id="", + task_name="", + labels={"timeout": "0.3"}, + args=[], + kwargs={}, + ), + ) + assert result.return_value is None + assert result.execution_time < 2 + assert result.is_err + assert wrapper_call is True From a646e33792660795e2b6f17847130f94ef84ff3e Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 19 Jul 2024 19:44:11 +0300 Subject: [PATCH 4/4] fix: `if` instead of `while` --- taskiq/receiver/receiver.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 853ed24..c9d5c24 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -269,11 +269,13 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 logger.warning("Timeouts for sync tasks don't work in python well.") with anyio.fail_after(float(timeout)): - while inspect.isawaitable(target_future): + target_future = await target_future + if inspect.isawaitable(target_future): target_future = await target_future else: - while inspect.isawaitable(target_future): + target_future = await target_future + if inspect.isawaitable(target_future): target_future = await target_future returned = target_future