diff --git a/ecowitt2mqtt/runtime.py b/ecowitt2mqtt/runtime.py index bbe785f5..22ec2414 100644 --- a/ecowitt2mqtt/runtime.py +++ b/ecowitt2mqtt/runtime.py @@ -3,7 +3,8 @@ import asyncio import traceback -from contextlib import suppress +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager, suppress from ssl import SSLContext from typing import TYPE_CHECKING, Any @@ -44,7 +45,21 @@ def __init__(self, ecowitt: Ecowitt) -> None: self._rest_api_server_task: asyncio.Task | None = None self.ecowitt = ecowitt - fastapi = FastAPI() + @asynccontextmanager + async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]: + """Define a lifespan context manager.""" + yield + + # Upon shutdown: + for task in self._mqtt_loop_tasks: + if task.done(): + continue + with suppress(asyncio.CancelledError): + LOGGER.debug("Cancelling MQTT loop task: %s", task.get_name()) + task.cancel() + LOGGER.debug("Runtime shutdown complete") + + fastapi = FastAPI(lifespan=lifespan) for config in ecowitt.configs.iterate(): if config.endpoint not in self._api_servers: api_server = self._api_servers[config.endpoint] = get_api_server( @@ -123,13 +138,6 @@ async def create_loop() -> None: payload_event.clear() retry_attempt = 0 except MqttError as err: - if self._uvicorn.should_exit: - # If we've instructed the server to shutdown while we're - # disconnected from the MQTT broker, we'll land here and be - # delayed from full shutdown while we attempt to reconnect. - # In this case, we raise a CancelledError to allow the task - # to exit: - raise asyncio.CancelledError LOGGER.error("There was an MQTT error: %s", err) payload_event.clear() retry_attempt += 1 @@ -140,15 +148,14 @@ async def create_loop() -> None: retry_attempt, ) await asyncio.sleep(delay) - except asyncio.CancelledError: - LOGGER.debug("Stopping MQTT process loop") - raise except Exception as err: # pylint: disable=broad-except LOGGER.exception("Exception caused a shutdown: %s", err) LOGGER.debug("".join(traceback.format_tb(err.__traceback__))) self.stop() - return asyncio.create_task(create_loop()) + task = asyncio.create_task(create_loop()) + task.set_name(f"mqtt_loop_{config.uuid}") + return task def _process_payload(self, payload: dict[str, Any]) -> None: """Define an endpoint for the Ecowitt device to post data to. @@ -176,15 +183,7 @@ async def async_start(self) -> None: """Start the runtime.""" LOGGER.debug("Starting runtime") self._rest_api_server_task = asyncio.create_task(self._uvicorn.serve()) - try: - await self._rest_api_server_task - except asyncio.CancelledError: - for task in self._mqtt_loop_tasks: - if task.done(): - continue - with suppress(asyncio.CancelledError): - task.cancel() - LOGGER.debug("Runtime shutdown complete") + await self._rest_api_server_task def stop(self) -> None: """Stop the REST API server.""" diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 541a47a7..c3b1235e 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -57,38 +57,6 @@ async def test_publish_failure( assert any(m for m in caplog.messages if "There was an MQTT error" in m) -@pytest.mark.asyncio -@pytest.mark.parametrize("mqtt_publish_side_effect", [AsyncMock(side_effect=MqttError)]) -async def test_publish_failure_during_shutdown( - caplog: Mock, - device_data: dict[str, Any], - ecowitt: Ecowitt, - setup_aiomqtt: AsyncGenerator[None, None], - setup_uvicorn_server: AsyncGenerator[None, None], -) -> None: - """Test a failed MQTT publish during a runtime shutdown. - - This is a sanity check to ensure that the runtime doesn't attempt to reconnect to - the MQTT broker when it's already in the process of shutting down. - - Args: - caplog: A mock logging utility. - device_data: A dictionary of device data. - ecowitt: A parsed Ecowitt object. - setup_aiomqtt: A mock aiomqtt client connection. - setup_uvicorn_server: A mock Uvicorn + FastAPI application. - """ - async with ClientSession() as session: - # Simulate a server shutdown request: - ecowitt.runtime._uvicorn.should_exit = True # pylint: disable=protected-access - await session.request( - "post", f"http://127.0.0.1:{TEST_PORT}{TEST_ENDPOINT}", data=device_data - ) - - await asyncio.sleep(0.1) - assert not any(m for m in caplog.messages if "There was an MQTT error" in m) - - @pytest.mark.asyncio @pytest.mark.parametrize( "config",