Skip to content

Commit

Permalink
feature: extend FastStream.__init__ (#1734)
Browse files Browse the repository at this point in the history
Add possibility set up tasks by __init__

Co-authored-by: sehat1137 <[email protected]>
  • Loading branch information
Sehat1137 and sehat1137 authored Aug 27, 2024
1 parent 3a00e37 commit 0eadfec
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 4 deletions.
14 changes: 10 additions & 4 deletions faststream/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,23 @@ def __init__(
Union["ExternalDocs", "ExternalDocsDict", "AnyDict"]
] = None,
identifier: Optional[str] = None,
on_startup: Sequence[Callable[P_HookParams, T_HookReturn]] = (),
after_startup: Sequence[Callable[P_HookParams, T_HookReturn]] = (),
on_shutdown: Sequence[Callable[P_HookParams, T_HookReturn]] = (),
after_shutdown: Sequence[Callable[P_HookParams, T_HookReturn]] = (),
) -> None:
context.set_global("app", self)

self.broker = broker
self.logger = logger
self.context = context

self._on_startup_calling = []
self._after_startup_calling = []
self._on_shutdown_calling = []
self._after_shutdown_calling = []
self._on_startup_calling = [apply_types(to_async(x)) for x in on_startup]
self._after_startup_calling = [apply_types(to_async(x)) for x in after_startup]
self._on_shutdown_calling = [apply_types(to_async(x)) for x in on_shutdown]
self._after_shutdown_calling = [
apply_types(to_async(x)) for x in after_shutdown
]

self.lifespan_context = (
apply_types(
Expand Down
74 changes: 74 additions & 0 deletions tests/cli/rabbit/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@ def test_log(app: FastStream, app_without_logger: FastStream):
app_without_logger._log(logging.INFO, "test")


@pytest.mark.asyncio
async def test_on_startup_calls(async_mock: AsyncMock, mock: Mock):
def call1():
mock.call_start1()
assert not async_mock.call_start2.called

async def call2():
await async_mock.call_start2()
assert mock.call_start1.call_count == 1

test_app = FastStream(on_startup=[call1, call2])

await test_app.start()

mock.call_start1.assert_called_once()
async_mock.call_start2.assert_called_once()


@pytest.mark.asyncio
async def test_startup_calls_lifespans(mock: Mock, app_without_broker: FastStream):
def call1():
Expand All @@ -55,6 +73,24 @@ def call2():
mock.call_start2.assert_called_once()


@pytest.mark.asyncio
async def test_on_shutdown_calls(async_mock: AsyncMock, mock: Mock):
def call1():
mock.call_stop1()
assert not async_mock.call_stop2.called

async def call2():
await async_mock.call_stop2()
assert mock.call_stop1.call_count == 1

test_app = FastStream(on_shutdown=[call1, call2])

await test_app.stop()

mock.call_stop1.assert_called_once()
async_mock.call_stop2.assert_called_once()


@pytest.mark.asyncio
async def test_shutdown_calls_lifespans(mock: Mock, app_without_broker: FastStream):
def call1():
Expand All @@ -74,6 +110,25 @@ def call2():
mock.call_stop2.assert_called_once()


@pytest.mark.asyncio
async def test_after_startup_calls(async_mock: AsyncMock, mock: Mock, broker):
def call1():
mock.after_startup1()
assert not async_mock.after_startup2.called

async def call2():
await async_mock.after_startup2()
assert mock.after_startup1.call_count == 1

test_app = FastStream(broker=broker, after_startup=[call1, call2])

with patch.object(test_app.broker, "start", async_mock.broker_start):
await test_app.start()

mock.after_startup1.assert_called_once()
async_mock.after_startup2.assert_called_once()


@pytest.mark.asyncio
async def test_startup_lifespan_before_broker_started(async_mock, app: FastStream):
@app.on_startup
Expand All @@ -95,6 +150,25 @@ async def call_after():
async_mock.before.assert_awaited_once()


@pytest.mark.asyncio
async def test_after_shutdown_calls(async_mock: AsyncMock, mock: Mock, broker):
def call1():
mock.after_shutdown1()
assert not async_mock.after_shutdown2.called

async def call2():
await async_mock.after_shutdown2()
assert mock.after_shutdown1.call_count == 1

test_app = FastStream(broker=broker, after_shutdown=[call1, call2])

with patch.object(test_app.broker, "start", async_mock.broker_start):
await test_app.stop()

mock.after_shutdown1.assert_called_once()
async_mock.after_shutdown2.assert_called_once()


@pytest.mark.asyncio
async def test_shutdown_lifespan_after_broker_stopped(
mock, async_mock, app: FastStream
Expand Down

0 comments on commit 0eadfec

Please sign in to comment.