From 6e488fefb82030de25a968dec01fbb6ee7d80b2d Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Thu, 18 Jul 2024 20:17:32 +0300 Subject: [PATCH] feat: full Kafka Pattern support --- docs/docs/en/kafka/Subscriber/index.md | 15 ++++++++ faststream/kafka/parser.py | 18 +++++++++- faststream/kafka/subscriber/usecase.py | 48 +++++++++++++++++++------- faststream/kafka/testing.py | 27 ++++++++++----- tests/brokers/kafka/test_consume.py | 35 +++++++++++++++++++ tests/utils/context/test_path.py | 30 +++++++++++++++- 6 files changed, 151 insertions(+), 22 deletions(-) diff --git a/docs/docs/en/kafka/Subscriber/index.md b/docs/docs/en/kafka/Subscriber/index.md index 72c61cc84f..f40c9a20ca 100644 --- a/docs/docs/en/kafka/Subscriber/index.md +++ b/docs/docs/en/kafka/Subscriber/index.md @@ -57,3 +57,18 @@ The function decorated with the `#!python @broker.subscriber(...)` decorator wil The message will then be injected into the typed `msg` argument of the function, and its type will be used to parse the message. In this example case, when the message is sent to a `#!python "hello_world"` topic, it will be parsed into a `HelloWorld` class, and the `on_hello_world` function will be called with the parsed class as the `msg` argument value. + +### Pattern data access + +You can also use pattern subscription feature to encode some data directly in the topic name. With **FastStream** you can easily access this data using the following code: + +```python hl_lines="3 6" +from faststream import Path + +@broker.subscriber(pattern="logs.{level}") +async def base_handler( + body: str, + level: str = Path(), +): + ... +``` \ No newline at end of file diff --git a/faststream/kafka/parser.py b/faststream/kafka/parser.py index 44886c6028..8e35ed0a02 100644 --- a/faststream/kafka/parser.py +++ b/faststream/kafka/parser.py @@ -5,6 +5,8 @@ from faststream.utils.context.repository import context if TYPE_CHECKING: + from re import Pattern + from aiokafka import ConsumerRecord from faststream.broker.message import StreamMessage @@ -15,8 +17,13 @@ class AioKafkaParser: """A class to parse Kafka messages.""" - def __init__(self, msg_class: Type[KafkaMessage]) -> None: + def __init__( + self, + msg_class: Type[KafkaMessage], + regex: Optional["Pattern[str]"], + ) -> None: self.msg_class = msg_class + self.regex = regex async def parse_message( self, @@ -25,6 +32,7 @@ async def parse_message( """Parses a Kafka message.""" headers = {i: j.decode() for i, j in message.headers} handler: Optional[LogicSubscriber[Any]] = context.get_local("handler_") + return self.msg_class( body=message.value, headers=headers, @@ -33,6 +41,7 @@ async def parse_message( message_id=f"{message.offset}-{message.timestamp}", correlation_id=headers.get("correlation_id", gen_cor_id()), raw_message=message, + path=self.get_path(message.topic), consumer=getattr(handler, "consumer", None) or FAKE_CONSUMER, ) @@ -43,6 +52,12 @@ async def decode_message( """Decodes a message.""" return decode_message(msg) + def get_path(self, topic: str) -> Dict[str, str]: + if self.regex and (match := self.regex.match(topic)): + return match.groupdict() + else: + return {} + class AioKafkaBatchParser(AioKafkaParser): async def parse_message( @@ -73,6 +88,7 @@ async def parse_message( message_id=f"{first.offset}-{last.offset}-{first.timestamp}", correlation_id=headers.get("correlation_id", gen_cor_id()), raw_message=message, + path=self.get_path(first.topic), consumer=getattr(handler, "consumer", None) or FAKE_CONSUMER, ) diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index e753945538..e9297d0bb1 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -28,6 +28,7 @@ ) from faststream.kafka.message import KafkaAckableMessage, KafkaMessage from faststream.kafka.parser import AioKafkaBatchParser, AioKafkaParser +from faststream.utils.path import compile_path if TYPE_CHECKING: from aiokafka import AIOKafkaConsumer, ConsumerRecord @@ -93,15 +94,16 @@ def __init__( self.partitions = partitions self.group_id = group_id - self.builder = None - self.consumer = None - self.task = None + self._pattern = pattern + self.__listener = listener + self.__connection_args = connection_args # Setup it later self.client_id = "" - self.__pattern = pattern - self.__listener = listener - self.__connection_args = connection_args + self.builder = None + + self.consumer = None + self.task = None @override def setup( # type: ignore[override] @@ -149,10 +151,10 @@ async def start(self) -> None: **self.__connection_args, ) - if self.topics or self.__pattern: + if self.topics or self._pattern: consumer.subscribe( topics=self.topics, - pattern=self.__pattern, + pattern=self._pattern, listener=self.__listener, ) @@ -229,8 +231,8 @@ def get_routing_hash( @property def topic_names(self) -> List[str]: - if self.__pattern: - return [self.__pattern] + if self._pattern: + return [self._pattern] elif self.topics: return list(self.topics) else: @@ -305,8 +307,19 @@ def __init__( description_: Optional[str], include_in_schema: bool, ) -> None: + if pattern: + reg, pattern = compile_path( + pattern, + replace_symbol=".*", + patch_regex=lambda x: x.replace(r"\*", ".*"), + ) + + else: + reg = None + parser = AioKafkaParser( - msg_class=KafkaAckableMessage if is_manual else KafkaMessage + msg_class=KafkaAckableMessage if is_manual else KafkaMessage, + regex=reg, ) super().__init__( @@ -365,8 +378,19 @@ def __init__( self.batch_timeout_ms = batch_timeout_ms self.max_records = max_records + if pattern: + reg, pattern = compile_path( + pattern, + replace_symbol=".*", + patch_regex=lambda x: x.replace(r"\*", ".*"), + ) + + else: + reg = None + parser = AioKafkaBatchParser( - msg_class=KafkaAckableMessage if is_manual else KafkaMessage + msg_class=KafkaAckableMessage if is_manual else KafkaMessage, + regex=reg, ) super().__init__( diff --git a/faststream/kafka/testing.py b/faststream/kafka/testing.py index 5abe59cf97..4a52be016f 100755 --- a/faststream/kafka/testing.py +++ b/faststream/kafka/testing.py @@ -1,3 +1,4 @@ +import re from datetime import datetime from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from unittest.mock import AsyncMock, MagicMock @@ -16,6 +17,7 @@ if TYPE_CHECKING: from faststream.broker.wrapper.call import HandlerCallWrapper from faststream.kafka.publisher.asyncapi import AsyncAPIPublisher + from faststream.kafka.subscriber.usecase import LogicSubscriber from faststream.types import SendableMessage __all__ = ("TestKafkaBroker",) @@ -108,13 +110,7 @@ async def publish( # type: ignore[override] return_value = None for handler in self.broker._subscribers.values(): # pragma: no branch - if ( - any( - p.topic == topic and (partition is None or p.partition == partition) - for p in handler.partitions - ) - or topic in handler.topics - ): + if _is_handler_matches(handler, topic, partition): handle_value = await call_handler( handler=handler, message=[incoming] @@ -141,7 +137,7 @@ async def publish_batch( ) -> None: """Publish a batch of messages to the Kafka broker.""" for handler in self.broker._subscribers.values(): # pragma: no branch - if topic in handler.topics: + if _is_handler_matches(handler, topic, partition): messages = ( build_message( message=message, @@ -215,3 +211,18 @@ def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: mock.subscribe = MagicMock mock.assign = MagicMock return mock + + +def _is_handler_matches( + handler: "LogicSubscriber[Any]", + topic: str, + partition: Optional[int], +) -> bool: + return bool( + any( + p.topic == topic and (partition is None or p.partition == partition) + for p in handler.partitions + ) + or topic in handler.topics + or (handler._pattern and re.match(handler._pattern, topic)) + ) diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index 2a7f57b888..bea50c9198 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -16,6 +16,41 @@ class TestConsume(BrokerRealConsumeTestcase): def get_broker(self, apply_types: bool = False): return KafkaBroker(apply_types=apply_types) + @pytest.mark.asyncio() + async def test_consume_by_pattern( + self, + queue: str, + event: asyncio.Event, + ): + consume_broker = self.get_broker() + + @consume_broker.subscriber(queue) + async def handler(msg): + event.set() + + pattern_event = asyncio.Event() + + @consume_broker.subscriber(pattern=f"{queue[:-1]}*") + async def pattern_handler(msg): + pattern_event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() + + await br.publish(1, topic=queue) + + await asyncio.wait( + ( + asyncio.create_task(br.publish(1, topic=queue)), + asyncio.create_task(event.wait()), + asyncio.create_task(pattern_event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + assert pattern_event.is_set() + @pytest.mark.asyncio() async def test_consume_batch(self, queue: str): consume_broker = self.get_broker() diff --git a/tests/utils/context/test_path.py b/tests/utils/context/test_path.py index babf557b58..81705b2aea 100644 --- a/tests/utils/context/test_path.py +++ b/tests/utils/context/test_path.py @@ -4,7 +4,35 @@ import pytest from faststream import Path -from tests.marks import require_aiopika, require_nats, require_redis +from tests.marks import require_aiokafka, require_aiopika, require_nats, require_redis + + +@pytest.mark.asyncio() +@require_aiokafka +async def test_aiokafka_path(): + from faststream.kafka import KafkaBroker, TestKafkaBroker + + broker = KafkaBroker() + + @broker.subscriber(pattern="in.{name}.{id}") + async def h( + name: str = Path(), + id_: int = Path("id"), + ): + assert name == "john" + assert id_ == 1 + return 1 + + async with TestKafkaBroker(broker) as br: + assert ( + await br.publish( + "", + "in.john.1", + rpc=True, + rpc_timeout=1.0, + ) + == 1 + ) @pytest.mark.asyncio()