From 837c3eff33ab2c97e0a7b879ba0140efd90190ee Mon Sep 17 00:00:00 2001 From: Vladimir Kibisov Date: Sun, 1 Sep 2024 16:38:21 +0300 Subject: [PATCH] Nats batch pull get_one tests + fixes --- faststream/nats/subscriber/usecase.py | 16 +++--- tests/brokers/nats/test_consume.py | 74 +++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 15e4676ea6..c8984aecdb 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -1049,12 +1049,12 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: **self.extra_options, ) - raw_message ,= await self.subscription.fetch( - batch=1, - timeout=timeout, - ) - - if not raw_message: + try: + raw_messages = await self.subscription.fetch( + batch=1, + timeout=timeout, + ) + except TimeoutError: return None async with AsyncExitStack() as stack: @@ -1063,11 +1063,11 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]: ) for m in self._broker_middlewares: - mid = m(raw_message) + mid = m(raw_messages) await stack.enter_async_context(mid) return_msg = partial(mid.consume_scope, return_msg) - parsed_msg = await self._parser(raw_message) + parsed_msg = await self._parser(raw_messages) parsed_msg._decoded_body = await self._decoder(parsed_msg) return await return_msg(parsed_msg) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 525a30ef88..5acf1aff78 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -477,7 +477,7 @@ async def test_get_one_timeout_js( message = object() async def coro(): nonlocal message - message = await subscriber.get_one(timeout=1) + message = await subscriber.get_one(timeout=0.5) await asyncio.wait( ( @@ -488,7 +488,7 @@ async def coro(): assert message is None - async def test_get_one_pool( + async def test_get_one_pull( self, queue: str, event: asyncio.Event, @@ -524,7 +524,7 @@ async def publish(): assert message is not None assert await message.decode() == "test_message" - async def test_get_one_pool_timeout( + async def test_get_one_pull_timeout( self, queue: str, event: asyncio.Event, @@ -543,7 +543,73 @@ async def test_get_one_pool_timeout( message = object async def consume(): nonlocal message - message = await subscriber.get_one(timeout=1) + message = await subscriber.get_one(timeout=0.5) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + ), + timeout=3 + ) + + assert message is None + + async def test_get_one_batch( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber( + queue, + stream=stream, + pull_sub=PullSub(1, batch=True), + ) + + async with self.patch_broker(broker) as br: + await br.start() + + message = None + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=5) + + async def publish(): + await asyncio.sleep(0.5) + await br.publish("test_message", queue) + + await asyncio.wait( + ( + asyncio.create_task(consume()), + asyncio.create_task(publish()), + ), + timeout=10 + ) + + assert message is not None + assert await message.decode() == ["test_message"] + + async def test_get_one_batch_timeout( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + broker = self.get_broker(apply_types=True) + subscriber = broker.subscriber( + queue, + stream=stream, + pull_sub=PullSub(1, batch=True), + ) + + async with self.patch_broker(broker) as br: + await br.start() + + message = object + async def consume(): + nonlocal message + message = await subscriber.get_one(timeout=0.5) await asyncio.wait( (