From 8dc684a760f657c2f776ed2d8b0059bc26ff1317 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 29 Nov 2024 16:32:59 +0100 Subject: [PATCH] fix: make SimpleMessageQueue server cancellable --- .../message_queue_simple/__init__.py | 0 .../test_message_queue.py | 19 ++++++++++++++++++ llama_deploy/message_queues/simple.py | 20 +++++++++---------- 3 files changed, 28 insertions(+), 11 deletions(-) create mode 100644 e2e_tests/message_queues/message_queue_simple/__init__.py create mode 100644 e2e_tests/message_queues/message_queue_simple/test_message_queue.py diff --git a/e2e_tests/message_queues/message_queue_simple/__init__.py b/e2e_tests/message_queues/message_queue_simple/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/e2e_tests/message_queues/message_queue_simple/test_message_queue.py b/e2e_tests/message_queues/message_queue_simple/test_message_queue.py new file mode 100644 index 00000000..3d6dc0ff --- /dev/null +++ b/e2e_tests/message_queues/message_queue_simple/test_message_queue.py @@ -0,0 +1,19 @@ +import asyncio + +import pytest + +from llama_deploy import SimpleMessageQueue + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_cancel_launch_server(): + mq = SimpleMessageQueue() + t = asyncio.create_task(mq.launch_server()) + + # Make sure the queue starts + await asyncio.sleep(1) + + # Cancel + t.cancel() + await t diff --git a/llama_deploy/message_queues/simple.py b/llama_deploy/message_queues/simple.py index dbb65277..b345c6c2 100644 --- a/llama_deploy/message_queues/simple.py +++ b/llama_deploy/message_queues/simple.py @@ -3,9 +3,8 @@ import asyncio import random from collections import deque -from contextlib import asynccontextmanager from logging import getLogger -from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence +from typing import Any, Dict, List, Literal, Optional, Sequence from urllib.parse import urljoin import httpx @@ -197,7 +196,7 @@ def __init__( internal_port=internal_port, ) - self._app = FastAPI(lifespan=self.lifespan) + self._app = FastAPI() self._app.add_api_route( "/", self.home, methods=["GET"], tags=["Message Queue State"] @@ -392,13 +391,6 @@ async def processing_loop(self) -> None: ) # TODO dedicated ack await asyncio.sleep(0.1) - @asynccontextmanager - async def lifespan(self, app: FastAPI) -> AsyncGenerator[None, None]: - """Starts the processing loop when the fastapi app starts.""" - asyncio.create_task(self.processing_loop()) - yield - self.running = False - async def launch_local(self) -> asyncio.Task: """Launch the message queue locally, in-process.""" logger.info("Launching message queue locally") @@ -417,7 +409,13 @@ def install_signal_handlers(self) -> None: cfg = uvicorn.Config(self._app, host=host, port=port) server = CustomServer(cfg) - await server.serve() + pl_task = asyncio.create_task(self.processing_loop()) + + try: + await server.serve() + except asyncio.CancelledError: + self.running = False + await asyncio.gather(server.shutdown(), pl_task, return_exceptions=True) async def home(self) -> Dict[str, str]: return {