Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rabbit sync new #120

Open
wants to merge 18 commits into
base: rabbit-sync
Choose a base branch
from
4 changes: 2 additions & 2 deletions propan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
else:
RabbitBroker = RabbitRouter = about.INSTALL_RABBIT # type: ignore

if is_installed("pika"):
try:
from propan.brokers.rabbit.rabbit_broker_sync import RabbitSyncBroker
from propan.brokers.rabbit.routing import RabbitRouter
else:
except ImportError:
RabbitSyncBroker = RabbitRouter = about.INSTALL_RABBIT_SYNC # type: ignore

if is_installed("nats"):
Expand Down
3 changes: 2 additions & 1 deletion propan/brokers/_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from propan.brokers._model.broker_usecase import BrokerAsyncUsecase
from propan.brokers._model.broker_usecase import BrokerAsyncUsecase, BrokerSyncUsecase
from propan.brokers._model.routing import BrokerRouter
from propan.brokers._model.schemas import PropanMessage, Queue

__all__ = (
"Queue",
"BrokerAsyncUsecase",
"BrokerSyncUsecase",
"PropanMessage",
"BrokerRouter",
)
44 changes: 36 additions & 8 deletions propan/brokers/_model/broker_usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import warnings
from abc import ABC, abstractmethod
from contextlib import AsyncExitStack
from contextlib import AsyncExitStack, ExitStack
from functools import wraps
from itertools import chain
from types import TracebackType
Expand All @@ -24,7 +24,6 @@
from fast_depends._compat import PYDANTIC_V2
from fast_depends.core import CallModel, build_call_model
from fast_depends.dependencies import Depends
from fast_depends.utils import args_to_kwargs
from typing_extensions import Self, TypeAlias, TypeVar

