Skip to content

Commit

Permalink
Nats batch pull get_one tests + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
KrySeyt committed Sep 1, 2024
1 parent 3c30b4a commit 837c3ef
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 12 deletions.
16 changes: 8 additions & 8 deletions faststream/nats/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,12 +1049,12 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]:
**self.extra_options,
)

raw_message ,= await self.subscription.fetch(
batch=1,
timeout=timeout,
)

if not raw_message:
try:
raw_messages = await self.subscription.fetch(
batch=1,
timeout=timeout,
)
except TimeoutError:
return None

async with AsyncExitStack() as stack:
Expand All @@ -1063,11 +1063,11 @@ async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]:
)

for m in self._broker_middlewares:
mid = m(raw_message)
mid = m(raw_messages)
await stack.enter_async_context(mid)
return_msg = partial(mid.consume_scope, return_msg)

parsed_msg = await self._parser(raw_message)
parsed_msg = await self._parser(raw_messages)
parsed_msg._decoded_body = await self._decoder(parsed_msg)
return await return_msg(parsed_msg)

Expand Down
74 changes: 70 additions & 4 deletions tests/brokers/nats/test_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ async def test_get_one_timeout_js(
message = object()
async def coro():
nonlocal message
message = await subscriber.get_one(timeout=1)
message = await subscriber.get_one(timeout=0.5)

await asyncio.wait(
(
Expand All @@ -488,7 +488,7 @@ async def coro():

assert message is None

async def test_get_one_pool(
async def test_get_one_pull(
self,
queue: str,
event: asyncio.Event,
Expand Down Expand Up @@ -524,7 +524,7 @@ async def publish():
assert message is not None
assert await message.decode() == "test_message"

async def test_get_one_pool_timeout(
async def test_get_one_pull_timeout(
self,
queue: str,
event: asyncio.Event,
Expand All @@ -543,7 +543,73 @@ async def test_get_one_pool_timeout(
message = object
async def consume():
nonlocal message
message = await subscriber.get_one(timeout=1)
message = await subscriber.get_one(timeout=0.5)

await asyncio.wait(
(
asyncio.create_task(consume()),
),
timeout=3
)

assert message is None

async def test_get_one_batch(
self,
queue: str,
event: asyncio.Event,
stream: JStream,
):
broker = self.get_broker(apply_types=True)
subscriber = broker.subscriber(
queue,
stream=stream,
pull_sub=PullSub(1, batch=True),
)

async with self.patch_broker(broker) as br:
await br.start()

message = None
async def consume():
nonlocal message
message = await subscriber.get_one(timeout=5)

async def publish():
await asyncio.sleep(0.5)
await br.publish("test_message", queue)

await asyncio.wait(
(
asyncio.create_task(consume()),
asyncio.create_task(publish()),
),
timeout=10
)

assert message is not None
assert await message.decode() == ["test_message"]

async def test_get_one_batch_timeout(
self,
queue: str,
event: asyncio.Event,
stream: JStream,
):
broker = self.get_broker(apply_types=True)
subscriber = broker.subscriber(
queue,
stream=stream,
pull_sub=PullSub(1, batch=True),
)

async with self.patch_broker(broker) as br:
await br.start()

message = object
async def consume():
nonlocal message
message = await subscriber.get_one(timeout=0.5)

await asyncio.wait(
(
Expand Down

0 comments on commit 837c3ef

Please sign in to comment.