Skip to content

Commit

Permalink
fix: add pattern checking (airtai#1590)
Browse files Browse the repository at this point in the history
* fix: add pattern checking

* feat: full Kafka Pattern support

* chore: bump version

---------

Co-authored-by: Nikita Pastukhov <[email protected]>
Co-authored-by: Pastukhov Nikita <[email protected]>
  • Loading branch information
3 people authored Jul 18, 2024
1 parent 4d210e0 commit 15cbe6b
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 23 deletions.
15 changes: 15 additions & 0 deletions docs/docs/en/kafka/Subscriber/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,18 @@ The function decorated with the `#!python @broker.subscriber(...)` decorator wil
The message will then be injected into the typed `msg` argument of the function, and its type will be used to parse the message.

In this example case, when the message is sent to a `#!python "hello_world"` topic, it will be parsed into a `HelloWorld` class, and the `on_hello_world` function will be called with the parsed class as the `msg` argument value.

### Pattern data access

You can also use pattern subscription feature to encode some data directly in the topic name. With **FastStream** you can easily access this data using the following code:

```python hl_lines="3 6"
from faststream import Path

@broker.subscriber(pattern="logs.{level}")
async def base_handler(
body: str,
level: str = Path(),
):
...
```
2 changes: 1 addition & 1 deletion faststream/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Simple and fast framework to create message brokers based microservices."""

__version__ = "0.5.14"
__version__ = "0.5.15"

SERVICE_NAME = f"faststream-{__version__}"

Expand Down
18 changes: 17 additions & 1 deletion faststream/kafka/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from faststream.utils.context.repository import context

if TYPE_CHECKING:
from re import Pattern

from aiokafka import ConsumerRecord

from faststream.broker.message import StreamMessage
Expand All @@ -15,8 +17,13 @@
class AioKafkaParser:
"""A class to parse Kafka messages."""

def __init__(self, msg_class: Type[KafkaMessage]) -> None:
def __init__(
self,
msg_class: Type[KafkaMessage],
regex: Optional["Pattern[str]"],
) -> None:
self.msg_class = msg_class
self.regex = regex

async def parse_message(
self,
Expand All @@ -25,6 +32,7 @@ async def parse_message(
"""Parses a Kafka message."""
headers = {i: j.decode() for i, j in message.headers}
handler: Optional[LogicSubscriber[Any]] = context.get_local("handler_")

return self.msg_class(
body=message.value,
headers=headers,
Expand All @@ -33,6 +41,7 @@ async def parse_message(
message_id=f"{message.offset}-{message.timestamp}",
correlation_id=headers.get("correlation_id", gen_cor_id()),
raw_message=message,
path=self.get_path(message.topic),
consumer=getattr(handler, "consumer", None) or FAKE_CONSUMER,
)

Expand All @@ -43,6 +52,12 @@ async def decode_message(
"""Decodes a message."""
return decode_message(msg)

def get_path(self, topic: str) -> Dict[str, str]:
if self.regex and (match := self.regex.match(topic)):
return match.groupdict()
else:
return {}


class AioKafkaBatchParser(AioKafkaParser):
async def parse_message(
Expand Down Expand Up @@ -73,6 +88,7 @@ async def parse_message(
message_id=f"{first.offset}-{last.offset}-{first.timestamp}",
correlation_id=headers.get("correlation_id", gen_cor_id()),
raw_message=message,
path=self.get_path(first.topic),
consumer=getattr(handler, "consumer", None) or FAKE_CONSUMER,
)

Expand Down
48 changes: 36 additions & 12 deletions faststream/kafka/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from faststream.kafka.message import KafkaAckableMessage, KafkaMessage
from faststream.kafka.parser import AioKafkaBatchParser, AioKafkaParser
from faststream.utils.path import compile_path

if TYPE_CHECKING:
from aiokafka import AIOKafkaConsumer, ConsumerRecord
Expand Down Expand Up @@ -93,15 +94,16 @@ def __init__(
self.partitions = partitions
self.group_id = group_id

self.builder = None
self.consumer = None
self.task = None
self._pattern = pattern
self.__listener = listener
self.__connection_args = connection_args

# Setup it later
self.client_id = ""
self.__pattern = pattern
self.__listener = listener
self.__connection_args = connection_args
self.builder = None

self.consumer = None
self.task = None

@override
def setup( # type: ignore[override]
Expand Down Expand Up @@ -149,10 +151,10 @@ async def start(self) -> None:
**self.__connection_args,
)

if self.topics:
if self.topics or self._pattern:
consumer.subscribe(
topics=self.topics,
pattern=self.__pattern,
pattern=self._pattern,
listener=self.__listener,
)

Expand Down Expand Up @@ -229,8 +231,8 @@ def get_routing_hash(

@property
def topic_names(self) -> List[str]:
if self.__pattern:
return [self.__pattern]
if self._pattern:
return [self._pattern]
elif self.topics:
return list(self.topics)
else:
Expand Down Expand Up @@ -305,8 +307,19 @@ def __init__(
description_: Optional[str],
include_in_schema: bool,
) -> None:
if pattern:
reg, pattern = compile_path(
pattern,
replace_symbol=".*",
patch_regex=lambda x: x.replace(r"\*", ".*"),
)

else:
reg = None

parser = AioKafkaParser(
msg_class=KafkaAckableMessage if is_manual else KafkaMessage
msg_class=KafkaAckableMessage if is_manual else KafkaMessage,
regex=reg,
)

super().__init__(
Expand Down Expand Up @@ -365,8 +378,19 @@ def __init__(
self.batch_timeout_ms = batch_timeout_ms
self.max_records = max_records

if pattern:
reg, pattern = compile_path(
pattern,
replace_symbol=".*",
patch_regex=lambda x: x.replace(r"\*", ".*"),
)

else:
reg = None

parser = AioKafkaBatchParser(
msg_class=KafkaAckableMessage if is_manual else KafkaMessage
msg_class=KafkaAckableMessage if is_manual else KafkaMessage,
regex=reg,
)

super().__init__(
Expand Down
27 changes: 19 additions & 8 deletions faststream/kafka/testing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from unittest.mock import AsyncMock, MagicMock
Expand All @@ -16,6 +17,7 @@
if TYPE_CHECKING:
from faststream.broker.wrapper.call import HandlerCallWrapper
from faststream.kafka.publisher.asyncapi import AsyncAPIPublisher
from faststream.kafka.subscriber.usecase import LogicSubscriber
from faststream.types import SendableMessage

__all__ = ("TestKafkaBroker",)
Expand Down Expand Up @@ -108,13 +110,7 @@ async def publish( # type: ignore[override]
return_value = None

for handler in self.broker._subscribers.values(): # pragma: no branch
if (
any(
p.topic == topic and (partition is None or p.partition == partition)
for p in handler.partitions
)
or topic in handler.topics
):
if _is_handler_matches(handler, topic, partition):
handle_value = await call_handler(
handler=handler,
message=[incoming]
Expand All @@ -141,7 +137,7 @@ async def publish_batch(
) -> None:
"""Publish a batch of messages to the Kafka broker."""
for handler in self.broker._subscribers.values(): # pragma: no branch
if topic in handler.topics:
if _is_handler_matches(handler, topic, partition):
messages = (
build_message(
message=message,
Expand Down Expand Up @@ -215,3 +211,18 @@ def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock:
mock.subscribe = MagicMock
mock.assign = MagicMock
return mock


def _is_handler_matches(
handler: "LogicSubscriber[Any]",
topic: str,
partition: Optional[int],
) -> bool:
return bool(
any(
p.topic == topic and (partition is None or p.partition == partition)
for p in handler.partitions
)
or topic in handler.topics
or (handler._pattern and re.match(handler._pattern, topic))
)
35 changes: 35 additions & 0 deletions tests/brokers/kafka/test_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,41 @@ class TestConsume(BrokerRealConsumeTestcase):
def get_broker(self, apply_types: bool = False):
return KafkaBroker(apply_types=apply_types)

@pytest.mark.asyncio()
async def test_consume_by_pattern(
self,
queue: str,
event: asyncio.Event,
):
consume_broker = self.get_broker()

@consume_broker.subscriber(queue)
async def handler(msg):
event.set()

pattern_event = asyncio.Event()

@consume_broker.subscriber(pattern=f"{queue[:-1]}*")
async def pattern_handler(msg):
pattern_event.set()

async with self.patch_broker(consume_broker) as br:
await br.start()

await br.publish(1, topic=queue)

await asyncio.wait(
(
asyncio.create_task(br.publish(1, topic=queue)),
asyncio.create_task(event.wait()),
asyncio.create_task(pattern_event.wait()),
),
timeout=3,
)

assert event.is_set()
assert pattern_event.is_set()

@pytest.mark.asyncio()
async def test_consume_batch(self, queue: str):
consume_broker = self.get_broker()
Expand Down
30 changes: 29 additions & 1 deletion tests/utils/context/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,35 @@
import pytest

from faststream import Path
from tests.marks import require_aiopika, require_nats, require_redis
from tests.marks import require_aiokafka, require_aiopika, require_nats, require_redis


@pytest.mark.asyncio()
@require_aiokafka
async def test_aiokafka_path():
from faststream.kafka import KafkaBroker, TestKafkaBroker

broker = KafkaBroker()

@broker.subscriber(pattern="in.{name}.{id}")
async def h(
name: str = Path(),
id_: int = Path("id"),
):
assert name == "john"
assert id_ == 1
return 1

async with TestKafkaBroker(broker) as br:
assert (
await br.publish(
"",
"in.john.1",
rpc=True,
rpc_timeout=1.0,
)
== 1
)


@pytest.mark.asyncio()
Expand Down

0 comments on commit 15cbe6b

Please sign in to comment.