from propan.brokers._model.routing import BrokerRouter
Expand Down Expand Up @@ -150,13 +149,19 @@ def __init__(

def _resolve_connection_kwargs(self, *args: Any, **kwargs: AnyDict) -> AnyDict:
arguments = get_function_positional_arguments(self.__init__) # type: ignore
init_kwargs = args_to_kwargs(
arguments,
*self._connection_args,

init_kwargs = {
**self._connection_kwargs,
)
connect_kwargs = args_to_kwargs(arguments, *args, **kwargs)
return {**init_kwargs, **connect_kwargs}
**dict(zip(arguments, self._connection_args)),
}

connect_kwargs = {
**kwargs,
**dict(zip(arguments, args)),
}

final_kwargs = {**init_kwargs, **connect_kwargs}
return final_kwargs

@staticmethod
def _decode_message(message: PropanMessage[Any]) -> DecodedMessage:
Expand Down Expand Up @@ -647,12 +652,19 @@ async def middleware_wrapper(msg: PropanMessage[MsgType]) -> T_HandlerReturn:
class BrokerSyncUsecase(BrokerUsecase[MsgType, ConnectionType]):
_global_parser: SyncParser[MsgType]
_global_decoder: SyncDecoder[MsgType]
_inited: bool

@abstractmethod
def start(self) -> None:
super().start()
if not self._inited:
self._init_handlers()
self.connect()

@abstractmethod
def _init_handlers(self) -> Any:
self._inited = True

@abstractmethod
def _connect(self, **kwargs: AnyDict) -> Any:
raise NotImplementedError()
Expand All @@ -665,6 +677,7 @@ def close(
exec_tb: Optional[TracebackType] = None,
) -> None:
super().close()
self._inited = False

@abstractmethod
def _process_message( # type: ignore[override]
Expand Down Expand Up @@ -750,6 +763,7 @@ def __init__(
url_=url_,
**kwargs,
)
self._inited = False

def connect(self, *args: Any, **kwargs: AnyDict) -> ConnectionType:
if self._connection is None:
Expand Down Expand Up @@ -866,6 +880,20 @@ def log_wrapper(message: PropanMessage[MsgType]) -> T_HandlerReturn:

return log_wrapper

def _wrap_middleware(
self,
func: Callable[[MsgType], Union[T_HandlerReturn, Awaitable[T_HandlerReturn]]],
) -> Callable[[MsgType], Union[T_HandlerReturn, Awaitable[T_HandlerReturn]]]:
@wraps(func)
def middleware_wrapper(message: MsgType) -> T_HandlerReturn:
with ExitStack() as stack:
for m in self.middlewares:
stack.enter_context(m(message))

return func(message)

return middleware_wrapper


def extend_dependencies(extra: Sequence[CallModel]) -> Callable[[CallModel], CallModel]:
def dependant_wrapper(dependant: CallModel) -> CallModel:
Expand Down
27 changes: 27 additions & 0 deletions propan/brokers/push_back_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,30 @@ async def __aexit__(

else:
await call_or_await(self.on_error, self.message, **self.extra_args)

def __enter__(self) -> None:
self.watcher.add(self.message.message_id)

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if not exc_type:
if callable(self.on_success):
try:
self.on_success(self.message, **self.extra_args)
except TypeError as e:
print("on_success type error", e)
self.watcher.remove(self.message.message_id)

elif isinstance(exc_val, SkipMessage):
self.watcher.remove(self.message.message_id)

elif self.watcher.is_max(self.message.message_id):
self.on_max(self.message, **self.extra_args)
self.watcher.remove(self.message.message_id)

else:
self.on_error(self.message, **self.extra_args)
2 changes: 2 additions & 0 deletions propan/brokers/rabbit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from propan.brokers.rabbit.rabbit_broker import RabbitBroker, RabbitMessage
from propan.brokers.rabbit.routing import RabbitRouter
from propan.brokers.rabbit.schemas import ExchangeType, RabbitExchange, RabbitQueue
from propan.brokers.rabbit.rabbit_broker_sync import RabbitSyncBroker

__all__ = (
"RabbitBroker",
Expand All @@ -9,4 +10,5 @@
"RabbitExchange",
"ExchangeType",
"RabbitMessage",
"RabbitSyncBroker",
)
2 changes: 1 addition & 1 deletion propan/brokers/rabbit/rabbit_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ async def _parse_message(
body=message.body,
headers=message.headers,
reply_to=message.reply_to or "",
message_id=message.message_id,
message_id=message.message_id or str(uuid4()),
content_type=message.content_type or "",
raw_message=message,
)
Expand Down
69 changes: 45 additions & 24 deletions propan/brokers/rabbit/rabbit_broker_sync.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from functools import wraps
from threading import Event
from types import TracebackType
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Callable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from uuid import uuid4

import pika
from fast_depends.dependencies import Depends
from pika import spec
from pika.adapters import blocking_connection
from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, Never

from propan._compat import model_to_dict
from propan.brokers._model.broker_usecase import (
Expand All @@ -29,6 +38,7 @@
from propan.types import AnyDict, DecodedMessage, SendableMessage
from propan.utils import context


PIKA_RAW_MESSAGE: TypeAlias = Tuple[
blocking_connection.BlockingChannel,
spec.Basic.Deliver,
Expand Down Expand Up @@ -71,13 +81,14 @@ def __init__(

self._channel = None

self.__max_queue_len = 4
self.__max_exchange_len = 4
self._max_queue_len = 4
self._max_exchange_len = 4
self._queues = {}
self._exchanges = {}

def _connect(self, **kwargs: Any) -> blocking_connection.BlockingConnection:
connection = pika.BlockingConnection()
# TODO: use all kwargs here
connection = pika.BlockingConnection(pika.URLParameters(kwargs["url"]))

if self._channel is None:
max_consumers = self._max_consumers
Expand Down Expand Up @@ -105,13 +116,17 @@ def close(
self._connection.close()
self._connection = None

def start(self) -> None:
def start(self) -> Never:
context.set_local(
"log_context",
self._get_log_context(None, RabbitQueue(""), RabbitExchange("")),
)

super().start()
self._channel.start_consuming()

def _init_handlers(self):
super()._init_handlers()

for handler in self.handlers:
self._init_handler(handler)
Expand All @@ -125,16 +140,14 @@ def start(self) -> None:
on_message_callback=func,
)

self._channel.start_consuming()

def handle(
self,
queue: Union[str, RabbitQueue],
exchange: Union[str, RabbitExchange, None] = None,
*,
dependencies: Sequence[Depends] = (),
description: str = "",
**original_kwargs: Any,
**original_kwargs: AnyDict,
) -> Callable[
[HandlerCallable[T_HandlerReturn]],
Callable[[Any, bool], T_HandlerReturn],
Expand Down Expand Up @@ -186,21 +199,22 @@ def _process_message(
watcher = NotPushBackWatcher()

@wraps(func)
def wrapper(message):
channel: blocking_connection.BlockingChannel
method_frame: spec.Basic.Deliver
header_frame: spec.BasicProperties
def wrapper(message: PikaMessage):
channel, method_frame, header_frame, _ = message.raw_message

context = WatcherContext(
watcher,
message.message_id,
on_success=lambda: channel.basic_ack(method_frame.delivery_tag),
on_error=lambda: channel.basic_nack(
method_frame.delivery_tag, requeue=True
message,
on_success=lambda msg: channel.basic_ack(
method_frame.delivery_tag
),
on_error=lambda msg: channel.basic_nack(
method_frame.delivery_tag,
requeue=True,
),
on_max=lambda: channel.basic_reject(
method_frame.delivery_tag, requeue=False
on_max=lambda msg: channel.basic_reject(
method_frame.delivery_tag,
requeue=False,
),
)

Expand Down Expand Up @@ -245,7 +259,7 @@ def publish(
type_: Optional[str] = None,
user_id: Optional[str] = None,
app_id: Optional[str] = None,
) -> DecodedMessage | None:
) -> Optional[DecodedMessage]:
if self._channel is None:
raise ValueError("RabbitBroker channel not started yet")

Expand All @@ -254,6 +268,7 @@ def publish(
message, content_type = super()._encode_message(message)

response_event: Optional[Event] = None
response_msg: Optional[DecodedMessage] = None
if callback is True:
if reply_to is not None:
raise WRONG_PUBLISH_ARGS
Expand All @@ -267,7 +282,9 @@ def handle_response(
header_frame: spec.BasicProperties,
body: bytes,
):
# TODO: return message
nonlocal response_msg
msg = self._parse_message((channel, method_frame, header_frame, body))
response_msg = self._decode_message(msg)
response_event.set()

response_consumer_tag = self._channel.basic_consume(
Expand Down Expand Up @@ -300,8 +317,11 @@ def handle_response(

if response_event is not None:
self._channel._process_data_events(callback_timeout)
response_event.wait()
self._channel.basic_cancel(response_consumer_tag)
try:
response_event.wait()
return response_msg
finally:
self._channel.basic_cancel(response_consumer_tag)

def _init_handler(self, handler: Handler) -> None:
self.declare_queue(handler.queue)
Expand Down Expand Up @@ -376,7 +396,8 @@ def pika_handler_func_wrapper(
reraise_exc: bool = False,
):
return func(
(channel, method_frame, header_frame, body), reraise_exc=reraise_exc
(channel, method_frame, header_frame, body),
reraise_exc=reraise_exc,
)

return pika_handler_func_wrapper
3 changes: 2 additions & 1 deletion propan/brokers/sqs/sqs_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ def _get_log_context(


async def delete_message(
message: SQSMessage, connection: Optional[AioBaseClient]
message: SQSMessage,
connection: Optional[AioBaseClient],
) -> None:
if connection:
await connection.delete_message(
Expand Down
5 changes: 4 additions & 1 deletion propan/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

try:
from propan.test.rabbit import TestRabbitBroker
from propan.test.rabbit_sync import TestRabbitSyncBroker
except Exception:
TestRabbitBroker = about.INSTALL_RABBIT
TestRabbitSyncBroker = about.INSTALL_RABBIT

try:
from propan.test.redis import TestRedisBroker
Expand All @@ -26,7 +28,7 @@
TestSQSBroker = about.INSTALL_SQS

assert any(
(TestRabbitBroker, TestRedisBroker, TestKafkaBroker, TestNatsBroker, TestSQSBroker)
(TestRabbitBroker, TestRedisBroker, TestKafkaBroker, TestNatsBroker, TestSQSBroker, TestRabbitSyncBroker)
), about.INSTALL_MESSAGE

__all__ = (
Expand All @@ -35,4 +37,5 @@
"TestKafkaBroker",
"TestNatsBroker",
"TestSQSBroker",
"TestRabbitSyncBroker",
)
Loading