From a3c353fc9c61e4c965d0f74c07b1eb7678773b0a Mon Sep 17 00:00:00 2001 From: Pastukhov Nikita Date: Fri, 7 Jun 2024 18:59:19 +0300 Subject: [PATCH] fix: include NatsRouter streams to original broker (#1509) * fix: includer NatsRouter streams to original broker * chore: limit typing-extensions version only for tests * docs: generate API References * fix: remove debug message * chore: update ruff * chore: use GHA concurency to cancel previous run at push * chore: test GHA cancelation * chore: add GHA concurency to generating API CI --------- Co-authored-by: Lancetnik --- .github/workflows/docs_update-references.yaml | 4 ++ .github/workflows/pr_codeql.yml | 4 ++ .github/workflows/pr_dependency-review.yaml | 4 ++ .github/workflows/pr_tests.yaml | 4 ++ docs/docs/SUMMARY.md | 2 + .../kafka/message/KafkaAckableMessage.md | 11 +++++ .../kafka/parser/AioKafkaBatchParser.md | 11 +++++ faststream/__about__.py | 2 +- .../confluent/opentelemetry/provider.py | 6 +-- faststream/kafka/annotations.py | 2 + faststream/kafka/message.py | 6 +-- faststream/kafka/opentelemetry/provider.py | 6 +-- faststream/kafka/parser.py | 43 +++++++++++-------- faststream/kafka/subscriber/usecase.py | 23 ++++++---- faststream/nats/broker/registrator.py | 23 +++++++++- faststream/nats/opentelemetry/provider.py | 6 +-- faststream/nats/parser.py | 2 +- faststream/rabbit/opentelemetry/provider.py | 4 +- faststream/rabbit/parser.py | 2 +- faststream/redis/parser.py | 8 ++-- faststream/utils/context/types.py | 4 -- pyproject.toml | 6 +-- tests/brokers/nats/test_router.py | 14 +++++- 23 files changed, 140 insertions(+), 57 deletions(-) create mode 100644 docs/docs/en/api/faststream/kafka/message/KafkaAckableMessage.md create mode 100644 docs/docs/en/api/faststream/kafka/parser/AioKafkaBatchParser.md diff --git a/.github/workflows/docs_update-references.yaml b/.github/workflows/docs_update-references.yaml index 30f9d91bea..56369ec56f 100644 --- a/.github/workflows/docs_update-references.yaml +++ b/.github/workflows/docs_update-references.yaml @@ -8,6 +8,10 @@ on: paths: - faststream/** +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + permissions: contents: write diff --git a/.github/workflows/pr_codeql.yml b/.github/workflows/pr_codeql.yml index f1fb50d463..a316b38777 100644 --- a/.github/workflows/pr_codeql.yml +++ b/.github/workflows/pr_codeql.yml @@ -23,6 +23,10 @@ on: schedule: - cron: '39 20 * * 0' +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: analyze: if: github.event.pull_request.draft == false diff --git a/.github/workflows/pr_dependency-review.yaml b/.github/workflows/pr_dependency-review.yaml index a241701673..f32635858b 100644 --- a/.github/workflows/pr_dependency-review.yaml +++ b/.github/workflows/pr_dependency-review.yaml @@ -16,6 +16,10 @@ on: paths: - pyproject.toml +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + permissions: contents: read diff --git a/.github/workflows/pr_tests.yaml b/.github/workflows/pr_tests.yaml index 59c2de5fd4..a2b545c759 100644 --- a/.github/workflows/pr_tests.yaml +++ b/.github/workflows/pr_tests.yaml @@ -12,6 +12,10 @@ on: types: - checks_requested +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: static_analysis: if: github.event.pull_request.draft == false diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index f10774cc8c..7f53b19fc5 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -540,6 +540,7 @@ search: - message - [ConsumerProtocol](api/faststream/kafka/message/ConsumerProtocol.md) - [FakeConsumer](api/faststream/kafka/message/FakeConsumer.md) + - [KafkaAckableMessage](api/faststream/kafka/message/KafkaAckableMessage.md) - [KafkaMessage](api/faststream/kafka/message/KafkaMessage.md) - opentelemetry - [KafkaTelemetryMiddleware](api/faststream/kafka/opentelemetry/KafkaTelemetryMiddleware.md) @@ -551,6 +552,7 @@ search: - [KafkaTelemetrySettingsProvider](api/faststream/kafka/opentelemetry/provider/KafkaTelemetrySettingsProvider.md) - [telemetry_attributes_provider_factory](api/faststream/kafka/opentelemetry/provider/telemetry_attributes_provider_factory.md) - parser + - [AioKafkaBatchParser](api/faststream/kafka/parser/AioKafkaBatchParser.md) - [AioKafkaParser](api/faststream/kafka/parser/AioKafkaParser.md) - publisher - asyncapi diff --git a/docs/docs/en/api/faststream/kafka/message/KafkaAckableMessage.md b/docs/docs/en/api/faststream/kafka/message/KafkaAckableMessage.md new file mode 100644 index 0000000000..16461be675 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/message/KafkaAckableMessage.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.message.KafkaAckableMessage diff --git a/docs/docs/en/api/faststream/kafka/parser/AioKafkaBatchParser.md b/docs/docs/en/api/faststream/kafka/parser/AioKafkaBatchParser.md new file mode 100644 index 0000000000..25df2532c6 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/parser/AioKafkaBatchParser.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.parser.AioKafkaBatchParser diff --git a/faststream/__about__.py b/faststream/__about__.py index 33eac3b1c1..6e014d02ac 100644 --- a/faststream/__about__.py +++ b/faststream/__about__.py @@ -1,6 +1,6 @@ """Simple and fast framework to create message brokers based microservices.""" -__version__ = "0.5.10" +__version__ = "0.5.11" SERVICE_NAME = f"faststream-{__version__}" diff --git a/faststream/confluent/opentelemetry/provider.py b/faststream/confluent/opentelemetry/provider.py index 6add7330ca..3c157851d9 100644 --- a/faststream/confluent/opentelemetry/provider.py +++ b/faststream/confluent/opentelemetry/provider.py @@ -37,8 +37,8 @@ def get_publish_attrs_from_kwargs( return attrs - @staticmethod def get_publish_destination_name( + self, kwargs: "AnyDict", ) -> str: return cast(str, kwargs["topic"]) @@ -66,8 +66,8 @@ def get_consume_attrs_from_message( return attrs - @staticmethod def get_consume_destination_name( + self, msg: "StreamMessage[Message]", ) -> str: return cast(str, msg.raw_message.topic()) @@ -95,8 +95,8 @@ def get_consume_attrs_from_message( return attrs - @staticmethod def get_consume_destination_name( + self, msg: "StreamMessage[Tuple[Message, ...]]", ) -> str: return cast(str, msg.raw_message[0].topic()) diff --git a/faststream/kafka/annotations.py b/faststream/kafka/annotations.py index 5ea36b7c4e..efca62b227 100644 --- a/faststream/kafka/annotations.py +++ b/faststream/kafka/annotations.py @@ -1,3 +1,4 @@ +from aiokafka import AIOKafkaConsumer from typing_extensions import Annotated from faststream.annotations import ContextRepo, Logger, NoCast @@ -15,6 +16,7 @@ "KafkaProducer", ) +Consumer = Annotated[AIOKafkaConsumer, Context("handler_.consumer")] KafkaMessage = Annotated[KM, Context("message")] KafkaBroker = Annotated[KB, Context("broker")] KafkaProducer = Annotated[AioKafkaFastProducer, Context("broker._producer")] diff --git a/faststream/kafka/message.py b/faststream/kafka/message.py index c051e8f5ba..52f243a6ab 100644 --- a/faststream/kafka/message.py +++ b/faststream/kafka/message.py @@ -39,16 +39,16 @@ def __init__( self, *args: Any, consumer: ConsumerProtocol, - is_manual: bool = False, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - self.is_manual = is_manual self.consumer = consumer + +class KafkaAckableMessage(KafkaMessage): async def ack(self) -> None: """Acknowledge the Kafka message.""" - if self.is_manual and not self.committed: + if not self.committed: await self.consumer.commit() await super().ack() diff --git a/faststream/kafka/opentelemetry/provider.py b/faststream/kafka/opentelemetry/provider.py index b1702b6022..b90d82c9fd 100644 --- a/faststream/kafka/opentelemetry/provider.py +++ b/faststream/kafka/opentelemetry/provider.py @@ -37,8 +37,8 @@ def get_publish_attrs_from_kwargs( return attrs - @staticmethod def get_publish_destination_name( + self, kwargs: "AnyDict", ) -> str: return cast(str, kwargs["topic"]) @@ -66,8 +66,8 @@ def get_consume_attrs_from_message( return attrs - @staticmethod def get_consume_destination_name( + self, msg: "StreamMessage[ConsumerRecord]", ) -> str: return cast(str, msg.raw_message.topic) @@ -96,8 +96,8 @@ def get_consume_attrs_from_message( return attrs - @staticmethod def get_consume_destination_name( + self, msg: "StreamMessage[Tuple[ConsumerRecord, ...]]", ) -> str: return cast(str, msg.raw_message[0].topic) diff --git a/faststream/kafka/parser.py b/faststream/kafka/parser.py index 49924c9e97..f6c9964584 100644 --- a/faststream/kafka/parser.py +++ b/faststream/kafka/parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from faststream.broker.message import decode_message, gen_cor_id from faststream.kafka.message import FAKE_CONSUMER, KafkaMessage @@ -15,14 +15,17 @@ class AioKafkaParser: """A class to parse Kafka messages.""" - @staticmethod + def __init__(self, msg_class: Type[KafkaMessage]) -> None: + self.msg_class = msg_class + async def parse_message( + self, message: "ConsumerRecord", ) -> "StreamMessage[ConsumerRecord]": """Parses a Kafka message.""" headers = {i: j.decode() for i, j in message.headers} - handler: Optional[LogicSubscriber[Any]] = context.get_local("handler_") - return KafkaMessage( + handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_") + return self.msg_class( body=message.value, headers=headers, reply_to=headers.get("reply_to", ""), @@ -31,11 +34,19 @@ async def parse_message( correlation_id=headers.get("correlation_id", gen_cor_id()), raw_message=message, consumer=getattr(handler, "consumer", None) or FAKE_CONSUMER, - is_manual=getattr(handler, "is_manual", True), ) - @staticmethod - async def parse_message_batch( + async def decode_message( + self, + msg: "StreamMessage[ConsumerRecord]", + ) -> "DecodedMessage": + """Decodes a message.""" + return decode_message(msg) + + +class AioKafkaBatchParser(AioKafkaParser): + async def parse_message( + self, message: Tuple["ConsumerRecord", ...], ) -> "StreamMessage[Tuple[ConsumerRecord, ...]]": """Parses a batch of messages from a Kafka consumer.""" @@ -53,7 +64,7 @@ async def parse_message_batch( handler: Optional[LogicSubscriber[Any]] = context.get_local("handler_") - return KafkaMessage( + return self.msg_class( body=body, headers=headers, batch_headers=batch_headers, @@ -63,18 +74,14 @@ async def parse_message_batch( correlation_id=headers.get("correlation_id", gen_cor_id()), raw_message=message, consumer=getattr(handler, "consumer", None) or FAKE_CONSUMER, - is_manual=getattr(handler, "is_manual", True), ) - @staticmethod - async def decode_message(msg: "StreamMessage[ConsumerRecord]") -> "DecodedMessage": - """Decodes a message.""" - return decode_message(msg) - - @classmethod - async def decode_message_batch( - cls, + async def decode_message( + self, msg: "StreamMessage[Tuple[ConsumerRecord, ...]]", ) -> "DecodedMessage": """Decode a batch of messages.""" - return [decode_message(await cls.parse_message(m)) for m in msg.raw_message] + return [ + decode_message(await super(AioKafkaBatchParser, self).parse_message(m)) + for m in msg.raw_message + ] diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index fa01a11fcb..5b077faf73 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -26,7 +26,8 @@ CustomCallable, MsgType, ) -from faststream.kafka.parser import AioKafkaParser +from faststream.kafka.message import KafkaAckableMessage, KafkaMessage +from faststream.kafka.parser import AioKafkaBatchParser, AioKafkaParser if TYPE_CHECKING: from aiokafka import AIOKafkaConsumer, ConsumerRecord @@ -60,7 +61,6 @@ def __init__( listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], partitions: Iterable["TopicPartition"], - is_manual: bool, # Subscriber args default_parser: "AsyncCallable", default_decoder: "AsyncCallable", @@ -93,7 +93,6 @@ def __init__( self.partitions = partitions self.group_id = group_id - self.is_manual = is_manual self.builder = None self.consumer = None self.task = None @@ -306,6 +305,10 @@ def __init__( description_: Optional[str], include_in_schema: bool, ) -> None: + parser = AioKafkaParser( + msg_class=KafkaAckableMessage if is_manual else KafkaMessage + ) + super().__init__( *topics, group_id=group_id, @@ -313,10 +316,9 @@ def __init__( pattern=pattern, connection_args=connection_args, partitions=partitions, - is_manual=is_manual, # subscriber args - default_parser=AioKafkaParser.parse_message, - default_decoder=AioKafkaParser.decode_message, + default_parser=parser.parse_message, + default_decoder=parser.decode_message, # Propagated args no_ack=no_ack, no_reply=no_reply, @@ -363,6 +365,10 @@ def __init__( self.batch_timeout_ms = batch_timeout_ms self.max_records = max_records + parser = AioKafkaBatchParser( + msg_class=KafkaAckableMessage if is_manual else KafkaMessage + ) + super().__init__( *topics, group_id=group_id, @@ -370,10 +376,9 @@ def __init__( pattern=pattern, connection_args=connection_args, partitions=partitions, - is_manual=is_manual, # subscriber args - default_parser=AioKafkaParser.parse_message_batch, - default_decoder=AioKafkaParser.decode_message_batch, + default_parser=parser.parse_message, + default_decoder=parser.decode_message, # Propagated args no_ack=no_ack, no_reply=no_reply, diff --git a/faststream/nats/broker/registrator.py b/faststream/nats/broker/registrator.py index ca6b84d4d4..a77b439b98 100644 --- a/faststream/nats/broker/registrator.py +++ b/faststream/nats/broker/registrator.py @@ -13,9 +13,10 @@ if TYPE_CHECKING: from fast_depends.dependencies import Depends - from nats.aio.msg import Msg # noqa: F401 + from nats.aio.msg import Msg from faststream.broker.types import ( + BrokerMiddleware, CustomCallable, Filter, PublisherMiddleware, @@ -348,3 +349,23 @@ def publisher( # type: ignore[override] ), ) return publisher + + @override + def include_router( # type: ignore[override] + self, + router: "NatsRegistrator", + *, + prefix: str = "", + dependencies: Iterable["Depends"] = (), + middlewares: Iterable["BrokerMiddleware[Msg]"] = (), + include_in_schema: Optional[bool] = None, + ) -> None: + self._stream_builder.objects.update(router._stream_builder.objects) + + return super().include_router( + router, + prefix=prefix, + dependencies=dependencies, + middlewares=middlewares, + include_in_schema=include_in_schema, + ) diff --git a/faststream/nats/opentelemetry/provider.py b/faststream/nats/opentelemetry/provider.py index 7c33a7d76b..a77ff0a2b3 100644 --- a/faststream/nats/opentelemetry/provider.py +++ b/faststream/nats/opentelemetry/provider.py @@ -29,8 +29,8 @@ def get_publish_attrs_from_kwargs( SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: kwargs["correlation_id"], } - @staticmethod def get_publish_destination_name( + self, kwargs: "AnyDict", ) -> str: subject: str = kwargs.get("subject", SERVICE_NAME) @@ -50,8 +50,8 @@ def get_consume_attrs_from_message( MESSAGING_DESTINATION_PUBLISH_NAME: msg.raw_message.subject, } - @staticmethod def get_consume_destination_name( + self, msg: "StreamMessage[Msg]", ) -> str: return msg.raw_message.subject @@ -73,8 +73,8 @@ def get_consume_attrs_from_message( MESSAGING_DESTINATION_PUBLISH_NAME: msg.raw_message[0].subject, } - @staticmethod def get_consume_destination_name( + self, msg: "StreamMessage[List[Msg]]", ) -> str: return msg.raw_message[0].subject diff --git a/faststream/nats/parser.py b/faststream/nats/parser.py index 206e851999..25d61d4901 100644 --- a/faststream/nats/parser.py +++ b/faststream/nats/parser.py @@ -41,8 +41,8 @@ def get_path( return path - @staticmethod async def decode_message( + self, msg: "StreamMessage[Any]", ) -> "DecodedMessage": return decode_message(msg) diff --git a/faststream/rabbit/opentelemetry/provider.py b/faststream/rabbit/opentelemetry/provider.py index da62338e70..7ba8c1900e 100644 --- a/faststream/rabbit/opentelemetry/provider.py +++ b/faststream/rabbit/opentelemetry/provider.py @@ -32,8 +32,8 @@ def get_consume_attrs_from_message( MESSAGING_DESTINATION_PUBLISH_NAME: msg.raw_message.exchange, } - @staticmethod def get_consume_destination_name( + self, msg: "StreamMessage[IncomingMessage]", ) -> str: exchange = msg.raw_message.exchange or "default" @@ -53,8 +53,8 @@ def get_publish_attrs_from_kwargs( SpanAttributes.MESSAGING_MESSAGE_CONVERSATION_ID: kwargs["correlation_id"], } - @staticmethod def get_publish_destination_name( + self, kwargs: "AnyDict", ) -> str: exchange: str = kwargs.get("exchange") or "default" diff --git a/faststream/rabbit/parser.py b/faststream/rabbit/parser.py index 66fed6ed71..8fe02dc4b3 100644 --- a/faststream/rabbit/parser.py +++ b/faststream/rabbit/parser.py @@ -50,8 +50,8 @@ async def parse_message( raw_message=message, ) - @staticmethod async def decode_message( + self, msg: StreamMessage["IncomingMessage"], ) -> "DecodedMessage": """Decode a message.""" diff --git a/faststream/redis/parser.py b/faststream/redis/parser.py index bad91875ef..d42297af77 100644 --- a/faststream/redis/parser.py +++ b/faststream/redis/parser.py @@ -152,8 +152,8 @@ async def parse_message( correlation_id=headers.get("correlation_id", id_), ) - @staticmethod def _parse_data( + self, message: Mapping[str, Any], ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: return (*RawMessage.parse(message["data"]), []) @@ -169,8 +169,8 @@ def get_path(self, message: Mapping[str, Any]) -> "AnyDict": else: return {} - @staticmethod async def decode_message( + self, msg: "StreamMessage[MsgType]", ) -> DecodedMessage: return decode_message(msg) @@ -187,8 +187,8 @@ class RedisListParser(SimpleParser): class RedisBatchListParser(SimpleParser): msg_class = RedisBatchListMessage - @staticmethod def _parse_data( + self, message: Mapping[str, Any], ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: body: List[Any] = [] @@ -225,8 +225,8 @@ def _parse_data( class RedisBatchStreamParser(SimpleParser): msg_class = RedisBatchStreamMessage - @staticmethod def _parse_data( + self, message: Mapping[str, Any], ) -> Tuple[bytes, "AnyDict", List["AnyDict"]]: body: List[Any] = [] diff --git a/faststream/utils/context/types.py b/faststream/utils/context/types.py index ee7ce1b5bc..f27d6fe77c 100644 --- a/faststream/utils/context/types.py +++ b/faststream/utils/context/types.py @@ -58,10 +58,6 @@ def use(self, /, **kwargs: Any) -> AnyDict: Returns: A dictionary containing the updated keyword arguments - - Raises: - KeyError: If the parameter name is not found in the keyword arguments - AttributeError: If the parameter name is not a valid attribute """ name = f"{self.prefix}{self.name or self.param_name}" diff --git a/pyproject.toml b/pyproject.toml index fce5418594..1442a88388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,8 +58,7 @@ dependencies = [ "anyio>=3.7.1,<5", "fast-depends>=2.4.0b0,<2.5.0", "typer>=0.9,!=0.12,<1", - "typing-extensions>=4.8.0,<4.12.1; python_version < '3.9'", - "typing-extensions>=4.8.0; python_version >= '3.9'", + "typing-extensions>=4.8.0", ] [project.optional-dependencies] @@ -112,7 +111,7 @@ types = [ lint = [ "faststream[types]", - "ruff==0.4.7", + "ruff==0.4.8", "bandit==1.7.8", "semgrep==1.74.0", "codespell==2.3.0", @@ -123,6 +122,7 @@ test-core = [ "pytest==8.2.1", "pytest-asyncio==0.23.7", "dirty-equals==0.7.1.post0", + "typing-extensions>=4.8.0,<4.12.1; python_version < '3.9'", # to fix dirty-equals ] testing = [ diff --git a/tests/brokers/nats/test_router.py b/tests/brokers/nats/test_router.py index a0951a06d1..c6c2b60bae 100644 --- a/tests/brokers/nats/test_router.py +++ b/tests/brokers/nats/test_router.py @@ -3,7 +3,7 @@ import pytest from faststream import Path -from faststream.nats import NatsPublisher, NatsRoute, NatsRouter +from faststream.nats import NatsBroker, NatsPublisher, NatsRoute, NatsRouter from tests.brokers.base.router import RouterLocalTestcase, RouterTestcase @@ -136,3 +136,15 @@ class TestRouterLocal(RouterLocalTestcase): broker_class = NatsRouter route_class = NatsRoute publisher_class = NatsPublisher + + def test_include_stream( + self, + router: NatsRouter, + pub_broker: NatsBroker, + ): + @router.subscriber("test", stream="stream") + async def handler(): ... + + pub_broker.include_router(router) + + assert next(iter(pub_broker._stream_builder.objects.keys())) == "stream"