From 865185e790808a11b07b21c56b1f256bbdfa992b Mon Sep 17 00:00:00 2001 From: Pastukhov Nikita Date: Fri, 20 Sep 2024 18:35:41 +0300 Subject: [PATCH] logging fsm * feat: refactor logging with FSM * refactor: delete subscribers without dict * refactor: new kafka logging * refactor: new confluent logging * docs: generate API References * tests: fix tests * tests: fix in-memory mocks * confluent: fix self.logger usage * confluent: fix logger usage * confluent: fix logger usage * confluent: make laze logging * confluent: check producer before setup * confluent: fix producer * tests: correct setup call order * fix: remove Confluent producer logger * fix: remove useless option * tests: fix confluent * fix confluent --------- Co-authored-by: Lancetnik --- docs/docs/SUMMARY.md | 12 +- ...LoggingBroker.md => KafkaParamsStorage.md} | 2 +- .../IncorrectState.md} | 2 +- .../broker/logging/KafkaParamsStorage.md} | 2 +- .../logging/LoggingMiddleware.md} | 2 +- .../broker/logging/NatsParamsStorage.md} | 2 +- .../broker/logging/RabbitParamsStorage.md | 11 ++ .../broker/logging/RedisParamsStorage.md | 11 ++ faststream/_internal/broker/abc_broker.py | 60 +++---- faststream/_internal/broker/broker.py | 133 +++++++-------- faststream/_internal/broker/logging_mixin.py | 93 ---------- faststream/_internal/cli/main.py | 1 + faststream/_internal/cli/utils/logs.py | 9 +- faststream/_internal/constants.py | 6 +- faststream/_internal/fastapi/router.py | 5 +- faststream/_internal/log/logging.py | 22 +-- faststream/_internal/proto.py | 5 +- faststream/_internal/setup/__init__.py | 17 ++ faststream/_internal/setup/fast_depends.py | 13 ++ faststream/_internal/setup/logger.py | 159 ++++++++++++++++++ faststream/_internal/setup/proto.py | 7 + faststream/_internal/setup/state.py | 99 +++++++++++ faststream/_internal/subscriber/call_item.py | 2 +- faststream/_internal/subscriber/usecase.py | 17 +- faststream/_internal/testing/broker.py | 14 +- faststream/app.py | 5 +- faststream/confluent/broker/broker.py | 46 ++--- faststream/confluent/broker/logging.py | 111 ++++++------ faststream/confluent/client.py | 33 ++-- faststream/confluent/subscriber/usecase.py | 13 +- faststream/confluent/testing.py | 4 + faststream/exceptions.py | 4 + faststream/kafka/broker/broker.py | 34 ++-- faststream/kafka/broker/logging.py | 110 ++++++------ faststream/kafka/subscriber/usecase.py | 13 +- faststream/middlewares/logging.py | 75 ++++----- faststream/nats/broker/broker.py | 54 +++--- faststream/nats/broker/logging.py | 118 ++++++------- faststream/nats/subscriber/usecase.py | 12 +- faststream/rabbit/broker/broker.py | 56 +++--- faststream/rabbit/broker/logging.py | 97 +++++------ faststream/rabbit/subscriber/usecase.py | 14 +- faststream/redis/broker/broker.py | 31 ++-- faststream/redis/broker/logging.py | 87 +++++----- faststream/redis/subscriber/usecase.py | 13 +- .../specification/asyncapi/v2_6_0/generate.py | 2 +- .../specification/asyncapi/v3_0_0/generate.py | 2 +- faststream/specification/proto.py | 13 ++ tests/brokers/base/testclient.py | 18 ++ tests/brokers/confluent/test_logger.py | 23 +-- tests/cli/rabbit/test_app.py | 68 +++++--- tests/cli/rabbit/test_logs.py | 7 +- tests/cli/test_publish.py | 42 ++--- tests/opentelemetry/basic.py | 35 ++-- .../opentelemetry/confluent/test_confluent.py | 38 ++--- tests/opentelemetry/kafka/test_kafka.py | 44 +++-- tests/opentelemetry/nats/test_nats.py | 21 ++- tests/opentelemetry/rabbit/test_rabbit.py | 12 +- tests/opentelemetry/redis/test_redis.py | 47 +++--- 59 files changed, 1116 insertions(+), 892 deletions(-) rename docs/docs/en/api/faststream/confluent/broker/logging/{KafkaLoggingBroker.md => KafkaParamsStorage.md} (65%) rename docs/docs/en/api/faststream/{nats/broker/logging/NatsLoggingBroker.md => exceptions/IncorrectState.md} (67%) rename docs/docs/en/api/faststream/{redis/broker/logging/RedisLoggingBroker.md => kafka/broker/logging/KafkaParamsStorage.md} (66%) rename docs/docs/en/api/faststream/{kafka/broker/logging/KafkaLoggingBroker.md => middlewares/logging/LoggingMiddleware.md} (66%) rename docs/docs/en/api/faststream/{rabbit/broker/logging/RabbitLoggingBroker.md => nats/broker/logging/NatsParamsStorage.md} (65%) create mode 100644 docs/docs/en/api/faststream/rabbit/broker/logging/RabbitParamsStorage.md create mode 100644 docs/docs/en/api/faststream/redis/broker/logging/RedisParamsStorage.md delete mode 100644 faststream/_internal/broker/logging_mixin.py create mode 100644 faststream/_internal/setup/__init__.py create mode 100644 faststream/_internal/setup/fast_depends.py create mode 100644 faststream/_internal/setup/logger.py create mode 100644 faststream/_internal/setup/proto.py create mode 100644 faststream/_internal/setup/state.py diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index 5739aad3e6..ee574189c4 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -245,7 +245,7 @@ search: - broker - [KafkaBroker](api/faststream/confluent/broker/broker/KafkaBroker.md) - logging - - [KafkaLoggingBroker](api/faststream/confluent/broker/logging/KafkaLoggingBroker.md) + - [KafkaParamsStorage](api/faststream/confluent/broker/logging/KafkaParamsStorage.md) - registrator - [KafkaRegistrator](api/faststream/confluent/broker/registrator/KafkaRegistrator.md) - client @@ -335,6 +335,7 @@ search: - [FastStreamException](api/faststream/exceptions/FastStreamException.md) - [HandlerException](api/faststream/exceptions/HandlerException.md) - [IgnoredException](api/faststream/exceptions/IgnoredException.md) + - [IncorrectState](api/faststream/exceptions/IncorrectState.md) - [NackMessage](api/faststream/exceptions/NackMessage.md) - [OperationForbiddenError](api/faststream/exceptions/OperationForbiddenError.md) - [RejectMessage](api/faststream/exceptions/RejectMessage.md) @@ -358,7 +359,7 @@ search: - broker - [KafkaBroker](api/faststream/kafka/broker/broker/KafkaBroker.md) - logging - - [KafkaLoggingBroker](api/faststream/kafka/broker/logging/KafkaLoggingBroker.md) + - [KafkaParamsStorage](api/faststream/kafka/broker/logging/KafkaParamsStorage.md) - registrator - [KafkaRegistrator](api/faststream/kafka/broker/registrator/KafkaRegistrator.md) - fastapi @@ -444,6 +445,7 @@ search: - [ignore_handler](api/faststream/middlewares/exception/ignore_handler.md) - logging - [CriticalLogMiddleware](api/faststream/middlewares/logging/CriticalLogMiddleware.md) + - [LoggingMiddleware](api/faststream/middlewares/logging/LoggingMiddleware.md) - nats - [AckPolicy](api/faststream/nats/AckPolicy.md) - [ConsumerConfig](api/faststream/nats/ConsumerConfig.md) @@ -473,7 +475,7 @@ search: - broker - [NatsBroker](api/faststream/nats/broker/broker/NatsBroker.md) - logging - - [NatsLoggingBroker](api/faststream/nats/broker/logging/NatsLoggingBroker.md) + - [NatsParamsStorage](api/faststream/nats/broker/logging/NatsParamsStorage.md) - registrator - [NatsRegistrator](api/faststream/nats/broker/registrator/NatsRegistrator.md) - fastapi @@ -617,7 +619,7 @@ search: - broker - [RabbitBroker](api/faststream/rabbit/broker/broker/RabbitBroker.md) - logging - - [RabbitLoggingBroker](api/faststream/rabbit/broker/logging/RabbitLoggingBroker.md) + - [RabbitParamsStorage](api/faststream/rabbit/broker/logging/RabbitParamsStorage.md) - registrator - [RabbitRegistrator](api/faststream/rabbit/broker/registrator/RabbitRegistrator.md) - fastapi @@ -699,7 +701,7 @@ search: - broker - [RedisBroker](api/faststream/redis/broker/broker/RedisBroker.md) - logging - - [RedisLoggingBroker](api/faststream/redis/broker/logging/RedisLoggingBroker.md) + - [RedisParamsStorage](api/faststream/redis/broker/logging/RedisParamsStorage.md) - registrator - [RedisRegistrator](api/faststream/redis/broker/registrator/RedisRegistrator.md) - fastapi diff --git a/docs/docs/en/api/faststream/confluent/broker/logging/KafkaLoggingBroker.md b/docs/docs/en/api/faststream/confluent/broker/logging/KafkaParamsStorage.md similarity index 65% rename from docs/docs/en/api/faststream/confluent/broker/logging/KafkaLoggingBroker.md rename to docs/docs/en/api/faststream/confluent/broker/logging/KafkaParamsStorage.md index ea238b6b85..900f945037 100644 --- a/docs/docs/en/api/faststream/confluent/broker/logging/KafkaLoggingBroker.md +++ b/docs/docs/en/api/faststream/confluent/broker/logging/KafkaParamsStorage.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: faststream.confluent.broker.logging.KafkaLoggingBroker +::: faststream.confluent.broker.logging.KafkaParamsStorage diff --git a/docs/docs/en/api/faststream/nats/broker/logging/NatsLoggingBroker.md b/docs/docs/en/api/faststream/exceptions/IncorrectState.md similarity index 67% rename from docs/docs/en/api/faststream/nats/broker/logging/NatsLoggingBroker.md rename to docs/docs/en/api/faststream/exceptions/IncorrectState.md index cd31396a61..2c890d5358 100644 --- a/docs/docs/en/api/faststream/nats/broker/logging/NatsLoggingBroker.md +++ b/docs/docs/en/api/faststream/exceptions/IncorrectState.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: faststream.nats.broker.logging.NatsLoggingBroker +::: faststream.exceptions.IncorrectState diff --git a/docs/docs/en/api/faststream/redis/broker/logging/RedisLoggingBroker.md b/docs/docs/en/api/faststream/kafka/broker/logging/KafkaParamsStorage.md similarity index 66% rename from docs/docs/en/api/faststream/redis/broker/logging/RedisLoggingBroker.md rename to docs/docs/en/api/faststream/kafka/broker/logging/KafkaParamsStorage.md index 58500b3c1f..f7c8136115 100644 --- a/docs/docs/en/api/faststream/redis/broker/logging/RedisLoggingBroker.md +++ b/docs/docs/en/api/faststream/kafka/broker/logging/KafkaParamsStorage.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: faststream.redis.broker.logging.RedisLoggingBroker +::: faststream.kafka.broker.logging.KafkaParamsStorage diff --git a/docs/docs/en/api/faststream/kafka/broker/logging/KafkaLoggingBroker.md b/docs/docs/en/api/faststream/middlewares/logging/LoggingMiddleware.md similarity index 66% rename from docs/docs/en/api/faststream/kafka/broker/logging/KafkaLoggingBroker.md rename to docs/docs/en/api/faststream/middlewares/logging/LoggingMiddleware.md index 1f8d5921b7..62e6dfa604 100644 --- a/docs/docs/en/api/faststream/kafka/broker/logging/KafkaLoggingBroker.md +++ b/docs/docs/en/api/faststream/middlewares/logging/LoggingMiddleware.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: faststream.kafka.broker.logging.KafkaLoggingBroker +::: faststream.middlewares.logging.LoggingMiddleware diff --git a/docs/docs/en/api/faststream/rabbit/broker/logging/RabbitLoggingBroker.md b/docs/docs/en/api/faststream/nats/broker/logging/NatsParamsStorage.md similarity index 65% rename from docs/docs/en/api/faststream/rabbit/broker/logging/RabbitLoggingBroker.md rename to docs/docs/en/api/faststream/nats/broker/logging/NatsParamsStorage.md index a3b3151d4b..25b77d4331 100644 --- a/docs/docs/en/api/faststream/rabbit/broker/logging/RabbitLoggingBroker.md +++ b/docs/docs/en/api/faststream/nats/broker/logging/NatsParamsStorage.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: faststream.rabbit.broker.logging.RabbitLoggingBroker +::: faststream.nats.broker.logging.NatsParamsStorage diff --git a/docs/docs/en/api/faststream/rabbit/broker/logging/RabbitParamsStorage.md b/docs/docs/en/api/faststream/rabbit/broker/logging/RabbitParamsStorage.md new file mode 100644 index 0000000000..e9e46da6af --- /dev/null +++ b/docs/docs/en/api/faststream/rabbit/broker/logging/RabbitParamsStorage.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.rabbit.broker.logging.RabbitParamsStorage diff --git a/docs/docs/en/api/faststream/redis/broker/logging/RedisParamsStorage.md b/docs/docs/en/api/faststream/redis/broker/logging/RedisParamsStorage.md new file mode 100644 index 0000000000..b7d1bb680a --- /dev/null +++ b/docs/docs/en/api/faststream/redis/broker/logging/RedisParamsStorage.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.redis.broker.logging.RedisParamsStorage diff --git a/faststream/_internal/broker/abc_broker.py b/faststream/_internal/broker/abc_broker.py index 9e9f268038..9bd6961515 100644 --- a/faststream/_internal/broker/abc_broker.py +++ b/faststream/_internal/broker/abc_broker.py @@ -90,43 +90,37 @@ def include_router( for h in router._subscribers.values(): h.add_prefix("".join((self.prefix, prefix))) - if (key := hash(h)) not in self._subscribers: - if include_in_schema is None: - h.include_in_schema = self._solve_include_in_schema( - h.include_in_schema - ) - else: - h.include_in_schema = include_in_schema - - h._broker_middlewares = ( - *self._middlewares, - *middlewares, - *h._broker_middlewares, - ) - h._broker_dependencies = ( - *self._dependencies, - *dependencies, - *h._broker_dependencies, - ) - self._subscribers = {**self._subscribers, key: h} + if include_in_schema is None: + h.include_in_schema = self._solve_include_in_schema(h.include_in_schema) + else: + h.include_in_schema = include_in_schema + + h._broker_middlewares = ( + *self._middlewares, + *middlewares, + *h._broker_middlewares, + ) + h._broker_dependencies = ( + *self._dependencies, + *dependencies, + *h._broker_dependencies, + ) + self._subscribers = {**self._subscribers, hash(h): h} for p in router._publishers.values(): p.add_prefix(self.prefix) - if (key := hash(p)) not in self._publishers: - if include_in_schema is None: - p.include_in_schema = self._solve_include_in_schema( - p.include_in_schema - ) - else: - p.include_in_schema = include_in_schema - - p._broker_middlewares = ( - *self._middlewares, - *middlewares, - *p._broker_middlewares, - ) - self._publishers = {**self._publishers, key: p} + if include_in_schema is None: + p.include_in_schema = self._solve_include_in_schema(p.include_in_schema) + else: + p.include_in_schema = include_in_schema + + p._broker_middlewares = ( + *self._middlewares, + *middlewares, + *p._broker_middlewares, + ) + self._publishers = {**self._publishers, hash(p): p} def include_routers( self, diff --git a/faststream/_internal/broker/broker.py b/faststream/_internal/broker/broker.py index 20d6424246..fb19fd61d1 100644 --- a/faststream/_internal/broker/broker.py +++ b/faststream/_internal/broker/broker.py @@ -1,4 +1,3 @@ -import logging from abc import abstractmethod from contextlib import AsyncExitStack from functools import partial @@ -19,8 +18,14 @@ from typing_extensions import Annotated, Doc, Self from faststream._internal._compat import is_test_env -from faststream._internal.log.logging import set_logger_fmt -from faststream._internal.proto import SetupAble +from faststream._internal.setup import ( + EmptyState, + FastDependsData, + LoggerState, + SetupAble, + SetupState, +) +from faststream._internal.setup.state import BaseState from faststream._internal.subscriber.proto import SubscriberProto from faststream._internal.types import ( AsyncCustomCallable, @@ -33,14 +38,14 @@ from faststream.exceptions import NOT_CONNECTED_YET from faststream.middlewares.logging import CriticalLogMiddleware -from .logging_mixin import LoggingBroker +from .abc_broker import ABCBroker if TYPE_CHECKING: from types import TracebackType from fast_depends.dependencies import Depends - from faststream._internal.basic_types import AnyDict, Decorator, LoggerProto + from faststream._internal.basic_types import AnyDict, Decorator from faststream._internal.publisher.proto import ( ProducerProto, PublisherProto, @@ -51,7 +56,7 @@ class BrokerUsecase( - LoggingBroker[MsgType], + ABCBroker[MsgType], SetupAble, Generic[MsgType, ConnectionType], ): @@ -60,6 +65,7 @@ class BrokerUsecase( url: Union[str, Sequence[str]] _connection: Optional[ConnectionType] _producer: Optional["ProducerProto"] + _state: BaseState def __init__( self, @@ -87,22 +93,7 @@ def __init__( ), ], # Logging args - default_logger: Annotated[ - logging.Logger, - Doc("Logger object to use if `logger` is not set."), - ], - logger: Annotated[ - Optional["LoggerProto"], - Doc("User specified logger to pass into Context and log service messages."), - ], - log_level: Annotated[ - int, - Doc("Service messages log level."), - ], - log_fmt: Annotated[ - Optional[str], - Doc("Default logger log format."), - ], + logger_state: LoggerState, # FastDepends args apply_types: Annotated[ bool, @@ -163,11 +154,6 @@ def __init__( # Broker is a root router include_in_schema=True, prefix="", - # Logging args - default_logger=default_logger, - log_level=log_level, - log_fmt=log_fmt, - logger=logger, ) self.running = False @@ -180,15 +166,19 @@ def __init__( # TODO: remove useless middleware filter if not is_test_env(): self._middlewares = ( - CriticalLogMiddleware(self.logger, log_level), + CriticalLogMiddleware(logger_state), *self._middlewares, ) - # FastDepends args - self._is_apply_types = apply_types - self._is_validate = validate - self._get_dependant = _get_dependant - self._call_decorators = _call_decorators + self._state = EmptyState( + depends_params=FastDependsData( + apply_types=apply_types, + is_validate=validate, + get_dependent=_get_dependant, + call_decorators=_call_decorators, + ), + logger_state=logger_state, + ) # AsyncAPI information self.url = specification_url @@ -213,8 +203,13 @@ async def __aexit__( @abstractmethod async def start(self) -> None: """Start the broker async use case.""" - self._abc_start() - await self.connect() + # TODO: filter by already running handlers after TestClient refactor + for handler in self._subscribers.values(): + self._state.logger_state.log( + f"`{handler.call_name}` waiting for messages", + extra=handler.get_log_context(None), + ) + await handler.start() async def connect(self, **kwargs: Any) -> ConnectionType: """Connect to a remote server.""" @@ -222,7 +217,7 @@ async def connect(self, **kwargs: Any) -> ConnectionType: connection_kwargs = self._connection_kwargs.copy() connection_kwargs.update(kwargs) self._connection = await self._connect(**connection_kwargs) - self._setup() + return self._connection @abstractmethod @@ -230,8 +225,33 @@ async def _connect(self) -> ConnectionType: """Connect to a resource.""" raise NotImplementedError() - def _setup(self) -> None: + def _setup(self, state: Optional[BaseState] = None) -> None: """Prepare all Broker entities to startup.""" + if not self._state: + # Fallback to default state if there no + # parent container like FastStream object + default_state = self._state.copy_to_state(SetupState) + + if state: + self._state = state.copy_with_params( + depends_params=default_state.depends_params, + logger_state=default_state.logger_state, + ) + else: + self._state = default_state + + if not self.running: + self.running = True + + for h in self._subscribers.values(): + log_context = h.get_log_context(None) + log_context.pop("message_id", None) + self._state.logger_state.params_storage.setup_log_contest(log_context) + + self._state._setup() + + # TODO: why we can't move it to running? + # TODO: can we setup subscriber in running broker automatically? for h in self._subscribers.values(): self.setup_subscriber(h) @@ -261,21 +281,18 @@ def setup_publisher( @property def _subscriber_setup_extra(self) -> "AnyDict": return { - "logger": self.logger, + "logger": self._state.logger_state.logger.logger, "producer": self._producer, "graceful_timeout": self.graceful_timeout, "extra_context": { "broker": self, - "logger": self.logger, + "logger": self._state.logger_state.logger.logger, }, # broker options "broker_parser": self._parser, "broker_decoder": self._decoder, # dependant args - "apply_types": self._is_apply_types, - "is_validate": self._is_validate, - "_get_dependant": self._get_dependant, - "_call_decorators": self._call_decorators, + "state": self._state, } @property @@ -290,21 +307,6 @@ def publisher(self, *args: Any, **kwargs: Any) -> "PublisherProto[MsgType]": self.setup_publisher(pub) return pub - def _abc_start(self) -> None: - for h in self._subscribers.values(): - log_context = h.get_log_context(None) - log_context.pop("message_id", None) - self._setup_log_context(**log_context) - - if not self.running: - self.running = True - - if not self.use_custom and self.logger is not None: - set_logger_fmt( - cast(logging.Logger, self.logger), - self._get_fmt(), - ) - async def close( self, exc_type: Optional[Type[BaseException]] = None, @@ -312,23 +314,10 @@ async def close( exc_tb: Optional["TracebackType"] = None, ) -> None: """Closes the object.""" - self.running = False - for h in self._subscribers.values(): await h.close() - if self._connection is not None: - await self._close(exc_type, exc_val, exc_tb) - - @abstractmethod - async def _close( - self, - exc_type: Optional[Type[BaseException]] = None, - exc_val: Optional[BaseException] = None, - exc_tb: Optional["TracebackType"] = None, - ) -> None: - """Close the object.""" - self._connection = None + self.running = False async def publish( self, diff --git a/faststream/_internal/broker/logging_mixin.py b/faststream/_internal/broker/logging_mixin.py deleted file mode 100644 index 3b8f401465..0000000000 --- a/faststream/_internal/broker/logging_mixin.py +++ /dev/null @@ -1,93 +0,0 @@ -import logging -from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Optional - -from typing_extensions import Annotated, Doc - -from faststream._internal.constants import EMPTY -from faststream._internal.types import MsgType - -from .abc_broker import ABCBroker - -if TYPE_CHECKING: - from faststream._internal.basic_types import AnyDict, LoggerProto - - -class LoggingBroker(ABCBroker[MsgType]): - """A mixin class for logging.""" - - logger: Optional["LoggerProto"] - - @abstractmethod - def get_fmt(self) -> str: - """Fallback method to get log format if `log_fmt` if not specified.""" - raise NotImplementedError() - - @abstractmethod - def _setup_log_context(self) -> None: - raise NotImplementedError() - - def __init__( - self, - *args: Any, - default_logger: Annotated[ - logging.Logger, - Doc("Logger object to use if `logger` is not set."), - ], - logger: Annotated[ - Optional["LoggerProto"], - Doc("User specified logger to pass into Context and log service messages."), - ], - log_level: Annotated[ - int, - Doc("Service messages log level."), - ], - log_fmt: Annotated[ - Optional[str], - Doc("Default logger log format."), - ], - **kwargs: Any, - ) -> None: - if logger is not EMPTY: - self.logger = logger - self.use_custom = True - else: - self.logger = default_logger - self.use_custom = False - - self._msg_log_level = log_level - self._fmt = log_fmt - - super().__init__(*args, **kwargs) - - def _get_fmt(self) -> str: - """Get default logger format at broker startup.""" - return self._fmt or self.get_fmt() - - def _log( - self, - message: Annotated[ - str, - Doc("Log message."), - ], - log_level: Annotated[ - Optional[int], - Doc("Log record level. Use `__init__: log_level` option if not specified."), - ] = None, - extra: Annotated[ - Optional["AnyDict"], - Doc("Log record extra information."), - ] = None, - exc_info: Annotated[ - Optional[Exception], - Doc("Exception object to log traceback."), - ] = None, - ) -> None: - """Logs a message.""" - if self.logger is not None: - self.logger.log( - (log_level or self._msg_log_level), - message, - extra=extra, - exc_info=exc_info, - ) diff --git a/faststream/_internal/cli/main.py b/faststream/_internal/cli/main.py index 3abb63a398..29a874e083 100644 --- a/faststream/_internal/cli/main.py +++ b/faststream/_internal/cli/main.py @@ -250,6 +250,7 @@ def publish( if not app_obj.broker: raise ValueError("Broker instance not found in the app.") + app_obj._setup() result = anyio.run(publish_message, app_obj.broker, rpc, extra) if rpc: diff --git a/faststream/_internal/cli/utils/logs.py b/faststream/_internal/cli/utils/logs.py index 96f9f48bfd..4330f9cdb9 100644 --- a/faststream/_internal/cli/utils/logs.py +++ b/faststream/_internal/cli/utils/logs.py @@ -69,6 +69,9 @@ def set_log_level(level: int, app: "FastStream") -> None: if app.logger and getattr(app.logger, "setLevel", None): app.logger.setLevel(level) # type: ignore[attr-defined] - broker_logger: Optional[LoggerProto] = getattr(app.broker, "logger", None) - if broker_logger is not None and getattr(broker_logger, "setLevel", None): - broker_logger.setLevel(level) # type: ignore[attr-defined] + if app.broker: + broker_logger: Optional[LoggerProto] = ( + app.broker._state.logger_state.logger.logger + ) + if broker_logger is not None and getattr(broker_logger, "setLevel", None): + broker_logger.setLevel(level) # type: ignore[attr-defined] diff --git a/faststream/_internal/constants.py b/faststream/_internal/constants.py index c3d2f73b4a..16d984165e 100644 --- a/faststream/_internal/constants.py +++ b/faststream/_internal/constants.py @@ -11,7 +11,7 @@ class ContentTypes(str, Enum): json = "application/json" -class _EmptyPlaceholder: +class EmptyPlaceholder: def __repr__(self) -> str: return "EMPTY" @@ -19,10 +19,10 @@ def __bool__(self) -> bool: return False def __eq__(self, other: object) -> bool: - if not isinstance(other, _EmptyPlaceholder): + if not isinstance(other, EmptyPlaceholder): return NotImplemented return True -EMPTY: Any = _EmptyPlaceholder() +EMPTY: Any = EmptyPlaceholder() diff --git a/faststream/_internal/fastapi/router.py b/faststream/_internal/fastapi/router.py index 7bf691b96d..b03c23cc7d 100644 --- a/faststream/_internal/fastapi/router.py +++ b/faststream/_internal/fastapi/router.py @@ -36,6 +36,7 @@ from faststream._internal.fastapi.route import ( wrap_callable_to_fastapi_compatible, ) +from faststream._internal.setup import EmptyState from faststream._internal.types import ( MsgType, P_HandlerParams, @@ -166,6 +167,8 @@ def __init__( self.schema = None + self._state = EmptyState() + super().__init__( prefix=prefix, tags=tags, @@ -316,7 +319,7 @@ async def start_broker_lifespan( context = dict(maybe_context) context.update({"broker": self.broker}) - await self.broker.start() + await self._start_broker() for h in self._after_startup_hooks: h_context = await h(app) diff --git a/faststream/_internal/log/logging.py b/faststream/_internal/log/logging.py index a6139a0f2d..56c5a6d31b 100644 --- a/faststream/_internal/log/logging.py +++ b/faststream/_internal/log/logging.py @@ -49,24 +49,18 @@ def get_broker_logger( name: str, default_context: Mapping[str, str], message_id_ln: int, + fmt: str, ) -> logging.Logger: logger = logging.getLogger(f"faststream.access.{name}") + logger.setLevel(logging.INFO) logger.propagate = False logger.addFilter(ExtendedFilter(default_context, message_id_ln)) - logger.setLevel(logging.INFO) - return logger - - -def set_logger_fmt( - logger: logging.Logger, - fmt: str = "%(asctime)s %(levelname)s - %(message)s", -) -> None: handler = logging.StreamHandler(stream=sys.stdout) - - formatter = ColourizedFormatter( - fmt=fmt, - use_colors=True, + handler.setFormatter( + ColourizedFormatter( + fmt=fmt, + use_colors=True, + ) ) - handler.setFormatter(formatter) - logger.addHandler(handler) + return logger diff --git a/faststream/_internal/proto.py b/faststream/_internal/proto.py index cc6811167b..2ee3cb5106 100644 --- a/faststream/_internal/proto.py +++ b/faststream/_internal/proto.py @@ -1,10 +1,7 @@ from abc import abstractmethod from typing import Any, Optional, Protocol, Type, TypeVar, Union, overload - -class SetupAble(Protocol): - @abstractmethod - def _setup(self) -> None: ... +from .setup import SetupAble class Endpoint(SetupAble, Protocol): diff --git a/faststream/_internal/setup/__init__.py b/faststream/_internal/setup/__init__.py new file mode 100644 index 0000000000..7462d2c586 --- /dev/null +++ b/faststream/_internal/setup/__init__.py @@ -0,0 +1,17 @@ +from .fast_depends import FastDependsData +from .logger import LoggerParamsStorage, LoggerState +from .proto import SetupAble +from .state import EmptyState, SetupState + +__all__ = ( + # state + "SetupState", + "EmptyState", + # proto + "SetupAble", + # FastDepend + "FastDependsData", + # logging + "LoggerState", + "LoggerParamsStorage", +) diff --git a/faststream/_internal/setup/fast_depends.py b/faststream/_internal/setup/fast_depends.py new file mode 100644 index 0000000000..9122492539 --- /dev/null +++ b/faststream/_internal/setup/fast_depends.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence + +if TYPE_CHECKING: + from faststream._internal.basic_types import Decorator + + +@dataclass +class FastDependsData: + apply_types: bool + is_validate: bool + get_dependent: Optional[Callable[..., Any]] + call_decorators: Sequence["Decorator"] diff --git a/faststream/_internal/setup/logger.py b/faststream/_internal/setup/logger.py new file mode 100644 index 0000000000..216ddc1267 --- /dev/null +++ b/faststream/_internal/setup/logger.py @@ -0,0 +1,159 @@ +import warnings +from dataclasses import dataclass, field +from typing import Optional, Protocol, Type + +from faststream._internal.basic_types import AnyDict, LoggerProto +from faststream._internal.constants import EMPTY +from faststream.exceptions import IncorrectState + +from .proto import SetupAble + +__all__ = ( + "make_logger_state", + "LoggerState", + "LoggerParamsStorage", + "DefaultLoggerStorage", +) + + +def make_logger_state( + logger: Optional["LoggerProto"], + log_level: int, + log_fmt: Optional[str], + default_storag_cls: Type["DefaultLoggerStorage"], +) -> "LoggerState": + if logger is not EMPTY and log_fmt: + warnings.warn( + message="You can't set custom `logger` with `log_fmt` both.", + category=RuntimeWarning, + stacklevel=1, + ) + + if logger is EMPTY: + storage = default_storag_cls(log_fmt) + elif logger is None: + storage = _EmptyLoggerStorage() + else: + storage = _ManualLoggerStorage(logger) + + return LoggerState( + log_level=log_level, + params_storage=storage, + ) + + +class _LoggerObject(Protocol): + logger: Optional["LoggerProto"] + + def log( + self, + message: str, + log_level: int, + extra: Optional["AnyDict"] = None, + exc_info: Optional[Exception] = None, + ) -> None: ... + + +class _NotSetLoggerObject(_LoggerObject): + def __init__(self) -> None: + self.logger = None + + def log( + self, + message: str, + log_level: int, + extra: Optional["AnyDict"] = None, + exc_info: Optional[Exception] = None, + ) -> None: + raise IncorrectState("Logger object was not set up.") + + +class _EmptyLoggerObject(_LoggerObject): + def __init__(self) -> None: + self.logger = None + + def log( + self, + message: str, + log_level: int, + extra: Optional["AnyDict"] = None, + exc_info: Optional[Exception] = None, + ) -> None: + pass + + +class _RealLoggerObject(_LoggerObject): + def __init__(self, logger: "LoggerProto") -> None: + self.logger = logger + + def log( + self, + message: str, + log_level: int, + extra: Optional["AnyDict"] = None, + exc_info: Optional[Exception] = None, + ) -> None: + self.logger.log( + log_level, + message, + extra=extra, + exc_info=exc_info, + ) + + +class LoggerParamsStorage(Protocol): + def setup_log_contest(self, params: "AnyDict") -> None: ... + + def get_logger(self) -> Optional["LoggerProto"]: ... + + +class _EmptyLoggerStorage(LoggerParamsStorage): + def setup_log_contest(self, params: AnyDict) -> None: + pass + + def get_logger(self) -> None: + return None + + +class _ManualLoggerStorage(LoggerParamsStorage): + def __init__(self, logger: "LoggerProto") -> None: + self.__logger = logger + + def setup_log_contest(self, params: AnyDict) -> None: + pass + + def get_logger(self) -> LoggerProto: + return self.__logger + + +class DefaultLoggerStorage(LoggerParamsStorage): + def __init__(self, log_fmt: Optional[str]) -> None: + self._log_fmt = log_fmt + + +@dataclass +class LoggerState(SetupAble): + log_level: int + params_storage: LoggerParamsStorage + + logger: _LoggerObject = field(default=_NotSetLoggerObject(), init=False) + + def log( + self, + message: str, + log_level: Optional[int] = None, + extra: Optional["AnyDict"] = None, + exc_info: Optional[Exception] = None, + ) -> None: + self.logger.log( + log_level=(log_level or self.log_level), + message=message, + extra=extra, + exc_info=exc_info, + ) + + def _setup(self) -> None: + if logger := self.params_storage.get_logger(): + self.logger = _RealLoggerObject(logger) + else: + self.logger = _EmptyLoggerObject() diff --git a/faststream/_internal/setup/proto.py b/faststream/_internal/setup/proto.py new file mode 100644 index 0000000000..21b5eda882 --- /dev/null +++ b/faststream/_internal/setup/proto.py @@ -0,0 +1,7 @@ +from abc import abstractmethod +from typing import Protocol + + +class SetupAble(Protocol): + @abstractmethod + def _setup(self) -> None: ... diff --git a/faststream/_internal/setup/state.py b/faststream/_internal/setup/state.py new file mode 100644 index 0000000000..a2c33d5a3c --- /dev/null +++ b/faststream/_internal/setup/state.py @@ -0,0 +1,99 @@ +from abc import abstractmethod, abstractproperty +from typing import Optional, Type + +from faststream.exceptions import IncorrectState + +from .fast_depends import FastDependsData +from .logger import LoggerState +from .proto import SetupAble + + +class BaseState(SetupAble): + _depends_params: FastDependsData + _logger_params: LoggerState + + @abstractproperty + def depends_params(self) -> FastDependsData: + raise NotImplementedError + + @abstractproperty + def logger_state(self) -> LoggerState: + raise NotImplementedError + + @abstractmethod + def __bool__(self) -> bool: + raise NotImplementedError + + def _setup(self) -> None: + self.logger_state._setup() + + def copy_with_params( + self, + *, + depends_params: Optional[FastDependsData] = None, + logger_state: Optional[LoggerState] = None, + ) -> "SetupState": + return self.__class__( + logger_state=logger_state or self._logger_params, + depends_params=depends_params or self._depends_params, + ) + + def copy_to_state(self, state_cls: Type["SetupState"]) -> "SetupState": + return state_cls( + depends_params=self._depends_params, + logger_state=self._logger_params, + ) + + +class SetupState(BaseState): + """State after broker._setup() called.""" + + def __init__( + self, + *, + logger_state: LoggerState, + depends_params: FastDependsData, + ) -> None: + self._depends_params = depends_params + self._logger_params = logger_state + + @property + def depends_params(self) -> FastDependsData: + return self._depends_params + + @property + def logger_state(self) -> LoggerState: + return self._logger_params + + def __bool__(self) -> bool: + return True + + +class EmptyState(BaseState): + """Initial state for App, broker, etc.""" + + def __init__( + self, + *, + logger_state: Optional[LoggerState] = None, + depends_params: Optional[FastDependsData] = None, + ) -> None: + self._depends_params = depends_params + self._logger_params = logger_state + + @property + def depends_params(self) -> FastDependsData: + if not self._depends_params: + raise IncorrectState + + return self._depends_params + + @property + def logger_state(self) -> LoggerState: + if not self._logger_params: + raise IncorrectState + + return self._logger_params + + def __bool__(self) -> bool: + return False diff --git a/faststream/_internal/subscriber/call_item.py b/faststream/_internal/subscriber/call_item.py index 6b550747d1..264a5cf68b 100644 --- a/faststream/_internal/subscriber/call_item.py +++ b/faststream/_internal/subscriber/call_item.py @@ -14,7 +14,7 @@ from typing_extensions import override -from faststream._internal.proto import SetupAble +from faststream._internal.setup import SetupAble from faststream._internal.types import MsgType from faststream.exceptions import IgnoredException, SetupError diff --git a/faststream/_internal/subscriber/usecase.py b/faststream/_internal/subscriber/usecase.py index 6bc39cb158..27404749c7 100644 --- a/faststream/_internal/subscriber/usecase.py +++ b/faststream/_internal/subscriber/usecase.py @@ -46,6 +46,7 @@ BasePublisherProto, ProducerProto, ) + from faststream._internal.setup import SetupState from faststream._internal.types import ( AsyncCallable, BrokerMiddleware, @@ -154,10 +155,7 @@ def _setup( # type: ignore[override] broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], # dependant args - apply_types: bool, - is_validate: bool, - _get_dependant: Optional[Callable[..., Any]], - _call_decorators: Iterable["Decorator"], + state: "SetupState", ) -> None: self.lock = MultiLock() @@ -184,10 +182,13 @@ def _setup( # type: ignore[override] call._setup( parser=async_parser, decoder=async_decoder, - apply_types=apply_types, - is_validate=is_validate, - _get_dependant=_get_dependant, - _call_decorators=(*self._call_decorators, *_call_decorators), + apply_types=state.depends_params.apply_types, + is_validate=state.depends_params.is_validate, + _get_dependant=state.depends_params.get_dependent, + _call_decorators=( + *self._call_decorators, + *state.depends_params.call_decorators, + ), broker_dependencies=self._broker_dependencies, ) diff --git a/faststream/_internal/testing/broker.py b/faststream/_internal/testing/broker.py index 2be6ab8f81..b079a240bf 100644 --- a/faststream/_internal/testing/broker.py +++ b/faststream/_internal/testing/broker.py @@ -117,11 +117,10 @@ def _patch_broker(self, broker: Broker) -> Generator[None, None, None]: "ping", return_value=True, ): + broker._setup() yield def _fake_start(self, broker: Broker, *args: Any, **kwargs: Any) -> None: - broker._setup() - patch_broker_calls(broker) for p in broker._publishers.values(): @@ -169,9 +168,12 @@ def _fake_close( if getattr(p, "_fake_handler", None): p.reset_test() # type: ignore[attr-defined] - for sub in self._fake_subscribers: - self.broker._subscribers.pop(hash(sub), None) # type: ignore[attr-defined] - self._fake_subscribers = [] + self.broker._subscribers = { + hash(sub): sub + for sub in self.broker._subscribers.values() + if sub not in self._fake_subscribers + } + self._fake_subscribers.clear() for h in broker._subscribers.values(): h.running = False @@ -193,7 +195,7 @@ async def _fake_connect(broker: Broker, *args: Any, **kwargs: Any) -> None: def patch_broker_calls(broker: "BrokerUsecase[Any, Any]") -> None: """Patch broker calls.""" - broker._abc_start() + broker._setup() for handler in broker._subscribers.values(): for h in handler.calls: diff --git a/faststream/app.py b/faststream/app.py index cea10b9449..ef45aeef79 100644 --- a/faststream/app.py +++ b/faststream/app.py @@ -19,6 +19,7 @@ from faststream._internal.cli.supervisors.utils import set_exit from faststream._internal.context import context from faststream._internal.log.logging import logger +from faststream._internal.setup import EmptyState from faststream._internal.utils import apply_types from faststream._internal.utils.functions import ( drop_response_type, @@ -103,6 +104,7 @@ def __init__( ) self._should_exit = anyio.Event() + self._state = EmptyState() # Specification information self.title = title @@ -192,8 +194,7 @@ async def start( for func in self._on_startup_calling: await func(**run_extra_options) - if self.broker is not None: - await self.broker.start() + await self._start_broker() for func in self._after_startup_calling: await func() diff --git a/faststream/confluent/broker/broker.py b/faststream/confluent/broker/broker.py index 7b31cdf5ae..59b1063f0d 100644 --- a/faststream/confluent/broker/broker.py +++ b/faststream/confluent/broker/broker.py @@ -16,17 +16,14 @@ ) import anyio +import confluent_kafka from typing_extensions import Annotated, Doc, override from faststream.__about__ import SERVICE_NAME +from faststream._internal.broker.broker import BrokerUsecase from faststream._internal.constants import EMPTY from faststream._internal.utils.data import filter_by_dict -from faststream.confluent.broker.logging import KafkaLoggingBroker -from faststream.confluent.broker.registrator import KafkaRegistrator -from faststream.confluent.client import ( - AsyncConfluentConsumer, - AsyncConfluentProducer, -) +from faststream.confluent.client import AsyncConfluentConsumer, AsyncConfluentProducer from faststream.confluent.config import ConfluentFastConfig from faststream.confluent.publisher.producer import AsyncConfluentFastProducer from faststream.confluent.schemas.params import ConsumerConnectionParams @@ -34,6 +31,9 @@ from faststream.exceptions import NOT_CONNECTED_YET from faststream.message import gen_cor_id +from .logging import make_kafka_logger_state +from .registrator import KafkaRegistrator + if TYPE_CHECKING: from types import TracebackType @@ -60,7 +60,13 @@ class KafkaBroker( KafkaRegistrator, - KafkaLoggingBroker, + BrokerUsecase[ + Union[ + confluent_kafka.Message, + Tuple[confluent_kafka.Message, ...], + ], + Callable[..., AsyncConfluentConsumer], + ], ): url: List[str] _producer: Optional[AsyncConfluentFastProducer] @@ -384,9 +390,11 @@ def __init__( security=security, tags=tags, # Logging args - logger=logger, - log_level=log_level, - log_fmt=log_fmt, + logger_state=make_kafka_logger_state( + logger=logger, + log_level=log_level, + log_fmt=log_fmt, + ), # FastDepends args _get_dependant=_get_dependant, _call_decorators=_call_decorators, @@ -397,17 +405,19 @@ def __init__( self._producer = None self.config = ConfluentFastConfig(config) - async def _close( + async def close( self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, exc_tb: Optional["TracebackType"] = None, ) -> None: + await super().close(exc_type, exc_val, exc_tb) + if self._producer is not None: # pragma: no branch await self._producer.stop() self._producer = None - await super()._close(exc_type, exc_val, exc_tb) + self._connection = None async def connect( self, @@ -435,7 +445,6 @@ async def _connect( # type: ignore[override] native_producer = AsyncConfluentProducer( **kwargs, client_id=client_id, - logger=self.logger, config=self.config, ) @@ -448,20 +457,15 @@ async def _connect( # type: ignore[override] return partial( AsyncConfluentConsumer, **filter_by_dict(ConsumerConnectionParams, kwargs), - logger=self.logger, + logger=self._state.logger_state, config=self.config, ) async def start(self) -> None: + await self.connect() + self._setup() await super().start() - for handler in self._subscribers.values(): - self._log( - f"`{handler.call_name}` waiting for messages", - extra=handler.get_log_context(None), - ) - await handler.start() - @property def _subscriber_setup_extra(self) -> "AnyDict": return { diff --git a/faststream/confluent/broker/logging.py b/faststream/confluent/broker/logging.py index 6c549f5dc6..ec17abb215 100644 --- a/faststream/confluent/broker/logging.py +++ b/faststream/confluent/broker/logging.py @@ -1,72 +1,67 @@ -import logging -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple, Union +from functools import partial +from typing import TYPE_CHECKING, Optional -from faststream._internal.broker.broker import BrokerUsecase -from faststream._internal.constants import EMPTY from faststream._internal.log.logging import get_broker_logger -from faststream.confluent.client import AsyncConfluentConsumer +from faststream._internal.setup.logger import ( + DefaultLoggerStorage, + make_logger_state, +) if TYPE_CHECKING: - import confluent_kafka + from faststream._internal.basic_types import AnyDict, LoggerProto - from faststream._internal.basic_types import LoggerProto - - -class KafkaLoggingBroker( - BrokerUsecase[ - Union["confluent_kafka.Message", Tuple["confluent_kafka.Message", ...]], - Callable[..., AsyncConfluentConsumer], - ] -): - """A class that extends the LoggingMixin class and adds additional functionality for logging Kafka related information.""" - - _max_topic_len: int - _max_group_len: int - __max_msg_id_ln: ClassVar[int] = 10 +class KafkaParamsStorage(DefaultLoggerStorage): def __init__( self, - *args: Any, - logger: Optional["LoggerProto"] = EMPTY, - log_level: int = logging.INFO, - log_fmt: Optional[str] = None, - **kwargs: Any, + log_fmt: Optional[str], ) -> None: - """Initialize the class.""" - super().__init__( - *args, - logger=logger, - # TODO: generate unique logger names to not share between brokers - default_logger=get_broker_logger( - name="confluent", - default_context={ - "topic": "", - "group_id": "", - }, - message_id_ln=self.__max_msg_id_ln, - ), - log_level=log_level, - log_fmt=log_fmt, - **kwargs, - ) + super().__init__(log_fmt) + self._max_topic_len = 4 self._max_group_len = 0 - def get_fmt(self) -> str: - return ( - "%(asctime)s %(levelname)-8s - " - + f"%(topic)-{self._max_topic_len}s | " - + (f"%(group_id)-{self._max_group_len}s | " if self._max_group_len else "") - + f"%(message_id)-{self.__max_msg_id_ln}s " - + "- %(message)s" + def setup_log_contest(self, params: "AnyDict") -> None: + self._max_topic_len = max( + ( + self._max_topic_len, + len(params.get("topic", "")), + ) + ) + self._max_group_len = max( + ( + self._max_group_len, + len(params.get("group_id", "")), + ) + ) + + def get_logger(self) -> Optional["LoggerProto"]: + message_id_ln = 10 + + # TODO: generate unique logger names to not share between brokers + return get_broker_logger( + name="confluent", + default_context={ + "topic": "", + "group_id": "", + }, + message_id_ln=message_id_ln, + fmt=self._log_fmt + or ( + "%(asctime)s %(levelname)-8s - " + + f"%(topic)-{self._max_topic_len}s | " + + ( + f"%(group_id)-{self._max_group_len}s | " + if self._max_group_len + else "" + ) + + f"%(message_id)-{message_id_ln}s " + + "- %(message)s" + ), ) - def _setup_log_context( - self, - *, - topic: str = "", - group_id: Optional[str] = None, - ) -> None: - """Set up log context.""" - self._max_topic_len = max((self._max_topic_len, len(topic))) - self._max_group_len = max((self._max_group_len, len(group_id or ""))) + +make_kafka_logger_state = partial( + make_logger_state, + default_storag_cls=KafkaParamsStorage, +) diff --git a/faststream/confluent/client.py b/faststream/confluent/client.py index 2c62873ef3..bd30ef2a56 100644 --- a/faststream/confluent/client.py +++ b/faststream/confluent/client.py @@ -30,6 +30,7 @@ from typing_extensions import NotRequired, TypedDict from faststream._internal.basic_types import AnyDict, LoggerProto + from faststream._internal.setup.logger import LoggerState class _SendKwargs(TypedDict): value: Optional[Union[str, bytes]] @@ -46,7 +47,6 @@ class AsyncConfluentProducer: def __init__( self, *, - logger: Optional["LoggerProto"], config: config_module.ConfluentFastConfig, bootstrap_servers: Union[str, List[str]] = "localhost", client_id: Optional[str] = None, @@ -68,8 +68,6 @@ def __init__( sasl_plain_password: Optional[str] = None, sasl_plain_username: Optional[str] = None, ) -> None: - self.logger = logger - if isinstance(bootstrap_servers, Iterable) and not isinstance( bootstrap_servers, str ): @@ -112,7 +110,7 @@ def __init__( } ) - self.producer = Producer(final_config, logger=self.logger) + self.producer = Producer(final_config) self.__running = True self._poll_task = asyncio.create_task(self._poll_loop()) @@ -223,7 +221,7 @@ def __init__( self, *topics: str, partitions: Sequence["TopicPartition"], - logger: Optional["LoggerProto"], + logger: "LoggerState", config: config_module.ConfluentFastConfig, bootstrap_servers: Union[str, List[str]] = "localhost", client_id: Optional[str] = "confluent-kafka-consumer", @@ -251,7 +249,7 @@ def __init__( sasl_plain_password: Optional[str] = None, sasl_plain_username: Optional[str] = None, ) -> None: - self.logger = logger + self.logger_state = logger if isinstance(bootstrap_servers, Iterable) and not isinstance( bootstrap_servers, str @@ -312,7 +310,7 @@ def __init__( ) self.config = final_config - self.consumer = Consumer(final_config, logger=self.logger) + self.consumer = Consumer(final_config) @property def topics_to_create(self) -> List[str]: @@ -322,13 +320,16 @@ async def start(self) -> None: """Starts the Kafka consumer and subscribes to the specified topics.""" if self.allow_auto_create_topics: await call_or_await( - create_topics, self.topics_to_create, self.config, self.logger + create_topics, + self.topics_to_create, + self.config, + self.logger_state.logger.logger, ) - elif self.logger: - self.logger.log( - logging.WARNING, - "Auto create topics is disabled. Make sure the topics exist.", + else: + self.logger_state.log( + log_level=logging.WARNING, + message="Auto create topics is disabled. Make sure the topics exist.", ) if self.topics: @@ -359,10 +360,10 @@ async def stop(self) -> None: # No offset stored issue is not a problem - https://github.com/confluentinc/confluent-kafka-python/issues/295#issuecomment-355907183 if "No offset stored" in str(e): pass - elif self.logger: - self.logger.log( - logging.ERROR, - "Consumer closing error occurred.", + else: + self.logger_state.log( + log_level=logging.ERROR, + message="Consumer closing error occurred.", exc_info=e, ) diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 3bc8e25cdd..5b4cf0a6ff 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -26,8 +26,9 @@ if TYPE_CHECKING: from fast_depends.dependencies import Depends - from faststream._internal.basic_types import AnyDict, Decorator, LoggerProto + from faststream._internal.basic_types import AnyDict, LoggerProto from faststream._internal.publisher.proto import ProducerProto + from faststream._internal.setup import SetupState from faststream._internal.types import ( AsyncCallable, BrokerMiddleware, @@ -116,10 +117,7 @@ def _setup( # type: ignore[override] broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], # dependant args - apply_types: bool, - is_validate: bool, - _get_dependant: Optional[Callable[..., Any]], - _call_decorators: Iterable["Decorator"], + state: "SetupState", ) -> None: self.client_id = client_id self.builder = builder @@ -131,10 +129,7 @@ def _setup( # type: ignore[override] extra_context=extra_context, broker_parser=broker_parser, broker_decoder=broker_decoder, - apply_types=apply_types, - is_validate=is_validate, - _get_dependant=_get_dependant, - _call_decorators=_call_decorators, + state=state, ) @override diff --git a/faststream/confluent/testing.py b/faststream/confluent/testing.py index 13be8c316c..58704c3d6e 100644 --- a/faststream/confluent/testing.py +++ b/faststream/confluent/testing.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from faststream._internal.basic_types import SendableMessage + from faststream._internal.setup.logger import LoggerState from faststream.confluent.publisher.publisher import SpecificationPublisher from faststream.confluent.subscriber.usecase import LogicSubscriber @@ -88,6 +89,9 @@ def __init__(self, broker: KafkaBroker) -> None: self._parser = resolve_custom_func(broker._parser, default.parse_message) self._decoder = resolve_custom_func(broker._decoder, default.decode_message) + def _setup(self, logger_stater: "LoggerState") -> None: + pass + @override async def publish( # type: ignore[override] self, diff --git a/faststream/exceptions.py b/faststream/exceptions.py index f7774ed818..8c66a92808 100644 --- a/faststream/exceptions.py +++ b/faststream/exceptions.py @@ -111,6 +111,10 @@ class SubscriberNotFound(FastStreamException): """Raises as a service message or in tests.""" +class IncorrectState(FastStreamException): + """Raises in FSM at wrong state calling.""" + + class ContextError(FastStreamException, KeyError): """Raises if context exception occurred.""" diff --git a/faststream/kafka/broker/broker.py b/faststream/kafka/broker/broker.py index 156bae9964..6cb06bb057 100644 --- a/faststream/kafka/broker/broker.py +++ b/faststream/kafka/broker/broker.py @@ -22,16 +22,18 @@ from typing_extensions import Annotated, Doc, override from faststream.__about__ import SERVICE_NAME +from faststream._internal.broker.broker import BrokerUsecase from faststream._internal.constants import EMPTY from faststream._internal.utils.data import filter_by_dict from faststream.exceptions import NOT_CONNECTED_YET -from faststream.kafka.broker.logging import KafkaLoggingBroker -from faststream.kafka.broker.registrator import KafkaRegistrator from faststream.kafka.publisher.producer import AioKafkaFastProducer from faststream.kafka.schemas.params import ConsumerConnectionParams from faststream.kafka.security import parse_security from faststream.message import gen_cor_id +from .logging import make_kafka_logger_state +from .registrator import KafkaRegistrator + Partition = TypeVar("Partition") if TYPE_CHECKING: @@ -230,7 +232,10 @@ class KafkaInitKwargs(TypedDict, total=False): class KafkaBroker( KafkaRegistrator, - KafkaLoggingBroker, + BrokerUsecase[ + Union[aiokafka.ConsumerRecord, Tuple[aiokafka.ConsumerRecord, ...]], + Callable[..., aiokafka.AIOKafkaConsumer], + ], ): url: List[str] _producer: Optional["AioKafkaFastProducer"] @@ -565,9 +570,11 @@ def __init__( security=security, tags=tags, # Logging args - logger=logger, - log_level=log_level, - log_fmt=log_fmt, + logger_state=make_kafka_logger_state( + logger=logger, + log_level=log_level, + log_fmt=log_fmt, + ), # FastDepends args _get_dependant=_get_dependant, _call_decorators=_call_decorators, @@ -578,17 +585,19 @@ def __init__( self.client_id = client_id self._producer = None - async def _close( + async def close( self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, exc_tb: Optional["TracebackType"] = None, ) -> None: + await super().close(exc_type, exc_val, exc_tb) + if self._producer is not None: # pragma: no branch await self._producer.stop() self._producer = None - await super()._close(exc_type, exc_val, exc_tb) + self._connection = None @override async def connect( # type: ignore[override] @@ -643,15 +652,10 @@ async def _connect( # type: ignore[override] async def start(self) -> None: """Connect broker to Kafka and startup all subscribers.""" + await self.connect() + self._setup() await super().start() - for handler in self._subscribers.values(): - self._log( - f"`{handler.call_name}` waiting for messages", - extra=handler.get_log_context(None), - ) - await handler.start() - @property def _subscriber_setup_extra(self) -> "AnyDict": return { diff --git a/faststream/kafka/broker/logging.py b/faststream/kafka/broker/logging.py index bb7fd2c6e7..5334f4a70d 100644 --- a/faststream/kafka/broker/logging.py +++ b/faststream/kafka/broker/logging.py @@ -1,71 +1,67 @@ -import logging -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple, Union +from functools import partial +from typing import TYPE_CHECKING, Optional -from faststream._internal.broker.broker import BrokerUsecase -from faststream._internal.constants import EMPTY from faststream._internal.log.logging import get_broker_logger +from faststream._internal.setup.logger import ( + DefaultLoggerStorage, + make_logger_state, +) if TYPE_CHECKING: - import aiokafka + from faststream._internal.basic_types import AnyDict, LoggerProto - from faststream._internal.basic_types import LoggerProto - - -class KafkaLoggingBroker( - BrokerUsecase[ - Union["aiokafka.ConsumerRecord", Tuple["aiokafka.ConsumerRecord", ...]], - Callable[..., "aiokafka.AIOKafkaConsumer"], - ] -): - """A class that extends the LoggingMixin class and adds additional functionality for logging Kafka related information.""" - - _max_topic_len: int - _max_group_len: int - __max_msg_id_ln: ClassVar[int] = 10 +class KafkaParamsStorage(DefaultLoggerStorage): def __init__( self, - *args: Any, - logger: Optional["LoggerProto"] = EMPTY, - log_level: int = logging.INFO, - log_fmt: Optional[str] = None, - **kwargs: Any, + log_fmt: Optional[str], ) -> None: - """Initialize the class.""" - super().__init__( - *args, - logger=logger, - # TODO: generate unique logger names to not share between brokers - default_logger=get_broker_logger( - name="kafka", - default_context={ - "topic": "", - "group_id": "", - }, - message_id_ln=self.__max_msg_id_ln, - ), - log_level=log_level, - log_fmt=log_fmt, - **kwargs, - ) + super().__init__(log_fmt) + self._max_topic_len = 4 self._max_group_len = 0 - def get_fmt(self) -> str: - return ( - "%(asctime)s %(levelname)-8s - " - + f"%(topic)-{self._max_topic_len}s | " - + (f"%(group_id)-{self._max_group_len}s | " if self._max_group_len else "") - + f"%(message_id)-{self.__max_msg_id_ln}s " - + "- %(message)s" + def setup_log_contest(self, params: "AnyDict") -> None: + self._max_topic_len = max( + ( + self._max_topic_len, + len(params.get("topic", "")), + ) + ) + self._max_group_len = max( + ( + self._max_group_len, + len(params.get("group_id", "")), + ) + ) + + def get_logger(self) -> Optional["LoggerProto"]: + message_id_ln = 10 + + # TODO: generate unique logger names to not share between brokers + return get_broker_logger( + name="kafka", + default_context={ + "topic": "", + "group_id": "", + }, + message_id_ln=message_id_ln, + fmt=self._log_fmt + or ( + "%(asctime)s %(levelname)-8s - " + + f"%(topic)-{self._max_topic_len}s | " + + ( + f"%(group_id)-{self._max_group_len}s | " + if self._max_group_len + else "" + ) + + f"%(message_id)-{message_id_ln}s " + + "- %(message)s" + ), ) - def _setup_log_context( - self, - *, - topic: str = "", - group_id: Optional[str] = None, - ) -> None: - """Set up log context.""" - self._max_topic_len = max((self._max_topic_len, len(topic))) - self._max_group_len = max((self._max_group_len, len(group_id or ""))) + +make_kafka_logger_state = partial( + make_logger_state, + default_storag_cls=KafkaParamsStorage, +) diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index cabc505d29..3a32f6016c 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -36,8 +36,9 @@ from aiokafka.abc import ConsumerRebalanceListener from fast_depends.dependencies import Depends - from faststream._internal.basic_types import AnyDict, Decorator, LoggerProto + from faststream._internal.basic_types import AnyDict, LoggerProto from faststream._internal.publisher.proto import ProducerProto + from faststream._internal.setup import SetupState from faststream.message import StreamMessage @@ -121,10 +122,7 @@ def _setup( # type: ignore[override] broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], # dependant args - apply_types: bool, - is_validate: bool, - _get_dependant: Optional[Callable[..., Any]], - _call_decorators: Iterable["Decorator"], + state: "SetupState", ) -> None: self.client_id = client_id self.builder = builder @@ -136,10 +134,7 @@ def _setup( # type: ignore[override] extra_context=extra_context, broker_parser=broker_parser, broker_decoder=broker_decoder, - apply_types=apply_types, - is_validate=is_validate, - _get_dependant=_get_dependant, - _call_decorators=_call_decorators, + state=state, ) async def start(self) -> None: diff --git a/faststream/middlewares/logging.py b/faststream/middlewares/logging.py index c0ab3c8479..df47735543 100644 --- a/faststream/middlewares/logging.py +++ b/faststream/middlewares/logging.py @@ -1,9 +1,8 @@ import logging from typing import TYPE_CHECKING, Any, Optional, Type -from typing_extensions import Self - from faststream._internal.context.repository import context +from faststream._internal.setup.logger import LoggerState from faststream.exceptions import IgnoredException from .base import BaseMiddleware @@ -11,63 +10,59 @@ if TYPE_CHECKING: from types import TracebackType - from faststream._internal.basic_types import LoggerProto from faststream.message import StreamMessage -class CriticalLogMiddleware(BaseMiddleware): - """A middleware class for logging critical errors.""" - - def __init__( - self, - logger: Optional["LoggerProto"], - log_level: int, - ) -> None: +class CriticalLogMiddleware: + def __init__(self, logger: LoggerState) -> None: """Initialize the class.""" self.logger = logger - self.log_level = log_level - def __call__(self, msg: Optional[Any]) -> Self: - """Call the object with a message.""" - self.msg = msg - return self + def __call__(self, msg: Optional[Any] = None) -> Any: + return LoggingMiddleware(logger=self.logger) + + +class LoggingMiddleware(BaseMiddleware): + """A middleware class for logging critical errors.""" + + def __init__(self, logger: LoggerState) -> None: + self.logger = logger async def on_consume( self, msg: "StreamMessage[Any]", ) -> "StreamMessage[Any]": - if self.logger is not None: - c = context.get_local("log_context", {}) - self.logger.log(self.log_level, "Received", extra=c) - + self.logger.log( + "Received", + extra=context.get_local("log_context", {}), + ) return await super().on_consume(msg) - async def after_processed( + async def __aexit__( self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, exc_tb: Optional["TracebackType"] = None, ) -> bool: """Asynchronously called after processing.""" - if self.logger is not None: - c = context.get_local("log_context", {}) - - if exc_type: - if issubclass(exc_type, IgnoredException): - self.logger.log( - logging.INFO, - exc_val, - extra=c, - ) - else: - self.logger.log( - logging.ERROR, - f"{exc_type.__name__}: {exc_val}", - exc_info=exc_val, - extra=c, - ) - - self.logger.log(self.log_level, "Processed", extra=c) + c = context.get_local("log_context", {}) + + if exc_type: + if issubclass(exc_type, IgnoredException): + self.logger.log( + logging.INFO, + exc_val, + extra=c, + ) + else: + self.logger.log( + logging.ERROR, + f"{exc_type.__name__}: {exc_val}", + exc_info=exc_val, + extra=c, + ) + + self.logger.log("Processed", extra=c) await super().after_processed(exc_type, exc_val, exc_tb) diff --git a/faststream/nats/broker/broker.py b/faststream/nats/broker/broker.py index f07a278d6c..46942f6e5f 100644 --- a/faststream/nats/broker/broker.py +++ b/faststream/nats/broker/broker.py @@ -24,21 +24,25 @@ DEFAULT_PENDING_SIZE, DEFAULT_PING_INTERVAL, DEFAULT_RECONNECT_TIME_WAIT, + Client, ) +from nats.aio.msg import Msg from nats.errors import Error from nats.js.errors import BadRequestError from typing_extensions import Annotated, Doc, override from faststream.__about__ import SERVICE_NAME +from faststream._internal.broker.broker import BrokerUsecase from faststream._internal.constants import EMPTY from faststream.message import gen_cor_id -from faststream.nats.broker.logging import NatsLoggingBroker -from faststream.nats.broker.registrator import NatsRegistrator from faststream.nats.helpers import KVBucketDeclarer, OSBucketDeclarer from faststream.nats.publisher.producer import NatsFastProducer, NatsJSFastProducer from faststream.nats.security import parse_security from faststream.nats.subscriber.subscriber import SpecificationSubscriber +from .logging import make_nats_logger_state +from .registrator import NatsRegistrator + if TYPE_CHECKING: import ssl from types import TracebackType @@ -46,13 +50,11 @@ from fast_depends.dependencies import Depends from nats.aio.client import ( Callback, - Client, Credentials, ErrorCallback, JWTCallback, SignatureCallback, ) - from nats.aio.msg import Msg from nats.js.api import Placement, RePublish, StorageType from nats.js.client import JetStreamContext from nats.js.kv import KeyValue @@ -214,7 +216,7 @@ class NatsInitKwargs(TypedDict, total=False): class NatsBroker( NatsRegistrator, - NatsLoggingBroker, + BrokerUsecase[Msg, Client], ): """A class to represent a NATS broker.""" @@ -532,9 +534,11 @@ def __init__( security=security, tags=tags, # logging - logger=logger, - log_level=log_level, - log_fmt=log_fmt, + logger_state=make_nats_logger_state( + logger=logger, + log_level=log_level, + log_fmt=log_fmt, + ), # FastDepends args apply_types=apply_types, validate=validate, @@ -597,29 +601,29 @@ async def _connect(self, **kwargs: Any) -> "Client": return connection - async def _close( + async def close( self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, exc_tb: Optional["TracebackType"] = None, ) -> None: - self._producer = None - self._js_producer = None - self.stream = None + await super().close(exc_type, exc_val, exc_tb) if self._connection is not None: await self._connection.drain() + self._connection = None - await super()._close(exc_type, exc_val, exc_tb) + self.stream = None + self._producer = None + self._js_producer = None self.__is_connected = False async def start(self) -> None: """Connect broker to NATS cluster and startup all subscribers.""" - await super().start() + await self.connect() + self._setup() - assert self._connection # nosec B101 assert self.stream, "Broker should be started already" # nosec B101 - assert self._producer, "Broker should be started already" # nosec B101 for stream in filter( lambda x: x.declare, @@ -645,7 +649,7 @@ async def start(self) -> None: ): old_config = (await self.stream.stream_info(stream.name)).config - self._log(str(e), logging.WARNING, log_context) + self._state.logger_state.log(str(e), logging.WARNING, log_context) await self.stream.update_stream( config=stream.config, subjects=tuple( @@ -654,19 +658,15 @@ async def start(self) -> None: ) else: # pragma: no cover - self._log(str(e), logging.ERROR, log_context, exc_info=e) + self._state.logger_state.log( + str(e), logging.ERROR, log_context, exc_info=e + ) finally: # prevent from double declaration stream.declare = False - # TODO: filter by already running handlers after TestClient refactor - for handler in self._subscribers.values(): - self._log( - f"`{handler.call_name}` waiting for messages", - extra=handler.get_log_context(None), - ) - await handler.start() + await super().start() @override async def publish( # type: ignore[override] @@ -923,7 +923,7 @@ async def wrapper(err: Exception) -> None: await error_cb(err) if isinstance(err, Error) and self.__is_connected: - self._log( + self._state.logger_state.log( f"Connection broken with {err!r}", logging.WARNING, c, exc_info=err ) self.__is_connected = False @@ -941,7 +941,7 @@ async def wrapper() -> None: await cb() if not self.__is_connected: - self._log("Connection established", logging.INFO, c) + self._state.logger_state.log("Connection established", logging.INFO, c) self.__is_connected = True return wrapper diff --git a/faststream/nats/broker/logging.py b/faststream/nats/broker/logging.py index dd163a1c77..0cf9ad45c3 100644 --- a/faststream/nats/broker/logging.py +++ b/faststream/nats/broker/logging.py @@ -1,74 +1,76 @@ -import logging -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from functools import partial +from typing import TYPE_CHECKING, Optional -from nats.aio.client import Client -from nats.aio.msg import Msg - -from faststream._internal.broker.broker import BrokerUsecase -from faststream._internal.constants import EMPTY from faststream._internal.log.logging import get_broker_logger +from faststream._internal.setup.logger import ( + DefaultLoggerStorage, + make_logger_state, +) if TYPE_CHECKING: - from faststream._internal.basic_types import LoggerProto - + from faststream._internal.basic_types import AnyDict, LoggerProto -class NatsLoggingBroker(BrokerUsecase[Msg, Client]): - """A class that extends the LoggingMixin class and adds additional functionality for logging NATS related information.""" - - _max_queue_len: int - _max_subject_len: int - __max_msg_id_ln: ClassVar[int] = 10 +class NatsParamsStorage(DefaultLoggerStorage): def __init__( self, - *args: Any, - logger: Optional["LoggerProto"] = EMPTY, - log_level: int = logging.INFO, - log_fmt: Optional[str] = None, - **kwargs: Any, + log_fmt: Optional[str], ) -> None: - """Initialize the NATS logging mixin.""" - super().__init__( - *args, - logger=logger, - # TODO: generate unique logger names to not share between brokers - default_logger=get_broker_logger( - name="nats", - default_context={ - "subject": "", - "stream": "", - "queue": "", - }, - message_id_ln=self.__max_msg_id_ln, - ), - log_level=log_level, - log_fmt=log_fmt, - **kwargs, - ) + super().__init__(log_fmt) self._max_queue_len = 0 self._max_stream_len = 0 self._max_subject_len = 4 - def get_fmt(self) -> str: - """Fallback method to get log format if `log_fmt` if not specified.""" - return ( - "%(asctime)s %(levelname)-8s - " - + (f"%(stream)-{self._max_stream_len}s | " if self._max_stream_len else "") - + (f"%(queue)-{self._max_queue_len}s | " if self._max_queue_len else "") - + f"%(subject)-{self._max_subject_len}s | " - + f"%(message_id)-{self.__max_msg_id_ln}s - " - "%(message)s" + def setup_log_contest(self, params: "AnyDict") -> None: + self._max_subject_len = max( + ( + self._max_subject_len, + len(params.get("subject", "")), + ) + ) + self._max_queue_len = max( + ( + self._max_queue_len, + len(params.get("queue", "")), + ) + ) + self._max_stream_len = max( + ( + self._max_stream_len, + len(params.get("stream", "")), + ) + ) + + def get_logger(self) -> Optional["LoggerProto"]: + message_id_ln = 10 + + # TODO: generate unique logger names to not share between brokers + return get_broker_logger( + name="nats", + default_context={ + "subject": "", + "stream": "", + "queue": "", + }, + message_id_ln=message_id_ln, + fmt=self._log_fmt + or ( + "%(asctime)s %(levelname)-8s - " + + ( + f"%(stream)-{self._max_stream_len}s | " + if self._max_stream_len + else "" + ) + + (f"%(queue)-{self._max_queue_len}s | " if self._max_queue_len else "") + + f"%(subject)-{self._max_subject_len}s | " + + f"%(message_id)-{message_id_ln}s - " + "%(message)s" + ), ) - def _setup_log_context( - self, - *, - queue: Optional[str] = None, - subject: Optional[str] = None, - stream: Optional[str] = None, - ) -> None: - """Setup subscriber's information to generate default log format.""" - self._max_subject_len = max((self._max_subject_len, len(subject or ""))) - self._max_queue_len = max((self._max_queue_len, len(queue or ""))) - self._max_stream_len = max((self._max_stream_len, len(stream or ""))) + +make_nats_logger_state = partial( + make_logger_state, + default_storag_cls=NatsParamsStorage, +) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index de1e5e8137..19c0778092 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -54,11 +54,11 @@ from faststream._internal.basic_types import ( AnyDict, - Decorator, LoggerProto, SendableMessage, ) from faststream._internal.publisher.proto import ProducerProto + from faststream._internal.setup import SetupState from faststream._internal.types import ( AsyncCallable, BrokerMiddleware, @@ -139,10 +139,7 @@ def _setup( # type: ignore[override] broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], # dependant args - apply_types: bool, - is_validate: bool, - _get_dependant: Optional[Callable[..., Any]], - _call_decorators: Iterable["Decorator"], + state: "SetupState", ) -> None: self._connection = connection @@ -153,10 +150,7 @@ def _setup( # type: ignore[override] extra_context=extra_context, broker_parser=broker_parser, broker_decoder=broker_decoder, - apply_types=apply_types, - is_validate=is_validate, - _get_dependant=_get_dependant, - _call_decorators=_call_decorators, + state=state, ) @property diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index 18b23f6a53..1548d42243 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -12,15 +12,14 @@ from urllib.parse import urlparse import anyio -from aio_pika import connect_robust +from aio_pika import IncomingMessage, RobustConnection, connect_robust from typing_extensions import Annotated, Doc, override from faststream.__about__ import SERVICE_NAME +from faststream._internal.broker.broker import BrokerUsecase from faststream._internal.constants import EMPTY from faststream.exceptions import NOT_CONNECTED_YET from faststream.message import gen_cor_id -from faststream.rabbit.broker.logging import RabbitLoggingBroker -from faststream.rabbit.broker.registrator import RabbitRegistrator from faststream.rabbit.helpers.declarer import RabbitDeclarer from faststream.rabbit.publisher.producer import AioPikaFastProducer from faststream.rabbit.schemas import ( @@ -29,18 +28,18 @@ RabbitQueue, ) from faststream.rabbit.security import parse_security -from faststream.rabbit.subscriber.subscriber import SpecificationSubscriber from faststream.rabbit.utils import build_url +from .logging import make_rabbit_logger_state +from .registrator import RabbitRegistrator + if TYPE_CHECKING: from ssl import SSLContext from types import TracebackType import aiormq from aio_pika import ( - IncomingMessage, RobustChannel, - RobustConnection, RobustExchange, RobustQueue, ) @@ -62,7 +61,7 @@ class RabbitBroker( RabbitRegistrator, - RabbitLoggingBroker, + BrokerUsecase[IncomingMessage, RobustConnection], ): """A class to represent a RabbitMQ broker.""" @@ -274,9 +273,11 @@ def __init__( security=security, tags=tags, # Logging args - logger=logger, - log_level=log_level, - log_fmt=log_fmt, + logger_state=make_rabbit_logger_state( + logger=logger, + log_level=log_level, + log_fmt=log_fmt, + ), # FastDepends args apply_types=apply_types, validate=validate, @@ -453,7 +454,6 @@ async def _connect( # type: ignore[override] ) if self._channel is None: # pragma: no branch - max_consumers = self._max_consumers channel = self._channel = cast( "RobustChannel", await connection.channel( @@ -472,40 +472,39 @@ async def _connect( # type: ignore[override] parser=self._parser, ) - if max_consumers: - c = SpecificationSubscriber.build_log_context( - None, - RabbitQueue(""), - RabbitExchange(""), - ) - self._log(f"Set max consumers to {max_consumers}", extra=c) - await channel.set_qos(prefetch_count=int(max_consumers)) + if self._max_consumers: + await channel.set_qos(prefetch_count=int(self._max_consumers)) return connection - async def _close( + async def close( self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, exc_tb: Optional["TracebackType"] = None, ) -> None: + await super().close(exc_type, exc_val, exc_tb) + if self._channel is not None: if not self._channel.is_closed: await self._channel.close() self._channel = None - self.declarer = None - self._producer = None - if self._connection is not None: await self._connection.close() + self._connection = None - await super()._close(exc_type, exc_val, exc_tb) + self.declarer = None + self._producer = None async def start(self) -> None: """Connect broker to RabbitMQ and startup all subscribers.""" - await super().start() + await self.connect() + self._setup() + + if self._max_consumers: + self._state.logger_state.log(f"Set max consumers to {self._max_consumers}") assert self.declarer, NOT_CONNECTED_YET # nosec B101 @@ -513,12 +512,7 @@ async def start(self) -> None: if publisher.exchange is not None: await self.declare_exchange(publisher.exchange) - for subscriber in self._subscribers.values(): - self._log( - f"`{subscriber.call_name}` waiting for messages", - extra=subscriber.get_log_context(None), - ) - await subscriber.start() + await super().start() @override async def publish( # type: ignore[override] diff --git a/faststream/rabbit/broker/logging.py b/faststream/rabbit/broker/logging.py index cf104338a7..4074d3e6df 100644 --- a/faststream/rabbit/broker/logging.py +++ b/faststream/rabbit/broker/logging.py @@ -1,66 +1,59 @@ -import logging -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from functools import partial +from typing import TYPE_CHECKING, Optional -from aio_pika import IncomingMessage, RobustConnection - -from faststream._internal.broker.broker import BrokerUsecase -from faststream._internal.constants import EMPTY from faststream._internal.log.logging import get_broker_logger +from faststream._internal.setup.logger import ( + DefaultLoggerStorage, + make_logger_state, +) if TYPE_CHECKING: - from faststream._internal.basic_types import LoggerProto - + from faststream._internal.basic_types import AnyDict, LoggerProto -class RabbitLoggingBroker(BrokerUsecase[IncomingMessage, RobustConnection]): - """A class that extends the LoggingMixin class and adds additional functionality for logging RabbitMQ related information.""" - - _max_queue_len: int - _max_exchange_len: int - __max_msg_id_ln: ClassVar[int] = 10 +class RabbitParamsStorage(DefaultLoggerStorage): def __init__( self, - *args: Any, - logger: Optional["LoggerProto"] = EMPTY, - log_level: int = logging.INFO, - log_fmt: Optional[str] = None, - **kwargs: Any, + log_fmt: Optional[str], ) -> None: - super().__init__( - *args, - logger=logger, - # TODO: generate unique logger names to not share between brokers - default_logger=get_broker_logger( - name="rabbit", - default_context={ - "queue": "", - "exchange": "", - }, - message_id_ln=self.__max_msg_id_ln, - ), - log_level=log_level, - log_fmt=log_fmt, - **kwargs, - ) + super().__init__(log_fmt) - self._max_queue_len = 4 self._max_exchange_len = 4 + self._max_queue_len = 4 - def get_fmt(self) -> str: - return ( - "%(asctime)s %(levelname)-8s - " - f"%(exchange)-{self._max_exchange_len}s | " - f"%(queue)-{self._max_queue_len}s | " - f"%(message_id)-{self.__max_msg_id_ln}s " - "- %(message)s" + def setup_log_contest(self, params: "AnyDict") -> None: + self._max_exchange_len = max( + self._max_exchange_len, + len(params.get("exchange", "")), + ) + self._max_queue_len = max( + self._max_queue_len, + len(params.get("queue", "")), ) - def _setup_log_context( - self, - *, - queue: Optional[str] = None, - exchange: Optional[str] = None, - ) -> None: - """Set up log context.""" - self._max_exchange_len = max(self._max_exchange_len, len(exchange or "")) - self._max_queue_len = max(self._max_queue_len, len(queue or "")) + def get_logger(self) -> "LoggerProto": + message_id_ln = 10 + + # TODO: generate unique logger names to not share between brokers + return get_broker_logger( + name="rabbit", + default_context={ + "queue": "", + "exchange": "", + }, + message_id_ln=message_id_ln, + fmt=self._log_fmt + or ( + "%(asctime)s %(levelname)-8s - " + f"%(exchange)-{self._max_exchange_len}s | " + f"%(queue)-{self._max_queue_len}s | " + f"%(message_id)-{message_id_ln}s " + "- %(message)s" + ), + ) + + +make_rabbit_logger_state = partial( + make_logger_state, + default_storag_cls=RabbitParamsStorage, +) diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index 6a8c346db3..f1bfff8a02 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -1,7 +1,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Dict, Iterable, Optional, @@ -23,7 +22,8 @@ from aio_pika import IncomingMessage, RobustQueue from fast_depends.dependencies import Depends - from faststream._internal.basic_types import AnyDict, Decorator, LoggerProto + from faststream._internal.basic_types import AnyDict, LoggerProto + from faststream._internal.setup import SetupState from faststream._internal.types import BrokerMiddleware, CustomCallable from faststream.message import StreamMessage from faststream.rabbit.helpers.declarer import RabbitDeclarer @@ -111,10 +111,7 @@ def _setup( # type: ignore[override] broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], # dependant args - apply_types: bool, - is_validate: bool, - _get_dependant: Optional[Callable[..., Any]], - _call_decorators: Iterable["Decorator"], + state: "SetupState", ) -> None: self.app_id = app_id self.virtual_host = virtual_host @@ -127,10 +124,7 @@ def _setup( # type: ignore[override] extra_context=extra_context, broker_parser=broker_parser, broker_decoder=broker_decoder, - apply_types=apply_types, - is_validate=is_validate, - _get_dependant=_get_dependant, - _call_decorators=_call_decorators, + state=state, ) @override diff --git a/faststream/redis/broker/broker.py b/faststream/redis/broker/broker.py index 82e6423b14..18b96ff29b 100644 --- a/faststream/redis/broker/broker.py +++ b/faststream/redis/broker/broker.py @@ -26,14 +26,17 @@ from typing_extensions import Annotated, Doc, TypeAlias, override from faststream.__about__ import __version__ +from faststream._internal.broker.broker import BrokerUsecase from faststream._internal.constants import EMPTY from faststream.exceptions import NOT_CONNECTED_YET from faststream.message import gen_cor_id -from faststream.redis.broker.logging import RedisLoggingBroker -from faststream.redis.broker.registrator import RedisRegistrator +from faststream.redis.message import UnifyRedisDict from faststream.redis.publisher.producer import RedisFastProducer from faststream.redis.security import parse_security +from .logging import make_redis_logger_state +from .registrator import RedisRegistrator + if TYPE_CHECKING: from types import TracebackType @@ -83,7 +86,7 @@ class RedisInitKwargs(TypedDict, total=False): class RedisBroker( RedisRegistrator, - RedisLoggingBroker, + BrokerUsecase[UnifyRedisDict, "Redis[bytes]"], ): """Redis broker.""" @@ -239,9 +242,9 @@ def __init__( security=security, tags=tags, # logging - logger=logger, - log_level=log_level, - log_fmt=log_fmt, + logger_state=make_redis_logger_state( + logger=logger, log_level=log_level, log_fmt=log_fmt + ), # FastDepends args apply_types=apply_types, validate=validate, @@ -334,27 +337,23 @@ async def _connect( # type: ignore[override] ) return client - async def _close( + async def close( self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, exc_tb: Optional["TracebackType"] = None, ) -> None: + await super().close(exc_type, exc_val, exc_tb) + if self._connection is not None: await self._connection.aclose() # type: ignore[attr-defined] - - await super()._close(exc_type, exc_val, exc_tb) + self._connection = None async def start(self) -> None: + await self.connect() + self._setup() await super().start() - for handler in self._subscribers.values(): - self._log( - f"`{handler.call_name}` waiting for messages", - extra=handler.get_log_context(None), - ) - await handler.start() - @property def _subscriber_setup_extra(self) -> "AnyDict": return { diff --git a/faststream/redis/broker/logging.py b/faststream/redis/broker/logging.py index 468da8162e..f3f89ad324 100644 --- a/faststream/redis/broker/logging.py +++ b/faststream/redis/broker/logging.py @@ -1,59 +1,54 @@ -import logging -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from functools import partial +from typing import TYPE_CHECKING, Optional -from faststream._internal.broker.broker import BrokerUsecase -from faststream._internal.constants import EMPTY from faststream._internal.log.logging import get_broker_logger -from faststream.redis.message import UnifyRedisDict +from faststream._internal.setup.logger import ( + DefaultLoggerStorage, + make_logger_state, +) if TYPE_CHECKING: - from redis.asyncio.client import Redis # noqa: F401 + from faststream._internal.basic_types import AnyDict, LoggerProto - from faststream._internal.basic_types import LoggerProto - - -class RedisLoggingBroker(BrokerUsecase[UnifyRedisDict, "Redis[bytes]"]): - """A class that extends the LoggingMixin class and adds additional functionality for logging Redis related information.""" - - _max_channel_name: int - __max_msg_id_ln: ClassVar[int] = 10 +class RedisParamsStorage(DefaultLoggerStorage): def __init__( self, - *args: Any, - logger: Optional["LoggerProto"] = EMPTY, - log_level: int = logging.INFO, - log_fmt: Optional[str] = None, - **kwargs: Any, + log_fmt: Optional[str], ) -> None: - super().__init__( - *args, - logger=logger, - # TODO: generate unique logger names to not share between brokers - default_logger=get_broker_logger( - name="redis", - default_context={ - "channel": "", - }, - message_id_ln=self.__max_msg_id_ln, - ), - log_level=log_level, - log_fmt=log_fmt, - **kwargs, - ) + super().__init__(log_fmt) + self._max_channel_name = 4 - def get_fmt(self) -> str: - return ( - "%(asctime)s %(levelname)-8s - " - f"%(channel)-{self._max_channel_name}s | " - f"%(message_id)-{self.__max_msg_id_ln}s " - "- %(message)s" + def setup_log_contest(self, params: "AnyDict") -> None: + self._max_channel_name = max( + ( + self._max_channel_name, + len(params.get("channel", "")), + ) ) - def _setup_log_context( - self, - *, - channel: Optional[str] = None, - ) -> None: - self._max_channel_name = max((self._max_channel_name, len(channel or ""))) + def get_logger(self) -> Optional["LoggerProto"]: + message_id_ln = 10 + + # TODO: generate unique logger names to not share between brokers + return get_broker_logger( + name="redis", + default_context={ + "channel": "", + }, + message_id_ln=message_id_ln, + fmt=self._log_fmt + or ( + "%(asctime)s %(levelname)-8s - " + f"%(channel)-{self._max_channel_name}s | " + f"%(message_id)-{message_id_ln}s " + "- %(message)s" + ), + ) + + +make_redis_logger_state = partial( + make_logger_state, + default_storag_cls=RedisParamsStorage, +) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index f576c47217..796dbeee2e 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -48,8 +48,9 @@ if TYPE_CHECKING: from fast_depends.dependencies import Depends - from faststream._internal.basic_types import AnyDict, Decorator, LoggerProto + from faststream._internal.basic_types import AnyDict, LoggerProto from faststream._internal.publisher.proto import ProducerProto + from faststream._internal.setup import SetupState from faststream._internal.types import ( AsyncCallable, BrokerMiddleware, @@ -115,10 +116,7 @@ def _setup( # type: ignore[override] broker_parser: Optional["CustomCallable"], broker_decoder: Optional["CustomCallable"], # dependant args - apply_types: bool, - is_validate: bool, - _get_dependant: Optional[Callable[..., Any]], - _call_decorators: Iterable["Decorator"], + state: "SetupState", ) -> None: self._client = connection @@ -129,10 +127,7 @@ def _setup( # type: ignore[override] extra_context=extra_context, broker_parser=broker_parser, broker_decoder=broker_decoder, - apply_types=apply_types, - is_validate=is_validate, - _get_dependant=_get_dependant, - _call_decorators=_call_decorators, + state=state, ) def _make_response_publisher( diff --git a/faststream/specification/asyncapi/v2_6_0/generate.py b/faststream/specification/asyncapi/v2_6_0/generate.py index c8d8dc823f..69fcbe84a5 100644 --- a/faststream/specification/asyncapi/v2_6_0/generate.py +++ b/faststream/specification/asyncapi/v2_6_0/generate.py @@ -32,7 +32,7 @@ def get_app_schema(app: Application) -> Schema: if broker is None: # pragma: no cover raise RuntimeError() - broker._setup() + app._setup() servers = get_broker_server(broker) channels = get_broker_channels(broker) diff --git a/faststream/specification/asyncapi/v3_0_0/generate.py b/faststream/specification/asyncapi/v3_0_0/generate.py index 2bfcdc3eb5..08ce96e6c4 100644 --- a/faststream/specification/asyncapi/v3_0_0/generate.py +++ b/faststream/specification/asyncapi/v3_0_0/generate.py @@ -40,7 +40,7 @@ def get_app_schema(app: Application) -> Schema: broker = app.broker if broker is None: # pragma: no cover raise RuntimeError() - broker._setup() + app._setup() servers = get_broker_server(broker) channels = get_broker_channels(broker) diff --git a/faststream/specification/proto.py b/faststream/specification/proto.py index de9514573b..776eb9ad16 100644 --- a/faststream/specification/proto.py +++ b/faststream/specification/proto.py @@ -11,6 +11,7 @@ AnyHttpUrl, ) from faststream._internal.broker.broker import BrokerUsecase + from faststream._internal.setup import SetupState from faststream.specification.schema.contact import Contact, ContactDict from faststream.specification.schema.docs import ExternalDocs, ExternalDocsDict from faststream.specification.schema.license import License, LicenseDict @@ -18,6 +19,8 @@ class Application(Protocol): + _state: "SetupState" + broker: Optional["BrokerUsecase[Any, Any]"] title: str @@ -30,6 +33,16 @@ class Application(Protocol): external_docs: Optional[Union["ExternalDocs", "ExternalDocsDict", "AnyDict"]] identifier: Optional[str] + def _setup(self) -> None: + if self.broker is not None: + self.broker._setup(self._state) + + async def _start_broker(self) -> None: + if self.broker is not None: + await self.broker.connect() + self._setup() + await self.broker.start() + class SpecificationProto(Protocol): """A class representing an asynchronous API operation.""" diff --git a/tests/brokers/base/testclient.py b/tests/brokers/base/testclient.py index 543a400b2a..f6d45d7cc7 100644 --- a/tests/brokers/base/testclient.py +++ b/tests/brokers/base/testclient.py @@ -14,6 +14,24 @@ class BrokerTestclientTestcase(BrokerPublishTestcase, BrokerConsumeTestcase): def get_fake_producer_class(self) -> type: raise NotImplementedError + @pytest.mark.asyncio + async def test_correct_clean_fake_subscribers(self): + broker = self.get_broker() + + @broker.subscriber("test") + async def handler1(msg): ... + + broker.publisher("test2") + broker.publisher("test") + + assert len(broker._subscribers) == 1 + + test_client = self.patch_broker(broker) + async with test_client as br: + assert len(br._subscribers) == 2 + + assert len(broker._subscribers) == 1 + @pytest.mark.asyncio async def test_subscriber_mock(self, queue: str): test_broker = self.get_broker() diff --git a/tests/brokers/confluent/test_logger.py b/tests/brokers/confluent/test_logger.py index 33698d522e..a956c8d40e 100644 --- a/tests/brokers/confluent/test_logger.py +++ b/tests/brokers/confluent/test_logger.py @@ -1,5 +1,4 @@ import logging -from typing import Any import pytest @@ -12,28 +11,20 @@ class TestLogger(ConfluentTestcaseConfig): """A class to represent a test Kafka broker.""" - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> KafkaBroker: - return broker - @pytest.mark.asyncio async def test_custom_logger(self, queue: str): test_logger = logging.getLogger("test_logger") - consume_broker = self.get_broker(logger=test_logger) + broker = KafkaBroker(logger=test_logger) args, kwargs = self.get_subscriber_params(queue) - @consume_broker.subscriber(*args, **kwargs) + @broker.subscriber(*args, **kwargs) def subscriber(m): ... - async with self.patch_broker(consume_broker) as br: - await br.start() + await broker.start() - for sub in br._subscribers.values(): - consumer_logger = sub.consumer.logger - assert consumer_logger == test_logger + for sub in broker._subscribers.values(): + consumer_logger = sub.consumer.logger_state.logger.logger + assert consumer_logger == test_logger - producer_logger = br._producer._producer.logger - assert producer_logger == test_logger + await broker.close() diff --git a/tests/cli/rabbit/test_app.py b/tests/cli/rabbit/test_app.py index 104b515c0e..ca59ff5d4d 100644 --- a/tests/cli/rabbit/test_app.py +++ b/tests/cli/rabbit/test_app.py @@ -122,7 +122,9 @@ async def call2(): test_app = FastStream(broker=broker, after_startup=[call1, call2]) - with patch.object(test_app.broker, "start", async_mock.broker_start): + with patch.object(test_app.broker, "start", async_mock.broker_start), patch.object( + test_app.broker, "connect", async_mock.broker_connect + ): await test_app.start() mock.after_startup1.assert_called_once() @@ -130,7 +132,9 @@ async def call2(): @pytest.mark.asyncio -async def test_startup_lifespan_before_broker_started(async_mock, app: FastStream): +async def test_startup_lifespan_before_broker_started( + async_mock: AsyncMock, app: FastStream +): @app.on_startup async def call(): await async_mock.before() @@ -142,7 +146,9 @@ async def call_after(): async_mock.before.assert_awaited_once() async_mock.broker_start.assert_called_once() - with patch.object(app.broker, "start", async_mock.broker_start): + with patch.object(app.broker, "start", async_mock.broker_start), patch.object( + app.broker, "connect", async_mock.broker_connect + ): await app.start() async_mock.broker_start.assert_called_once() @@ -162,7 +168,9 @@ async def call2(): test_app = FastStream(broker=broker, after_shutdown=[call1, call2]) - with patch.object(test_app.broker, "start", async_mock.broker_start): + with patch.object(test_app.broker, "start", async_mock.broker_start), patch.object( + test_app.broker, "connect", async_mock.broker_connect + ): await test_app.stop() mock.after_shutdown1.assert_called_once() @@ -171,7 +179,7 @@ async def call2(): @pytest.mark.asyncio async def test_shutdown_lifespan_after_broker_stopped( - mock, async_mock, app: FastStream + mock, async_mock: AsyncMock, app: FastStream ): @app.after_shutdown async def call(): @@ -192,14 +200,15 @@ async def call_before(): @pytest.mark.asyncio -async def test_running(async_mock, app: FastStream): +async def test_running(async_mock: AsyncMock, app: FastStream): app.exit() with patch.object(app.broker, "start", async_mock.broker_run), patch.object( - app.broker, "close", async_mock.broker_stopped - ): + app.broker, "connect", async_mock.broker_connect + ), patch.object(app.broker, "close", async_mock.broker_stopped): await app.run() + async_mock.broker_connect.assert_called_once() async_mock.broker_run.assert_called_once() async_mock.broker_stopped.assert_called_once() @@ -217,7 +226,9 @@ async def f(): @pytest.mark.asyncio -async def test_running_lifespan_contextmanager(async_mock, mock: Mock, app: FastStream): +async def test_running_lifespan_contextmanager( + async_mock: AsyncMock, mock: Mock, app: FastStream +): @asynccontextmanager async def lifespan(env: str): mock.on(env) @@ -228,10 +239,11 @@ async def lifespan(env: str): app.exit() with patch.object(app.broker, "start", async_mock.broker_run), patch.object( - app.broker, "close", async_mock.broker_stopped - ): + app.broker, "connect", async_mock.broker_connect + ), patch.object(app.broker, "close", async_mock.broker_stopped): await app.run(run_extra_options={"env": "test"}) + async_mock.broker_connect.assert_called_once() async_mock.broker_run.assert_called_once() async_mock.broker_stopped.assert_called_once() @@ -241,30 +253,32 @@ async def lifespan(env: str): @pytest.mark.asyncio @pytest.mark.skipif(IS_WINDOWS, reason="does not run on windows") -async def test_stop_with_sigint(async_mock, app: FastStream): - with patch.object(app.broker, "start", async_mock.broker_run_sigint), patch.object( - app.broker, "close", async_mock.broker_stopped_sigint - ): +async def test_stop_with_sigint(async_mock: AsyncMock, app: FastStream): + with patch.object(app.broker, "start", async_mock.broker_run), patch.object( + app.broker, "connect", async_mock.broker_connect + ), patch.object(app.broker, "close", async_mock.broker_stopped): async with anyio.create_task_group() as tg: tg.start_soon(app.run) tg.start_soon(_kill, signal.SIGINT) - async_mock.broker_run_sigint.assert_called_once() - async_mock.broker_stopped_sigint.assert_called_once() + async_mock.broker_connect.assert_called_once() + async_mock.broker_run.assert_called_once() + async_mock.broker_stopped.assert_called_once() @pytest.mark.asyncio @pytest.mark.skipif(IS_WINDOWS, reason="does not run on windows") -async def test_stop_with_sigterm(async_mock, app: FastStream): - with patch.object(app.broker, "start", async_mock.broker_run_sigterm), patch.object( - app.broker, "close", async_mock.broker_stopped_sigterm - ): +async def test_stop_with_sigterm(async_mock: AsyncMock, app: FastStream): + with patch.object(app.broker, "start", async_mock.broker_run), patch.object( + app.broker, "connect", async_mock.broker_connect + ), patch.object(app.broker, "close", async_mock.broker_stopped): async with anyio.create_task_group() as tg: tg.start_soon(app.run) tg.start_soon(_kill, signal.SIGTERM) - async_mock.broker_run_sigterm.assert_called_once() - async_mock.broker_stopped_sigterm.assert_called_once() + async_mock.broker_connect.assert_called_once() + async_mock.broker_run.assert_called_once() + async_mock.broker_stopped.assert_called_once() @pytest.mark.asyncio @@ -333,8 +347,8 @@ async def lifespan(env: str): app = FastStream(app.broker, lifespan=lifespan) with patch.object(app.broker, "start", async_mock.broker_run), patch.object( - app.broker, "close", async_mock.broker_stopped - ): + app.broker, "connect", async_mock.broker_connect + ), patch.object(app.broker, "close", async_mock.broker_stopped): async with TestApp(app, {"env": "test"}): pass @@ -355,7 +369,9 @@ async def lifespan(env: str): with patch.object(app.broker, "start", async_mock.broker_run), patch.object( app.broker, "close", async_mock.broker_stopped - ), TestApp(app, {"env": "test"}): + ), patch.object(app.broker, "connect", async_mock.broker_connect), TestApp( + app, {"env": "test"} + ): pass async_mock.on.assert_awaited_once_with("test") diff --git a/tests/cli/rabbit/test_logs.py b/tests/cli/rabbit/test_logs.py index 92bf5beb67..6506f88524 100644 --- a/tests/cli/rabbit/test_logs.py +++ b/tests/cli/rabbit/test_logs.py @@ -20,12 +20,14 @@ ) def test_set_level(level, app: FastStream): level = get_log_level(level) + app._setup() set_log_level(level, app) - assert app.logger.level is app.broker.logger.level is level + broker_logger = app.broker._state.logger_state.logger.logger + assert app.logger.level is broker_logger.level is level @pytest.mark.parametrize( - ("level", "broker"), + ("level", "app"), ( # noqa: PT007 pytest.param( logging.CRITICAL, @@ -50,6 +52,7 @@ def test_set_level(level, app: FastStream): ), ) def test_set_level_to_none(level, app: FastStream): + app._setup() set_log_level(get_log_level(level), app) diff --git a/tests/cli/test_publish.py b/tests/cli/test_publish.py index a718d6e693..d73382a771 100644 --- a/tests/cli/test_publish.py +++ b/tests/cli/test_publish.py @@ -1,3 +1,4 @@ +from typing import Tuple from unittest.mock import AsyncMock, patch from dirty_equals import IsPartialDict @@ -14,7 +15,7 @@ ) -def get_mock_app(broker_type, producer_type) -> FastStream: +def get_mock_app(broker_type, producer_type) -> Tuple[FastStream, AsyncMock]: broker = broker_type() broker.connect = AsyncMock() mock_producer = AsyncMock(spec=producer_type) @@ -22,7 +23,7 @@ def get_mock_app(broker_type, producer_type) -> FastStream: mock_producer._parser = AsyncMock() mock_producer._decoder = AsyncMock() broker._producer = mock_producer - return FastStream(broker) + return FastStream(broker), mock_producer @require_redis @@ -30,7 +31,7 @@ def test_publish_command_with_redis_options(runner): from faststream.redis import RedisBroker from faststream.redis.publisher.producer import RedisFastProducer - mock_app = get_mock_app(RedisBroker, RedisFastProducer) + mock_app, producer_mock = get_mock_app(RedisBroker, RedisFastProducer) with patch( "faststream._internal.cli.main.import_from_string", @@ -57,8 +58,8 @@ def test_publish_command_with_redis_options(runner): assert result.exit_code == 0 - assert mock_app.broker._producer.publish.call_args.args[0] == "hello world" - assert mock_app.broker._producer.publish.call_args.kwargs == IsPartialDict( + assert producer_mock.publish.call_args.args[0] == "hello world" + assert producer_mock.publish.call_args.kwargs == IsPartialDict( reply_to="tester", stream="streamname", list="listname", @@ -72,7 +73,7 @@ def test_publish_command_with_confluent_options(runner): from faststream.confluent import KafkaBroker as ConfluentBroker from faststream.confluent.publisher.producer import AsyncConfluentFastProducer - mock_app = get_mock_app(ConfluentBroker, AsyncConfluentFastProducer) + mock_app, producer_mock = get_mock_app(ConfluentBroker, AsyncConfluentFastProducer) with patch( "faststream._internal.cli.main.import_from_string", @@ -92,8 +93,9 @@ def test_publish_command_with_confluent_options(runner): ) assert result.exit_code == 0 - assert mock_app.broker._producer.publish.call_args.args[0] == "hello world" - assert mock_app.broker._producer.publish.call_args.kwargs == IsPartialDict( + + assert producer_mock.publish.call_args.args[0] == "hello world" + assert producer_mock.publish.call_args.kwargs == IsPartialDict( topic="topicname", correlation_id="someId", ) @@ -104,7 +106,7 @@ def test_publish_command_with_kafka_options(runner): from faststream.kafka import KafkaBroker from faststream.kafka.publisher.producer import AioKafkaFastProducer - mock_app = get_mock_app(KafkaBroker, AioKafkaFastProducer) + mock_app, producer_mock = get_mock_app(KafkaBroker, AioKafkaFastProducer) with patch( "faststream._internal.cli.main.import_from_string", @@ -124,8 +126,8 @@ def test_publish_command_with_kafka_options(runner): ) assert result.exit_code == 0 - assert mock_app.broker._producer.publish.call_args.args[0] == "hello world" - assert mock_app.broker._producer.publish.call_args.kwargs == IsPartialDict( + assert producer_mock.publish.call_args.args[0] == "hello world" + assert producer_mock.publish.call_args.kwargs == IsPartialDict( topic="topicname", correlation_id="someId", ) @@ -136,7 +138,7 @@ def test_publish_command_with_nats_options(runner): from faststream.nats import NatsBroker from faststream.nats.publisher.producer import NatsFastProducer - mock_app = get_mock_app(NatsBroker, NatsFastProducer) + mock_app, producer_mock = get_mock_app(NatsBroker, NatsFastProducer) with patch( "faststream._internal.cli.main.import_from_string", @@ -159,8 +161,8 @@ def test_publish_command_with_nats_options(runner): assert result.exit_code == 0 - assert mock_app.broker._producer.publish.call_args.args[0] == "hello world" - assert mock_app.broker._producer.publish.call_args.kwargs == IsPartialDict( + assert producer_mock.publish.call_args.args[0] == "hello world" + assert producer_mock.publish.call_args.kwargs == IsPartialDict( subject="subjectname", reply_to="tester", correlation_id="someId", @@ -172,7 +174,7 @@ def test_publish_command_with_rabbit_options(runner): from faststream.rabbit import RabbitBroker from faststream.rabbit.publisher.producer import AioPikaFastProducer - mock_app = get_mock_app(RabbitBroker, AioPikaFastProducer) + mock_app, producer_mock = get_mock_app(RabbitBroker, AioPikaFastProducer) with patch( "faststream._internal.cli.main.import_from_string", @@ -191,8 +193,8 @@ def test_publish_command_with_rabbit_options(runner): assert result.exit_code == 0 - assert mock_app.broker._producer.publish.call_args.args[0] == "hello world" - assert mock_app.broker._producer.publish.call_args.kwargs == IsPartialDict( + assert producer_mock.publish.call_args.args[0] == "hello world" + assert producer_mock.publish.call_args.kwargs == IsPartialDict( { "correlation_id": "someId", } @@ -204,7 +206,7 @@ def test_publish_nats_request_command(runner: CliRunner): from faststream.nats import NatsBroker from faststream.nats.publisher.producer import NatsFastProducer - mock_app = get_mock_app(NatsBroker, NatsFastProducer) + mock_app, producer_mock = get_mock_app(NatsBroker, NatsFastProducer) with patch( "faststream._internal.cli.main.import_from_string", @@ -224,8 +226,8 @@ def test_publish_nats_request_command(runner: CliRunner): ], ) - assert mock_app.broker._producer.request.call_args.args[0] == "hello world" - assert mock_app.broker._producer.request.call_args.kwargs == IsPartialDict( + assert producer_mock.request.call_args.args[0] == "hello world" + assert producer_mock.request.call_args.kwargs == IsPartialDict( subject="subjectname", timeout=1.0, ) diff --git a/tests/opentelemetry/basic.py b/tests/opentelemetry/basic.py index 2207f7b0dc..5cd8b6b0ad 100644 --- a/tests/opentelemetry/basic.py +++ b/tests/opentelemetry/basic.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Optional, Tuple, Type, cast +from typing import Any, List, Optional, Tuple, cast from unittest.mock import Mock import pytest @@ -29,12 +29,21 @@ class LocalTelemetryTestcase(BaseTestcaseConfig): messaging_system: str include_messages_counters: bool - broker_class: Type[BrokerUsecase] resource: Resource = Resource.create(attributes={"service.name": "faststream.test"}) - telemetry_middleware_class: TelemetryMiddleware - def patch_broker(self, broker: BrokerUsecase) -> BrokerUsecase: + def get_broker( + self, + apply_types: bool = False, + **kwargs: Any, + ) -> BrokerUsecase[Any, Any]: + raise NotImplementedError + + def patch_broker( + self, + broker: BrokerUsecase[Any, Any], + **kwargs: Any, + ) -> BrokerUsecase[Any, Any]: return broker def destination_name(self, queue: str) -> str: @@ -163,7 +172,7 @@ async def test_subscriber_create_publish_process_span( trace_exporter: InMemorySpanExporter, ): mid = self.telemetry_middleware_class(tracer_provider=tracer_provider) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,)) args, kwargs = self.get_subscriber_params(queue) @@ -202,7 +211,7 @@ async def test_chain_subscriber_publisher( trace_exporter: InMemorySpanExporter, ): mid = self.telemetry_middleware_class(tracer_provider=tracer_provider) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,)) first_queue = queue second_queue = queue + "2" @@ -262,7 +271,7 @@ async def test_no_trace_context_create_process_span( trace_exporter: InMemorySpanExporter, ): mid = self.telemetry_middleware_class(tracer_provider=tracer_provider) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,)) args, kwargs = self.get_subscriber_params(queue) @@ -301,7 +310,7 @@ async def test_metrics( metric_reader: InMemoryMetricReader, ): mid = self.telemetry_middleware_class(meter_provider=meter_provider) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,)) args, kwargs = self.get_subscriber_params(queue) @@ -337,7 +346,7 @@ async def test_error_metrics( metric_reader: InMemoryMetricReader, ): mid = self.telemetry_middleware_class(meter_provider=meter_provider) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,)) expected_value_type = "ValueError" args, kwargs = self.get_subscriber_params(queue) @@ -377,7 +386,7 @@ async def test_span_in_context( trace_exporter: InMemorySpanExporter, ): mid = self.telemetry_middleware_class(tracer_provider=tracer_provider) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) args, kwargs = self.get_subscriber_params(queue) @@ -408,7 +417,7 @@ async def test_get_baggage( mock: Mock, ): mid = self.telemetry_middleware_class() - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) expected_baggage = {"foo": "bar"} args, kwargs = self.get_subscriber_params(queue) @@ -447,7 +456,7 @@ async def test_clear_baggage( mock: Mock, ): mid = self.telemetry_middleware_class() - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) first_queue = queue + "1" second_queue = queue + "2" @@ -494,7 +503,7 @@ async def test_modify_baggage( mock: Mock, ): mid = self.telemetry_middleware_class() - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) expected_baggage = {"baz": "bar", "bar": "baz"} first_queue = queue + "1" diff --git a/tests/opentelemetry/confluent/test_confluent.py b/tests/opentelemetry/confluent/test_confluent.py index af914ebc9d..4d9844a6e6 100644 --- a/tests/opentelemetry/confluent/test_confluent.py +++ b/tests/opentelemetry/confluent/test_confluent.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional +from typing import Any, Optional from unittest.mock import Mock import pytest @@ -24,9 +24,11 @@ class TestTelemetry(ConfluentTestcaseConfig, LocalTelemetryTestcase): messaging_system = "kafka" include_messages_counters = True - broker_class = KafkaBroker telemetry_middleware_class = KafkaTelemetryMiddleware + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: + return KafkaBroker(apply_types=apply_types, **kwargs) + def assert_span( self, span: Span, @@ -73,7 +75,7 @@ async def test_batch( mid = self.telemetry_middleware_class( meter_provider=meter_provider, tracer_provider=tracer_provider ) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) expected_msg_count = 3 expected_link_count = 1 expected_link_attrs = {"messaging.batch.message_count": 3} @@ -91,13 +93,11 @@ async def handler(m, baggage: CurrentBaggage): mock(m) event.set() - broker = self.patch_broker(broker) - - async with broker: - await broker.start() + async with self.patch_broker(broker) as br: + await br.start() tasks = ( asyncio.create_task( - broker.publish_batch( + br.publish_batch( 1, "hi", 3, @@ -139,7 +139,7 @@ async def test_batch_publish_with_single_consume( mid = self.telemetry_middleware_class( meter_provider=meter_provider, tracer_provider=tracer_provider ) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) msgs_queue = asyncio.Queue(maxsize=3) expected_msg_count = 3 expected_link_count = 1 @@ -155,11 +155,9 @@ async def handler(msg, baggage: CurrentBaggage): assert baggage.get_all_batch() == [] await msgs_queue.put(msg) - broker = self.patch_broker(broker) - - async with broker: - await broker.start() - await broker.publish_batch( + async with self.patch_broker(broker) as br: + await br.start() + await br.publish_batch( 1, "hi", 3, topic=queue, headers=Baggage({"foo": "bar"}).to_headers() ) result, _ = await asyncio.wait( @@ -205,7 +203,7 @@ async def test_single_publish_with_batch_consume( mid = self.telemetry_middleware_class( meter_provider=meter_provider, tracer_provider=tracer_provider ) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) expected_msg_count = 2 expected_link_count = 2 expected_span_count = 6 @@ -222,18 +220,16 @@ async def handler(m, baggage: CurrentBaggage): mock(m) event.set() - broker = self.patch_broker(broker) - - async with broker: - await broker.start() + async with self.patch_broker(broker) as br: + await br.start() tasks = ( asyncio.create_task( - broker.publish( + br.publish( "hi", topic=queue, headers=Baggage({"foo": "bar"}).to_headers() ) ), asyncio.create_task( - broker.publish( + br.publish( "buy", topic=queue, headers=Baggage({"bar": "baz"}).to_headers() ) ), diff --git a/tests/opentelemetry/kafka/test_kafka.py b/tests/opentelemetry/kafka/test_kafka.py index a4f83748c4..db2249c1a2 100644 --- a/tests/opentelemetry/kafka/test_kafka.py +++ b/tests/opentelemetry/kafka/test_kafka.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional +from typing import Any, Optional from unittest.mock import Mock import pytest @@ -25,9 +25,11 @@ class TestTelemetry(LocalTelemetryTestcase): messaging_system = "kafka" include_messages_counters = True - broker_class = KafkaBroker telemetry_middleware_class = KafkaTelemetryMiddleware + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: + return KafkaBroker(apply_types=apply_types, **kwargs) + def assert_span( self, span: Span, @@ -74,7 +76,7 @@ async def test_batch( mid = self.telemetry_middleware_class( meter_provider=meter_provider, tracer_provider=tracer_provider ) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) expected_msg_count = 3 expected_link_count = 1 expected_link_attrs = {"messaging.batch.message_count": 3} @@ -92,13 +94,11 @@ async def handler(m, baggage: CurrentBaggage): mock(m) event.set() - broker = self.patch_broker(broker) - - async with broker: - await broker.start() + async with self.patch_broker(broker) as br: + await br.start() tasks = ( asyncio.create_task( - broker.publish_batch( + br.publish_batch( 1, "hi", 3, @@ -140,7 +140,7 @@ async def test_batch_publish_with_single_consume( mid = self.telemetry_middleware_class( meter_provider=meter_provider, tracer_provider=tracer_provider ) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) msgs_queue = asyncio.Queue(maxsize=3) expected_msg_count = 3 expected_link_count = 1 @@ -156,11 +156,9 @@ async def handler(msg, baggage: CurrentBaggage): assert baggage.get_all_batch() == [] await msgs_queue.put(msg) - broker = self.patch_broker(broker) - - async with broker: - await broker.start() - await broker.publish_batch( + async with self.patch_broker(broker) as br: + await br.start() + await br.publish_batch( 1, "hi", 3, topic=queue, headers=Baggage({"foo": "bar"}).to_headers() ) result, _ = await asyncio.wait( @@ -206,7 +204,7 @@ async def test_single_publish_with_batch_consume( mid = self.telemetry_middleware_class( meter_provider=meter_provider, tracer_provider=tracer_provider ) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) expected_msg_count = 2 expected_link_count = 2 expected_span_count = 6 @@ -223,18 +221,16 @@ async def handler(m, baggage: CurrentBaggage): mock(m) event.set() - broker = self.patch_broker(broker) - - async with broker: - await broker.start() + async with self.patch_broker(broker) as br: + await br.start() tasks = ( asyncio.create_task( - broker.publish( + br.publish( "hi", topic=queue, headers=Baggage({"foo": "bar"}).to_headers() ) ), asyncio.create_task( - broker.publish( + br.publish( "buy", topic=queue, headers=Baggage({"bar": "baz"}).to_headers() ) ), @@ -260,17 +256,19 @@ async def handler(m, baggage: CurrentBaggage): @pytest.mark.kafka class TestPublishWithTelemetry(TestPublish): - def get_broker(self, apply_types: bool = False): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: return KafkaBroker( middlewares=(KafkaTelemetryMiddleware(),), apply_types=apply_types, + **kwargs, ) @pytest.mark.kafka class TestConsumeWithTelemetry(TestConsume): - def get_broker(self, apply_types: bool = False): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: return KafkaBroker( middlewares=(KafkaTelemetryMiddleware(),), apply_types=apply_types, + **kwargs, ) diff --git a/tests/opentelemetry/nats/test_nats.py b/tests/opentelemetry/nats/test_nats.py index 88dc22d49c..6d03d901c8 100644 --- a/tests/opentelemetry/nats/test_nats.py +++ b/tests/opentelemetry/nats/test_nats.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any from unittest.mock import Mock import pytest @@ -24,9 +25,11 @@ def stream(queue): class TestTelemetry(LocalTelemetryTestcase): messaging_system = "nats" include_messages_counters = True - broker_class = NatsBroker telemetry_middleware_class = NatsTelemetryMiddleware + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: + return NatsBroker(apply_types=apply_types, **kwargs) + async def test_batch( self, event: asyncio.Event, @@ -41,7 +44,7 @@ async def test_batch( mid = self.telemetry_middleware_class( meter_provider=meter_provider, tracer_provider=tracer_provider ) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,)) expected_msg_count = 1 expected_span_count = 4 expected_proc_batch_count = 1 @@ -57,12 +60,10 @@ async def handler(m): mock(m) event.set() - broker = self.patch_broker(broker) - - async with broker: - await broker.start() + async with self.patch_broker(broker) as br: + await br.start() tasks = ( - asyncio.create_task(broker.publish("hi", queue)), + asyncio.create_task(br.publish("hi", queue)), asyncio.create_task(event.wait()), ) await asyncio.wait(tasks, timeout=self.timeout) @@ -90,17 +91,19 @@ async def handler(m): @pytest.mark.nats class TestPublishWithTelemetry(TestPublish): - def get_broker(self, apply_types: bool = False): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: return NatsBroker( middlewares=(NatsTelemetryMiddleware(),), apply_types=apply_types, + **kwargs, ) @pytest.mark.nats class TestConsumeWithTelemetry(TestConsume): - def get_broker(self, apply_types: bool = False): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: return NatsBroker( middlewares=(NatsTelemetryMiddleware(),), apply_types=apply_types, + **kwargs, ) diff --git a/tests/opentelemetry/rabbit/test_rabbit.py b/tests/opentelemetry/rabbit/test_rabbit.py index 2a779cdd4b..582c2e6832 100644 --- a/tests/opentelemetry/rabbit/test_rabbit.py +++ b/tests/opentelemetry/rabbit/test_rabbit.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional import pytest from dirty_equals import IsInt, IsUUID @@ -24,9 +24,11 @@ def exchange(queue): class TestTelemetry(LocalTelemetryTestcase): messaging_system = "rabbitmq" include_messages_counters = False - broker_class = RabbitBroker telemetry_middleware_class = RabbitTelemetryMiddleware + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: + return RabbitBroker(apply_types=apply_types, **kwargs) + def destination_name(self, queue: str) -> str: return f"default.{queue}" @@ -66,17 +68,19 @@ def assert_span( @pytest.mark.rabbit class TestPublishWithTelemetry(TestPublish): - def get_broker(self, apply_types: bool = False): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: return RabbitBroker( middlewares=(RabbitTelemetryMiddleware(),), apply_types=apply_types, + **kwargs, ) @pytest.mark.rabbit class TestConsumeWithTelemetry(TestConsume): - def get_broker(self, apply_types: bool = False): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: return RabbitBroker( middlewares=(RabbitTelemetryMiddleware(),), apply_types=apply_types, + **kwargs, ) diff --git a/tests/opentelemetry/redis/test_redis.py b/tests/opentelemetry/redis/test_redis.py index 8d8366ba10..871e347a2b 100644 --- a/tests/opentelemetry/redis/test_redis.py +++ b/tests/opentelemetry/redis/test_redis.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any from unittest.mock import Mock import pytest @@ -24,9 +25,11 @@ class TestTelemetry(LocalTelemetryTestcase): messaging_system = "redis" include_messages_counters = True - broker_class = RedisBroker telemetry_middleware_class = RedisTelemetryMiddleware + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: + return RedisBroker(apply_types=apply_types, **kwargs) + async def test_batch( self, event: asyncio.Event, @@ -40,7 +43,7 @@ async def test_batch( mid = self.telemetry_middleware_class( meter_provider=meter_provider, tracer_provider=tracer_provider ) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) expected_msg_count = 3 expected_link_count = 1 expected_link_attrs = {"messaging.batch.message_count": 3} @@ -56,12 +59,10 @@ async def handler(m, baggage: CurrentBaggage): mock(m) event.set() - broker = self.patch_broker(broker) - - async with broker: - await broker.start() + async with self.patch_broker(broker) as br: + await br.start() tasks = ( - asyncio.create_task(broker.publish_batch(1, "hi", 3, list=queue)), + asyncio.create_task(br.publish_batch(1, "hi", 3, list=queue)), asyncio.create_task(event.wait()), ) await asyncio.wait(tasks, timeout=self.timeout) @@ -96,7 +97,7 @@ async def test_batch_publish_with_single_consume( mid = self.telemetry_middleware_class( meter_provider=meter_provider, tracer_provider=tracer_provider ) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) msgs_queue = asyncio.Queue(maxsize=3) expected_msg_count = 3 expected_link_count = 1 @@ -113,11 +114,9 @@ async def handler(msg, baggage: CurrentBaggage): assert baggage.get_all_batch() == expected_baggage_batch await msgs_queue.put(msg) - broker = self.patch_broker(broker) - - async with broker: - await broker.start() - await broker.publish_batch(1, "hi", 3, list=queue) + async with self.patch_broker(broker) as br: + await br.start() + await br.publish_batch(1, "hi", 3, list=queue) result, _ = await asyncio.wait( ( asyncio.create_task(msgs_queue.get()), @@ -161,7 +160,7 @@ async def test_single_publish_with_batch_consume( mid = self.telemetry_middleware_class( meter_provider=meter_provider, tracer_provider=tracer_provider ) - broker = self.broker_class(middlewares=(mid,)) + broker = self.get_broker(middlewares=(mid,), apply_types=True) expected_msg_count = 2 expected_link_count = 2 expected_span_count = 6 @@ -178,17 +177,15 @@ async def handler(m, baggage: CurrentBaggage): mock(m) event.set() - broker = self.patch_broker(broker) - - async with broker: + async with self.patch_broker(broker) as br: tasks = ( asyncio.create_task( - broker.publish( + br.publish( "hi", list=queue, headers=Baggage({"foo": "bar"}).to_headers() ) ), asyncio.create_task( - broker.publish( + br.publish( "buy", list=queue, headers=Baggage({"bar": "baz"}).to_headers() ) ), @@ -217,35 +214,39 @@ async def handler(m, baggage: CurrentBaggage): @pytest.mark.redis class TestPublishWithTelemetry(TestPublish): - def get_broker(self, apply_types: bool = False): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: return RedisBroker( middlewares=(RedisTelemetryMiddleware(),), apply_types=apply_types, + **kwargs, ) @pytest.mark.redis class TestConsumeWithTelemetry(TestConsume): - def get_broker(self, apply_types: bool = False): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: return RedisBroker( middlewares=(RedisTelemetryMiddleware(),), apply_types=apply_types, + **kwargs, ) @pytest.mark.redis class TestConsumeListWithTelemetry(TestConsumeList): - def get_broker(self, apply_types: bool = False): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: return RedisBroker( middlewares=(RedisTelemetryMiddleware(),), apply_types=apply_types, + **kwargs, ) @pytest.mark.redis class TestConsumeStreamWithTelemetry(TestConsumeStream): - def get_broker(self, apply_types: bool = False): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: return RedisBroker( middlewares=(RedisTelemetryMiddleware(),), apply_types=apply_types, + **kwargs, )