diff --git a/faststream/app.py b/faststream/app.py index ebf71bfb1f..65449e247d 100644 --- a/faststream/app.py +++ b/faststream/app.py @@ -75,6 +75,10 @@ def __init__( Union["ExternalDocs", "ExternalDocsDict", "AnyDict"] ] = None, identifier: Optional[str] = None, + on_startup: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + after_startup: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + on_shutdown: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + after_shutdown: Sequence[Callable[P_HookParams, T_HookReturn]] = (), ) -> None: context.set_global("app", self) @@ -82,10 +86,12 @@ def __init__( self.logger = logger self.context = context - self._on_startup_calling = [] - self._after_startup_calling = [] - self._on_shutdown_calling = [] - self._after_shutdown_calling = [] + self._on_startup_calling = [apply_types(to_async(x)) for x in on_startup] + self._after_startup_calling = [apply_types(to_async(x)) for x in after_startup] + self._on_shutdown_calling = [apply_types(to_async(x)) for x in on_shutdown] + self._after_shutdown_calling = [ + apply_types(to_async(x)) for x in after_shutdown + ] self.lifespan_context = ( apply_types( diff --git a/tests/cli/rabbit/test_app.py b/tests/cli/rabbit/test_app.py index 27bf29139b..4765ffac9a 100644 --- a/tests/cli/rabbit/test_app.py +++ b/tests/cli/rabbit/test_app.py @@ -36,6 +36,24 @@ def test_log(app: FastStream, app_without_logger: FastStream): app_without_logger._log(logging.INFO, "test") +@pytest.mark.asyncio +async def test_on_startup_calls(async_mock: AsyncMock, mock: Mock): + def call1(): + mock.call_start1() + assert not async_mock.call_start2.called + + async def call2(): + await async_mock.call_start2() + assert mock.call_start1.call_count == 1 + + test_app = FastStream(on_startup=[call1, call2]) + + await test_app.start() + + mock.call_start1.assert_called_once() + async_mock.call_start2.assert_called_once() + + @pytest.mark.asyncio async def test_startup_calls_lifespans(mock: Mock, app_without_broker: FastStream): def call1(): @@ -55,6 +73,24 @@ def call2(): mock.call_start2.assert_called_once() +@pytest.mark.asyncio +async def test_on_shutdown_calls(async_mock: AsyncMock, mock: Mock): + def call1(): + mock.call_stop1() + assert not async_mock.call_stop2.called + + async def call2(): + await async_mock.call_stop2() + assert mock.call_stop1.call_count == 1 + + test_app = FastStream(on_shutdown=[call1, call2]) + + await test_app.stop() + + mock.call_stop1.assert_called_once() + async_mock.call_stop2.assert_called_once() + + @pytest.mark.asyncio async def test_shutdown_calls_lifespans(mock: Mock, app_without_broker: FastStream): def call1(): @@ -74,6 +110,25 @@ def call2(): mock.call_stop2.assert_called_once() +@pytest.mark.asyncio +async def test_after_startup_calls(async_mock: AsyncMock, mock: Mock, broker): + def call1(): + mock.after_startup1() + assert not async_mock.after_startup2.called + + async def call2(): + await async_mock.after_startup2() + assert mock.after_startup1.call_count == 1 + + test_app = FastStream(broker=broker, after_startup=[call1, call2]) + + with patch.object(test_app.broker, "start", async_mock.broker_start): + await test_app.start() + + mock.after_startup1.assert_called_once() + async_mock.after_startup2.assert_called_once() + + @pytest.mark.asyncio async def test_startup_lifespan_before_broker_started(async_mock, app: FastStream): @app.on_startup @@ -95,6 +150,25 @@ async def call_after(): async_mock.before.assert_awaited_once() +@pytest.mark.asyncio +async def test_after_shutdown_calls(async_mock: AsyncMock, mock: Mock, broker): + def call1(): + mock.after_shutdown1() + assert not async_mock.after_shutdown2.called + + async def call2(): + await async_mock.after_shutdown2() + assert mock.after_shutdown1.call_count == 1 + + test_app = FastStream(broker=broker, after_shutdown=[call1, call2]) + + with patch.object(test_app.broker, "start", async_mock.broker_start): + await test_app.stop() + + mock.after_shutdown1.assert_called_once() + async_mock.after_shutdown2.assert_called_once() + + @pytest.mark.asyncio async def test_shutdown_lifespan_after_broker_stopped( mock, async_mock, app: FastStream