Skip to content

Commit

Permalink
fix: patch broker within testbroker context only (airtai#1619)
Browse files Browse the repository at this point in the history
* fix: patch broker within testbroker context only

* tests: mark test connection required

* tests: fix some warnings

* fix: better monkeypatching + with_real adjustment

* tests: fix real publisher test

* fix: multiple listeners for handler
test: update signature of testclient test in redis

* refactor: do not start NATS subscribers twice

* fix: python3.8 compatibility

* refactor: unify already started subsriber scipping

* chore: use uv to speedup docs building

* chore: revert uv

* chore: use uv with editable

* chore: fix CI

---------

Co-authored-by: Nikita Pastukhov <[email protected]>
Co-authored-by: Pastukhov Nikita <[email protected]>
  • Loading branch information
3 people authored Aug 4, 2024
1 parent 35aac56 commit d705076
Show file tree
Hide file tree
Showing 18 changed files with 263 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .codespell-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
dependant
unsecure
socio-economic
socio-economic
7 changes: 6 additions & 1 deletion .github/workflows/docs_update-references.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ jobs:
cache-dependency-path: pyproject.toml
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: pip install -e ".[dev]"
shell: bash
# should install with `-e`
run: |
set -ux
python -m pip install uv
uv pip install --system -e ".[dev]"
- name: Run build docs
run: bash scripts/build-docs.sh
- name: Commit
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pr_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: |
set -ux
python -m pip install uv
uv pip install --system -e ".[lint]"
uv pip install --system ".[lint]"
- name: Run ruff
shell: bash
Expand Down
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
"filename": "docs/docs/en/release.md",
"hashed_secret": "35675e68f4b5af7b995d9205ad0fc43842f16450",
"is_verified": false,
"line_number": 1325,
"line_number": 1423,
"is_secret": false
}
],
Expand Down Expand Up @@ -163,5 +163,5 @@
}
]
},
"generated_at": "2024-06-10T09:56:52Z"
"generated_at": "2024-07-23T21:38:30Z"
}
2 changes: 1 addition & 1 deletion docs/docs/en/kafka/Subscriber/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ async def base_handler(
level: str = Path(),
):
...
```
```
27 changes: 27 additions & 0 deletions faststream/nats/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ async def _create_subscription( # type: ignore[override]
connection: "Client",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.subscription = await connection.subscribe(
subject=self.clear_subject,
queue=self.queue,
Expand Down Expand Up @@ -495,6 +498,9 @@ async def _create_subscription( # type: ignore[override]
connection: "Client",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.start_consume_task()

self.subscription = await connection.subscribe(
Expand Down Expand Up @@ -576,6 +582,9 @@ async def _create_subscription( # type: ignore[override]
connection: "JetStreamContext",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.subscription = await connection.subscribe(
subject=self.clear_subject,
queue=self.queue,
Expand Down Expand Up @@ -636,6 +645,9 @@ async def _create_subscription( # type: ignore[override]
connection: "JetStreamContext",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.start_consume_task()

self.subscription = await connection.subscribe(
Expand Down Expand Up @@ -698,6 +710,9 @@ async def _create_subscription( # type: ignore[override]
connection: "JetStreamContext",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.subscription = await connection.pull_subscribe(
subject=self.clear_subject,
config=self.config,
Expand Down Expand Up @@ -775,6 +790,9 @@ async def _create_subscription( # type: ignore[override]
connection: "JetStreamContext",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.start_consume_task()

self.subscription = await connection.pull_subscribe(
Expand Down Expand Up @@ -841,6 +859,9 @@ async def _create_subscription( # type: ignore[override]
connection: "JetStreamContext",
) -> None:
"""Create NATS subscription and start consume task."""
if self.subscription:
return

self.subscription = await connection.pull_subscribe(
subject=self.clear_subject,
config=self.config,
Expand Down Expand Up @@ -905,6 +926,9 @@ async def _create_subscription( # type: ignore[override]
*,
connection: "KVBucketDeclarer",
) -> None:
if self.subscription:
return

bucket = await connection.create_key_value(
bucket=self.kv_watch.name,
declare=self.kv_watch.declare,
Expand Down Expand Up @@ -1012,6 +1036,9 @@ async def _create_subscription( # type: ignore[override]
*,
connection: "OSBucketDeclarer",
) -> None:
if self.subscription:
return

self.bucket = await connection.create_object_store(
bucket=self.subject,
declare=self.obj_watch.declare,
Expand Down
20 changes: 15 additions & 5 deletions faststream/rabbit/testing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Any, Optional, Union
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Generator, Optional, Union
from unittest import mock
from unittest.mock import AsyncMock

import aiormq
Expand Down Expand Up @@ -34,10 +36,18 @@ class TestRabbitBroker(TestBroker[RabbitBroker]):
"""A class to test RabbitMQ brokers."""

@classmethod
def _patch_test_broker(cls, broker: RabbitBroker) -> None:
broker._channel = AsyncMock()
broker.declarer = AsyncMock()
super()._patch_test_broker(broker)
@contextmanager
def _patch_broker(cls, broker: RabbitBroker) -> Generator[None, None, None]:
with mock.patch.object(
broker,
"_channel",
new_callable=AsyncMock,
), mock.patch.object(
broker,
"declarer",
new_callable=AsyncMock,
), super()._patch_broker(broker):
yield

@staticmethod
async def _fake_connect(broker: RabbitBroker, *args: Any, **kwargs: Any) -> None:
Expand Down
17 changes: 16 additions & 1 deletion faststream/redis/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,15 @@ async def start( # type: ignore[override]
self,
*args: Any,
) -> None:
if self.task:
return

await super().start()

start_signal = anyio.Event()
self.task = asyncio.create_task(self._consume(*args, start_signal=start_signal))
self.task = asyncio.create_task(
self._consume(*args, start_signal=start_signal)
)

with anyio.fail_after(3.0):
await start_signal.wait()
Expand Down Expand Up @@ -253,6 +258,9 @@ def get_log_context(

@override
async def start(self) -> None:
if self.subscription:
return

assert self._client, "You should setup subscriber at first." # nosec B101

self.subscription = psub = self._client.pubsub()
Expand Down Expand Up @@ -352,6 +360,9 @@ async def _consume( # type: ignore[override]

@override
async def start(self) -> None:
if self.task:
return

assert self._client, "You should setup subscriber at first." # nosec B101
await super().start(self._client)

Expand Down Expand Up @@ -512,7 +523,11 @@ def get_log_context(

@override
async def start(self) -> None:
if self.task:
return

assert self._client, "You should setup subscriber at first." # nosec B101

client = self._client

self.extra_watcher_options.update(
Expand Down
55 changes: 38 additions & 17 deletions faststream/testing/broker.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
import warnings
from abc import abstractmethod
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, contextmanager
from functools import partial
from types import MethodType
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Generator,
Generic,
Optional,
Tuple,
Type,
TypeVar,
)
from unittest.mock import AsyncMock, MagicMock
from unittest import mock
from unittest.mock import MagicMock

from faststream.broker.core.usecase import BrokerUsecase
from faststream.broker.message import StreamMessage, decode_message, encode_message
from faststream.broker.middlewares.logging import CriticalLogMiddleware
from faststream.broker.wrapper.call import HandlerCallWrapper
from faststream.testing.app import TestApp
from faststream.utils.ast import is_contains_context_name
from faststream.utils.functions import timeout_scope
from faststream.utils.functions import sync_fake_context, timeout_scope

if TYPE_CHECKING:
from types import TracebackType

from faststream.broker.subscriber.proto import SubscriberProto
from faststream.broker.types import BrokerMiddleware


Broker = TypeVar("Broker", bound=BrokerUsecase[Any, Any])


Expand Down Expand Up @@ -113,22 +113,43 @@ async def __aexit__(self, *args: Any) -> None:
async def _create_ctx(self) -> AsyncGenerator[Broker, None]:
if self.with_real:
self._fake_start(self.broker)
context = sync_fake_context()
else:
self._patch_test_broker(self.broker)
context = self._patch_broker(self.broker)

async with self.broker:
try:
if not self.connect_only:
await self.broker.start()
yield self.broker
finally:
self._fake_close(self.broker)
with context:
async with self.broker:
try:
if not self.connect_only:
await self.broker.start()
yield self.broker
finally:
self._fake_close(self.broker)

@classmethod
def _patch_test_broker(cls, broker: Broker) -> None:
broker.start = AsyncMock(wraps=partial(cls._fake_start, broker)) # type: ignore[method-assign]
broker._connect = MethodType(cls._fake_connect, broker) # type: ignore[method-assign]
broker.close = AsyncMock() # type: ignore[method-assign]
@contextmanager
def _patch_broker(cls, broker: Broker) -> Generator[None, None, None]:
with mock.patch.object(
broker,
"start",
wraps=partial(cls._fake_start, broker),
), mock.patch.object(
broker,
"_connect",
wraps=partial(cls._fake_connect, broker),
), mock.patch.object(
broker,
"close",
), mock.patch.object(
broker,
"_connection",
new=None,
), mock.patch.object(
broker,
"_producer",
new=None,
):
yield

@classmethod
def _fake_start(cls, broker: Broker, *args: Any, **kwargs: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/brokers/base/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ async def test_ping_timeout(self, settings):
kwargs = self.get_broker_args(settings)
broker = self.broker("wrong_url")
await broker.connect(**kwargs)
assert not await broker.ping(timeout=0.00001)
assert not await broker.ping(timeout=1e-24)
await broker.close()
Loading

0 comments on commit d705076

Please sign in to comment.