Skip to content

Commit

Permalink
Support eager mode in async celery
Browse files Browse the repository at this point in the history
SDESK-7371
  • Loading branch information
eos87 committed Nov 1, 2024
1 parent beb6af3 commit 700ffb5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
39 changes: 34 additions & 5 deletions superdesk/celery_app/context_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def run_async(self, *args: Any, **kwargs: Any) -> Any:
If the event loop is running, returns an asyncio.Task that represents the execution of the coroutine.
Otherwise it runs the tasks and returns the result of the task.
"""

loop = asyncio.get_event_loop()
is_always_eager = self._is_always_eager()

# We need a wrapper to handle exceptions inside the async function because asyncio
# does not propagate them in the same way as synchronous exceptions. This ensures that
Expand All @@ -67,10 +66,36 @@ async def wrapper():
self.handle_exception(e)
return None

if not loop.is_running():
return loop.run_until_complete(wrapper())
if is_always_eager:
return asyncio.create_task(wrapper())
else:
background_tasks = set()
loop = asyncio.get_event_loop()

# the loop might not be running even if `CELERY_TASK_ALWAYS_EAGER` is False
if not loop.is_running():
return loop.run_until_complete(wrapper())

# **Important** from asyncio documentation
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
# Save a reference to the result of this function, to avoid a task disappearing mid-execution.
# The event loop only keeps weak references to tasks. A task that isn’t referenced elsewhere may get
# garbage collected at any time, even before it’s done
task = asyncio.create_task(wrapper())
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
return task

async def apply_async(self, args: Tuple = (), kwargs: Dict = {}, **other_kwargs) -> Any:
"""
Schedules the task asynchronously. Awaits the result if `CELERY_TASK_ALWAYS_EAGER` is True.
"""
# directly run and await the task if eager
if self._is_always_eager():
async_result = super().apply_async(args=args, kwargs=kwargs, **other_kwargs)
return await async_result.get()

return asyncio.create_task(wrapper())
return super().apply_async(args=args, kwargs=kwargs, **other_kwargs)

def handle_exception(self, exc: Exception) -> None:
"""
Expand All @@ -85,3 +110,7 @@ def on_failure(self, exc: Exception, task_id: str, args: Tuple, kwargs: Dict, ei
# TODO-ASYNC: Support async with ``on_failure`` method
# async with self.get_current_app().app_context():
self.handle_exception(exc)

def _is_always_eager(self):
app = self.get_current_app()
return app.config.get("CELERY_TASK_ALWAYS_EAGER", False)
14 changes: 8 additions & 6 deletions tests/celery_app/context_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@

from superdesk.errors import SuperdeskError
from superdesk.celery_app import HybridAppContextTask
from superdesk.tests import AsyncFlaskTestCase, markers
from superdesk.tests import AsyncFlaskTestCase

# NOTE: all tasks below are in eager mode because of global
# tests settings. See `update_config` function in tests.__init__.py


@markers.requires_async_celery
class TestHybridAppContextTask(AsyncFlaskTestCase):
async def test_sync_task(self):
@self.app.celery.task(base=HybridAppContextTask)
def sync_task():
return "sync result"

result = sync_task.apply_async().get()
result = await sync_task.apply_async()
self.assertEqual(result, "sync result")

async def test_async_task(self):
Expand All @@ -22,7 +24,7 @@ async def async_task():
await asyncio.sleep(0.1)
return "async result"

result = await async_task.apply_async().get()
result = await async_task.apply_async()
self.assertEqual(result, "async result")

async def test_sync_task_exception(self):
Expand All @@ -31,7 +33,7 @@ def sync_task_exception():
raise SuperdeskError("Test exception")

with patch("superdesk.celery_app.context_task.logger") as mock_logger:
sync_task_exception.apply_async().get(propagate=True)
await sync_task_exception.apply_async()
expected_exc = SuperdeskError("Test exception")
expected_msg = f"Error handling task: {str(expected_exc)}"
mock_logger.exception.assert_called_once_with(expected_msg)
Expand All @@ -42,7 +44,7 @@ async def async_task_exception():
raise SuperdeskError("Async exception")

with patch("superdesk.celery_app.context_task.logger") as mock_logger:
await async_task_exception.apply_async().get()
await async_task_exception.apply_async()

expected_exc = SuperdeskError("Async exception")
expected_msg = f"Error handling task: {str(expected_exc)}"
Expand Down

0 comments on commit 700ffb5

Please sign in to comment.