From 9bc7a05c0979bf2017a3ee3c095b84851ee3c7ea Mon Sep 17 00:00:00 2001 From: Flosckow <66554425+Flosckow@users.noreply.github.com> Date: Fri, 22 Nov 2024 16:35:40 +0700 Subject: [PATCH] Feat: Add kafka concurrent subscriber (#1912) * Feat: stage 1 add typing, and mock class for concurrent subscriber * Fix: lint * Feat: stage 2 add concurrent consume * Fix: lint * Feat: change consume to put * Fix: topics, typo * Feat: add tests * docs: generate API References * chore: polish PR * chore: update python version in precommit --------- Co-authored-by: Daniil Dumchenko Co-authored-by: Flosckow Co-authored-by: Nikita Pastukhov --- .github/workflows/pr_tests.yaml | 5 +- docs/docs/SUMMARY.md | 2 + .../AsyncAPIConcurrentDefaultSubscriber.md | 11 ++ .../usecase/ConcurrentDefaultSubscriber.md | 11 ++ faststream/broker/fastapi/get_dependant.py | 4 +- faststream/broker/subscriber/mixins.py | 8 +- faststream/broker/utils.py | 20 +--- faststream/confluent/broker/broker.py | 2 +- faststream/confluent/broker/registrator.py | 108 +++++++++--------- faststream/confluent/client.py | 6 +- faststream/confluent/fastapi/fastapi.py | 2 +- faststream/confluent/publisher/asyncapi.py | 32 ++++-- faststream/confluent/router.py | 2 +- faststream/confluent/subscriber/factory.py | 20 +++- faststream/confluent/subscriber/usecase.py | 7 +- faststream/kafka/broker/registrator.py | 36 ++++-- faststream/kafka/fastapi/fastapi.py | 12 +- faststream/kafka/router.py | 5 + faststream/kafka/subscriber/asyncapi.py | 8 ++ faststream/kafka/subscriber/factory.py | 63 +++++++--- faststream/kafka/subscriber/usecase.py | 73 +++++++++--- faststream/nats/subscriber/usecase.py | 10 +- tests/brokers/kafka/test_consume.py | 41 ++++++- tests/brokers/kafka/test_misconfigure.py | 11 ++ 24 files changed, 342 insertions(+), 157 deletions(-) create mode 100644 docs/docs/en/api/faststream/kafka/subscriber/asyncapi/AsyncAPIConcurrentDefaultSubscriber.md create mode 100644 docs/docs/en/api/faststream/kafka/subscriber/usecase/ConcurrentDefaultSubscriber.md create mode 100644 tests/brokers/kafka/test_misconfigure.py diff --git a/.github/workflows/pr_tests.yaml b/.github/workflows/pr_tests.yaml index e9282d496c..4f7048a8c1 100644 --- a/.github/workflows/pr_tests.yaml +++ b/.github/workflows/pr_tests.yaml @@ -30,10 +30,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: | - 3.8 - 3.9 - 3.10 + python-version: "3.12" - name: Set $PY environment variable run: echo "PY=$(python -VV | sha256sum | cut -d' ' -f1)" >> $GITHUB_ENV - uses: actions/cache@v4 diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index 0c51c2d6a0..596d8e0091 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -668,12 +668,14 @@ search: - subscriber - asyncapi - [AsyncAPIBatchSubscriber](api/faststream/kafka/subscriber/asyncapi/AsyncAPIBatchSubscriber.md) + - [AsyncAPIConcurrentDefaultSubscriber](api/faststream/kafka/subscriber/asyncapi/AsyncAPIConcurrentDefaultSubscriber.md) - [AsyncAPIDefaultSubscriber](api/faststream/kafka/subscriber/asyncapi/AsyncAPIDefaultSubscriber.md) - [AsyncAPISubscriber](api/faststream/kafka/subscriber/asyncapi/AsyncAPISubscriber.md) - factory - [create_subscriber](api/faststream/kafka/subscriber/factory/create_subscriber.md) - usecase - [BatchSubscriber](api/faststream/kafka/subscriber/usecase/BatchSubscriber.md) + - [ConcurrentDefaultSubscriber](api/faststream/kafka/subscriber/usecase/ConcurrentDefaultSubscriber.md) - [DefaultSubscriber](api/faststream/kafka/subscriber/usecase/DefaultSubscriber.md) - [LogicSubscriber](api/faststream/kafka/subscriber/usecase/LogicSubscriber.md) - testing diff --git a/docs/docs/en/api/faststream/kafka/subscriber/asyncapi/AsyncAPIConcurrentDefaultSubscriber.md b/docs/docs/en/api/faststream/kafka/subscriber/asyncapi/AsyncAPIConcurrentDefaultSubscriber.md new file mode 100644 index 0000000000..8ce5838961 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/subscriber/asyncapi/AsyncAPIConcurrentDefaultSubscriber.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.subscriber.asyncapi.AsyncAPIConcurrentDefaultSubscriber diff --git a/docs/docs/en/api/faststream/kafka/subscriber/usecase/ConcurrentDefaultSubscriber.md b/docs/docs/en/api/faststream/kafka/subscriber/usecase/ConcurrentDefaultSubscriber.md new file mode 100644 index 0000000000..16f09d9334 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/subscriber/usecase/ConcurrentDefaultSubscriber.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.subscriber.usecase.ConcurrentDefaultSubscriber diff --git a/faststream/broker/fastapi/get_dependant.py b/faststream/broker/fastapi/get_dependant.py index 45d5aaba30..663812bda1 100644 --- a/faststream/broker/fastapi/get_dependant.py +++ b/faststream/broker/fastapi/get_dependant.py @@ -89,7 +89,7 @@ def _patch_fastapi_dependent(dependant: "Dependant") -> "Dependant": lambda x: isinstance(x, FieldInfo), p.field_info.metadata or (), ), - Field(**field_data), # type: ignore[pydantic-field] + Field(**field_data), ) else: @@ -109,7 +109,7 @@ def _patch_fastapi_dependent(dependant: "Dependant") -> "Dependant": "le": info.field_info.le, } ) - f = Field(**field_data) # type: ignore[pydantic-field] + f = Field(**field_data) params_unique[p.name] = ( info.annotation, diff --git a/faststream/broker/subscriber/mixins.py b/faststream/broker/subscriber/mixins.py index 24b0fd7e46..f1e2274a3b 100644 --- a/faststream/broker/subscriber/mixins.py +++ b/faststream/broker/subscriber/mixins.py @@ -16,8 +16,8 @@ class TasksMixin(SubscriberUsecase[Any]): - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self.tasks: List[asyncio.Task[Any]] = [] def add_task(self, coro: Coroutine[Any, Any, Any]) -> None: @@ -40,7 +40,7 @@ class ConcurrentMixin(TasksMixin): def __init__( self, - *, + *args: Any, max_workers: int, **kwargs: Any, ) -> None: @@ -51,7 +51,7 @@ def __init__( ) self.limiter = anyio.Semaphore(max_workers) - super().__init__(**kwargs) + super().__init__(*args, **kwargs) def start_consume_task(self) -> None: self.add_task(self._serve_consume_queue()) diff --git a/faststream/broker/utils.py b/faststream/broker/utils.py index c12c3fc967..067446de40 100644 --- a/faststream/broker/utils.py +++ b/faststream/broker/utils.py @@ -16,7 +16,7 @@ ) import anyio -from typing_extensions import Literal, Self, overload +from typing_extensions import Self from faststream.broker.acknowledgement_watcher import WatcherContext, get_watcher from faststream.broker.types import MsgType @@ -35,24 +35,6 @@ from faststream.types import LoggerProto -@overload -async def process_msg( - msg: Literal[None], - middlewares: Iterable["BrokerMiddleware[MsgType]"], - parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], - decoder: Callable[["StreamMessage[MsgType]"], "Any"], -) -> None: ... - - -@overload -async def process_msg( - msg: MsgType, - middlewares: Iterable["BrokerMiddleware[MsgType]"], - parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], - decoder: Callable[["StreamMessage[MsgType]"], "Any"], -) -> "StreamMessage[MsgType]": ... - - async def process_msg( msg: Optional[MsgType], middlewares: Iterable["BrokerMiddleware[MsgType]"], diff --git a/faststream/confluent/broker/broker.py b/faststream/confluent/broker/broker.py index e5facb8647..329a9440de 100644 --- a/faststream/confluent/broker/broker.py +++ b/faststream/confluent/broker/broker.py @@ -58,7 +58,7 @@ Partition = TypeVar("Partition") -class KafkaBroker( +class KafkaBroker( # type: ignore[misc] KafkaRegistrator, KafkaLoggingBroker, ): diff --git a/faststream/confluent/broker/registrator.py b/faststream/confluent/broker/registrator.py index 5f7d8e1354..38e99fc877 100644 --- a/faststream/confluent/broker/registrator.py +++ b/faststream/confluent/broker/registrator.py @@ -52,10 +52,12 @@ class KafkaRegistrator( ): """Includable to KafkaBroker router.""" - _subscribers: Dict[ + _subscribers: Dict[ # type: ignore[assignment] int, Union["AsyncAPIBatchSubscriber", "AsyncAPIDefaultSubscriber"] ] - _publishers: Dict[int, Union["AsyncAPIBatchPublisher", "AsyncAPIDefaultPublisher"]] + _publishers: Dict[ # type: ignore[assignment] + int, Union["AsyncAPIBatchPublisher", "AsyncAPIDefaultPublisher"] + ] @overload # type: ignore[override] def subscriber( @@ -1193,60 +1195,56 @@ def subscriber( if not auto_commit and not group_id: raise SetupError("You should install `group_id` with manual commit mode") - subscriber = super().subscriber( - create_subscriber( - *topics, - polling_interval=polling_interval, - partitions=partitions, - batch=batch, - max_records=max_records, - group_id=group_id, - connection_data={ - "group_instance_id": group_instance_id, - "fetch_max_wait_ms": fetch_max_wait_ms, - "fetch_max_bytes": fetch_max_bytes, - "fetch_min_bytes": fetch_min_bytes, - "max_partition_fetch_bytes": max_partition_fetch_bytes, - "auto_offset_reset": auto_offset_reset, - "enable_auto_commit": auto_commit, - "auto_commit_interval_ms": auto_commit_interval_ms, - "check_crcs": check_crcs, - "partition_assignment_strategy": partition_assignment_strategy, - "max_poll_interval_ms": max_poll_interval_ms, - "session_timeout_ms": session_timeout_ms, - "heartbeat_interval_ms": heartbeat_interval_ms, - "isolation_level": isolation_level, - }, - is_manual=not auto_commit, - # subscriber args - no_ack=no_ack, - no_reply=no_reply, - retry=retry, - broker_middlewares=self._middlewares, - broker_dependencies=self._dependencies, - # AsyncAPI - title_=title, - description_=description, - include_in_schema=self._solve_include_in_schema(include_in_schema), - ) + subscriber = create_subscriber( + *topics, + polling_interval=polling_interval, + partitions=partitions, + batch=batch, + max_records=max_records, + group_id=group_id, + connection_data={ + "group_instance_id": group_instance_id, + "fetch_max_wait_ms": fetch_max_wait_ms, + "fetch_max_bytes": fetch_max_bytes, + "fetch_min_bytes": fetch_min_bytes, + "max_partition_fetch_bytes": max_partition_fetch_bytes, + "auto_offset_reset": auto_offset_reset, + "enable_auto_commit": auto_commit, + "auto_commit_interval_ms": auto_commit_interval_ms, + "check_crcs": check_crcs, + "partition_assignment_strategy": partition_assignment_strategy, + "max_poll_interval_ms": max_poll_interval_ms, + "session_timeout_ms": session_timeout_ms, + "heartbeat_interval_ms": heartbeat_interval_ms, + "isolation_level": isolation_level, + }, + is_manual=not auto_commit, + # subscriber args + no_ack=no_ack, + no_reply=no_reply, + retry=retry, + broker_middlewares=self._middlewares, + broker_dependencies=self._dependencies, + # AsyncAPI + title_=title, + description_=description, + include_in_schema=self._solve_include_in_schema(include_in_schema), ) if batch: - return cast("AsyncAPIBatchSubscriber", subscriber).add_call( - filter_=filter, - parser_=parser or self._parser, - decoder_=decoder or self._decoder, - dependencies_=dependencies, - middlewares_=middlewares, - ) + subscriber = cast("AsyncAPIBatchSubscriber", subscriber) else: - return cast("AsyncAPIDefaultSubscriber", subscriber).add_call( - filter_=filter, - parser_=parser or self._parser, - decoder_=decoder or self._decoder, - dependencies_=dependencies, - middlewares_=middlewares, - ) + subscriber = cast("AsyncAPIDefaultSubscriber", subscriber) + + subscriber = super().subscriber(subscriber) # type: ignore[arg-type,assignment] + + return subscriber.add_call( + filter_=filter, + parser_=parser or self._parser, + decoder_=decoder or self._decoder, + dependencies_=dependencies, + middlewares_=middlewares, + ) @overload # type: ignore[override] def publisher( @@ -1577,6 +1575,8 @@ def publisher( ) if batch: - return cast("AsyncAPIBatchPublisher", super().publisher(publisher)) + publisher = cast("AsyncAPIBatchPublisher", publisher) else: - return cast("AsyncAPIDefaultPublisher", super().publisher(publisher)) + publisher = cast("AsyncAPIDefaultPublisher", publisher) + + return super().publisher(publisher) # type: ignore[return-value,arg-type] diff --git a/faststream/confluent/client.py b/faststream/confluent/client.py index 3bf60f205d..db6f8370a2 100644 --- a/faststream/confluent/client.py +++ b/faststream/confluent/client.py @@ -112,7 +112,7 @@ def __init__( } ) - self.producer = Producer(final_config, logger=self.logger) + self.producer = Producer(final_config) self.__running = True self._poll_task = asyncio.create_task(self._poll_loop()) @@ -312,7 +312,7 @@ def __init__( ) self.config = final_config - self.consumer = Consumer(final_config, logger=self.logger) + self.consumer = Consumer(final_config) @property def topics_to_create(self) -> List[str]: @@ -381,7 +381,7 @@ async def getmany( ) -> Tuple[Message, ...]: """Consumes a batch of messages from Kafka and groups them by topic and partition.""" raw_messages: List[Optional[Message]] = await call_or_await( - self.consumer.consume, + self.consumer.consume, # type: ignore[arg-type] num_messages=max_records or 10, timeout=timeout, ) diff --git a/faststream/confluent/fastapi/fastapi.py b/faststream/confluent/fastapi/fastapi.py index eacfc7b37a..bc2c2a1d71 100644 --- a/faststream/confluent/fastapi/fastapi.py +++ b/faststream/confluent/fastapi/fastapi.py @@ -564,7 +564,7 @@ def __init__( graceful_timeout=graceful_timeout, decoder=decoder, parser=parser, - middlewares=middlewares, + middlewares=middlewares, # type: ignore[arg-type] schema_url=schema_url, setup_state=setup_state, # logger options diff --git a/faststream/confluent/publisher/asyncapi.py b/faststream/confluent/publisher/asyncapi.py index f82c0a12f9..f41834b9c2 100644 --- a/faststream/confluent/publisher/asyncapi.py +++ b/faststream/confluent/publisher/asyncapi.py @@ -7,6 +7,7 @@ Optional, Tuple, Union, + cast, overload, ) @@ -64,41 +65,41 @@ def get_schema(self) -> Dict[str, Channel]: @staticmethod def create( *, - batch: Literal[True], + batch: Literal[False], key: Optional[bytes], topic: str, partition: Optional[int], headers: Optional[Dict[str, str]], reply_to: str, # Publisher args - broker_middlewares: Iterable["BrokerMiddleware[Tuple[ConfluentMsg, ...]]"], + broker_middlewares: Iterable["BrokerMiddleware[ConfluentMsg]"], middlewares: Iterable["PublisherMiddleware"], # AsyncAPI args schema_: Optional[Any], title_: Optional[str], description_: Optional[str], include_in_schema: bool, - ) -> "AsyncAPIBatchPublisher": ... + ) -> "AsyncAPIDefaultPublisher": ... @overload @staticmethod def create( *, - batch: Literal[False], + batch: Literal[True], key: Optional[bytes], topic: str, partition: Optional[int], headers: Optional[Dict[str, str]], reply_to: str, # Publisher args - broker_middlewares: Iterable["BrokerMiddleware[ConfluentMsg]"], + broker_middlewares: Iterable["BrokerMiddleware[Tuple[ConfluentMsg, ...]]"], middlewares: Iterable["PublisherMiddleware"], # AsyncAPI args schema_: Optional[Any], title_: Optional[str], description_: Optional[str], include_in_schema: bool, - ) -> "AsyncAPIDefaultPublisher": ... + ) -> "AsyncAPIBatchPublisher": ... @overload @staticmethod @@ -111,8 +112,9 @@ def create( headers: Optional[Dict[str, str]], reply_to: str, # Publisher args - broker_middlewares: Iterable[ - "BrokerMiddleware[Union[Tuple[ConfluentMsg, ...], ConfluentMsg]]" + broker_middlewares: Union[ + Iterable["BrokerMiddleware[Tuple[ConfluentMsg, ...]]"], + Iterable["BrokerMiddleware[ConfluentMsg]"], ], middlewares: Iterable["PublisherMiddleware"], # AsyncAPI args @@ -136,8 +138,9 @@ def create( headers: Optional[Dict[str, str]], reply_to: str, # Publisher args - broker_middlewares: Iterable[ - "BrokerMiddleware[Union[Tuple[ConfluentMsg, ...], ConfluentMsg]]" + broker_middlewares: Union[ + Iterable["BrokerMiddleware[Tuple[ConfluentMsg, ...]]"], + Iterable["BrokerMiddleware[ConfluentMsg]"], ], middlewares: Iterable["PublisherMiddleware"], # AsyncAPI args @@ -158,7 +161,10 @@ def create( partition=partition, headers=headers, reply_to=reply_to, - broker_middlewares=broker_middlewares, + broker_middlewares=cast( + Iterable["BrokerMiddleware[Tuple[ConfluentMsg, ...]]"], + broker_middlewares, + ), middlewares=middlewares, schema_=schema_, title_=title_, @@ -173,7 +179,9 @@ def create( partition=partition, headers=headers, reply_to=reply_to, - broker_middlewares=broker_middlewares, + broker_middlewares=cast( + Iterable["BrokerMiddleware[ConfluentMsg]"], broker_middlewares + ), middlewares=middlewares, schema_=schema_, title_=title_, diff --git a/faststream/confluent/router.py b/faststream/confluent/router.py index 2c66a38992..76364390ea 100644 --- a/faststream/confluent/router.py +++ b/faststream/confluent/router.py @@ -518,7 +518,7 @@ def __init__( # basic args prefix=prefix, dependencies=dependencies, - middlewares=middlewares, + middlewares=middlewares, # type: ignore[arg-type] parser=parser, decoder=decoder, include_in_schema=include_in_schema, diff --git a/faststream/confluent/subscriber/factory.py b/faststream/confluent/subscriber/factory.py index dcb7e414b3..b0c72deb8d 100644 --- a/faststream/confluent/subscriber/factory.py +++ b/faststream/confluent/subscriber/factory.py @@ -6,6 +6,7 @@ Sequence, Tuple, Union, + cast, overload, ) @@ -87,8 +88,9 @@ def create_subscriber( no_reply: bool, retry: bool, broker_dependencies: Iterable["Depends"], - broker_middlewares: Iterable[ - "BrokerMiddleware[Union[ConfluentMsg, Tuple[ConfluentMsg, ...]]]" + broker_middlewares: Union[ + Iterable["BrokerMiddleware[Tuple[ConfluentMsg, ...]]"], + Iterable["BrokerMiddleware[ConfluentMsg]"], ], # AsyncAPI args title_: Optional[str], @@ -115,8 +117,9 @@ def create_subscriber( no_reply: bool, retry: bool, broker_dependencies: Iterable["Depends"], - broker_middlewares: Iterable[ - "BrokerMiddleware[Union[ConfluentMsg, Tuple[ConfluentMsg, ...]]]" + broker_middlewares: Union[ + Iterable["BrokerMiddleware[Tuple[ConfluentMsg, ...]]"], + Iterable["BrokerMiddleware[ConfluentMsg]"], ], # AsyncAPI args title_: Optional[str], @@ -139,7 +142,10 @@ def create_subscriber( no_reply=no_reply, retry=retry, broker_dependencies=broker_dependencies, - broker_middlewares=broker_middlewares, + broker_middlewares=cast( + Iterable["BrokerMiddleware[Tuple[ConfluentMsg, ...]]"], + broker_middlewares, + ), title_=title_, description_=description_, include_in_schema=include_in_schema, @@ -156,7 +162,9 @@ def create_subscriber( no_reply=no_reply, retry=retry, broker_dependencies=broker_dependencies, - broker_middlewares=broker_middlewares, + broker_middlewares=cast( + Iterable["BrokerMiddleware[ConfluentMsg]"], broker_middlewares + ), title_=title_, description_=description_, include_in_schema=include_in_schema, diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 3540be9bdf..2e11b9851a 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -173,7 +173,7 @@ async def get_one( self, *, timeout: float = 5.0, - ) -> "Optional[StreamMessage[Message]]": + ) -> "Optional[StreamMessage[MsgType]]": assert self.consumer, "You should start subscriber at first." # nosec B101 assert ( # nosec B101 not self.calls @@ -181,13 +181,12 @@ async def get_one( raw_message = await self.consumer.getone(timeout=timeout) - msg = await process_msg( - msg=raw_message, + return await process_msg( + msg=raw_message, # type: ignore[arg-type] middlewares=self._broker_middlewares, parser=self._parser, decoder=self._decoder, ) - return msg def _make_response_publisher( self, diff --git a/faststream/kafka/broker/registrator.py b/faststream/kafka/broker/registrator.py index 1cb3fa38e2..0c2c3e1ce2 100644 --- a/faststream/kafka/broker/registrator.py +++ b/faststream/kafka/broker/registrator.py @@ -41,6 +41,7 @@ ) from faststream.kafka.subscriber.asyncapi import ( AsyncAPIBatchSubscriber, + AsyncAPIConcurrentDefaultSubscriber, AsyncAPIDefaultSubscriber, ) @@ -57,7 +58,11 @@ class KafkaRegistrator( _subscribers: Dict[ int, - Union["AsyncAPIBatchSubscriber", "AsyncAPIDefaultSubscriber"], + Union[ + "AsyncAPIBatchSubscriber", + "AsyncAPIDefaultSubscriber", + "AsyncAPIConcurrentDefaultSubscriber", + ], ] _publishers: Dict[ int, @@ -1548,6 +1553,10 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), + max_workers: Annotated[ + int, + Doc("Number of workers to process messages concurrently."), + ] = 1, filter: Annotated[ "Filter[KafkaMessage]", Doc( @@ -1592,11 +1601,13 @@ def subscriber( ) -> Union[ "AsyncAPIDefaultSubscriber", "AsyncAPIBatchSubscriber", + "AsyncAPIConcurrentDefaultSubscriber", ]: subscriber = super().subscriber( create_subscriber( *topics, batch=batch, + max_workers=max_workers, batch_timeout_ms=batch_timeout_ms, max_records=max_records, group_id=group_id, @@ -1648,13 +1659,22 @@ def subscriber( ) else: - return cast("AsyncAPIDefaultSubscriber", subscriber).add_call( - filter_=filter, - parser_=parser or self._parser, - decoder_=decoder or self._decoder, - dependencies_=dependencies, - middlewares_=middlewares, - ) + if max_workers > 1: + return cast("AsyncAPIConcurrentDefaultSubscriber", subscriber).add_call( + filter_=filter, + parser_=parser or self._parser, + decoder_=decoder or self._decoder, + dependencies_=dependencies, + middlewares_=middlewares, + ) + else: + return cast("AsyncAPIDefaultSubscriber", subscriber).add_call( + filter_=filter, + parser_=parser or self._parser, + decoder_=decoder or self._decoder, + dependencies_=dependencies, + middlewares_=middlewares, + ) @overload # type: ignore[override] def publisher( diff --git a/faststream/kafka/fastapi/fastapi.py b/faststream/kafka/fastapi/fastapi.py index 17b8c03192..5bad796902 100644 --- a/faststream/kafka/fastapi/fastapi.py +++ b/faststream/kafka/fastapi/fastapi.py @@ -60,6 +60,7 @@ ) from faststream.kafka.subscriber.asyncapi import ( AsyncAPIBatchSubscriber, + AsyncAPIConcurrentDefaultSubscriber, AsyncAPIDefaultSubscriber, ) from faststream.security import BaseSecurity @@ -2618,13 +2619,19 @@ def subscriber( """ ), ] = False, + max_workers: Annotated[ + int, + Doc("Number of workers to process messages concurrently."), + ] = 1, ) -> Union[ "AsyncAPIBatchSubscriber", "AsyncAPIDefaultSubscriber", + "AsyncAPIConcurrentDefaultSubscriber", ]: subscriber = super().subscriber( *topics, group_id=group_id, + max_workers=max_workers, key_deserializer=key_deserializer, value_deserializer=value_deserializer, fetch_max_wait_ms=fetch_max_wait_ms, @@ -2675,7 +2682,10 @@ def subscriber( if batch: return cast("AsyncAPIBatchSubscriber", subscriber) else: - return cast("AsyncAPIDefaultSubscriber", subscriber) + if max_workers > 1: + return cast("AsyncAPIConcurrentDefaultSubscriber", subscriber) + else: + return cast("AsyncAPIDefaultSubscriber", subscriber) @overload # type: ignore[override] def publisher( diff --git a/faststream/kafka/router.py b/faststream/kafka/router.py index cef54442c8..102240e2ca 100644 --- a/faststream/kafka/router.py +++ b/faststream/kafka/router.py @@ -525,11 +525,16 @@ def __init__( bool, Doc("Whetever to include operation in AsyncAPI schema or not."), ] = True, + max_workers: Annotated[ + int, + Doc("Number of workers to process messages concurrently."), + ] = 1, ) -> None: super().__init__( call, *topics, publishers=publishers, + max_workers=max_workers, group_id=group_id, key_deserializer=key_deserializer, value_deserializer=value_deserializer, diff --git a/faststream/kafka/subscriber/asyncapi.py b/faststream/kafka/subscriber/asyncapi.py index 87fff1c232..1c3ad53ce7 100644 --- a/faststream/kafka/subscriber/asyncapi.py +++ b/faststream/kafka/subscriber/asyncapi.py @@ -17,6 +17,7 @@ from faststream.broker.types import MsgType from faststream.kafka.subscriber.usecase import ( BatchSubscriber, + ConcurrentDefaultSubscriber, DefaultSubscriber, LogicSubscriber, ) @@ -72,3 +73,10 @@ class AsyncAPIBatchSubscriber( AsyncAPISubscriber[Tuple["ConsumerRecord", ...]], ): pass + + +class AsyncAPIConcurrentDefaultSubscriber( + AsyncAPISubscriber["ConsumerRecord"], + ConcurrentDefaultSubscriber, +): + pass diff --git a/faststream/kafka/subscriber/factory.py b/faststream/kafka/subscriber/factory.py index 0f504667f4..162866cf39 100644 --- a/faststream/kafka/subscriber/factory.py +++ b/faststream/kafka/subscriber/factory.py @@ -11,6 +11,7 @@ from faststream.exceptions import SetupError from faststream.kafka.subscriber.asyncapi import ( AsyncAPIBatchSubscriber, + AsyncAPIConcurrentDefaultSubscriber, AsyncAPIDefaultSubscriber, ) @@ -37,6 +38,7 @@ def create_subscriber( partitions: Iterable["TopicPartition"], is_manual: bool, # Subscriber args + max_workers: int, no_ack: bool, no_reply: bool, retry: bool, @@ -63,6 +65,7 @@ def create_subscriber( partitions: Iterable["TopicPartition"], is_manual: bool, # Subscriber args + max_workers: int, no_ack: bool, no_reply: bool, retry: bool, @@ -89,6 +92,7 @@ def create_subscriber( partitions: Iterable["TopicPartition"], is_manual: bool, # Subscriber args + max_workers: int, no_ack: bool, no_reply: bool, retry: bool, @@ -119,6 +123,7 @@ def create_subscriber( partitions: Iterable["TopicPartition"], is_manual: bool, # Subscriber args + max_workers: int, no_ack: bool, no_reply: bool, retry: bool, @@ -133,10 +138,14 @@ def create_subscriber( ) -> Union[ "AsyncAPIDefaultSubscriber", "AsyncAPIBatchSubscriber", + "AsyncAPIConcurrentDefaultSubscriber", ]: if is_manual and not group_id: raise SetupError("You must use `group_id` with manual commit mode.") + if is_manual and max_workers > 1: + raise SetupError("Max workers not work with manual commit mode.") + if not topics and not partitions and not pattern: raise SetupError( "You should provide either `topics` or `partitions` or `pattern`." @@ -170,20 +179,40 @@ def create_subscriber( ) else: - return AsyncAPIDefaultSubscriber( - *topics, - group_id=group_id, - listener=listener, - pattern=pattern, - connection_args=connection_args, - partitions=partitions, - is_manual=is_manual, - no_ack=no_ack, - no_reply=no_reply, - retry=retry, - broker_dependencies=broker_dependencies, - broker_middlewares=broker_middlewares, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, - ) + if max_workers > 1: + return AsyncAPIConcurrentDefaultSubscriber( + *topics, + max_workers=max_workers, + group_id=group_id, + listener=listener, + pattern=pattern, + connection_args=connection_args, + partitions=partitions, + is_manual=is_manual, + no_ack=no_ack, + no_reply=no_reply, + retry=retry, + broker_dependencies=broker_dependencies, + broker_middlewares=broker_middlewares, + title_=title_, + description_=description_, + include_in_schema=include_in_schema, + ) + else: + return AsyncAPIDefaultSubscriber( + *topics, + group_id=group_id, + listener=listener, + pattern=pattern, + connection_args=connection_args, + partitions=partitions, + is_manual=is_manual, + no_ack=no_ack, + no_reply=no_reply, + retry=retry, + broker_dependencies=broker_dependencies, + broker_middlewares=broker_middlewares, + title_=title_, + description_=description_, + include_in_schema=include_in_schema, + ) diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index b14e107faf..2bd4d7162e 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -1,4 +1,3 @@ -import asyncio from abc import ABC, abstractmethod from itertools import chain from typing import ( @@ -19,6 +18,7 @@ from typing_extensions import override from faststream.broker.publisher.fake import FakePublisher +from faststream.broker.subscriber.mixins import ConcurrentMixin, TasksMixin from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.broker.types import ( AsyncCallable, @@ -41,7 +41,7 @@ from faststream.types import AnyDict, Decorator, LoggerProto -class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): +class LogicSubscriber(ABC, TasksMixin, SubscriberUsecase[MsgType]): """A class to handle logic for consuming messages from Kafka.""" topics: Sequence[str] @@ -50,7 +50,6 @@ class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): builder: Optional[Callable[..., "AIOKafkaConsumer"]] consumer: Optional["AIOKafkaConsumer"] - task: Optional["asyncio.Task[None]"] client_id: Optional[str] batch: bool @@ -104,7 +103,6 @@ def __init__( self.builder = None self.consumer = None - self.task = None @override def setup( # type: ignore[override] @@ -166,7 +164,7 @@ async def start(self) -> None: await super().start() if self.calls: - self.task = asyncio.create_task(self._consume()) + self.add_task(self._consume()) async def close(self) -> None: await super().close() @@ -175,11 +173,6 @@ async def close(self) -> None: await self.consumer.stop() self.consumer = None - if self.task is not None and not self.task.done(): - self.task.cancel() - - self.task = None - @override async def get_one( self, @@ -200,13 +193,12 @@ async def get_one( ((raw_message,),) = raw_messages.values() - msg: StreamMessage[MsgType] = await process_msg( + return await process_msg( msg=raw_message, middlewares=self._broker_middlewares, parser=self._parser, decoder=self._decoder, ) - return msg def _make_response_publisher( self, @@ -250,7 +242,10 @@ async def _consume(self) -> None: connected = True if msg: - await self.consume(msg) + await self.consume_one(msg) + + async def consume_one(self, msg: MsgType) -> None: + await self.consume(msg) @staticmethod def get_routing_hash( @@ -471,3 +466,55 @@ def get_log_context( topic=topic, group_id=self.group_id, ) + + +class ConcurrentDefaultSubscriber(ConcurrentMixin, DefaultSubscriber): + def __init__( + self, + *topics: str, + # Kafka information + group_id: Optional[str], + listener: Optional["ConsumerRebalanceListener"], + pattern: Optional[str], + connection_args: "AnyDict", + partitions: Iterable["TopicPartition"], + is_manual: bool, + # Subscriber args + max_workers: int, + no_ack: bool, + no_reply: bool, + retry: bool, + broker_dependencies: Iterable["Depends"], + broker_middlewares: Iterable["BrokerMiddleware[ConsumerRecord]"], + # AsyncAPI args + title_: Optional[str], + description_: Optional[str], + include_in_schema: bool, + ) -> None: + super().__init__( + *topics, + group_id=group_id, + listener=listener, + pattern=pattern, + connection_args=connection_args, + partitions=partitions, + is_manual=is_manual, + # Propagated args + no_ack=no_ack, + no_reply=no_reply, + retry=retry, + broker_middlewares=broker_middlewares, + broker_dependencies=broker_dependencies, + # AsyncAPI args + title_=title_, + description_=description_, + include_in_schema=include_in_schema, + max_workers=max_workers, + ) + + async def start(self) -> None: + await super().start() + self.start_consume_task() + + async def consume_one(self, msg: "ConsumerRecord") -> None: + await self._put_msg(msg) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index efbe0342b9..5d16299f14 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -1034,7 +1034,7 @@ async def get_one( else: fetch_sub = self._fetch_sub - raw_message = None + raw_message: Optional[KeyValue.Entry] = None sleep_interval = timeout / 10 with anyio.move_on_after(timeout): while ( # noqa: ASYNC110 @@ -1042,13 +1042,12 @@ async def get_one( ) is None: await anyio.sleep(sleep_interval) - msg: NatsKvMessage = await process_msg( + return await process_msg( # type: ignore[return-value] msg=raw_message, middlewares=self._broker_middlewares, parser=self._parser, decoder=self._decoder, ) - return msg @override async def _create_subscription( @@ -1192,7 +1191,7 @@ async def get_one( else: fetch_sub = self._fetch_sub - raw_message = None + raw_message: Optional[ObjectInfo] = None sleep_interval = timeout / 10 with anyio.move_on_after(timeout): while ( # noqa: ASYNC110 @@ -1200,13 +1199,12 @@ async def get_one( ) is None: await anyio.sleep(sleep_interval) - msg: NatsObjMessage = await process_msg( + return await process_msg( # type: ignore[return-value] msg=raw_message, middlewares=self._broker_middlewares, parser=self._parser, decoder=self._decoder, ) - return msg @override async def _create_subscription( diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index 7da9f90a5f..9593f61a48 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from aiokafka import AIOKafkaConsumer @@ -312,3 +312,42 @@ async def handler(msg: KafkaMessage): m.mock.assert_not_called() assert event.is_set() + + @pytest.mark.asyncio + @pytest.mark.slow + async def test_concurrent_consume(self, queue: str, mock: MagicMock): + event = asyncio.Event() + event2 = asyncio.Event() + + consume_broker = self.get_broker() + + args, kwargs = self.get_subscriber_params(queue, max_workers=2) + + @consume_broker.subscriber(*args, **kwargs) + async def handler(msg): + mock() + if event.is_set(): + event2.set() + else: + event.set() + + # probably, we should increase it + await asyncio.sleep(0.1) + + async with self.patch_broker(consume_broker) as br: + await br.start() + + for i in range(5): + await br.publish(i, queue) + + await asyncio.wait( + ( + asyncio.create_task(event.wait()), + asyncio.create_task(event2.wait()), + ), + timeout=3, + ) + + assert event.is_set() + assert event2.is_set() + assert mock.call_count == 2, mock.call_count diff --git a/tests/brokers/kafka/test_misconfigure.py b/tests/brokers/kafka/test_misconfigure.py new file mode 100644 index 0000000000..771c45426f --- /dev/null +++ b/tests/brokers/kafka/test_misconfigure.py @@ -0,0 +1,11 @@ +import pytest + +from faststream.exceptions import SetupError +from faststream.kafka import KafkaBroker + + +def test_max_workers_with_manual(queue: str) -> None: + broker = KafkaBroker() + + with pytest.raises(SetupError): + broker.subscriber(queue, max_workers=3, auto_commit=False)